aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sqlite3.go26
-rw-r--r--sqlite3_test.go29
2 files changed, 55 insertions, 0 deletions
diff --git a/sqlite3.go b/sqlite3.go
index 2b7b8df..2ebf7e7 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -599,6 +599,8 @@ func errorString(err Error) string {
// "deferred", "exclusive".
// _foreign_keys=X
// Enable or disable enforcement of foreign keys. X can be 1 or 0.
+// _recursive_triggers=X
+// Enable or disable recursive triggers. X can be 1 or 0.
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
@@ -608,6 +610,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
txlock := "BEGIN"
busyTimeout := 5000
foreignKeys := -1
+ recursiveTriggers := -1
pos := strings.IndexRune(dsn, '?')
if pos >= 1 {
params, err := url.ParseQuery(dsn[pos+1:])
@@ -662,6 +665,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
}
}
+ // _recursive_triggers
+ if val := params.Get("_recursive_triggers"); val != "" {
+ switch val {
+ case "1":
+ recursiveTriggers = 1
+ case "0":
+ recursiveTriggers = 0
+ default:
+ return nil, fmt.Errorf("Invalid _recursive_triggers: %v", val)
+ }
+ }
+
if !strings.HasPrefix(dsn, "file:") {
dsn = dsn[:pos]
}
@@ -708,6 +723,17 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, err
}
}
+ if recursiveTriggers == 0 {
+ if err := exec("PRAGMA recursive_triggers = OFF;"); err != nil {
+ C.sqlite3_close_v2(db)
+ return nil, err
+ }
+ } else if recursiveTriggers == 1 {
+ if err := exec("PRAGMA recursive_triggers = ON;"); err != nil {
+ C.sqlite3_close_v2(db)
+ return nil, err
+ }
+ }
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
diff --git a/sqlite3_test.go b/sqlite3_test.go
index f11c349..a00e622 100644
--- a/sqlite3_test.go
+++ b/sqlite3_test.go
@@ -136,6 +136,35 @@ func TestForeignKeys(t *testing.T) {
}
}
+func TestRecursiveTriggers(t *testing.T) {
+ cases := map[string]bool{
+ "?_recursive_triggers=1": true,
+ "?_recursive_triggers=0": false,
+ }
+ for option, want := range cases {
+ fname := TempFilename(t)
+ uri := "file:" + fname + option
+ db, err := sql.Open("sqlite3", uri)
+ if err != nil {
+ os.Remove(fname)
+ t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
+ continue
+ }
+ var enabled bool
+ err = db.QueryRow("PRAGMA recursive_triggers;").Scan(&enabled)
+ db.Close()
+ os.Remove(fname)
+ if err != nil {
+ t.Errorf("query recursive_triggers for %s: %v", uri, err)
+ continue
+ }
+ if enabled != want {
+ t.Errorf("\"PRAGMA recursive_triggers;\" for %q = %t; want %t", uri, enabled, want)
+ continue
+ }
+ }
+}
+
func TestClose(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)