From cf8fa0af80e0d227c79ef2b4635e8d0d77432275 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Thu, 20 Aug 2015 23:08:48 -0700 Subject: Implement support for passing Go functions as custom functions to SQLite. Fixes #226. --- sqlite3_test.go | 108 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) (limited to 'sqlite3_test.go') diff --git a/sqlite3_test.go b/sqlite3_test.go index 423f30e..a58e373 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -15,7 +15,9 @@ import ( "net/url" "os" "path/filepath" + "regexp" "strings" + "sync" "testing" "time" @@ -1056,3 +1058,109 @@ func TestDateTimeNow(t *testing.T) { t.Fatal("Failed to scan datetime:", err) } } + +func TestFunctionRegistration(t *testing.T) { + custom_add := func(a, b int64) (int64, error) { + return a + b, nil + } + custom_regex := func(s, re string) bool { + matched, err := regexp.MatchString(re, s) + if err != nil { + // We should really return the error here, but this + // function is also testing single return value functions. + panic("Bad regexp") + } + return matched + } + + sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + if err := conn.RegisterFunc("custom_add", custom_add, true); err != nil { + return err + } + if err := conn.RegisterFunc("regexp", custom_regex, true); err != nil { + return err + } + return nil + }, + }) + db, err := sql.Open("sqlite3_FunctionRegistration", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + additions := []struct { + a, b, c int64 + }{ + {1, 1, 2}, + {1, 3, 4}, + {1, -1, 0}, + } + + for _, add := range additions { + var i int64 + err = db.QueryRow("SELECT custom_add($1, $2)", add.a, add.b).Scan(&i) + if err != nil { + t.Fatal("Failed to call custom_add:", err) + } + if i != add.c { + t.Fatalf("custom_add returned the wrong value, got %d, want %d", i, add.c) + } + } + + regexes := []struct { + re, in string + out bool + }{ + {".*", "foo", true}, + {"^foo.*", "foobar", true}, + {"^foo.*", "barfoo", false}, + } + + for _, re := range regexes { + var b bool + err = db.QueryRow("SELECT regexp($1, $2)", re.in, re.re).Scan(&b) + if err != nil { + t.Fatal("Failed to call regexp:", err) + } + if b != re.out { + t.Fatalf("regexp returned the wrong value, got %v, want %v", b, re.out) + } + } +} + +var customFunctionOnce sync.Once + +func BenchmarkCustomFunctions(b *testing.B) { + customFunctionOnce.Do(func() { + custom_add := func(a, b int64) (int64, error) { + return a + b, nil + } + + sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + // Impure function to force sqlite to reexecute it each time. + if err := conn.RegisterFunc("custom_add", custom_add, false); err != nil { + return err + } + return nil + }, + }) + }) + + db, err := sql.Open("sqlite3_BenchmarkCustomFunctions", ":memory:") + if err != nil { + b.Fatal("Failed to open database:", err) + } + defer db.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var i int64 + err = db.QueryRow("SELECT custom_add(1,2)").Scan(&i) + if err != nil { + b.Fatal("Failed to run custom add:", err) + } + } +} -- cgit v1.2.3 From 122ddb16de825ed3d989d25d4d7b2d2e278abdf6 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Fri, 21 Aug 2015 13:38:22 -0700 Subject: Move argument converters to callback.go, and optimize return value handling. A call now doesn't have to do any reflection, it just blindly invokes a bunch of argument and return value handlers to execute the translation, and the safety of the translation is determined at registration time. --- callback.go | 200 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- callback_test.go | 97 +++++++++++++++++++++++++++ sqlite3.go | 122 +++++---------------------------- sqlite3_test.go | 102 ++++++++++++++-------------- 4 files changed, 367 insertions(+), 154 deletions(-) create mode 100644 callback_test.go (limited to 'sqlite3_test.go') diff --git a/callback.go b/callback.go index 938d7fe..1692106 100644 --- a/callback.go +++ b/callback.go @@ -5,12 +5,25 @@ package sqlite3 +// You can't export a Go function to C and have definitions in the C +// preamble in the same file, so we have to have callbackTrampoline in +// its own file. Because we need a separate file anyway, the support +// code for SQLite custom functions is in here. + /* #include + +void _sqlite3_result_text(sqlite3_context* ctx, const char* s); +void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l); */ import "C" -import "unsafe" +import ( + "errors" + "fmt" + "reflect" + "unsafe" +) //export callbackTrampoline func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) { @@ -18,3 +31,188 @@ func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx))) fi.Call(ctx, args) } + +// This is only here so that tests can refer to it. +type callbackArgRaw C.sqlite3_value + +type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error) + +type callbackArgCast struct { + f callbackArgConverter + typ reflect.Type +} + +func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) { + val, err := c.f(v) + if err != nil { + return reflect.Value{}, err + } + if !val.Type().ConvertibleTo(c.typ) { + return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ) + } + return val.Convert(c.typ), nil +} + +func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { + return reflect.Value{}, fmt.Errorf("argument must be an INTEGER") + } + return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil +} + +func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { + return reflect.Value{}, fmt.Errorf("argument must be an INTEGER") + } + i := int64(C.sqlite3_value_int64(v)) + val := false + if i != 0 { + val = true + } + return reflect.ValueOf(val), nil +} + +func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_FLOAT { + return reflect.Value{}, fmt.Errorf("argument must be a FLOAT") + } + return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil +} + +func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_BLOB: + l := C.sqlite3_value_bytes(v) + p := C.sqlite3_value_blob(v) + return reflect.ValueOf(C.GoBytes(p, l)), nil + case C.SQLITE_TEXT: + l := C.sqlite3_value_bytes(v) + c := unsafe.Pointer(C.sqlite3_value_text(v)) + return reflect.ValueOf(C.GoBytes(c, l)), nil + default: + return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT") + } +} + +func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_BLOB: + l := C.sqlite3_value_bytes(v) + p := (*C.char)(C.sqlite3_value_blob(v)) + return reflect.ValueOf(C.GoStringN(p, l)), nil + case C.SQLITE_TEXT: + c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v))) + return reflect.ValueOf(C.GoString(c)), nil + default: + return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT") + } +} + +func callbackArg(typ reflect.Type) (callbackArgConverter, error) { + switch typ.Kind() { + case reflect.Slice: + if typ.Elem().Kind() != reflect.Uint8 { + return nil, errors.New("the only supported slice type is []byte") + } + return callbackArgBytes, nil + case reflect.String: + return callbackArgString, nil + case reflect.Bool: + return callbackArgBool, nil + case reflect.Int64: + return callbackArgInt64, nil + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: + c := callbackArgCast{callbackArgInt64, typ} + return c.Run, nil + case reflect.Float64: + return callbackArgFloat64, nil + case reflect.Float32: + c := callbackArgCast{callbackArgFloat64, typ} + return c.Run, nil + default: + return nil, fmt.Errorf("don't know how to convert to %s", typ) + } +} + +type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error + +func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error { + switch v.Type().Kind() { + case reflect.Int64: + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: + v = v.Convert(reflect.TypeOf(int64(0))) + case reflect.Bool: + b := v.Interface().(bool) + if b { + v = reflect.ValueOf(int64(1)) + } else { + v = reflect.ValueOf(int64(0)) + } + default: + return fmt.Errorf("cannot convert %s to INTEGER", v.Type()) + } + + C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64))) + return nil +} + +func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error { + switch v.Type().Kind() { + case reflect.Float64: + case reflect.Float32: + v = v.Convert(reflect.TypeOf(float64(0))) + default: + return fmt.Errorf("cannot convert %s to FLOAT", v.Type()) + } + + C.sqlite3_result_double(ctx, C.double(v.Interface().(float64))) + return nil +} + +func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error { + if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 { + return fmt.Errorf("cannot convert %s to BLOB", v.Type()) + } + i := v.Interface() + if i == nil || len(i.([]byte)) == 0 { + C.sqlite3_result_null(ctx) + } else { + bs := i.([]byte) + C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs))) + } + return nil +} + +func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error { + if v.Type().Kind() != reflect.String { + return fmt.Errorf("cannot convert %s to TEXT", v.Type()) + } + C._sqlite3_result_text(ctx, C.CString(v.Interface().(string))) + return nil +} + +func callbackRet(typ reflect.Type) (callbackRetConverter, error) { + switch typ.Kind() { + case reflect.Slice: + if typ.Elem().Kind() != reflect.Uint8 { + return nil, errors.New("the only supported slice type is []byte") + } + return callbackRetBlob, nil + case reflect.String: + return callbackRetText, nil + case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: + return callbackRetInteger, nil + case reflect.Float32, reflect.Float64: + return callbackRetFloat, nil + default: + return nil, fmt.Errorf("don't know how to convert to %s", typ) + } +} + +// Test support code. Tests are not allowed to import "C", so we can't +// declare any functions that use C.sqlite3_value. +func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter { + return func(*C.sqlite3_value) (reflect.Value, error) { + return v, err + } +} diff --git a/callback_test.go b/callback_test.go new file mode 100644 index 0000000..5c61f44 --- /dev/null +++ b/callback_test.go @@ -0,0 +1,97 @@ +package sqlite3 + +import ( + "errors" + "math" + "reflect" + "testing" +) + +func TestCallbackArgCast(t *testing.T) { + intConv := callbackSyntheticForTests(reflect.ValueOf(int64(math.MaxInt64)), nil) + floatConv := callbackSyntheticForTests(reflect.ValueOf(float64(math.MaxFloat64)), nil) + errConv := callbackSyntheticForTests(reflect.Value{}, errors.New("test")) + + tests := []struct { + f callbackArgConverter + o reflect.Value + }{ + {intConv, reflect.ValueOf(int8(-1))}, + {intConv, reflect.ValueOf(int16(-1))}, + {intConv, reflect.ValueOf(int32(-1))}, + {intConv, reflect.ValueOf(uint8(math.MaxUint8))}, + {intConv, reflect.ValueOf(uint16(math.MaxUint16))}, + {intConv, reflect.ValueOf(uint32(math.MaxUint32))}, + // Special case, int64->uint64 is only 1<<63 - 1, not 1<<64 - 1 + {intConv, reflect.ValueOf(uint64(math.MaxInt64))}, + {floatConv, reflect.ValueOf(float32(math.Inf(1)))}, + } + + for _, test := range tests { + conv := callbackArgCast{test.f, test.o.Type()} + val, err := conv.Run(nil) + if err != nil { + t.Errorf("Couldn't convert to %s: %s", test.o.Type(), err) + } else if !reflect.DeepEqual(val.Interface(), test.o.Interface()) { + t.Errorf("Unexpected result from converting to %s: got %v, want %v", test.o.Type(), val.Interface(), test.o.Interface()) + } + } + + conv := callbackArgCast{errConv, reflect.TypeOf(int8(0))} + _, err := conv.Run(nil) + if err == nil { + t.Errorf("Expected error during callbackArgCast, but got none") + } +} + +func TestCallbackConverters(t *testing.T) { + tests := []struct { + v interface{} + err bool + }{ + // Unfortunately, we can't tell which converter was returned, + // but we can at least check which types can be converted. + {[]byte{0}, false}, + {"text", false}, + {true, false}, + {int8(0), false}, + {int16(0), false}, + {int32(0), false}, + {int64(0), false}, + {uint8(0), false}, + {uint16(0), false}, + {uint32(0), false}, + {uint64(0), false}, + {int(0), false}, + {uint(0), false}, + {float64(0), false}, + {float32(0), false}, + + {func() {}, true}, + {complex64(complex(0, 0)), true}, + {complex128(complex(0, 0)), true}, + {struct{}{}, true}, + {map[string]string{}, true}, + {[]string{}, true}, + {(*int8)(nil), true}, + {make(chan int), true}, + } + + for _, test := range tests { + _, err := callbackArg(reflect.TypeOf(test.v)) + if test.err && err == nil { + t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v)) + } else if !test.err && err != nil { + t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err) + } + } + + for _, test := range tests { + _, err := callbackRet(reflect.TypeOf(test.v)) + if test.err && err == nil { + t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v)) + } else if !test.err && err != nil { + t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err) + } + } +} diff --git a/sqlite3.go b/sqlite3.go index f995589..174a3ee 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -166,7 +166,8 @@ type SQLiteRows struct { type functionInfo struct { f reflect.Value - argConverters []func(*C.sqlite3_value) (reflect.Value, error) + argConverters []callbackArgConverter + retConverter callbackRetConverter } func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) { @@ -193,58 +194,11 @@ func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { return } - res := ret[0].Interface() - // Normalize ret to one of the types sqlite knows. - switch r := res.(type) { - case int64, float64, []byte, string: - // Already the right type - case bool: - if r { - res = int64(1) - } else { - res = int64(0) - } - case int: - res = int64(r) - case uint: - res = int64(r) - case uint8: - res = int64(r) - case uint16: - res = int64(r) - case uint32: - res = int64(r) - case uint64: - res = int64(r) - case int8: - res = int64(r) - case int16: - res = int64(r) - case int32: - res = int64(r) - case float32: - res = float64(r) - default: - fi.error(ctx, errors.New("cannot convert returned type to sqlite type")) + err := fi.retConverter(ctx, ret[0]) + if err != nil { + fi.error(ctx, err) return } - - switch r := res.(type) { - case int64: - C.sqlite3_result_int64(ctx, C.sqlite3_int64(r)) - case float64: - C.sqlite3_result_double(ctx, C.double(r)) - case []byte: - if len(r) == 0 { - C.sqlite3_result_null(ctx) - } else { - C._sqlite3_result_blob(ctx, unsafe.Pointer(&r[0]), C.int(len(r))) - } - case string: - C._sqlite3_result_text(ctx, C.CString(r)) - default: - panic("unreachable") - } } // Commit transaction. @@ -261,10 +215,10 @@ func (tx *SQLiteTx) Rollback() error { // RegisterFunc makes a Go function available as a SQLite function. // -// The function must accept only arguments of type int64, float64, -// []byte or string, and return one value of any numeric type except -// complex, bool, []byte or string. Optionally, an error can be -// provided as a second return value. +// The function can accept arguments of any real numeric type +// (i.e. not complex), as well as []byte and string. It must return a +// value of one of those types, and optionally an error as a second +// value. // // If pure is true. SQLite will assume that the function's return // value depends only on its inputs, and make more aggressive @@ -287,59 +241,19 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro } for i := 0; i < t.NumIn(); i++ { - arg := t.In(i) - var conv func(*C.sqlite3_value) (reflect.Value, error) - switch arg.Kind() { - case reflect.Int64: - conv = func(v *C.sqlite3_value) (reflect.Value, error) { - if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { - return reflect.Value{}, fmt.Errorf("Argument %d to %s must be an INTEGER", i+1, name) - } - return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil - } - case reflect.Float64: - conv = func(v *C.sqlite3_value) (reflect.Value, error) { - if C.sqlite3_value_type(v) != C.SQLITE_FLOAT { - return reflect.Value{}, fmt.Errorf("Argument %d to %s must be a FLOAT", i+1, name) - } - return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil - } - case reflect.Slice: - if arg.Elem().Kind() != reflect.Uint8 { - return errors.New("The only supported slice type is []byte") - } - conv = func(v *C.sqlite3_value) (reflect.Value, error) { - switch C.sqlite3_value_type(v) { - case C.SQLITE_BLOB: - l := C.sqlite3_value_bytes(v) - p := C.sqlite3_value_blob(v) - return reflect.ValueOf(C.GoBytes(p, l)), nil - case C.SQLITE_TEXT: - l := C.sqlite3_value_bytes(v) - c := unsafe.Pointer(C.sqlite3_value_text(v)) - return reflect.ValueOf(C.GoBytes(c, l)), nil - default: - return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name) - } - } - case reflect.String: - conv = func(v *C.sqlite3_value) (reflect.Value, error) { - switch C.sqlite3_value_type(v) { - case C.SQLITE_BLOB: - l := C.sqlite3_value_bytes(v) - p := (*C.char)(C.sqlite3_value_blob(v)) - return reflect.ValueOf(C.GoStringN(p, l)), nil - case C.SQLITE_TEXT: - c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v))) - return reflect.ValueOf(C.GoString(c)), nil - default: - return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name) - } - } + conv, err := callbackArg(t.In(i)) + if err != nil { + return err } fi.argConverters = append(fi.argConverters, conv) } + conv, err := callbackRet(t.Out(0)) + if err != nil { + return err + } + fi.retConverter = conv + // fi must outlast the database connection, or we'll have dangling pointers. c.funcs = append(c.funcs, &fi) diff --git a/sqlite3_test.go b/sqlite3_test.go index a58e373..e8dfe5c 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -15,6 +15,7 @@ import ( "net/url" "os" "path/filepath" + "reflect" "regexp" "strings" "sync" @@ -1060,25 +1061,41 @@ func TestDateTimeNow(t *testing.T) { } func TestFunctionRegistration(t *testing.T) { - custom_add := func(a, b int64) (int64, error) { - return a + b, nil - } - custom_regex := func(s, re string) bool { - matched, err := regexp.MatchString(re, s) - if err != nil { - // We should really return the error here, but this - // function is also testing single return value functions. - panic("Bad regexp") - } - return matched + addi_8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) } + addi_64 := func(a, b int64) int64 { return a + b } + addu_8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) } + addu_64 := func(a, b uint64) uint64 { return a + b } + addiu := func(a int, b uint) int64 { return int64(a) + int64(b) } + addf_32_64 := func(a float32, b float64) float64 { return float64(a) + b } + not := func(a bool) bool { return !a } + regex := func(re, s string) (bool, error) { + return regexp.MatchString(re, s) } sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ ConnectHook: func(conn *SQLiteConn) error { - if err := conn.RegisterFunc("custom_add", custom_add, true); err != nil { + if err := conn.RegisterFunc("addi_8_16_32", addi_8_16_32, true); err != nil { + return err + } + if err := conn.RegisterFunc("addi_64", addi_64, true); err != nil { + return err + } + if err := conn.RegisterFunc("addu_8_16_32", addu_8_16_32, true); err != nil { return err } - if err := conn.RegisterFunc("regexp", custom_regex, true); err != nil { + if err := conn.RegisterFunc("addu_64", addu_64, true); err != nil { + return err + } + if err := conn.RegisterFunc("addiu", addiu, true); err != nil { + return err + } + if err := conn.RegisterFunc("addf_32_64", addf_32_64, true); err != nil { + return err + } + if err := conn.RegisterFunc("not", not, true); err != nil { + return err + } + if err := conn.RegisterFunc("regex", regex, true); err != nil { return err } return nil @@ -1090,42 +1107,29 @@ func TestFunctionRegistration(t *testing.T) { } defer db.Close() - additions := []struct { - a, b, c int64 + ops := []struct { + query string + expected interface{} }{ - {1, 1, 2}, - {1, 3, 4}, - {1, -1, 0}, - } - - for _, add := range additions { - var i int64 - err = db.QueryRow("SELECT custom_add($1, $2)", add.a, add.b).Scan(&i) + {"SELECT addi_8_16_32(1,2)", int32(3)}, + {"SELECT addi_64(1,2)", int64(3)}, + {"SELECT addu_8_16_32(1,2)", uint32(3)}, + {"SELECT addu_64(1,2)", uint64(3)}, + {"SELECT addiu(1,2)", int64(3)}, + {"SELECT addf_32_64(1.5,1.5)", float64(3)}, + {"SELECT not(1)", false}, + {"SELECT not(0)", true}, + {`SELECT regex("^foo.*", "foobar")`, true}, + {`SELECT regex("^foo.*", "barfoobar")`, false}, + } + + for _, op := range ops { + ret := reflect.New(reflect.TypeOf(op.expected)) + err = db.QueryRow(op.query).Scan(ret.Interface()) if err != nil { - t.Fatal("Failed to call custom_add:", err) - } - if i != add.c { - t.Fatalf("custom_add returned the wrong value, got %d, want %d", i, add.c) - } - } - - regexes := []struct { - re, in string - out bool - }{ - {".*", "foo", true}, - {"^foo.*", "foobar", true}, - {"^foo.*", "barfoo", false}, - } - - for _, re := range regexes { - var b bool - err = db.QueryRow("SELECT regexp($1, $2)", re.in, re.re).Scan(&b) - if err != nil { - t.Fatal("Failed to call regexp:", err) - } - if b != re.out { - t.Fatalf("regexp returned the wrong value, got %v, want %v", b, re.out) + t.Errorf("Query %q failed: %s", op.query, err) + } else if !reflect.DeepEqual(ret.Elem().Interface(), op.expected) { + t.Errorf("Query %q returned wrong value: got %v (%T), want %v (%T)", op.query, ret.Elem().Interface(), ret.Elem().Interface(), op.expected, op.expected) } } } @@ -1134,8 +1138,8 @@ var customFunctionOnce sync.Once func BenchmarkCustomFunctions(b *testing.B) { customFunctionOnce.Do(func() { - custom_add := func(a, b int64) (int64, error) { - return a + b, nil + custom_add := func(a, b int64) int64 { + return a + b } sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{ -- cgit v1.2.3 From 566f63a43a314f8dcd758dba8c40dc11edc27a5e Mon Sep 17 00:00:00 2001 From: David Anderson Date: Fri, 21 Aug 2015 16:34:55 -0700 Subject: Implement support for variadic functions. Currently, the variadic part must all be the same type, because there's no "generic" arg converter. --- sqlite3.go | 52 ++++++++++++++++++++++++++++++++++++++++++---------- sqlite3_test.go | 13 +++++++++++++ 2 files changed, 55 insertions(+), 10 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3.go b/sqlite3.go index 174a3ee..8bb9826 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -165,9 +165,10 @@ type SQLiteRows struct { } type functionInfo struct { - f reflect.Value - argConverters []callbackArgConverter - retConverter callbackRetConverter + f reflect.Value + argConverters []callbackArgConverter + variadicConverter callbackArgConverter + retConverter callbackRetConverter } func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) { @@ -178,7 +179,12 @@ func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) { func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { var args []reflect.Value - for i, arg := range argv { + + if len(argv) < len(fi.argConverters) { + fi.error(ctx, fmt.Errorf("function requires at least %d arguments", len(fi.argConverters))) + } + + for i, arg := range argv[:len(fi.argConverters)] { v, err := fi.argConverters[i](arg) if err != nil { fi.error(ctx, err) @@ -187,6 +193,17 @@ func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { args = append(args, v) } + if fi.variadicConverter != nil { + for _, arg := range argv[len(fi.argConverters):] { + v, err := fi.variadicConverter(arg) + if err != nil { + fi.error(ctx, err) + return + } + args = append(args, v) + } + } + ret := fi.f.Call(args) if len(ret) == 2 && ret[1].Interface() != nil { @@ -218,7 +235,8 @@ func (tx *SQLiteTx) Rollback() error { // The function can accept arguments of any real numeric type // (i.e. not complex), as well as []byte and string. It must return a // value of one of those types, and optionally an error as a second -// value. +// value. Variadic functions are allowed, if the variadic argument is +// one of the allowed types. // // If pure is true. SQLite will assume that the function's return // value depends only on its inputs, and make more aggressive @@ -230,9 +248,6 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro if t.Kind() != reflect.Func { return errors.New("Non-function passed to RegisterFunc") } - if t.IsVariadic() { - return errors.New("Variadic SQLite functions are not supported") - } if t.NumOut() != 1 && t.NumOut() != 2 { return errors.New("SQLite functions must return 1 or 2 values") } @@ -240,7 +255,12 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro return errors.New("Second return value of SQLite function must be error") } - for i := 0; i < t.NumIn(); i++ { + numArgs := t.NumIn() + if t.IsVariadic() { + numArgs-- + } + + for i := 0; i < numArgs; i++ { conv, err := callbackArg(t.In(i)) if err != nil { return err @@ -248,6 +268,18 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro fi.argConverters = append(fi.argConverters, conv) } + if t.IsVariadic() { + conv, err := callbackArg(t.In(numArgs).Elem()) + if err != nil { + return err + } + fi.variadicConverter = conv + // Pass -1 to sqlite so that it allows any number of + // arguments. The call helper verifies that the minimum number + // of arguments is present for variadic functions. + numArgs = -1 + } + conv, err := callbackRet(t.Out(0)) if err != nil { return err @@ -263,7 +295,7 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro if pure { opts |= C.SQLITE_DETERMINISTIC } - rv := C.sqlite3_create_function_v2(c.db, cname, C.int(t.NumIn()), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil) + rv := C.sqlite3_create_function_v2(c.db, cname, C.int(numArgs), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil) if rv != C.SQLITE_OK { return c.lastError() } diff --git a/sqlite3_test.go b/sqlite3_test.go index e8dfe5c..a563c08 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1071,6 +1071,13 @@ func TestFunctionRegistration(t *testing.T) { regex := func(re, s string) (bool, error) { return regexp.MatchString(re, s) } + variadic := func(a, b int64, c ...int64) int64 { + ret := a + b + for _, d := range c { + ret += d + } + return ret + } sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ ConnectHook: func(conn *SQLiteConn) error { @@ -1098,6 +1105,9 @@ func TestFunctionRegistration(t *testing.T) { if err := conn.RegisterFunc("regex", regex, true); err != nil { return err } + if err := conn.RegisterFunc("variadic", variadic, true); err != nil { + return err + } return nil }, }) @@ -1121,6 +1131,9 @@ func TestFunctionRegistration(t *testing.T) { {"SELECT not(0)", true}, {`SELECT regex("^foo.*", "foobar")`, true}, {`SELECT regex("^foo.*", "barfoobar")`, false}, + {"SELECT variadic(1,2)", int64(3)}, + {"SELECT variadic(1,2,3,4)", int64(10)}, + {"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)}, } for _, op := range ops { -- cgit v1.2.3 From b037a616903746de8e647f53503d4edca29192ec Mon Sep 17 00:00:00 2001 From: David Anderson Date: Fri, 21 Aug 2015 17:12:18 -0700 Subject: Add support for interface{} arguments in Go SQLite functions. This enabled support for functions like Foo(a interface{}) and Bar(a ...interface{}). --- callback.go | 24 ++++++++++++++++++++++++ sqlite3.go | 13 ++++++++----- sqlite3_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 5 deletions(-) (limited to 'sqlite3_test.go') diff --git a/callback.go b/callback.go index 1692106..b1704fe 100644 --- a/callback.go +++ b/callback.go @@ -108,8 +108,32 @@ func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) { } } +func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_INTEGER: + return callbackArgInt64(v) + case C.SQLITE_FLOAT: + return callbackArgFloat64(v) + case C.SQLITE_TEXT: + return callbackArgString(v) + case C.SQLITE_BLOB: + return callbackArgBytes(v) + case C.SQLITE_NULL: + // Interpret NULL as a nil byte slice. + var ret []byte + return reflect.ValueOf(ret), nil + default: + panic("unreachable") + } +} + func callbackArg(typ reflect.Type) (callbackArgConverter, error) { switch typ.Kind() { + case reflect.Interface: + if typ.NumMethod() != 0 { + return nil, errors.New("the only supported interface type is interface{}") + } + return callbackArgGeneric, nil case reflect.Slice: if typ.Elem().Kind() != reflect.Uint8 { return nil, errors.New("the only supported slice type is []byte") diff --git a/sqlite3.go b/sqlite3.go index 8bb9826..73e67e3 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -232,11 +232,14 @@ func (tx *SQLiteTx) Rollback() error { // RegisterFunc makes a Go function available as a SQLite function. // -// The function can accept arguments of any real numeric type -// (i.e. not complex), as well as []byte and string. It must return a -// value of one of those types, and optionally an error as a second -// value. Variadic functions are allowed, if the variadic argument is -// one of the allowed types. +// The Go function can have arguments of the following types: any +// numeric type except complex, bool, []byte, string and +// interface{}. interface{} arguments are given the direct translation +// of the SQLite data type: int64 for INTEGER, float64 for FLOAT, +// []byte for BLOB, string for TEXT. +// +// The function can additionally be variadic, as long as the type of +// the variadic argument is one of the above. // // If pure is true. SQLite will assume that the function's return // value depends only on its inputs, and make more aggressive diff --git a/sqlite3_test.go b/sqlite3_test.go index a563c08..62db05b 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1071,6 +1071,20 @@ func TestFunctionRegistration(t *testing.T) { regex := func(re, s string) (bool, error) { return regexp.MatchString(re, s) } + generic := func(a interface{}) int64 { + switch a.(type) { + case int64: + return 1 + case float64: + return 2 + case []byte: + return 3 + case string: + return 4 + default: + panic("unreachable") + } + } variadic := func(a, b int64, c ...int64) int64 { ret := a + b for _, d := range c { @@ -1078,6 +1092,9 @@ func TestFunctionRegistration(t *testing.T) { } return ret } + variadicGeneric := func(a ...interface{}) int64 { + return int64(len(a)) + } sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ ConnectHook: func(conn *SQLiteConn) error { @@ -1105,9 +1122,15 @@ func TestFunctionRegistration(t *testing.T) { if err := conn.RegisterFunc("regex", regex, true); err != nil { return err } + if err := conn.RegisterFunc("generic", generic, true); err != nil { + return err + } if err := conn.RegisterFunc("variadic", variadic, true); err != nil { return err } + if err := conn.RegisterFunc("variadicGeneric", variadicGeneric, true); err != nil { + return err + } return nil }, }) @@ -1131,9 +1154,14 @@ func TestFunctionRegistration(t *testing.T) { {"SELECT not(0)", true}, {`SELECT regex("^foo.*", "foobar")`, true}, {`SELECT regex("^foo.*", "barfoobar")`, false}, + {"SELECT generic(1)", int64(1)}, + {"SELECT generic(1.1)", int64(2)}, + {`SELECT generic(NULL)`, int64(3)}, + {`SELECT generic("foo")`, int64(4)}, {"SELECT variadic(1,2)", int64(3)}, {"SELECT variadic(1,2,3,4)", int64(10)}, {"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)}, + {`SELECT variadicGeneric(1,"foo",2.3, NULL)`, int64(4)}, } for _, op := range ops { -- cgit v1.2.3 From 26917df7a6010a157123c4bf60e3d57eff2948e4 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Fri, 21 Aug 2015 20:31:41 -0700 Subject: Implement support for aggregation functions implemented in Go. --- _example/go_custom_funcs/go_custom_funcs | Bin 0 -> 6601208 bytes _example/go_custom_funcs/main.go | 133 +++++++++++++++++ callback.go | 47 ++++++ sqlite3.go | 243 ++++++++++++++++++++++++++----- sqlite3_test.go | 59 ++++++++ 5 files changed, 449 insertions(+), 33 deletions(-) create mode 100755 _example/go_custom_funcs/go_custom_funcs create mode 100644 _example/go_custom_funcs/main.go (limited to 'sqlite3_test.go') diff --git a/_example/go_custom_funcs/go_custom_funcs b/_example/go_custom_funcs/go_custom_funcs new file mode 100755 index 0000000..b6be764 Binary files /dev/null and b/_example/go_custom_funcs/go_custom_funcs differ diff --git a/_example/go_custom_funcs/main.go b/_example/go_custom_funcs/main.go new file mode 100644 index 0000000..85657e6 --- /dev/null +++ b/_example/go_custom_funcs/main.go @@ -0,0 +1,133 @@ +package main + +import ( + "database/sql" + "fmt" + "log" + "math" + "math/rand" + + sqlite "github.com/mattn/go-sqlite3" +) + +// Computes x^y +func pow(x, y int64) int64 { + return int64(math.Pow(float64(x), float64(y))) +} + +// Computes the bitwise exclusive-or of all its arguments +func xor(xs ...int64) int64 { + var ret int64 + for _, x := range xs { + ret ^= x + } + return ret +} + +// Returns a random number. It's actually deterministic here because +// we don't seed the RNG, but it's an example of a non-pure function +// from SQLite's POV. +func getrand() int64 { + return rand.Int63() +} + +// Computes the standard deviation of a GROUPed BY set of values +type stddev struct { + xs []int64 + // Running average calculation + sum int64 + n int64 +} + +func newStddev() *stddev { return &stddev{} } + +func (s *stddev) Step(x int64) { + s.xs = append(s.xs, x) + s.sum += x + s.n++ +} + +func (s *stddev) Done() float64 { + mean := float64(s.sum) / float64(s.n) + var sqDiff []float64 + for _, x := range s.xs { + sqDiff = append(sqDiff, math.Pow(float64(x)-mean, 2)) + } + var dev float64 + for _, x := range sqDiff { + dev += x + } + dev /= float64(len(sqDiff)) + return math.Sqrt(dev) +} + +func main() { + sql.Register("sqlite3_custom", &sqlite.SQLiteDriver{ + ConnectHook: func(conn *sqlite.SQLiteConn) error { + if err := conn.RegisterFunc("pow", pow, true); err != nil { + return err + } + if err := conn.RegisterFunc("xor", xor, true); err != nil { + return err + } + if err := conn.RegisterFunc("rand", getrand, false); err != nil { + return err + } + if err := conn.RegisterAggregator("stddev", newStddev, true); err != nil { + return err + } + return nil + }, + }) + + db, err := sql.Open("sqlite3_custom", ":memory:") + if err != nil { + log.Fatal("Failed to open database:", err) + } + defer db.Close() + + var i int64 + err = db.QueryRow("SELECT pow(2,3)").Scan(&i) + if err != nil { + log.Fatal("POW query error:", err) + } + fmt.Println("pow(2,3) =", i) // 8 + + err = db.QueryRow("SELECT xor(1,2,3,4,5,6)").Scan(&i) + if err != nil { + log.Fatal("XOR query error:", err) + } + fmt.Println("xor(1,2,3,4,5) =", i) // 7 + + err = db.QueryRow("SELECT rand()").Scan(&i) + if err != nil { + log.Fatal("RAND query error:", err) + } + fmt.Println("rand() =", i) // pseudorandom + + _, err = db.Exec("create table foo (department integer, profits integer)") + if err != nil { + log.Fatal("Failed to create table:", err) + } + _, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115)") + if err != nil { + log.Fatal("Failed to insert records:", err) + } + + rows, err := db.Query("select department, stddev(profits) from foo group by department") + if err != nil { + log.Fatal("STDDEV query error:", err) + } + defer rows.Close() + for rows.Next() { + var dept int64 + var dev float64 + if err := rows.Scan(&dept, &dev); err != nil { + log.Fatal(err) + } + fmt.Printf("dept=%d stddev=%f\n", dept, dev) + } + if err := rows.Err(); err != nil { + log.Fatal(err) + } +} diff --git a/callback.go b/callback.go index b1704fe..61fc8d1 100644 --- a/callback.go +++ b/callback.go @@ -12,6 +12,7 @@ package sqlite3 /* #include +#include void _sqlite3_result_text(sqlite3_context* ctx, const char* s); void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l); @@ -32,6 +33,19 @@ func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value fi.Call(ctx, args) } +//export stepTrampoline +func stepTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) { + args := (*[1 << 30]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc] + ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx))) + ai.Step(ctx, args) +} + +//export doneTrampoline +func doneTrampoline(ctx *C.sqlite3_context) { + ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx))) + ai.Done(ctx) +} + // This is only here so that tests can refer to it. type callbackArgRaw C.sqlite3_value @@ -158,6 +172,33 @@ func callbackArg(typ reflect.Type) (callbackArgConverter, error) { } } +func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) { + var args []reflect.Value + + if len(argv) < len(converters) { + return nil, fmt.Errorf("function requires at least %d arguments", len(converters)) + } + + for i, arg := range argv[:len(converters)] { + v, err := converters[i](arg) + if err != nil { + return nil, err + } + args = append(args, v) + } + + if variadic != nil { + for _, arg := range argv[len(converters):] { + v, err := variadic(arg) + if err != nil { + return nil, err + } + args = append(args, v) + } + } + return args, nil +} + type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error { @@ -233,6 +274,12 @@ func callbackRet(typ reflect.Type) (callbackRetConverter, error) { } } +func callbackError(ctx *C.sqlite3_context, err error) { + cstr := C.CString(err.Error()) + defer C.free(unsafe.Pointer(cstr)) + C.sqlite3_result_error(ctx, cstr, -1) +} + // Test support code. Tests are not allowed to import "C", so we can't // declare any functions that use C.sqlite3_value. func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter { diff --git a/sqlite3.go b/sqlite3.go index 73e67e3..8d2faca 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -75,6 +75,8 @@ void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l) { } void callbackTrampoline(sqlite3_context*, int, sqlite3_value**); +void stepTrampoline(sqlite3_context*, int, sqlite3_value**); +void doneTrampoline(sqlite3_context*); */ import "C" import ( @@ -127,10 +129,11 @@ type SQLiteDriver struct { // Conn struct. type SQLiteConn struct { - db *C.sqlite3 - loc *time.Location - txlock string - funcs []*functionInfo + db *C.sqlite3 + loc *time.Location + txlock string + funcs []*functionInfo + aggregators []*aggInfo } // Tx struct. @@ -171,49 +174,96 @@ type functionInfo struct { retConverter callbackRetConverter } -func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) { - cstr := C.CString(err.Error()) - defer C.free(unsafe.Pointer(cstr)) - C.sqlite3_result_error(ctx, cstr, -1) -} - func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { - var args []reflect.Value + args, err := callbackConvertArgs(argv, fi.argConverters, fi.variadicConverter) + if err != nil { + callbackError(ctx, err) + return + } + + ret := fi.f.Call(args) - if len(argv) < len(fi.argConverters) { - fi.error(ctx, fmt.Errorf("function requires at least %d arguments", len(fi.argConverters))) + if len(ret) == 2 && ret[1].Interface() != nil { + callbackError(ctx, ret[1].Interface().(error)) + return } - for i, arg := range argv[:len(fi.argConverters)] { - v, err := fi.argConverters[i](arg) - if err != nil { - fi.error(ctx, err) - return - } - args = append(args, v) + err = fi.retConverter(ctx, ret[0]) + if err != nil { + callbackError(ctx, err) + return } +} - if fi.variadicConverter != nil { - for _, arg := range argv[len(fi.argConverters):] { - v, err := fi.variadicConverter(arg) - if err != nil { - fi.error(ctx, err) - return - } - args = append(args, v) +type aggInfo struct { + constructor reflect.Value + + // Active aggregator objects for aggregations in flight. The + // aggregators are indexed by a counter stored in the aggregation + // user data space provided by sqlite. + active map[int64]reflect.Value + next int64 + + stepArgConverters []callbackArgConverter + stepVariadicConverter callbackArgConverter + + doneRetConverter callbackRetConverter +} + +func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) { + aggIdx := (*int64)(C.sqlite3_aggregate_context(ctx, C.int(8))) + if *aggIdx == 0 { + *aggIdx = ai.next + ret := ai.constructor.Call(nil) + if len(ret) == 2 && ret[1].Interface() != nil { + return 0, reflect.Value{}, ret[1].Interface().(error) } + if ret[0].IsNil() { + return 0, reflect.Value{}, errors.New("aggregator constructor returned nil state") + } + ai.next++ + ai.active[*aggIdx] = ret[0] } + return *aggIdx, ai.active[*aggIdx], nil +} - ret := fi.f.Call(args) +func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { + _, agg, err := ai.agg(ctx) + if err != nil { + callbackError(ctx, err) + return + } + + args, err := callbackConvertArgs(argv, ai.stepArgConverters, ai.stepVariadicConverter) + if err != nil { + callbackError(ctx, err) + return + } + + ret := agg.MethodByName("Step").Call(args) + if len(ret) == 1 && ret[0].Interface() != nil { + callbackError(ctx, ret[0].Interface().(error)) + return + } +} + +func (ai *aggInfo) Done(ctx *C.sqlite3_context) { + idx, agg, err := ai.agg(ctx) + if err != nil { + callbackError(ctx, err) + return + } + defer func() { delete(ai.active, idx) }() + ret := agg.MethodByName("Done").Call(nil) if len(ret) == 2 && ret[1].Interface() != nil { - fi.error(ctx, ret[1].Interface().(error)) + callbackError(ctx, ret[1].Interface().(error)) return } - err := fi.retConverter(ctx, ret[0]) + err = ai.doneRetConverter(ctx, ret[0]) if err != nil { - fi.error(ctx, err) + callbackError(ctx, err) return } } @@ -244,6 +294,8 @@ func (tx *SQLiteTx) Rollback() error { // If pure is true. SQLite will assume that the function's return // value depends only on its inputs, and make more aggressive // optimizations in its queries. +// +// See _example/go_custom_funcs for a detailed example. func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) error { var fi functionInfo fi.f = reflect.ValueOf(impl) @@ -298,7 +350,132 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro if pure { opts |= C.SQLITE_DETERMINISTIC } - rv := C.sqlite3_create_function_v2(c.db, cname, C.int(numArgs), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil) + rv := C.sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil) + if rv != C.SQLITE_OK { + return c.lastError() + } + return nil +} + +// RegisterAggregator makes a Go type available as a SQLite aggregation function. +// +// Because aggregation is incremental, it's implemented in Go with a +// type that has 2 methods: func Step(values) accumulates one row of +// data into the accumulator, and func Done() ret finalizes and +// returns the aggregate value. "values" and "ret" may be any type +// supported by RegisterFunc. +// +// RegisterAggregator takes as implementation a constructor function +// that constructs an instance of the aggregator type each time an +// aggregation begins. The constructor must return a pointer to a +// type, or an interface that implements Step() and Done(). +// +// The constructor function and the Step/Done methods may optionally +// return an error in addition to their other return values. +// +// See _example/go_custom_funcs for a detailed example. +func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error { + var ai aggInfo + ai.constructor = reflect.ValueOf(impl) + t := ai.constructor.Type() + if t.Kind() != reflect.Func { + return errors.New("non-function passed to RegisterAggregator") + } + if t.NumOut() != 1 && t.NumOut() != 2 { + return errors.New("SQLite aggregator constructors must return 1 or 2 values") + } + if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return errors.New("Second return value of SQLite function must be error") + } + if t.NumIn() != 0 { + return errors.New("SQLite aggregator constructors must not have arguments") + } + + agg := t.Out(0) + switch agg.Kind() { + case reflect.Ptr, reflect.Interface: + default: + return errors.New("SQlite aggregator constructor must return a pointer object") + } + stepFn, found := agg.MethodByName("Step") + if !found { + return errors.New("SQlite aggregator doesn't have a Step() function") + } + step := stepFn.Type + if step.NumOut() != 0 && step.NumOut() != 1 { + return errors.New("SQlite aggregator Step() function must return 0 or 1 values") + } + if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return errors.New("type of SQlite aggregator Step() return value must be error") + } + + stepNArgs := step.NumIn() + start := 0 + if agg.Kind() == reflect.Ptr { + // Skip over the method receiver + stepNArgs-- + start++ + } + if step.IsVariadic() { + stepNArgs-- + } + for i := start; i < start+stepNArgs; i++ { + conv, err := callbackArg(step.In(i)) + if err != nil { + return err + } + ai.stepArgConverters = append(ai.stepArgConverters, conv) + } + if step.IsVariadic() { + conv, err := callbackArg(t.In(start + stepNArgs).Elem()) + if err != nil { + return err + } + ai.stepVariadicConverter = conv + // Pass -1 to sqlite so that it allows any number of + // arguments. The call helper verifies that the minimum number + // of arguments is present for variadic functions. + stepNArgs = -1 + } + + doneFn, found := agg.MethodByName("Done") + if !found { + return errors.New("SQlite aggregator doesn't have a Done() function") + } + done := doneFn.Type + doneNArgs := done.NumIn() + if agg.Kind() == reflect.Ptr { + // Skip over the method receiver + doneNArgs-- + } + if doneNArgs != 0 { + return errors.New("SQlite aggregator Done() function must have no arguments") + } + if done.NumOut() != 1 && done.NumOut() != 2 { + return errors.New("SQLite aggregator Done() function must return 1 or 2 values") + } + if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return errors.New("second return value of SQLite aggregator Done() function must be error") + } + + conv, err := callbackRet(done.Out(0)) + if err != nil { + return err + } + ai.doneRetConverter = conv + ai.active = make(map[int64]reflect.Value) + ai.next = 1 + + // ai must outlast the database connection, or we'll have dangling pointers. + c.aggregators = append(c.aggregators, &ai) + + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + opts := C.SQLITE_UTF8 + if pure { + opts |= C.SQLITE_DETERMINISTIC + } + rv := C.sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), unsafe.Pointer(&ai), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline))) if rv != C.SQLITE_OK { return c.lastError() } diff --git a/sqlite3_test.go b/sqlite3_test.go index 62db05b..74d3de1 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1175,6 +1175,65 @@ func TestFunctionRegistration(t *testing.T) { } } +type sumAggregator int64 + +func (s *sumAggregator) Step(x int64) { + *s += sumAggregator(x) +} + +func (s *sumAggregator) Done() int64 { + return int64(*s) +} + +func TestAggregatorRegistration(t *testing.T) { + customSum := func() *sumAggregator { + var ret sumAggregator + return &ret + } + + sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + if err := conn.RegisterAggregator("customSum", customSum, true); err != nil { + return err + } + return nil + }, + }) + db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("create table foo (department integer, profits integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)") + if err != nil { + t.Fatal("Failed to insert records:", err) + } + + tests := []struct { + dept, sum int64 + }{ + {1, 30}, + {2, 42}, + } + + for _, test := range tests { + var ret int64 + err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret) + if err != nil { + t.Fatal("Query failed:", err) + } + if ret != test.sum { + t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum) + } + } +} + var customFunctionOnce sync.Once func BenchmarkCustomFunctions(b *testing.B) { -- cgit v1.2.3