aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYasuhiro Matsumoto <mattn.jp@gmail.com>2016-11-08 12:19:13 +0900
committerYasuhiro Matsumoto <mattn.jp@gmail.com>2016-11-08 12:19:51 +0900
commitdd2c82226baa4c72a4b99efb07d048e9d2e92796 (patch)
treeff713fe3531eea1db29723c33430d40ae37d5e39
parentMerge pull request #351 from andrefsp/data-race-fix (diff)
downloadgolite-dd2c82226baa4c72a4b99efb07d048e9d2e92796.tar.gz
golite-dd2c82226baa4c72a4b99efb07d048e9d2e92796.tar.xz
fix trace callback.
Close #352
-rw-r--r--callback.go4
-rw-r--r--sqlite3.go56
-rw-r--r--sqlite3_go18_test.go56
-rw-r--r--tracecallback.go12
4 files changed, 97 insertions, 31 deletions
diff --git a/callback.go b/callback.go
index 190b695..48fc63a 100644
--- a/callback.go
+++ b/callback.go
@@ -40,8 +40,8 @@ func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value
}
//export stepTrampoline
-func stepTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
- args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
+func stepTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) {
+ args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)]
ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo)
ai.Step(ctx, args)
}
diff --git a/sqlite3.go b/sqlite3.go
index 9b55ef1..71a791f 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -191,6 +191,7 @@ type SQLiteRows struct {
decltype []string
cls bool
done chan struct{}
+ next *SQLiteRows
}
type functionInfo struct {
@@ -296,19 +297,19 @@ func (ai *aggInfo) Done(ctx *C.sqlite3_context) {
// Commit transaction.
func (tx *SQLiteTx) Commit() error {
- _, err := tx.c.execQuery("COMMIT")
+ _, err := tx.c.exec(context.Background(), "COMMIT", nil)
if err != nil && err.(Error).Code == C.SQLITE_BUSY {
// sqlite3 will leave the transaction open in this scenario.
// However, database/sql considers the transaction complete once we
// return from Commit() - we must clean up to honour its semantics.
- tx.c.execQuery("ROLLBACK")
+ tx.c.exec(context.Background(), "ROLLBACK", nil)
}
return err
}
// Rollback transaction.
func (tx *SQLiteTx) Rollback() error {
- _, err := tx.c.execQuery("ROLLBACK")
+ _, err := tx.c.exec(context.Background(), "ROLLBACK", nil)
return err
}
@@ -382,13 +383,17 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
- rv := C._sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), C.uintptr_t(newHandle(c, &fi)), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil)
+ rv := sqlite3_create_function(c.db, cname, numArgs, opts, newHandle(c, &fi), C.callbackTrampoline, nil, nil)
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}
+func sqlite3_create_function(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTextRep C.int, pApp C.uintptr_t, xFunc unsafe.Pointer, xStep unsafe.Pointer, xFinal unsafe.Pointer) C.int {
+ return C._sqlite3_create_function(db, zFunctionName, nArg, eTextRep, pApp, (*[0]byte)(unsafe.Pointer(xFunc)), (*[0]byte)(unsafe.Pointer(xStep)), (*[0]byte)(unsafe.Pointer(xFinal)))
+}
+
// AutoCommit return which currently auto commit or not.
func (c *SQLiteConn) AutoCommit() bool {
return int(C.sqlite3_get_autocommit(c.db)) != 0
@@ -404,10 +409,6 @@ func (c *SQLiteConn) lastError() Error {
// Exec implements Execer.
func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) {
- if len(args) == 0 {
- return c.execQuery(query)
- }
-
list := make([]namedValue, len(args))
for i, v := range args {
list[i] = namedValue{
@@ -470,6 +471,7 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
}
func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue) (driver.Rows, error) {
+ var top, cur *SQLiteRows
start := 0
for {
s, err := c.Prepare(query)
@@ -487,7 +489,14 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue)
rows, err := s.(*SQLiteStmt).query(ctx, args[:na])
if err != nil && err != driver.ErrSkip {
s.Close()
- return nil, err
+ return top, err
+ }
+ if top == nil {
+ top = rows.(*SQLiteRows)
+ cur = top
+ } else {
+ cur.next = rows.(*SQLiteRows)
+ cur = cur.next
}
args = args[na:]
start += na
@@ -501,25 +510,13 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue)
}
}
-func (c *SQLiteConn) execQuery(cmd string) (driver.Result, error) {
- pcmd := C.CString(cmd)
- defer C.free(unsafe.Pointer(pcmd))
-
- var rowid, changes C.longlong
- rv := C._sqlite3_exec(c.db, pcmd, &rowid, &changes)
- if rv != C.SQLITE_OK {
- return nil, c.lastError()
- }
- return &SQLiteResult{int64(rowid), int64(changes)}, nil
-}
-
// Begin transaction.
func (c *SQLiteConn) Begin() (driver.Tx, error) {
return c.begin(context.Background())
}
func (c *SQLiteConn) begin(ctx context.Context) (driver.Tx, error) {
- if _, err := c.execQuery(c.txlock); err != nil {
+ if _, err := c.exec(ctx, c.txlock, nil); err != nil {
return nil, err
}
return &SQLiteTx{c}, nil
@@ -775,6 +772,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
decltype: nil,
cls: s.cls,
done: make(chan struct{}),
+ next: nil,
}
go func() {
@@ -837,7 +835,7 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result
return nil, err
}
- return &SQLiteResult{int64(rowid), int64(changes)}, nil
+ return &SQLiteResult{id: int64(rowid), changes: int64(changes)}, nil
}
// Close the rows.
@@ -972,3 +970,15 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
}
return nil
}
+
+func (rc *SQLiteRows) HasNextResultSet() bool {
+ return rc.next != nil
+}
+
+func (rc *SQLiteRows) NextResultSet() error {
+ if rc.next == nil {
+ return io.EOF
+ }
+ *rc = *rc.next
+ return nil
+}
diff --git a/sqlite3_go18_test.go b/sqlite3_go18_test.go
index 54e6604..0536042 100644
--- a/sqlite3_go18_test.go
+++ b/sqlite3_go18_test.go
@@ -9,6 +9,7 @@ package sqlite3
import (
"database/sql"
+ "fmt"
"os"
"testing"
)
@@ -48,3 +49,58 @@ func TestNamedParams(t *testing.T) {
t.Error("Failed to db.QueryRow: not matched results")
}
}
+
+func TestMultipleResultSet(t *testing.T) {
+ tempFilename := TempFilename(t)
+ defer os.Remove(tempFilename)
+ db, err := sql.Open("sqlite3", tempFilename)
+ if err != nil {
+ t.Fatal("Failed to open database:", err)
+ }
+ defer db.Close()
+
+ _, err = db.Exec(`
+ create table foo (id integer, name text);
+ `)
+ if err != nil {
+ t.Error("Failed to call db.Query:", err)
+ }
+
+ for i := 0; i < 100; i++ {
+ _, err = db.Exec(`insert into foo(id, name) values(?, ?)`, i+1, fmt.Sprintf("foo%03d", i+1))
+ if err != nil {
+ t.Error("Failed to call db.Exec:", err)
+ }
+ }
+
+ rows, err := db.Query(`
+ select id, name from foo where id < :id1;
+ select id, name from foo where id = :id2;
+ select id, name from foo where id > :id3;
+ `,
+ sql.Param(":id1", 3),
+ sql.Param(":id2", 50),
+ sql.Param(":id3", 98),
+ )
+ if err != nil {
+ t.Error("Failed to call db.Query:", err)
+ }
+
+ var id int
+ var extra string
+
+ for {
+ for rows.Next() {
+ err = rows.Scan(&id, &extra)
+ if err != nil {
+ t.Error("Failed to db.Scan:", err)
+ }
+ if id != 1 || extra != "foo" {
+ t.Error("Failed to db.QueryRow: not matched results")
+ }
+ }
+ if !rows.NextResultSet() {
+ break
+ }
+ }
+}
diff --git a/tracecallback.go b/tracecallback.go
index bf222b5..9c42791 100644
--- a/tracecallback.go
+++ b/tracecallback.go
@@ -17,7 +17,7 @@ package sqlite3
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
void doneTrampoline(sqlite3_context*);
-void traceCallbackTrampoline(unsigned traceEventCode, void *ctx, void *p, void *x);
+int traceCallbackTrampoline(unsigned int traceEventCode, void *ctx, void *p, void *x);
*/
import "C"
@@ -76,7 +76,7 @@ type TraceUserCallback func(TraceInfo) int
type TraceConfig struct {
Callback TraceUserCallback
- EventMask uint
+ EventMask C.uint
WantExpandedSQL bool
}
@@ -102,13 +102,13 @@ func fillExpandedSQL(info *TraceInfo, db *C.sqlite3, pStmt unsafe.Pointer) {
//export traceCallbackTrampoline
func traceCallbackTrampoline(
- traceEventCode uint,
+ traceEventCode C.uint,
// Parameter named 'C' in SQLite docs = Context given at registration:
ctx unsafe.Pointer,
// Parameter named 'P' in SQLite docs (Primary event data?):
p unsafe.Pointer,
// Parameter named 'X' in SQLite docs (eXtra event data?):
- xValue unsafe.Pointer) int {
+ xValue unsafe.Pointer) C.int {
if ctx == nil {
panic(fmt.Sprintf("No context (ev 0x%x)", traceEventCode))
@@ -196,7 +196,7 @@ func traceCallbackTrampoline(
if traceConf.Callback != nil {
r = traceConf.Callback(info)
}
- return r
+ return C.int(r)
}
type traceMapEntry struct {
@@ -358,7 +358,7 @@ func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
- rv := C._sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), C.uintptr_t(newHandle(c, &ai)), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline)))
+ rv := sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), C.uintptr_t(newHandle(c, &ai)), nil, C.stepTrampoline, C.doneTrampoline)
if rv != C.SQLITE_OK {
return c.lastError()
}