aboutsummaryrefslogtreecommitdiff
path: root/sqlite3_test.go
diff options
context:
space:
mode:
authormattn <mattn.jp@gmail.com>2017-08-30 19:57:18 +0900
committerGitHub <noreply@github.com>2017-08-30 19:57:18 +0900
commit132eeedb4ad10f217a357d731f930cbdd1360733 (patch)
tree45438e41d445e4496b5c80066b74ba760ac07fb1 /sqlite3_test.go
parentAdd support for collation sequences implemented in Go. (diff)
parentMerge pull request #461 from mattn/solaris (diff)
downloadgolite-132eeedb4ad10f217a357d731f930cbdd1360733.tar.gz
golite-132eeedb4ad10f217a357d731f930cbdd1360733.tar.xz
Merge branch 'master' into master
Diffstat (limited to 'sqlite3_test.go')
-rw-r--r--sqlite3_test.go582
1 files changed, 568 insertions, 14 deletions
diff --git a/sqlite3_test.go b/sqlite3_test.go
index 842f5d7..9d4b373 100644
--- a/sqlite3_test.go
+++ b/sqlite3_test.go
@@ -6,21 +6,22 @@
package sqlite3
import (
+ "bytes"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io/ioutil"
+ "math/rand"
"net/url"
"os"
"reflect"
"regexp"
+ "strconv"
"strings"
"sync"
"testing"
"time"
-
- "github.com/mattn/go-sqlite3/sqlite3_test"
)
func TempFilename(t *testing.T) string {
@@ -136,6 +137,35 @@ func TestForeignKeys(t *testing.T) {
}
}
+func TestRecursiveTriggers(t *testing.T) {
+ cases := map[string]bool{
+ "?_recursive_triggers=1": true,
+ "?_recursive_triggers=0": false,
+ }
+ for option, want := range cases {
+ fname := TempFilename(t)
+ uri := "file:" + fname + option
+ db, err := sql.Open("sqlite3", uri)
+ if err != nil {
+ os.Remove(fname)
+ t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
+ continue
+ }
+ var enabled bool
+ err = db.QueryRow("PRAGMA recursive_triggers;").Scan(&enabled)
+ db.Close()
+ os.Remove(fname)
+ if err != nil {
+ t.Errorf("query recursive_triggers for %s: %v", uri, err)
+ continue
+ }
+ if enabled != want {
+ t.Errorf("\"PRAGMA recursive_triggers;\" for %q = %t; want %t", uri, enabled, want)
+ continue
+ }
+ }
+}
+
func TestClose(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
@@ -403,6 +433,7 @@ func TestTimestamp(t *testing.T) {
}{
{"nonsense", time.Time{}},
{"0000-00-00 00:00:00", time.Time{}},
+ {time.Time{}.Unix(), time.Time{}},
{timestamp1, timestamp1},
{timestamp2.Unix(), timestamp2.Truncate(time.Second)},
{timestamp2.UnixNano() / int64(time.Millisecond), timestamp2.Truncate(time.Millisecond)},
@@ -840,18 +871,6 @@ func TestTimezoneConversion(t *testing.T) {
}
}
-func TestSuite(t *testing.T) {
- tempFilename := TempFilename(t)
- defer os.Remove(tempFilename)
- db, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
- if err != nil {
- t.Fatal(err)
- }
- defer db.Close()
-
- sqlite3_test.RunTests(t, db, sqlite3_test.SQLITE)
-}
-
// TODO: Execer & Queryer currently disabled
// https://github.com/mattn/go-sqlite3/issues/82
func TestExecer(t *testing.T) {
@@ -1385,6 +1404,122 @@ func TestPinger(t *testing.T) {
}
}
+func TestUpdateAndTransactionHooks(t *testing.T) {
+ var events []string
+ var commitHookReturn = 0
+
+ sql.Register("sqlite3_UpdateHook", &SQLiteDriver{
+ ConnectHook: func(conn *SQLiteConn) error {
+ conn.RegisterCommitHook(func() int {
+ events = append(events, "commit")
+ return commitHookReturn
+ })
+ conn.RegisterRollbackHook(func() {
+ events = append(events, "rollback")
+ })
+ conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
+ events = append(events, fmt.Sprintf("update(op=%v db=%v table=%v rowid=%v)", op, db, table, rowid))
+ })
+ return nil
+ },
+ })
+ db, err := sql.Open("sqlite3_UpdateHook", ":memory:")
+ if err != nil {
+ t.Fatal("Failed to open database:", err)
+ }
+ defer db.Close()
+
+ statements := []string{
+ "create table foo (id integer primary key)",
+ "insert into foo values (9)",
+ "update foo set id = 99 where id = 9",
+ "delete from foo where id = 99",
+ }
+ for _, statement := range statements {
+ _, err = db.Exec(statement)
+ if err != nil {
+ t.Fatalf("Unable to prepare test data [%v]: %v", statement, err)
+ }
+ }
+
+ commitHookReturn = 1
+ _, err = db.Exec("insert into foo values (5)")
+ if err == nil {
+ t.Error("Commit hook failed to rollback transaction")
+ }
+
+ var expected = []string{
+ "commit",
+ fmt.Sprintf("update(op=%v db=main table=foo rowid=9)", SQLITE_INSERT),
+ "commit",
+ fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_UPDATE),
+ "commit",
+ fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_DELETE),
+ "commit",
+ fmt.Sprintf("update(op=%v db=main table=foo rowid=5)", SQLITE_INSERT),
+ "commit",
+ "rollback",
+ }
+ if !reflect.DeepEqual(events, expected) {
+ t.Errorf("Expected notifications %v but got %v", expected, events)
+ }
+}
+
+func TestNilAndEmptyBytes(t *testing.T) {
+ db, err := sql.Open("sqlite3", ":memory:")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+ actualNil := []byte("use this to use an actual nil not a reference to nil")
+ emptyBytes := []byte{}
+ for tsti, tst := range []struct {
+ name string
+ columnType string
+ insertBytes []byte
+ expectedBytes []byte
+ }{
+ {"actual nil blob", "blob", actualNil, nil},
+ {"referenced nil blob", "blob", nil, nil},
+ {"empty blob", "blob", emptyBytes, emptyBytes},
+ {"actual nil text", "text", actualNil, nil},
+ {"referenced nil text", "text", nil, nil},
+ {"empty text", "text", emptyBytes, emptyBytes},
+ } {
+ if _, err = db.Exec(fmt.Sprintf("create table tbl%d (txt %s)", tsti, tst.columnType)); err != nil {
+ t.Fatal(tst.name, err)
+ }
+ if bytes.Equal(tst.insertBytes, actualNil) {
+ if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), nil); err != nil {
+ t.Fatal(tst.name, err)
+ }
+ } else {
+ if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), &tst.insertBytes); err != nil {
+ t.Fatal(tst.name, err)
+ }
+ }
+ rows, err := db.Query(fmt.Sprintf("select txt from tbl%d", tsti))
+ if err != nil {
+ t.Fatal(tst.name, err)
+ }
+ if !rows.Next() {
+ t.Fatal(tst.name, "no rows")
+ }
+ var scanBytes []byte
+ if err = rows.Scan(&scanBytes); err != nil {
+ t.Fatal(tst.name, err)
+ }
+ if err = rows.Err(); err != nil {
+ t.Fatal(tst.name, err)
+ }
+ if tst.expectedBytes == nil && scanBytes != nil {
+ t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes)
+ } else if !bytes.Equal(scanBytes, tst.expectedBytes) {
+ t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes)
+ }
+ }
+}
+
var customFunctionOnce sync.Once
func BenchmarkCustomFunctions(b *testing.B) {
@@ -1419,3 +1554,422 @@ func BenchmarkCustomFunctions(b *testing.B) {
}
}
}
+
+func TestSuite(t *testing.T) {
+ tempFilename := TempFilename(t)
+ defer os.Remove(tempFilename)
+ d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer d.Close()
+
+ db = &TestDB{t, d, SQLITE, sync.Once{}}
+ testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
+
+ if !testing.Short() {
+ for _, b := range benchmarks {
+ fmt.Printf("%-20s", b.Name)
+ r := testing.Benchmark(b.F)
+ fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
+ }
+ }
+ db.tearDown()
+}
+
+// Dialect is a type of dialect of databases.
+type Dialect int
+
+// Dialects for databases.
+const (
+ SQLITE Dialect = iota // SQLITE mean SQLite3 dialect
+ POSTGRESQL // POSTGRESQL mean PostgreSQL dialect
+ MYSQL // MYSQL mean MySQL dialect
+)
+
+// DB provide context for the tests
+type TestDB struct {
+ *testing.T
+ *sql.DB
+ dialect Dialect
+ once sync.Once
+}
+
+var db *TestDB
+
+// the following tables will be created and dropped during the test
+var testTables = []string{"foo", "bar", "t", "bench"}
+
+var tests = []testing.InternalTest{
+ {Name: "TestResult", F: testResult},
+ {Name: "TestBlobs", F: testBlobs},
+ {Name: "TestManyQueryRow", F: testManyQueryRow},
+ {Name: "TestTxQuery", F: testTxQuery},
+ {Name: "TestPreparedStmt", F: testPreparedStmt},
+}
+
+var benchmarks = []testing.InternalBenchmark{
+ {Name: "BenchmarkExec", F: benchmarkExec},
+ {Name: "BenchmarkQuery", F: benchmarkQuery},
+ {Name: "BenchmarkParams", F: benchmarkParams},
+ {Name: "BenchmarkStmt", F: benchmarkStmt},
+ {Name: "BenchmarkRows", F: benchmarkRows},
+ {Name: "BenchmarkStmtRows", F: benchmarkStmtRows},
+}
+
+func (db *TestDB) mustExec(sql string, args ...interface{}) sql.Result {
+ res, err := db.Exec(sql, args...)
+ if err != nil {
+ db.Fatalf("Error running %q: %v", sql, err)
+ }
+ return res
+}
+
+func (db *TestDB) tearDown() {
+ for _, tbl := range testTables {
+ switch db.dialect {
+ case SQLITE:
+ db.mustExec("drop table if exists " + tbl)
+ case MYSQL, POSTGRESQL:
+ db.mustExec("drop table if exists " + tbl)
+ default:
+ db.Fatal("unknown dialect")
+ }
+ }
+}
+
+// q replaces ? parameters if needed
+func (db *TestDB) q(sql string) string {
+ switch db.dialect {
+ case POSTGRESQL: // repace with $1, $2, ..
+ qrx := regexp.MustCompile(`\?`)
+ n := 0
+ return qrx.ReplaceAllStringFunc(sql, func(string) string {
+ n++
+ return "$" + strconv.Itoa(n)
+ })
+ }
+ return sql
+}
+
+func (db *TestDB) blobType(size int) string {
+ switch db.dialect {
+ case SQLITE:
+ return fmt.Sprintf("blob[%d]", size)
+ case POSTGRESQL:
+ return "bytea"
+ case MYSQL:
+ return fmt.Sprintf("VARBINARY(%d)", size)
+ }
+ panic("unknown dialect")
+}
+
+func (db *TestDB) serialPK() string {
+ switch db.dialect {
+ case SQLITE:
+ return "integer primary key autoincrement"
+ case POSTGRESQL:
+ return "serial primary key"
+ case MYSQL:
+ return "integer primary key auto_increment"
+ }
+ panic("unknown dialect")
+}
+
+func (db *TestDB) now() string {
+ switch db.dialect {
+ case SQLITE:
+ return "datetime('now')"
+ case POSTGRESQL:
+ return "now()"
+ case MYSQL:
+ return "now()"
+ }
+ panic("unknown dialect")
+}
+
+func makeBench() {
+ if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
+ panic(err)
+ }
+ st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
+ if err != nil {
+ panic(err)
+ }
+ defer st.Close()
+ for i := 0; i < 100; i++ {
+ if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
+ panic(err)
+ }
+ }
+}
+
+// testResult is test for result
+func testResult(t *testing.T) {
+ db.tearDown()
+ db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
+
+ for i := 1; i < 3; i++ {
+ r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
+ n, err := r.RowsAffected()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 1 {
+ t.Errorf("got %v, want %v", n, 1)
+ }
+ n, err = r.LastInsertId()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != int64(i) {
+ t.Errorf("got %v, want %v", n, i)
+ }
+ }
+ if _, err := db.Exec("error!"); err == nil {
+ t.Fatalf("expected error")
+ }
+}
+
+// testBlobs is test for blobs
+func testBlobs(t *testing.T) {
+ db.tearDown()
+ var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+ db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
+ db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
+
+ want := fmt.Sprintf("%x", blob)
+
+ b := make([]byte, 16)
+ err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
+ got := fmt.Sprintf("%x", b)
+ if err != nil {
+ t.Errorf("[]byte scan: %v", err)
+ } else if got != want {
+ t.Errorf("for []byte, got %q; want %q", got, want)
+ }
+
+ err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
+ want = string(blob)
+ if err != nil {
+ t.Errorf("string scan: %v", err)
+ } else if got != want {
+ t.Errorf("for string, got %q; want %q", got, want)
+ }
+}
+
+// testManyQueryRow is test for many query row
+func testManyQueryRow(t *testing.T) {
+ if testing.Short() {
+ t.Log("skipping in short mode")
+ return
+ }
+ db.tearDown()
+ db.mustExec("create table foo (id integer primary key, name varchar(50))")
+ db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
+ var name string
+ for i := 0; i < 10000; i++ {
+ err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
+ if err != nil || name != "bob" {
+ t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
+ }
+ }
+}
+
+// testTxQuery is test for transactional query
+func testTxQuery(t *testing.T) {
+ db.tearDown()
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer tx.Rollback()
+
+ _, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ _, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer r.Close()
+
+ if !r.Next() {
+ if r.Err() != nil {
+ t.Fatal(err)
+ }
+ t.Fatal("expected one rows")
+ }
+
+ var name string
+ err = r.Scan(&name)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+// testPreparedStmt is test for prepared statement
+func testPreparedStmt(t *testing.T) {
+ db.tearDown()
+ db.mustExec("CREATE TABLE t (count INT)")
+ sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
+ if err != nil {
+ t.Fatalf("prepare 1: %v", err)
+ }
+ ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
+ if err != nil {
+ t.Fatalf("prepare 2: %v", err)
+ }
+
+ for n := 1; n <= 3; n++ {
+ if _, err := ins.Exec(n); err != nil {
+ t.Fatalf("insert(%d) = %v", n, err)
+ }
+ }
+
+ const nRuns = 10
+ var wg sync.WaitGroup
+ for i := 0; i < nRuns; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for j := 0; j < 10; j++ {
+ count := 0
+ if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
+ t.Errorf("Query: %v", err)
+ return
+ }
+ if _, err := ins.Exec(rand.Intn(100)); err != nil {
+ t.Errorf("Insert: %v", err)
+ return
+ }
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+// Benchmarks need to use panic() since b.Error errors are lost when
+// running via testing.Benchmark() I would like to run these via go
+// test -bench but calling Benchmark() from a benchmark test
+// currently hangs go.
+
+// benchmarkExec is benchmark for exec
+func benchmarkExec(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ if _, err := db.Exec("select 1"); err != nil {
+ panic(err)
+ }
+ }
+}
+
+// benchmarkQuery is benchmark for query
+func benchmarkQuery(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ var n sql.NullString
+ var i int
+ var f float64
+ var s string
+ // var t time.Time
+ if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
+ panic(err)
+ }
+ }
+}
+
+// benchmarkParams is benchmark for params
+func benchmarkParams(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ var n sql.NullString
+ var i int
+ var f float64
+ var s string
+ // var t time.Time
+ if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
+ panic(err)
+ }
+ }
+}
+
+// benchmarkStmt is benchmark for statement
+func benchmarkStmt(b *testing.B) {
+ st, err := db.Prepare("select ?, ?, ?, ?")
+ if err != nil {
+ panic(err)
+ }
+ defer st.Close()
+
+ for n := 0; n < b.N; n++ {
+ var n sql.NullString
+ var i int
+ var f float64
+ var s string
+ // var t time.Time
+ if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
+ panic(err)
+ }
+ }
+}
+
+// benchmarkRows is benchmark for rows
+func benchmarkRows(b *testing.B) {
+ db.once.Do(makeBench)
+
+ for n := 0; n < b.N; n++ {
+ var n sql.NullString
+ var i int
+ var f float64
+ var s string
+ var t time.Time
+ r, err := db.Query("select * from bench")
+ if err != nil {
+ panic(err)
+ }
+ for r.Next() {
+ if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
+ panic(err)
+ }
+ }
+ if err = r.Err(); err != nil {
+ panic(err)
+ }
+ }
+}
+
+// benchmarkStmtRows is benchmark for statement rows
+func benchmarkStmtRows(b *testing.B) {
+ db.once.Do(makeBench)
+
+ st, err := db.Prepare("select * from bench")
+ if err != nil {
+ panic(err)
+ }
+ defer st.Close()
+
+ for n := 0; n < b.N; n++ {
+ var n sql.NullString
+ var i int
+ var f float64
+ var s string
+ var t time.Time
+ r, err := st.Query()
+ if err != nil {
+ panic(err)
+ }
+ for r.Next() {
+ if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
+ panic(err)
+ }
+ }
+ if err = r.Err(); err != nil {
+ panic(err)
+ }
+ }
+}