aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--callback.go19
-rw-r--r--callback_test.go12
-rw-r--r--sqlite3_test.go57
3 files changed, 88 insertions, 0 deletions
diff --git a/callback.go b/callback.go
index b020fe3..d305691 100644
--- a/callback.go
+++ b/callback.go
@@ -353,6 +353,20 @@ func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error {
return nil
}
+func callbackRetGeneric(ctx *C.sqlite3_context, v reflect.Value) error {
+ if v.IsNil() {
+ C.sqlite3_result_null(ctx)
+ return nil
+ }
+
+ cb, err := callbackRet(v.Elem().Type())
+ if err != nil {
+ return err
+ }
+
+ return cb(ctx, v.Elem())
+}
+
func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
switch typ.Kind() {
case reflect.Interface:
@@ -360,6 +374,11 @@ func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
if typ.Implements(errorInterface) {
return callbackRetNil, nil
}
+
+ if typ.NumMethod() == 0 {
+ return callbackRetGeneric, nil
+ }
+
fallthrough
case reflect.Slice:
if typ.Elem().Kind() != reflect.Uint8 {
diff --git a/callback_test.go b/callback_test.go
index 714ed60..b09122a 100644
--- a/callback_test.go
+++ b/callback_test.go
@@ -102,3 +102,15 @@ func TestCallbackConverters(t *testing.T) {
}
}
}
+
+func TestCallbackReturnAny(t *testing.T) {
+ udf := func() interface{} {
+ return 1
+ }
+
+ typ := reflect.TypeOf(udf)
+ _, err := callbackRet(typ.Out(0))
+ if err != nil {
+ t.Errorf("Expected valid callback for any return type, got: %s", err)
+ }
+}
diff --git a/sqlite3_test.go b/sqlite3_test.go
index c86aba4..9ee87e7 100644
--- a/sqlite3_test.go
+++ b/sqlite3_test.go
@@ -1449,6 +1449,63 @@ func TestAggregatorRegistration(t *testing.T) {
}
}
+type mode struct {
+ counts map[interface{}]int
+ top interface{}
+ topCount int
+}
+
+func newMode() *mode {
+ return &mode{
+ counts: map[interface{}]int{},
+ }
+}
+
+func (m *mode) Step(x interface{}) {
+ m.counts[x]++
+ c := m.counts[x]
+ if c > m.topCount {
+ m.top = x
+ m.topCount = c
+ }
+}
+
+func (m *mode) Done() interface{} {
+ return m.top
+}
+
+func TestAggregatorRegistration_GenericReturn(t *testing.T) {
+ sql.Register("sqlite3_AggregatorRegistration_GenericReturn", &SQLiteDriver{
+ ConnectHook: func(conn *SQLiteConn) error {
+ return conn.RegisterAggregator("mode", newMode, true)
+ },
+ })
+ db, err := sql.Open("sqlite3_AggregatorRegistration_GenericReturn", ":memory:")
+ if err != nil {
+ t.Fatal("Failed to open database:", err)
+ }
+ defer db.Close()
+
+ _, err = db.Exec("create table foo (department integer, profits integer)")
+ if err != nil {
+ t.Fatal("Failed to create table:", err)
+ }
+ _, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115), (2, 20)")
+ if err != nil {
+ t.Fatal("Failed to insert records:", err)
+ }
+
+ var mode int
+ err = db.QueryRow("select mode(profits) from foo").Scan(&mode)
+ if err != nil {
+ t.Fatal("MODE query error:", err)
+ }
+
+ if mode != 20 {
+ t.Fatal("Got incorrect mode. Wanted 20, got: ", mode)
+ }
+}
+
func rot13(r rune) rune {
switch {
case r >= 'A' && r <= 'Z':