diff options
Diffstat (limited to 'sqlite3.go')
-rw-r--r-- | sqlite3.go | 42 |
1 files changed, 30 insertions, 12 deletions
@@ -72,13 +72,13 @@ var SQLiteTimestampFormats = []string{ } func init() { - sql.Register("sqlite3", &SQLiteDriver{false, nil}) + sql.Register("sqlite3", &SQLiteDriver{}) } // Driver struct. type SQLiteDriver struct { - EnableLoadExtension bool - ConnectHook func(*SQLiteConn) error + Extensions []string + ConnectHook func(*SQLiteConn) error } // Conn struct. @@ -182,17 +182,35 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) } - enableLoadExtension := 0 - if d.EnableLoadExtension { - enableLoadExtension = 1 - } - rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtension)) - if rv != C.SQLITE_OK { - return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) - } - conn := &SQLiteConn{db} + if len(d.Extensions) > 0 { + rv = C.sqlite3_enable_load_extension(db, 1) + if rv != C.SQLITE_OK { + return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) + } + + stmt, err := conn.Prepare("SELECT load_extension(?);") + if err != nil { + return nil, err + } + + for _, extension := range d.Extensions { + if _, err = stmt.Exec([]driver.Value{extension}); err != nil { + return nil, err + } + } + + if err = stmt.Close(); err != nil { + return nil, err + } + + rv = C.sqlite3_enable_load_extension(db, 0) + if rv != C.SQLITE_OK { + return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) + } + } + if d.ConnectHook != nil { if err := d.ConnectHook(conn); err != nil { return nil, err |