diff options
-rw-r--r-- | sqlite3.go | 21 | ||||
-rw-r--r-- | sqlite3_go113_test.go | 74 | ||||
-rw-r--r-- | sqlite3_go18_test.go | 40 |
3 files changed, 129 insertions, 6 deletions
@@ -1918,6 +1918,14 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { return s.exec(context.Background(), list) } +func isInterruptErr(err error) bool { + sqliteErr, ok := err.(Error) + if ok { + return sqliteErr.Code == ErrInterrupt + } + return false +} + // exec executes a query that doesn't return rows. Attempts to honor context timeout. func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) { if ctx.Done() == nil { @@ -1933,19 +1941,22 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result r, err := s.execSync(args) resultCh <- result{r, err} }() + var rv result select { - case rv := <-resultCh: - return rv.r, rv.err + case rv = <-resultCh: case <-ctx.Done(): select { - case <-resultCh: // no need to interrupt + case rv = <-resultCh: // no need to interrupt, operation completed in db default: // this is still racy and can be no-op if executed between sqlite3_* calls in execSync. C.sqlite3_interrupt(s.c.db) - <-resultCh // ensure goroutine completed + rv = <-resultCh // wait for goroutine completed + if isInterruptErr(rv.err) { + return nil, ctx.Err() + } } - return nil, ctx.Err() } + return rv.r, rv.err } func (s *SQLiteStmt) execSync(args []namedValue) (driver.Result, error) { diff --git a/sqlite3_go113_test.go b/sqlite3_go113_test.go new file mode 100644 index 0000000..74036f8 --- /dev/null +++ b/sqlite3_go113_test.go @@ -0,0 +1,74 @@ +// Copyright (C) 2019 Yasuhiro Matsumoto <mattn.jp@gmail.com>. +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// +build go1.13,cgo + +package sqlite3 + +import ( + "context" + "database/sql" + "database/sql/driver" + "os" + "testing" +) + +func TestBeginTxCancel(t *testing.T) { + srcTempFilename := TempFilename(t) + defer os.Remove(srcTempFilename) + + db, err := sql.Open("sqlite3", srcTempFilename) + if err != nil { + t.Fatal(err) + } + + db.SetMaxOpenConns(10) + db.SetMaxIdleConns(5) + + defer db.Close() + initDatabase(t, db, 100) + + // create several go-routines to expose racy issue + for i := 0; i < 1000; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() + + err = conn.Raw(func(driverConn interface{}) error { + d, ok := driverConn.(driver.ConnBeginTx) + if !ok { + t.Fatal("unexpected: wrong type") + } + + go cancel() // make it cancel concurrently with exec("BEGIN"); + tx, err := d.BeginTx(ctx, driver.TxOptions{}) + switch err { + case nil: + switch err := tx.Rollback(); err { + case nil, sql.ErrTxDone: + default: + return err + } + case context.Canceled: + default: + // must not fail with "cannot start a transaction within a transaction" + return err + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }() + } +} diff --git a/sqlite3_go18_test.go b/sqlite3_go18_test.go index cfc89b0..5ee3d81 100644 --- a/sqlite3_go18_test.go +++ b/sqlite3_go18_test.go @@ -136,6 +136,44 @@ func TestShortTimeout(t *testing.T) { } } +func TestExecContextCancel(t *testing.T) { + srcTempFilename := TempFilename(t) + defer os.Remove(srcTempFilename) + + db, err := sql.Open("sqlite3", srcTempFilename) + if err != nil { + t.Fatal(err) + } + + defer db.Close() + + ts := time.Now() + initDatabase(t, db, 1000) + spent := time.Since(ts) + if spent < 100*time.Millisecond { + t.Skip("test will be too racy, as ExecContext below will be too fast.") + } + + // expected to be extremely slow query + q := ` +INSERT INTO test_table (key1, key_id, key2, key3, key4, key5, key6, data) +SELECT t1.key1 || t2.key1, t1.key_id || t2.key_id, t1.key2 || t2.key2, t1.key3 || t2.key3, t1.key4 || t2.key4, t1.key5 || t2.key5, t1.key6 || t2.key6, t1.data || t2.data +FROM test_table t1 LEFT OUTER JOIN test_table t2` + // expect query above take ~ same time as setup above + ctx, cancel := context.WithTimeout(context.Background(), spent/2) + defer cancel() + ts = time.Now() + r, err := db.ExecContext(ctx, q) + // racy check + if r != nil { + n, err := r.RowsAffected() + t.Log(n, err, time.Since(ts)) + } + if err != context.DeadlineExceeded { + t.Fatal(err, ctx.Err()) + } +} + func TestQueryRowContextCancel(t *testing.T) { srcTempFilename := TempFilename(t) defer os.Remove(srcTempFilename) @@ -191,7 +229,7 @@ func TestQueryRowContextCancelParallel(t *testing.T) { testCtx, cancel := context.WithCancel(context.Background()) defer cancel() - for i := 0; i < 50; i++ { + for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() |