diff options
Diffstat (limited to 'sqlite3_test.go')
-rw-r--r-- | sqlite3_test.go | 221 |
1 files changed, 218 insertions, 3 deletions
diff --git a/sqlite3_test.go b/sqlite3_test.go index 43f7b4f..90bee51 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -13,7 +13,11 @@ import ( "io/ioutil" "net/url" "os" + "path/filepath" + "reflect" + "regexp" "strings" + "sync" "testing" "time" @@ -87,13 +91,13 @@ func TestReadonly(t *testing.T) { tempFilename := TempFilename(t) defer os.Remove(tempFilename) - db1, err := sql.Open("sqlite3", "file:" + tempFilename) + db1, err := sql.Open("sqlite3", "file:"+tempFilename) if err != nil { t.Fatal(err) } db1.Exec("CREATE TABLE test (x int, y float)") - db2, err := sql.Open("sqlite3", "file:" + tempFilename + "?mode=ro") + db2, err := sql.Open("sqlite3", "file:"+tempFilename+"?mode=ro") if err != nil { t.Fatal(err) } @@ -792,10 +796,12 @@ func TestTimezoneConversion(t *testing.T) { } func TestSuite(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + tempFilename := TempFilename() + db, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999") if err != nil { t.Fatal(err) } + defer os.Remove(tempFilename) defer db.Close() sqlite3_test.RunTests(t, db, sqlite3_test.SQLITE) @@ -1079,3 +1085,212 @@ func TestDateTimeNow(t *testing.T) { t.Fatal("Failed to scan datetime:", err) } } + +func TestFunctionRegistration(t *testing.T) { + addi_8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) } + addi_64 := func(a, b int64) int64 { return a + b } + addu_8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) } + addu_64 := func(a, b uint64) uint64 { return a + b } + addiu := func(a int, b uint) int64 { return int64(a) + int64(b) } + addf_32_64 := func(a float32, b float64) float64 { return float64(a) + b } + not := func(a bool) bool { return !a } + regex := func(re, s string) (bool, error) { + return regexp.MatchString(re, s) + } + generic := func(a interface{}) int64 { + switch a.(type) { + case int64: + return 1 + case float64: + return 2 + case []byte: + return 3 + case string: + return 4 + default: + panic("unreachable") + } + } + variadic := func(a, b int64, c ...int64) int64 { + ret := a + b + for _, d := range c { + ret += d + } + return ret + } + variadicGeneric := func(a ...interface{}) int64 { + return int64(len(a)) + } + + sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + if err := conn.RegisterFunc("addi_8_16_32", addi_8_16_32, true); err != nil { + return err + } + if err := conn.RegisterFunc("addi_64", addi_64, true); err != nil { + return err + } + if err := conn.RegisterFunc("addu_8_16_32", addu_8_16_32, true); err != nil { + return err + } + if err := conn.RegisterFunc("addu_64", addu_64, true); err != nil { + return err + } + if err := conn.RegisterFunc("addiu", addiu, true); err != nil { + return err + } + if err := conn.RegisterFunc("addf_32_64", addf_32_64, true); err != nil { + return err + } + if err := conn.RegisterFunc("not", not, true); err != nil { + return err + } + if err := conn.RegisterFunc("regex", regex, true); err != nil { + return err + } + if err := conn.RegisterFunc("generic", generic, true); err != nil { + return err + } + if err := conn.RegisterFunc("variadic", variadic, true); err != nil { + return err + } + if err := conn.RegisterFunc("variadicGeneric", variadicGeneric, true); err != nil { + return err + } + return nil + }, + }) + db, err := sql.Open("sqlite3_FunctionRegistration", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + ops := []struct { + query string + expected interface{} + }{ + {"SELECT addi_8_16_32(1,2)", int32(3)}, + {"SELECT addi_64(1,2)", int64(3)}, + {"SELECT addu_8_16_32(1,2)", uint32(3)}, + {"SELECT addu_64(1,2)", uint64(3)}, + {"SELECT addiu(1,2)", int64(3)}, + {"SELECT addf_32_64(1.5,1.5)", float64(3)}, + {"SELECT not(1)", false}, + {"SELECT not(0)", true}, + {`SELECT regex("^foo.*", "foobar")`, true}, + {`SELECT regex("^foo.*", "barfoobar")`, false}, + {"SELECT generic(1)", int64(1)}, + {"SELECT generic(1.1)", int64(2)}, + {`SELECT generic(NULL)`, int64(3)}, + {`SELECT generic("foo")`, int64(4)}, + {"SELECT variadic(1,2)", int64(3)}, + {"SELECT variadic(1,2,3,4)", int64(10)}, + {"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)}, + {`SELECT variadicGeneric(1,"foo",2.3, NULL)`, int64(4)}, + } + + for _, op := range ops { + ret := reflect.New(reflect.TypeOf(op.expected)) + err = db.QueryRow(op.query).Scan(ret.Interface()) + if err != nil { + t.Errorf("Query %q failed: %s", op.query, err) + } else if !reflect.DeepEqual(ret.Elem().Interface(), op.expected) { + t.Errorf("Query %q returned wrong value: got %v (%T), want %v (%T)", op.query, ret.Elem().Interface(), ret.Elem().Interface(), op.expected, op.expected) + } + } +} + +type sumAggregator int64 + +func (s *sumAggregator) Step(x int64) { + *s += sumAggregator(x) +} + +func (s *sumAggregator) Done() int64 { + return int64(*s) +} + +func TestAggregatorRegistration(t *testing.T) { + customSum := func() *sumAggregator { + var ret sumAggregator + return &ret + } + + sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + if err := conn.RegisterAggregator("customSum", customSum, true); err != nil { + return err + } + return nil + }, + }) + db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("create table foo (department integer, profits integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)") + if err != nil { + t.Fatal("Failed to insert records:", err) + } + + tests := []struct { + dept, sum int64 + }{ + {1, 30}, + {2, 42}, + } + + for _, test := range tests { + var ret int64 + err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret) + if err != nil { + t.Fatal("Query failed:", err) + } + if ret != test.sum { + t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum) + } + } +} + +var customFunctionOnce sync.Once + +func BenchmarkCustomFunctions(b *testing.B) { + customFunctionOnce.Do(func() { + custom_add := func(a, b int64) int64 { + return a + b + } + + sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + // Impure function to force sqlite to reexecute it each time. + if err := conn.RegisterFunc("custom_add", custom_add, false); err != nil { + return err + } + return nil + }, + }) + }) + + db, err := sql.Open("sqlite3_BenchmarkCustomFunctions", ":memory:") + if err != nil { + b.Fatal("Failed to open database:", err) + } + defer db.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var i int64 + err = db.QueryRow("SELECT custom_add(1,2)").Scan(&i) + if err != nil { + b.Fatal("Failed to run custom add:", err) + } + } +} |