aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--example/extension/extension.go4
-rw-r--r--sqlite3.go16
2 files changed, 14 insertions, 6 deletions
diff --git a/example/extension/extension.go b/example/extension/extension.go
index 3ec145a..95f2f70 100644
--- a/example/extension/extension.go
+++ b/example/extension/extension.go
@@ -3,11 +3,13 @@ package main
import (
"database/sql"
"fmt"
- _ "github.com/mattn/go-sqlite3"
+ "github.com/mattn/go-sqlite3"
"log"
)
func main() {
+ sql.Register("sqlite3_with_extensions", &sqlite3.SQLiteDriver{true, nil})
+
db, err := sql.Open("sqlite3_with_extensions", ":memory:")
if err != nil {
log.Fatal(err)
diff --git a/sqlite3.go b/sqlite3.go
index 0aba2db..43e255b 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -72,13 +72,13 @@ var SQLiteTimestampFormats = []string{
}
func init() {
- sql.Register("sqlite3", &SQLiteDriver{false})
- sql.Register("sqlite3_with_extensions", &SQLiteDriver{true})
+ sql.Register("sqlite3", &SQLiteDriver{false, nil})
}
// Driver struct.
type SQLiteDriver struct {
- enableLoadExtentions bool
+ EnableLoadExtentions bool
+ ConnectHook func(*SQLiteConn)
}
// Conn struct.
@@ -179,7 +179,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
}
enableLoadExtentions := 0
- if d.enableLoadExtentions {
+ if d.EnableLoadExtentions {
enableLoadExtentions = 1
}
rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtentions))
@@ -187,7 +187,13 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
- return &SQLiteConn{db}, nil
+ conn := &SQLiteConn{db}
+
+ if d.ConnectHook != nil {
+ d.ConnectHook(conn)
+ }
+
+ return conn, nil
}
// Close the connection.