aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--callback.go24
-rw-r--r--sqlite3.go13
-rw-r--r--sqlite3_test.go28
3 files changed, 60 insertions, 5 deletions
diff --git a/callback.go b/callback.go
index 1692106..b1704fe 100644
--- a/callback.go
+++ b/callback.go
@@ -108,8 +108,32 @@ func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) {
}
}
+func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) {
+ switch C.sqlite3_value_type(v) {
+ case C.SQLITE_INTEGER:
+ return callbackArgInt64(v)
+ case C.SQLITE_FLOAT:
+ return callbackArgFloat64(v)
+ case C.SQLITE_TEXT:
+ return callbackArgString(v)
+ case C.SQLITE_BLOB:
+ return callbackArgBytes(v)
+ case C.SQLITE_NULL:
+ // Interpret NULL as a nil byte slice.
+ var ret []byte
+ return reflect.ValueOf(ret), nil
+ default:
+ panic("unreachable")
+ }
+}
+
func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
switch typ.Kind() {
+ case reflect.Interface:
+ if typ.NumMethod() != 0 {
+ return nil, errors.New("the only supported interface type is interface{}")
+ }
+ return callbackArgGeneric, nil
case reflect.Slice:
if typ.Elem().Kind() != reflect.Uint8 {
return nil, errors.New("the only supported slice type is []byte")
diff --git a/sqlite3.go b/sqlite3.go
index 8bb9826..73e67e3 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -232,11 +232,14 @@ func (tx *SQLiteTx) Rollback() error {
// RegisterFunc makes a Go function available as a SQLite function.
//
-// The function can accept arguments of any real numeric type
-// (i.e. not complex), as well as []byte and string. It must return a
-// value of one of those types, and optionally an error as a second
-// value. Variadic functions are allowed, if the variadic argument is
-// one of the allowed types.
+// The Go function can have arguments of the following types: any
+// numeric type except complex, bool, []byte, string and
+// interface{}. interface{} arguments are given the direct translation
+// of the SQLite data type: int64 for INTEGER, float64 for FLOAT,
+// []byte for BLOB, string for TEXT.
+//
+// The function can additionally be variadic, as long as the type of
+// the variadic argument is one of the above.
//
// If pure is true. SQLite will assume that the function's return
// value depends only on its inputs, and make more aggressive
diff --git a/sqlite3_test.go b/sqlite3_test.go
index a563c08..62db05b 100644
--- a/sqlite3_test.go
+++ b/sqlite3_test.go
@@ -1071,6 +1071,20 @@ func TestFunctionRegistration(t *testing.T) {
regex := func(re, s string) (bool, error) {
return regexp.MatchString(re, s)
}
+ generic := func(a interface{}) int64 {
+ switch a.(type) {
+ case int64:
+ return 1
+ case float64:
+ return 2
+ case []byte:
+ return 3
+ case string:
+ return 4
+ default:
+ panic("unreachable")
+ }
+ }
variadic := func(a, b int64, c ...int64) int64 {
ret := a + b
for _, d := range c {
@@ -1078,6 +1092,9 @@ func TestFunctionRegistration(t *testing.T) {
}
return ret
}
+ variadicGeneric := func(a ...interface{}) int64 {
+ return int64(len(a))
+ }
sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
@@ -1105,9 +1122,15 @@ func TestFunctionRegistration(t *testing.T) {
if err := conn.RegisterFunc("regex", regex, true); err != nil {
return err
}
+ if err := conn.RegisterFunc("generic", generic, true); err != nil {
+ return err
+ }
if err := conn.RegisterFunc("variadic", variadic, true); err != nil {
return err
}
+ if err := conn.RegisterFunc("variadicGeneric", variadicGeneric, true); err != nil {
+ return err
+ }
return nil
},
})
@@ -1131,9 +1154,14 @@ func TestFunctionRegistration(t *testing.T) {
{"SELECT not(0)", true},
{`SELECT regex("^foo.*", "foobar")`, true},
{`SELECT regex("^foo.*", "barfoobar")`, false},
+ {"SELECT generic(1)", int64(1)},
+ {"SELECT generic(1.1)", int64(2)},
+ {`SELECT generic(NULL)`, int64(3)},
+ {`SELECT generic("foo")`, int64(4)},
{"SELECT variadic(1,2)", int64(3)},
{"SELECT variadic(1,2,3,4)", int64(10)},
{"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)},
+ {`SELECT variadicGeneric(1,"foo",2.3, NULL)`, int64(4)},
}
for _, op := range ops {