From 5e5d088a3662a12104630827ac6926f7626d23cf Mon Sep 17 00:00:00 2001 From: mattn Date: Mon, 18 Aug 2014 16:56:31 +0900 Subject: Add license header --- sqlite3_test.go | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'sqlite3_test.go') diff --git a/sqlite3_test.go b/sqlite3_test.go index 4f20026..0f10780 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1,3 +1,7 @@ +// Copyright (C) 2014 Yasuhiro Matsumoto . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. package sqlite3 import ( -- cgit v1.2.3 From 6535341da906b68a836904b742145b409e961a56 Mon Sep 17 00:00:00 2001 From: mattn Date: Mon, 18 Aug 2014 17:00:59 +0900 Subject: Add one blank line for godoc --- backup.go | 1 + error.go | 1 + error_test.go | 1 + sqlite3.go | 1 + sqlite3_test.go | 1 + sqlite3_windows.go | 1 + 6 files changed, 6 insertions(+) (limited to 'sqlite3_test.go') diff --git a/backup.go b/backup.go index 9e79c60..4b2a72d 100644 --- a/backup.go +++ b/backup.go @@ -2,6 +2,7 @@ // // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. + package sqlite3 /* diff --git a/error.go b/error.go index 59b5051..b910108 100644 --- a/error.go +++ b/error.go @@ -2,6 +2,7 @@ // // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. + package sqlite3 import "C" diff --git a/error_test.go b/error_test.go index 6a8660d..a006188 100644 --- a/error_test.go +++ b/error_test.go @@ -2,6 +2,7 @@ // // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. + package sqlite3 import ( diff --git a/sqlite3.go b/sqlite3.go index 1683a48..aac71ad 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -2,6 +2,7 @@ // // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. + package sqlite3 /* diff --git a/sqlite3_test.go b/sqlite3_test.go index 0f10780..581d289 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -2,6 +2,7 @@ // // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. + package sqlite3 import ( diff --git a/sqlite3_windows.go b/sqlite3_windows.go index c5565e9..269c05a 100644 --- a/sqlite3_windows.go +++ b/sqlite3_windows.go @@ -2,6 +2,7 @@ // // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. + package sqlite3 /* -- cgit v1.2.3 From abf79dbdd53a86fc1f0d81ffa8c9970c31159d81 Mon Sep 17 00:00:00 2001 From: Paweł Błaszczyk Date: Wed, 20 Aug 2014 16:13:15 +0200 Subject: Fix for sqlite3_test import. --- sqlite3_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3_test.go b/sqlite3_test.go index 581d289..9cc5a0e 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -14,7 +14,7 @@ import ( "testing" "time" - "./sqlite3_test" + "github.com/mattn/go-sqlite3/sqlite3_test" ) func TempFilename() string { -- cgit v1.2.3 From 0b05acc293a940d3bfedf2df5fba12c49ac9ec35 Mon Sep 17 00:00:00 2001 From: Ian Bishop Date: Fri, 2 Jan 2015 16:31:46 +1000 Subject: Handle 13 digit datetime values --- sqlite3.go | 14 +++++++++++++- sqlite3_test.go | 1 + 2 files changed, 14 insertions(+), 1 deletion(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3.go b/sqlite3.go index bc532cd..04f0edb 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -64,6 +64,7 @@ import ( "fmt" "io" "runtime" + "strconv" "strings" "time" "unsafe" @@ -504,7 +505,18 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i))) switch rc.decltype[i] { case "timestamp", "datetime", "date": - dest[i] = time.Unix(val, 0).Local() + unixTimestamp := strconv.FormatInt(val, 10) + if len(unixTimestamp) == 13 { + duration, err := time.ParseDuration(unixTimestamp + "ms") + if err != nil { + return fmt.Errorf("error parsing %s value %d, %s", rc.decltype[i], val, err) + } + epoch := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) + dest[i] = epoch.Add(duration) + } else { + dest[i] = time.Unix(val, 0).Local() + } + case "boolean": dest[i] = val > 0 default: diff --git a/sqlite3_test.go b/sqlite3_test.go index 9cc5a0e..3e50258 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -309,6 +309,7 @@ func TestTimestamp(t *testing.T) { {"0000-00-00 00:00:00", time.Time{}}, {timestamp1, timestamp1}, {timestamp1.Unix(), timestamp1}, + {timestamp1.UnixNano() / int64(time.Millisecond), timestamp1}, {timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1}, {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1}, {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1}, -- cgit v1.2.3 From 5e6658a5c802ef478bcaaccad2103ab2a0cda562 Mon Sep 17 00:00:00 2001 From: mattn Date: Mon, 26 Jan 2015 18:43:28 +0900 Subject: Add test for Version --- sqlite3_test.go | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'sqlite3_test.go') diff --git a/sqlite3_test.go b/sqlite3_test.go index 3e50258..7df2169 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -743,3 +743,10 @@ func TestStress(t *testing.T) { db.Close() } } + +func TestVersion(t *testing.T) { + s, n, id, := Version() + if s == "" || n == 0 || id == 0 { + t.Errorf("Version failed %q, %d, %q\n", s, n, id) + } +} -- cgit v1.2.3 From 6717138923265f43d673bf92ece46f707370a94b Mon Sep 17 00:00:00 2001 From: mattn Date: Mon, 26 Jan 2015 18:55:41 +0900 Subject: Fix test --- sqlite3_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3_test.go b/sqlite3_test.go index 7df2169..672128e 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -745,7 +745,7 @@ func TestStress(t *testing.T) { } func TestVersion(t *testing.T) { - s, n, id, := Version() + s, n, id := Version() if s == "" || n == 0 || id == 0 { t.Errorf("Version failed %q, %d, %q\n", s, n, id) } -- cgit v1.2.3 From a141177ca6dad6563ee78bb15495cea8679730fd Mon Sep 17 00:00:00 2001 From: mattn Date: Mon, 26 Jan 2015 18:58:58 +0900 Subject: Fix test --- sqlite3_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3_test.go b/sqlite3_test.go index 672128e..a0adf30 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -746,7 +746,7 @@ func TestStress(t *testing.T) { func TestVersion(t *testing.T) { s, n, id := Version() - if s == "" || n == 0 || id == 0 { + if s == "" || n == 0 || id == "" { t.Errorf("Version failed %q, %d, %q\n", s, n, id) } } -- cgit v1.2.3 From 4c5c4e526100f23b17fda84d39a8109cfec00118 Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Wed, 4 Mar 2015 22:49:17 +0900 Subject: Add loc=XXX parameters to handle timezone --- sqlite3.go | 52 ++++++++++++++++++++++++++++++++++++++------- sqlite3_test.go | 65 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 8 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3.go b/sqlite3.go index 605474c..4457798 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -65,6 +65,7 @@ import ( "errors" "fmt" "io" + "net/url" "runtime" "strconv" "strings" @@ -107,7 +108,8 @@ type SQLiteDriver struct { // Conn struct. type SQLiteConn struct { - db *C.sqlite3 + db *C.sqlite3 + loc *time.Location } // Tx struct. @@ -256,11 +258,31 @@ func errorString(err Error) string { // file:test.db?cache=shared&mode=memory // :memory: // file::memory: +// go-sqlite handle especially query parameters. +// loc=XXX +// Specify location of time format. It's possible to specify "auto". func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { if C.sqlite3_threadsafe() == 0 { return nil, errors.New("sqlite library was not compiled for thread-safe operation") } + var loc *time.Location + if u, err := url.Parse(dsn); err == nil { + for k, v := range u.Query() { + switch k { + case "loc": + if len(v) > 0 { + if v[0] == "auto" { + v[0] = time.Local.String() + } + if loc, err = time.LoadLocation(v[0]); err != nil { + return nil, fmt.Errorf("Invalid loc: %v: %v", v[0], err) + } + } + } + } + } + var db *C.sqlite3 name := C.CString(dsn) defer C.free(unsafe.Pointer(name)) @@ -281,7 +303,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { return nil, Error{Code: ErrNo(rv)} } - conn := &SQLiteConn{db} + conn := &SQLiteConn{db: db, loc: loc} if len(d.Extensions) > 0 { rv = C.sqlite3_enable_load_extension(db, 1) @@ -401,8 +423,13 @@ func (s *SQLiteStmt) bind(args []driver.Value) error { } rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v))) case time.Time: - b := []byte(v.UTC().Format(SQLiteTimestampFormats[0])) - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) + if s.c.loc != nil { + b := []byte(v.In(s.c.loc).Format(SQLiteTimestampFormats[0])) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) + } else { + b := []byte(v.UTC().Format(SQLiteTimestampFormats[0])) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) + } } if rv != C.SQLITE_OK { return s.c.lastError() @@ -545,10 +572,19 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { switch rc.decltype[i] { case "timestamp", "datetime", "date": - for _, format := range SQLiteTimestampFormats { - if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil { - dest[i] = timeVal.Local() - break + if rc.s.c.loc != nil { + for _, format := range SQLiteTimestampFormats { + if timeVal, err = time.ParseInLocation(format, s, rc.s.c.loc); err == nil { + dest[i] = timeVal + break + } + } + } else { + for _, format := range SQLiteTimestampFormats { + if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil { + dest[i] = timeVal + break + } } } if err != nil { diff --git a/sqlite3_test.go b/sqlite3_test.go index a0adf30..325ba8e 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -744,6 +744,71 @@ func TestStress(t *testing.T) { } } +func TestDateTimeLocal(t *testing.T) { + zone := "Asia/Tokyo" + z, err := time.LoadLocation(zone) + if err != nil { + t.Skip("Failed to load timezon:", err) + } + tempFilename := TempFilename() + db, err := sql.Open("sqlite3", "file:///"+tempFilename+"?loc="+zone) + if err != nil { + t.Fatal("Failed to open database:", err) + } + db.Exec("CREATE TABLE foo (id datetime);") + db.Exec("INSERT INTO foo VALUES('2015-03-05 15:16:17');") + + row := db.QueryRow("select * from foo") + var d time.Time + err = row.Scan(&d) + if err != nil { + t.Fatal("Failed to scan datetime:", err) + } + if d.Local().Hour() != 15 { + t.Fatal("Result should have timezone", d) + } + db.Close() + + db, err = sql.Open("sqlite3", "file:///"+tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + + row = db.QueryRow("select * from foo") + err = row.Scan(&d) + if err != nil { + t.Fatal("Failed to scan datetime:", err) + } + if d.In(z).Hour() == 15 { + t.Fatalf("Result should not have timezone %v", zone) + } + + _, err = db.Exec("DELETE FROM foo") + if err != nil { + t.Fatal("Failed to delete table:", err) + } + dt, err := time.Parse("2006/1/2 15/4/5 -0700 MST", "2015/3/5 15/16/17 +0900 JST") + if err != nil { + t.Fatal("Failed to parse datetime:", err) + } + db.Exec("INSERT INTO foo VALUES(?);", dt) + + db.Close() + db, err = sql.Open("sqlite3", "file:///"+tempFilename+"?loc="+zone) + if err != nil { + t.Fatal("Failed to open database:", err) + } + + row = db.QueryRow("select * from foo") + err = row.Scan(&d) + if err != nil { + t.Fatal("Failed to scan datetime:", err) + } + if d.Hour() == 15 { + t.Fatalf("Result should have timezone %v", zone) + } +} + func TestVersion(t *testing.T) { s, n, id := Version() if s == "" || n == 0 || id == "" { -- cgit v1.2.3 From e273a1552e39b67e23de088b5ee663e6c16baccd Mon Sep 17 00:00:00 2001 From: mattn Date: Thu, 5 Mar 2015 01:17:38 +0900 Subject: Fixed bug for loc parameter --- sqlite3.go | 34 ++++++++++++++++++---------------- sqlite3_test.go | 17 +++++++---------- 2 files changed, 25 insertions(+), 26 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3.go b/sqlite3.go index e6584a4..b51ad9a 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -423,13 +423,8 @@ func (s *SQLiteStmt) bind(args []driver.Value) error { } rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v))) case time.Time: - if s.c.loc != nil { - b := []byte(v.In(s.c.loc).Format(SQLiteTimestampFormats[0])) - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) - } else { - b := []byte(v.UTC().Format(SQLiteTimestampFormats[0])) - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) - } + b := []byte(v.UTC().Format(SQLiteTimestampFormats[0])) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) } if rv != C.SQLITE_OK { return s.c.lastError() @@ -536,11 +531,18 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { return fmt.Errorf("error parsing %s value %d, %s", rc.decltype[i], val, err) } epoch := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) - dest[i] = epoch.Add(duration) + if rc.s.c.loc != nil { + dest[i] = epoch.Add(duration).In(rc.s.c.loc) + } else { + dest[i] = epoch.Add(duration) + } } else { - dest[i] = time.Unix(val, 0).Local() + if rc.s.c.loc != nil { + dest[i] = time.Unix(val, 0).In(rc.s.c.loc) + } else { + dest[i] = time.Unix(val, 0) + } } - case "boolean": dest[i] = val > 0 default: @@ -572,13 +574,13 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { switch rc.decltype[i] { case "timestamp", "datetime", "date": - zone := rc.s.c.loc - if zone == nil { - zone = time.UTC - } for _, format := range SQLiteTimestampFormats { - if timeVal, err = time.ParseInLocation(format, s, zone); err == nil { - dest[i] = timeVal + if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil { + if rc.s.c.loc != nil { + dest[i] = timeVal.In(rc.s.c.loc) + } else { + dest[i] = timeVal + } break } } diff --git a/sqlite3_test.go b/sqlite3_test.go index 325ba8e..1036525 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -11,6 +11,7 @@ import ( "encoding/hex" "os" "path/filepath" + "strings" "testing" "time" @@ -746,16 +747,12 @@ func TestStress(t *testing.T) { func TestDateTimeLocal(t *testing.T) { zone := "Asia/Tokyo" - z, err := time.LoadLocation(zone) - if err != nil { - t.Skip("Failed to load timezon:", err) - } tempFilename := TempFilename() db, err := sql.Open("sqlite3", "file:///"+tempFilename+"?loc="+zone) if err != nil { t.Fatal("Failed to open database:", err) } - db.Exec("CREATE TABLE foo (id datetime);") + db.Exec("CREATE TABLE foo (dt datetime);") db.Exec("INSERT INTO foo VALUES('2015-03-05 15:16:17');") row := db.QueryRow("select * from foo") @@ -764,7 +761,7 @@ func TestDateTimeLocal(t *testing.T) { if err != nil { t.Fatal("Failed to scan datetime:", err) } - if d.Local().Hour() != 15 { + if d.Hour() == 15 || !strings.Contains(d.String(), "JST") { t.Fatal("Result should have timezone", d) } db.Close() @@ -779,8 +776,8 @@ func TestDateTimeLocal(t *testing.T) { if err != nil { t.Fatal("Failed to scan datetime:", err) } - if d.In(z).Hour() == 15 { - t.Fatalf("Result should not have timezone %v", zone) + if d.UTC().Hour() != 15 || !strings.Contains(d.String(), "UTC") { + t.Fatalf("Result should not have timezone %v %v", zone, d.String()) } _, err = db.Exec("DELETE FROM foo") @@ -804,8 +801,8 @@ func TestDateTimeLocal(t *testing.T) { if err != nil { t.Fatal("Failed to scan datetime:", err) } - if d.Hour() == 15 { - t.Fatalf("Result should have timezone %v", zone) + if d.Hour() != 15 || !strings.Contains(d.String(), "JST") { + t.Fatalf("Result should have timezone %v %v", zone, d.String()) } } -- cgit v1.2.3 From 02f54e026317040a528e0b37207026dd2f885aa9 Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Thu, 5 Mar 2015 10:34:31 +0900 Subject: Add test --- sqlite3_test.go | 97 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) (limited to 'sqlite3_test.go') diff --git a/sqlite3_test.go b/sqlite3_test.go index 1036525..f66c64a 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -9,6 +9,7 @@ import ( "crypto/rand" "database/sql" "encoding/hex" + "net/url" "os" "path/filepath" "strings" @@ -635,6 +636,102 @@ func TestWAL(t *testing.T) { } } +func TestTimezoneConversion(t *testing.T) { + zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} + for _, tz := range zones { + tempFilename := TempFilename() + db, err := sql.Open("sqlite3", tempFilename+"?loc="+url.QueryEscape(tz)) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer os.Remove(tempFilename) + defer db.Close() + + _, err = db.Exec("DROP TABLE foo") + _, err = db.Exec("CREATE TABLE foo(id INTEGER, ts TIMESTAMP, dt DATETIME)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + loc, err := time.LoadLocation(tz) + if err != nil { + t.Fatal("Failed to load location:", err) + } + + timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC) + timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC) + timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC) + tests := []struct { + value interface{} + expected time.Time + }{ + {"nonsense", time.Time{}.In(loc)}, + {"0000-00-00 00:00:00", time.Time{}.In(loc)}, + {timestamp1, timestamp1.In(loc)}, + {timestamp1.Unix(), timestamp1.In(loc)}, + {timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1.In(loc)}, + {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1.In(loc)}, + {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1.In(loc)}, + {timestamp1.Format("2006-01-02 15:04:05"), timestamp1.In(loc)}, + {timestamp1.Format("2006-01-02T15:04:05"), timestamp1.In(loc)}, + {timestamp2, timestamp2.In(loc)}, + {"2006-01-02 15:04:05.123456789", timestamp2.In(loc)}, + {"2006-01-02T15:04:05.123456789", timestamp2.In(loc)}, + {"2012-11-04", timestamp3.In(loc)}, + {"2012-11-04 00:00", timestamp3.In(loc)}, + {"2012-11-04 00:00:00", timestamp3.In(loc)}, + {"2012-11-04 00:00:00.000", timestamp3.In(loc)}, + {"2012-11-04T00:00", timestamp3.In(loc)}, + {"2012-11-04T00:00:00", timestamp3.In(loc)}, + {"2012-11-04T00:00:00.000", timestamp3.In(loc)}, + } + for i := range tests { + _, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value) + if err != nil { + t.Fatal("Failed to insert timestamp:", err) + } + } + + rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + defer rows.Close() + + seen := 0 + for rows.Next() { + var id int + var ts, dt time.Time + + if err := rows.Scan(&id, &ts, &dt); err != nil { + t.Error("Unable to scan results:", err) + continue + } + if id < 0 || id >= len(tests) { + t.Error("Bad row id: ", id) + continue + } + seen++ + if !tests[id].expected.Equal(ts) { + t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, ts) + } + if !tests[id].expected.Equal(dt) { + t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) + } + if tests[id].expected.Location().String() != ts.Location().String() { + t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].expected.Location().String(), ts.Location().String()) + } + if tests[id].expected.Location().String() != dt.Location().String() { + t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].expected.Location().String(), dt.Location().String()) + } + } + + if seen != len(tests) { + t.Errorf("Expected to see %d rows", len(tests)) + } + } +} + func TestSuite(t *testing.T) { db, err := sql.Open("sqlite3", ":memory:") if err != nil { -- cgit v1.2.3 From 71712f0ba9e5a262fcd808cfabedd0c759fa214b Mon Sep 17 00:00:00 2001 From: mix3 Date: Thu, 5 Mar 2015 10:34:31 +0900 Subject: Add test --- sqlite3_test.go | 96 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) (limited to 'sqlite3_test.go') diff --git a/sqlite3_test.go b/sqlite3_test.go index f66c64a..490f852 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -732,6 +732,102 @@ func TestTimezoneConversion(t *testing.T) { } } +func TestTimezoneConversion(t *testing.T) { + zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} + for _, tz := range zones { + tempFilename := TempFilename() + db, err := sql.Open("sqlite3", tempFilename+"?loc="+url.QueryEscape(tz)) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer os.Remove(tempFilename) + defer db.Close() + + _, err = db.Exec("DROP TABLE foo") + _, err = db.Exec("CREATE TABLE foo(id INTEGER, ts TIMESTAMP, dt DATETIME)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + loc, err := time.LoadLocation(tz) + if err != nil { + t.Fatal("Failed to load location:", err) + } + + timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC) + timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC) + timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC) + tests := []struct { + value interface{} + expected time.Time + }{ + {"nonsense", time.Time{}.In(loc)}, + {"0000-00-00 00:00:00", time.Time{}.In(loc)}, + {timestamp1, timestamp1.In(loc)}, + {timestamp1.Unix(), timestamp1.In(loc)}, + {timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1.In(loc)}, + {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1.In(loc)}, + {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1.In(loc)}, + {timestamp1.Format("2006-01-02 15:04:05"), timestamp1.In(loc)}, + {timestamp1.Format("2006-01-02T15:04:05"), timestamp1.In(loc)}, + {timestamp2, timestamp2.In(loc)}, + {"2006-01-02 15:04:05.123456789", timestamp2.In(loc)}, + {"2006-01-02T15:04:05.123456789", timestamp2.In(loc)}, + {"2012-11-04", timestamp3.In(loc)}, + {"2012-11-04 00:00", timestamp3.In(loc)}, + {"2012-11-04 00:00:00", timestamp3.In(loc)}, + {"2012-11-04 00:00:00.000", timestamp3.In(loc)}, + {"2012-11-04T00:00", timestamp3.In(loc)}, + {"2012-11-04T00:00:00", timestamp3.In(loc)}, + {"2012-11-04T00:00:00.000", timestamp3.In(loc)}, + } + for i := range tests { + _, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value) + if err != nil { + t.Fatal("Failed to insert timestamp:", err) + } + } + + rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + defer rows.Close() + + seen := 0 + for rows.Next() { + var id int + var ts, dt time.Time + + if err := rows.Scan(&id, &ts, &dt); err != nil { + t.Error("Unable to scan results:", err) + continue + } + if id < 0 || id >= len(tests) { + t.Error("Bad row id: ", id) + continue + } + seen++ + if !tests[id].expected.Equal(ts) { + t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, ts) + } + if !tests[id].expected.Equal(dt) { + t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) + } + if tests[id].expected.Location().String() != ts.Location().String() { + t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].expected.Location().String(), ts.Location().String()) + } + if tests[id].expected.Location().String() != dt.Location().String() { + t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].expected.Location().String(), dt.Location().String()) + } + } + + if seen != len(tests) { + t.Errorf("Expected to see %d rows", len(tests)) + } + } +} + func TestSuite(t *testing.T) { db, err := sql.Open("sqlite3", ":memory:") if err != nil { -- cgit v1.2.3 From e48e0597ab41ba5aac618e241d9551a314838cda Mon Sep 17 00:00:00 2001 From: mix3 Date: Thu, 5 Mar 2015 11:05:58 +0900 Subject: Fix loc parsing --- sqlite3.go | 58 ++++++++++++++++++++++++++++++--------------------------- sqlite3_test.go | 4 ++-- 2 files changed, 33 insertions(+), 29 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3.go b/sqlite3.go index b51ad9a..454e42d 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -267,20 +267,26 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { } var loc *time.Location - if u, err := url.Parse(dsn); err == nil { - for k, v := range u.Query() { - switch k { - case "loc": - if len(v) > 0 { - if v[0] == "auto" { - v[0] = time.Local.String() - } - if loc, err = time.LoadLocation(v[0]); err != nil { - return nil, fmt.Errorf("Invalid loc: %v: %v", v[0], err) - } + pos := strings.IndexRune(dsn, '?') + if pos >= 1 { + params, err := url.ParseQuery(dsn[pos+1:]) + if err != nil { + return nil, err + } + + // loc + if val := params.Get("loc"); val != "" { + if val == "auto" { + loc = time.Local + } else { + loc, err = time.LoadLocation(val) + if err != nil { + return nil, fmt.Errorf("Invalid loc: %v: %v", val, err) } } } + + dsn = dsn[:pos-1] } var db *C.sqlite3 @@ -525,24 +531,21 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { switch rc.decltype[i] { case "timestamp", "datetime", "date": unixTimestamp := strconv.FormatInt(val, 10) + var t time.Time if len(unixTimestamp) == 13 { duration, err := time.ParseDuration(unixTimestamp + "ms") if err != nil { return fmt.Errorf("error parsing %s value %d, %s", rc.decltype[i], val, err) } epoch := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) - if rc.s.c.loc != nil { - dest[i] = epoch.Add(duration).In(rc.s.c.loc) - } else { - dest[i] = epoch.Add(duration) - } + t = epoch.Add(duration) } else { - if rc.s.c.loc != nil { - dest[i] = time.Unix(val, 0).In(rc.s.c.loc) - } else { - dest[i] = time.Unix(val, 0) - } + t = time.Unix(val, 0) + } + if rc.s.c.loc != nil { + t = t.In(rc.s.c.loc) } + dest[i] = t case "boolean": dest[i] = val > 0 default: @@ -574,20 +577,21 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { switch rc.decltype[i] { case "timestamp", "datetime", "date": + var t time.Time for _, format := range SQLiteTimestampFormats { if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil { - if rc.s.c.loc != nil { - dest[i] = timeVal.In(rc.s.c.loc) - } else { - dest[i] = timeVal - } + t = timeVal break } } if err != nil { // The column is a time value, so return the zero time on parse failure. - dest[i] = time.Time{} + t = time.Time{} + } + if rc.s.c.loc != nil { + t = t.In(rc.s.c.loc) } + dest[i] = t default: dest[i] = []byte(s) } diff --git a/sqlite3_test.go b/sqlite3_test.go index 490f852..9c573ce 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -719,10 +719,10 @@ func TestTimezoneConversion(t *testing.T) { t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) } if tests[id].expected.Location().String() != ts.Location().String() { - t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].expected.Location().String(), ts.Location().String()) + t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), ts.Location().String()) } if tests[id].expected.Location().String() != dt.Location().String() { - t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].expected.Location().String(), dt.Location().String()) + t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), dt.Location().String()) } } -- cgit v1.2.3 From d463e8f1f9a1cd0f320006c22643dbc583d079ce Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Thu, 5 Mar 2015 12:32:06 +0900 Subject: Remove test dup --- sqlite3_test.go | 96 --------------------------------------------------------- 1 file changed, 96 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3_test.go b/sqlite3_test.go index 9c573ce..4b0fe01 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -732,102 +732,6 @@ func TestTimezoneConversion(t *testing.T) { } } -func TestTimezoneConversion(t *testing.T) { - zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} - for _, tz := range zones { - tempFilename := TempFilename() - db, err := sql.Open("sqlite3", tempFilename+"?loc="+url.QueryEscape(tz)) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer os.Remove(tempFilename) - defer db.Close() - - _, err = db.Exec("DROP TABLE foo") - _, err = db.Exec("CREATE TABLE foo(id INTEGER, ts TIMESTAMP, dt DATETIME)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - loc, err := time.LoadLocation(tz) - if err != nil { - t.Fatal("Failed to load location:", err) - } - - timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC) - timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC) - timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC) - tests := []struct { - value interface{} - expected time.Time - }{ - {"nonsense", time.Time{}.In(loc)}, - {"0000-00-00 00:00:00", time.Time{}.In(loc)}, - {timestamp1, timestamp1.In(loc)}, - {timestamp1.Unix(), timestamp1.In(loc)}, - {timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1.In(loc)}, - {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1.In(loc)}, - {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1.In(loc)}, - {timestamp1.Format("2006-01-02 15:04:05"), timestamp1.In(loc)}, - {timestamp1.Format("2006-01-02T15:04:05"), timestamp1.In(loc)}, - {timestamp2, timestamp2.In(loc)}, - {"2006-01-02 15:04:05.123456789", timestamp2.In(loc)}, - {"2006-01-02T15:04:05.123456789", timestamp2.In(loc)}, - {"2012-11-04", timestamp3.In(loc)}, - {"2012-11-04 00:00", timestamp3.In(loc)}, - {"2012-11-04 00:00:00", timestamp3.In(loc)}, - {"2012-11-04 00:00:00.000", timestamp3.In(loc)}, - {"2012-11-04T00:00", timestamp3.In(loc)}, - {"2012-11-04T00:00:00", timestamp3.In(loc)}, - {"2012-11-04T00:00:00.000", timestamp3.In(loc)}, - } - for i := range tests { - _, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value) - if err != nil { - t.Fatal("Failed to insert timestamp:", err) - } - } - - rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - defer rows.Close() - - seen := 0 - for rows.Next() { - var id int - var ts, dt time.Time - - if err := rows.Scan(&id, &ts, &dt); err != nil { - t.Error("Unable to scan results:", err) - continue - } - if id < 0 || id >= len(tests) { - t.Error("Bad row id: ", id) - continue - } - seen++ - if !tests[id].expected.Equal(ts) { - t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, ts) - } - if !tests[id].expected.Equal(dt) { - t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) - } - if tests[id].expected.Location().String() != ts.Location().String() { - t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].expected.Location().String(), ts.Location().String()) - } - if tests[id].expected.Location().String() != dt.Location().String() { - t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].expected.Location().String(), dt.Location().String()) - } - } - - if seen != len(tests) { - t.Errorf("Expected to see %d rows", len(tests)) - } - } -} - func TestSuite(t *testing.T) { db, err := sql.Open("sqlite3", ":memory:") if err != nil { -- cgit v1.2.3 From f40baee643e0e2b8e7d08d8d6fd62127a02ab132 Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Thu, 5 Mar 2015 12:39:44 +0900 Subject: Fix test --- sqlite3_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3_test.go b/sqlite3_test.go index 4b0fe01..6570b52 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -845,7 +845,7 @@ func TestStress(t *testing.T) { func TestDateTimeLocal(t *testing.T) { zone := "Asia/Tokyo" tempFilename := TempFilename() - db, err := sql.Open("sqlite3", "file:///"+tempFilename+"?loc="+zone) + db, err := sql.Open("sqlite3", tempFilename+"?loc="+zone) if err != nil { t.Fatal("Failed to open database:", err) } @@ -863,7 +863,7 @@ func TestDateTimeLocal(t *testing.T) { } db.Close() - db, err = sql.Open("sqlite3", "file:///"+tempFilename) + db, err = sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } @@ -888,7 +888,7 @@ func TestDateTimeLocal(t *testing.T) { db.Exec("INSERT INTO foo VALUES(?);", dt) db.Close() - db, err = sql.Open("sqlite3", "file:///"+tempFilename+"?loc="+zone) + db, err = sql.Open("sqlite3", tempFilename+"?loc="+zone) if err != nil { t.Fatal("Failed to open database:", err) } -- cgit v1.2.3 From a6c208564eccf3c6743f608ef88398a4ca84c5eb Mon Sep 17 00:00:00 2001 From: mattn Date: Sun, 22 Mar 2015 02:08:47 +0900 Subject: Support $NNN-style named parameter. Close #187 --- sqlite3.go | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++++---- sqlite3_test.go | 36 ++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 4 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3.go b/sqlite3.go index 91aea8f..4e68e96 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -121,6 +121,7 @@ type SQLiteTx struct { type SQLiteStmt struct { c *SQLiteConn s *C.sqlite3_stmt + nv int t string closed bool cls bool @@ -368,7 +369,19 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { if tail != nil && C.strlen(tail) > 0 { t = strings.TrimSpace(C.GoString(tail)) } - ss := &SQLiteStmt{c: c, s: s, t: t} + nv := int(C.sqlite3_bind_parameter_count(s)) + if nv > 0 { + pn := C.GoString(C.sqlite3_bind_parameter_name(s, 1)) + /* TODO: map argument for named parameters + if len(pn) > 0 && pn[0] == '$' && pn[1] != '1' { + nv = -1 + } + */ + if len(pn) > 0 && pn[0] != '?' { + nv = -1 + } + } + ss := &SQLiteStmt{c: c, s: s, nv: nv, t: t} runtime.SetFinalizer(ss, (*SQLiteStmt).Close) return ss, nil } @@ -392,7 +405,12 @@ func (s *SQLiteStmt) Close() error { // Return a number of parameters. func (s *SQLiteStmt) NumInput() int { - return int(C.sqlite3_bind_parameter_count(s.s)) + return s.nv +} + +type bindArg struct { + n int + v driver.Value } func (s *SQLiteStmt) bind(args []driver.Value) error { @@ -401,8 +419,43 @@ func (s *SQLiteStmt) bind(args []driver.Value) error { return s.c.lastError() } - for i, v := range args { - n := C.int(i + 1) + var vargs []bindArg + narg := len(args) + if s.nv == -1 { + /* TODO: map argument for named parameters + if narg == 1 { + if m, ok := args[0].(map[string]driver.Value); ok { + for k, v := range m { + pn := C.CString(k) + if pi := int(C.sqlite3_bind_parameter_index(s.s, pn)); pi > 0 { + println(pi) + vargs = append(vargs, bindArg{pi, v}) + } + C.free(unsafe.Pointer(pn)) + } + } + narg = 0 + } + */ + if narg > 0 { + for i := 0; i < narg; i++ { + pn := C.CString(fmt.Sprint(i + 1)) + if pi := int(C.sqlite3_bind_parameter_index(s.s, pn)); pi > 0 { + vargs = append(vargs, bindArg{pi, args[i]}) + } + C.free(unsafe.Pointer(pn)) + } + } + } else { + vargs = make([]bindArg, narg) + for i, v := range args { + vargs[i] = bindArg{i + 1, v} + } + } + + for _, varg := range vargs { + n := C.int(varg.n) + v := varg.v switch v := v.(type) { case nil: rv = C.sqlite3_bind_null(s.s, n) diff --git a/sqlite3_test.go b/sqlite3_test.go index 6570b52..245f363 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -909,3 +909,39 @@ func TestVersion(t *testing.T) { t.Errorf("Version failed %q, %d, %q\n", s, n, id) } } + +func TestNumberNamedParams(t *testing.T) { + tempFilename := TempFilename() + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer os.Remove(tempFilename) + defer db.Close() + + _, err = db.Exec(` + create table foo (id integer, name text, extra text); + `) + if err != nil { + t.Error("Failed to call db.Query:", err) + } + + _, err = db.Exec(`insert into foo(id, name, extra)) values($1, $2, $2)`, 1, "foo") + if err != nil { + t.Error("Failed to call db.Exec:", err) + } + + row := db.QueryRow(`select id, name, extra where id = $1 and extra = $2`, 1, "foo") + if row == nil { + t.Error("Failed to call db.QueryRow") + } + var id int + var extra string + err = row.Scan(&id, &extra) + if err != nil { + t.Error("Failed to db.Scan:", err) + } + if id != 1 || extra != "foo" { + t.Error("Failed to db.QueryRow: not matched results") + } +} -- cgit v1.2.3 From d7dbb909ec3384421c36ef3ab3f10a87ba1b01bf Mon Sep 17 00:00:00 2001 From: mattn Date: Sun, 22 Mar 2015 02:39:28 +0900 Subject: Fix test --- sqlite3_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3_test.go b/sqlite3_test.go index 245f363..3e89746 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -926,7 +926,7 @@ func TestNumberNamedParams(t *testing.T) { t.Error("Failed to call db.Query:", err) } - _, err = db.Exec(`insert into foo(id, name, extra)) values($1, $2, $2)`, 1, "foo") + _, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, "foo") if err != nil { t.Error("Failed to call db.Exec:", err) } -- cgit v1.2.3 From c1abf95b381746cfae7fcb2796d72356f38dce63 Mon Sep 17 00:00:00 2001 From: mattn Date: Sun, 22 Mar 2015 03:16:35 +0900 Subject: Fix build --- sqlite3.go | 3 ++- sqlite3_test.go | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3.go b/sqlite3.go index efb7cca..758cfdf 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -292,10 +292,11 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { // _busy_timeout if val := params.Get("_busy_timeout"); val != "" { - busy_timeout = int(strconv.ParseInt(val, 10, 64)) + iv, err := strconv.ParseInt(val, 10, 64) if err != nil { return nil, fmt.Errorf("Invalid _busy_timeout: %v: %v", val, err) } + busy_timeout = int(iv) } if !strings.HasPrefix(dsn, "file:") { diff --git a/sqlite3_test.go b/sqlite3_test.go index 3e89746..1d06fd4 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -640,7 +640,7 @@ func TestTimezoneConversion(t *testing.T) { zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} for _, tz := range zones { tempFilename := TempFilename() - db, err := sql.Open("sqlite3", tempFilename+"?loc="+url.QueryEscape(tz)) + db, err := sql.Open("sqlite3", tempFilename+"?_loc="+url.QueryEscape(tz)) if err != nil { t.Fatal("Failed to open database:", err) } @@ -845,7 +845,7 @@ func TestStress(t *testing.T) { func TestDateTimeLocal(t *testing.T) { zone := "Asia/Tokyo" tempFilename := TempFilename() - db, err := sql.Open("sqlite3", tempFilename+"?loc="+zone) + db, err := sql.Open("sqlite3", tempFilename+"?_loc="+zone) if err != nil { t.Fatal("Failed to open database:", err) } @@ -888,7 +888,7 @@ func TestDateTimeLocal(t *testing.T) { db.Exec("INSERT INTO foo VALUES(?);", dt) db.Close() - db, err = sql.Open("sqlite3", tempFilename+"?loc="+zone) + db, err = sql.Open("sqlite3", tempFilename+"?_loc="+zone) if err != nil { t.Fatal("Failed to open database:", err) } -- cgit v1.2.3 From ff38c8ec02230a8eb61cb8b93cad1dc2f5411e74 Mon Sep 17 00:00:00 2001 From: mattn Date: Sun, 22 Mar 2015 04:29:14 +0900 Subject: Revert a6c208564eccf3c6743f608ef88398a4ca84c5eb --- sqlite3.go | 44 +++----------------------------------------- sqlite3_test.go | 36 ------------------------------------ 2 files changed, 3 insertions(+), 77 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3.go b/sqlite3.go index ca0180b..d384202 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -382,18 +382,6 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { t = strings.TrimSpace(C.GoString(tail)) } nv := int(C.sqlite3_bind_parameter_count(s)) - /* - if nv > 0 { - pn := C.GoString(C.sqlite3_bind_parameter_name(s, 1)) - // TODO: map argument for named parameters - if len(pn) > 0 && pn[0] == '$' && pn[1] != '1' { - nv = -1 - } - if len(pn) > 0 && pn[0] != '?' { - nv = -1 - } - } - */ ss := &SQLiteStmt{c: c, s: s, nv: nv, t: t} runtime.SetFinalizer(ss, (*SQLiteStmt).Close) return ss, nil @@ -434,35 +422,9 @@ func (s *SQLiteStmt) bind(args []driver.Value) error { var vargs []bindArg narg := len(args) - if s.nv == -1 { - /* TODO: map argument for named parameters - if narg == 1 { - if m, ok := args[0].(map[string]driver.Value); ok { - for k, v := range m { - pn := C.CString(k) - if pi := int(C.sqlite3_bind_parameter_index(s.s, pn)); pi > 0 { - vargs = append(vargs, bindArg{pi, v}) - } - C.free(unsafe.Pointer(pn)) - } - } - narg = 0 - } - */ - if narg > 0 { - for i := 0; i < narg; i++ { - pn := C.CString(fmt.Sprint(i + 1)) - if pi := int(C.sqlite3_bind_parameter_index(s.s, pn)); pi > 0 { - vargs = append(vargs, bindArg{pi, args[i]}) - } - C.free(unsafe.Pointer(pn)) - } - } - } else { - vargs = make([]bindArg, narg) - for i, v := range args { - vargs[i] = bindArg{i + 1, v} - } + vargs = make([]bindArg, narg) + for i, v := range args { + vargs[i] = bindArg{i + 1, v} } for _, varg := range vargs { diff --git a/sqlite3_test.go b/sqlite3_test.go index 1d06fd4..81113fc 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -909,39 +909,3 @@ func TestVersion(t *testing.T) { t.Errorf("Version failed %q, %d, %q\n", s, n, id) } } - -func TestNumberNamedParams(t *testing.T) { - tempFilename := TempFilename() - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer os.Remove(tempFilename) - defer db.Close() - - _, err = db.Exec(` - create table foo (id integer, name text, extra text); - `) - if err != nil { - t.Error("Failed to call db.Query:", err) - } - - _, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, "foo") - if err != nil { - t.Error("Failed to call db.Exec:", err) - } - - row := db.QueryRow(`select id, name, extra where id = $1 and extra = $2`, 1, "foo") - if row == nil { - t.Error("Failed to call db.QueryRow") - } - var id int - var extra string - err = row.Scan(&id, &extra) - if err != nil { - t.Error("Failed to db.Scan:", err) - } - if id != 1 || extra != "foo" { - t.Error("Failed to db.QueryRow: not matched results") - } -} -- cgit v1.2.3 From 07f9c9c30fbc180b71842b9282da7b5fa3d40fc9 Mon Sep 17 00:00:00 2001 From: mattn Date: Tue, 24 Mar 2015 00:46:49 +0900 Subject: Implement number-named parameters. Close #187 --- sqlite3.go | 22 +++++++++++++++++++--- sqlite3_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 3 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3.go b/sqlite3.go index d384202..65865f3 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -122,6 +122,7 @@ type SQLiteStmt struct { c *SQLiteConn s *C.sqlite3_stmt nv int + nn []string t string closed bool cls bool @@ -382,7 +383,14 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { t = strings.TrimSpace(C.GoString(tail)) } nv := int(C.sqlite3_bind_parameter_count(s)) - ss := &SQLiteStmt{c: c, s: s, nv: nv, t: t} + var nn []string + for i := 0; i < nv; i++ { + pn := C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1))) + if len(pn) > 1 && pn[0] == '$' && 48 <= pn[1] && pn[1] <= 57 { + nn = append(nn, C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1)))) + } + } + ss := &SQLiteStmt{c: c, s: s, nv: nv, nn: nn, t: t} runtime.SetFinalizer(ss, (*SQLiteStmt).Close) return ss, nil } @@ -423,8 +431,16 @@ func (s *SQLiteStmt) bind(args []driver.Value) error { var vargs []bindArg narg := len(args) vargs = make([]bindArg, narg) - for i, v := range args { - vargs[i] = bindArg{i + 1, v} + if len(s.nn) > 0 { + for i, v := range s.nn { + if pi, err := strconv.Atoi(v[1:]); err == nil { + vargs[i] = bindArg{pi, args[i]} + } + } + } else { + for i, v := range args { + vargs[i] = bindArg{i + 1, v} + } } for _, varg := range vargs { diff --git a/sqlite3_test.go b/sqlite3_test.go index 81113fc..aa86011 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -909,3 +909,39 @@ func TestVersion(t *testing.T) { t.Errorf("Version failed %q, %d, %q\n", s, n, id) } } + +func TestNumberNamedParams(t *testing.T) { + tempFilename := TempFilename() + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer os.Remove(tempFilename) + defer db.Close() + + _, err = db.Exec(` + create table foo (id integer, name text, extra text); + `) + if err != nil { + t.Error("Failed to call db.Query:", err) + } + + _, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, "foo") + if err != nil { + t.Error("Failed to call db.Exec:", err) + } + + row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, "foo") + if row == nil { + t.Error("Failed to call db.QueryRow") + } + var id int + var extra string + err = row.Scan(&id, &extra) + if err != nil { + t.Error("Failed to db.Scan:", err) + } + if id != 1 || extra != "foo" { + t.Error("Failed to db.QueryRow: not matched results") + } +} -- cgit v1.2.3 From ac0129617fd4293e7c341a48bba2adfc85f5afe1 Mon Sep 17 00:00:00 2001 From: Egon Elbre Date: Sun, 12 Apr 2015 14:59:29 +0300 Subject: Fix NULs in text. NUL character is a valid symbols in UTF8. Fixes #195 --- sqlite3.go | 4 +++- sqlite3_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3.go b/sqlite3.go index f4de3fd..99924bf 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -624,7 +624,9 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { case C.SQLITE_TEXT: var err error var timeVal time.Time - s := C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i))))) + + n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i))) + s := C.GoStringN((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))), C.int(n)) switch rc.decltype[i] { case "timestamp", "datetime", "date": diff --git a/sqlite3_test.go b/sqlite3_test.go index aa86011..44e00f1 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -945,3 +945,42 @@ func TestNumberNamedParams(t *testing.T) { t.Error("Failed to db.QueryRow: not matched results") } } + +func TestStringContainingZero(t *testing.T) { + tempFilename := TempFilename() + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer os.Remove(tempFilename) + defer db.Close() + + _, err = db.Exec(` + create table foo (id integer, name, extra text); + `) + if err != nil { + t.Error("Failed to call db.Query:", err) + } + + const text = "foo\x00bar" + + _, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, text) + if err != nil { + t.Error("Failed to call db.Exec:", err) + } + + row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, text) + if row == nil { + t.Error("Failed to call db.QueryRow") + } + + var id int + var extra string + err = row.Scan(&id, &extra) + if err != nil { + t.Error("Failed to db.Scan:", err) + } + if id != 1 || extra != text { + t.Error("Failed to db.QueryRow: not matched results") + } +} -- cgit v1.2.3 From f91a09fb506e7543841b4a9725e16da84fc82b72 Mon Sep 17 00:00:00 2001 From: Serge Hallyn Date: Fri, 10 Apr 2015 11:32:18 -0500 Subject: Add a txlock option when opening databases (v2) When specified, changes the default locking at a tx.Begin. Changelog (v2): Add a testcase to ensure _txlock is properly handled. Closes #189 Signed-off-by: Serge Hallyn --- sqlite3.go | 27 +++++++++++++++++++++++---- sqlite3_test.go | 44 +++++++++++++++++++++++++++++++++++++++----- 2 files changed, 62 insertions(+), 9 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3.go b/sqlite3.go index f4de3fd..c7a87b7 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -117,8 +117,9 @@ type SQLiteDriver struct { // Conn struct. type SQLiteConn struct { - db *C.sqlite3 - loc *time.Location + db *C.sqlite3 + loc *time.Location + txlock string } // Tx struct. @@ -252,7 +253,7 @@ func (c *SQLiteConn) exec(cmd string) (driver.Result, error) { // Begin transaction. func (c *SQLiteConn) Begin() (driver.Tx, error) { - if _, err := c.exec("BEGIN"); err != nil { + if _, err := c.exec(c.txlock); err != nil { return nil, err } return &SQLiteTx{c}, nil @@ -273,12 +274,16 @@ func errorString(err Error) string { // Specify location of time format. It's possible to specify "auto". // _busy_timeout=XXX // Specify value for sqlite3_busy_timeout. +// _txlock=XXX +// Specify locking behavior for transactions. XXX can be "immediate", +// "deferred", "exclusive". func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { if C.sqlite3_threadsafe() == 0 { return nil, errors.New("sqlite library was not compiled for thread-safe operation") } var loc *time.Location + txlock := "BEGIN" busy_timeout := 5000 pos := strings.IndexRune(dsn, '?') if pos >= 1 { @@ -308,6 +313,20 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { busy_timeout = int(iv) } + // _txlock + if val := params.Get("_txlock"); val != "" { + switch val { + case "immediate": + txlock = "BEGIN IMMEDIATE" + case "exclusive": + txlock = "BEGIN EXCLUSIVE" + case "deferred": + txlock = "BEGIN" + default: + return nil, fmt.Errorf("Invalid _txlock: %v", val) + } + } + if !strings.HasPrefix(dsn, "file:") { dsn = dsn[:pos] } @@ -333,7 +352,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { return nil, Error{Code: ErrNo(rv)} } - conn := &SQLiteConn{db: db, loc: loc} + conn := &SQLiteConn{db: db, loc: loc, txlock: txlock} if len(d.Extensions) > 0 { rv = C.sqlite3_enable_load_extension(db, 1) diff --git a/sqlite3_test.go b/sqlite3_test.go index aa86011..deb9706 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -9,6 +9,7 @@ import ( "crypto/rand" "database/sql" "encoding/hex" + "fmt" "net/url" "os" "path/filepath" @@ -25,11 +26,17 @@ func TempFilename() string { return filepath.Join(os.TempDir(), "foo"+hex.EncodeToString(randBytes)+".db") } -func TestOpen(t *testing.T) { +func doTestOpen(t *testing.T, option string) (string, error) { + var url string tempFilename := TempFilename() - db, err := sql.Open("sqlite3", tempFilename) + if option != "" { + url = tempFilename + option + } else { + url = tempFilename + } + db, err := sql.Open("sqlite3", url) if err != nil { - t.Fatal("Failed to open database:", err) + return "Failed to open database:", err } defer os.Remove(tempFilename) defer db.Close() @@ -37,11 +44,38 @@ func TestOpen(t *testing.T) { _, err = db.Exec("drop table foo") _, err = db.Exec("create table foo (id integer)") if err != nil { - t.Fatal("Failed to create table:", err) + return "Failed to create table:", err } if stat, err := os.Stat(tempFilename); err != nil || stat.IsDir() { - t.Error("Failed to create ./foo.db") + return "Failed to create ./foo.db", nil + } + + return "", nil +} + +func TestOpen(t *testing.T) { + cases := map[string]bool{ + "": true, + "?_txlock=immediate": true, + "?_txlock=deferred": true, + "?_txlock=exclusive": true, + "?_txlock=bogus": false, + } + for option, expectedPass := range cases { + result, err := doTestOpen(t, option) + if result == "" { + if ! expectedPass { + errmsg := fmt.Sprintf("_txlock error not caught at dbOpen with option: %s", option) + t.Fatal(errmsg) + } + } else if expectedPass { + if err == nil { + t.Fatal(result) + } else { + t.Fatal(result, err) + } + } } } -- cgit v1.2.3 From dee1a37fe11067e97971fac64fa5bff4ef0934f4 Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Wed, 15 Apr 2015 16:26:27 +0900 Subject: Z suffix should be no-op --- sqlite3.go | 1 + sqlite3_test.go | 42 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3.go b/sqlite3.go index 5620b88..233e7e9 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -650,6 +650,7 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { switch rc.decltype[i] { case "timestamp", "datetime", "date": var t time.Time + s = strings.TrimSuffix(s, "Z") for _, format := range SQLiteTimestampFormats { if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil { t = timeVal diff --git a/sqlite3_test.go b/sqlite3_test.go index 74cafa8..3a7d162 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -8,7 +8,9 @@ package sqlite3 import ( "crypto/rand" "database/sql" + "database/sql/driver" "encoding/hex" + "errors" "fmt" "net/url" "os" @@ -65,7 +67,7 @@ func TestOpen(t *testing.T) { for option, expectedPass := range cases { result, err := doTestOpen(t, option) if result == "" { - if ! expectedPass { + if !expectedPass { errmsg := fmt.Sprintf("_txlock error not caught at dbOpen with option: %s", option) t.Fatal(errmsg) } @@ -1018,3 +1020,41 @@ func TestStringContainingZero(t *testing.T) { t.Error("Failed to db.QueryRow: not matched results") } } + +const CurrentTimeStamp = "2006-01-02 15:04:05" + +type TimeStamp struct{ *time.Time } + +func (t TimeStamp) Scan(value interface{}) error { + fmt.Printf("%T\n", value) + + var err error + switch v := value.(type) { + case string: + *t.Time, err = time.Parse(CurrentTimeStamp, v) + case []byte: + *t.Time, err = time.Parse(CurrentTimeStamp, string(v)) + default: + err = errors.New("invalid type for current_timestamp") + } + return err +} + +func (t TimeStamp) Value() (driver.Value, error) { + return t.Time.Format(CurrentTimeStamp), nil +} + +func TestDateTimeNow(t *testing.T) { + tempFilename := TempFilename() + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + var d time.Time + err = db.QueryRow("SELECT datetime('now')").Scan(TimeStamp{&d}) + if err != nil { + t.Fatal("Failed to scan datetime:", err) + } +} -- cgit v1.2.3 From f136f0c8dcab5adb705d1d356685c33983b5210b Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Wed, 15 Apr 2015 16:27:00 +0900 Subject: Remove debug code --- sqlite3_test.go | 2 -- 1 file changed, 2 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3_test.go b/sqlite3_test.go index 3a7d162..423f30e 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1026,8 +1026,6 @@ const CurrentTimeStamp = "2006-01-02 15:04:05" type TimeStamp struct{ *time.Time } func (t TimeStamp) Scan(value interface{}) error { - fmt.Printf("%T\n", value) - var err error switch v := value.(type) { case string: -- cgit v1.2.3 From cebbf42ff60a87f038ef03b3c5734817d41ef0d6 Mon Sep 17 00:00:00 2001 From: Lars Buitinck Date: Fri, 5 Jun 2015 16:02:14 +0200 Subject: Get reliable tempfile names from ioutil.TempFile Also makes them easier to spot (the tests tend to litter /tmp). --- sqlite3_fts3_test.go | 2 +- sqlite3_test.go | 55 ++++++++++++++++++++++++++-------------------------- 2 files changed, 29 insertions(+), 28 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3_fts3_test.go b/sqlite3_fts3_test.go index a1cd217..716a106 100644 --- a/sqlite3_fts3_test.go +++ b/sqlite3_fts3_test.go @@ -12,7 +12,7 @@ import ( ) func TestFTS3(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) diff --git a/sqlite3_test.go b/sqlite3_test.go index 423f30e..ee1ba0c 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -6,15 +6,13 @@ package sqlite3 import ( - "crypto/rand" "database/sql" "database/sql/driver" - "encoding/hex" "errors" "fmt" + "io/ioutil" "net/url" "os" - "path/filepath" "strings" "testing" "time" @@ -22,15 +20,18 @@ import ( "github.com/mattn/go-sqlite3/sqlite3_test" ) -func TempFilename() string { - randBytes := make([]byte, 16) - rand.Read(randBytes) - return filepath.Join(os.TempDir(), "foo"+hex.EncodeToString(randBytes)+".db") +func TempFilename(t *testing.T) string { + f, err := ioutil.TempFile("", "go-sqlite3-test-") + if err != nil { + t.Fatal(err) + } + f.Close() + return f.Name() } func doTestOpen(t *testing.T, option string) (string, error) { var url string - tempFilename := TempFilename() + tempFilename := TempFilename(t) if option != "" { url = tempFilename + option } else { @@ -82,7 +83,7 @@ func TestOpen(t *testing.T) { } func TestClose(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -108,7 +109,7 @@ func TestClose(t *testing.T) { } func TestInsert(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -147,7 +148,7 @@ func TestInsert(t *testing.T) { } func TestUpdate(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -212,7 +213,7 @@ func TestUpdate(t *testing.T) { } func TestDelete(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -273,7 +274,7 @@ func TestDelete(t *testing.T) { } func TestBooleanRoundtrip(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -322,7 +323,7 @@ func TestBooleanRoundtrip(t *testing.T) { } func TestTimestamp(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -405,7 +406,7 @@ func TestTimestamp(t *testing.T) { } func TestBoolean(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -497,7 +498,7 @@ func TestBoolean(t *testing.T) { } func TestFloat32(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -535,7 +536,7 @@ func TestFloat32(t *testing.T) { } func TestNull(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -567,7 +568,7 @@ func TestNull(t *testing.T) { } func TestTransaction(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -627,7 +628,7 @@ func TestTransaction(t *testing.T) { } func TestWAL(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -675,7 +676,7 @@ func TestWAL(t *testing.T) { func TestTimezoneConversion(t *testing.T) { zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} for _, tz := range zones { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename+"?_loc="+url.QueryEscape(tz)) if err != nil { t.Fatal("Failed to open database:", err) @@ -781,7 +782,7 @@ func TestSuite(t *testing.T) { // TODO: Execer & Queryer currently disabled // https://github.com/mattn/go-sqlite3/issues/82 func TestExecer(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -801,7 +802,7 @@ func TestExecer(t *testing.T) { } func TestQueryer(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -842,7 +843,7 @@ func TestQueryer(t *testing.T) { } func TestStress(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -880,7 +881,7 @@ func TestStress(t *testing.T) { func TestDateTimeLocal(t *testing.T) { zone := "Asia/Tokyo" - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename+"?_loc="+zone) if err != nil { t.Fatal("Failed to open database:", err) @@ -947,7 +948,7 @@ func TestVersion(t *testing.T) { } func TestNumberNamedParams(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -983,7 +984,7 @@ func TestNumberNamedParams(t *testing.T) { } func TestStringContainingZero(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -1043,7 +1044,7 @@ func (t TimeStamp) Value() (driver.Value, error) { } func TestDateTimeNow(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) -- cgit v1.2.3 From a3efcea001ff5a2fb956efc4272fea14e422086b Mon Sep 17 00:00:00 2001 From: Lars Buitinck Date: Fri, 5 Jun 2015 16:32:51 +0200 Subject: Clean up more tempfiles --- sqlite3_test.go | 1 + 1 file changed, 1 insertion(+) (limited to 'sqlite3_test.go') diff --git a/sqlite3_test.go b/sqlite3_test.go index ee1ba0c..f2d461b 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -32,6 +32,7 @@ func TempFilename(t *testing.T) string { func doTestOpen(t *testing.T, option string) (string, error) { var url string tempFilename := TempFilename(t) + defer os.Remove(tempFilename) if option != "" { url = tempFilename + option } else { -- cgit v1.2.3 From 5674e19d0587e09d0649f650671cb23627d20ba8 Mon Sep 17 00:00:00 2001 From: Lars Buitinck Date: Fri, 5 Jun 2015 16:33:25 +0200 Subject: Test read-only databases --- sqlite3_test.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) (limited to 'sqlite3_test.go') diff --git a/sqlite3_test.go b/sqlite3_test.go index f2d461b..43f7b4f 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -83,6 +83,27 @@ func TestOpen(t *testing.T) { } } +func TestReadonly(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + + db1, err := sql.Open("sqlite3", "file:" + tempFilename) + if err != nil { + t.Fatal(err) + } + db1.Exec("CREATE TABLE test (x int, y float)") + + db2, err := sql.Open("sqlite3", "file:" + tempFilename + "?mode=ro") + if err != nil { + t.Fatal(err) + } + _ = db2 + _, err = db2.Exec("INSERT INTO test VALUES (1, 3.14)") + if err == nil { + t.Fatal("didn't expect INSERT into read-only database to work") + } +} + func TestClose(t *testing.T) { tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) -- cgit v1.2.3 From cf8fa0af80e0d227c79ef2b4635e8d0d77432275 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Thu, 20 Aug 2015 23:08:48 -0700 Subject: Implement support for passing Go functions as custom functions to SQLite. Fixes #226. --- callback.go | 20 +++++ doc.go | 23 +++++- sqlite3.go | 191 ++++++++++++++++++++++++++++++++++++++++++++++++ sqlite3_test.go | 108 +++++++++++++++++++++++++++ sqlite3_test/sqltest.go | 6 +- 5 files changed, 342 insertions(+), 6 deletions(-) create mode 100644 callback.go (limited to 'sqlite3_test.go') diff --git a/callback.go b/callback.go new file mode 100644 index 0000000..938d7fe --- /dev/null +++ b/callback.go @@ -0,0 +1,20 @@ +// Copyright (C) 2014 Yasuhiro Matsumoto . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package sqlite3 + +/* +#include +*/ +import "C" + +import "unsafe" + +//export callbackTrampoline +func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) { + args := (*[1 << 30]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc] + fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx))) + fi.Call(ctx, args) +} diff --git a/doc.go b/doc.go index 51364c3..a45d852 100644 --- a/doc.go +++ b/doc.go @@ -33,7 +33,7 @@ extension for Regexp matcher operation. #include #include #include - + SQLITE_EXTENSION_INIT1 static void regexp_func(sqlite3_context *context, int argc, sqlite3_value **argv) { if (argc >= 2) { @@ -44,7 +44,7 @@ extension for Regexp matcher operation. int vec[500]; int n, rc; pcre* re = pcre_compile(pattern, 0, &errstr, &erroff, NULL); - rc = pcre_exec(re, NULL, target, strlen(target), 0, 0, vec, 500); + rc = pcre_exec(re, NULL, target, strlen(target), 0, 0, vec, 500); if (rc <= 0) { sqlite3_result_error(context, errstr, 0); return; @@ -52,7 +52,7 @@ extension for Regexp matcher operation. sqlite3_result_int(context, 1); } } - + #ifdef _WIN32 __declspec(dllexport) #endif @@ -91,5 +91,22 @@ you need to hook ConnectHook and get the SQLiteConn. }, }) +Go SQlite3 Extensions + +If you want to register Go functions as SQLite extension functions, +call RegisterFunction from ConnectHook. + + regex = func(re, s string) (bool, error) { + return regexp.MatchString(re, s) + } + sql.Register("sqlite3_with_go_func", + &sqlite3.SQLiteDriver{ + ConnectHook: func(conn *sqlite3.SQLiteConn) error { + return conn.RegisterFunc("regex", regex, true) + }, + }) + +See the documentation of RegisterFunc for more details. + */ package sqlite3 diff --git a/sqlite3.go b/sqlite3.go index d57d9fb..f995589 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -66,6 +66,15 @@ _sqlite3_step(sqlite3_stmt* stmt, long long* rowid, long long* changes) return rv; } +void _sqlite3_result_text(sqlite3_context* ctx, const char* s) { + sqlite3_result_text(ctx, s, -1, &free); +} + +void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l) { + sqlite3_result_blob(ctx, b, l, SQLITE_TRANSIENT); +} + +void callbackTrampoline(sqlite3_context*, int, sqlite3_value**); */ import "C" import ( @@ -75,6 +84,7 @@ import ( "fmt" "io" "net/url" + "reflect" "runtime" "strconv" "strings" @@ -120,6 +130,7 @@ type SQLiteConn struct { db *C.sqlite3 loc *time.Location txlock string + funcs []*functionInfo } // Tx struct. @@ -153,6 +164,89 @@ type SQLiteRows struct { cls bool } +type functionInfo struct { + f reflect.Value + argConverters []func(*C.sqlite3_value) (reflect.Value, error) +} + +func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) { + cstr := C.CString(err.Error()) + defer C.free(unsafe.Pointer(cstr)) + C.sqlite3_result_error(ctx, cstr, -1) +} + +func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { + var args []reflect.Value + for i, arg := range argv { + v, err := fi.argConverters[i](arg) + if err != nil { + fi.error(ctx, err) + return + } + args = append(args, v) + } + + ret := fi.f.Call(args) + + if len(ret) == 2 && ret[1].Interface() != nil { + fi.error(ctx, ret[1].Interface().(error)) + return + } + + res := ret[0].Interface() + // Normalize ret to one of the types sqlite knows. + switch r := res.(type) { + case int64, float64, []byte, string: + // Already the right type + case bool: + if r { + res = int64(1) + } else { + res = int64(0) + } + case int: + res = int64(r) + case uint: + res = int64(r) + case uint8: + res = int64(r) + case uint16: + res = int64(r) + case uint32: + res = int64(r) + case uint64: + res = int64(r) + case int8: + res = int64(r) + case int16: + res = int64(r) + case int32: + res = int64(r) + case float32: + res = float64(r) + default: + fi.error(ctx, errors.New("cannot convert returned type to sqlite type")) + return + } + + switch r := res.(type) { + case int64: + C.sqlite3_result_int64(ctx, C.sqlite3_int64(r)) + case float64: + C.sqlite3_result_double(ctx, C.double(r)) + case []byte: + if len(r) == 0 { + C.sqlite3_result_null(ctx) + } else { + C._sqlite3_result_blob(ctx, unsafe.Pointer(&r[0]), C.int(len(r))) + } + case string: + C._sqlite3_result_text(ctx, C.CString(r)) + default: + panic("unreachable") + } +} + // Commit transaction. func (tx *SQLiteTx) Commit() error { _, err := tx.c.exec("COMMIT") @@ -165,6 +259,103 @@ func (tx *SQLiteTx) Rollback() error { return err } +// RegisterFunc makes a Go function available as a SQLite function. +// +// The function must accept only arguments of type int64, float64, +// []byte or string, and return one value of any numeric type except +// complex, bool, []byte or string. Optionally, an error can be +// provided as a second return value. +// +// If pure is true. SQLite will assume that the function's return +// value depends only on its inputs, and make more aggressive +// optimizations in its queries. +func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) error { + var fi functionInfo + fi.f = reflect.ValueOf(impl) + t := fi.f.Type() + if t.Kind() != reflect.Func { + return errors.New("Non-function passed to RegisterFunc") + } + if t.IsVariadic() { + return errors.New("Variadic SQLite functions are not supported") + } + if t.NumOut() != 1 && t.NumOut() != 2 { + return errors.New("SQLite functions must return 1 or 2 values") + } + if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return errors.New("Second return value of SQLite function must be error") + } + + for i := 0; i < t.NumIn(); i++ { + arg := t.In(i) + var conv func(*C.sqlite3_value) (reflect.Value, error) + switch arg.Kind() { + case reflect.Int64: + conv = func(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { + return reflect.Value{}, fmt.Errorf("Argument %d to %s must be an INTEGER", i+1, name) + } + return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil + } + case reflect.Float64: + conv = func(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_FLOAT { + return reflect.Value{}, fmt.Errorf("Argument %d to %s must be a FLOAT", i+1, name) + } + return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil + } + case reflect.Slice: + if arg.Elem().Kind() != reflect.Uint8 { + return errors.New("The only supported slice type is []byte") + } + conv = func(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_BLOB: + l := C.sqlite3_value_bytes(v) + p := C.sqlite3_value_blob(v) + return reflect.ValueOf(C.GoBytes(p, l)), nil + case C.SQLITE_TEXT: + l := C.sqlite3_value_bytes(v) + c := unsafe.Pointer(C.sqlite3_value_text(v)) + return reflect.ValueOf(C.GoBytes(c, l)), nil + default: + return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name) + } + } + case reflect.String: + conv = func(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_BLOB: + l := C.sqlite3_value_bytes(v) + p := (*C.char)(C.sqlite3_value_blob(v)) + return reflect.ValueOf(C.GoStringN(p, l)), nil + case C.SQLITE_TEXT: + c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v))) + return reflect.ValueOf(C.GoString(c)), nil + default: + return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name) + } + } + } + fi.argConverters = append(fi.argConverters, conv) + } + + // fi must outlast the database connection, or we'll have dangling pointers. + c.funcs = append(c.funcs, &fi) + + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + opts := C.SQLITE_UTF8 + if pure { + opts |= C.SQLITE_DETERMINISTIC + } + rv := C.sqlite3_create_function_v2(c.db, cname, C.int(t.NumIn()), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil) + if rv != C.SQLITE_OK { + return c.lastError() + } + return nil +} + // AutoCommit return which currently auto commit or not. func (c *SQLiteConn) AutoCommit() bool { return int(C.sqlite3_get_autocommit(c.db)) != 0 diff --git a/sqlite3_test.go b/sqlite3_test.go index 423f30e..a58e373 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -15,7 +15,9 @@ import ( "net/url" "os" "path/filepath" + "regexp" "strings" + "sync" "testing" "time" @@ -1056,3 +1058,109 @@ func TestDateTimeNow(t *testing.T) { t.Fatal("Failed to scan datetime:", err) } } + +func TestFunctionRegistration(t *testing.T) { + custom_add := func(a, b int64) (int64, error) { + return a + b, nil + } + custom_regex := func(s, re string) bool { + matched, err := regexp.MatchString(re, s) + if err != nil { + // We should really return the error here, but this + // function is also testing single return value functions. + panic("Bad regexp") + } + return matched + } + + sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + if err := conn.RegisterFunc("custom_add", custom_add, true); err != nil { + return err + } + if err := conn.RegisterFunc("regexp", custom_regex, true); err != nil { + return err + } + return nil + }, + }) + db, err := sql.Open("sqlite3_FunctionRegistration", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + additions := []struct { + a, b, c int64 + }{ + {1, 1, 2}, + {1, 3, 4}, + {1, -1, 0}, + } + + for _, add := range additions { + var i int64 + err = db.QueryRow("SELECT custom_add($1, $2)", add.a, add.b).Scan(&i) + if err != nil { + t.Fatal("Failed to call custom_add:", err) + } + if i != add.c { + t.Fatalf("custom_add returned the wrong value, got %d, want %d", i, add.c) + } + } + + regexes := []struct { + re, in string + out bool + }{ + {".*", "foo", true}, + {"^foo.*", "foobar", true}, + {"^foo.*", "barfoo", false}, + } + + for _, re := range regexes { + var b bool + err = db.QueryRow("SELECT regexp($1, $2)", re.in, re.re).Scan(&b) + if err != nil { + t.Fatal("Failed to call regexp:", err) + } + if b != re.out { + t.Fatalf("regexp returned the wrong value, got %v, want %v", b, re.out) + } + } +} + +var customFunctionOnce sync.Once + +func BenchmarkCustomFunctions(b *testing.B) { + customFunctionOnce.Do(func() { + custom_add := func(a, b int64) (int64, error) { + return a + b, nil + } + + sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + // Impure function to force sqlite to reexecute it each time. + if err := conn.RegisterFunc("custom_add", custom_add, false); err != nil { + return err + } + return nil + }, + }) + }) + + db, err := sql.Open("sqlite3_BenchmarkCustomFunctions", ":memory:") + if err != nil { + b.Fatal("Failed to open database:", err) + } + defer db.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var i int64 + err = db.QueryRow("SELECT custom_add(1,2)").Scan(&i) + if err != nil { + b.Fatal("Failed to run custom add:", err) + } + } +} diff --git a/sqlite3_test/sqltest.go b/sqlite3_test/sqltest.go index fc82782..782e15f 100644 --- a/sqlite3_test/sqltest.go +++ b/sqlite3_test/sqltest.go @@ -318,7 +318,7 @@ func BenchmarkQuery(b *testing.B) { var i int var f float64 var s string -// var t time.Time + // var t time.Time if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil { panic(err) } @@ -331,7 +331,7 @@ func BenchmarkParams(b *testing.B) { var i int var f float64 var s string -// var t time.Time + // var t time.Time if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { panic(err) } @@ -350,7 +350,7 @@ func BenchmarkStmt(b *testing.B) { var i int var f float64 var s string -// var t time.Time + // var t time.Time if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { panic(err) } -- cgit v1.2.3 From 122ddb16de825ed3d989d25d4d7b2d2e278abdf6 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Fri, 21 Aug 2015 13:38:22 -0700 Subject: Move argument converters to callback.go, and optimize return value handling. A call now doesn't have to do any reflection, it just blindly invokes a bunch of argument and return value handlers to execute the translation, and the safety of the translation is determined at registration time. --- callback.go | 200 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- callback_test.go | 97 +++++++++++++++++++++++++++ sqlite3.go | 122 +++++---------------------------- sqlite3_test.go | 102 ++++++++++++++-------------- 4 files changed, 367 insertions(+), 154 deletions(-) create mode 100644 callback_test.go (limited to 'sqlite3_test.go') diff --git a/callback.go b/callback.go index 938d7fe..1692106 100644 --- a/callback.go +++ b/callback.go @@ -5,12 +5,25 @@ package sqlite3 +// You can't export a Go function to C and have definitions in the C +// preamble in the same file, so we have to have callbackTrampoline in +// its own file. Because we need a separate file anyway, the support +// code for SQLite custom functions is in here. + /* #include + +void _sqlite3_result_text(sqlite3_context* ctx, const char* s); +void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l); */ import "C" -import "unsafe" +import ( + "errors" + "fmt" + "reflect" + "unsafe" +) //export callbackTrampoline func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) { @@ -18,3 +31,188 @@ func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx))) fi.Call(ctx, args) } + +// This is only here so that tests can refer to it. +type callbackArgRaw C.sqlite3_value + +type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error) + +type callbackArgCast struct { + f callbackArgConverter + typ reflect.Type +} + +func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) { + val, err := c.f(v) + if err != nil { + return reflect.Value{}, err + } + if !val.Type().ConvertibleTo(c.typ) { + return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ) + } + return val.Convert(c.typ), nil +} + +func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { + return reflect.Value{}, fmt.Errorf("argument must be an INTEGER") + } + return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil +} + +func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { + return reflect.Value{}, fmt.Errorf("argument must be an INTEGER") + } + i := int64(C.sqlite3_value_int64(v)) + val := false + if i != 0 { + val = true + } + return reflect.ValueOf(val), nil +} + +func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_FLOAT { + return reflect.Value{}, fmt.Errorf("argument must be a FLOAT") + } + return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil +} + +func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_BLOB: + l := C.sqlite3_value_bytes(v) + p := C.sqlite3_value_blob(v) + return reflect.ValueOf(C.GoBytes(p, l)), nil + case C.SQLITE_TEXT: + l := C.sqlite3_value_bytes(v) + c := unsafe.Pointer(C.sqlite3_value_text(v)) + return reflect.ValueOf(C.GoBytes(c, l)), nil + default: + return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT") + } +} + +func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_BLOB: + l := C.sqlite3_value_bytes(v) + p := (*C.char)(C.sqlite3_value_blob(v)) + return reflect.ValueOf(C.GoStringN(p, l)), nil + case C.SQLITE_TEXT: + c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v))) + return reflect.ValueOf(C.GoString(c)), nil + default: + return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT") + } +} + +func callbackArg(typ reflect.Type) (callbackArgConverter, error) { + switch typ.Kind() { + case reflect.Slice: + if typ.Elem().Kind() != reflect.Uint8 { + return nil, errors.New("the only supported slice type is []byte") + } + return callbackArgBytes, nil + case reflect.String: + return callbackArgString, nil + case reflect.Bool: + return callbackArgBool, nil + case reflect.Int64: + return callbackArgInt64, nil + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: + c := callbackArgCast{callbackArgInt64, typ} + return c.Run, nil + case reflect.Float64: + return callbackArgFloat64, nil + case reflect.Float32: + c := callbackArgCast{callbackArgFloat64, typ} + return c.Run, nil + default: + return nil, fmt.Errorf("don't know how to convert to %s", typ) + } +} + +type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error + +func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error { + switch v.Type().Kind() { + case reflect.Int64: + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: + v = v.Convert(reflect.TypeOf(int64(0))) + case reflect.Bool: + b := v.Interface().(bool) + if b { + v = reflect.ValueOf(int64(1)) + } else { + v = reflect.ValueOf(int64(0)) + } + default: + return fmt.Errorf("cannot convert %s to INTEGER", v.Type()) + } + + C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64))) + return nil +} + +func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error { + switch v.Type().Kind() { + case reflect.Float64: + case reflect.Float32: + v = v.Convert(reflect.TypeOf(float64(0))) + default: + return fmt.Errorf("cannot convert %s to FLOAT", v.Type()) + } + + C.sqlite3_result_double(ctx, C.double(v.Interface().(float64))) + return nil +} + +func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error { + if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 { + return fmt.Errorf("cannot convert %s to BLOB", v.Type()) + } + i := v.Interface() + if i == nil || len(i.([]byte)) == 0 { + C.sqlite3_result_null(ctx) + } else { + bs := i.([]byte) + C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs))) + } + return nil +} + +func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error { + if v.Type().Kind() != reflect.String { + return fmt.Errorf("cannot convert %s to TEXT", v.Type()) + } + C._sqlite3_result_text(ctx, C.CString(v.Interface().(string))) + return nil +} + +func callbackRet(typ reflect.Type) (callbackRetConverter, error) { + switch typ.Kind() { + case reflect.Slice: + if typ.Elem().Kind() != reflect.Uint8 { + return nil, errors.New("the only supported slice type is []byte") + } + return callbackRetBlob, nil + case reflect.String: + return callbackRetText, nil + case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: + return callbackRetInteger, nil + case reflect.Float32, reflect.Float64: + return callbackRetFloat, nil + default: + return nil, fmt.Errorf("don't know how to convert to %s", typ) + } +} + +// Test support code. Tests are not allowed to import "C", so we can't +// declare any functions that use C.sqlite3_value. +func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter { + return func(*C.sqlite3_value) (reflect.Value, error) { + return v, err + } +} diff --git a/callback_test.go b/callback_test.go new file mode 100644 index 0000000..5c61f44 --- /dev/null +++ b/callback_test.go @@ -0,0 +1,97 @@ +package sqlite3 + +import ( + "errors" + "math" + "reflect" + "testing" +) + +func TestCallbackArgCast(t *testing.T) { + intConv := callbackSyntheticForTests(reflect.ValueOf(int64(math.MaxInt64)), nil) + floatConv := callbackSyntheticForTests(reflect.ValueOf(float64(math.MaxFloat64)), nil) + errConv := callbackSyntheticForTests(reflect.Value{}, errors.New("test")) + + tests := []struct { + f callbackArgConverter + o reflect.Value + }{ + {intConv, reflect.ValueOf(int8(-1))}, + {intConv, reflect.ValueOf(int16(-1))}, + {intConv, reflect.ValueOf(int32(-1))}, + {intConv, reflect.ValueOf(uint8(math.MaxUint8))}, + {intConv, reflect.ValueOf(uint16(math.MaxUint16))}, + {intConv, reflect.ValueOf(uint32(math.MaxUint32))}, + // Special case, int64->uint64 is only 1<<63 - 1, not 1<<64 - 1 + {intConv, reflect.ValueOf(uint64(math.MaxInt64))}, + {floatConv, reflect.ValueOf(float32(math.Inf(1)))}, + } + + for _, test := range tests { + conv := callbackArgCast{test.f, test.o.Type()} + val, err := conv.Run(nil) + if err != nil { + t.Errorf("Couldn't convert to %s: %s", test.o.Type(), err) + } else if !reflect.DeepEqual(val.Interface(), test.o.Interface()) { + t.Errorf("Unexpected result from converting to %s: got %v, want %v", test.o.Type(), val.Interface(), test.o.Interface()) + } + } + + conv := callbackArgCast{errConv, reflect.TypeOf(int8(0))} + _, err := conv.Run(nil) + if err == nil { + t.Errorf("Expected error during callbackArgCast, but got none") + } +} + +func TestCallbackConverters(t *testing.T) { + tests := []struct { + v interface{} + err bool + }{ + // Unfortunately, we can't tell which converter was returned, + // but we can at least check which types can be converted. + {[]byte{0}, false}, + {"text", false}, + {true, false}, + {int8(0), false}, + {int16(0), false}, + {int32(0), false}, + {int64(0), false}, + {uint8(0), false}, + {uint16(0), false}, + {uint32(0), false}, + {uint64(0), false}, + {int(0), false}, + {uint(0), false}, + {float64(0), false}, + {float32(0), false}, + + {func() {}, true}, + {complex64(complex(0, 0)), true}, + {complex128(complex(0, 0)), true}, + {struct{}{}, true}, + {map[string]string{}, true}, + {[]string{}, true}, + {(*int8)(nil), true}, + {make(chan int), true}, + } + + for _, test := range tests { + _, err := callbackArg(reflect.TypeOf(test.v)) + if test.err && err == nil { + t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v)) + } else if !test.err && err != nil { + t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err) + } + } + + for _, test := range tests { + _, err := callbackRet(reflect.TypeOf(test.v)) + if test.err && err == nil { + t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v)) + } else if !test.err && err != nil { + t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err) + } + } +} diff --git a/sqlite3.go b/sqlite3.go index f995589..174a3ee 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -166,7 +166,8 @@ type SQLiteRows struct { type functionInfo struct { f reflect.Value - argConverters []func(*C.sqlite3_value) (reflect.Value, error) + argConverters []callbackArgConverter + retConverter callbackRetConverter } func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) { @@ -193,58 +194,11 @@ func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { return } - res := ret[0].Interface() - // Normalize ret to one of the types sqlite knows. - switch r := res.(type) { - case int64, float64, []byte, string: - // Already the right type - case bool: - if r { - res = int64(1) - } else { - res = int64(0) - } - case int: - res = int64(r) - case uint: - res = int64(r) - case uint8: - res = int64(r) - case uint16: - res = int64(r) - case uint32: - res = int64(r) - case uint64: - res = int64(r) - case int8: - res = int64(r) - case int16: - res = int64(r) - case int32: - res = int64(r) - case float32: - res = float64(r) - default: - fi.error(ctx, errors.New("cannot convert returned type to sqlite type")) + err := fi.retConverter(ctx, ret[0]) + if err != nil { + fi.error(ctx, err) return } - - switch r := res.(type) { - case int64: - C.sqlite3_result_int64(ctx, C.sqlite3_int64(r)) - case float64: - C.sqlite3_result_double(ctx, C.double(r)) - case []byte: - if len(r) == 0 { - C.sqlite3_result_null(ctx) - } else { - C._sqlite3_result_blob(ctx, unsafe.Pointer(&r[0]), C.int(len(r))) - } - case string: - C._sqlite3_result_text(ctx, C.CString(r)) - default: - panic("unreachable") - } } // Commit transaction. @@ -261,10 +215,10 @@ func (tx *SQLiteTx) Rollback() error { // RegisterFunc makes a Go function available as a SQLite function. // -// The function must accept only arguments of type int64, float64, -// []byte or string, and return one value of any numeric type except -// complex, bool, []byte or string. Optionally, an error can be -// provided as a second return value. +// The function can accept arguments of any real numeric type +// (i.e. not complex), as well as []byte and string. It must return a +// value of one of those types, and optionally an error as a second +// value. // // If pure is true. SQLite will assume that the function's return // value depends only on its inputs, and make more aggressive @@ -287,59 +241,19 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro } for i := 0; i < t.NumIn(); i++ { - arg := t.In(i) - var conv func(*C.sqlite3_value) (reflect.Value, error) - switch arg.Kind() { - case reflect.Int64: - conv = func(v *C.sqlite3_value) (reflect.Value, error) { - if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { - return reflect.Value{}, fmt.Errorf("Argument %d to %s must be an INTEGER", i+1, name) - } - return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil - } - case reflect.Float64: - conv = func(v *C.sqlite3_value) (reflect.Value, error) { - if C.sqlite3_value_type(v) != C.SQLITE_FLOAT { - return reflect.Value{}, fmt.Errorf("Argument %d to %s must be a FLOAT", i+1, name) - } - return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil - } - case reflect.Slice: - if arg.Elem().Kind() != reflect.Uint8 { - return errors.New("The only supported slice type is []byte") - } - conv = func(v *C.sqlite3_value) (reflect.Value, error) { - switch C.sqlite3_value_type(v) { - case C.SQLITE_BLOB: - l := C.sqlite3_value_bytes(v) - p := C.sqlite3_value_blob(v) - return reflect.ValueOf(C.GoBytes(p, l)), nil - case C.SQLITE_TEXT: - l := C.sqlite3_value_bytes(v) - c := unsafe.Pointer(C.sqlite3_value_text(v)) - return reflect.ValueOf(C.GoBytes(c, l)), nil - default: - return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name) - } - } - case reflect.String: - conv = func(v *C.sqlite3_value) (reflect.Value, error) { - switch C.sqlite3_value_type(v) { - case C.SQLITE_BLOB: - l := C.sqlite3_value_bytes(v) - p := (*C.char)(C.sqlite3_value_blob(v)) - return reflect.ValueOf(C.GoStringN(p, l)), nil - case C.SQLITE_TEXT: - c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v))) - return reflect.ValueOf(C.GoString(c)), nil - default: - return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name) - } - } + conv, err := callbackArg(t.In(i)) + if err != nil { + return err } fi.argConverters = append(fi.argConverters, conv) } + conv, err := callbackRet(t.Out(0)) + if err != nil { + return err + } + fi.retConverter = conv + // fi must outlast the database connection, or we'll have dangling pointers. c.funcs = append(c.funcs, &fi) diff --git a/sqlite3_test.go b/sqlite3_test.go index a58e373..e8dfe5c 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -15,6 +15,7 @@ import ( "net/url" "os" "path/filepath" + "reflect" "regexp" "strings" "sync" @@ -1060,25 +1061,41 @@ func TestDateTimeNow(t *testing.T) { } func TestFunctionRegistration(t *testing.T) { - custom_add := func(a, b int64) (int64, error) { - return a + b, nil - } - custom_regex := func(s, re string) bool { - matched, err := regexp.MatchString(re, s) - if err != nil { - // We should really return the error here, but this - // function is also testing single return value functions. - panic("Bad regexp") - } - return matched + addi_8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) } + addi_64 := func(a, b int64) int64 { return a + b } + addu_8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) } + addu_64 := func(a, b uint64) uint64 { return a + b } + addiu := func(a int, b uint) int64 { return int64(a) + int64(b) } + addf_32_64 := func(a float32, b float64) float64 { return float64(a) + b } + not := func(a bool) bool { return !a } + regex := func(re, s string) (bool, error) { + return regexp.MatchString(re, s) } sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ ConnectHook: func(conn *SQLiteConn) error { - if err := conn.RegisterFunc("custom_add", custom_add, true); err != nil { + if err := conn.RegisterFunc("addi_8_16_32", addi_8_16_32, true); err != nil { + return err + } + if err := conn.RegisterFunc("addi_64", addi_64, true); err != nil { + return err + } + if err := conn.RegisterFunc("addu_8_16_32", addu_8_16_32, true); err != nil { return err } - if err := conn.RegisterFunc("regexp", custom_regex, true); err != nil { + if err := conn.RegisterFunc("addu_64", addu_64, true); err != nil { + return err + } + if err := conn.RegisterFunc("addiu", addiu, true); err != nil { + return err + } + if err := conn.RegisterFunc("addf_32_64", addf_32_64, true); err != nil { + return err + } + if err := conn.RegisterFunc("not", not, true); err != nil { + return err + } + if err := conn.RegisterFunc("regex", regex, true); err != nil { return err } return nil @@ -1090,42 +1107,29 @@ func TestFunctionRegistration(t *testing.T) { } defer db.Close() - additions := []struct { - a, b, c int64 + ops := []struct { + query string + expected interface{} }{ - {1, 1, 2}, - {1, 3, 4}, - {1, -1, 0}, - } - - for _, add := range additions { - var i int64 - err = db.QueryRow("SELECT custom_add($1, $2)", add.a, add.b).Scan(&i) + {"SELECT addi_8_16_32(1,2)", int32(3)}, + {"SELECT addi_64(1,2)", int64(3)}, + {"SELECT addu_8_16_32(1,2)", uint32(3)}, + {"SELECT addu_64(1,2)", uint64(3)}, + {"SELECT addiu(1,2)", int64(3)}, + {"SELECT addf_32_64(1.5,1.5)", float64(3)}, + {"SELECT not(1)", false}, + {"SELECT not(0)", true}, + {`SELECT regex("^foo.*", "foobar")`, true}, + {`SELECT regex("^foo.*", "barfoobar")`, false}, + } + + for _, op := range ops { + ret := reflect.New(reflect.TypeOf(op.expected)) + err = db.QueryRow(op.query).Scan(ret.Interface()) if err != nil { - t.Fatal("Failed to call custom_add:", err) - } - if i != add.c { - t.Fatalf("custom_add returned the wrong value, got %d, want %d", i, add.c) - } - } - - regexes := []struct { - re, in string - out bool - }{ - {".*", "foo", true}, - {"^foo.*", "foobar", true}, - {"^foo.*", "barfoo", false}, - } - - for _, re := range regexes { - var b bool - err = db.QueryRow("SELECT regexp($1, $2)", re.in, re.re).Scan(&b) - if err != nil { - t.Fatal("Failed to call regexp:", err) - } - if b != re.out { - t.Fatalf("regexp returned the wrong value, got %v, want %v", b, re.out) + t.Errorf("Query %q failed: %s", op.query, err) + } else if !reflect.DeepEqual(ret.Elem().Interface(), op.expected) { + t.Errorf("Query %q returned wrong value: got %v (%T), want %v (%T)", op.query, ret.Elem().Interface(), ret.Elem().Interface(), op.expected, op.expected) } } } @@ -1134,8 +1138,8 @@ var customFunctionOnce sync.Once func BenchmarkCustomFunctions(b *testing.B) { customFunctionOnce.Do(func() { - custom_add := func(a, b int64) (int64, error) { - return a + b, nil + custom_add := func(a, b int64) int64 { + return a + b } sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{ -- cgit v1.2.3 From 566f63a43a314f8dcd758dba8c40dc11edc27a5e Mon Sep 17 00:00:00 2001 From: David Anderson Date: Fri, 21 Aug 2015 16:34:55 -0700 Subject: Implement support for variadic functions. Currently, the variadic part must all be the same type, because there's no "generic" arg converter. --- sqlite3.go | 52 ++++++++++++++++++++++++++++++++++++++++++---------- sqlite3_test.go | 13 +++++++++++++ 2 files changed, 55 insertions(+), 10 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3.go b/sqlite3.go index 174a3ee..8bb9826 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -165,9 +165,10 @@ type SQLiteRows struct { } type functionInfo struct { - f reflect.Value - argConverters []callbackArgConverter - retConverter callbackRetConverter + f reflect.Value + argConverters []callbackArgConverter + variadicConverter callbackArgConverter + retConverter callbackRetConverter } func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) { @@ -178,7 +179,12 @@ func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) { func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { var args []reflect.Value - for i, arg := range argv { + + if len(argv) < len(fi.argConverters) { + fi.error(ctx, fmt.Errorf("function requires at least %d arguments", len(fi.argConverters))) + } + + for i, arg := range argv[:len(fi.argConverters)] { v, err := fi.argConverters[i](arg) if err != nil { fi.error(ctx, err) @@ -187,6 +193,17 @@ func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { args = append(args, v) } + if fi.variadicConverter != nil { + for _, arg := range argv[len(fi.argConverters):] { + v, err := fi.variadicConverter(arg) + if err != nil { + fi.error(ctx, err) + return + } + args = append(args, v) + } + } + ret := fi.f.Call(args) if len(ret) == 2 && ret[1].Interface() != nil { @@ -218,7 +235,8 @@ func (tx *SQLiteTx) Rollback() error { // The function can accept arguments of any real numeric type // (i.e. not complex), as well as []byte and string. It must return a // value of one of those types, and optionally an error as a second -// value. +// value. Variadic functions are allowed, if the variadic argument is +// one of the allowed types. // // If pure is true. SQLite will assume that the function's return // value depends only on its inputs, and make more aggressive @@ -230,9 +248,6 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro if t.Kind() != reflect.Func { return errors.New("Non-function passed to RegisterFunc") } - if t.IsVariadic() { - return errors.New("Variadic SQLite functions are not supported") - } if t.NumOut() != 1 && t.NumOut() != 2 { return errors.New("SQLite functions must return 1 or 2 values") } @@ -240,7 +255,12 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro return errors.New("Second return value of SQLite function must be error") } - for i := 0; i < t.NumIn(); i++ { + numArgs := t.NumIn() + if t.IsVariadic() { + numArgs-- + } + + for i := 0; i < numArgs; i++ { conv, err := callbackArg(t.In(i)) if err != nil { return err @@ -248,6 +268,18 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro fi.argConverters = append(fi.argConverters, conv) } + if t.IsVariadic() { + conv, err := callbackArg(t.In(numArgs).Elem()) + if err != nil { + return err + } + fi.variadicConverter = conv + // Pass -1 to sqlite so that it allows any number of + // arguments. The call helper verifies that the minimum number + // of arguments is present for variadic functions. + numArgs = -1 + } + conv, err := callbackRet(t.Out(0)) if err != nil { return err @@ -263,7 +295,7 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro if pure { opts |= C.SQLITE_DETERMINISTIC } - rv := C.sqlite3_create_function_v2(c.db, cname, C.int(t.NumIn()), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil) + rv := C.sqlite3_create_function_v2(c.db, cname, C.int(numArgs), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil) if rv != C.SQLITE_OK { return c.lastError() } diff --git a/sqlite3_test.go b/sqlite3_test.go index e8dfe5c..a563c08 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1071,6 +1071,13 @@ func TestFunctionRegistration(t *testing.T) { regex := func(re, s string) (bool, error) { return regexp.MatchString(re, s) } + variadic := func(a, b int64, c ...int64) int64 { + ret := a + b + for _, d := range c { + ret += d + } + return ret + } sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ ConnectHook: func(conn *SQLiteConn) error { @@ -1098,6 +1105,9 @@ func TestFunctionRegistration(t *testing.T) { if err := conn.RegisterFunc("regex", regex, true); err != nil { return err } + if err := conn.RegisterFunc("variadic", variadic, true); err != nil { + return err + } return nil }, }) @@ -1121,6 +1131,9 @@ func TestFunctionRegistration(t *testing.T) { {"SELECT not(0)", true}, {`SELECT regex("^foo.*", "foobar")`, true}, {`SELECT regex("^foo.*", "barfoobar")`, false}, + {"SELECT variadic(1,2)", int64(3)}, + {"SELECT variadic(1,2,3,4)", int64(10)}, + {"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)}, } for _, op := range ops { -- cgit v1.2.3 From b037a616903746de8e647f53503d4edca29192ec Mon Sep 17 00:00:00 2001 From: David Anderson Date: Fri, 21 Aug 2015 17:12:18 -0700 Subject: Add support for interface{} arguments in Go SQLite functions. This enabled support for functions like Foo(a interface{}) and Bar(a ...interface{}). --- callback.go | 24 ++++++++++++++++++++++++ sqlite3.go | 13 ++++++++----- sqlite3_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 5 deletions(-) (limited to 'sqlite3_test.go') diff --git a/callback.go b/callback.go index 1692106..b1704fe 100644 --- a/callback.go +++ b/callback.go @@ -108,8 +108,32 @@ func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) { } } +func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_INTEGER: + return callbackArgInt64(v) + case C.SQLITE_FLOAT: + return callbackArgFloat64(v) + case C.SQLITE_TEXT: + return callbackArgString(v) + case C.SQLITE_BLOB: + return callbackArgBytes(v) + case C.SQLITE_NULL: + // Interpret NULL as a nil byte slice. + var ret []byte + return reflect.ValueOf(ret), nil + default: + panic("unreachable") + } +} + func callbackArg(typ reflect.Type) (callbackArgConverter, error) { switch typ.Kind() { + case reflect.Interface: + if typ.NumMethod() != 0 { + return nil, errors.New("the only supported interface type is interface{}") + } + return callbackArgGeneric, nil case reflect.Slice: if typ.Elem().Kind() != reflect.Uint8 { return nil, errors.New("the only supported slice type is []byte") diff --git a/sqlite3.go b/sqlite3.go index 8bb9826..73e67e3 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -232,11 +232,14 @@ func (tx *SQLiteTx) Rollback() error { // RegisterFunc makes a Go function available as a SQLite function. // -// The function can accept arguments of any real numeric type -// (i.e. not complex), as well as []byte and string. It must return a -// value of one of those types, and optionally an error as a second -// value. Variadic functions are allowed, if the variadic argument is -// one of the allowed types. +// The Go function can have arguments of the following types: any +// numeric type except complex, bool, []byte, string and +// interface{}. interface{} arguments are given the direct translation +// of the SQLite data type: int64 for INTEGER, float64 for FLOAT, +// []byte for BLOB, string for TEXT. +// +// The function can additionally be variadic, as long as the type of +// the variadic argument is one of the above. // // If pure is true. SQLite will assume that the function's return // value depends only on its inputs, and make more aggressive diff --git a/sqlite3_test.go b/sqlite3_test.go index a563c08..62db05b 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1071,6 +1071,20 @@ func TestFunctionRegistration(t *testing.T) { regex := func(re, s string) (bool, error) { return regexp.MatchString(re, s) } + generic := func(a interface{}) int64 { + switch a.(type) { + case int64: + return 1 + case float64: + return 2 + case []byte: + return 3 + case string: + return 4 + default: + panic("unreachable") + } + } variadic := func(a, b int64, c ...int64) int64 { ret := a + b for _, d := range c { @@ -1078,6 +1092,9 @@ func TestFunctionRegistration(t *testing.T) { } return ret } + variadicGeneric := func(a ...interface{}) int64 { + return int64(len(a)) + } sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ ConnectHook: func(conn *SQLiteConn) error { @@ -1105,9 +1122,15 @@ func TestFunctionRegistration(t *testing.T) { if err := conn.RegisterFunc("regex", regex, true); err != nil { return err } + if err := conn.RegisterFunc("generic", generic, true); err != nil { + return err + } if err := conn.RegisterFunc("variadic", variadic, true); err != nil { return err } + if err := conn.RegisterFunc("variadicGeneric", variadicGeneric, true); err != nil { + return err + } return nil }, }) @@ -1131,9 +1154,14 @@ func TestFunctionRegistration(t *testing.T) { {"SELECT not(0)", true}, {`SELECT regex("^foo.*", "foobar")`, true}, {`SELECT regex("^foo.*", "barfoobar")`, false}, + {"SELECT generic(1)", int64(1)}, + {"SELECT generic(1.1)", int64(2)}, + {`SELECT generic(NULL)`, int64(3)}, + {`SELECT generic("foo")`, int64(4)}, {"SELECT variadic(1,2)", int64(3)}, {"SELECT variadic(1,2,3,4)", int64(10)}, {"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)}, + {`SELECT variadicGeneric(1,"foo",2.3, NULL)`, int64(4)}, } for _, op := range ops { -- cgit v1.2.3 From 296ddf7cd7934b7929ec495757f11a65eafdb215 Mon Sep 17 00:00:00 2001 From: mattn Date: Tue, 25 Aug 2015 23:40:01 +0900 Subject: Fix test. Close #216 When one goroutine close db that opended as :memory:, session will be lost. So another goroutine can't refer the last session. goroutine . --- sqlite3_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3_test.go b/sqlite3_test.go index 423f30e..9760f01 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -769,10 +769,12 @@ func TestTimezoneConversion(t *testing.T) { } func TestSuite(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + tempFilename := TempFilename() + db, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999") if err != nil { t.Fatal(err) } + defer os.Remove(tempFilename) defer db.Close() sqlite3_test.RunTests(t, db, sqlite3_test.SQLITE) -- cgit v1.2.3 From 26917df7a6010a157123c4bf60e3d57eff2948e4 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Fri, 21 Aug 2015 20:31:41 -0700 Subject: Implement support for aggregation functions implemented in Go. --- _example/go_custom_funcs/go_custom_funcs | Bin 0 -> 6601208 bytes _example/go_custom_funcs/main.go | 133 +++++++++++++++++ callback.go | 47 ++++++ sqlite3.go | 243 ++++++++++++++++++++++++++----- sqlite3_test.go | 59 ++++++++ 5 files changed, 449 insertions(+), 33 deletions(-) create mode 100755 _example/go_custom_funcs/go_custom_funcs create mode 100644 _example/go_custom_funcs/main.go (limited to 'sqlite3_test.go') diff --git a/_example/go_custom_funcs/go_custom_funcs b/_example/go_custom_funcs/go_custom_funcs new file mode 100755 index 0000000..b6be764 Binary files /dev/null and b/_example/go_custom_funcs/go_custom_funcs differ diff --git a/_example/go_custom_funcs/main.go b/_example/go_custom_funcs/main.go new file mode 100644 index 0000000..85657e6 --- /dev/null +++ b/_example/go_custom_funcs/main.go @@ -0,0 +1,133 @@ +package main + +import ( + "database/sql" + "fmt" + "log" + "math" + "math/rand" + + sqlite "github.com/mattn/go-sqlite3" +) + +// Computes x^y +func pow(x, y int64) int64 { + return int64(math.Pow(float64(x), float64(y))) +} + +// Computes the bitwise exclusive-or of all its arguments +func xor(xs ...int64) int64 { + var ret int64 + for _, x := range xs { + ret ^= x + } + return ret +} + +// Returns a random number. It's actually deterministic here because +// we don't seed the RNG, but it's an example of a non-pure function +// from SQLite's POV. +func getrand() int64 { + return rand.Int63() +} + +// Computes the standard deviation of a GROUPed BY set of values +type stddev struct { + xs []int64 + // Running average calculation + sum int64 + n int64 +} + +func newStddev() *stddev { return &stddev{} } + +func (s *stddev) Step(x int64) { + s.xs = append(s.xs, x) + s.sum += x + s.n++ +} + +func (s *stddev) Done() float64 { + mean := float64(s.sum) / float64(s.n) + var sqDiff []float64 + for _, x := range s.xs { + sqDiff = append(sqDiff, math.Pow(float64(x)-mean, 2)) + } + var dev float64 + for _, x := range sqDiff { + dev += x + } + dev /= float64(len(sqDiff)) + return math.Sqrt(dev) +} + +func main() { + sql.Register("sqlite3_custom", &sqlite.SQLiteDriver{ + ConnectHook: func(conn *sqlite.SQLiteConn) error { + if err := conn.RegisterFunc("pow", pow, true); err != nil { + return err + } + if err := conn.RegisterFunc("xor", xor, true); err != nil { + return err + } + if err := conn.RegisterFunc("rand", getrand, false); err != nil { + return err + } + if err := conn.RegisterAggregator("stddev", newStddev, true); err != nil { + return err + } + return nil + }, + }) + + db, err := sql.Open("sqlite3_custom", ":memory:") + if err != nil { + log.Fatal("Failed to open database:", err) + } + defer db.Close() + + var i int64 + err = db.QueryRow("SELECT pow(2,3)").Scan(&i) + if err != nil { + log.Fatal("POW query error:", err) + } + fmt.Println("pow(2,3) =", i) // 8 + + err = db.QueryRow("SELECT xor(1,2,3,4,5,6)").Scan(&i) + if err != nil { + log.Fatal("XOR query error:", err) + } + fmt.Println("xor(1,2,3,4,5) =", i) // 7 + + err = db.QueryRow("SELECT rand()").Scan(&i) + if err != nil { + log.Fatal("RAND query error:", err) + } + fmt.Println("rand() =", i) // pseudorandom + + _, err = db.Exec("create table foo (department integer, profits integer)") + if err != nil { + log.Fatal("Failed to create table:", err) + } + _, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115)") + if err != nil { + log.Fatal("Failed to insert records:", err) + } + + rows, err := db.Query("select department, stddev(profits) from foo group by department") + if err != nil { + log.Fatal("STDDEV query error:", err) + } + defer rows.Close() + for rows.Next() { + var dept int64 + var dev float64 + if err := rows.Scan(&dept, &dev); err != nil { + log.Fatal(err) + } + fmt.Printf("dept=%d stddev=%f\n", dept, dev) + } + if err := rows.Err(); err != nil { + log.Fatal(err) + } +} diff --git a/callback.go b/callback.go index b1704fe..61fc8d1 100644 --- a/callback.go +++ b/callback.go @@ -12,6 +12,7 @@ package sqlite3 /* #include +#include void _sqlite3_result_text(sqlite3_context* ctx, const char* s); void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l); @@ -32,6 +33,19 @@ func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value fi.Call(ctx, args) } +//export stepTrampoline +func stepTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) { + args := (*[1 << 30]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc] + ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx))) + ai.Step(ctx, args) +} + +//export doneTrampoline +func doneTrampoline(ctx *C.sqlite3_context) { + ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx))) + ai.Done(ctx) +} + // This is only here so that tests can refer to it. type callbackArgRaw C.sqlite3_value @@ -158,6 +172,33 @@ func callbackArg(typ reflect.Type) (callbackArgConverter, error) { } } +func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) { + var args []reflect.Value + + if len(argv) < len(converters) { + return nil, fmt.Errorf("function requires at least %d arguments", len(converters)) + } + + for i, arg := range argv[:len(converters)] { + v, err := converters[i](arg) + if err != nil { + return nil, err + } + args = append(args, v) + } + + if variadic != nil { + for _, arg := range argv[len(converters):] { + v, err := variadic(arg) + if err != nil { + return nil, err + } + args = append(args, v) + } + } + return args, nil +} + type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error { @@ -233,6 +274,12 @@ func callbackRet(typ reflect.Type) (callbackRetConverter, error) { } } +func callbackError(ctx *C.sqlite3_context, err error) { + cstr := C.CString(err.Error()) + defer C.free(unsafe.Pointer(cstr)) + C.sqlite3_result_error(ctx, cstr, -1) +} + // Test support code. Tests are not allowed to import "C", so we can't // declare any functions that use C.sqlite3_value. func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter { diff --git a/sqlite3.go b/sqlite3.go index 73e67e3..8d2faca 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -75,6 +75,8 @@ void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l) { } void callbackTrampoline(sqlite3_context*, int, sqlite3_value**); +void stepTrampoline(sqlite3_context*, int, sqlite3_value**); +void doneTrampoline(sqlite3_context*); */ import "C" import ( @@ -127,10 +129,11 @@ type SQLiteDriver struct { // Conn struct. type SQLiteConn struct { - db *C.sqlite3 - loc *time.Location - txlock string - funcs []*functionInfo + db *C.sqlite3 + loc *time.Location + txlock string + funcs []*functionInfo + aggregators []*aggInfo } // Tx struct. @@ -171,49 +174,96 @@ type functionInfo struct { retConverter callbackRetConverter } -func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) { - cstr := C.CString(err.Error()) - defer C.free(unsafe.Pointer(cstr)) - C.sqlite3_result_error(ctx, cstr, -1) -} - func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { - var args []reflect.Value + args, err := callbackConvertArgs(argv, fi.argConverters, fi.variadicConverter) + if err != nil { + callbackError(ctx, err) + return + } + + ret := fi.f.Call(args) - if len(argv) < len(fi.argConverters) { - fi.error(ctx, fmt.Errorf("function requires at least %d arguments", len(fi.argConverters))) + if len(ret) == 2 && ret[1].Interface() != nil { + callbackError(ctx, ret[1].Interface().(error)) + return } - for i, arg := range argv[:len(fi.argConverters)] { - v, err := fi.argConverters[i](arg) - if err != nil { - fi.error(ctx, err) - return - } - args = append(args, v) + err = fi.retConverter(ctx, ret[0]) + if err != nil { + callbackError(ctx, err) + return } +} - if fi.variadicConverter != nil { - for _, arg := range argv[len(fi.argConverters):] { - v, err := fi.variadicConverter(arg) - if err != nil { - fi.error(ctx, err) - return - } - args = append(args, v) +type aggInfo struct { + constructor reflect.Value + + // Active aggregator objects for aggregations in flight. The + // aggregators are indexed by a counter stored in the aggregation + // user data space provided by sqlite. + active map[int64]reflect.Value + next int64 + + stepArgConverters []callbackArgConverter + stepVariadicConverter callbackArgConverter + + doneRetConverter callbackRetConverter +} + +func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) { + aggIdx := (*int64)(C.sqlite3_aggregate_context(ctx, C.int(8))) + if *aggIdx == 0 { + *aggIdx = ai.next + ret := ai.constructor.Call(nil) + if len(ret) == 2 && ret[1].Interface() != nil { + return 0, reflect.Value{}, ret[1].Interface().(error) } + if ret[0].IsNil() { + return 0, reflect.Value{}, errors.New("aggregator constructor returned nil state") + } + ai.next++ + ai.active[*aggIdx] = ret[0] } + return *aggIdx, ai.active[*aggIdx], nil +} - ret := fi.f.Call(args) +func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { + _, agg, err := ai.agg(ctx) + if err != nil { + callbackError(ctx, err) + return + } + + args, err := callbackConvertArgs(argv, ai.stepArgConverters, ai.stepVariadicConverter) + if err != nil { + callbackError(ctx, err) + return + } + + ret := agg.MethodByName("Step").Call(args) + if len(ret) == 1 && ret[0].Interface() != nil { + callbackError(ctx, ret[0].Interface().(error)) + return + } +} + +func (ai *aggInfo) Done(ctx *C.sqlite3_context) { + idx, agg, err := ai.agg(ctx) + if err != nil { + callbackError(ctx, err) + return + } + defer func() { delete(ai.active, idx) }() + ret := agg.MethodByName("Done").Call(nil) if len(ret) == 2 && ret[1].Interface() != nil { - fi.error(ctx, ret[1].Interface().(error)) + callbackError(ctx, ret[1].Interface().(error)) return } - err := fi.retConverter(ctx, ret[0]) + err = ai.doneRetConverter(ctx, ret[0]) if err != nil { - fi.error(ctx, err) + callbackError(ctx, err) return } } @@ -244,6 +294,8 @@ func (tx *SQLiteTx) Rollback() error { // If pure is true. SQLite will assume that the function's return // value depends only on its inputs, and make more aggressive // optimizations in its queries. +// +// See _example/go_custom_funcs for a detailed example. func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) error { var fi functionInfo fi.f = reflect.ValueOf(impl) @@ -298,7 +350,132 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro if pure { opts |= C.SQLITE_DETERMINISTIC } - rv := C.sqlite3_create_function_v2(c.db, cname, C.int(numArgs), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil) + rv := C.sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil) + if rv != C.SQLITE_OK { + return c.lastError() + } + return nil +} + +// RegisterAggregator makes a Go type available as a SQLite aggregation function. +// +// Because aggregation is incremental, it's implemented in Go with a +// type that has 2 methods: func Step(values) accumulates one row of +// data into the accumulator, and func Done() ret finalizes and +// returns the aggregate value. "values" and "ret" may be any type +// supported by RegisterFunc. +// +// RegisterAggregator takes as implementation a constructor function +// that constructs an instance of the aggregator type each time an +// aggregation begins. The constructor must return a pointer to a +// type, or an interface that implements Step() and Done(). +// +// The constructor function and the Step/Done methods may optionally +// return an error in addition to their other return values. +// +// See _example/go_custom_funcs for a detailed example. +func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error { + var ai aggInfo + ai.constructor = reflect.ValueOf(impl) + t := ai.constructor.Type() + if t.Kind() != reflect.Func { + return errors.New("non-function passed to RegisterAggregator") + } + if t.NumOut() != 1 && t.NumOut() != 2 { + return errors.New("SQLite aggregator constructors must return 1 or 2 values") + } + if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return errors.New("Second return value of SQLite function must be error") + } + if t.NumIn() != 0 { + return errors.New("SQLite aggregator constructors must not have arguments") + } + + agg := t.Out(0) + switch agg.Kind() { + case reflect.Ptr, reflect.Interface: + default: + return errors.New("SQlite aggregator constructor must return a pointer object") + } + stepFn, found := agg.MethodByName("Step") + if !found { + return errors.New("SQlite aggregator doesn't have a Step() function") + } + step := stepFn.Type + if step.NumOut() != 0 && step.NumOut() != 1 { + return errors.New("SQlite aggregator Step() function must return 0 or 1 values") + } + if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return errors.New("type of SQlite aggregator Step() return value must be error") + } + + stepNArgs := step.NumIn() + start := 0 + if agg.Kind() == reflect.Ptr { + // Skip over the method receiver + stepNArgs-- + start++ + } + if step.IsVariadic() { + stepNArgs-- + } + for i := start; i < start+stepNArgs; i++ { + conv, err := callbackArg(step.In(i)) + if err != nil { + return err + } + ai.stepArgConverters = append(ai.stepArgConverters, conv) + } + if step.IsVariadic() { + conv, err := callbackArg(t.In(start + stepNArgs).Elem()) + if err != nil { + return err + } + ai.stepVariadicConverter = conv + // Pass -1 to sqlite so that it allows any number of + // arguments. The call helper verifies that the minimum number + // of arguments is present for variadic functions. + stepNArgs = -1 + } + + doneFn, found := agg.MethodByName("Done") + if !found { + return errors.New("SQlite aggregator doesn't have a Done() function") + } + done := doneFn.Type + doneNArgs := done.NumIn() + if agg.Kind() == reflect.Ptr { + // Skip over the method receiver + doneNArgs-- + } + if doneNArgs != 0 { + return errors.New("SQlite aggregator Done() function must have no arguments") + } + if done.NumOut() != 1 && done.NumOut() != 2 { + return errors.New("SQLite aggregator Done() function must return 1 or 2 values") + } + if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return errors.New("second return value of SQLite aggregator Done() function must be error") + } + + conv, err := callbackRet(done.Out(0)) + if err != nil { + return err + } + ai.doneRetConverter = conv + ai.active = make(map[int64]reflect.Value) + ai.next = 1 + + // ai must outlast the database connection, or we'll have dangling pointers. + c.aggregators = append(c.aggregators, &ai) + + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + opts := C.SQLITE_UTF8 + if pure { + opts |= C.SQLITE_DETERMINISTIC + } + rv := C.sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), unsafe.Pointer(&ai), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline))) if rv != C.SQLITE_OK { return c.lastError() } diff --git a/sqlite3_test.go b/sqlite3_test.go index 62db05b..74d3de1 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1175,6 +1175,65 @@ func TestFunctionRegistration(t *testing.T) { } } +type sumAggregator int64 + +func (s *sumAggregator) Step(x int64) { + *s += sumAggregator(x) +} + +func (s *sumAggregator) Done() int64 { + return int64(*s) +} + +func TestAggregatorRegistration(t *testing.T) { + customSum := func() *sumAggregator { + var ret sumAggregator + return &ret + } + + sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + if err := conn.RegisterAggregator("customSum", customSum, true); err != nil { + return err + } + return nil + }, + }) + db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("create table foo (department integer, profits integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)") + if err != nil { + t.Fatal("Failed to insert records:", err) + } + + tests := []struct { + dept, sum int64 + }{ + {1, 30}, + {2, 42}, + } + + for _, test := range tests { + var ret int64 + err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret) + if err != nil { + t.Fatal("Query failed:", err) + } + if ret != test.sum { + t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum) + } + } +} + var customFunctionOnce sync.Once func BenchmarkCustomFunctions(b *testing.B) { -- cgit v1.2.3 From 7b0d180ce9aa2473631426dbc011e34b435d5c65 Mon Sep 17 00:00:00 2001 From: Augusto Roman Date: Fri, 9 Oct 2015 22:59:25 -0700 Subject: Store/retrieve timezones for time.Time values. Previously, the timezone information for a provided value was discarded and the value always stored as in UTC. However, sqlite allows specifying the timezone offsets and handles those values appropriately. This change stores the timezone information and parses it out if present, otherwise it defaults to UTC as before. One additional bugfix: Previously, a unix timestamp in seconds was parsed in the local timezone (rather than UTC), in contrast to a unix timestamp in milliseconds that was parsed in UTC. While fixing that extra bug, I cleaned up the parsing code -- no need to convert to a string and then parse it back again and risk a parse error, just to check the number of digits. The tests were extended to cover non-UTC timezones storage & retrieval, meaningful unix timestamps, and correct handling of a trailing Z. --- sqlite3.go | 22 +++++++++++----------- sqlite3_test.go | 26 +++++++++++++++++++++++--- 2 files changed, 34 insertions(+), 14 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3.go b/sqlite3.go index d56bed3..fb5e99b 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -99,6 +99,10 @@ import ( // into the database. When parsing a string from a timestamp or // datetime column, the formats are tried in order. var SQLiteTimestampFormats = []string{ + // By default, store timestamps with whatever timezone they come with. + // When parsed, they will be returned with the same timezone. + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02T15:04:05.999999999-07:00", "2006-01-02 15:04:05.999999999", "2006-01-02T15:04:05.999999999", "2006-01-02 15:04:05", @@ -106,7 +110,6 @@ var SQLiteTimestampFormats = []string{ "2006-01-02 15:04", "2006-01-02T15:04", "2006-01-02", - "2006-01-02 15:04:05-07:00", } func init() { @@ -803,7 +806,7 @@ func (s *SQLiteStmt) bind(args []driver.Value) error { } rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v))) case time.Time: - b := []byte(v.UTC().Format(SQLiteTimestampFormats[0])) + b := []byte(v.Format(SQLiteTimestampFormats[0])) rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) } if rv != C.SQLITE_OK { @@ -902,18 +905,15 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i))) switch rc.decltype[i] { case "timestamp", "datetime", "date": - unixTimestamp := strconv.FormatInt(val, 10) var t time.Time - if len(unixTimestamp) == 13 { - duration, err := time.ParseDuration(unixTimestamp + "ms") - if err != nil { - return fmt.Errorf("error parsing %s value %d, %s", rc.decltype[i], val, err) - } - epoch := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) - t = epoch.Add(duration) + // Assume a millisecond unix timestamp if it's 13 digits -- too + // large to be a reasonable timestamp in seconds. + if val > 1e12 || val < -1e12 { + val *= int64(time.Millisecond) // convert ms to nsec } else { - t = time.Unix(val, 0) + val *= int64(time.Second) // convert sec to nsec } + t = time.Unix(0, val).UTC() if rc.s.c.loc != nil { t = t.In(rc.s.c.loc) } diff --git a/sqlite3_test.go b/sqlite3_test.go index 0239c78..4239bd6 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -324,6 +324,8 @@ func TestBooleanRoundtrip(t *testing.T) { } } +func timezone(t time.Time) string { return t.Format("-07:00") } + func TestTimestamp(t *testing.T) { tempFilename := TempFilename() db, err := sql.Open("sqlite3", tempFilename) @@ -342,6 +344,7 @@ func TestTimestamp(t *testing.T) { timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC) timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC) timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC) + tzTest := time.FixedZone("TEST", -9*3600-13*60) tests := []struct { value interface{} expected time.Time @@ -349,9 +352,9 @@ func TestTimestamp(t *testing.T) { {"nonsense", time.Time{}}, {"0000-00-00 00:00:00", time.Time{}}, {timestamp1, timestamp1}, - {timestamp1.Unix(), timestamp1}, - {timestamp1.UnixNano() / int64(time.Millisecond), timestamp1}, - {timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1}, + {timestamp2.Unix(), timestamp2.Truncate(time.Second)}, + {timestamp2.UnixNano() / int64(time.Millisecond), timestamp2.Truncate(time.Millisecond)}, + {timestamp1.In(tzTest), timestamp1.In(tzTest)}, {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1}, {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1}, {timestamp1.Format("2006-01-02 15:04:05"), timestamp1}, @@ -359,6 +362,7 @@ func TestTimestamp(t *testing.T) { {timestamp2, timestamp2}, {"2006-01-02 15:04:05.123456789", timestamp2}, {"2006-01-02T15:04:05.123456789", timestamp2}, + {"2006-01-02T05:51:05.123456789-09:13", timestamp2.In(tzTest)}, {"2012-11-04", timestamp3}, {"2012-11-04 00:00", timestamp3}, {"2012-11-04 00:00:00", timestamp3}, @@ -366,6 +370,14 @@ func TestTimestamp(t *testing.T) { {"2012-11-04T00:00", timestamp3}, {"2012-11-04T00:00:00", timestamp3}, {"2012-11-04T00:00:00.000", timestamp3}, + {"2006-01-02T15:04:05.123456789Z", timestamp2}, + {"2012-11-04Z", timestamp3}, + {"2012-11-04 00:00Z", timestamp3}, + {"2012-11-04 00:00:00Z", timestamp3}, + {"2012-11-04 00:00:00.000Z", timestamp3}, + {"2012-11-04T00:00Z", timestamp3}, + {"2012-11-04T00:00:00Z", timestamp3}, + {"2012-11-04T00:00:00.000Z", timestamp3}, } for i := range tests { _, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value) @@ -400,6 +412,14 @@ func TestTimestamp(t *testing.T) { if !tests[id].expected.Equal(dt) { t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) } + if timezone(tests[id].expected) != timezone(ts) { + t.Errorf("Timezone for id %v (%v) should be %v, not %v", id, tests[id].value, + timezone(tests[id].expected), timezone(ts)) + } + if timezone(tests[id].expected) != timezone(dt) { + t.Errorf("Timezone for id %v (%v) should be %v, not %v", id, tests[id].value, + timezone(tests[id].expected), timezone(dt)) + } } if seen != len(tests) { -- cgit v1.2.3 From 59f20de728c91fc2a6593419074299398d8358b4 Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Mon, 2 Nov 2015 11:56:49 +0900 Subject: fix tests --- sqlite3_fts3_test.go | 2 +- sqlite3_test.go | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3_fts3_test.go b/sqlite3_fts3_test.go index beb5ada..87b4db8 100644 --- a/sqlite3_fts3_test.go +++ b/sqlite3_fts3_test.go @@ -83,7 +83,7 @@ func TestFTS3(t *testing.T) { } func TestFTS4(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) diff --git a/sqlite3_test.go b/sqlite3_test.go index 90bee51..22b9c21 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -13,7 +13,6 @@ import ( "io/ioutil" "net/url" "os" - "path/filepath" "reflect" "regexp" "strings" @@ -796,7 +795,7 @@ func TestTimezoneConversion(t *testing.T) { } func TestSuite(t *testing.T) { - tempFilename := TempFilename() + tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999") if err != nil { t.Fatal(err) -- cgit v1.2.3 From 21637a6531000298cf4a2d622e7d049a289df9ef Mon Sep 17 00:00:00 2001 From: Lars Buitinck Date: Tue, 3 Nov 2015 13:52:28 +0100 Subject: Clean up tempfiles in tests "go test" leaves no more clutter in /tmp. --- sqlite3_fts3_test.go | 4 ++-- sqlite3_test.go | 42 +++++++++++++++++++++--------------------- 2 files changed, 23 insertions(+), 23 deletions(-) (limited to 'sqlite3_test.go') diff --git a/sqlite3_fts3_test.go b/sqlite3_fts3_test.go index 87b4db8..803afbd 100644 --- a/sqlite3_fts3_test.go +++ b/sqlite3_fts3_test.go @@ -13,11 +13,11 @@ import ( func TestFTS3(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - defer os.Remove(tempFilename) defer db.Close() _, err = db.Exec("DROP TABLE foo") @@ -84,11 +84,11 @@ func TestFTS3(t *testing.T) { func TestFTS4(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - defer os.Remove(tempFilename) defer db.Close() _, err = db.Exec("DROP TABLE foo") diff --git a/sqlite3_test.go b/sqlite3_test.go index dfe1c43..9efd313 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -109,11 +109,11 @@ func TestReadonly(t *testing.T) { func TestClose(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - defer os.Remove(tempFilename) _, err = db.Exec("drop table foo") _, err = db.Exec("create table foo (id integer)") @@ -135,11 +135,11 @@ func TestClose(t *testing.T) { func TestInsert(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - defer os.Remove(tempFilename) defer db.Close() _, err = db.Exec("drop table foo") @@ -174,11 +174,11 @@ func TestInsert(t *testing.T) { func TestUpdate(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - defer os.Remove(tempFilename) defer db.Close() _, err = db.Exec("drop table foo") @@ -239,11 +239,11 @@ func TestUpdate(t *testing.T) { func TestDelete(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - defer os.Remove(tempFilename) defer db.Close() _, err = db.Exec("drop table foo") @@ -300,11 +300,11 @@ func TestDelete(t *testing.T) { func TestBooleanRoundtrip(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - defer os.Remove(tempFilename) defer db.Close() _, err = db.Exec("DROP TABLE foo") @@ -351,11 +351,11 @@ func timezone(t time.Time) string { return t.Format("-07:00") } func TestTimestamp(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - defer os.Remove(tempFilename) defer db.Close() _, err = db.Exec("DROP TABLE foo") @@ -452,12 +452,12 @@ func TestTimestamp(t *testing.T) { func TestBoolean(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - defer os.Remove(tempFilename) defer db.Close() _, err = db.Exec("CREATE TABLE foo(id INTEGER, fbool BOOLEAN)") @@ -544,12 +544,11 @@ func TestBoolean(t *testing.T) { func TestFloat32(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - - defer os.Remove(tempFilename) defer db.Close() _, err = db.Exec("CREATE TABLE foo(id INTEGER)") @@ -582,12 +581,11 @@ func TestFloat32(t *testing.T) { func TestNull(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - - defer os.Remove(tempFilename) defer db.Close() rows, err := db.Query("SELECT 3.141592") @@ -614,12 +612,11 @@ func TestNull(t *testing.T) { func TestTransaction(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - - defer os.Remove(tempFilename) defer db.Close() _, err = db.Exec("CREATE TABLE foo(id INTEGER)") @@ -674,13 +671,13 @@ func TestTransaction(t *testing.T) { func TestWAL(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - - defer os.Remove(tempFilename) defer db.Close() + if _, err = db.Exec("PRAGMA journal_mode=WAL;"); err != nil { t.Fatal("Failed to Exec PRAGMA journal_mode:", err) } @@ -722,11 +719,11 @@ func TestTimezoneConversion(t *testing.T) { zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} for _, tz := range zones { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename+"?_loc="+url.QueryEscape(tz)) if err != nil { t.Fatal("Failed to open database:", err) } - defer os.Remove(tempFilename) defer db.Close() _, err = db.Exec("DROP TABLE foo") @@ -816,11 +813,11 @@ func TestTimezoneConversion(t *testing.T) { func TestSuite(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999") if err != nil { t.Fatal(err) } - defer os.Remove(tempFilename) defer db.Close() sqlite3_test.RunTests(t, db, sqlite3_test.SQLITE) @@ -830,11 +827,11 @@ func TestSuite(t *testing.T) { // https://github.com/mattn/go-sqlite3/issues/82 func TestExecer(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - defer os.Remove(tempFilename) defer db.Close() _, err = db.Exec(` @@ -850,11 +847,11 @@ func TestExecer(t *testing.T) { func TestQueryer(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - defer os.Remove(tempFilename) defer db.Close() _, err = db.Exec(` @@ -891,6 +888,7 @@ func TestQueryer(t *testing.T) { func TestStress(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) @@ -929,6 +927,7 @@ func TestStress(t *testing.T) { func TestDateTimeLocal(t *testing.T) { zone := "Asia/Tokyo" tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename+"?_loc="+zone) if err != nil { t.Fatal("Failed to open database:", err) @@ -996,11 +995,11 @@ func TestVersion(t *testing.T) { func TestNumberNamedParams(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - defer os.Remove(tempFilename) defer db.Close() _, err = db.Exec(` @@ -1032,11 +1031,11 @@ func TestNumberNamedParams(t *testing.T) { func TestStringContainingZero(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) } - defer os.Remove(tempFilename) defer db.Close() _, err = db.Exec(` @@ -1092,6 +1091,7 @@ func (t TimeStamp) Value() (driver.Value, error) { func TestDateTimeNow(t *testing.T) { tempFilename := TempFilename(t) + defer os.Remove(tempFilename) db, err := sql.Open("sqlite3", tempFilename) if err != nil { t.Fatal("Failed to open database:", err) -- cgit v1.2.3