aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sqlite3.go47
-rw-r--r--sqlite3_test.go29
2 files changed, 73 insertions, 3 deletions
diff --git a/sqlite3.go b/sqlite3.go
index cbde900..33b9b9c 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -400,14 +400,18 @@ func (c *SQLiteConn) AutoCommit() bool {
}
func (c *SQLiteConn) lastError() error {
- rv := C.sqlite3_errcode(c.db)
+ return lastError(c.db)
+}
+
+func lastError(db *C.sqlite3) error {
+ rv := C.sqlite3_errcode(db)
if rv == C.SQLITE_OK {
return nil
}
return Error{
Code: ErrNo(rv),
- ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(c.db)),
- err: C.GoString(C.sqlite3_errmsg(c.db)),
+ ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(db)),
+ err: C.GoString(C.sqlite3_errmsg(db)),
}
}
@@ -537,6 +541,8 @@ func errorString(err Error) string {
// _txlock=XXX
// Specify locking behavior for transactions. XXX can be "immediate",
// "deferred", "exclusive".
+// _foreign_keys=X
+// Enable or disable enforcement of foreign keys. X can be 1 or 0.
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
@@ -545,6 +551,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
var loc *time.Location
txlock := "BEGIN"
busyTimeout := 5000
+ foreignKeys := -1
pos := strings.IndexRune(dsn, '?')
if pos >= 1 {
params, err := url.ParseQuery(dsn[pos+1:])
@@ -587,6 +594,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
}
}
+ // _foreign_keys
+ if val := params.Get("_foreign_keys"); val != "" {
+ switch val {
+ case "1":
+ foreignKeys = 1
+ case "0":
+ foreignKeys = 0
+ default:
+ return nil, fmt.Errorf("Invalid _foreign_keys: %v", val)
+ }
+ }
+
if !strings.HasPrefix(dsn, "file:") {
dsn = dsn[:pos]
}
@@ -609,9 +628,31 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
rv = C.sqlite3_busy_timeout(db, C.int(busyTimeout))
if rv != C.SQLITE_OK {
+ C.sqlite3_close_v2(db)
return nil, Error{Code: ErrNo(rv)}
}
+ exec := func(s string) error {
+ cs := C.CString(s)
+ rv := C.sqlite3_exec(db, cs, nil, nil, nil)
+ C.free(unsafe.Pointer(cs))
+ if rv != C.SQLITE_OK {
+ return lastError(db)
+ }
+ return nil
+ }
+ if foreignKeys == 0 {
+ if err := exec("PRAGMA foreign_keys = OFF;"); err != nil {
+ C.sqlite3_close_v2(db)
+ return nil, err
+ }
+ } else if foreignKeys == 1 {
+ if err := exec("PRAGMA foreign_keys = ON;"); err != nil {
+ C.sqlite3_close_v2(db)
+ return nil, err
+ }
+ }
+
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
if len(d.Extensions) > 0 {
diff --git a/sqlite3_test.go b/sqlite3_test.go
index e844f82..03b678d 100644
--- a/sqlite3_test.go
+++ b/sqlite3_test.go
@@ -107,6 +107,35 @@ func TestReadonly(t *testing.T) {
}
}
+func TestForeignKeys(t *testing.T) {
+ cases := map[string]bool{
+ "?_foreign_keys=1": true,
+ "?_foreign_keys=0": false,
+ }
+ for option, want := range cases {
+ fname := TempFilename(t)
+ uri := "file:" + fname + option
+ db, err := sql.Open("sqlite3", uri)
+ if err != nil {
+ os.Remove(fname)
+ t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
+ continue
+ }
+ var enabled bool
+ err = db.QueryRow("PRAGMA foreign_keys;").Scan(&enabled)
+ db.Close()
+ os.Remove(fname)
+ if err != nil {
+ t.Errorf("query foreign_keys for %s: %v", uri, err)
+ continue
+ }
+ if enabled != want {
+ t.Errorf("\"PRAGMA foreign_keys;\" for %q = %t; want %t", uri, enabled, want)
+ continue
+ }
+ }
+}
+
func TestClose(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)