aboutsummaryrefslogtreecommitdiff
path: root/sqlite3.go
diff options
context:
space:
mode:
authormattn <mattn.jp@gmail.com>2019-11-19 01:19:53 +0900
committerGitHub <noreply@github.com>2019-11-19 01:19:53 +0900
commit590d44c02bca83987d23f6eab75e6d0ddf95f644 (patch)
tree2d1979323f14971f927380c6a687a523aa033486 /sqlite3.go
parentMerge pull request #760 from mattn/sqlite-amalgamation-3300100 (diff)
parentFix context cancellation racy handling (diff)
downloadgolite-590d44c02bca83987d23f6eab75e6d0ddf95f644.tar.gz
golite-590d44c02bca83987d23f6eab75e6d0ddf95f644.tar.xz
Merge pull request #744 from azavorotnii/ctx_cancel
Fix context cancellation racy handling
Diffstat (limited to 'sqlite3.go')
-rw-r--r--sqlite3.go96
1 files changed, 59 insertions, 37 deletions
diff --git a/sqlite3.go b/sqlite3.go
index 053e92d..7f0e7c0 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -328,7 +328,7 @@ type SQLiteRows struct {
decltype []string
cls bool
closed bool
- done chan struct{}
+ ctx context.Context // no better alternative to pass context into Next() method
}
type functionInfo struct {
@@ -1847,22 +1847,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
decltype: nil,
cls: s.cls,
closed: false,
- done: make(chan struct{}),
- }
-
- if ctxdone := ctx.Done(); ctxdone != nil {
- go func(db *C.sqlite3) {
- select {
- case <-ctxdone:
- select {
- case <-rows.done:
- default:
- C.sqlite3_interrupt(db)
- rows.Close()
- }
- case <-rows.done:
- }
- }(s.c.db)
+ ctx: ctx,
}
return rows, nil
@@ -1890,29 +1875,43 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
return s.exec(context.Background(), list)
}
+// exec executes a query that doesn't return rows. Attempts to honor context timeout.
func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
+ if ctx.Done() == nil {
+ return s.execSync(args)
+ }
+
+ type result struct {
+ r driver.Result
+ err error
+ }
+ resultCh := make(chan result)
+ go func() {
+ r, err := s.execSync(args)
+ resultCh <- result{r, err}
+ }()
+ select {
+ case rv := <- resultCh:
+ return rv.r, rv.err
+ case <-ctx.Done():
+ select {
+ case <-resultCh: // no need to interrupt
+ default:
+ // this is still racy and can be no-op if executed between sqlite3_* calls in execSync.
+ C.sqlite3_interrupt(s.c.db)
+ <-resultCh // ensure goroutine completed
+ }
+ return nil, ctx.Err()
+ }
+}
+
+func (s *SQLiteStmt) execSync(args []namedValue) (driver.Result, error) {
if err := s.bind(args); err != nil {
C.sqlite3_reset(s.s)
C.sqlite3_clear_bindings(s.s)
return nil, err
}
- if ctxdone := ctx.Done(); ctxdone != nil {
- done := make(chan struct{})
- defer close(done)
- go func(db *C.sqlite3) {
- select {
- case <-done:
- case <-ctxdone:
- select {
- case <-done:
- default:
- C.sqlite3_interrupt(db)
- }
- }
- }(s.c.db)
- }
-
var rowid, changes C.longlong
rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
@@ -1933,9 +1932,6 @@ func (rc *SQLiteRows) Close() error {
return nil
}
rc.closed = true
- if rc.done != nil {
- close(rc.done)
- }
if rc.cls {
rc.s.mu.Unlock()
return rc.s.Close()
@@ -1979,13 +1975,39 @@ func (rc *SQLiteRows) DeclTypes() []string {
return rc.declTypes()
}
-// Next move cursor to next.
+// Next move cursor to next. Attempts to honor context timeout from QueryContext call.
func (rc *SQLiteRows) Next(dest []driver.Value) error {
rc.s.mu.Lock()
defer rc.s.mu.Unlock()
+
if rc.s.closed {
return io.EOF
}
+
+ if rc.ctx.Done() == nil {
+ return rc.nextSyncLocked(dest)
+ }
+ resultCh := make(chan error)
+ go func() {
+ resultCh <- rc.nextSyncLocked(dest)
+ }()
+ select {
+ case err := <- resultCh:
+ return err
+ case <-rc.ctx.Done():
+ select {
+ case <-resultCh: // no need to interrupt
+ default:
+ // this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked.
+ C.sqlite3_interrupt(rc.s.c.db)
+ <-resultCh // ensure goroutine completed
+ }
+ return rc.ctx.Err()
+ }
+}
+
+// nextSyncLocked moves cursor to next; must be called with locked mutex.
+func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error {
rv := C._sqlite3_step_internal(rc.s.s)
if rv == C.SQLITE_DONE {
return io.EOF