diff options
Diffstat (limited to 'sqlite3_test.go')
-rw-r--r-- | sqlite3_test.go | 82 |
1 files changed, 77 insertions, 5 deletions
diff --git a/sqlite3_test.go b/sqlite3_test.go index 75d8f52..0667893 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1299,10 +1299,7 @@ func TestAggregatorRegistration(t *testing.T) { sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{ ConnectHook: func(conn *SQLiteConn) error { - if err := conn.RegisterAggregator("customSum", customSum, true); err != nil { - return err - } - return nil + return conn.RegisterAggregator("customSum", customSum, true) }, }) db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:") @@ -1574,6 +1571,47 @@ func TestUpdateAndTransactionHooks(t *testing.T) { } } +func TestAuthorizer(t *testing.T) { + var authorizerReturn = 0 + + sql.Register("sqlite3_Authorizer", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + conn.RegisterAuthorizer(func(op int, arg1, arg2, arg3 string) int { + return authorizerReturn + }) + return nil + }, + }) + db, err := sql.Open("sqlite3_Authorizer", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + statements := []string{ + "create table foo (id integer primary key, name varchar)", + "insert into foo values (9, 'test9')", + "update foo set name = 'test99' where id = 9", + "select * from foo", + } + + authorizerReturn = SQLITE_OK + for _, statement := range statements { + _, err = db.Exec(statement) + if err != nil { + t.Fatalf("No error expected [%v]: %v", statement, err) + } + } + + authorizerReturn = SQLITE_DENY + for _, statement := range statements { + _, err = db.Exec(statement) + if err == nil { + t.Fatalf("Authorizer didn't worked - nil received, but error expected: [%v]", statement) + } + } +} + func TestNilAndEmptyBytes(t *testing.T) { db, err := sql.Open("sqlite3", ":memory:") if err != nil { @@ -1690,7 +1728,10 @@ func TestSuite(t *testing.T) { defer d.Close() db = &TestDB{t, d, SQLITE, sync.Once{}} - testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests) + ok := testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests) + if !ok { + t.Fatal("A subtest failed") + } if !testing.Short() { for _, b := range benchmarks { @@ -1729,6 +1770,7 @@ var tests = []testing.InternalTest{ {Name: "TestResult", F: testResult}, {Name: "TestBlobs", F: testBlobs}, {Name: "TestMultiBlobs", F: testMultiBlobs}, + {Name: "TestNullZeroLengthBlobs", F: testNullZeroLengthBlobs}, {Name: "TestManyQueryRow", F: testManyQueryRow}, {Name: "TestTxQuery", F: testTxQuery}, {Name: "TestPreparedStmt", F: testPreparedStmt}, @@ -1934,6 +1976,36 @@ func testMultiBlobs(t *testing.T) { } } +// testBlobs tests that we distinguish between null and zero-length blobs +func testNullZeroLengthBlobs(t *testing.T) { + db.tearDown() + db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")") + db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, nil) + db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 1, []byte{}) + + r0 := db.QueryRow(db.q("select bar from foo where id=0")) + var b0 []byte + err := r0.Scan(&b0) + if err != nil { + t.Fatal(err) + } + if b0 != nil { + t.Errorf("for id=0, got %x; want nil", b0) + } + + r1 := db.QueryRow(db.q("select bar from foo where id=1")) + var b1 []byte + err = r1.Scan(&b1) + if err != nil { + t.Fatal(err) + } + if b1 == nil { + t.Error("for id=1, got nil; want zero-length slice") + } else if len(b1) > 0 { + t.Errorf("for id=1, got %x; want zero-length slice", b1) + } +} + // testManyQueryRow is test for many query row func testManyQueryRow(t *testing.T) { if testing.Short() { |