aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sqlite3.go21
-rw-r--r--sqlite3_go113_test.go74
-rw-r--r--sqlite3_go18_test.go40
3 files changed, 129 insertions, 6 deletions
diff --git a/sqlite3.go b/sqlite3.go
index ababcfd..63e1c4f 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -1918,6 +1918,14 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
return s.exec(context.Background(), list)
}
+func isInterruptErr(err error) bool {
+ sqliteErr, ok := err.(Error)
+ if ok {
+ return sqliteErr.Code == ErrInterrupt
+ }
+ return false
+}
+
// exec executes a query that doesn't return rows. Attempts to honor context timeout.
func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
if ctx.Done() == nil {
@@ -1933,19 +1941,22 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result
r, err := s.execSync(args)
resultCh <- result{r, err}
}()
+ var rv result
select {
- case rv := <-resultCh:
- return rv.r, rv.err
+ case rv = <-resultCh:
case <-ctx.Done():
select {
- case <-resultCh: // no need to interrupt
+ case rv = <-resultCh: // no need to interrupt, operation completed in db
default:
// this is still racy and can be no-op if executed between sqlite3_* calls in execSync.
C.sqlite3_interrupt(s.c.db)
- <-resultCh // ensure goroutine completed
+ rv = <-resultCh // wait for goroutine completed
+ if isInterruptErr(rv.err) {
+ return nil, ctx.Err()
+ }
}
- return nil, ctx.Err()
}
+ return rv.r, rv.err
}
func (s *SQLiteStmt) execSync(args []namedValue) (driver.Result, error) {
diff --git a/sqlite3_go113_test.go b/sqlite3_go113_test.go
new file mode 100644
index 0000000..74036f8
--- /dev/null
+++ b/sqlite3_go113_test.go
@@ -0,0 +1,74 @@
+// Copyright (C) 2019 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
+//
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+// +build go1.13,cgo
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "os"
+ "testing"
+)
+
+func TestBeginTxCancel(t *testing.T) {
+ srcTempFilename := TempFilename(t)
+ defer os.Remove(srcTempFilename)
+
+ db, err := sql.Open("sqlite3", srcTempFilename)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ db.SetMaxOpenConns(10)
+ db.SetMaxIdleConns(5)
+
+ defer db.Close()
+ initDatabase(t, db, 100)
+
+ // create several go-routines to expose racy issue
+ for i := 0; i < 1000; i++ {
+ func() {
+ ctx, cancel := context.WithCancel(context.Background())
+ conn, err := db.Conn(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ if err := conn.Close(); err != nil {
+ t.Error(err)
+ }
+ }()
+
+ err = conn.Raw(func(driverConn interface{}) error {
+ d, ok := driverConn.(driver.ConnBeginTx)
+ if !ok {
+ t.Fatal("unexpected: wrong type")
+ }
+
+ go cancel() // make it cancel concurrently with exec("BEGIN");
+ tx, err := d.BeginTx(ctx, driver.TxOptions{})
+ switch err {
+ case nil:
+ switch err := tx.Rollback(); err {
+ case nil, sql.ErrTxDone:
+ default:
+ return err
+ }
+ case context.Canceled:
+ default:
+ // must not fail with "cannot start a transaction within a transaction"
+ return err
+ }
+ return nil
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ }()
+ }
+}
diff --git a/sqlite3_go18_test.go b/sqlite3_go18_test.go
index cfc89b0..5ee3d81 100644
--- a/sqlite3_go18_test.go
+++ b/sqlite3_go18_test.go
@@ -136,6 +136,44 @@ func TestShortTimeout(t *testing.T) {
}
}
+func TestExecContextCancel(t *testing.T) {
+ srcTempFilename := TempFilename(t)
+ defer os.Remove(srcTempFilename)
+
+ db, err := sql.Open("sqlite3", srcTempFilename)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ defer db.Close()
+
+ ts := time.Now()
+ initDatabase(t, db, 1000)
+ spent := time.Since(ts)
+ if spent < 100*time.Millisecond {
+ t.Skip("test will be too racy, as ExecContext below will be too fast.")
+ }
+
+ // expected to be extremely slow query
+ q := `
+INSERT INTO test_table (key1, key_id, key2, key3, key4, key5, key6, data)
+SELECT t1.key1 || t2.key1, t1.key_id || t2.key_id, t1.key2 || t2.key2, t1.key3 || t2.key3, t1.key4 || t2.key4, t1.key5 || t2.key5, t1.key6 || t2.key6, t1.data || t2.data
+FROM test_table t1 LEFT OUTER JOIN test_table t2`
+ // expect query above take ~ same time as setup above
+ ctx, cancel := context.WithTimeout(context.Background(), spent/2)
+ defer cancel()
+ ts = time.Now()
+ r, err := db.ExecContext(ctx, q)
+ // racy check
+ if r != nil {
+ n, err := r.RowsAffected()
+ t.Log(n, err, time.Since(ts))
+ }
+ if err != context.DeadlineExceeded {
+ t.Fatal(err, ctx.Err())
+ }
+}
+
func TestQueryRowContextCancel(t *testing.T) {
srcTempFilename := TempFilename(t)
defer os.Remove(srcTempFilename)
@@ -191,7 +229,7 @@ func TestQueryRowContextCancelParallel(t *testing.T) {
testCtx, cancel := context.WithCancel(context.Background())
defer cancel()
- for i := 0; i < 50; i++ {
+ for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()