aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sqlite3.go96
-rw-r--r--sqlite3_go18_test.go88
2 files changed, 147 insertions, 37 deletions
diff --git a/sqlite3.go b/sqlite3.go
index 4000173..1f0730a 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -328,7 +328,7 @@ type SQLiteRows struct {
decltype []string
cls bool
closed bool
- done chan struct{}
+ ctx context.Context // no better alternative to pass context into Next() method
}
type functionInfo struct {
@@ -1846,22 +1846,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
decltype: nil,
cls: s.cls,
closed: false,
- done: make(chan struct{}),
- }
-
- if ctxdone := ctx.Done(); ctxdone != nil {
- go func(db *C.sqlite3) {
- select {
- case <-ctxdone:
- select {
- case <-rows.done:
- default:
- C.sqlite3_interrupt(db)
- rows.Close()
- }
- case <-rows.done:
- }
- }(s.c.db)
+ ctx: ctx,
}
return rows, nil
@@ -1889,29 +1874,43 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
return s.exec(context.Background(), list)
}
+// exec executes a query that doesn't return rows. Attempts to honor context timeout.
func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
+ if ctx.Done() == nil {
+ return s.execSync(args)
+ }
+
+ type result struct {
+ r driver.Result
+ err error
+ }
+ resultCh := make(chan result)
+ go func() {
+ r, err := s.execSync(args)
+ resultCh <- result{r, err}
+ }()
+ select {
+ case rv := <- resultCh:
+ return rv.r, rv.err
+ case <-ctx.Done():
+ select {
+ case <-resultCh: // no need to interrupt
+ default:
+ // this is still racy and can be no-op if executed between sqlite3_* calls in execSync.
+ C.sqlite3_interrupt(s.c.db)
+ <-resultCh // ensure goroutine completed
+ }
+ return nil, ctx.Err()
+ }
+}
+
+func (s *SQLiteStmt) execSync(args []namedValue) (driver.Result, error) {
if err := s.bind(args); err != nil {
C.sqlite3_reset(s.s)
C.sqlite3_clear_bindings(s.s)
return nil, err
}
- if ctxdone := ctx.Done(); ctxdone != nil {
- done := make(chan struct{})
- defer close(done)
- go func(db *C.sqlite3) {
- select {
- case <-done:
- case <-ctxdone:
- select {
- case <-done:
- default:
- C.sqlite3_interrupt(db)
- }
- }
- }(s.c.db)
- }
-
var rowid, changes C.longlong
rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
@@ -1932,9 +1931,6 @@ func (rc *SQLiteRows) Close() error {
return nil
}
rc.closed = true
- if rc.done != nil {
- close(rc.done)
- }
if rc.cls {
rc.s.mu.Unlock()
return rc.s.Close()
@@ -1978,13 +1974,39 @@ func (rc *SQLiteRows) DeclTypes() []string {
return rc.declTypes()
}
-// Next move cursor to next.
+// Next move cursor to next. Attempts to honor context timeout from QueryContext call.
func (rc *SQLiteRows) Next(dest []driver.Value) error {
rc.s.mu.Lock()
defer rc.s.mu.Unlock()
+
if rc.s.closed {
return io.EOF
}
+
+ if rc.ctx.Done() == nil {
+ return rc.nextSyncLocked(dest)
+ }
+ resultCh := make(chan error)
+ go func() {
+ resultCh <- rc.nextSyncLocked(dest)
+ }()
+ select {
+ case err := <- resultCh:
+ return err
+ case <-rc.ctx.Done():
+ select {
+ case <-resultCh: // no need to interrupt
+ default:
+ // this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked.
+ C.sqlite3_interrupt(rc.s.c.db)
+ <-resultCh // ensure goroutine completed
+ }
+ return rc.ctx.Err()
+ }
+}
+
+// nextSyncLocked moves cursor to next; must be called with locked mutex.
+func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error {
rv := C._sqlite3_step_internal(rc.s.s)
if rv == C.SQLITE_DONE {
return io.EOF
diff --git a/sqlite3_go18_test.go b/sqlite3_go18_test.go
index c9e79e7..37117b0 100644
--- a/sqlite3_go18_test.go
+++ b/sqlite3_go18_test.go
@@ -14,6 +14,7 @@ import (
"io/ioutil"
"math/rand"
"os"
+ "sync"
"testing"
"time"
)
@@ -135,6 +136,93 @@ func TestShortTimeout(t *testing.T) {
}
}
+func TestQueryRowContextCancel(t *testing.T) {
+ srcTempFilename := TempFilename(t)
+ defer os.Remove(srcTempFilename)
+
+ db, err := sql.Open("sqlite3", srcTempFilename)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+ initDatabase(t, db, 100)
+
+ const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
+ var keyID string
+ unexpectedErrors := make(map[string]int)
+ for i := 0; i < 10000; i++ {
+ ctx, cancel := context.WithCancel(context.Background())
+ row := db.QueryRowContext(ctx, query)
+
+ cancel()
+ // it is fine to get "nil" as context cancellation can be handled with delay
+ if err := row.Scan(&keyID); err != nil && err != context.Canceled {
+ if err.Error() == "sql: Rows are closed" {
+ // see https://github.com/golang/go/issues/24431
+ // fixed in 1.11.1 to properly return context error
+ continue
+ }
+ unexpectedErrors[err.Error()]++
+ }
+ }
+ for errText, count := range unexpectedErrors {
+ t.Error(errText, count)
+ }
+}
+
+func TestQueryRowContextCancelParallel(t *testing.T) {
+ srcTempFilename := TempFilename(t)
+ defer os.Remove(srcTempFilename)
+
+ db, err := sql.Open("sqlite3", srcTempFilename)
+ if err != nil {
+ t.Fatal(err)
+ }
+ db.SetMaxOpenConns(10)
+ db.SetMaxIdleConns(5)
+
+ defer db.Close()
+ initDatabase(t, db, 100)
+
+ const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
+ wg := sync.WaitGroup{}
+ defer wg.Wait()
+
+ testCtx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ for i := 0; i < 50; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ var keyID string
+ for {
+ select {
+ case <-testCtx.Done():
+ return
+ default:
+ }
+ ctx, cancel := context.WithCancel(context.Background())
+ row := db.QueryRowContext(ctx, query)
+
+ cancel()
+ _ = row.Scan(&keyID) // see TestQueryRowContextCancel
+ }
+ }()
+ }
+
+ var keyID string
+ for i := 0; i < 10000; i++ {
+ // note that testCtx is not cancelled during query execution
+ row := db.QueryRowContext(testCtx, query)
+
+ if err := row.Scan(&keyID); err != nil {
+ t.Fatal(i, err)
+ }
+ }
+}
+
func TestExecCancel(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {