diff options
Diffstat (limited to '')
-rw-r--r-- | sqlite3.go | 63 | ||||
-rw-r--r-- | sqlite3_test.go | 42 |
2 files changed, 93 insertions, 12 deletions
@@ -130,6 +130,8 @@ type SQLiteTx struct { type SQLiteStmt struct { c *SQLiteConn s *C.sqlite3_stmt + nv int + nn []string t string closed bool cls bool @@ -267,14 +269,17 @@ func errorString(err Error) string { // :memory: // file::memory: // go-sqlite handle especially query parameters. -// loc=XXX +// _loc=XXX // Specify location of time format. It's possible to specify "auto". +// _busy_timeout=XXX +// Specify value for sqlite3_busy_timeout. func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { if C.sqlite3_threadsafe() == 0 { return nil, errors.New("sqlite library was not compiled for thread-safe operation") } var loc *time.Location + busy_timeout := 5000 pos := strings.IndexRune(dsn, '?') if pos >= 1 { params, err := url.ParseQuery(dsn[pos+1:]) @@ -282,18 +287,27 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { return nil, err } - // loc - if val := params.Get("loc"); val != "" { + // _loc + if val := params.Get("_loc"); val != "" { if val == "auto" { loc = time.Local } else { loc, err = time.LoadLocation(val) if err != nil { - return nil, fmt.Errorf("Invalid loc: %v: %v", val, err) + return nil, fmt.Errorf("Invalid _loc: %v: %v", val, err) } } } + // _busy_timeout + if val := params.Get("_busy_timeout"); val != "" { + iv, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return nil, fmt.Errorf("Invalid _busy_timeout: %v: %v", val, err) + } + busy_timeout = int(iv) + } + if !strings.HasPrefix(dsn, "file:") { dsn = dsn[:pos] } @@ -314,7 +328,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { return nil, errors.New("sqlite succeeded without returning a database") } - rv = C.sqlite3_busy_timeout(db, 5000) + rv = C.sqlite3_busy_timeout(db, C.int(busy_timeout)) if rv != C.SQLITE_OK { return nil, Error{Code: ErrNo(rv)} } @@ -376,7 +390,15 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { if tail != nil && *tail != '\000' { t = strings.TrimSpace(C.GoString(tail)) } - ss := &SQLiteStmt{c: c, s: s, t: t} + nv := int(C.sqlite3_bind_parameter_count(s)) + var nn []string + for i := 0; i < nv; i++ { + pn := C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1))) + if len(pn) > 1 && pn[0] == '$' && 48 <= pn[1] && pn[1] <= 57 { + nn = append(nn, C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1)))) + } + } + ss := &SQLiteStmt{c: c, s: s, nv: nv, nn: nn, t: t} runtime.SetFinalizer(ss, (*SQLiteStmt).Close) return ss, nil } @@ -400,7 +422,12 @@ func (s *SQLiteStmt) Close() error { // Return a number of parameters. func (s *SQLiteStmt) NumInput() int { - return int(C.sqlite3_bind_parameter_count(s.s)) + return s.nv +} + +type bindArg struct { + n int + v driver.Value } func (s *SQLiteStmt) bind(args []driver.Value) error { @@ -409,8 +436,24 @@ func (s *SQLiteStmt) bind(args []driver.Value) error { return s.c.lastError() } - for i, v := range args { - n := C.int(i + 1) + var vargs []bindArg + narg := len(args) + vargs = make([]bindArg, narg) + if len(s.nn) > 0 { + for i, v := range s.nn { + if pi, err := strconv.Atoi(v[1:]); err == nil { + vargs[i] = bindArg{pi, args[i]} + } + } + } else { + for i, v := range args { + vargs[i] = bindArg{i + 1, v} + } + } + + for _, varg := range vargs { + n := C.int(varg.n) + v := varg.v switch v := v.(type) { case nil: rv = C.sqlite3_bind_null(s.s, n) @@ -471,6 +514,7 @@ func (r *SQLiteResult) RowsAffected() (int64, error) { func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { if err := s.bind(args); err != nil { C.sqlite3_reset(s.s) + C.sqlite3_clear_bindings(s.s) return nil, err } var rowid, changes C.long @@ -478,6 +522,7 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { err := s.c.lastError() C.sqlite3_reset(s.s) + C.sqlite3_clear_bindings(s.s) return nil, err } return &SQLiteResult{int64(rowid), int64(changes)}, nil diff --git a/sqlite3_test.go b/sqlite3_test.go index 6570b52..aa86011 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -640,7 +640,7 @@ func TestTimezoneConversion(t *testing.T) { zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} for _, tz := range zones { tempFilename := TempFilename() - db, err := sql.Open("sqlite3", tempFilename+"?loc="+url.QueryEscape(tz)) + db, err := sql.Open("sqlite3", tempFilename+"?_loc="+url.QueryEscape(tz)) if err != nil { t.Fatal("Failed to open database:", err) } @@ -845,7 +845,7 @@ func TestStress(t *testing.T) { func TestDateTimeLocal(t *testing.T) { zone := "Asia/Tokyo" tempFilename := TempFilename() - db, err := sql.Open("sqlite3", tempFilename+"?loc="+zone) + db, err := sql.Open("sqlite3", tempFilename+"?_loc="+zone) if err != nil { t.Fatal("Failed to open database:", err) } @@ -888,7 +888,7 @@ func TestDateTimeLocal(t *testing.T) { db.Exec("INSERT INTO foo VALUES(?);", dt) db.Close() - db, err = sql.Open("sqlite3", tempFilename+"?loc="+zone) + db, err = sql.Open("sqlite3", tempFilename+"?_loc="+zone) if err != nil { t.Fatal("Failed to open database:", err) } @@ -909,3 +909,39 @@ func TestVersion(t *testing.T) { t.Errorf("Version failed %q, %d, %q\n", s, n, id) } } + +func TestNumberNamedParams(t *testing.T) { + tempFilename := TempFilename() + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer os.Remove(tempFilename) + defer db.Close() + + _, err = db.Exec(` + create table foo (id integer, name text, extra text); + `) + if err != nil { + t.Error("Failed to call db.Query:", err) + } + + _, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, "foo") + if err != nil { + t.Error("Failed to call db.Exec:", err) + } + + row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, "foo") + if row == nil { + t.Error("Failed to call db.QueryRow") + } + var id int + var extra string + err = row.Scan(&id, &extra) + if err != nil { + t.Error("Failed to db.Scan:", err) + } + if id != 1 || extra != "foo" { + t.Error("Failed to db.QueryRow: not matched results") + } +} |