aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--callback.go52
-rw-r--r--sqlite3.go5
2 files changed, 52 insertions, 5 deletions
diff --git a/callback.go b/callback.go
index ee9d40c..e2bf3c6 100644
--- a/callback.go
+++ b/callback.go
@@ -24,29 +24,75 @@ import (
"fmt"
"math"
"reflect"
+ "sync"
"unsafe"
)
//export callbackTrampoline
func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
- fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
+ fi := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*functionInfo)
fi.Call(ctx, args)
}
//export stepTrampoline
func stepTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
- ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
+ ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo)
ai.Step(ctx, args)
}
//export doneTrampoline
func doneTrampoline(ctx *C.sqlite3_context) {
- ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
+ handle := uintptr(C.sqlite3_user_data(ctx))
+ ai := lookupHandle(handle).(*aggInfo)
ai.Done(ctx)
}
+// Use handles to avoid passing Go pointers to C.
+
+type handleVal struct {
+ db *SQLiteConn
+ val interface{}
+}
+
+var handleLock sync.Mutex
+var handleVals = make(map[uintptr]handleVal)
+var handleIndex uintptr = 100
+
+func newHandle(db *SQLiteConn, v interface{}) uintptr {
+ handleLock.Lock()
+ defer handleLock.Unlock()
+ i := handleIndex
+ handleIndex++
+ handleVals[i] = handleVal{db, v}
+ return i
+}
+
+func lookupHandle(handle uintptr) interface{} {
+ handleLock.Lock()
+ defer handleLock.Unlock()
+ r, ok := handleVals[handle]
+ if !ok {
+ if handle >= 100 && handle < handleIndex {
+ panic("deleted handle")
+ } else {
+ panic("invalid handle")
+ }
+ }
+ return r.val
+}
+
+func deleteHandles(db *SQLiteConn) {
+ handleLock.Lock()
+ defer handleLock.Unlock()
+ for handle, val := range handleVals {
+ if val.db == db {
+ delete(handleVals, handle)
+ }
+ }
+}
+
// This is only here so that tests can refer to it.
type callbackArgRaw C.sqlite3_value
diff --git a/sqlite3.go b/sqlite3.go
index eae71ab..0a6f136 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -367,7 +367,7 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
- rv := C._sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), C.uintptr_t(uintptr(unsafe.Pointer(&fi))), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil)
+ rv := C._sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), C.uintptr_t(newHandle(c, &fi)), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil)
if rv != C.SQLITE_OK {
return c.lastError()
}
@@ -492,7 +492,7 @@ func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
- rv := C._sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), C.uintptr_t(uintptr(unsafe.Pointer(&ai))), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline)))
+ rv := C._sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), C.uintptr_t(newHandle(c, &ai)), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline)))
if rv != C.SQLITE_OK {
return c.lastError()
}
@@ -705,6 +705,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
// Close the connection.
func (c *SQLiteConn) Close() error {
+ deleteHandles(c)
rv := C.sqlite3_close_v2(c.db)
if rv != C.SQLITE_OK {
return c.lastError()