aboutsummaryrefslogtreecommitdiff
path: root/sqlite3_test
diff options
context:
space:
mode:
authorDavid Anderson <dave@natulte.net>2015-08-20 23:08:48 -0700
committerDavid Anderson <dave@natulte.net>2015-08-21 13:39:50 -0700
commitcf8fa0af80e0d227c79ef2b4635e8d0d77432275 (patch)
treeaa5d09e0d949847240ef50f3da2ad1b99f5cefe2 /sqlite3_test
parentMerge pull request #228 from whiter4bbit/added_icu_support (diff)
downloadgolite-cf8fa0af80e0d227c79ef2b4635e8d0d77432275.tar.gz
golite-cf8fa0af80e0d227c79ef2b4635e8d0d77432275.tar.xz
Implement support for passing Go functions as custom functions to SQLite.
Fixes #226.
Diffstat (limited to '')
-rw-r--r--sqlite3_test.go108
-rw-r--r--sqlite3_test/sqltest.go6
2 files changed, 111 insertions, 3 deletions
diff --git a/sqlite3_test.go b/sqlite3_test.go
index 423f30e..a58e373 100644
--- a/sqlite3_test.go
+++ b/sqlite3_test.go
@@ -15,7 +15,9 @@ import (
"net/url"
"os"
"path/filepath"
+ "regexp"
"strings"
+ "sync"
"testing"
"time"
@@ -1056,3 +1058,109 @@ func TestDateTimeNow(t *testing.T) {
t.Fatal("Failed to scan datetime:", err)
}
}
+
+func TestFunctionRegistration(t *testing.T) {
+ custom_add := func(a, b int64) (int64, error) {
+ return a + b, nil
+ }
+ custom_regex := func(s, re string) bool {
+ matched, err := regexp.MatchString(re, s)
+ if err != nil {
+ // We should really return the error here, but this
+ // function is also testing single return value functions.
+ panic("Bad regexp")
+ }
+ return matched
+ }
+
+ sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
+ ConnectHook: func(conn *SQLiteConn) error {
+ if err := conn.RegisterFunc("custom_add", custom_add, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("regexp", custom_regex, true); err != nil {
+ return err
+ }
+ return nil
+ },
+ })
+ db, err := sql.Open("sqlite3_FunctionRegistration", ":memory:")
+ if err != nil {
+ t.Fatal("Failed to open database:", err)
+ }
+ defer db.Close()
+
+ additions := []struct {
+ a, b, c int64
+ }{
+ {1, 1, 2},
+ {1, 3, 4},
+ {1, -1, 0},
+ }
+
+ for _, add := range additions {
+ var i int64
+ err = db.QueryRow("SELECT custom_add($1, $2)", add.a, add.b).Scan(&i)
+ if err != nil {
+ t.Fatal("Failed to call custom_add:", err)
+ }
+ if i != add.c {
+ t.Fatalf("custom_add returned the wrong value, got %d, want %d", i, add.c)
+ }
+ }
+
+ regexes := []struct {
+ re, in string
+ out bool
+ }{
+ {".*", "foo", true},
+ {"^foo.*", "foobar", true},
+ {"^foo.*", "barfoo", false},
+ }
+
+ for _, re := range regexes {
+ var b bool
+ err = db.QueryRow("SELECT regexp($1, $2)", re.in, re.re).Scan(&b)
+ if err != nil {
+ t.Fatal("Failed to call regexp:", err)
+ }
+ if b != re.out {
+ t.Fatalf("regexp returned the wrong value, got %v, want %v", b, re.out)
+ }
+ }
+}
+
+var customFunctionOnce sync.Once
+
+func BenchmarkCustomFunctions(b *testing.B) {
+ customFunctionOnce.Do(func() {
+ custom_add := func(a, b int64) (int64, error) {
+ return a + b, nil
+ }
+
+ sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{
+ ConnectHook: func(conn *SQLiteConn) error {
+ // Impure function to force sqlite to reexecute it each time.
+ if err := conn.RegisterFunc("custom_add", custom_add, false); err != nil {
+ return err
+ }
+ return nil
+ },
+ })
+ })
+
+ db, err := sql.Open("sqlite3_BenchmarkCustomFunctions", ":memory:")
+ if err != nil {
+ b.Fatal("Failed to open database:", err)
+ }
+ defer db.Close()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ var i int64
+ err = db.QueryRow("SELECT custom_add(1,2)").Scan(&i)
+ if err != nil {
+ b.Fatal("Failed to run custom add:", err)
+ }
+ }
+}
diff --git a/sqlite3_test/sqltest.go b/sqlite3_test/sqltest.go
index fc82782..782e15f 100644
--- a/sqlite3_test/sqltest.go
+++ b/sqlite3_test/sqltest.go
@@ -318,7 +318,7 @@ func BenchmarkQuery(b *testing.B) {
var i int
var f float64
var s string
-// var t time.Time
+ // var t time.Time
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
@@ -331,7 +331,7 @@ func BenchmarkParams(b *testing.B) {
var i int
var f float64
var s string
-// var t time.Time
+ // var t time.Time
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}
@@ -350,7 +350,7 @@ func BenchmarkStmt(b *testing.B) {
var i int
var f float64
var s string
-// var t time.Time
+ // var t time.Time
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err)
}