aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md2
-rw-r--r--_example/hook/hook.go6
-rw-r--r--callback.go18
-rw-r--r--sqlite3.go152
-rw-r--r--sqlite3_go18_test.go92
-rw-r--r--sqlite3_libsqlite3.go1
-rw-r--r--sqlite3_other.go1
-rw-r--r--sqlite3_test.go582
-rw-r--r--sqlite3_test/sqltest.go423
9 files changed, 823 insertions, 454 deletions
diff --git a/README.md b/README.md
index 01d28a2..ad00f10 100644
--- a/README.md
+++ b/README.md
@@ -65,7 +65,7 @@ FAQ
* Want to get time.Time with current locale
- Use `loc=auto` in SQLite3 filename schema like `file:foo.db?loc=auto`.
+ Use `_loc=auto` in SQLite3 filename schema like `file:foo.db?_loc=auto`.
* Can I use this in multiple routines concurrently?
diff --git a/_example/hook/hook.go b/_example/hook/hook.go
index 17bddeb..6023181 100644
--- a/_example/hook/hook.go
+++ b/_example/hook/hook.go
@@ -14,6 +14,12 @@ func main() {
&sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
sqlite3conn = append(sqlite3conn, conn)
+ conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
+ switch op {
+ case sqlite3.SQLITE_INSERT:
+ log.Println("Notified of insert on db", db, "table", table, "rowid", rowid)
+ }
+ })
return nil
},
})
diff --git a/callback.go b/callback.go
index 244e739..29ece3d 100644
--- a/callback.go
+++ b/callback.go
@@ -59,6 +59,24 @@ func compareTrampoline(handlePtr uintptr, la C.int, a *C.char, lb C.int, b *C.ch
return C.int(cmp(C.GoStringN(a, la), C.GoStringN(b, lb)))
}
+//export commitHookTrampoline
+func commitHookTrampoline(handle uintptr) int {
+ callback := lookupHandle(handle).(func() int)
+ return callback()
+}
+
+//export rollbackHookTrampoline
+func rollbackHookTrampoline(handle uintptr) {
+ callback := lookupHandle(handle).(func())
+ callback()
+}
+
+//export updateHookTrampoline
+func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, rowid int64) {
+ callback := lookupHandle(handle).(func(int, string, string, int64))
+ callback(op, C.GoString(db), C.GoString(table), rowid)
+}
+
// Use handles to avoid passing Go pointers to C.
type handleVal struct {
diff --git a/sqlite3.go b/sqlite3.go
index c16204c..1ff58c3 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -7,7 +7,7 @@ package sqlite3
/*
#cgo CFLAGS: -std=gnu99
-#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
+#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE=1
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61
#cgo CFLAGS: -DSQLITE_TRACE_SIZE_LIMIT=15
#cgo CFLAGS: -DSQLITE_DISABLE_INTRINSIC
@@ -102,6 +102,9 @@ int _sqlite3_create_function(
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
int compareTrampoline(void*, int, char*, int, char*);
+int commitHookTrampoline(void*);
+void rollbackHookTrampoline(void*);
+void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64);
*/
import "C"
import (
@@ -115,6 +118,7 @@ import (
"runtime"
"strconv"
"strings"
+ "sync"
"time"
"unsafe"
@@ -151,6 +155,12 @@ func Version() (libVersion string, libVersionNumber int, sourceID string) {
return libVersion, libVersionNumber, sourceID
}
+const (
+ SQLITE_DELETE = C.SQLITE_DELETE
+ SQLITE_INSERT = C.SQLITE_INSERT
+ SQLITE_UPDATE = C.SQLITE_UPDATE
+)
+
// SQLiteDriver implement sql.Driver.
type SQLiteDriver struct {
Extensions []string
@@ -159,6 +169,7 @@ type SQLiteDriver struct {
// SQLiteConn implement sql.Conn.
type SQLiteConn struct {
+ mu sync.Mutex
db *C.sqlite3
loc *time.Location
txlock string
@@ -173,6 +184,7 @@ type SQLiteTx struct {
// SQLiteStmt implement sql.Stmt.
type SQLiteStmt struct {
+ mu sync.Mutex
c *SQLiteConn
s *C.sqlite3_stmt
t string
@@ -193,6 +205,7 @@ type SQLiteRows struct {
cols []string
decltype []string
cls bool
+ closed bool
done chan struct{}
}
@@ -338,6 +351,51 @@ func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int
return nil
}
+// RegisterCommitHook sets the commit hook for a connection.
+//
+// If the callback returns non-zero the transaction will become a rollback.
+//
+// If there is an existing commit hook for this connection, it will be
+// removed. If callback is nil the existing hook (if any) will be removed
+// without creating a new one.
+func (c *SQLiteConn) RegisterCommitHook(callback func() int) {
+ if callback == nil {
+ C.sqlite3_commit_hook(c.db, nil, nil)
+ } else {
+ C.sqlite3_commit_hook(c.db, (*[0]byte)(unsafe.Pointer(C.commitHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
+ }
+}
+
+// RegisterRollbackHook sets the rollback hook for a connection.
+//
+// If there is an existing rollback hook for this connection, it will be
+// removed. If callback is nil the existing hook (if any) will be removed
+// without creating a new one.
+func (c *SQLiteConn) RegisterRollbackHook(callback func()) {
+ if callback == nil {
+ C.sqlite3_rollback_hook(c.db, nil, nil)
+ } else {
+ C.sqlite3_rollback_hook(c.db, (*[0]byte)(unsafe.Pointer(C.rollbackHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
+ }
+}
+
+// RegisterUpdateHook sets the update hook for a connection.
+//
+// The parameters to the callback are the operation (one of the constants
+// SQLITE_INSERT, SQLITE_DELETE, or SQLITE_UPDATE), the database name, the
+// table name, and the rowid.
+//
+// If there is an existing update hook for this connection, it will be
+// removed. If callback is nil the existing hook (if any) will be removed
+// without creating a new one.
+func (c *SQLiteConn) RegisterUpdateHook(callback func(int, string, string, int64)) {
+ if callback == nil {
+ C.sqlite3_update_hook(c.db, nil, nil)
+ } else {
+ C.sqlite3_update_hook(c.db, (*[0]byte)(unsafe.Pointer(C.updateHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
+ }
+}
+
// RegisterFunc makes a Go function available as a SQLite function.
//
// The Go function can have arguments of the following types: any
@@ -568,6 +626,8 @@ func errorString(err Error) string {
// "deferred", "exclusive".
// _foreign_keys=X
// Enable or disable enforcement of foreign keys. X can be 1 or 0.
+// _recursive_triggers=X
+// Enable or disable recursive triggers. X can be 1 or 0.
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
@@ -577,6 +637,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
txlock := "BEGIN"
busyTimeout := 5000
foreignKeys := -1
+ recursiveTriggers := -1
pos := strings.IndexRune(dsn, '?')
if pos >= 1 {
params, err := url.ParseQuery(dsn[pos+1:])
@@ -631,6 +692,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
}
}
+ // _recursive_triggers
+ if val := params.Get("_recursive_triggers"); val != "" {
+ switch val {
+ case "1":
+ recursiveTriggers = 1
+ case "0":
+ recursiveTriggers = 0
+ default:
+ return nil, fmt.Errorf("Invalid _recursive_triggers: %v", val)
+ }
+ }
+
if !strings.HasPrefix(dsn, "file:") {
dsn = dsn[:pos]
}
@@ -677,6 +750,17 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, err
}
}
+ if recursiveTriggers == 0 {
+ if err := exec("PRAGMA recursive_triggers = OFF;"); err != nil {
+ C.sqlite3_close_v2(db)
+ return nil, err
+ }
+ } else if recursiveTriggers == 1 {
+ if err := exec("PRAGMA recursive_triggers = ON;"); err != nil {
+ C.sqlite3_close_v2(db)
+ return nil, err
+ }
+ }
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
@@ -704,11 +788,22 @@ func (c *SQLiteConn) Close() error {
return c.lastError()
}
deleteHandles(c)
+ c.mu.Lock()
c.db = nil
+ c.mu.Unlock()
runtime.SetFinalizer(c, nil)
return nil
}
+func (c *SQLiteConn) dbConnOpen() bool {
+ if c == nil {
+ return false
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.db != nil
+}
+
// Prepare the query string. Return a new statement.
func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
return c.prepare(context.Background(), query)
@@ -734,14 +829,17 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er
// Close the statement.
func (s *SQLiteStmt) Close() error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
if s.closed {
return nil
}
s.closed = true
- if s.c == nil || s.c.db == nil {
+ if !s.c.dbConnOpen() {
return errors.New("sqlite statement with already closed database connection")
}
rv := C.sqlite3_finalize(s.s)
+ s.s = nil
if rv != C.SQLITE_OK {
return s.c.lastError()
}
@@ -759,6 +857,8 @@ type bindArg struct {
v driver.Value
}
+var placeHolder = []byte{0}
+
func (s *SQLiteStmt) bind(args []namedValue) error {
rv := C.sqlite3_reset(s.s)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
@@ -780,8 +880,7 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
rv = C.sqlite3_bind_null(s.s, n)
case string:
if len(v) == 0 {
- b := []byte{0}
- rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(0))
+ rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0))
} else {
b := []byte(v)
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
@@ -797,11 +896,11 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte:
- if len(v) == 0 {
- rv = C._sqlite3_bind_blob(s.s, n, nil, 0)
- } else {
- rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(len(v)))
+ ln := len(v)
+ if ln == 0 {
+ v = placeHolder
}
+ rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
case time.Time:
b := []byte(v.Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
@@ -836,6 +935,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
cols: nil,
decltype: nil,
cls: s.cls,
+ closed: false,
done: make(chan struct{}),
}
@@ -908,25 +1008,33 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result
// Close the rows.
func (rc *SQLiteRows) Close() error {
- if rc.s.closed {
+ rc.s.mu.Lock()
+ if rc.s.closed || rc.closed {
+ rc.s.mu.Unlock()
return nil
}
+ rc.closed = true
if rc.done != nil {
close(rc.done)
}
if rc.cls {
+ rc.s.mu.Unlock()
return rc.s.Close()
}
rv := C.sqlite3_reset(rc.s.s)
if rv != C.SQLITE_OK {
+ rc.s.mu.Unlock()
return rc.s.c.lastError()
}
+ rc.s.mu.Unlock()
return nil
}
// Columns return column names.
func (rc *SQLiteRows) Columns() []string {
- if rc.nc != len(rc.cols) {
+ rc.s.mu.Lock()
+ defer rc.s.mu.Unlock()
+ if rc.s.s != nil && rc.nc != len(rc.cols) {
rc.cols = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i)))
@@ -935,9 +1043,8 @@ func (rc *SQLiteRows) Columns() []string {
return rc.cols
}
-// DeclTypes return column types.
-func (rc *SQLiteRows) DeclTypes() []string {
- if rc.decltype == nil {
+func (rc *SQLiteRows) declTypes() []string {
+ if rc.s.s != nil && rc.decltype == nil {
rc.decltype = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
@@ -946,8 +1053,20 @@ func (rc *SQLiteRows) DeclTypes() []string {
return rc.decltype
}
+// DeclTypes return column types.
+func (rc *SQLiteRows) DeclTypes() []string {
+ rc.s.mu.Lock()
+ defer rc.s.mu.Unlock()
+ return rc.declTypes()
+}
+
// Next move cursor to next.
func (rc *SQLiteRows) Next(dest []driver.Value) error {
+ if rc.s.closed {
+ return io.EOF
+ }
+ rc.s.mu.Lock()
+ defer rc.s.mu.Unlock()
rv := C.sqlite3_step(rc.s.s)
if rv == C.SQLITE_DONE {
return io.EOF
@@ -960,7 +1079,7 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
return nil
}
- rc.DeclTypes()
+ rc.declTypes()
for i := range dest {
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
@@ -973,10 +1092,11 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
// large to be a reasonable timestamp in seconds.
if val > 1e12 || val < -1e12 {
val *= int64(time.Millisecond) // convert ms to nsec
+ t = time.Unix(0, val)
} else {
- val *= int64(time.Second) // convert sec to nsec
+ t = time.Unix(val, 0)
}
- t = time.Unix(0, val).UTC()
+ t = t.UTC()
if rc.s.c.loc != nil {
t = t.In(rc.s.c.loc)
}
diff --git a/sqlite3_go18_test.go b/sqlite3_go18_test.go
index f076b81..a5f4aae 100644
--- a/sqlite3_go18_test.go
+++ b/sqlite3_go18_test.go
@@ -8,9 +8,13 @@
package sqlite3
import (
+ "context"
"database/sql"
+ "fmt"
+ "math/rand"
"os"
"testing"
+ "time"
)
func TestNamedParams(t *testing.T) {
@@ -48,3 +52,91 @@ func TestNamedParams(t *testing.T) {
t.Error("Failed to db.QueryRow: not matched results")
}
}
+
+var (
+ testTableStatements = []string{
+ `DROP TABLE IF EXISTS test_table`,
+ `
+CREATE TABLE IF NOT EXISTS test_table (
+ key1 VARCHAR(64) PRIMARY KEY,
+ key_id VARCHAR(64) NOT NULL,
+ key2 VARCHAR(64) NOT NULL,
+ key3 VARCHAR(64) NOT NULL,
+ key4 VARCHAR(64) NOT NULL,
+ key5 VARCHAR(64) NOT NULL,
+ key6 VARCHAR(64) NOT NULL,
+ data BLOB NOT NULL
+);`,
+ }
+ letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+)
+
+func randStringBytes(n int) string {
+ b := make([]byte, n)
+ for i := range b {
+ b[i] = letterBytes[rand.Intn(len(letterBytes))]
+ }
+ return string(b)
+}
+
+func initDatabase(t *testing.T, db *sql.DB, rowCount int64) {
+ t.Logf("Executing db initializing statements")
+ for _, query := range testTableStatements {
+ _, err := db.Exec(query)
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ for i := int64(0); i < rowCount; i++ {
+ query := `INSERT INTO test_table
+ (key1, key_id, key2, key3, key4, key5, key6, data)
+ VALUES
+ (?, ?, ?, ?, ?, ?, ?, ?);`
+ args := []interface{}{
+ randStringBytes(50),
+ fmt.Sprint(i),
+ randStringBytes(50),
+ randStringBytes(50),
+ randStringBytes(50),
+ randStringBytes(50),
+ randStringBytes(50),
+ randStringBytes(50),
+ randStringBytes(2048),
+ }
+ _, err := db.Exec(query, args...)
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+}
+
+func TestShortTimeout(t *testing.T) {
+ db, err := sql.Open("sqlite3", "file::memory:?mode=memory&cache=shared")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+ initDatabase(t, db, 10000)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Microsecond)
+ defer cancel()
+ query := `SELECT key1, key_id, key2, key3, key4, key5, key6, data
+ FROM test_table
+ ORDER BY key2 ASC`
+ rows, err := db.QueryContext(ctx, query)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer rows.Close()
+ for rows.Next() {
+ var key1, keyid, key2, key3, key4, key5, key6 string
+ var data []byte
+ err = rows.Scan(&key1, &keyid, &key2, &key3, &key4, &key5, &key6, &data)
+ if err != nil {
+ break
+ }
+ }
+ if context.DeadlineExceeded != ctx.Err() {
+ t.Fatal(ctx.Err())
+ }
+}
diff --git a/sqlite3_libsqlite3.go b/sqlite3_libsqlite3.go
index 135863e..e4557e6 100644
--- a/sqlite3_libsqlite3.go
+++ b/sqlite3_libsqlite3.go
@@ -10,5 +10,6 @@ package sqlite3
#cgo CFLAGS: -DUSE_LIBSQLITE3
#cgo linux LDFLAGS: -lsqlite3
#cgo darwin LDFLAGS: -L/usr/local/opt/sqlite/lib -lsqlite3
+#cgo solaris LDFLAGS: -lsqlite3
*/
import "C"
diff --git a/sqlite3_other.go b/sqlite3_other.go
index a20d02c..f721b5e 100644
--- a/sqlite3_other.go
+++ b/sqlite3_other.go
@@ -9,5 +9,6 @@ package sqlite3
/*
#cgo CFLAGS: -I.
#cgo linux LDFLAGS: -ldl
+#cgo solaris LDFLAGS: -lc
*/
import "C"
diff --git a/sqlite3_test.go b/sqlite3_test.go
index 842f5d7..9d4b373 100644
--- a/sqlite3_test.go
+++ b/sqlite3_test.go
@@ -6,21 +6,22 @@
package sqlite3
import (
+ "bytes"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io/ioutil"
+ "math/rand"
"net/url"
"os"
"reflect"
"regexp"
+ "strconv"
"strings"
"sync"
"testing"
"time"
-
- "github.com/mattn/go-sqlite3/sqlite3_test"
)
func TempFilename(t *testing.T) string {
@@ -136,6 +137,35 @@ func TestForeignKeys(t *testing.T) {
}
}
+func TestRecursiveTriggers(t *testing.T) {
+ cases := map[string]bool{
+ "?_recursive_triggers=1": true,
+ "?_recursive_triggers=0": false,
+ }
+ for option, want := range cases {
+ fname := TempFilename(t)
+ uri := "file:" + fname + option
+ db, err := sql.Open("sqlite3", uri)
+ if err != nil {
+ os.Remove(fname)
+ t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
+ continue
+ }
+ var enabled bool
+ err = db.QueryRow("PRAGMA recursive_triggers;").Scan(&enabled)
+ db.Close()
+ os.Remove(fname)
+ if err != nil {
+ t.Errorf("query recursive_triggers for %s: %v", uri, err)
+ continue
+ }
+ if enabled != want {
+ t.Errorf("\"PRAGMA recursive_triggers;\" for %q = %t; want %t", uri, enabled, want)
+ continue
+ }
+ }
+}
+
func TestClose(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
@@ -403,6 +433,7 @@ func TestTimestamp(t *testing.T) {
}{
{"nonsense", time.Time{}},
{"0000-00-00 00:00:00", time.Time{}},
+ {time.Time{}.Unix(), time.Time{}},
{timestamp1, timestamp1},
{timestamp2.Unix(), timestamp2.Truncate(time.Second)},
{timestamp2.UnixNano() / int64(time.Millisecond), timestamp2.Truncate(time.Millisecond)},
@@ -840,18 +871,6 @@ func TestTimezoneConversion(t *testing.T) {
}
}
-func TestSuite(t *testing.T) {
- tempFilename := TempFilename(t)
- defer os.Remove(tempFilename)
- db, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
- if err != nil {
- t.Fatal(err)
- }
- defer db.Close()
-
- sqlite3_test.RunTests(t, db, sqlite3_test.SQLITE)
-}
-
// TODO: Execer & Queryer currently disabled
// https://github.com/mattn/go-sqlite3/issues/82
func TestExecer(t *testing.T) {
@@ -1385,6 +1404,122 @@ func TestPinger(t *testing.T) {
}
}
+func TestUpdateAndTransactionHooks(t *testing.T) {
+ var events []string
+ var commitHookReturn = 0
+
+ sql.Register("sqlite3_UpdateHook", &SQLiteDriver{
+ ConnectHook: func(conn *SQLiteConn) error {
+ conn.RegisterCommitHook(func() int {
+ events = append(events, "commit")
+ return commitHookReturn
+ })
+ conn.RegisterRollbackHook(func() {
+ events = append(events, "rollback")
+ })
+ conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
+ events = append(events, fmt.Sprintf("update(op=%v db=%v table=%v rowid=%v)", op, db, table, rowid))
+ })
+ return nil
+ },
+ })
+ db, err := sql.Open("sqlite3_UpdateHook", ":memory:")
+ if err != nil {
+ t.Fatal("Failed to open database:", err)
+ }
+ defer db.Close()
+
+ statements := []string{
+ "create table foo (id integer primary key)",
+ "insert into foo values (9)",
+ "update foo set id = 99 where id = 9",
+ "delete from foo where id = 99",
+ }
+ for _, statement := range statements {
+ _, err = db.Exec(statement)
+ if err != nil {
+ t.Fatalf("Unable to prepare test data [%v]: %v", statement, err)
+ }
+ }
+
+ commitHookReturn = 1
+ _, err = db.Exec("insert into foo values (5)")
+ if err == nil {
+ t.Error("Commit hook failed to rollback transaction")
+ }
+
+ var expected = []string{
+ "commit",
+ fmt.Sprintf("update(op=%v db=main table=foo rowid=9)", SQLITE_INSERT),
+ "commit",
+ fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_UPDATE),
+ "commit",
+ fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_DELETE),
+ "commit",
+ fmt.Sprintf("update(op=%v db=main table=foo rowid=5)", SQLITE_INSERT),
+ "commit",
+ "rollback",
+ }
+ if !reflect.DeepEqual(events, expected) {
+ t.Errorf("Expected notifications %v but got %v", expected, events)
+ }
+}
+
+func TestNilAndEmptyBytes(t *testing.T) {
+ db, err := sql.Open("sqlite3", ":memory:")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+ actualNil := []byte("use this to use an actual nil not a reference to nil")
+ emptyBytes := []byte{}
+ for tsti, tst := range []struct {
+ name string
+ columnType string
+ insertBytes []byte
+ expectedBytes []byte
+ }{
+ {"actual nil blob", "blob", actualNil, nil},
+ {"referenced nil blob", "blob", nil, nil},
+ {"empty blob", "blob", emptyBytes, emptyBytes},
+ {"actual nil text", "text", actualNil, nil},
+ {"referenced nil text", "text", nil, nil},
+ {"empty text", "text", emptyBytes, emptyBytes},
+ } {
+ if _, err = db.Exec(fmt.Sprintf("create table tbl%d (txt %s)", tsti, tst.columnType)); err != nil {
+ t.Fatal(tst.name, err)
+ }
+ if bytes.Equal(tst.insertBytes, actualNil) {
+ if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), nil); err != nil {
+ t.Fatal(tst.name, err)
+ }
+ } else {
+ if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), &tst.insertBytes); err != nil {
+ t.Fatal(tst.name, err)
+ }
+ }
+ rows, err := db.Query(fmt.Sprintf("select txt from tbl%d", tsti))
+ if err != nil {
+ t.Fatal(tst.name, err)
+ }
+ if !rows.Next() {
+ t.Fatal(tst.name, "no rows")
+ }
+ var scanBytes []byte
+ if err = rows.Scan(&scanBytes); err != nil {
+ t.Fatal(tst.name, err)
+ }
+ if err = rows.Err(); err != nil {
+ t.Fatal(tst.name, err)
+ }
+ if tst.expectedBytes == nil && scanBytes != nil {
+ t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes)
+ } else if !bytes.Equal(scanBytes, tst.expectedBytes) {
+ t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes)
+ }
+ }
+}
+
var customFunctionOnce sync.Once
func BenchmarkCustomFunctions(b *testing.B) {
@@ -1419,3 +1554,422 @@ func BenchmarkCustomFunctions(b *testing.B) {
}
}
}
+
+func TestSuite(t *testing.T) {
+ tempFilename := TempFilename(t)
+ defer os.Remove(tempFilename)
+ d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer d.Close()
+
+ db = &TestDB{t, d, SQLITE, sync.Once{}}
+ testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
+
+ if !testing.Short() {
+ for _, b := range benchmarks {
+ fmt.Printf("%-20s", b.Name)
+ r := testing.Benchmark(b.F)
+ fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
+ }
+ }
+ db.tearDown()
+}
+
+// Dialect is a type of dialect of databases.
+type Dialect int
+
+// Dialects for databases.
+const (
+ SQLITE Dialect = iota // SQLITE mean SQLite3 dialect
+ POSTGRESQL // POSTGRESQL mean PostgreSQL dialect
+ MYSQL // MYSQL mean MySQL dialect
+)
+
+// DB provide context for the tests
+type TestDB struct {
+ *testing.T
+ *sql.DB
+ dialect Dialect
+ once sync.Once
+}
+
+var db *TestDB
+
+// the following tables will be created and dropped during the test
+var testTables = []string{"foo", "bar", "t", "bench"}
+
+var tests = []testing.InternalTest{
+ {Name: "TestResult", F: testResult},
+ {Name: "TestBlobs", F: testBlobs},
+ {Name: "TestManyQueryRow", F: testManyQueryRow},
+ {Name: "TestTxQuery", F: testTxQuery},
+ {Name: "TestPreparedStmt", F: testPreparedStmt},
+}
+
+var benchmarks = []testing.InternalBenchmark{
+ {Name: "BenchmarkExec", F: benchmarkExec},
+ {Name: "BenchmarkQuery", F: benchmarkQuery},
+ {Name: "BenchmarkParams", F: benchmarkParams},
+ {Name: "BenchmarkStmt", F: benchmarkStmt},
+ {Name: "BenchmarkRows", F: benchmarkRows},
+ {Name: "BenchmarkStmtRows", F: benchmarkStmtRows},
+}
+
+func (db *TestDB) mustExec(sql string, args ...interface{}) sql.Result {
+ res, err := db.Exec(sql, args...)
+ if err != nil {
+ db.Fatalf("Error running %q: %v", sql, err)
+ }
+ return res
+}
+
+func (db *TestDB) tearDown() {
+ for _, tbl := range testTables {
+ switch db.dialect {
+ case SQLITE:
+ db.mustExec("drop table if exists " + tbl)
+ case MYSQL, POSTGRESQL:
+ db.mustExec("drop table if exists " + tbl)
+ default:
+ db.Fatal("unknown dialect")
+ }
+ }
+}
+
+// q replaces ? parameters if needed
+func (db *TestDB) q(sql string) string {
+ switch db.dialect {
+ case POSTGRESQL: // repace with $1, $2, ..
+ qrx := regexp.MustCompile(`\?`)
+ n := 0
+ return qrx.ReplaceAllStringFunc(sql, func(string) string {
+ n++
+ return "$" + strconv.Itoa(n)
+ })
+ }
+ return sql
+}
+
+func (db *TestDB) blobType(size int) string {
+ switch db.dialect {
+ case SQLITE:
+ return fmt.Sprintf("blob[%d]", size)
+ case POSTGRESQL:
+ return "bytea"
+ case MYSQL:
+ return fmt.Sprintf("VARBINARY(%d)", size)
+ }
+ panic("unknown dialect")
+}
+
+func (db *TestDB) serialPK() string {
+ switch db.dialect {
+ case SQLITE:
+ return "integer primary key autoincrement"
+ case POSTGRESQL:
+ return "serial primary key"
+ case MYSQL:
+ return "integer primary key auto_increment"
+ }
+ panic("unknown dialect")
+}
+
+func (db *TestDB) now() string {
+ switch db.dialect {
+ case SQLITE:
+ return "datetime('now')"
+ case POSTGRESQL:
+ return "now()"
+ case MYSQL:
+ return "now()"
+ }
+ panic("unknown dialect")
+}
+
+func makeBench() {
+ if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
+ panic(err)
+ }
+ st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
+ if err != nil {
+ panic(err)
+ }
+ defer st.Close()
+ for i := 0; i < 100; i++ {
+ if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
+ panic(err)
+ }
+ }
+}
+
+// testResult is test for result
+func testResult(t *testing.T) {
+ db.tearDown()
+ db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
+
+ for i := 1; i < 3; i++ {
+ r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
+ n, err := r.RowsAffected()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 1 {
+ t.Errorf("got %v, want %v", n, 1)
+ }
+ n, err = r.LastInsertId()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != int64(i) {
+ t.Errorf("got %v, want %v", n, i)
+ }
+ }
+ if _, err := db.Exec("error!"); err == nil {
+ t.Fatalf("expected error")
+ }
+}
+
+// testBlobs is test for blobs
+func testBlobs(t *testing.T) {
+ db.tearDown()
+ var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+ db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
+ db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
+
+ want := fmt.Sprintf("%x", blob)
+
+ b := make([]byte, 16)
+ err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
+ got := fmt.Sprintf("%x", b)
+ if err != nil {
+ t.Errorf("[]byte scan: %v", err)
+ } else if got != want {
+ t.Errorf("for []byte, got %q; want %q", got, want)
+ }
+
+ err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
+ want = string(blob)
+ if err != nil {
+ t.Errorf("string scan: %v", err)
+ } else if got != want {
+ t.Errorf("for string, got %q; want %q", got, want)
+ }
+}
+
+// testManyQueryRow is test for many query row
+func testManyQueryRow(t *testing.T) {
+ if testing.Short() {
+ t.Log("skipping in short mode")
+ return
+ }
+ db.tearDown()
+ db.mustExec("create table foo (id integer primary key, name varchar(50))")
+ db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
+ var name string
+ for i := 0; i < 10000; i++ {
+ err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
+ if err != nil || name != "bob" {
+ t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
+ }
+ }
+}
+
+// testTxQuery is test for transactional query
+func testTxQuery(t *testing.T) {
+ db.tearDown()
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer tx.Rollback()
+
+ _, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ _, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer r.Close()
+
+ if !r.Next() {
+ if r.Err() != nil {
+ t.Fatal(err)
+ }
+ t.Fatal("expected one rows")
+ }
+
+ var name string
+ err = r.Scan(&name)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+// testPreparedStmt is test for prepared statement
+func testPreparedStmt(t *testing.T) {
+ db.tearDown()
+ db.mustExec("CREATE TABLE t (count INT)")
+ sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
+ if err != nil {
+ t.Fatalf("prepare 1: %v", err)
+ }
+ ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
+ if err != nil {
+ t.Fatalf("prepare 2: %v", err)
+ }
+
+ for n := 1; n <= 3; n++ {
+ if _, err := ins.Exec(n); err != nil {
+ t.Fatalf("insert(%d) = %v", n, err)
+ }
+ }
+
+ const nRuns = 10
+ var wg sync.WaitGroup
+ for i := 0; i < nRuns; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for j := 0; j < 10; j++ {
+ count := 0
+ if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
+ t.Errorf("Query: %v", err)
+ return
+ }
+ if _, err := ins.Exec(rand.Intn(100)); err != nil {
+ t.Errorf("Insert: %v", err)
+ return
+ }
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+// Benchmarks need to use panic() since b.Error errors are lost when
+// running via testing.Benchmark() I would like to run these via go
+// test -bench but calling Benchmark() from a benchmark test
+// currently hangs go.
+
+// benchmarkExec is benchmark for exec
+func benchmarkExec(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ if _, err := db.Exec("select 1"); err != nil {
+ panic(err)
+ }
+ }
+}
+
+// benchmarkQuery is benchmark for query
+func benchmarkQuery(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ var n sql.NullString
+ var i int
+ var f float64
+ var s string
+ // var t time.Time
+ if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
+ panic(err)
+ }
+ }
+}
+
+// benchmarkParams is benchmark for params
+func benchmarkParams(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ var n sql.NullString
+ var i int
+ var f float64
+ var s string
+ // var t time.Time
+ if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
+ panic(err)
+ }
+ }
+}
+
+// benchmarkStmt is benchmark for statement
+func benchmarkStmt(b *testing.B) {
+ st, err := db.Prepare("select ?, ?, ?, ?")
+ if err != nil {
+ panic(err)
+ }
+ defer st.Close()
+
+ for n := 0; n < b.N; n++ {
+ var n sql.NullString
+ var i int
+ var f float64
+ var s string
+ // var t time.Time
+ if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
+ panic(err)
+ }
+ }
+}
+
+// benchmarkRows is benchmark for rows
+func benchmarkRows(b *testing.B) {
+ db.once.Do(makeBench)
+
+ for n := 0; n < b.N; n++ {
+ var n sql.NullString
+ var i int
+ var f float64
+ var s string
+ var t time.Time
+ r, err := db.Query("select * from bench")
+ if err != nil {
+ panic(err)
+ }
+ for r.Next() {
+ if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
+ panic(err)
+ }
+ }
+ if err = r.Err(); err != nil {
+ panic(err)
+ }
+ }
+}
+
+// benchmarkStmtRows is benchmark for statement rows
+func benchmarkStmtRows(b *testing.B) {
+ db.once.Do(makeBench)
+
+ st, err := db.Prepare("select * from bench")
+ if err != nil {
+ panic(err)
+ }
+ defer st.Close()
+
+ for n := 0; n < b.N; n++ {
+ var n sql.NullString
+ var i int
+ var f float64
+ var s string
+ var t time.Time
+ r, err := st.Query()
+ if err != nil {
+ panic(err)
+ }
+ for r.Next() {
+ if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
+ panic(err)
+ }
+ }
+ if err = r.Err(); err != nil {
+ panic(err)
+ }
+ }
+}
diff --git a/sqlite3_test/sqltest.go b/sqlite3_test/sqltest.go
deleted file mode 100644
index 0ad9c3a..0000000
--- a/sqlite3_test/sqltest.go
+++ /dev/null
@@ -1,423 +0,0 @@
-package sqlite3_test
-
-import (
- "database/sql"
- "fmt"
- "math/rand"
- "regexp"
- "strconv"
- "sync"
- "testing"
- "time"
-)
-
-// Dialect is a type of dialect of databases.
-type Dialect int
-
-// Dialects for databases.
-const (
- SQLITE Dialect = iota // SQLITE mean SQLite3 dialect
- POSTGRESQL // POSTGRESQL mean PostgreSQL dialect
- MYSQL // MYSQL mean MySQL dialect
-)
-
-// DB provide context for the tests
-type DB struct {
- *testing.T
- *sql.DB
- dialect Dialect
- once sync.Once
-}
-
-var db *DB
-
-// the following tables will be created and dropped during the test
-var testTables = []string{"foo", "bar", "t", "bench"}
-
-var tests = []testing.InternalTest{
- {Name: "TestBlobs", F: TestBlobs},
- {Name: "TestManyQueryRow", F: TestManyQueryRow},
- {Name: "TestTxQuery", F: TestTxQuery},
- {Name: "TestPreparedStmt", F: TestPreparedStmt},
-}
-
-var benchmarks = []testing.InternalBenchmark{
- {Name: "BenchmarkExec", F: BenchmarkExec},
- {Name: "BenchmarkQuery", F: BenchmarkQuery},
- {Name: "BenchmarkParams", F: BenchmarkParams},
- {Name: "BenchmarkStmt", F: BenchmarkStmt},
- {Name: "BenchmarkRows", F: BenchmarkRows},
- {Name: "BenchmarkStmtRows", F: BenchmarkStmtRows},
-}
-
-// RunTests runs the SQL test suite
-func RunTests(t *testing.T, d *sql.DB, dialect Dialect) {
- db = &DB{t, d, dialect, sync.Once{}}
- testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
-
- if !testing.Short() {
- for _, b := range benchmarks {
- fmt.Printf("%-20s", b.Name)
- r := testing.Benchmark(b.F)
- fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
- }
- }
- db.tearDown()
-}
-
-func (db *DB) mustExec(sql string, args ...interface{}) sql.Result {
- res, err := db.Exec(sql, args...)
- if err != nil {
- db.Fatalf("Error running %q: %v", sql, err)
- }
- return res
-}
-
-func (db *DB) tearDown() {
- for _, tbl := range testTables {
- switch db.dialect {
- case SQLITE:
- db.mustExec("drop table if exists " + tbl)
- case MYSQL, POSTGRESQL:
- db.mustExec("drop table if exists " + tbl)
- default:
- db.Fatal("unknown dialect")
- }
- }
-}
-
-// q replaces ? parameters if needed
-func (db *DB) q(sql string) string {
- switch db.dialect {
- case POSTGRESQL: // repace with $1, $2, ..
- qrx := regexp.MustCompile(`\?`)
- n := 0
- return qrx.ReplaceAllStringFunc(sql, func(string) string {
- n++
- return "$" + strconv.Itoa(n)
- })
- }
- return sql
-}
-
-func (db *DB) blobType(size int) string {
- switch db.dialect {
- case SQLITE:
- return fmt.Sprintf("blob[%d]", size)
- case POSTGRESQL:
- return "bytea"
- case MYSQL:
- return fmt.Sprintf("VARBINARY(%d)", size)
- }
- panic("unknown dialect")
-}
-
-func (db *DB) serialPK() string {
- switch db.dialect {
- case SQLITE:
- return "integer primary key autoincrement"
- case POSTGRESQL:
- return "serial primary key"
- case MYSQL:
- return "integer primary key auto_increment"
- }
- panic("unknown dialect")
-}
-
-func (db *DB) now() string {
- switch db.dialect {
- case SQLITE:
- return "datetime('now')"
- case POSTGRESQL:
- return "now()"
- case MYSQL:
- return "now()"
- }
- panic("unknown dialect")
-}
-
-func makeBench() {
- if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
- panic(err)
- }
- st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
- if err != nil {
- panic(err)
- }
- defer st.Close()
- for i := 0; i < 100; i++ {
- if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
- panic(err)
- }
- }
-}
-
-// TestResult is test for result
-func TestResult(t *testing.T) {
- db.tearDown()
- db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
-
- for i := 1; i < 3; i++ {
- r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
- n, err := r.RowsAffected()
- if err != nil {
- t.Fatal(err)
- }
- if n != 1 {
- t.Errorf("got %v, want %v", n, 1)
- }
- n, err = r.LastInsertId()
- if err != nil {
- t.Fatal(err)
- }
- if n != int64(i) {
- t.Errorf("got %v, want %v", n, i)
- }
- }
- if _, err := db.Exec("error!"); err == nil {
- t.Fatalf("expected error")
- }
-}
-
-// TestBlobs is test for blobs
-func TestBlobs(t *testing.T) {
- db.tearDown()
- var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
- db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
- db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
-
- want := fmt.Sprintf("%x", blob)
-
- b := make([]byte, 16)
- err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
- got := fmt.Sprintf("%x", b)
- if err != nil {
- t.Errorf("[]byte scan: %v", err)
- } else if got != want {
- t.Errorf("for []byte, got %q; want %q", got, want)
- }
-
- err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
- want = string(blob)
- if err != nil {
- t.Errorf("string scan: %v", err)
- } else if got != want {
- t.Errorf("for string, got %q; want %q", got, want)
- }
-}
-
-// TestManyQueryRow is test for many query row
-func TestManyQueryRow(t *testing.T) {
- if testing.Short() {
- t.Log("skipping in short mode")
- return
- }
- db.tearDown()
- db.mustExec("create table foo (id integer primary key, name varchar(50))")
- db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
- var name string
- for i := 0; i < 10000; i++ {
- err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
- if err != nil || name != "bob" {
- t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
- }
- }
-}
-
-// TestTxQuery is test for transactional query
-func TestTxQuery(t *testing.T) {
- db.tearDown()
- tx, err := db.Begin()
- if err != nil {
- t.Fatal(err)
- }
- defer tx.Rollback()
-
- _, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
- if err != nil {
- t.Fatal(err)
- }
-
- _, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
- if err != nil {
- t.Fatal(err)
- }
-
- r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
- if err != nil {
- t.Fatal(err)
- }
- defer r.Close()
-
- if !r.Next() {
- if r.Err() != nil {
- t.Fatal(err)
- }
- t.Fatal("expected one rows")
- }
-
- var name string
- err = r.Scan(&name)
- if err != nil {
- t.Fatal(err)
- }
-}
-
-// TestPreparedStmt is test for prepared statement
-func TestPreparedStmt(t *testing.T) {
- db.tearDown()
- db.mustExec("CREATE TABLE t (count INT)")
- sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
- if err != nil {
- t.Fatalf("prepare 1: %v", err)
- }
- ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
- if err != nil {
- t.Fatalf("prepare 2: %v", err)
- }
-
- for n := 1; n <= 3; n++ {
- if _, err := ins.Exec(n); err != nil {
- t.Fatalf("insert(%d) = %v", n, err)
- }
- }
-
- const nRuns = 10
- var wg sync.WaitGroup
- for i := 0; i < nRuns; i++ {
- wg.Add(1)
- go func() {
- defer wg.Done()
- for j := 0; j < 10; j++ {
- count := 0
- if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
- t.Errorf("Query: %v", err)
- return
- }
- if _, err := ins.Exec(rand.Intn(100)); err != nil {
- t.Errorf("Insert: %v", err)
- return
- }
- }
- }()
- }
- wg.Wait()
-}
-
-// Benchmarks need to use panic() since b.Error errors are lost when
-// running via testing.Benchmark() I would like to run these via go
-// test -bench but calling Benchmark() from a benchmark test
-// currently hangs go.
-
-// BenchmarkExec is benchmark for exec
-func BenchmarkExec(b *testing.B) {
- for i := 0; i < b.N; i++ {
- if _, err := db.Exec("select 1"); err != nil {
- panic(err)
- }
- }
-}
-
-// BenchmarkQuery is benchmark for query
-func BenchmarkQuery(b *testing.B) {
- for i := 0; i < b.N; i++ {
- var n sql.NullString
- var i int
- var f float64
- var s string
- // var t time.Time
- if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
- panic(err)
- }
- }
-}
-
-// BenchmarkParams is benchmark for params
-func BenchmarkParams(b *testing.B) {
- for i := 0; i < b.N; i++ {
- var n sql.NullString
- var i int
- var f float64
- var s string
- // var t time.Time
- if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
- panic(err)
- }
- }
-}
-
-// BenchmarkStmt is benchmark for statement
-func BenchmarkStmt(b *testing.B) {
- st, err := db.Prepare("select ?, ?, ?, ?")
- if err != nil {
- panic(err)
- }
- defer st.Close()
-
- for n := 0; n < b.N; n++ {
- var n sql.NullString
- var i int
- var f float64
- var s string
- // var t time.Time
- if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
- panic(err)
- }
- }
-}
-
-// BenchmarkRows is benchmark for rows
-func BenchmarkRows(b *testing.B) {
- db.once.Do(makeBench)
-
- for n := 0; n < b.N; n++ {
- var n sql.NullString
- var i int
- var f float64
- var s string
- var t time.Time
- r, err := db.Query("select * from bench")
- if err != nil {
- panic(err)
- }
- for r.Next() {
- if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
- panic(err)
- }
- }
- if err = r.Err(); err != nil {
- panic(err)
- }
- }
-}
-
-// BenchmarkStmtRows is benchmark for statement rows
-func BenchmarkStmtRows(b *testing.B) {
- db.once.Do(makeBench)
-
- st, err := db.Prepare("select * from bench")
- if err != nil {
- panic(err)
- }
- defer st.Close()
-
- for n := 0; n < b.N; n++ {
- var n sql.NullString
- var i int
- var f float64
- var s string
- var t time.Time
- r, err := st.Query()
- if err != nil {
- panic(err)
- }
- for r.Next() {
- if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
- panic(err)
- }
- }
- if err = r.Err(); err != nil {
- panic(err)
- }
- }
-}