diff options
author | Carlos Castillo <cookieo9@gmail.com> | 2013-08-24 20:36:35 -0700 |
---|---|---|
committer | Carlos Castillo <cookieo9@gmail.com> | 2013-08-24 20:36:35 -0700 |
commit | 0dd71564e26e5fbe2177d420e2bf82a889568d64 (patch) | |
tree | 0ab26d44a06baa3b794821078f7c9ba40f28aa75 /sqlite3.go | |
parent | Added error return to ConnectHook and fixed extension example (diff) | |
download | golite-0dd71564e26e5fbe2177d420e2bf82a889568d64.tar.gz golite-0dd71564e26e5fbe2177d420e2bf82a889568d64.tar.xz |
Changed extension support to load from a string list of extensions
By loading extensions this way, it's not possible to later load
extensions using db.Exec, which improves security, and makes it much
easier to load extensions correctly. The zero value for the slice
(the empty slice) loads no extensions by default.
The extension example has been updated to use this much simpler system.
The ConnectHook field is still in SQLiteDriver in case it's needed for
other driver-wide initialization.
Updates #71 of mattn/go-sqlite3.
Diffstat (limited to 'sqlite3.go')
-rw-r--r-- | sqlite3.go | 42 |
1 files changed, 30 insertions, 12 deletions
@@ -72,13 +72,13 @@ var SQLiteTimestampFormats = []string{ } func init() { - sql.Register("sqlite3", &SQLiteDriver{false, nil}) + sql.Register("sqlite3", &SQLiteDriver{}) } // Driver struct. type SQLiteDriver struct { - EnableLoadExtension bool - ConnectHook func(*SQLiteConn) error + Extensions []string + ConnectHook func(*SQLiteConn) error } // Conn struct. @@ -182,17 +182,35 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) } - enableLoadExtension := 0 - if d.EnableLoadExtension { - enableLoadExtension = 1 - } - rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtension)) - if rv != C.SQLITE_OK { - return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) - } - conn := &SQLiteConn{db} + if len(d.Extensions) > 0 { + rv = C.sqlite3_enable_load_extension(db, 1) + if rv != C.SQLITE_OK { + return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) + } + + stmt, err := conn.Prepare("SELECT load_extension(?);") + if err != nil { + return nil, err + } + + for _, extension := range d.Extensions { + if _, err = stmt.Exec([]driver.Value{extension}); err != nil { + return nil, err + } + } + + if err = stmt.Close(); err != nil { + return nil, err + } + + rv = C.sqlite3_enable_load_extension(db, 0) + if rv != C.SQLITE_OK { + return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) + } + } + if d.ConnectHook != nil { if err := d.ConnectHook(conn); err != nil { return nil, err |