aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--example/extension/extension.go29
-rw-r--r--sqlite3.go42
2 files changed, 32 insertions, 39 deletions
diff --git a/example/extension/extension.go b/example/extension/extension.go
index 49eacf1..f58ea3a 100644
--- a/example/extension/extension.go
+++ b/example/extension/extension.go
@@ -8,29 +8,10 @@ import (
)
func main() {
- const (
- use_hook = true
- load_query = "SELECT load_extension('sqlite3_mod_regexp.dll')"
- )
-
sql.Register("sqlite3_with_extensions",
&sqlite3.SQLiteDriver{
- EnableLoadExtension: true,
- ConnectHook: func(c *sqlite3.SQLiteConn) error {
- if use_hook {
- stmt, err := c.Prepare(load_query)
- if err != nil {
- return err
- }
-
- _, err = stmt.Exec(nil)
- if err != nil {
- return err
- }
-
- return stmt.Close()
- }
- return nil
+ Extensions: []string{
+ "sqlite3_mod_regexp.dll",
},
})
@@ -40,12 +21,6 @@ func main() {
}
defer db.Close()
- if !use_hook {
- if _, err = db.Exec(load_query); err != nil {
- log.Fatal(err)
- }
- }
-
// Force db to make a new connection in pool
// by putting the original in a transaction
tx, err := db.Begin()
diff --git a/sqlite3.go b/sqlite3.go
index cc42c13..e7417ec 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -72,13 +72,13 @@ var SQLiteTimestampFormats = []string{
}
func init() {
- sql.Register("sqlite3", &SQLiteDriver{false, nil})
+ sql.Register("sqlite3", &SQLiteDriver{})
}
// Driver struct.
type SQLiteDriver struct {
- EnableLoadExtension bool
- ConnectHook func(*SQLiteConn) error
+ Extensions []string
+ ConnectHook func(*SQLiteConn) error
}
// Conn struct.
@@ -182,17 +182,35 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
- enableLoadExtension := 0
- if d.EnableLoadExtension {
- enableLoadExtension = 1
- }
- rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtension))
- if rv != C.SQLITE_OK {
- return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
- }
-
conn := &SQLiteConn{db}
+ if len(d.Extensions) > 0 {
+ rv = C.sqlite3_enable_load_extension(db, 1)
+ if rv != C.SQLITE_OK {
+ return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
+ }
+
+ stmt, err := conn.Prepare("SELECT load_extension(?);")
+ if err != nil {
+ return nil, err
+ }
+
+ for _, extension := range d.Extensions {
+ if _, err = stmt.Exec([]driver.Value{extension}); err != nil {
+ return nil, err
+ }
+ }
+
+ if err = stmt.Close(); err != nil {
+ return nil, err
+ }
+
+ rv = C.sqlite3_enable_load_extension(db, 0)
+ if rv != C.SQLITE_OK {
+ return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
+ }
+ }
+
if d.ConnectHook != nil {
if err := d.ConnectHook(conn); err != nil {
return nil, err