aboutsummaryrefslogtreecommitdiff
path: root/sqlite3_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'sqlite3_test.go')
-rw-r--r--sqlite3_test.go221
1 files changed, 218 insertions, 3 deletions
diff --git a/sqlite3_test.go b/sqlite3_test.go
index 43f7b4f..90bee51 100644
--- a/sqlite3_test.go
+++ b/sqlite3_test.go
@@ -13,7 +13,11 @@ import (
"io/ioutil"
"net/url"
"os"
+ "path/filepath"
+ "reflect"
+ "regexp"
"strings"
+ "sync"
"testing"
"time"
@@ -87,13 +91,13 @@ func TestReadonly(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
- db1, err := sql.Open("sqlite3", "file:" + tempFilename)
+ db1, err := sql.Open("sqlite3", "file:"+tempFilename)
if err != nil {
t.Fatal(err)
}
db1.Exec("CREATE TABLE test (x int, y float)")
- db2, err := sql.Open("sqlite3", "file:" + tempFilename + "?mode=ro")
+ db2, err := sql.Open("sqlite3", "file:"+tempFilename+"?mode=ro")
if err != nil {
t.Fatal(err)
}
@@ -792,10 +796,12 @@ func TestTimezoneConversion(t *testing.T) {
}
func TestSuite(t *testing.T) {
- db, err := sql.Open("sqlite3", ":memory:")
+ tempFilename := TempFilename()
+ db, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
if err != nil {
t.Fatal(err)
}
+ defer os.Remove(tempFilename)
defer db.Close()
sqlite3_test.RunTests(t, db, sqlite3_test.SQLITE)
@@ -1079,3 +1085,212 @@ func TestDateTimeNow(t *testing.T) {
t.Fatal("Failed to scan datetime:", err)
}
}
+
+func TestFunctionRegistration(t *testing.T) {
+ addi_8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) }
+ addi_64 := func(a, b int64) int64 { return a + b }
+ addu_8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) }
+ addu_64 := func(a, b uint64) uint64 { return a + b }
+ addiu := func(a int, b uint) int64 { return int64(a) + int64(b) }
+ addf_32_64 := func(a float32, b float64) float64 { return float64(a) + b }
+ not := func(a bool) bool { return !a }
+ regex := func(re, s string) (bool, error) {
+ return regexp.MatchString(re, s)
+ }
+ generic := func(a interface{}) int64 {
+ switch a.(type) {
+ case int64:
+ return 1
+ case float64:
+ return 2
+ case []byte:
+ return 3
+ case string:
+ return 4
+ default:
+ panic("unreachable")
+ }
+ }
+ variadic := func(a, b int64, c ...int64) int64 {
+ ret := a + b
+ for _, d := range c {
+ ret += d
+ }
+ return ret
+ }
+ variadicGeneric := func(a ...interface{}) int64 {
+ return int64(len(a))
+ }
+
+ sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
+ ConnectHook: func(conn *SQLiteConn) error {
+ if err := conn.RegisterFunc("addi_8_16_32", addi_8_16_32, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("addi_64", addi_64, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("addu_8_16_32", addu_8_16_32, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("addu_64", addu_64, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("addiu", addiu, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("addf_32_64", addf_32_64, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("not", not, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("regex", regex, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("generic", generic, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("variadic", variadic, true); err != nil {
+ return err
+ }
+ if err := conn.RegisterFunc("variadicGeneric", variadicGeneric, true); err != nil {
+ return err
+ }
+ return nil
+ },
+ })
+ db, err := sql.Open("sqlite3_FunctionRegistration", ":memory:")
+ if err != nil {
+ t.Fatal("Failed to open database:", err)
+ }
+ defer db.Close()
+
+ ops := []struct {
+ query string
+ expected interface{}
+ }{
+ {"SELECT addi_8_16_32(1,2)", int32(3)},
+ {"SELECT addi_64(1,2)", int64(3)},
+ {"SELECT addu_8_16_32(1,2)", uint32(3)},
+ {"SELECT addu_64(1,2)", uint64(3)},
+ {"SELECT addiu(1,2)", int64(3)},
+ {"SELECT addf_32_64(1.5,1.5)", float64(3)},
+ {"SELECT not(1)", false},
+ {"SELECT not(0)", true},
+ {`SELECT regex("^foo.*", "foobar")`, true},
+ {`SELECT regex("^foo.*", "barfoobar")`, false},
+ {"SELECT generic(1)", int64(1)},
+ {"SELECT generic(1.1)", int64(2)},
+ {`SELECT generic(NULL)`, int64(3)},
+ {`SELECT generic("foo")`, int64(4)},
+ {"SELECT variadic(1,2)", int64(3)},
+ {"SELECT variadic(1,2,3,4)", int64(10)},
+ {"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)},
+ {`SELECT variadicGeneric(1,"foo",2.3, NULL)`, int64(4)},
+ }
+
+ for _, op := range ops {
+ ret := reflect.New(reflect.TypeOf(op.expected))
+ err = db.QueryRow(op.query).Scan(ret.Interface())
+ if err != nil {
+ t.Errorf("Query %q failed: %s", op.query, err)
+ } else if !reflect.DeepEqual(ret.Elem().Interface(), op.expected) {
+ t.Errorf("Query %q returned wrong value: got %v (%T), want %v (%T)", op.query, ret.Elem().Interface(), ret.Elem().Interface(), op.expected, op.expected)
+ }
+ }
+}
+
+type sumAggregator int64
+
+func (s *sumAggregator) Step(x int64) {
+ *s += sumAggregator(x)
+}
+
+func (s *sumAggregator) Done() int64 {
+ return int64(*s)
+}
+
+func TestAggregatorRegistration(t *testing.T) {
+ customSum := func() *sumAggregator {
+ var ret sumAggregator
+ return &ret
+ }
+
+ sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{
+ ConnectHook: func(conn *SQLiteConn) error {
+ if err := conn.RegisterAggregator("customSum", customSum, true); err != nil {
+ return err
+ }
+ return nil
+ },
+ })
+ db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:")
+ if err != nil {
+ t.Fatal("Failed to open database:", err)
+ }
+ defer db.Close()
+
+ _, err = db.Exec("create table foo (department integer, profits integer)")
+ if err != nil {
+ t.Fatal("Failed to create table:", err)
+ }
+
+ _, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)")
+ if err != nil {
+ t.Fatal("Failed to insert records:", err)
+ }
+
+ tests := []struct {
+ dept, sum int64
+ }{
+ {1, 30},
+ {2, 42},
+ }
+
+ for _, test := range tests {
+ var ret int64
+ err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret)
+ if err != nil {
+ t.Fatal("Query failed:", err)
+ }
+ if ret != test.sum {
+ t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum)
+ }
+ }
+}
+
+var customFunctionOnce sync.Once
+
+func BenchmarkCustomFunctions(b *testing.B) {
+ customFunctionOnce.Do(func() {
+ custom_add := func(a, b int64) int64 {
+ return a + b
+ }
+
+ sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{
+ ConnectHook: func(conn *SQLiteConn) error {
+ // Impure function to force sqlite to reexecute it each time.
+ if err := conn.RegisterFunc("custom_add", custom_add, false); err != nil {
+ return err
+ }
+ return nil
+ },
+ })
+ })
+
+ db, err := sql.Open("sqlite3_BenchmarkCustomFunctions", ":memory:")
+ if err != nil {
+ b.Fatal("Failed to open database:", err)
+ }
+ defer db.Close()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ var i int64
+ err = db.QueryRow("SELECT custom_add(1,2)").Scan(&i)
+ if err != nil {
+ b.Fatal("Failed to run custom add:", err)
+ }
+ }
+}