aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--example/extension/extension.go11
-rw-r--r--sqlite3.go46
2 files changed, 41 insertions, 16 deletions
diff --git a/example/extension/extension.go b/example/extension/extension.go
index d4b8fdb..f58ea3a 100644
--- a/example/extension/extension.go
+++ b/example/extension/extension.go
@@ -10,8 +10,9 @@ import (
func main() {
sql.Register("sqlite3_with_extensions",
&sqlite3.SQLiteDriver{
- EnableLoadExtension: true,
- ConnectHook: nil,
+ Extensions: []string{
+ "sqlite3_mod_regexp.dll",
+ },
})
db, err := sql.Open("sqlite3_with_extensions", ":memory:")
@@ -20,11 +21,15 @@ func main() {
}
defer db.Close()
- _, err = db.Exec("select load_extension('sqlite3_mod_regexp.dll')")
+ // Force db to make a new connection in pool
+ // by putting the original in a transaction
+ tx, err := db.Begin()
if err != nil {
log.Fatal(err)
}
+ defer tx.Commit()
+ // New connection works (hopefully!)
rows, err := db.Query("select 'hello world' where 'hello world' regexp '^hello.*d$'")
if err != nil {
log.Fatal(err)
diff --git a/sqlite3.go b/sqlite3.go
index 692306d..e7417ec 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -72,13 +72,13 @@ var SQLiteTimestampFormats = []string{
}
func init() {
- sql.Register("sqlite3", &SQLiteDriver{false, nil})
+ sql.Register("sqlite3", &SQLiteDriver{})
}
// Driver struct.
type SQLiteDriver struct {
- EnableLoadExtension bool
- ConnectHook func(*SQLiteConn)
+ Extensions []string
+ ConnectHook func(*SQLiteConn) error
}
// Conn struct.
@@ -182,19 +182,39 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
- enableLoadExtension := 0
- if d.EnableLoadExtension {
- enableLoadExtension = 1
- }
- rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtension))
- if rv != C.SQLITE_OK {
- return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
- }
-
conn := &SQLiteConn{db}
+ if len(d.Extensions) > 0 {
+ rv = C.sqlite3_enable_load_extension(db, 1)
+ if rv != C.SQLITE_OK {
+ return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
+ }
+
+ stmt, err := conn.Prepare("SELECT load_extension(?);")
+ if err != nil {
+ return nil, err
+ }
+
+ for _, extension := range d.Extensions {
+ if _, err = stmt.Exec([]driver.Value{extension}); err != nil {
+ return nil, err
+ }
+ }
+
+ if err = stmt.Close(); err != nil {
+ return nil, err
+ }
+
+ rv = C.sqlite3_enable_load_extension(db, 0)
+ if rv != C.SQLITE_OK {
+ return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
+ }
+ }
+
if d.ConnectHook != nil {
- d.ConnectHook(conn)
+ if err := d.ConnectHook(conn); err != nil {
+ return nil, err
+ }
}
return conn, nil