aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Anderson <dave@natulte.net>2015-08-21 13:38:22 -0700
committerDavid Anderson <dave@natulte.net>2015-08-21 16:37:45 -0700
commit122ddb16de825ed3d989d25d4d7b2d2e278abdf6 (patch)
tree43d4cad7936f67a709c32b80439fef32805ddebf
parentImplement support for passing Go functions as custom functions to SQLite. (diff)
downloadgolite-122ddb16de825ed3d989d25d4d7b2d2e278abdf6.tar.gz
golite-122ddb16de825ed3d989d25d4d7b2d2e278abdf6.tar.xz
Move argument converters to callback.go, and optimize return value handling.
A call now doesn't have to do any reflection, it just blindly invokes a bunch of argument and return value handlers to execute the translation, and the safety of the translation is determined at registration time.
-rw-r--r--callback.go200
-rw-r--r--callback_test.go97
-rw-r--r--sqlite3.go122
-rw-r--r--sqlite3_test.go102
4 files changed, 367 insertions, 154 deletions
diff --git a/callback.go b/callback.go
index 938d7fe..1692106 100644
--- a/callback.go
+++ b/callback.go
@@ -5,12 +5,25 @@
package sqlite3
+// You can't export a Go function to C and have definitions in the C
+// preamble in the same file, so we have to have callbackTrampoline in
+// its own file. Because we need a separate file anyway, the support
+// code for SQLite custom functions is in here.
+
/*
#include <sqlite3-binding.h>
+
+void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
+void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l);
*/
import "C"
-import "unsafe"
+import (
+ "errors"
+ "fmt"
+ "reflect"
+ "unsafe"
+)
//export callbackTrampoline
func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
@@ -18,3 +31,188 @@ func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value
fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
fi.Call(ctx, args)
}
+
+// This is only here so that tests can refer to it.
+type callbackArgRaw C.sqlite3_value
+
+type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error)
+
+type callbackArgCast struct {
+ f callbackArgConverter
+ typ reflect.Type
+}
+
+func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) {
+ val, err := c.f(v)
+ if err != nil {
+ return reflect.Value{}, err
+ }
+ if !val.Type().ConvertibleTo(c.typ) {
+ return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ)
+ }
+ return val.Convert(c.typ), nil
+}
+
+func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) {
+ if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
+ return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
+ }
+ return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
+}
+
+func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) {
+ if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
+ return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
+ }
+ i := int64(C.sqlite3_value_int64(v))
+ val := false
+ if i != 0 {
+ val = true
+ }
+ return reflect.ValueOf(val), nil
+}
+
+func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) {
+ if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
+ return reflect.Value{}, fmt.Errorf("argument must be a FLOAT")
+ }
+ return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
+}
+
+func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) {
+ switch C.sqlite3_value_type(v) {
+ case C.SQLITE_BLOB:
+ l := C.sqlite3_value_bytes(v)
+ p := C.sqlite3_value_blob(v)
+ return reflect.ValueOf(C.GoBytes(p, l)), nil
+ case C.SQLITE_TEXT:
+ l := C.sqlite3_value_bytes(v)
+ c := unsafe.Pointer(C.sqlite3_value_text(v))
+ return reflect.ValueOf(C.GoBytes(c, l)), nil
+ default:
+ return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
+ }
+}
+
+func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) {
+ switch C.sqlite3_value_type(v) {
+ case C.SQLITE_BLOB:
+ l := C.sqlite3_value_bytes(v)
+ p := (*C.char)(C.sqlite3_value_blob(v))
+ return reflect.ValueOf(C.GoStringN(p, l)), nil
+ case C.SQLITE_TEXT:
+ c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
+ return reflect.ValueOf(C.GoString(c)), nil
+ default:
+ return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
+ }
+}
+
+func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
+ switch typ.Kind() {
+ case reflect.Slice:
+ if typ.Elem().Kind() != reflect.Uint8 {
+ return nil, errors.New("the only supported slice type is []byte")
+ }
+ return callbackArgBytes, nil
+ case reflect.String:
+ return callbackArgString, nil
+ case reflect.Bool:
+ return callbackArgBool, nil
+ case reflect.Int64:
+ return callbackArgInt64, nil
+ case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
+ c := callbackArgCast{callbackArgInt64, typ}
+ return c.Run, nil
+ case reflect.Float64:
+ return callbackArgFloat64, nil
+ case reflect.Float32:
+ c := callbackArgCast{callbackArgFloat64, typ}
+ return c.Run, nil
+ default:
+ return nil, fmt.Errorf("don't know how to convert to %s", typ)
+ }
+}
+
+type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error
+
+func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error {
+ switch v.Type().Kind() {
+ case reflect.Int64:
+ case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
+ v = v.Convert(reflect.TypeOf(int64(0)))
+ case reflect.Bool:
+ b := v.Interface().(bool)
+ if b {
+ v = reflect.ValueOf(int64(1))
+ } else {
+ v = reflect.ValueOf(int64(0))
+ }
+ default:
+ return fmt.Errorf("cannot convert %s to INTEGER", v.Type())
+ }
+
+ C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64)))
+ return nil
+}
+
+func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error {
+ switch v.Type().Kind() {
+ case reflect.Float64:
+ case reflect.Float32:
+ v = v.Convert(reflect.TypeOf(float64(0)))
+ default:
+ return fmt.Errorf("cannot convert %s to FLOAT", v.Type())
+ }
+
+ C.sqlite3_result_double(ctx, C.double(v.Interface().(float64)))
+ return nil
+}
+
+func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error {
+ if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 {
+ return fmt.Errorf("cannot convert %s to BLOB", v.Type())
+ }
+ i := v.Interface()
+ if i == nil || len(i.([]byte)) == 0 {
+ C.sqlite3_result_null(ctx)
+ } else {
+ bs := i.([]byte)
+ C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs)))
+ }
+ return nil
+}
+
+func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error {
+ if v.Type().Kind() != reflect.String {
+ return fmt.Errorf("cannot convert %s to TEXT", v.Type())
+ }
+ C._sqlite3_result_text(ctx, C.CString(v.Interface().(string)))
+ return nil
+}
+
+func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
+ switch typ.Kind() {
+ case reflect.Slice:
+ if typ.Elem().Kind() != reflect.Uint8 {
+ return nil, errors.New("the only supported slice type is []byte")
+ }
+ return callbackRetBlob, nil
+ case reflect.String:
+ return callbackRetText, nil
+ case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
+ return callbackRetInteger, nil
+ case reflect.Float32, reflect.Float64:
+ return callbackRetFloat, nil
+ default:
+ return nil, fmt.Errorf("don't know how to convert to %s", typ)
+ }
+}
+
+// Test support code. Tests are not allowed to import "C", so we can't
+// declare any functions that use C.sqlite3_value.
+func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter {
+ return func(*C.sqlite3_value) (reflect.Value, error) {
+ return v, err
+ }
+}
diff --git a/callback_test.go b/callback_test.go
new file mode 100644
index 0000000..5c61f44
--- /dev/null
+++ b/callback_test.go
@@ -0,0 +1,97 @@
+package sqlite3
+
+import (
+ "errors"
+ "math"
+ "reflect"
+ "testing"
+)
+
+func TestCallbackArgCast(t *testing.T) {
+ intConv := callbackSyntheticForTests(reflect.ValueOf(int64(math.MaxInt64)), nil)
+ floatConv := callbackSyntheticForTests(reflect.ValueOf(float64(math.MaxFloat64)), nil)
+ errConv := callbackSyntheticForTests(reflect.Value{}, errors.New("test"))
+
+ tests := []struct {
+ f callbackArgConverter
+ o reflect.Value
+ }{
+ {intConv, reflect.ValueOf(int8(-1))},
+ {intConv, reflect.ValueOf(int16(-1))},
+ {intConv, reflect.ValueOf(int32(-1))},
+ {intConv, reflect.ValueOf(uint8(math.MaxUint8))},
+ {intConv, reflect.ValueOf(uint16(math.MaxUint16))},
+ {intConv, reflect.ValueOf(uint32(math.MaxUint32))},
+ // Special case, int64->uint64 is only 1<<63 - 1, not 1<<64 - 1
+ {intConv, reflect.ValueOf(uint64(math.MaxInt64))},
+ {floatConv, reflect.ValueOf(float32(math.Inf(1)))},
+ }
+
+ for _, test := range tests {
+ conv := callbackArgCast{test.f, test.o.Type()}
+ val, err := conv.Run(nil)
+ if err != nil {
+ t.Errorf("Couldn't convert to %s: %s", test.o.Type(), err)
+ } else if !reflect.DeepEqual(val.Interface(), test.o.Interface()) {
+ t.Errorf("Unexpected result from converting to %s: got %v, want %v", test.o.Type(), val.Interface(), test.o.Interface())
+ }
+ }
+
+ conv := callbackArgCast{errConv, reflect.TypeOf(int8(0))}
+ _, err := conv.Run(nil)
+ if err == nil {
+ t.Errorf("Expected error during callbackArgCast, but got none")
+ }
+}
+
+func TestCallbackConverters(t *testing.T) {
+ tests := []struct {
+ v interface{}
+ err bool
+ }{
+ // Unfortunately, we can't tell which converter was returned,
+ // but we can at least check which types can be converted.
+ {[]byte{0}, false},
+ {"text", false},
+ {true, false},
+ {int8(0), false},
+ {int16(0), false},
+ {int32(0), false},
+ {int64(0), false},
+ {uint8(0), false},
+ {uint16(0), false},
+ {uint32(0), false},
+ {uint64(0), false},
+ {int(0), false},
+ {uint(0), false},
+ {float64(0), false},
+ {float32(0), false},
+
+ {func() {}, true},
+ {complex64(complex(0, 0)), true},
+ {complex128(complex(0, 0)), true},
+ {struct{}{}, true},
+ {map[string]string{}, true},
+ {[]string{}, true},
+ {(*int8)(nil), true},
+ {make(chan int), true},
+ }
+
+ for _, test := range tests {
+ _, err := callbackArg(reflect.TypeOf(test.v))
+ if test.err && err == nil {
+ t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v))
+ } else if !test.err && err != nil {
+ t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err)
+ }
+ }
+
+ for _, test := range tests {
+ _, err := callbackRet(reflect.TypeOf(test.v))
+ if test.err && err == nil {
+ t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v))
+ } else if !test.err && err != nil {
+ t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err)
+ }
+ }
+}
diff --git a/sqlite3.go b/sqlite3.go
index f995589..174a3ee 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -166,7 +166,8 @@ type SQLiteRows struct {
type functionInfo struct {
f reflect.Value
- argConverters []func(*C.sqlite3_value) (reflect.Value, error)
+ argConverters []callbackArgConverter
+ retConverter callbackRetConverter
}
func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) {
@@ -193,58 +194,11 @@ func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
return
}
- res := ret[0].Interface()
- // Normalize ret to one of the types sqlite knows.
- switch r := res.(type) {
- case int64, float64, []byte, string:
- // Already the right type
- case bool:
- if r {
- res = int64(1)
- } else {
- res = int64(0)
- }
- case int:
- res = int64(r)
- case uint:
- res = int64(r)
- case uint8:
- res = int64(r)
- case uint16:
- res = int64(r)
- case uint32:
- res = int64(r)
- case uint64:
- res = int64(r)
- case int8:
- res = int64(r)
- case int16:
- res = int64(r)
- case int32:
- res = int64(r)
- case float32:
- res = float64(r)
- default:
- fi.error(ctx, errors.New("cannot convert returned type to sqlite type"))
+ err := fi.retConverter(ctx, ret[0])
+ if err != nil {
+ fi.error(ctx, err)
return
}
-
- switch r := res.(type) {
- case int64:
- C.sqlite3_result_int64(ctx, C.sqlite3_int64(r))
- case float64:
- C.sqlite3_result_double(ctx, C.double(r))
- case []byte:
- if len(r) == 0 {
- C.sqlite3_result_null(ctx)
- } else {
- C._sqlite3_result_blob(ctx, unsafe.Pointer(&r[0]), C.int(len(r)))
- }
- case string:
- C._sqlite3_result_text(ctx, C.CString(r))
- default:
- panic("unreachable")
- }
}
// Commit transaction.
@@ -261,10 +215,10 @@ func (tx *SQLiteTx) Rollback() error {
// RegisterFunc makes a Go function available as a SQLite function.
//
-// The function must accept only arguments of type int64, float64,
-// []byte or string, and return one value of any numeric type except
-// complex, bool, []byte or string. Optionally, an error can be
-// provided as a second return value.
+// The function can accept arguments of any real numeric type
+// (i.e. not complex), as well as []byte and string. It must return a
+// value of one of those types, and optionally an error as a second
+// value.
//
// If pure is true. SQLite will assume that the function's return
// value depends only on its inputs, and make more aggressive
@@ -287,59 +241,19 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
}
for i := 0; i < t.NumIn(); i++ {
- arg := t.In(i)
- var conv func(*C.sqlite3_value) (reflect.Value, error)
- switch arg.Kind() {
- case reflect.Int64:
- conv = func(v *C.sqlite3_value) (reflect.Value, error) {
- if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
- return reflect.Value{}, fmt.Errorf("Argument %d to %s must be an INTEGER", i+1, name)
- }
- return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
- }
- case reflect.Float64:
- conv = func(v *C.sqlite3_value) (reflect.Value, error) {
- if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
- return reflect.Value{}, fmt.Errorf("Argument %d to %s must be a FLOAT", i+1, name)
- }
- return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
- }
- case reflect.Slice:
- if arg.Elem().Kind() != reflect.Uint8 {
- return errors.New("The only supported slice type is []byte")
- }
- conv = func(v *C.sqlite3_value) (reflect.Value, error) {
- switch C.sqlite3_value_type(v) {
- case C.SQLITE_BLOB:
- l := C.sqlite3_value_bytes(v)
- p := C.sqlite3_value_blob(v)
- return reflect.ValueOf(C.GoBytes(p, l)), nil
- case C.SQLITE_TEXT:
- l := C.sqlite3_value_bytes(v)
- c := unsafe.Pointer(C.sqlite3_value_text(v))
- return reflect.ValueOf(C.GoBytes(c, l)), nil
- default:
- return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name)
- }
- }
- case reflect.String:
- conv = func(v *C.sqlite3_value) (reflect.Value, error) {
- switch C.sqlite3_value_type(v) {
- case C.SQLITE_BLOB:
- l := C.sqlite3_value_bytes(v)
- p := (*C.char)(C.sqlite3_value_blob(v))
- return reflect.ValueOf(C.GoStringN(p, l)), nil
- case C.SQLITE_TEXT:
- c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
- return reflect.ValueOf(C.GoString(c)), nil
- default:
- return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name)
- }
- }
+ conv, err := callbackArg(t.In(i))
+ if err != nil {
+ return err
}
fi.argConverters = append(fi.argConverters, conv)
}
+ conv, err := callbackRet(t.Out(0))
+ if err != nil {
+ return err
+ }
+ fi.retConverter = conv
+
// fi must outlast the database connection, or we'll have dangling pointers.
c.funcs = append(c.funcs, &fi)
diff --git a/sqlite3_test.go b/sqlite3_test.go
index a58e373..e8dfe5c 100644
--- a/sqlite3_test.go
+++ b/sqlite3_test.go
@@ -15,6 +15,7 @@ import (
"net/url"
"os"
"path/filepath"
+ "reflect"
"regexp"
"strings"
"sync"
@@ -1060,25 +1061,41 @@ func TestDateTimeNow(t *testing.T) {
}
func TestFunctionRegistration(t *testing.T) {
- custom_add := func(a, b int64) (int64, error) {
- return a + b, nil
- }
- custom_regex := func(s, re string) bool {
- matched, err := regexp.MatchString(re, s)
- if err != nil {
- // We should really return the error here, but this
- // function is also testing single return value functions.
- panic("Bad regexp")
- }
- return matched
+ addi_8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) }
+ addi_64 := func(a, b int64) int64 { return a + b }
+ addu_8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) }
+ addu_64 := func(a, b uint64) uint64 { return a + b }
+ addiu := func(a int, b uint) int64 { return int64(a) + int64(b) }
+ addf_32_64 := func(a float32, b float64) float64 { return float64(a) + b }
+ not := func(a bool) bool { return !a }
+ regex := func(re, s string) (bool, error) {
+ return regexp.MatchString(re, s)
}
sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
- if err := conn.RegisterFunc("custom_add", custom_add, true); err != nil {
+ if err := conn.RegisterFunc("addi_8_16_32", addi_8_16_32, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("addi_64", addi_64, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("addu_8_16_32", addu_8_16_32, true); err != nil {
return err
}
- if err := conn.RegisterFunc("regexp", custom_regex, true); err != nil {
+ if err := conn.RegisterFunc("addu_64", addu_64, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("addiu", addiu, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("addf_32_64", addf_32_64, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("not", not, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("regex", regex, true); err != nil {
return err
}
return nil
@@ -1090,42 +1107,29 @@ func TestFunctionRegistration(t *testing.T) {
}
defer db.Close()
- additions := []struct {
- a, b, c int64
+ ops := []struct {
+ query string
+ expected interface{}
}{
- {1, 1, 2},
- {1, 3, 4},
- {1, -1, 0},
- }
-
- for _, add := range additions {
- var i int64
- err = db.QueryRow("SELECT custom_add($1, $2)", add.a, add.b).Scan(&i)
+ {"SELECT addi_8_16_32(1,2)", int32(3)},
+ {"SELECT addi_64(1,2)", int64(3)},
+ {"SELECT addu_8_16_32(1,2)", uint32(3)},
+ {"SELECT addu_64(1,2)", uint64(3)},
+ {"SELECT addiu(1,2)", int64(3)},
+ {"SELECT addf_32_64(1.5,1.5)", float64(3)},
+ {"SELECT not(1)", false},
+ {"SELECT not(0)", true},
+ {`SELECT regex("^foo.*", "foobar")`, true},
+ {`SELECT regex("^foo.*", "barfoobar")`, false},
+ }
+
+ for _, op := range ops {
+ ret := reflect.New(reflect.TypeOf(op.expected))
+ err = db.QueryRow(op.query).Scan(ret.Interface())
if err != nil {
- t.Fatal("Failed to call custom_add:", err)
- }
- if i != add.c {
- t.Fatalf("custom_add returned the wrong value, got %d, want %d", i, add.c)
- }
- }
-
- regexes := []struct {
- re, in string
- out bool
- }{
- {".*", "foo", true},
- {"^foo.*", "foobar", true},
- {"^foo.*", "barfoo", false},
- }
-
- for _, re := range regexes {
- var b bool
- err = db.QueryRow("SELECT regexp($1, $2)", re.in, re.re).Scan(&b)
- if err != nil {
- t.Fatal("Failed to call regexp:", err)
- }
- if b != re.out {
- t.Fatalf("regexp returned the wrong value, got %v, want %v", b, re.out)
+ t.Errorf("Query %q failed: %s", op.query, err)
+ } else if !reflect.DeepEqual(ret.Elem().Interface(), op.expected) {
+ t.Errorf("Query %q returned wrong value: got %v (%T), want %v (%T)", op.query, ret.Elem().Interface(), ret.Elem().Interface(), op.expected, op.expected)
}
}
}
@@ -1134,8 +1138,8 @@ var customFunctionOnce sync.Once
func BenchmarkCustomFunctions(b *testing.B) {
customFunctionOnce.Do(func() {
- custom_add := func(a, b int64) (int64, error) {
- return a + b, nil
+ custom_add := func(a, b int64) int64 {
+ return a + b
}
sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{