aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sqlite3.go134
-rw-r--r--sqlite3_test.go39
2 files changed, 118 insertions, 55 deletions
diff --git a/sqlite3.go b/sqlite3.go
index 59e3670..86c0f64 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -802,20 +802,29 @@ func (c *SQLiteConn) exec(ctx context.Context, query string, args []namedValue)
}
var res driver.Result
if s.(*SQLiteStmt).s != nil {
+ stmtArgs := make([]namedValue, 0, len(args))
na := s.NumInput()
- if len(args) < na {
+ if len(args) - start < na {
s.Close()
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
}
- for i := 0; i < na; i++ {
- args[i].Ordinal -= start
+ // consume the number of arguments used in the current
+ // statement and append all named arguments not
+ // contained therein
+ stmtArgs = append(stmtArgs, args[start:start+na]...)
+ for i := range args {
+ if (i < start || i >= na) && args[i].Name != "" {
+ stmtArgs = append(stmtArgs, args[i])
+ }
+ }
+ for i := range stmtArgs {
+ stmtArgs[i].Ordinal = i + 1
}
- res, err = s.(*SQLiteStmt).exec(ctx, args[:na])
+ res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs)
if err != nil && err != driver.ErrSkip {
s.Close()
return nil, err
}
- args = args[na:]
start += na
}
tail := s.(*SQLiteStmt).t
@@ -848,24 +857,33 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue) (driver.Rows, error) {
start := 0
for {
+ stmtArgs := make([]namedValue, 0, len(args))
s, err := c.prepare(ctx, query)
if err != nil {
return nil, err
}
s.(*SQLiteStmt).cls = true
na := s.NumInput()
- if len(args) < na {
- return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
+ if len(args) - start < na {
+ return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args) - start)
+ }
+ // consume the number of arguments used in the current
+ // statement and append all named arguments not contained
+ // therein
+ stmtArgs = append(stmtArgs, args[start:start+na]...)
+ for i := range args {
+ if (i < start || i >= na) && args[i].Name != "" {
+ stmtArgs = append(stmtArgs, args[i])
+ }
}
- for i := 0; i < na; i++ {
- args[i].Ordinal -= start
+ for i := range stmtArgs {
+ stmtArgs[i].Ordinal = i + 1
}
- rows, err := s.(*SQLiteStmt).query(ctx, args[:na])
+ rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs)
if err != nil && err != driver.ErrSkip {
s.Close()
return rows, err
}
- args = args[na:]
start += na
tail := s.(*SQLiteStmt).t
if tail == "" {
@@ -1778,11 +1796,6 @@ func (s *SQLiteStmt) NumInput() int {
return int(C.sqlite3_bind_parameter_count(s.s))
}
-type bindArg struct {
- n int
- v driver.Value
-}
-
var placeHolder = []byte{0}
func (s *SQLiteStmt) bind(args []namedValue) error {
@@ -1791,52 +1804,63 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
return s.c.lastError()
}
+ bindIndices := make([][3]int, len(args))
+ prefixes := []string{":", "@", "$"}
for i, v := range args {
+ bindIndices[i][0] = args[i].Ordinal
if v.Name != "" {
- cname := C.CString(":" + v.Name)
- args[i].Ordinal = int(C.sqlite3_bind_parameter_index(s.s, cname))
- C.free(unsafe.Pointer(cname))
- }
- }
-
- for _, arg := range args {
- n := C.int(arg.Ordinal)
- switch v := arg.Value.(type) {
- case nil:
- rv = C.sqlite3_bind_null(s.s, n)
- case string:
- if len(v) == 0 {
- rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0))
- } else {
- b := []byte(v)
- rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
+ for j := range prefixes {
+ cname := C.CString(prefixes[j] + v.Name)
+ bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname))
+ C.free(unsafe.Pointer(cname))
}
- case int64:
- rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
- case bool:
- if v {
- rv = C.sqlite3_bind_int(s.s, n, 1)
- } else {
- rv = C.sqlite3_bind_int(s.s, n, 0)
+ args[i].Ordinal = bindIndices[i][0]
+ }
+ }
+
+ for i, arg := range args {
+ for j := range bindIndices[i] {
+ if bindIndices[i][j] == 0 {
+ continue
}
- case float64:
- rv = C.sqlite3_bind_double(s.s, n, C.double(v))
- case []byte:
- if v == nil {
+ n := C.int(bindIndices[i][j])
+ switch v := arg.Value.(type) {
+ case nil:
rv = C.sqlite3_bind_null(s.s, n)
- } else {
- ln := len(v)
- if ln == 0 {
- v = placeHolder
+ case string:
+ if len(v) == 0 {
+ rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0))
+ } else {
+ b := []byte(v)
+ rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
+ }
+ case int64:
+ rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
+ case bool:
+ if v {
+ rv = C.sqlite3_bind_int(s.s, n, 1)
+ } else {
+ rv = C.sqlite3_bind_int(s.s, n, 0)
+ }
+ case float64:
+ rv = C.sqlite3_bind_double(s.s, n, C.double(v))
+ case []byte:
+ if v == nil {
+ rv = C.sqlite3_bind_null(s.s, n)
+ } else {
+ ln := len(v)
+ if ln == 0 {
+ v = placeHolder
+ }
+ rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
}
- rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
+ case time.Time:
+ b := []byte(v.Format(SQLiteTimestampFormats[0]))
+ rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
+ }
+ if rv != C.SQLITE_OK {
+ return s.c.lastError()
}
- case time.Time:
- b := []byte(v.Format(SQLiteTimestampFormats[0]))
- rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
- }
- if rv != C.SQLITE_OK {
- return s.c.lastError()
}
}
return nil
diff --git a/sqlite3_test.go b/sqlite3_test.go
index 4b8fe01..d5b0cea 100644
--- a/sqlite3_test.go
+++ b/sqlite3_test.go
@@ -1778,6 +1778,45 @@ func TestInsertNilByteSlice(t *testing.T) {
}
}
+func TestNamedParam(t *testing.T) {
+ tempFilename := TempFilename(t)
+ defer os.Remove(tempFilename)
+ db, err := sql.Open("sqlite3", tempFilename)
+ if err != nil {
+ t.Fatal("Failed to open database:", err)
+ }
+ defer db.Close()
+
+ _, err = db.Exec("drop table foo")
+ _, err = db.Exec("create table foo (id integer, name text, amount integer)")
+ if err != nil {
+ t.Fatal("Failed to create table:", err)
+ }
+
+ _, err = db.Exec("insert into foo(id, name, amount) values(:id, @name, $amount)",
+ sql.Named("bar", 42), sql.Named("baz", "quux"),
+ sql.Named("amount", 123), sql.Named("corge", "waldo"),
+ sql.Named("id", 2), sql.Named("name", "grault"))
+ if err != nil {
+ t.Fatal("Failed to insert record with named parameters:", err)
+ }
+
+ rows, err := db.Query("select id, name, amount from foo")
+ if err != nil {
+ t.Fatal("Failed to select records:", err)
+ }
+ defer rows.Close()
+
+ rows.Next()
+
+ var id, amount int
+ var name string
+ rows.Scan(&id, &name, &amount)
+ if id != 2 || name != "grault" || amount != 123 {
+ t.Errorf("Expected %d, %q, %d for fetched result, but got %d, %q, %d:", 2, "grault", 123, id, name, amount)
+ }
+}
+
var customFunctionOnce sync.Once
func BenchmarkCustomFunctions(b *testing.B) {