diff options
author | mattn <mattn.jp@gmail.com> | 2017-08-30 19:57:18 +0900 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-08-30 19:57:18 +0900 |
commit | 132eeedb4ad10f217a357d731f930cbdd1360733 (patch) | |
tree | 45438e41d445e4496b5c80066b74ba760ac07fb1 /sqlite3.go | |
parent | Add support for collation sequences implemented in Go. (diff) | |
parent | Merge pull request #461 from mattn/solaris (diff) | |
download | golite-132eeedb4ad10f217a357d731f930cbdd1360733.tar.gz golite-132eeedb4ad10f217a357d731f930cbdd1360733.tar.xz |
Merge branch 'master' into master
Diffstat (limited to 'sqlite3.go')
-rw-r--r-- | sqlite3.go | 152 |
1 files changed, 136 insertions, 16 deletions
@@ -7,7 +7,7 @@ package sqlite3 /* #cgo CFLAGS: -std=gnu99 -#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE +#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE=1 #cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61 #cgo CFLAGS: -DSQLITE_TRACE_SIZE_LIMIT=15 #cgo CFLAGS: -DSQLITE_DISABLE_INTRINSIC @@ -102,6 +102,9 @@ int _sqlite3_create_function( void callbackTrampoline(sqlite3_context*, int, sqlite3_value**); int compareTrampoline(void*, int, char*, int, char*); +int commitHookTrampoline(void*); +void rollbackHookTrampoline(void*); +void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64); */ import "C" import ( @@ -115,6 +118,7 @@ import ( "runtime" "strconv" "strings" + "sync" "time" "unsafe" @@ -151,6 +155,12 @@ func Version() (libVersion string, libVersionNumber int, sourceID string) { return libVersion, libVersionNumber, sourceID } +const ( + SQLITE_DELETE = C.SQLITE_DELETE + SQLITE_INSERT = C.SQLITE_INSERT + SQLITE_UPDATE = C.SQLITE_UPDATE +) + // SQLiteDriver implement sql.Driver. type SQLiteDriver struct { Extensions []string @@ -159,6 +169,7 @@ type SQLiteDriver struct { // SQLiteConn implement sql.Conn. type SQLiteConn struct { + mu sync.Mutex db *C.sqlite3 loc *time.Location txlock string @@ -173,6 +184,7 @@ type SQLiteTx struct { // SQLiteStmt implement sql.Stmt. type SQLiteStmt struct { + mu sync.Mutex c *SQLiteConn s *C.sqlite3_stmt t string @@ -193,6 +205,7 @@ type SQLiteRows struct { cols []string decltype []string cls bool + closed bool done chan struct{} } @@ -338,6 +351,51 @@ func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int return nil } +// RegisterCommitHook sets the commit hook for a connection. +// +// If the callback returns non-zero the transaction will become a rollback. +// +// If there is an existing commit hook for this connection, it will be +// removed. If callback is nil the existing hook (if any) will be removed +// without creating a new one. +func (c *SQLiteConn) RegisterCommitHook(callback func() int) { + if callback == nil { + C.sqlite3_commit_hook(c.db, nil, nil) + } else { + C.sqlite3_commit_hook(c.db, (*[0]byte)(unsafe.Pointer(C.commitHookTrampoline)), unsafe.Pointer(newHandle(c, callback))) + } +} + +// RegisterRollbackHook sets the rollback hook for a connection. +// +// If there is an existing rollback hook for this connection, it will be +// removed. If callback is nil the existing hook (if any) will be removed +// without creating a new one. +func (c *SQLiteConn) RegisterRollbackHook(callback func()) { + if callback == nil { + C.sqlite3_rollback_hook(c.db, nil, nil) + } else { + C.sqlite3_rollback_hook(c.db, (*[0]byte)(unsafe.Pointer(C.rollbackHookTrampoline)), unsafe.Pointer(newHandle(c, callback))) + } +} + +// RegisterUpdateHook sets the update hook for a connection. +// +// The parameters to the callback are the operation (one of the constants +// SQLITE_INSERT, SQLITE_DELETE, or SQLITE_UPDATE), the database name, the +// table name, and the rowid. +// +// If there is an existing update hook for this connection, it will be +// removed. If callback is nil the existing hook (if any) will be removed +// without creating a new one. +func (c *SQLiteConn) RegisterUpdateHook(callback func(int, string, string, int64)) { + if callback == nil { + C.sqlite3_update_hook(c.db, nil, nil) + } else { + C.sqlite3_update_hook(c.db, (*[0]byte)(unsafe.Pointer(C.updateHookTrampoline)), unsafe.Pointer(newHandle(c, callback))) + } +} + // RegisterFunc makes a Go function available as a SQLite function. // // The Go function can have arguments of the following types: any @@ -568,6 +626,8 @@ func errorString(err Error) string { // "deferred", "exclusive". // _foreign_keys=X // Enable or disable enforcement of foreign keys. X can be 1 or 0. +// _recursive_triggers=X +// Enable or disable recursive triggers. X can be 1 or 0. func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { if C.sqlite3_threadsafe() == 0 { return nil, errors.New("sqlite library was not compiled for thread-safe operation") @@ -577,6 +637,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { txlock := "BEGIN" busyTimeout := 5000 foreignKeys := -1 + recursiveTriggers := -1 pos := strings.IndexRune(dsn, '?') if pos >= 1 { params, err := url.ParseQuery(dsn[pos+1:]) @@ -631,6 +692,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { } } + // _recursive_triggers + if val := params.Get("_recursive_triggers"); val != "" { + switch val { + case "1": + recursiveTriggers = 1 + case "0": + recursiveTriggers = 0 + default: + return nil, fmt.Errorf("Invalid _recursive_triggers: %v", val) + } + } + if !strings.HasPrefix(dsn, "file:") { dsn = dsn[:pos] } @@ -677,6 +750,17 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { return nil, err } } + if recursiveTriggers == 0 { + if err := exec("PRAGMA recursive_triggers = OFF;"); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + } else if recursiveTriggers == 1 { + if err := exec("PRAGMA recursive_triggers = ON;"); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + } conn := &SQLiteConn{db: db, loc: loc, txlock: txlock} @@ -704,11 +788,22 @@ func (c *SQLiteConn) Close() error { return c.lastError() } deleteHandles(c) + c.mu.Lock() c.db = nil + c.mu.Unlock() runtime.SetFinalizer(c, nil) return nil } +func (c *SQLiteConn) dbConnOpen() bool { + if c == nil { + return false + } + c.mu.Lock() + defer c.mu.Unlock() + return c.db != nil +} + // Prepare the query string. Return a new statement. func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { return c.prepare(context.Background(), query) @@ -734,14 +829,17 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er // Close the statement. func (s *SQLiteStmt) Close() error { + s.mu.Lock() + defer s.mu.Unlock() if s.closed { return nil } s.closed = true - if s.c == nil || s.c.db == nil { + if !s.c.dbConnOpen() { return errors.New("sqlite statement with already closed database connection") } rv := C.sqlite3_finalize(s.s) + s.s = nil if rv != C.SQLITE_OK { return s.c.lastError() } @@ -759,6 +857,8 @@ type bindArg struct { v driver.Value } +var placeHolder = []byte{0} + func (s *SQLiteStmt) bind(args []namedValue) error { rv := C.sqlite3_reset(s.s) if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { @@ -780,8 +880,7 @@ func (s *SQLiteStmt) bind(args []namedValue) error { rv = C.sqlite3_bind_null(s.s, n) case string: if len(v) == 0 { - b := []byte{0} - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(0)) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0)) } else { b := []byte(v) rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) @@ -797,11 +896,11 @@ func (s *SQLiteStmt) bind(args []namedValue) error { case float64: rv = C.sqlite3_bind_double(s.s, n, C.double(v)) case []byte: - if len(v) == 0 { - rv = C._sqlite3_bind_blob(s.s, n, nil, 0) - } else { - rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(len(v))) + ln := len(v) + if ln == 0 { + v = placeHolder } + rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln)) case time.Time: b := []byte(v.Format(SQLiteTimestampFormats[0])) rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) @@ -836,6 +935,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows, cols: nil, decltype: nil, cls: s.cls, + closed: false, done: make(chan struct{}), } @@ -908,25 +1008,33 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result // Close the rows. func (rc *SQLiteRows) Close() error { - if rc.s.closed { + rc.s.mu.Lock() + if rc.s.closed || rc.closed { + rc.s.mu.Unlock() return nil } + rc.closed = true if rc.done != nil { close(rc.done) } if rc.cls { + rc.s.mu.Unlock() return rc.s.Close() } rv := C.sqlite3_reset(rc.s.s) if rv != C.SQLITE_OK { + rc.s.mu.Unlock() return rc.s.c.lastError() } + rc.s.mu.Unlock() return nil } // Columns return column names. func (rc *SQLiteRows) Columns() []string { - if rc.nc != len(rc.cols) { + rc.s.mu.Lock() + defer rc.s.mu.Unlock() + if rc.s.s != nil && rc.nc != len(rc.cols) { rc.cols = make([]string, rc.nc) for i := 0; i < rc.nc; i++ { rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i))) @@ -935,9 +1043,8 @@ func (rc *SQLiteRows) Columns() []string { return rc.cols } -// DeclTypes return column types. -func (rc *SQLiteRows) DeclTypes() []string { - if rc.decltype == nil { +func (rc *SQLiteRows) declTypes() []string { + if rc.s.s != nil && rc.decltype == nil { rc.decltype = make([]string, rc.nc) for i := 0; i < rc.nc; i++ { rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i)))) @@ -946,8 +1053,20 @@ func (rc *SQLiteRows) DeclTypes() []string { return rc.decltype } +// DeclTypes return column types. +func (rc *SQLiteRows) DeclTypes() []string { + rc.s.mu.Lock() + defer rc.s.mu.Unlock() + return rc.declTypes() +} + // Next move cursor to next. func (rc *SQLiteRows) Next(dest []driver.Value) error { + if rc.s.closed { + return io.EOF + } + rc.s.mu.Lock() + defer rc.s.mu.Unlock() rv := C.sqlite3_step(rc.s.s) if rv == C.SQLITE_DONE { return io.EOF @@ -960,7 +1079,7 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { return nil } - rc.DeclTypes() + rc.declTypes() for i := range dest { switch C.sqlite3_column_type(rc.s.s, C.int(i)) { @@ -973,10 +1092,11 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { // large to be a reasonable timestamp in seconds. if val > 1e12 || val < -1e12 { val *= int64(time.Millisecond) // convert ms to nsec + t = time.Unix(0, val) } else { - val *= int64(time.Second) // convert sec to nsec + t = time.Unix(val, 0) } - t = time.Unix(0, val).UTC() + t = t.UTC() if rc.s.c.loc != nil { t = t.In(rc.s.c.loc) } |