aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Abbott <dev@trailimage.com>2017-07-03 12:51:48 -0600
committerJason Abbott <dev@trailimage.com>2017-07-03 12:51:48 -0600
commit59bd281a89883d39ef219699e4a46eab87b3cff9 (patch)
treecc83849ec381c059e37faf0db7c7bb698ffa3fdd
parentMerge pull request #431 from deepilla/issue-430 (diff)
downloadgolite-59bd281a89883d39ef219699e4a46eab87b3cff9.tar.gz
golite-59bd281a89883d39ef219699e4a46eab87b3cff9.tar.xz
Incorporate original PR 271 from https://github.com/brokensandals
-rw-r--r--_example/hook/hook.go6
-rw-r--r--callback.go18
-rw-r--r--sqlite3.go54
-rw-r--r--sqlite3_test.go61
4 files changed, 139 insertions, 0 deletions
diff --git a/_example/hook/hook.go b/_example/hook/hook.go
index 17bddeb..6023181 100644
--- a/_example/hook/hook.go
+++ b/_example/hook/hook.go
@@ -14,6 +14,12 @@ func main() {
&sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
sqlite3conn = append(sqlite3conn, conn)
+ conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
+ switch op {
+ case sqlite3.SQLITE_INSERT:
+ log.Println("Notified of insert on db", db, "table", table, "rowid", rowid)
+ }
+ })
return nil
},
})
diff --git a/callback.go b/callback.go
index 48fc63a..6a55964 100644
--- a/callback.go
+++ b/callback.go
@@ -53,6 +53,24 @@ func doneTrampoline(ctx *C.sqlite3_context) {
ai.Done(ctx)
}
+//export commitHookTrampoline
+func commitHookTrampoline(handle uintptr) int {
+ callback := lookupHandle(handle).(func() int)
+ return callback()
+}
+
+//export rollbackHookTrampoline
+func rollbackHookTrampoline(handle uintptr) {
+ callback := lookupHandle(handle).(func())
+ callback()
+}
+
+//export updateHookTrampoline
+func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, rowid int64) {
+ callback := lookupHandle(handle).(func(int, string, string, int64))
+ callback(op, C.GoString(db), C.GoString(table), rowid)
+}
+
// Use handles to avoid passing Go pointers to C.
type handleVal struct {
diff --git a/sqlite3.go b/sqlite3.go
index d3a6407..0217cce 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -100,6 +100,9 @@ int _sqlite3_create_function(
}
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
+int commitHookTrampoline(void*);
+void rollbackHookTrampoline(void*);
+void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64);
*/
import "C"
import (
@@ -150,6 +153,12 @@ func Version() (libVersion string, libVersionNumber int, sourceID string) {
return libVersion, libVersionNumber, sourceID
}
+const (
+ SQLITE_DELETE = C.SQLITE_DELETE
+ SQLITE_INSERT = C.SQLITE_INSERT
+ SQLITE_UPDATE = C.SQLITE_UPDATE
+)
+
// SQLiteDriver implement sql.Driver.
type SQLiteDriver struct {
Extensions []string
@@ -315,6 +324,51 @@ func (tx *SQLiteTx) Rollback() error {
return err
}
+// RegisterCommitHook sets the commit hook for a connection.
+//
+// If the callback returns non-zero the transaction will become a rollback.
+//
+// If there is an existing commit hook for this connection, it will be
+// removed. If callback is nil the existing hook (if any) will be removed
+// without creating a new one.
+func (c *SQLiteConn) RegisterCommitHook(callback func() int) {
+ if callback == nil {
+ C.sqlite3_commit_hook(c.db, nil, nil)
+ } else {
+ C.sqlite3_commit_hook(c.db, (*[0]byte)(unsafe.Pointer(C.commitHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
+ }
+}
+
+// RegisterRollbackHook sets the rollback hook for a connection.
+//
+// If there is an existing rollback hook for this connection, it will be
+// removed. If callback is nil the existing hook (if any) will be removed
+// without creating a new one.
+func (c *SQLiteConn) RegisterRollbackHook(callback func()) {
+ if callback == nil {
+ C.sqlite3_rollback_hook(c.db, nil, nil)
+ } else {
+ C.sqlite3_rollback_hook(c.db, (*[0]byte)(unsafe.Pointer(C.rollbackHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
+ }
+}
+
+// RegisterUpdateHook sets the update hook for a connection.
+//
+// The parameters to the callback are the operation (one of the constants
+// SQLITE_INSERT, SQLITE_DELETE, or SQLITE_UPDATE), the database name, the
+// table name, and the rowid.
+//
+// If there is an existing update hook for this connection, it will be
+// removed. If callback is nil the existing hook (if any) will be removed
+// without creating a new one.
+func (c *SQLiteConn) RegisterUpdateHook(callback func(int, string, string, int64)) {
+ if callback == nil {
+ C.sqlite3_update_hook(c.db, nil, nil)
+ } else {
+ C.sqlite3_update_hook(c.db, (*[0]byte)(unsafe.Pointer(C.updateHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
+ }
+}
+
// RegisterFunc makes a Go function available as a SQLite function.
//
// The Go function can have arguments of the following types: any
diff --git a/sqlite3_test.go b/sqlite3_test.go
index e563479..f11c349 100644
--- a/sqlite3_test.go
+++ b/sqlite3_test.go
@@ -1265,6 +1265,67 @@ func TestPinger(t *testing.T) {
}
}
+func TestUpdateAndTransactionHooks(t *testing.T) {
+ var events []string
+ var commitHookReturn = 0
+
+ sql.Register("sqlite3_UpdateHook", &SQLiteDriver{
+ ConnectHook: func(conn *SQLiteConn) error {
+ conn.RegisterCommitHook(func() int {
+ events = append(events, "commit")
+ return commitHookReturn
+ })
+ conn.RegisterRollbackHook(func() {
+ events = append(events, "rollback")
+ })
+ conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
+ events = append(events, fmt.Sprintf("update(op=%v db=%v table=%v rowid=%v)", op, db, table, rowid))
+ })
+ return nil
+ },
+ })
+ db, err := sql.Open("sqlite3_UpdateHook", ":memory:")
+ if err != nil {
+ t.Fatal("Failed to open database:", err)
+ }
+ defer db.Close()
+
+ statements := []string{
+ "create table foo (id integer primary key)",
+ "insert into foo values (9)",
+ "update foo set id = 99 where id = 9",
+ "delete from foo where id = 99",
+ }
+ for _, statement := range statements {
+ _, err = db.Exec(statement)
+ if err != nil {
+ t.Fatalf("Unable to prepare test data [%v]: %v", statement, err)
+ }
+ }
+
+ commitHookReturn = 1
+ _, err = db.Exec("insert into foo values (5)")
+ if err == nil {
+ t.Error("Commit hook failed to rollback transaction")
+ }
+
+ var expected = []string{
+ "commit",
+ fmt.Sprintf("update(op=%v db=main table=foo rowid=9)", SQLITE_INSERT),
+ "commit",
+ fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_UPDATE),
+ "commit",
+ fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_DELETE),
+ "commit",
+ fmt.Sprintf("update(op=%v db=main table=foo rowid=5)", SQLITE_INSERT),
+ "commit",
+ "rollback",
+ }
+ if !reflect.DeepEqual(events, expected) {
+ t.Errorf("Expected notifications %v but got %v", expected, events)
+ }
+}
+
var customFunctionOnce sync.Once
func BenchmarkCustomFunctions(b *testing.B) {