aboutsummaryrefslogtreecommitdiff
path: root/sqlite3.go
diff options
context:
space:
mode:
Diffstat (limited to 'sqlite3.go')
-rw-r--r--sqlite3.go42
1 files changed, 30 insertions, 12 deletions
diff --git a/sqlite3.go b/sqlite3.go
index cc42c13..e7417ec 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -72,13 +72,13 @@ var SQLiteTimestampFormats = []string{
}
func init() {
- sql.Register("sqlite3", &SQLiteDriver{false, nil})
+ sql.Register("sqlite3", &SQLiteDriver{})
}
// Driver struct.
type SQLiteDriver struct {
- EnableLoadExtension bool
- ConnectHook func(*SQLiteConn) error
+ Extensions []string
+ ConnectHook func(*SQLiteConn) error
}
// Conn struct.
@@ -182,17 +182,35 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
- enableLoadExtension := 0
- if d.EnableLoadExtension {
- enableLoadExtension = 1
- }
- rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtension))
- if rv != C.SQLITE_OK {
- return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
- }
-
conn := &SQLiteConn{db}
+ if len(d.Extensions) > 0 {
+ rv = C.sqlite3_enable_load_extension(db, 1)
+ if rv != C.SQLITE_OK {
+ return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
+ }
+
+ stmt, err := conn.Prepare("SELECT load_extension(?);")
+ if err != nil {
+ return nil, err
+ }
+
+ for _, extension := range d.Extensions {
+ if _, err = stmt.Exec([]driver.Value{extension}); err != nil {
+ return nil, err
+ }
+ }
+
+ if err = stmt.Close(); err != nil {
+ return nil, err
+ }
+
+ rv = C.sqlite3_enable_load_extension(db, 0)
+ if rv != C.SQLITE_OK {
+ return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
+ }
+ }
+
if d.ConnectHook != nil {
if err := d.ConnectHook(conn); err != nil {
return nil, err