aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--_example/hook/hook.go14
-rw-r--r--sqlite3.go19
-rw-r--r--sqlite3_load_extension.go39
-rw-r--r--sqlite3_omit_load_extension.go19
-rw-r--r--sqlite3_test.go4
5 files changed, 70 insertions, 25 deletions
diff --git a/_example/hook/hook.go b/_example/hook/hook.go
index 59f8cd4..3059f9e 100644
--- a/_example/hook/hook.go
+++ b/_example/hook/hook.go
@@ -10,12 +10,12 @@ import (
func main() {
sqlite3conn := []*sqlite3.SQLiteConn{}
sql.Register("sqlite3_with_hook_example",
- &sqlite3.SQLiteDriver{
- ConnectHook: func(conn *sqlite3.SQLiteConn) error {
- sqlite3conn = append(sqlite3conn, conn)
- return nil
- },
- })
+ &sqlite3.SQLiteDriver{
+ ConnectHook: func(conn *sqlite3.SQLiteConn) error {
+ sqlite3conn = append(sqlite3conn, conn)
+ return nil
+ },
+ })
os.Remove("./foo.db")
os.Remove("./bar.db")
@@ -54,7 +54,7 @@ func main() {
log.Fatal(err)
}
- bk.Step(-1)
+ _, err = bk.Step(-1)
if err != nil {
log.Fatal(err)
}
diff --git a/sqlite3.go b/sqlite3.go
index 8d2faca..f524d17 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -672,23 +672,8 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
if len(d.Extensions) > 0 {
- rv = C.sqlite3_enable_load_extension(db, 1)
- if rv != C.SQLITE_OK {
- return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
- }
-
- for _, extension := range d.Extensions {
- cext := C.CString(extension)
- defer C.free(unsafe.Pointer(cext))
- rv = C.sqlite3_load_extension(db, cext, nil, nil)
- if rv != C.SQLITE_OK {
- return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
- }
- }
-
- rv = C.sqlite3_enable_load_extension(db, 0)
- if rv != C.SQLITE_OK {
- return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
+ if err := conn.loadExtensions(d.Extensions); err != nil {
+ return nil, err
}
}
diff --git a/sqlite3_load_extension.go b/sqlite3_load_extension.go
new file mode 100644
index 0000000..0251016
--- /dev/null
+++ b/sqlite3_load_extension.go
@@ -0,0 +1,39 @@
+// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
+//
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+// +build !sqlite_omit_load_extension
+
+package sqlite3
+
+/*
+#include <sqlite3-binding.h>
+#include <stdlib.h>
+*/
+import "C"
+import (
+ "errors"
+ "unsafe"
+)
+
+func (c *SQLiteConn) loadExtensions(extensions []string) error {
+ rv := C.sqlite3_enable_load_extension(c.db, 1)
+ if rv != C.SQLITE_OK {
+ return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
+ }
+
+ for _, extension := range extensions {
+ cext := C.CString(extension)
+ defer C.free(unsafe.Pointer(cext))
+ rv = C.sqlite3_load_extension(c.db, cext, nil, nil)
+ if rv != C.SQLITE_OK {
+ return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
+ }
+ }
+
+ rv = C.sqlite3_enable_load_extension(c.db, 0)
+ if rv != C.SQLITE_OK {
+ return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
+ }
+ return nil
+}
diff --git a/sqlite3_omit_load_extension.go b/sqlite3_omit_load_extension.go
new file mode 100644
index 0000000..a80cf87
--- /dev/null
+++ b/sqlite3_omit_load_extension.go
@@ -0,0 +1,19 @@
+// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
+//
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+// +build sqlite_omit_load_extension
+
+package sqlite3
+
+/*
+#cgo CFLAGS: -DSQLITE_OMIT_LOAD_EXTENSION
+*/
+import "C"
+import (
+ "errors"
+)
+
+func (c *SQLiteConn) loadExtensions(extensions []string) error {
+ return errors.New("Extensions have been disabled for static builds")
+}
diff --git a/sqlite3_test.go b/sqlite3_test.go
index 74d3de1..0239c78 100644
--- a/sqlite3_test.go
+++ b/sqlite3_test.go
@@ -772,10 +772,12 @@ func TestTimezoneConversion(t *testing.T) {
}
func TestSuite(t *testing.T) {
- db, err := sql.Open("sqlite3", ":memory:")
+ tempFilename := TempFilename()
+ db, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
if err != nil {
t.Fatal(err)
}
+ defer os.Remove(tempFilename)
defer db.Close()
sqlite3_test.RunTests(t, db, sqlite3_test.SQLITE)