From dcd536348e52686758ec76f68dc0157804f763f9 Mon Sep 17 00:00:00 2001 From: EuAndreh Date: Sun, 20 Oct 2024 20:30:37 -0300 Subject: Revert c3a3cf9d7aed9b3c48acbe31fd48f2c27549a570 --- Makefile | 2 +- README.md | 4 +- deps.mk | 38 +- src/acudego.go | 2616 -------------------------- src/golite.go | 2616 ++++++++++++++++++++++++++ tests/acudego.go | 3652 ------------------------------------- tests/benchmarks/exec/acudego.go | 31 - tests/benchmarks/exec/golite.go | 31 + tests/benchmarks/query/acudego.go | 42 - tests/benchmarks/query/golite.go | 42 + tests/functional/json/acudego.go | 98 - tests/functional/json/golite.go | 98 + tests/functional/limit/acudego.go | 104 -- tests/functional/limit/golite.go | 104 ++ tests/fuzz/api/acudego.go | 34 - tests/fuzz/api/golite.go | 34 + tests/golite.go | 3652 +++++++++++++++++++++++++++++++++++++ tests/main.go | 4 +- 18 files changed, 6601 insertions(+), 6601 deletions(-) delete mode 100644 src/acudego.go create mode 100644 src/golite.go delete mode 100644 tests/acudego.go delete mode 100644 tests/benchmarks/exec/acudego.go create mode 100644 tests/benchmarks/exec/golite.go delete mode 100644 tests/benchmarks/query/acudego.go create mode 100644 tests/benchmarks/query/golite.go delete mode 100644 tests/functional/json/acudego.go create mode 100644 tests/functional/json/golite.go delete mode 100644 tests/functional/limit/acudego.go create mode 100644 tests/functional/limit/golite.go delete mode 100644 tests/fuzz/api/acudego.go create mode 100644 tests/fuzz/api/golite.go create mode 100644 tests/golite.go diff --git a/Makefile b/Makefile index 4267fa3..ab3dbe0 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ .POSIX: DATE = 1970-01-01 VERSION = 0.1.0 -NAME = acudego +NAME = golite NAME_UC = $(NAME) LANGUAGES = en ## Installation prefix. Defaults to "/usr". diff --git a/README.md b/README.md index 7f63746..5209ee7 100644 --- a/README.md +++ b/README.md @@ -506,7 +506,7 @@ For an example, see [dinedal/go-sqlite3-extension-functions](https://github.com/ - Why I'm getting `no such table` error? - Why is it racy if I use a `sql.Open("acude", ":memory:")` database? + Why is it racy if I use a `sql.Open("golite", ":memory:")` database? Each connection to `":memory:"` opens a brand new in-memory sql database, so if the stdlib's sql engine happens to open another connection and you've only @@ -545,7 +545,7 @@ For an example, see [dinedal/go-sqlite3-extension-functions](https://github.com/ Example: ```go - db, err := sql.Open("acude", "file:locked.sqlite?cache=shared") + db, err := sql.Open("golite", "file:locked.sqlite?cache=shared") ``` Next, please set the database connections of the SQL package to 1: diff --git a/deps.mk b/deps.mk index 5be3ee3..4b2c048 100644 --- a/deps.mk +++ b/deps.mk @@ -1,11 +1,11 @@ libs.go = \ - src/acudego.go \ - tests/acudego.go \ - tests/benchmarks/exec/acudego.go \ - tests/benchmarks/query/acudego.go \ - tests/functional/json/acudego.go \ - tests/functional/limit/acudego.go \ - tests/fuzz/api/acudego.go \ + src/golite.go \ + tests/benchmarks/exec/golite.go \ + tests/benchmarks/query/golite.go \ + tests/functional/json/golite.go \ + tests/functional/limit/golite.go \ + tests/fuzz/api/golite.go \ + tests/golite.go \ mains.go = \ tests/benchmarks/exec/main.go \ @@ -16,39 +16,39 @@ mains.go = \ tests/main.go \ functional-tests/lib.go = \ - tests/functional/json/acudego.go \ - tests/functional/limit/acudego.go \ + tests/functional/json/golite.go \ + tests/functional/limit/golite.go \ functional-tests/main.go = \ tests/functional/json/main.go \ tests/functional/limit/main.go \ fuzz-targets/lib.go = \ - tests/fuzz/api/acudego.go \ + tests/fuzz/api/golite.go \ fuzz-targets/main.go = \ tests/fuzz/api/main.go \ benchmarks/lib.go = \ - tests/benchmarks/exec/acudego.go \ - tests/benchmarks/query/acudego.go \ + tests/benchmarks/exec/golite.go \ + tests/benchmarks/query/golite.go \ benchmarks/main.go = \ tests/benchmarks/exec/main.go \ tests/benchmarks/query/main.go \ -src/acudego.a: src/acudego.go -tests/acudego.a: tests/acudego.go -tests/benchmarks/exec/acudego.a: tests/benchmarks/exec/acudego.go +src/golite.a: src/golite.go +tests/benchmarks/exec/golite.a: tests/benchmarks/exec/golite.go tests/benchmarks/exec/main.a: tests/benchmarks/exec/main.go -tests/benchmarks/query/acudego.a: tests/benchmarks/query/acudego.go +tests/benchmarks/query/golite.a: tests/benchmarks/query/golite.go tests/benchmarks/query/main.a: tests/benchmarks/query/main.go -tests/functional/json/acudego.a: tests/functional/json/acudego.go +tests/functional/json/golite.a: tests/functional/json/golite.go tests/functional/json/main.a: tests/functional/json/main.go -tests/functional/limit/acudego.a: tests/functional/limit/acudego.go +tests/functional/limit/golite.a: tests/functional/limit/golite.go tests/functional/limit/main.a: tests/functional/limit/main.go -tests/fuzz/api/acudego.a: tests/fuzz/api/acudego.go +tests/fuzz/api/golite.a: tests/fuzz/api/golite.go tests/fuzz/api/main.a: tests/fuzz/api/main.go +tests/golite.a: tests/golite.go tests/main.a: tests/main.go tests/benchmarks/exec/main.bin: tests/benchmarks/exec/main.a tests/benchmarks/query/main.bin: tests/benchmarks/query/main.a diff --git a/src/acudego.go b/src/acudego.go deleted file mode 100644 index c9acbcd..0000000 --- a/src/acudego.go +++ /dev/null @@ -1,2616 +0,0 @@ -package acudego - -import ( - "context" - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "io" - "math" - "net/url" - "reflect" - "runtime" - "strconv" - "strings" - "sync" - "time" - "unsafe" - "syscall" -) - - - -/* -#include -#include - -void stepTrampoline(sqlite3_context *, int, sqlite3_value **); -void doneTrampoline(sqlite3_context *); -int compareTrampoline(void *, int, char *, int, char *); -int commitHookTrampoline(void *); -void rollbackHookTrampoline(void *); -void updateHookTrampoline(void *, int, char *, char *, sqlite3_int64); -*/ -import "C" - - - -// SQLiteBackup implement interface of Backup. -type SQLiteBackup struct { - b *C.sqlite3_backup -} - -// Backup make backup from src to dest. -func (destConn *SQLiteConn) Backup(dest string, srcConn *SQLiteConn, src string) (*SQLiteBackup, error) { - destptr := C.CString(dest) - defer C.free(unsafe.Pointer(destptr)) - srcptr := C.CString(src) - defer C.free(unsafe.Pointer(srcptr)) - - if b := C.sqlite3_backup_init(destConn.db, destptr, srcConn.db, srcptr); b != nil { - bb := &SQLiteBackup{b: b} - runtime.SetFinalizer(bb, (*SQLiteBackup).Finish) - return bb, nil - } - return nil, destConn.lastError() -} - -// Step to backs up for one step. Calls the underlying `sqlite3_backup_step` -// function. This function returns a boolean indicating if the backup is done -// and an error signalling any other error. Done is returned if the underlying -// C function returns SQLITE_DONE (Code 101) -func (b *SQLiteBackup) Step(p int) (bool, error) { - ret := C.sqlite3_backup_step(b.b, C.int(p)) - if ret == C.SQLITE_DONE { - return true, nil - } else if ret != 0 && ret != C.SQLITE_LOCKED && ret != C.SQLITE_BUSY { - return false, Error{Code: ErrNo(ret)} - } - return false, nil -} - -// Remaining return whether have the rest for backup. -func (b *SQLiteBackup) Remaining() int { - return int(C.sqlite3_backup_remaining(b.b)) -} - -// PageCount return count of pages. -func (b *SQLiteBackup) PageCount() int { - return int(C.sqlite3_backup_pagecount(b.b)) -} - -// Finish close backup. -func (b *SQLiteBackup) Finish() error { - return b.Close() -} - -// Close close backup. -func (b *SQLiteBackup) Close() error { - ret := C.sqlite3_backup_finish(b.b) - - // sqlite3_backup_finish() never fails, it just returns the - // error code from previous operations, so clean up before - // checking and returning an error - b.b = nil - runtime.SetFinalizer(b, nil) - - if ret != 0 { - return Error{Code: ErrNo(ret)} - } - return nil -} - -//export stepTrampoline -func stepTrampoline( - ctx *C.sqlite3_context, - argc C.int, - argv **C.sqlite3_value, -) { - const size = (math.MaxInt32 - 1) / - unsafe.Sizeof((*C.sqlite3_value)(nil)) - args := (*[size]*C.sqlite3_value)(unsafe.Pointer(argv)) - slice := args[:int(argc):int(argc)] - lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo).Step(ctx, slice) -} - -//export doneTrampoline -func doneTrampoline(ctx *C.sqlite3_context) { - lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo).Done(ctx) -} - -//export compareTrampoline -func compareTrampoline( - handlePtr unsafe.Pointer, - la C.int, - a *C.char, - lb C.int, - b *C.char, -) C.int { - cmpFn := lookupHandle(handlePtr).(func(string, string) int) - return C.int(cmpFn(C.GoStringN(a, la), C.GoStringN(b, lb))) -} - -//export commitHookTrampoline -func commitHookTrampoline(handle unsafe.Pointer) C.int { - return C.int(lookupHandle(handle).(func() int)()) -} - -//export rollbackHookTrampoline -func rollbackHookTrampoline(handle unsafe.Pointer) { - lookupHandle(handle).(func())() -} - -//export updateHookTrampoline -func updateHookTrampoline( - handle unsafe.Pointer, - op C.int, - db *C.char, - table *C.char, - rowid int64, -) { - lookupHandle(handle).(func(int, string, string, int64))( - int(op), - C.GoString(db), - C.GoString(table), - int64(rowid), - ) -} - -// Use handles to avoid passing Go pointers to C. -type handleVal struct { - db *SQLiteConn - val any -} - -var handleLock sync.Mutex -var handleVals = make(map[unsafe.Pointer]handleVal) - -func newHandle(db *SQLiteConn, v any) unsafe.Pointer { - val := handleVal{db: db, val: v} - p := C.malloc(C.size_t(1)) - if p == nil { - panic("can't allocate 'cgo-pointer hack index pointer': ptr == nil") - } - { - handleLock.Lock() - defer handleLock.Unlock() - handleVals[p] = val - } - return p -} - -func lookupHandleVal(handle unsafe.Pointer) handleVal { - handleLock.Lock() - defer handleLock.Unlock() - return handleVals[handle] -} - -func lookupHandle(handle unsafe.Pointer) any { - return lookupHandleVal(handle).val -} - -func deleteHandles(db *SQLiteConn) { - handleLock.Lock() - defer handleLock.Unlock() - for handle, val := range handleVals { - if val.db == db { - delete(handleVals, handle) - C.free(handle) - } - } -} - -type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error) - -type callbackArgCast struct { - f callbackArgConverter - typ reflect.Type -} - -func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) { - val, err := c.f(v) - if err != nil { - return reflect.Value{}, err - } - if !val.Type().ConvertibleTo(c.typ) { - return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ) - } - return val.Convert(c.typ), nil -} - -func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) { - if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { - return reflect.Value{}, fmt.Errorf("argument must be an INTEGER") - } - return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil -} - -func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) { - if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { - return reflect.Value{}, fmt.Errorf("argument must be an INTEGER") - } - i := int64(C.sqlite3_value_int64(v)) - val := false - if i != 0 { - val = true - } - return reflect.ValueOf(val), nil -} - -func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) { - if C.sqlite3_value_type(v) != C.SQLITE_FLOAT { - return reflect.Value{}, fmt.Errorf("argument must be a FLOAT") - } - return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil -} - -func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) { - switch C.sqlite3_value_type(v) { - case C.SQLITE_BLOB: - l := C.sqlite3_value_bytes(v) - p := C.sqlite3_value_blob(v) - return reflect.ValueOf(C.GoBytes(p, l)), nil - case C.SQLITE_TEXT: - l := C.sqlite3_value_bytes(v) - c := unsafe.Pointer(C.sqlite3_value_text(v)) - return reflect.ValueOf(C.GoBytes(c, l)), nil - default: - return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT") - } -} - -func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) { - switch C.sqlite3_value_type(v) { - case C.SQLITE_BLOB: - l := C.sqlite3_value_bytes(v) - p := (*C.char)(C.sqlite3_value_blob(v)) - return reflect.ValueOf(C.GoStringN(p, l)), nil - case C.SQLITE_TEXT: - c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v))) - return reflect.ValueOf(C.GoString(c)), nil - default: - return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT") - } -} - -func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) { - switch C.sqlite3_value_type(v) { - case C.SQLITE_INTEGER: - return callbackArgInt64(v) - case C.SQLITE_FLOAT: - return callbackArgFloat64(v) - case C.SQLITE_TEXT: - return callbackArgString(v) - case C.SQLITE_BLOB: - return callbackArgBytes(v) - case C.SQLITE_NULL: - // Interpret NULL as a nil byte slice. - var ret []byte - return reflect.ValueOf(ret), nil - default: - panic("unreachable") - } -} - -func callbackArg(typ reflect.Type) (callbackArgConverter, error) { - switch typ.Kind() { - case reflect.Interface: - if typ.NumMethod() != 0 { - return nil, errors.New("the only supported interface type is any") - } - return callbackArgGeneric, nil - case reflect.Slice: - if typ.Elem().Kind() != reflect.Uint8 { - return nil, errors.New("the only supported slice type is []byte") - } - return callbackArgBytes, nil - case reflect.String: - return callbackArgString, nil - case reflect.Bool: - return callbackArgBool, nil - case reflect.Int64: - return callbackArgInt64, nil - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: - c := callbackArgCast{callbackArgInt64, typ} - return c.Run, nil - case reflect.Float64: - return callbackArgFloat64, nil - case reflect.Float32: - c := callbackArgCast{callbackArgFloat64, typ} - return c.Run, nil - default: - return nil, fmt.Errorf("don't know how to convert to %s", typ) - } -} - -func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) { - var args []reflect.Value - - if len(argv) < len(converters) { - return nil, fmt.Errorf("function requires at least %d arguments", len(converters)) - } - - for i, arg := range argv[:len(converters)] { - v, err := converters[i](arg) - if err != nil { - return nil, err - } - args = append(args, v) - } - - if variadic != nil { - for _, arg := range argv[len(converters):] { - v, err := variadic(arg) - if err != nil { - return nil, err - } - args = append(args, v) - } - } - return args, nil -} - -type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error - -func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error { - switch v.Type().Kind() { - case reflect.Int64: - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: - v = v.Convert(reflect.TypeOf(int64(0))) - case reflect.Bool: - b := v.Interface().(bool) - if b { - v = reflect.ValueOf(int64(1)) - } else { - v = reflect.ValueOf(int64(0)) - } - default: - return fmt.Errorf("cannot convert %s to INTEGER", v.Type()) - } - - C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64))) - return nil -} - -func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error { - switch v.Type().Kind() { - case reflect.Float64: - case reflect.Float32: - v = v.Convert(reflect.TypeOf(float64(0))) - default: - return fmt.Errorf("cannot convert %s to FLOAT", v.Type()) - } - - C.sqlite3_result_double(ctx, C.double(v.Interface().(float64))) - return nil -} - -func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error { - if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 { - return fmt.Errorf("cannot convert %s to BLOB", v.Type()) - } - i := v.Interface() - if i == nil || len(i.([]byte)) == 0 { - C.sqlite3_result_null(ctx) - } else { - bs := i.([]byte) - C.sqlite3_result_blob( - ctx, - unsafe.Pointer(&bs[0]), - C.int(len(bs)), - C.SQLITE_TRANSIENT, - ) - } - return nil -} - -func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error { - if v.Type().Kind() != reflect.String { - return fmt.Errorf("cannot convert %s to TEXT", v.Type()) - } - C.sqlite3_result_text( - ctx, - C.CString(v.Interface().(string)), - -1, - (*[0]byte)(unsafe.Pointer(C.free)), - ) - return nil -} - -func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error { - return nil -} - -func callbackRetGeneric(ctx *C.sqlite3_context, v reflect.Value) error { - if v.IsNil() { - C.sqlite3_result_null(ctx) - return nil - } - - cb, err := callbackRet(v.Elem().Type()) - if err != nil { - return err - } - - return cb(ctx, v.Elem()) -} - -func callbackRet(typ reflect.Type) (callbackRetConverter, error) { - switch typ.Kind() { - case reflect.Interface: - errorInterface := reflect.TypeOf((*error)(nil)).Elem() - if typ.Implements(errorInterface) { - return callbackRetNil, nil - } - - if typ.NumMethod() == 0 { - return callbackRetGeneric, nil - } - - fallthrough - case reflect.Slice: - if typ.Elem().Kind() != reflect.Uint8 { - return nil, errors.New("the only supported slice type is []byte") - } - return callbackRetBlob, nil - case reflect.String: - return callbackRetText, nil - case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: - return callbackRetInteger, nil - case reflect.Float32, reflect.Float64: - return callbackRetFloat, nil - default: - return nil, fmt.Errorf("don't know how to convert to %s", typ) - } -} - -func callbackError(ctx *C.sqlite3_context, err error) { - cstr := C.CString(err.Error()) - defer C.free(unsafe.Pointer(cstr)) - C.sqlite3_result_error(ctx, cstr, C.int(-1)) -} - -// FIXME: remove this -// Test support code. Tests are not allowed to import "C", so we can't -// declare any functions that use C.sqlite3_value. -func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter { - return func(*C.sqlite3_value) (reflect.Value, error) { - return v, err - } -} - -// Extracted from Go database/sql source code -// Type conversions for Scan. - -var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error - -// convertAssign copies to dest the value in src, converting it if possible. -// An error is returned if the copy would result in loss of information. -// dest should be a pointer type. -func convertAssign(dest, src any) error { - // Common cases, without reflect. - switch s := src.(type) { - case string: - switch d := dest.(type) { - case *string: - if d == nil { - return errNilPtr - } - *d = s - return nil - case *[]byte: - if d == nil { - return errNilPtr - } - *d = []byte(s) - return nil - case *sql.RawBytes: - if d == nil { - return errNilPtr - } - *d = append((*d)[:0], s...) - return nil - } - case []byte: - switch d := dest.(type) { - case *string: - if d == nil { - return errNilPtr - } - *d = string(s) - return nil - case *any: - if d == nil { - return errNilPtr - } - *d = cloneBytes(s) - return nil - case *[]byte: - if d == nil { - return errNilPtr - } - *d = cloneBytes(s) - return nil - case *sql.RawBytes: - if d == nil { - return errNilPtr - } - *d = s - return nil - } - case time.Time: - switch d := dest.(type) { - case *time.Time: - *d = s - return nil - case *string: - *d = s.Format(time.RFC3339Nano) - return nil - case *[]byte: - if d == nil { - return errNilPtr - } - *d = []byte(s.Format(time.RFC3339Nano)) - return nil - case *sql.RawBytes: - if d == nil { - return errNilPtr - } - *d = s.AppendFormat((*d)[:0], time.RFC3339Nano) - return nil - } - case nil: - switch d := dest.(type) { - case *any: - if d == nil { - return errNilPtr - } - *d = nil - return nil - case *[]byte: - if d == nil { - return errNilPtr - } - *d = nil - return nil - case *sql.RawBytes: - if d == nil { - return errNilPtr - } - *d = nil - return nil - } - } - - var sv reflect.Value - - switch d := dest.(type) { - case *string: - sv = reflect.ValueOf(src) - switch sv.Kind() { - case reflect.Bool, - reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, - reflect.Float32, reflect.Float64: - *d = asString(src) - return nil - } - case *[]byte: - sv = reflect.ValueOf(src) - if b, ok := asBytes(nil, sv); ok { - *d = b - return nil - } - case *sql.RawBytes: - sv = reflect.ValueOf(src) - if b, ok := asBytes([]byte(*d)[:0], sv); ok { - *d = sql.RawBytes(b) - return nil - } - case *bool: - bv, err := driver.Bool.ConvertValue(src) - if err == nil { - *d = bv.(bool) - } - return err - case *any: - *d = src - return nil - } - - if scanner, ok := dest.(sql.Scanner); ok { - return scanner.Scan(src) - } - - dpv := reflect.ValueOf(dest) - if dpv.Kind() != reflect.Ptr { - return errors.New("destination not a pointer") - } - if dpv.IsNil() { - return errNilPtr - } - - if !sv.IsValid() { - sv = reflect.ValueOf(src) - } - - dv := reflect.Indirect(dpv) - if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { - switch b := src.(type) { - case []byte: - dv.Set(reflect.ValueOf(cloneBytes(b))) - default: - dv.Set(sv) - } - return nil - } - - if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { - dv.Set(sv.Convert(dv.Type())) - return nil - } - - // The following conversions use a string value as an intermediate representation - // to convert between various numeric types. - // - // This also allows scanning into user defined types such as "type Int int64". - // For symmetry, also check for string destination types. - switch dv.Kind() { - case reflect.Ptr: - if src == nil { - dv.Set(reflect.Zero(dv.Type())) - return nil - } - dv.Set(reflect.New(dv.Type().Elem())) - return convertAssign(dv.Interface(), src) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - s := asString(src) - i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) - if err != nil { - err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) - } - dv.SetInt(i64) - return nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - s := asString(src) - u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) - if err != nil { - err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) - } - dv.SetUint(u64) - return nil - case reflect.Float32, reflect.Float64: - s := asString(src) - f64, err := strconv.ParseFloat(s, dv.Type().Bits()) - if err != nil { - err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) - } - dv.SetFloat(f64) - return nil - case reflect.String: - switch v := src.(type) { - case string: - dv.SetString(v) - return nil - case []byte: - dv.SetString(string(v)) - return nil - } - } - - return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) -} - -func strconvErr(err error) error { - if ne, ok := err.(*strconv.NumError); ok { - return ne.Err - } - return err -} - -func cloneBytes(b []byte) []byte { - if b == nil { - return nil - } - c := make([]byte, len(b)) - copy(c, b) - return c -} - -func asString(src any) string { - switch v := src.(type) { - case string: - return v - case []byte: - return string(v) - } - rv := reflect.ValueOf(src) - switch rv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return strconv.FormatInt(rv.Int(), 10) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return strconv.FormatUint(rv.Uint(), 10) - case reflect.Float64: - return strconv.FormatFloat(rv.Float(), 'g', -1, 64) - case reflect.Float32: - return strconv.FormatFloat(rv.Float(), 'g', -1, 32) - case reflect.Bool: - return strconv.FormatBool(rv.Bool()) - } - return fmt.Sprintf("%v", src) -} - -func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { - switch rv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return strconv.AppendInt(buf, rv.Int(), 10), true - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return strconv.AppendUint(buf, rv.Uint(), 10), true - case reflect.Float32: - return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true - case reflect.Float64: - return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true - case reflect.Bool: - return strconv.AppendBool(buf, rv.Bool()), true - case reflect.String: - s := rv.String() - return append(buf, s...), true - } - return -} - -/* -Package sqlite3 provides interface to SQLite3 databases. - -This works as a driver for database/sql. - -Installation - - go get github.com/mattn/go-sqlite3 - -# Supported Types - -Currently, go-sqlite3 supports the following data types. - - +------------------------------+ - |go | sqlite3 | - |----------|-------------------| - |nil | null | - |int | integer | - |int64 | integer | - |float64 | float | - |bool | integer | - |[]byte | blob | - |string | text | - |time.Time | timestamp/datetime| - +------------------------------+ - -# SQLite3 Extension - -You can write your own extension module for sqlite3. For example, below is an -extension for a Regexp matcher operation. - - #include - #include - #include - #include - - SQLITE_EXTENSION_INIT1 - static void regexp_func(sqlite3_context *context, int argc, sqlite3_value **argv) { - if (argc >= 2) { - const char *target = (const char *)sqlite3_value_text(argv[1]); - const char *pattern = (const char *)sqlite3_value_text(argv[0]); - const char* errstr = NULL; - int erroff = 0; - int vec[500]; - int n, rc; - pcre* re = pcre_compile(pattern, 0, &errstr, &erroff, NULL); - rc = pcre_exec(re, NULL, target, strlen(target), 0, 0, vec, 500); - if (rc <= 0) { - sqlite3_result_error(context, errstr, 0); - return; - } - sqlite3_result_int(context, 1); - } - } - - int sqlite3_extension_init(sqlite3 *db, char **errmsg, - const sqlite3_api_routines *api) { - SQLITE_EXTENSION_INIT2(api); - return sqlite3_create_function(db, "regexp", 2, SQLITE_UTF8, - (void*)db, regexp_func, NULL, NULL); - } - -It needs to be built as a so/dll shared library. And you need to register -the extension module like below. - - sql.Register("sqlite3_with_extensions", - &sqlite3.SQLiteDriver{ - Extensions: []string{ - "sqlite3_mod_regexp", - }, - }) - -Then, you can use this extension. - - rows, err := db.Query("select text from mytable where name regexp '^golang'") - -# Connection Hook - -You can hook and inject your code when the connection is established by setting -ConnectHook to get the SQLiteConn. - - sql.Register("sqlite3_with_hook_example", - &sqlite3.SQLiteDriver{ - ConnectHook: func(conn *sqlite3.SQLiteConn) error { - sqlite3conn = append(sqlite3conn, conn) - return nil - }, - }) - -You can also use database/sql.Conn.Raw (Go >= 1.13): - - conn, err := db.Conn(context.Background()) - // if err != nil { ... } - defer conn.Close() - err = conn.Raw(func (driverConn any) error { - sqliteConn := driverConn.(*sqlite3.SQLiteConn) - // ... use sqliteConn - }) - // if err != nil { ... } - -# Go SQlite3 Extensions - -If you want to register Go functions as SQLite extension functions -you can make a custom driver by calling RegisterFunction from -ConnectHook. - - regex = func(re, s string) (bool, error) { - return regexp.MatchString(re, s) - } - sql.Register("sqlite3_extended", - &sqlite3.SQLiteDriver{ - ConnectHook: func(conn *sqlite3.SQLiteConn) error { - return conn.RegisterFunc("regexp", regex, true) - }, - }) - -You can then use the custom driver by passing its name to sql.Open. - - var i int - conn, err := sql.Open("sqlite3_extended", "./foo.db") - if err != nil { - panic(err) - } - err = db.QueryRow(`SELECT regexp("foo.*", "seafood")`).Scan(&i) - if err != nil { - panic(err) - } - -See the documentation of RegisterFunc for more details. -*/ - -// ErrNo inherit errno. -type ErrNo int - -// ErrNoMask is mask code. -const ErrNoMask C.int = 0xff - -// ErrNoExtended is extended errno. -type ErrNoExtended int - -// Error implement sqlite error code. -type Error struct { - Code ErrNo /* The error code returned by SQLite */ - ExtendedCode ErrNoExtended /* The extended error code returned by SQLite */ - SystemErrno syscall.Errno /* The system errno returned by the OS through SQLite, if applicable */ - err string /* The error string returned by sqlite3_errmsg(), - this usually contains more specific details. */ -} - -// result codes from http://www.sqlite.org/c3ref/c_abort.html -var ( - ErrError = ErrNo(1) /* SQL error or missing database */ - ErrInternal = ErrNo(2) /* Internal logic error in SQLite */ - ErrPerm = ErrNo(3) /* Access permission denied */ - ErrAbort = ErrNo(4) /* Callback routine requested an abort */ - ErrBusy = ErrNo(5) /* The database file is locked */ - ErrLocked = ErrNo(6) /* A table in the database is locked */ - ErrNomem = ErrNo(7) /* A malloc() failed */ - ErrReadonly = ErrNo(8) /* Attempt to write a readonly database */ - ErrInterrupt = ErrNo(9) /* Operation terminated by sqlite3_interrupt() */ - ErrIoErr = ErrNo(10) /* Some kind of disk I/O error occurred */ - ErrCorrupt = ErrNo(11) /* The database disk image is malformed */ - ErrNotFound = ErrNo(12) /* Unknown opcode in sqlite3_file_control() */ - ErrFull = ErrNo(13) /* Insertion failed because database is full */ - ErrCantOpen = ErrNo(14) /* Unable to open the database file */ - ErrProtocol = ErrNo(15) /* Database lock protocol error */ - ErrEmpty = ErrNo(16) /* Database is empty */ - ErrSchema = ErrNo(17) /* The database schema changed */ - ErrTooBig = ErrNo(18) /* String or BLOB exceeds size limit */ - ErrConstraint = ErrNo(19) /* Abort due to constraint violation */ - ErrMismatch = ErrNo(20) /* Data type mismatch */ - ErrMisuse = ErrNo(21) /* Library used incorrectly */ - ErrNoLFS = ErrNo(22) /* Uses OS features not supported on host */ - ErrAuth = ErrNo(23) /* Authorization denied */ - ErrFormat = ErrNo(24) /* Auxiliary database format error */ - ErrRange = ErrNo(25) /* 2nd parameter to sqlite3_bind out of range */ - ErrNotADB = ErrNo(26) /* File opened that is not a database file */ - ErrNotice = ErrNo(27) /* Notifications from sqlite3_log() */ - ErrWarning = ErrNo(28) /* Warnings from sqlite3_log() */ -) - -// Error return error message from errno. -func (err ErrNo) Error() string { - return Error{Code: err}.Error() -} - -// Extend return extended errno. -func (err ErrNo) Extend(by int) ErrNoExtended { - return ErrNoExtended(int(err) | (by << 8)) -} - -// Error return error message that is extended code. -func (err ErrNoExtended) Error() string { - return Error{Code: ErrNo(C.int(err) & ErrNoMask), ExtendedCode: err}.Error() -} - -func (err Error) Error() string { - var str string - if err.err != "" { - str = err.err - } else { - str = C.GoString(C.sqlite3_errstr(C.int(err.Code))) - } - if err.SystemErrno != 0 { - str += ": " + err.SystemErrno.Error() - } - return str -} - -// result codes from http://www.sqlite.org/c3ref/c_abort_rollback.html -var ( - ErrIoErrRead = ErrIoErr.Extend(1) - ErrIoErrShortRead = ErrIoErr.Extend(2) - ErrIoErrWrite = ErrIoErr.Extend(3) - ErrIoErrFsync = ErrIoErr.Extend(4) - ErrIoErrDirFsync = ErrIoErr.Extend(5) - ErrIoErrTruncate = ErrIoErr.Extend(6) - ErrIoErrFstat = ErrIoErr.Extend(7) - ErrIoErrUnlock = ErrIoErr.Extend(8) - ErrIoErrRDlock = ErrIoErr.Extend(9) - ErrIoErrDelete = ErrIoErr.Extend(10) - ErrIoErrBlocked = ErrIoErr.Extend(11) - ErrIoErrNoMem = ErrIoErr.Extend(12) - ErrIoErrAccess = ErrIoErr.Extend(13) - ErrIoErrCheckReservedLock = ErrIoErr.Extend(14) - ErrIoErrLock = ErrIoErr.Extend(15) - ErrIoErrClose = ErrIoErr.Extend(16) - ErrIoErrDirClose = ErrIoErr.Extend(17) - ErrIoErrSHMOpen = ErrIoErr.Extend(18) - ErrIoErrSHMSize = ErrIoErr.Extend(19) - ErrIoErrSHMLock = ErrIoErr.Extend(20) - ErrIoErrSHMMap = ErrIoErr.Extend(21) - ErrIoErrSeek = ErrIoErr.Extend(22) - ErrIoErrDeleteNoent = ErrIoErr.Extend(23) - ErrIoErrMMap = ErrIoErr.Extend(24) - ErrIoErrGetTempPath = ErrIoErr.Extend(25) - ErrIoErrConvPath = ErrIoErr.Extend(26) - ErrLockedSharedCache = ErrLocked.Extend(1) - ErrBusyRecovery = ErrBusy.Extend(1) - ErrBusySnapshot = ErrBusy.Extend(2) - ErrCantOpenNoTempDir = ErrCantOpen.Extend(1) - ErrCantOpenIsDir = ErrCantOpen.Extend(2) - ErrCantOpenFullPath = ErrCantOpen.Extend(3) - ErrCantOpenConvPath = ErrCantOpen.Extend(4) - ErrCorruptVTab = ErrCorrupt.Extend(1) - ErrReadonlyRecovery = ErrReadonly.Extend(1) - ErrReadonlyCantLock = ErrReadonly.Extend(2) - ErrReadonlyRollback = ErrReadonly.Extend(3) - ErrReadonlyDbMoved = ErrReadonly.Extend(4) - ErrAbortRollback = ErrAbort.Extend(2) - ErrConstraintCheck = ErrConstraint.Extend(1) - ErrConstraintCommitHook = ErrConstraint.Extend(2) - ErrConstraintForeignKey = ErrConstraint.Extend(3) - ErrConstraintFunction = ErrConstraint.Extend(4) - ErrConstraintNotNull = ErrConstraint.Extend(5) - ErrConstraintPrimaryKey = ErrConstraint.Extend(6) - ErrConstraintTrigger = ErrConstraint.Extend(7) - ErrConstraintUnique = ErrConstraint.Extend(8) - ErrConstraintVTab = ErrConstraint.Extend(9) - ErrConstraintRowID = ErrConstraint.Extend(10) - ErrNoticeRecoverWAL = ErrNotice.Extend(1) - ErrNoticeRecoverRollback = ErrNotice.Extend(2) - ErrWarningAutoIndex = ErrWarning.Extend(1) -) - -// FIXME: remove this -// SQLiteTimestampFormats is timestamp formats understood by both this module -// and SQLite. The first format in the slice will be used when saving time -// values into the database. When parsing a string from a timestamp or datetime -// column, the formats are tried in order. -var SQLiteTimestampFormats = []string{ - // By default, store timestamps with whatever timezone they come with. - // When parsed, they will be returned with the same timezone. - "2006-01-02 15:04:05.999999999-07:00", - "2006-01-02T15:04:05.999999999-07:00", - "2006-01-02 15:04:05.999999999", - "2006-01-02T15:04:05.999999999", - "2006-01-02 15:04:05", - "2006-01-02T15:04:05", - "2006-01-02 15:04", - "2006-01-02T15:04", - "2006-01-02", -} - -const ( - columnDate string = "date" - columnDatetime string = "datetime" - columnTimestamp string = "timestamp" -) - -const DriverName = "acude" -func init() { - sql.Register(DriverName, &SQLiteDriver{}) -} - -// Version returns SQLite library version information. -func LibVersion() (libVersion string, libVersionNumber int, sourceID string) { - libVersion = C.GoString(C.sqlite3_libversion()) - libVersionNumber = int(C.sqlite3_libversion_number()) - sourceID = C.GoString(C.sqlite3_sourceid()) - return libVersion, libVersionNumber, sourceID -} - -const ( - // used by update hook - SQLITE_DELETE = C.SQLITE_DELETE - SQLITE_INSERT = C.SQLITE_INSERT - SQLITE_UPDATE = C.SQLITE_UPDATE -) - -// Standard File Control Opcodes -// See: https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html -const ( - SQLITE_FCNTL_LOCKSTATE = int(1) - SQLITE_FCNTL_GET_LOCKPROXYFILE = int(2) - SQLITE_FCNTL_SET_LOCKPROXYFILE = int(3) - SQLITE_FCNTL_LAST_ERRNO = int(4) - SQLITE_FCNTL_SIZE_HINT = int(5) - SQLITE_FCNTL_CHUNK_SIZE = int(6) - SQLITE_FCNTL_FILE_POINTER = int(7) - SQLITE_FCNTL_SYNC_OMITTED = int(8) - SQLITE_FCNTL_WIN32_AV_RETRY = int(9) - SQLITE_FCNTL_PERSIST_WAL = int(10) - SQLITE_FCNTL_OVERWRITE = int(11) - SQLITE_FCNTL_VFSNAME = int(12) - SQLITE_FCNTL_POWERSAFE_OVERWRITE = int(13) - SQLITE_FCNTL_PRAGMA = int(14) - SQLITE_FCNTL_BUSYHANDLER = int(15) - SQLITE_FCNTL_TEMPFILENAME = int(16) - SQLITE_FCNTL_MMAP_SIZE = int(18) - SQLITE_FCNTL_TRACE = int(19) - SQLITE_FCNTL_HAS_MOVED = int(20) - SQLITE_FCNTL_SYNC = int(21) - SQLITE_FCNTL_COMMIT_PHASETWO = int(22) - SQLITE_FCNTL_WIN32_SET_HANDLE = int(23) - SQLITE_FCNTL_WAL_BLOCK = int(24) - SQLITE_FCNTL_ZIPVFS = int(25) - SQLITE_FCNTL_RBU = int(26) - SQLITE_FCNTL_VFS_POINTER = int(27) - SQLITE_FCNTL_JOURNAL_POINTER = int(28) - SQLITE_FCNTL_WIN32_GET_HANDLE = int(29) - SQLITE_FCNTL_PDB = int(30) - SQLITE_FCNTL_BEGIN_ATOMIC_WRITE = int(31) - SQLITE_FCNTL_COMMIT_ATOMIC_WRITE = int(32) - SQLITE_FCNTL_ROLLBACK_ATOMIC_WRITE = int(33) - SQLITE_FCNTL_LOCK_TIMEOUT = int(34) - SQLITE_FCNTL_DATA_VERSION = int(35) - SQLITE_FCNTL_SIZE_LIMIT = int(36) - SQLITE_FCNTL_CKPT_DONE = int(37) - SQLITE_FCNTL_RESERVE_BYTES = int(38) - SQLITE_FCNTL_CKPT_START = int(39) - SQLITE_FCNTL_EXTERNAL_READER = int(40) - SQLITE_FCNTL_CKSM_FILE = int(41) -) - -// SQLiteDriver implements driver.Driver. -type SQLiteDriver struct { - ConnectHook func(*SQLiteConn) error -} - -// SQLiteConn implements driver.Conn. -type SQLiteConn struct { - mu sync.Mutex - db *C.sqlite3 - loc *time.Location - funcs []*functionInfo - aggregators []*aggInfo -} - -// SQLiteTx implements driver.Tx. -type SQLiteTx struct { - c *SQLiteConn -} - -// SQLiteStmt implements driver.Stmt. -type SQLiteStmt struct { - mu sync.Mutex - c *SQLiteConn - s *C.sqlite3_stmt - t string - closed bool - cls bool -} - -// SQLiteResult implements sql.Result. -type SQLiteResult struct { - id int64 - changes int64 -} - -// SQLiteRows implements driver.Rows. -type SQLiteRows struct { - s *SQLiteStmt - nc int - cols []string - decltype []string - cls bool - closed bool - ctx context.Context // no better alternative to pass context into Next() method -} - -type functionInfo struct { - f reflect.Value - argConverters []callbackArgConverter - variadicConverter callbackArgConverter - retConverter callbackRetConverter -} - -func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { - args, err := callbackConvertArgs(argv, fi.argConverters, fi.variadicConverter) - if err != nil { - callbackError(ctx, err) - return - } - - ret := fi.f.Call(args) - - if len(ret) == 2 && ret[1].Interface() != nil { - callbackError(ctx, ret[1].Interface().(error)) - return - } - - err = fi.retConverter(ctx, ret[0]) - if err != nil { - callbackError(ctx, err) - return - } -} - -type aggInfo struct { - constructor reflect.Value - - // Active aggregator objects for aggregations in flight. The - // aggregators are indexed by a counter stored in the aggregation - // user data space provided by sqlite. - active map[int64]reflect.Value - next int64 - - stepArgConverters []callbackArgConverter - stepVariadicConverter callbackArgConverter - - doneRetConverter callbackRetConverter -} - -func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) { - aggIdx := (*int64)(C.sqlite3_aggregate_context(ctx, C.int(8))) - if *aggIdx == 0 { - *aggIdx = ai.next - ret := ai.constructor.Call(nil) - if len(ret) == 2 && ret[1].Interface() != nil { - return 0, reflect.Value{}, ret[1].Interface().(error) - } - if ret[0].IsNil() { - return 0, reflect.Value{}, errors.New("aggregator constructor returned nil state") - } - ai.next++ - ai.active[*aggIdx] = ret[0] - } - return *aggIdx, ai.active[*aggIdx], nil -} - -func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { - _, agg, err := ai.agg(ctx) - if err != nil { - callbackError(ctx, err) - return - } - - args, err := callbackConvertArgs(argv, ai.stepArgConverters, ai.stepVariadicConverter) - if err != nil { - callbackError(ctx, err) - return - } - - ret := agg.MethodByName("Step").Call(args) - if len(ret) == 1 && ret[0].Interface() != nil { - callbackError(ctx, ret[0].Interface().(error)) - return - } -} - -func (ai *aggInfo) Done(ctx *C.sqlite3_context) { - idx, agg, err := ai.agg(ctx) - if err != nil { - callbackError(ctx, err) - return - } - defer func() { delete(ai.active, idx) }() - - ret := agg.MethodByName("Done").Call(nil) - if len(ret) == 2 && ret[1].Interface() != nil { - callbackError(ctx, ret[1].Interface().(error)) - return - } - - err = ai.doneRetConverter(ctx, ret[0]) - if err != nil { - callbackError(ctx, err) - return - } -} - -// Commit transaction. -func (tx *SQLiteTx) Commit() error { - _, err := tx.c.exec(context.Background(), "COMMIT", nil) - if err != nil { - // sqlite3 may leave the transaction open in this scenario. - // However, database/sql considers the transaction complete once we - // return from Commit() - we must clean up to honour its semantics. - // We don't know if the ROLLBACK is strictly necessary, but according - // to sqlite's docs, there is no harm in calling ROLLBACK unnecessarily. - tx.c.exec(context.Background(), "ROLLBACK", nil) - } - return err -} - -// Rollback transaction. -func (tx *SQLiteTx) Rollback() error { - _, err := tx.c.exec(context.Background(), "ROLLBACK", nil) - return err -} - -// RegisterCollation makes a Go function available as a collation. -// -// cmp receives two UTF-8 strings, a and b. The result should be 0 if -// a==b, -1 if a < b, and +1 if a > b. -// -// cmp must always return the same result given the same -// inputs. Additionally, it must have the following properties for all -// strings A, B and C: if A==B then B==A; if A==B and B==C then A==C; -// if AA; if A 0 { - stmtArgs = append(stmtArgs, args[start:start+na]...) - for i := range args { - if (i < start || i >= na) && args[i].Name != "" { - stmtArgs = append(stmtArgs, args[i]) - } - } - for i := range stmtArgs { - stmtArgs[i].Ordinal = i + 1 - } - } - res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs) - if err != nil && err != driver.ErrSkip { - s.Close() - return nil, err - } - start += na - } - tail := s.(*SQLiteStmt).t - s.Close() - if tail == "" { - if res == nil { - // https://github.com/mattn/go-sqlite3/issues/963 - res = &SQLiteResult{0, 0} - } - return res, nil - } - query = tail - } -} - -// Query implements Queryer. -func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) { - list := make([]driver.NamedValue, len(args)) - for i, v := range args { - list[i] = driver.NamedValue{ - Ordinal: i + 1, - Value: v, - } - } - return c.query(context.Background(), query, list) -} - -func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - start := 0 - for { - stmtArgs := make([]driver.NamedValue, 0, len(args)) - s, err := c.prepare(ctx, query) - if err != nil { - return nil, err - } - s.(*SQLiteStmt).cls = true - na := s.NumInput() - if len(args)-start < na { - return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start) - } - // consume the number of arguments used in the current - // statement and append all named arguments not contained - // therein - stmtArgs = append(stmtArgs, args[start:start+na]...) - for i := range args { - if (i < start || i >= na) && args[i].Name != "" { - stmtArgs = append(stmtArgs, args[i]) - } - } - for i := range stmtArgs { - stmtArgs[i].Ordinal = i + 1 - } - rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs) - if err != nil && err != driver.ErrSkip { - s.Close() - return rows, err - } - start += na - tail := s.(*SQLiteStmt).t - if tail == "" { - return rows, nil - } - rows.Close() - s.Close() - query = tail - } -} - -// Begin transaction. -func (c *SQLiteConn) Begin() (driver.Tx, error) { - return c.begin(context.Background()) -} - -func (c *SQLiteConn) begin(ctx context.Context) (driver.Tx, error) { - if _, err := c.exec(ctx, "BEGIN", nil); err != nil { - return nil, err - } - return &SQLiteTx{c}, nil -} - -// Open database and return a new connection. -// -// A pragma can take either zero or one argument. -// The argument is may be either in parentheses or it may be separated from -// the pragma name by an equal sign. The two syntaxes yield identical results. -// In many pragmas, the argument is a boolean. The boolean can be one of: -// -// 1 yes true on -// 0 no false off -// -// You can specify a DSN string using a URI as the filename. -// -// test.db -// file:test.db?cache=shared&mode=memory -// :memory: -// file::memory: -// -// cache -// SQLite Shared-Cache Mode -// https://www.sqlite.org/sharedcache.html -// Values: -// - shared -// - private -// -// go-sqlite3 adds the following query parameters to those used by SQLite: -// -// _loc=XXX -// Specify location of time format. It's possible to specify "auto". -func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { - // Options - var loc *time.Location - - pos := strings.IndexRune(dsn, '?') - if pos >= 1 { - params, err := url.ParseQuery(dsn[pos+1:]) - if err != nil { - return nil, err - } - - // _loc - if val := params.Get("_loc"); val != "" { - switch strings.ToLower(val) { - case "auto": - loc = time.Local - default: - loc, err = time.LoadLocation(val) - if err != nil { - return nil, fmt.Errorf("Invalid _loc: %v: %v", val, err) - } - } - } - - if !strings.HasPrefix(dsn, "file:") { - dsn = dsn[:pos] - } - } - - var db *C.sqlite3 - name := C.CString(dsn) - defer C.free(unsafe.Pointer(name)) - - var openFlags C.int = - C.SQLITE_OPEN_READWRITE | // FIXME: fails if RO FS? - C.SQLITE_OPEN_CREATE | - C.SQLITE_OPEN_URI | - C.SQLITE_OPEN_FULLMUTEX - rv := C.sqlite3_open_v2(name, &db, openFlags, nil) - if rv != 0 { - // Save off the error _before_ closing the database. - // This is safe even if db is nil. - err := lastError(db) - if db != nil { - C.sqlite3_close_v2(db) - } - return nil, err - } - if db == nil { - return nil, errors.New("sqlite succeeded without returning a database") - } - - const setup = ` - PRAGMA journal_mode = WAL; - PRAGMA busy_timeout = 10000; - ` - setupCStr := C.CString(setup) - defer C.free(unsafe.Pointer(setupCStr)) - rv = C.sqlite3_exec(db, setupCStr, nil, nil, nil) - if rv != C.SQLITE_OK { - err := lastError(db) - C.sqlite3_close_v2(db) - return nil, err - } - - // Create connection to SQLite - conn := &SQLiteConn{db: db, loc: loc} - - if d.ConnectHook != nil { - if err := d.ConnectHook(conn); err != nil { - conn.Close() - return nil, err - } - } - - runtime.SetFinalizer(conn, (*SQLiteConn).Close) - return conn, nil -} - -// Close the connection. -func (c *SQLiteConn) Close() error { - rv := C.sqlite3_close_v2(c.db) - if rv != C.SQLITE_OK { - return c.lastError() - } - deleteHandles(c) - c.mu.Lock() - c.db = nil - c.mu.Unlock() - runtime.SetFinalizer(c, nil) - return nil -} - -func (c *SQLiteConn) dbConnOpen() bool { - if c == nil { - return false - } - c.mu.Lock() - defer c.mu.Unlock() - return c.db != nil -} - -// Prepare the query string. Return a new statement. -func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { - return c.prepare(context.Background(), query) -} - -func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, error) { - pquery := C.CString(query) - defer C.free(unsafe.Pointer(pquery)) - var s *C.sqlite3_stmt - var tail *C.char - rv := C.sqlite3_prepare_v2(c.db, pquery, C.int(-1), &s, &tail) - if rv != C.SQLITE_OK { - return nil, c.lastError() - } - var t string - if tail != nil && *tail != '\000' { - t = strings.TrimSpace(C.GoString(tail)) - } - ss := &SQLiteStmt{c: c, s: s, t: t} - runtime.SetFinalizer(ss, (*SQLiteStmt).Close) - return ss, nil -} - -// Run-Time Limit Categories. -// See: http://www.sqlite.org/c3ref/c_limit_attached.html -const ( - SQLITE_LIMIT_LENGTH = C.SQLITE_LIMIT_LENGTH - SQLITE_LIMIT_SQL_LENGTH = C.SQLITE_LIMIT_SQL_LENGTH - SQLITE_LIMIT_COLUMN = C.SQLITE_LIMIT_COLUMN - SQLITE_LIMIT_EXPR_DEPTH = C.SQLITE_LIMIT_EXPR_DEPTH - SQLITE_LIMIT_COMPOUND_SELECT = C.SQLITE_LIMIT_COMPOUND_SELECT - SQLITE_LIMIT_VDBE_OP = C.SQLITE_LIMIT_VDBE_OP - SQLITE_LIMIT_FUNCTION_ARG = C.SQLITE_LIMIT_FUNCTION_ARG - SQLITE_LIMIT_ATTACHED = C.SQLITE_LIMIT_ATTACHED - SQLITE_LIMIT_LIKE_PATTERN_LENGTH = C.SQLITE_LIMIT_LIKE_PATTERN_LENGTH - SQLITE_LIMIT_VARIABLE_NUMBER = C.SQLITE_LIMIT_VARIABLE_NUMBER - SQLITE_LIMIT_TRIGGER_DEPTH = C.SQLITE_LIMIT_TRIGGER_DEPTH - SQLITE_LIMIT_WORKER_THREADS = C.SQLITE_LIMIT_WORKER_THREADS -) - -// GetFilename returns the absolute path to the file containing -// the requested schema. When passed an empty string, it will -// instead use the database's default schema: "main". -// See: sqlite3_db_filename, https://www.sqlite.org/c3ref/db_filename.html -func (c *SQLiteConn) GetFilename(schemaName string) string { - if schemaName == "" { - schemaName = "main" - } - return C.GoString(C.sqlite3_db_filename(c.db, C.CString(schemaName))) -} - -func (c *SQLiteConn) GetLimit(id int) int { - return int(C.sqlite3_limit(c.db, C.int(id), C.int(-1))) -} - -func (c *SQLiteConn) SetLimit(id int, newVal int) int { - return int(C.sqlite3_limit(c.db, C.int(id), C.int(newVal))) -} - -// SetFileControlInt invokes the xFileControl method on a given database. The -// dbName is the name of the database. It will default to "main" if left blank. -// The op is one of the opcodes prefixed by "SQLITE_FCNTL_". The arg argument -// and return code are both opcode-specific. Please see the SQLite documentation. -// -// This method is not thread-safe as the returned error code can be changed by -// another call if invoked concurrently. -// -// See: sqlite3_file_control, https://www.sqlite.org/c3ref/file_control.html -func (c *SQLiteConn) SetFileControlInt(dbName string, op int, arg int) error { - if dbName == "" { - dbName = "main" - } - - cDBName := C.CString(dbName) - defer C.free(unsafe.Pointer(cDBName)) - - cArg := C.int(arg) - rv := C.sqlite3_file_control(c.db, cDBName, C.int(op), unsafe.Pointer(&cArg)) - if rv != C.SQLITE_OK { - return c.lastError() - } - return nil -} - -func (s *SQLiteStmt) Close() error { - s.mu.Lock() - defer s.mu.Unlock() - if s.closed { - return nil - } - s.closed = true - if !s.c.dbConnOpen() { - return errors.New("sqlite statement with already closed database connection") - } - rv := C.sqlite3_finalize(s.s) - s.s = nil - if rv != C.SQLITE_OK { - return s.c.lastError() - } - s.c = nil - runtime.SetFinalizer(s, nil) - return nil -} - -func (s *SQLiteStmt) NumInput() int { - return int(C.sqlite3_bind_parameter_count(s.s)) -} - -var placeHolder = []byte{0} - -func (s *SQLiteStmt) bind(args []driver.NamedValue) error { - rv := C.sqlite3_reset(s.s) - if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { - return s.c.lastError() - } - - bindIndices := make([][3]int, len(args)) - prefixes := []string{":", "@", "$"} - for i, v := range args { - bindIndices[i][0] = args[i].Ordinal - if v.Name != "" { - for j := range prefixes { - cname := C.CString(prefixes[j] + v.Name) - bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname)) - C.free(unsafe.Pointer(cname)) - } - args[i].Ordinal = bindIndices[i][0] - } - } - - for i, arg := range args { - for j := range bindIndices[i] { - if bindIndices[i][j] == 0 { - continue - } - n := C.int(bindIndices[i][j]) - switch v := arg.Value.(type) { - case nil: - rv = C.sqlite3_bind_null(s.s, n) - case string: - if len(v) == 0 { - rv = C.sqlite3_bind_text( - s.s, - n, - (*C.char)(unsafe.Pointer(&placeHolder[0])), - C.int(0), - C.SQLITE_TRANSIENT, - ) - } else { - b := []byte(v) - rv = C.sqlite3_bind_text( - s.s, - n, - (*C.char)(unsafe.Pointer(&b[0])), - C.int(len(b)), - C.SQLITE_TRANSIENT, - ) - } - case int64: - rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v)) - case bool: - if v { - rv = C.sqlite3_bind_int(s.s, n, 1) - } else { - rv = C.sqlite3_bind_int(s.s, n, 0) - } - case float64: - rv = C.sqlite3_bind_double(s.s, n, C.double(v)) - case []byte: - if v == nil { - rv = C.sqlite3_bind_null(s.s, n) - } else { - ln := len(v) - if ln == 0 { - v = placeHolder - } - rv = C.sqlite3_bind_blob( - s.s, - n, - unsafe.Pointer(&v[0]), - C.int(ln), - C.SQLITE_TRANSIENT, - ) - } - case time.Time: - b := []byte(v.Format(SQLiteTimestampFormats[0])) - rv = C.sqlite3_bind_text( - s.s, - n, - (*C.char)(unsafe.Pointer(&b[0])), - C.int(len(b)), - C.SQLITE_TRANSIENT, - ) - } - if rv != C.SQLITE_OK { - return s.c.lastError() - } - } - } - return nil -} - -// Query the statement with arguments. Return records. -func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) { - list := make([]driver.NamedValue, len(args)) - for i, v := range args { - list[i] = driver.NamedValue{ - Ordinal: i + 1, - Value: v, - } - } - return s.query(context.Background(), list) -} - -func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - if err := s.bind(args); err != nil { - return nil, err - } - - rows := &SQLiteRows{ - s: s, - nc: int(C.sqlite3_column_count(s.s)), - cols: nil, - decltype: nil, - cls: s.cls, - closed: false, - ctx: ctx, - } - runtime.SetFinalizer(rows, (*SQLiteRows).Close) - - return rows, nil -} - -// LastInsertId return last inserted ID. -func (r *SQLiteResult) LastInsertId() (int64, error) { - return r.id, nil -} - -// RowsAffected return how many rows affected. -func (r *SQLiteResult) RowsAffected() (int64, error) { - return r.changes, nil -} - -// Exec execute the statement with arguments. Return result object. -func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { - list := make([]driver.NamedValue, len(args)) - for i, v := range args { - list[i] = driver.NamedValue{ - Ordinal: i + 1, - Value: v, - } - } - return s.exec(context.Background(), list) -} - -func isInterruptErr(err error) bool { - sqliteErr, ok := err.(Error) - if ok { - return sqliteErr.Code == ErrInterrupt - } - return false -} - -// exec executes a query that doesn't return rows. Attempts to honor context timeout. -func (s *SQLiteStmt) exec(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - if ctx.Done() == nil { - return s.execSync(args) - } - - type result struct { - r driver.Result - err error - } - resultCh := make(chan result) - defer close(resultCh) - go func() { - r, err := s.execSync(args) - resultCh <- result{r, err} - }() - var rv result - select { - case rv = <-resultCh: - case <-ctx.Done(): - select { - case rv = <-resultCh: // no need to interrupt, operation completed in db - default: - // this is still racy and can be no-op if executed between sqlite3_* calls in execSync. - C.sqlite3_interrupt(s.c.db) - rv = <-resultCh // wait for goroutine completed - if isInterruptErr(rv.err) { - return nil, ctx.Err() - } - } - } - return rv.r, rv.err -} - -func (s *SQLiteStmt) execSync(args []driver.NamedValue) (driver.Result, error) { - if err := s.bind(args); err != nil { - C.sqlite3_reset(s.s) - C.sqlite3_clear_bindings(s.s) - return nil, err - } - - rv := C.sqlite3_step(s.s) - if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { - err := s.c.lastError() - C.sqlite3_reset(s.s) - C.sqlite3_clear_bindings(s.s) - return nil, err - } - - db := C.sqlite3_db_handle(s.s) - id := int64(C.sqlite3_last_insert_rowid(db)) - changes := int64(C.sqlite3_changes(db)) - - return &SQLiteResult{ - id: id, - changes: changes, - }, nil -} - -// Readonly reports if this statement is considered readonly by SQLite. -// -// See: https://sqlite.org/c3ref/stmt_readonly.html -func (s *SQLiteStmt) Readonly() bool { - return C.sqlite3_stmt_readonly(s.s) == 1 -} - -func (rc *SQLiteRows) Close() error { - rc.s.mu.Lock() - if rc.s.closed || rc.closed { - rc.s.mu.Unlock() - return nil - } - rc.closed = true - if rc.cls { - rc.s.mu.Unlock() - return rc.s.Close() - } - rv := C.sqlite3_reset(rc.s.s) - if rv != C.SQLITE_OK { - rc.s.mu.Unlock() - return rc.s.c.lastError() - } - rc.s.mu.Unlock() - rc.s = nil - runtime.SetFinalizer(rc, nil) - return nil -} - -func (rc *SQLiteRows) Columns() []string { - rc.s.mu.Lock() - defer rc.s.mu.Unlock() - if rc.s.s != nil && rc.nc != len(rc.cols) { - rc.cols = make([]string, rc.nc) - for i := 0; i < rc.nc; i++ { - rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i))) - } - } - return rc.cols -} - -func (rc *SQLiteRows) declTypes() []string { - if rc.s.s != nil && rc.decltype == nil { - rc.decltype = make([]string, rc.nc) - for i := 0; i < rc.nc; i++ { - rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i)))) - } - } - return rc.decltype -} - -// DeclTypes return column types. -func (rc *SQLiteRows) DeclTypes() []string { - rc.s.mu.Lock() - defer rc.s.mu.Unlock() - return rc.declTypes() -} - -// Next move cursor to next. Attempts to honor context timeout from QueryContext call. -func (rc *SQLiteRows) Next(dest []driver.Value) error { - rc.s.mu.Lock() - defer rc.s.mu.Unlock() - - if rc.s.closed { - return io.EOF - } - - if rc.ctx.Done() == nil { - return rc.nextSyncLocked(dest) - } - resultCh := make(chan error) - defer close(resultCh) - go func() { - resultCh <- rc.nextSyncLocked(dest) - }() - select { - case err := <-resultCh: - return err - case <-rc.ctx.Done(): - select { - case <-resultCh: // no need to interrupt - default: - // this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked. - C.sqlite3_interrupt(rc.s.c.db) - <-resultCh // ensure goroutine completed - } - return rc.ctx.Err() - } -} - -// nextSyncLocked moves cursor to next; must be called with locked mutex. -func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error { - rv := C.sqlite3_step(rc.s.s) - if rv == C.SQLITE_DONE { - return io.EOF - } - if rv != C.SQLITE_ROW { - rv = C.sqlite3_reset(rc.s.s) - if rv != C.SQLITE_OK { - return rc.s.c.lastError() - } - return nil - } - - rc.declTypes() - - for i := range dest { - switch C.sqlite3_column_type(rc.s.s, C.int(i)) { - case C.SQLITE_INTEGER: - val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i))) - switch rc.decltype[i] { - case columnTimestamp, columnDatetime, columnDate: - var t time.Time - // Assume a millisecond unix timestamp if it's 13 digits -- too - // large to be a reasonable timestamp in seconds. - if val > 1e12 || val < -1e12 { - val *= int64(time.Millisecond) // convert ms to nsec - t = time.Unix(0, val) - } else { - t = time.Unix(val, 0) - } - t = t.UTC() - if rc.s.c.loc != nil { - t = t.In(rc.s.c.loc) - } - dest[i] = t - case "boolean": - dest[i] = val > 0 - default: - dest[i] = val - } - case C.SQLITE_FLOAT: - dest[i] = float64(C.sqlite3_column_double(rc.s.s, C.int(i))) - case C.SQLITE_BLOB: - p := C.sqlite3_column_blob(rc.s.s, C.int(i)) - if p == nil { - dest[i] = []byte{} - continue - } - n := C.sqlite3_column_bytes(rc.s.s, C.int(i)) - dest[i] = C.GoBytes(p, n) - case C.SQLITE_NULL: - dest[i] = nil - case C.SQLITE_TEXT: - var err error - var timeVal time.Time - - n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i))) - s := C.GoStringN((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))), C.int(n)) - - switch rc.decltype[i] { - case columnTimestamp, columnDatetime, columnDate: - var t time.Time - s = strings.TrimSuffix(s, "Z") - for _, format := range SQLiteTimestampFormats { - if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil { - t = timeVal - break - } - } - if err != nil { - // The column is a time value, so return the zero time on parse failure. - t = time.Time{} - } - if rc.s.c.loc != nil { - t = t.In(rc.s.c.loc) - } - dest[i] = t - default: - dest[i] = s - } - } - } - return nil -} - -const i64 = unsafe.Sizeof(int(0)) > 4 - -// SQLiteContext behave sqlite3_context -type SQLiteContext C.sqlite3_context - -// ResultBool sets the result of an SQL function. -func (c *SQLiteContext) ResultBool(b bool) { - if b { - c.ResultInt(1) - } else { - c.ResultInt(0) - } -} - -// ResultBlob sets the result of an SQL function. -// See: sqlite3_result_blob, http://sqlite.org/c3ref/result_blob.html -func (c *SQLiteContext) ResultBlob(b []byte) { - if i64 && len(b) > math.MaxInt32 { - C.sqlite3_result_error_toobig((*C.sqlite3_context)(c)) - return - } - var p *byte - if len(b) > 0 { - p = &b[0] - } - C.sqlite3_result_blob( - (*C.sqlite3_context)(c), - unsafe.Pointer(p), - C.int(len(b)), - C.SQLITE_TRANSIENT, - ) -} - -// ResultDouble sets the result of an SQL function. -// See: sqlite3_result_double, http://sqlite.org/c3ref/result_blob.html -func (c *SQLiteContext) ResultDouble(d float64) { - C.sqlite3_result_double((*C.sqlite3_context)(c), C.double(d)) -} - -// ResultInt sets the result of an SQL function. -// See: sqlite3_result_int, http://sqlite.org/c3ref/result_blob.html -func (c *SQLiteContext) ResultInt(i int) { - if i64 && (i > math.MaxInt32 || i < math.MinInt32) { - C.sqlite3_result_int64((*C.sqlite3_context)(c), C.sqlite3_int64(i)) - } else { - C.sqlite3_result_int((*C.sqlite3_context)(c), C.int(i)) - } -} - -// ResultInt64 sets the result of an SQL function. -// See: sqlite3_result_int64, http://sqlite.org/c3ref/result_blob.html -func (c *SQLiteContext) ResultInt64(i int64) { - C.sqlite3_result_int64((*C.sqlite3_context)(c), C.sqlite3_int64(i)) -} - -// ResultNull sets the result of an SQL function. -// See: sqlite3_result_null, http://sqlite.org/c3ref/result_blob.html -func (c *SQLiteContext) ResultNull() { - C.sqlite3_result_null((*C.sqlite3_context)(c)) -} - -// ResultText sets the result of an SQL function. -// See: sqlite3_result_text, http://sqlite.org/c3ref/result_blob.html -func (c *SQLiteContext) ResultText(s string) { - h := (*reflect.StringHeader)(unsafe.Pointer(&s)) - cs, l := (*C.char)(unsafe.Pointer(h.Data)), C.int(h.Len) - C.sqlite3_result_text( - (*C.sqlite3_context)(c), - cs, - l, - C.SQLITE_TRANSIENT, - ) -} - -// ResultZeroblob sets the result of an SQL function. -// See: sqlite3_result_zeroblob, http://sqlite.org/c3ref/result_blob.html -func (c *SQLiteContext) ResultZeroblob(n int) { - C.sqlite3_result_zeroblob((*C.sqlite3_context)(c), C.int(n)) -} - -// Ping implement Pinger. -func (c *SQLiteConn) Ping(ctx context.Context) error { - if c.db == nil { - // must be ErrBadConn for sql to close the database - return driver.ErrBadConn - } - return nil -} - -// QueryContext implement QueryerContext. -func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - return c.query(ctx, query, args) -} - -// ExecContext implement ExecerContext. -func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - return c.exec(ctx, query, args) -} - -// PrepareContext implement ConnPrepareContext. -func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - return c.prepare(ctx, query) -} - -// BeginTx implement ConnBeginTx. -func (c *SQLiteConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - return c.begin(ctx) -} - -// QueryContext implement QueryerContext. -func (s *SQLiteStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - return s.query(ctx, args) -} - -// ExecContext implement ExecerContext. -func (s *SQLiteStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - return s.exec(ctx, args) -} - -// ColumnTableName returns the table that is the origin of a particular result -// column in a SELECT statement. -// -// See https://www.sqlite.org/c3ref/column_database_name.html -func (s *SQLiteStmt) ColumnTableName(n int) string { - return C.GoString(C.sqlite3_column_table_name(s.s, C.int(n))) -} - -// Serialize returns a byte slice that is a serialization of the database. -// -// See https://www.sqlite.org/c3ref/serialize.html -func (c *SQLiteConn) Serialize(schema string) ([]byte, error) { - if schema == "" { - schema = "main" - } - var zSchema *C.char - zSchema = C.CString(schema) - defer C.free(unsafe.Pointer(zSchema)) - - var sz C.sqlite3_int64 - ptr := C.sqlite3_serialize(c.db, zSchema, &sz, 0) - if ptr == nil { - return nil, fmt.Errorf("serialize failed") - } - defer C.sqlite3_free(unsafe.Pointer(ptr)) - - if sz > C.sqlite3_int64(math.MaxInt) { - return nil, fmt.Errorf("serialized database is too large (%d bytes)", sz) - } - - cBuf := *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(ptr)), - Len: int(sz), - Cap: int(sz), - })) - - res := make([]byte, int(sz)) - copy(res, cBuf) - return res, nil -} - -// Deserialize causes the connection to disconnect from the current database and -// then re-open as an in-memory database based on the contents of the byte slice. -// -// See https://www.sqlite.org/c3ref/deserialize.html -func (c *SQLiteConn) Deserialize(b []byte, schema string) error { - if schema == "" { - schema = "main" - } - var zSchema *C.char - zSchema = C.CString(schema) - defer C.free(unsafe.Pointer(zSchema)) - - tmpBuf := (*C.uchar)(C.sqlite3_malloc64(C.sqlite3_uint64(len(b)))) - cBuf := *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(tmpBuf)), - Len: len(b), - Cap: len(b), - })) - copy(cBuf, b) - - rc := C.sqlite3_deserialize(c.db, zSchema, tmpBuf, C.sqlite3_int64(len(b)), - C.sqlite3_int64(len(b)), C.SQLITE_DESERIALIZE_FREEONCLOSE) - if rc != C.SQLITE_OK { - return fmt.Errorf("deserialize failed with return %v", rc) - } - return nil -} - -// Op is type of operations. -type Op uint8 - -// Op mean identity of operations. -const ( - OpEQ Op = 2 - OpGT = 4 - OpLE = 8 - OpLT = 16 - OpGE = 32 - OpMATCH = 64 - OpLIKE = 65 /* 3.10.0 and later only */ - OpGLOB = 66 /* 3.10.0 and later only */ - OpREGEXP = 67 /* 3.10.0 and later only */ - OpScanUnique = 1 /* Scan visits at most 1 row */ -) - -// InfoConstraint give information of constraint. -type InfoConstraint struct { - Column int - Op Op - Usable bool -} - -// InfoOrderBy give information of order-by. -type InfoOrderBy struct { - Column int - Desc bool -} - -func constraints(info *C.sqlite3_index_info) []InfoConstraint { - slice := *(*[]C.struct_sqlite3_index_constraint)(unsafe.Pointer(&reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(info.aConstraint)), - Len: int(info.nConstraint), - Cap: int(info.nConstraint), - })) - - cst := make([]InfoConstraint, 0, len(slice)) - for _, c := range slice { - var usable bool - if c.usable > 0 { - usable = true - } - cst = append(cst, InfoConstraint{ - Column: int(c.iColumn), - Op: Op(c.op), - Usable: usable, - }) - } - return cst -} - -func orderBys(info *C.sqlite3_index_info) []InfoOrderBy { - slice := *(*[]C.struct_sqlite3_index_orderby)(unsafe.Pointer(&reflect.SliceHeader{ - Data: uintptr(unsafe.Pointer(info.aOrderBy)), - Len: int(info.nOrderBy), - Cap: int(info.nOrderBy), - })) - - ob := make([]InfoOrderBy, 0, len(slice)) - for _, c := range slice { - var desc bool - if c.desc > 0 { - desc = true - } - ob = append(ob, InfoOrderBy{ - Column: int(c.iColumn), - Desc: desc, - }) - } - return ob -} - -// IndexResult is a Go struct representation of what eventually ends up in the -// output fields for `sqlite3_index_info` -// See: https://www.sqlite.org/c3ref/index_info.html -type IndexResult struct { - Used []bool // aConstraintUsage - IdxNum int - IdxStr string - AlreadyOrdered bool // orderByConsumed - EstimatedCost float64 - EstimatedRows float64 -} - -func fillDBError(dbErr *Error, db *C.sqlite3) { - // See SQLiteConn.lastError(), in file 'sqlite3.go' at the time of writing (Sept 5, 2016) - dbErr.Code = ErrNo(C.sqlite3_errcode(db)) - dbErr.ExtendedCode = ErrNoExtended(C.sqlite3_extended_errcode(db)) - dbErr.err = C.GoString(C.sqlite3_errmsg(db)) -} - -// ColumnTypeDatabaseTypeName implement RowsColumnTypeDatabaseTypeName. -func (rc *SQLiteRows) ColumnTypeDatabaseTypeName(i int) string { - return C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))) -} - -// ColumnTypeNullable implement RowsColumnTypeNullable. -func (rc *SQLiteRows) ColumnTypeNullable(i int) (nullable, ok bool) { - return true, true -} - -// ColumnTypeScanType implement RowsColumnTypeScanType. -func (rc *SQLiteRows) ColumnTypeScanType(i int) reflect.Type { - //ct := C.sqlite3_column_type(rc.s.s, C.int(i)) // Always returns 5 - return scanType(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i)))) -} - -const ( - SQLITE_INTEGER = iota - SQLITE_TEXT - SQLITE_BLOB - SQLITE_REAL - SQLITE_NUMERIC - SQLITE_TIME - SQLITE_BOOL - SQLITE_NULL -) - -func scanType(cdt string) reflect.Type { - t := strings.ToUpper(cdt) - i := databaseTypeConvSqlite(t) - switch i { - case SQLITE_INTEGER: - return reflect.TypeOf(sql.NullInt64{}) - case SQLITE_TEXT: - return reflect.TypeOf(sql.NullString{}) - case SQLITE_BLOB: - return reflect.TypeOf(sql.RawBytes{}) - case SQLITE_REAL: - return reflect.TypeOf(sql.NullFloat64{}) - case SQLITE_NUMERIC: - return reflect.TypeOf(sql.NullFloat64{}) - case SQLITE_BOOL: - return reflect.TypeOf(sql.NullBool{}) - case SQLITE_TIME: - return reflect.TypeOf(sql.NullTime{}) - } - return reflect.TypeOf(new(any)) -} - -func databaseTypeConvSqlite(t string) int { - if strings.Contains(t, "INT") { - return SQLITE_INTEGER - } - if t == "CLOB" || t == "TEXT" || - strings.Contains(t, "CHAR") { - return SQLITE_TEXT - } - if t == "BLOB" { - return SQLITE_BLOB - } - if t == "REAL" || t == "FLOAT" || - strings.Contains(t, "DOUBLE") { - return SQLITE_REAL - } - if t == "DATE" || t == "DATETIME" || - t == "TIMESTAMP" { - return SQLITE_TIME - } - if t == "NUMERIC" || - strings.Contains(t, "DECIMAL") { - return SQLITE_NUMERIC - } - if t == "BOOLEAN" { - return SQLITE_BOOL - } - - return SQLITE_NULL -} diff --git a/src/golite.go b/src/golite.go new file mode 100644 index 0000000..ab3e596 --- /dev/null +++ b/src/golite.go @@ -0,0 +1,2616 @@ +package golite + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "io" + "math" + "net/url" + "reflect" + "runtime" + "strconv" + "strings" + "sync" + "time" + "unsafe" + "syscall" +) + + + +/* +#include +#include + +void stepTrampoline(sqlite3_context *, int, sqlite3_value **); +void doneTrampoline(sqlite3_context *); +int compareTrampoline(void *, int, char *, int, char *); +int commitHookTrampoline(void *); +void rollbackHookTrampoline(void *); +void updateHookTrampoline(void *, int, char *, char *, sqlite3_int64); +*/ +import "C" + + + +// SQLiteBackup implement interface of Backup. +type SQLiteBackup struct { + b *C.sqlite3_backup +} + +// Backup make backup from src to dest. +func (destConn *SQLiteConn) Backup(dest string, srcConn *SQLiteConn, src string) (*SQLiteBackup, error) { + destptr := C.CString(dest) + defer C.free(unsafe.Pointer(destptr)) + srcptr := C.CString(src) + defer C.free(unsafe.Pointer(srcptr)) + + if b := C.sqlite3_backup_init(destConn.db, destptr, srcConn.db, srcptr); b != nil { + bb := &SQLiteBackup{b: b} + runtime.SetFinalizer(bb, (*SQLiteBackup).Finish) + return bb, nil + } + return nil, destConn.lastError() +} + +// Step to backs up for one step. Calls the underlying `sqlite3_backup_step` +// function. This function returns a boolean indicating if the backup is done +// and an error signalling any other error. Done is returned if the underlying +// C function returns SQLITE_DONE (Code 101) +func (b *SQLiteBackup) Step(p int) (bool, error) { + ret := C.sqlite3_backup_step(b.b, C.int(p)) + if ret == C.SQLITE_DONE { + return true, nil + } else if ret != 0 && ret != C.SQLITE_LOCKED && ret != C.SQLITE_BUSY { + return false, Error{Code: ErrNo(ret)} + } + return false, nil +} + +// Remaining return whether have the rest for backup. +func (b *SQLiteBackup) Remaining() int { + return int(C.sqlite3_backup_remaining(b.b)) +} + +// PageCount return count of pages. +func (b *SQLiteBackup) PageCount() int { + return int(C.sqlite3_backup_pagecount(b.b)) +} + +// Finish close backup. +func (b *SQLiteBackup) Finish() error { + return b.Close() +} + +// Close close backup. +func (b *SQLiteBackup) Close() error { + ret := C.sqlite3_backup_finish(b.b) + + // sqlite3_backup_finish() never fails, it just returns the + // error code from previous operations, so clean up before + // checking and returning an error + b.b = nil + runtime.SetFinalizer(b, nil) + + if ret != 0 { + return Error{Code: ErrNo(ret)} + } + return nil +} + +//export stepTrampoline +func stepTrampoline( + ctx *C.sqlite3_context, + argc C.int, + argv **C.sqlite3_value, +) { + const size = (math.MaxInt32 - 1) / + unsafe.Sizeof((*C.sqlite3_value)(nil)) + args := (*[size]*C.sqlite3_value)(unsafe.Pointer(argv)) + slice := args[:int(argc):int(argc)] + lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo).Step(ctx, slice) +} + +//export doneTrampoline +func doneTrampoline(ctx *C.sqlite3_context) { + lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo).Done(ctx) +} + +//export compareTrampoline +func compareTrampoline( + handlePtr unsafe.Pointer, + la C.int, + a *C.char, + lb C.int, + b *C.char, +) C.int { + cmpFn := lookupHandle(handlePtr).(func(string, string) int) + return C.int(cmpFn(C.GoStringN(a, la), C.GoStringN(b, lb))) +} + +//export commitHookTrampoline +func commitHookTrampoline(handle unsafe.Pointer) C.int { + return C.int(lookupHandle(handle).(func() int)()) +} + +//export rollbackHookTrampoline +func rollbackHookTrampoline(handle unsafe.Pointer) { + lookupHandle(handle).(func())() +} + +//export updateHookTrampoline +func updateHookTrampoline( + handle unsafe.Pointer, + op C.int, + db *C.char, + table *C.char, + rowid int64, +) { + lookupHandle(handle).(func(int, string, string, int64))( + int(op), + C.GoString(db), + C.GoString(table), + int64(rowid), + ) +} + +// Use handles to avoid passing Go pointers to C. +type handleVal struct { + db *SQLiteConn + val any +} + +var handleLock sync.Mutex +var handleVals = make(map[unsafe.Pointer]handleVal) + +func newHandle(db *SQLiteConn, v any) unsafe.Pointer { + val := handleVal{db: db, val: v} + p := C.malloc(C.size_t(1)) + if p == nil { + panic("can't allocate 'cgo-pointer hack index pointer': ptr == nil") + } + { + handleLock.Lock() + defer handleLock.Unlock() + handleVals[p] = val + } + return p +} + +func lookupHandleVal(handle unsafe.Pointer) handleVal { + handleLock.Lock() + defer handleLock.Unlock() + return handleVals[handle] +} + +func lookupHandle(handle unsafe.Pointer) any { + return lookupHandleVal(handle).val +} + +func deleteHandles(db *SQLiteConn) { + handleLock.Lock() + defer handleLock.Unlock() + for handle, val := range handleVals { + if val.db == db { + delete(handleVals, handle) + C.free(handle) + } + } +} + +type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error) + +type callbackArgCast struct { + f callbackArgConverter + typ reflect.Type +} + +func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) { + val, err := c.f(v) + if err != nil { + return reflect.Value{}, err + } + if !val.Type().ConvertibleTo(c.typ) { + return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ) + } + return val.Convert(c.typ), nil +} + +func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { + return reflect.Value{}, fmt.Errorf("argument must be an INTEGER") + } + return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil +} + +func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { + return reflect.Value{}, fmt.Errorf("argument must be an INTEGER") + } + i := int64(C.sqlite3_value_int64(v)) + val := false + if i != 0 { + val = true + } + return reflect.ValueOf(val), nil +} + +func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_FLOAT { + return reflect.Value{}, fmt.Errorf("argument must be a FLOAT") + } + return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil +} + +func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_BLOB: + l := C.sqlite3_value_bytes(v) + p := C.sqlite3_value_blob(v) + return reflect.ValueOf(C.GoBytes(p, l)), nil + case C.SQLITE_TEXT: + l := C.sqlite3_value_bytes(v) + c := unsafe.Pointer(C.sqlite3_value_text(v)) + return reflect.ValueOf(C.GoBytes(c, l)), nil + default: + return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT") + } +} + +func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_BLOB: + l := C.sqlite3_value_bytes(v) + p := (*C.char)(C.sqlite3_value_blob(v)) + return reflect.ValueOf(C.GoStringN(p, l)), nil + case C.SQLITE_TEXT: + c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v))) + return reflect.ValueOf(C.GoString(c)), nil + default: + return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT") + } +} + +func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_INTEGER: + return callbackArgInt64(v) + case C.SQLITE_FLOAT: + return callbackArgFloat64(v) + case C.SQLITE_TEXT: + return callbackArgString(v) + case C.SQLITE_BLOB: + return callbackArgBytes(v) + case C.SQLITE_NULL: + // Interpret NULL as a nil byte slice. + var ret []byte + return reflect.ValueOf(ret), nil + default: + panic("unreachable") + } +} + +func callbackArg(typ reflect.Type) (callbackArgConverter, error) { + switch typ.Kind() { + case reflect.Interface: + if typ.NumMethod() != 0 { + return nil, errors.New("the only supported interface type is any") + } + return callbackArgGeneric, nil + case reflect.Slice: + if typ.Elem().Kind() != reflect.Uint8 { + return nil, errors.New("the only supported slice type is []byte") + } + return callbackArgBytes, nil + case reflect.String: + return callbackArgString, nil + case reflect.Bool: + return callbackArgBool, nil + case reflect.Int64: + return callbackArgInt64, nil + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: + c := callbackArgCast{callbackArgInt64, typ} + return c.Run, nil + case reflect.Float64: + return callbackArgFloat64, nil + case reflect.Float32: + c := callbackArgCast{callbackArgFloat64, typ} + return c.Run, nil + default: + return nil, fmt.Errorf("don't know how to convert to %s", typ) + } +} + +func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) { + var args []reflect.Value + + if len(argv) < len(converters) { + return nil, fmt.Errorf("function requires at least %d arguments", len(converters)) + } + + for i, arg := range argv[:len(converters)] { + v, err := converters[i](arg) + if err != nil { + return nil, err + } + args = append(args, v) + } + + if variadic != nil { + for _, arg := range argv[len(converters):] { + v, err := variadic(arg) + if err != nil { + return nil, err + } + args = append(args, v) + } + } + return args, nil +} + +type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error + +func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error { + switch v.Type().Kind() { + case reflect.Int64: + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: + v = v.Convert(reflect.TypeOf(int64(0))) + case reflect.Bool: + b := v.Interface().(bool) + if b { + v = reflect.ValueOf(int64(1)) + } else { + v = reflect.ValueOf(int64(0)) + } + default: + return fmt.Errorf("cannot convert %s to INTEGER", v.Type()) + } + + C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64))) + return nil +} + +func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error { + switch v.Type().Kind() { + case reflect.Float64: + case reflect.Float32: + v = v.Convert(reflect.TypeOf(float64(0))) + default: + return fmt.Errorf("cannot convert %s to FLOAT", v.Type()) + } + + C.sqlite3_result_double(ctx, C.double(v.Interface().(float64))) + return nil +} + +func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error { + if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 { + return fmt.Errorf("cannot convert %s to BLOB", v.Type()) + } + i := v.Interface() + if i == nil || len(i.([]byte)) == 0 { + C.sqlite3_result_null(ctx) + } else { + bs := i.([]byte) + C.sqlite3_result_blob( + ctx, + unsafe.Pointer(&bs[0]), + C.int(len(bs)), + C.SQLITE_TRANSIENT, + ) + } + return nil +} + +func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error { + if v.Type().Kind() != reflect.String { + return fmt.Errorf("cannot convert %s to TEXT", v.Type()) + } + C.sqlite3_result_text( + ctx, + C.CString(v.Interface().(string)), + -1, + (*[0]byte)(unsafe.Pointer(C.free)), + ) + return nil +} + +func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error { + return nil +} + +func callbackRetGeneric(ctx *C.sqlite3_context, v reflect.Value) error { + if v.IsNil() { + C.sqlite3_result_null(ctx) + return nil + } + + cb, err := callbackRet(v.Elem().Type()) + if err != nil { + return err + } + + return cb(ctx, v.Elem()) +} + +func callbackRet(typ reflect.Type) (callbackRetConverter, error) { + switch typ.Kind() { + case reflect.Interface: + errorInterface := reflect.TypeOf((*error)(nil)).Elem() + if typ.Implements(errorInterface) { + return callbackRetNil, nil + } + + if typ.NumMethod() == 0 { + return callbackRetGeneric, nil + } + + fallthrough + case reflect.Slice: + if typ.Elem().Kind() != reflect.Uint8 { + return nil, errors.New("the only supported slice type is []byte") + } + return callbackRetBlob, nil + case reflect.String: + return callbackRetText, nil + case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: + return callbackRetInteger, nil + case reflect.Float32, reflect.Float64: + return callbackRetFloat, nil + default: + return nil, fmt.Errorf("don't know how to convert to %s", typ) + } +} + +func callbackError(ctx *C.sqlite3_context, err error) { + cstr := C.CString(err.Error()) + defer C.free(unsafe.Pointer(cstr)) + C.sqlite3_result_error(ctx, cstr, C.int(-1)) +} + +// FIXME: remove this +// Test support code. Tests are not allowed to import "C", so we can't +// declare any functions that use C.sqlite3_value. +func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter { + return func(*C.sqlite3_value) (reflect.Value, error) { + return v, err + } +} + +// Extracted from Go database/sql source code +// Type conversions for Scan. + +var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error + +// convertAssign copies to dest the value in src, converting it if possible. +// An error is returned if the copy would result in loss of information. +// dest should be a pointer type. +func convertAssign(dest, src any) error { + // Common cases, without reflect. + switch s := src.(type) { + case string: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = s + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s) + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = append((*d)[:0], s...) + return nil + } + case []byte: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = string(s) + return nil + case *any: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = s + return nil + } + case time.Time: + switch d := dest.(type) { + case *time.Time: + *d = s + return nil + case *string: + *d = s.Format(time.RFC3339Nano) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s.Format(time.RFC3339Nano)) + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = s.AppendFormat((*d)[:0], time.RFC3339Nano) + return nil + } + case nil: + switch d := dest.(type) { + case *any: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = nil + return nil + } + } + + var sv reflect.Value + + switch d := dest.(type) { + case *string: + sv = reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + *d = asString(src) + return nil + } + case *[]byte: + sv = reflect.ValueOf(src) + if b, ok := asBytes(nil, sv); ok { + *d = b + return nil + } + case *sql.RawBytes: + sv = reflect.ValueOf(src) + if b, ok := asBytes([]byte(*d)[:0], sv); ok { + *d = sql.RawBytes(b) + return nil + } + case *bool: + bv, err := driver.Bool.ConvertValue(src) + if err == nil { + *d = bv.(bool) + } + return err + case *any: + *d = src + return nil + } + + if scanner, ok := dest.(sql.Scanner); ok { + return scanner.Scan(src) + } + + dpv := reflect.ValueOf(dest) + if dpv.Kind() != reflect.Ptr { + return errors.New("destination not a pointer") + } + if dpv.IsNil() { + return errNilPtr + } + + if !sv.IsValid() { + sv = reflect.ValueOf(src) + } + + dv := reflect.Indirect(dpv) + if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { + switch b := src.(type) { + case []byte: + dv.Set(reflect.ValueOf(cloneBytes(b))) + default: + dv.Set(sv) + } + return nil + } + + if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { + dv.Set(sv.Convert(dv.Type())) + return nil + } + + // The following conversions use a string value as an intermediate representation + // to convert between various numeric types. + // + // This also allows scanning into user defined types such as "type Int int64". + // For symmetry, also check for string destination types. + switch dv.Kind() { + case reflect.Ptr: + if src == nil { + dv.Set(reflect.Zero(dv.Type())) + return nil + } + dv.Set(reflect.New(dv.Type().Elem())) + return convertAssign(dv.Interface(), src) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + s := asString(src) + i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + s := asString(src) + u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + s := asString(src) + f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + case reflect.String: + switch v := src.(type) { + case string: + dv.SetString(v) + return nil + case []byte: + dv.SetString(string(v)) + return nil + } + } + + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) +} + +func strconvErr(err error) error { + if ne, ok := err.(*strconv.NumError); ok { + return ne.Err + } + return err +} + +func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } + c := make([]byte, len(b)) + copy(c, b) + return c +} + +func asString(src any) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} + +func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.AppendInt(buf, rv.Int(), 10), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.AppendUint(buf, rv.Uint(), 10), true + case reflect.Float32: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true + case reflect.Float64: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true + case reflect.Bool: + return strconv.AppendBool(buf, rv.Bool()), true + case reflect.String: + s := rv.String() + return append(buf, s...), true + } + return +} + +/* +Package sqlite3 provides interface to SQLite3 databases. + +This works as a driver for database/sql. + +Installation + + go get github.com/mattn/go-sqlite3 + +# Supported Types + +Currently, go-sqlite3 supports the following data types. + + +------------------------------+ + |go | sqlite3 | + |----------|-------------------| + |nil | null | + |int | integer | + |int64 | integer | + |float64 | float | + |bool | integer | + |[]byte | blob | + |string | text | + |time.Time | timestamp/datetime| + +------------------------------+ + +# SQLite3 Extension + +You can write your own extension module for sqlite3. For example, below is an +extension for a Regexp matcher operation. + + #include + #include + #include + #include + + SQLITE_EXTENSION_INIT1 + static void regexp_func(sqlite3_context *context, int argc, sqlite3_value **argv) { + if (argc >= 2) { + const char *target = (const char *)sqlite3_value_text(argv[1]); + const char *pattern = (const char *)sqlite3_value_text(argv[0]); + const char* errstr = NULL; + int erroff = 0; + int vec[500]; + int n, rc; + pcre* re = pcre_compile(pattern, 0, &errstr, &erroff, NULL); + rc = pcre_exec(re, NULL, target, strlen(target), 0, 0, vec, 500); + if (rc <= 0) { + sqlite3_result_error(context, errstr, 0); + return; + } + sqlite3_result_int(context, 1); + } + } + + int sqlite3_extension_init(sqlite3 *db, char **errmsg, + const sqlite3_api_routines *api) { + SQLITE_EXTENSION_INIT2(api); + return sqlite3_create_function(db, "regexp", 2, SQLITE_UTF8, + (void*)db, regexp_func, NULL, NULL); + } + +It needs to be built as a so/dll shared library. And you need to register +the extension module like below. + + sql.Register("sqlite3_with_extensions", + &sqlite3.SQLiteDriver{ + Extensions: []string{ + "sqlite3_mod_regexp", + }, + }) + +Then, you can use this extension. + + rows, err := db.Query("select text from mytable where name regexp '^golang'") + +# Connection Hook + +You can hook and inject your code when the connection is established by setting +ConnectHook to get the SQLiteConn. + + sql.Register("sqlite3_with_hook_example", + &sqlite3.SQLiteDriver{ + ConnectHook: func(conn *sqlite3.SQLiteConn) error { + sqlite3conn = append(sqlite3conn, conn) + return nil + }, + }) + +You can also use database/sql.Conn.Raw (Go >= 1.13): + + conn, err := db.Conn(context.Background()) + // if err != nil { ... } + defer conn.Close() + err = conn.Raw(func (driverConn any) error { + sqliteConn := driverConn.(*sqlite3.SQLiteConn) + // ... use sqliteConn + }) + // if err != nil { ... } + +# Go SQlite3 Extensions + +If you want to register Go functions as SQLite extension functions +you can make a custom driver by calling RegisterFunction from +ConnectHook. + + regex = func(re, s string) (bool, error) { + return regexp.MatchString(re, s) + } + sql.Register("sqlite3_extended", + &sqlite3.SQLiteDriver{ + ConnectHook: func(conn *sqlite3.SQLiteConn) error { + return conn.RegisterFunc("regexp", regex, true) + }, + }) + +You can then use the custom driver by passing its name to sql.Open. + + var i int + conn, err := sql.Open("sqlite3_extended", "./foo.db") + if err != nil { + panic(err) + } + err = db.QueryRow(`SELECT regexp("foo.*", "seafood")`).Scan(&i) + if err != nil { + panic(err) + } + +See the documentation of RegisterFunc for more details. +*/ + +// ErrNo inherit errno. +type ErrNo int + +// ErrNoMask is mask code. +const ErrNoMask C.int = 0xff + +// ErrNoExtended is extended errno. +type ErrNoExtended int + +// Error implement sqlite error code. +type Error struct { + Code ErrNo /* The error code returned by SQLite */ + ExtendedCode ErrNoExtended /* The extended error code returned by SQLite */ + SystemErrno syscall.Errno /* The system errno returned by the OS through SQLite, if applicable */ + err string /* The error string returned by sqlite3_errmsg(), + this usually contains more specific details. */ +} + +// result codes from http://www.sqlite.org/c3ref/c_abort.html +var ( + ErrError = ErrNo(1) /* SQL error or missing database */ + ErrInternal = ErrNo(2) /* Internal logic error in SQLite */ + ErrPerm = ErrNo(3) /* Access permission denied */ + ErrAbort = ErrNo(4) /* Callback routine requested an abort */ + ErrBusy = ErrNo(5) /* The database file is locked */ + ErrLocked = ErrNo(6) /* A table in the database is locked */ + ErrNomem = ErrNo(7) /* A malloc() failed */ + ErrReadonly = ErrNo(8) /* Attempt to write a readonly database */ + ErrInterrupt = ErrNo(9) /* Operation terminated by sqlite3_interrupt() */ + ErrIoErr = ErrNo(10) /* Some kind of disk I/O error occurred */ + ErrCorrupt = ErrNo(11) /* The database disk image is malformed */ + ErrNotFound = ErrNo(12) /* Unknown opcode in sqlite3_file_control() */ + ErrFull = ErrNo(13) /* Insertion failed because database is full */ + ErrCantOpen = ErrNo(14) /* Unable to open the database file */ + ErrProtocol = ErrNo(15) /* Database lock protocol error */ + ErrEmpty = ErrNo(16) /* Database is empty */ + ErrSchema = ErrNo(17) /* The database schema changed */ + ErrTooBig = ErrNo(18) /* String or BLOB exceeds size limit */ + ErrConstraint = ErrNo(19) /* Abort due to constraint violation */ + ErrMismatch = ErrNo(20) /* Data type mismatch */ + ErrMisuse = ErrNo(21) /* Library used incorrectly */ + ErrNoLFS = ErrNo(22) /* Uses OS features not supported on host */ + ErrAuth = ErrNo(23) /* Authorization denied */ + ErrFormat = ErrNo(24) /* Auxiliary database format error */ + ErrRange = ErrNo(25) /* 2nd parameter to sqlite3_bind out of range */ + ErrNotADB = ErrNo(26) /* File opened that is not a database file */ + ErrNotice = ErrNo(27) /* Notifications from sqlite3_log() */ + ErrWarning = ErrNo(28) /* Warnings from sqlite3_log() */ +) + +// Error return error message from errno. +func (err ErrNo) Error() string { + return Error{Code: err}.Error() +} + +// Extend return extended errno. +func (err ErrNo) Extend(by int) ErrNoExtended { + return ErrNoExtended(int(err) | (by << 8)) +} + +// Error return error message that is extended code. +func (err ErrNoExtended) Error() string { + return Error{Code: ErrNo(C.int(err) & ErrNoMask), ExtendedCode: err}.Error() +} + +func (err Error) Error() string { + var str string + if err.err != "" { + str = err.err + } else { + str = C.GoString(C.sqlite3_errstr(C.int(err.Code))) + } + if err.SystemErrno != 0 { + str += ": " + err.SystemErrno.Error() + } + return str +} + +// result codes from http://www.sqlite.org/c3ref/c_abort_rollback.html +var ( + ErrIoErrRead = ErrIoErr.Extend(1) + ErrIoErrShortRead = ErrIoErr.Extend(2) + ErrIoErrWrite = ErrIoErr.Extend(3) + ErrIoErrFsync = ErrIoErr.Extend(4) + ErrIoErrDirFsync = ErrIoErr.Extend(5) + ErrIoErrTruncate = ErrIoErr.Extend(6) + ErrIoErrFstat = ErrIoErr.Extend(7) + ErrIoErrUnlock = ErrIoErr.Extend(8) + ErrIoErrRDlock = ErrIoErr.Extend(9) + ErrIoErrDelete = ErrIoErr.Extend(10) + ErrIoErrBlocked = ErrIoErr.Extend(11) + ErrIoErrNoMem = ErrIoErr.Extend(12) + ErrIoErrAccess = ErrIoErr.Extend(13) + ErrIoErrCheckReservedLock = ErrIoErr.Extend(14) + ErrIoErrLock = ErrIoErr.Extend(15) + ErrIoErrClose = ErrIoErr.Extend(16) + ErrIoErrDirClose = ErrIoErr.Extend(17) + ErrIoErrSHMOpen = ErrIoErr.Extend(18) + ErrIoErrSHMSize = ErrIoErr.Extend(19) + ErrIoErrSHMLock = ErrIoErr.Extend(20) + ErrIoErrSHMMap = ErrIoErr.Extend(21) + ErrIoErrSeek = ErrIoErr.Extend(22) + ErrIoErrDeleteNoent = ErrIoErr.Extend(23) + ErrIoErrMMap = ErrIoErr.Extend(24) + ErrIoErrGetTempPath = ErrIoErr.Extend(25) + ErrIoErrConvPath = ErrIoErr.Extend(26) + ErrLockedSharedCache = ErrLocked.Extend(1) + ErrBusyRecovery = ErrBusy.Extend(1) + ErrBusySnapshot = ErrBusy.Extend(2) + ErrCantOpenNoTempDir = ErrCantOpen.Extend(1) + ErrCantOpenIsDir = ErrCantOpen.Extend(2) + ErrCantOpenFullPath = ErrCantOpen.Extend(3) + ErrCantOpenConvPath = ErrCantOpen.Extend(4) + ErrCorruptVTab = ErrCorrupt.Extend(1) + ErrReadonlyRecovery = ErrReadonly.Extend(1) + ErrReadonlyCantLock = ErrReadonly.Extend(2) + ErrReadonlyRollback = ErrReadonly.Extend(3) + ErrReadonlyDbMoved = ErrReadonly.Extend(4) + ErrAbortRollback = ErrAbort.Extend(2) + ErrConstraintCheck = ErrConstraint.Extend(1) + ErrConstraintCommitHook = ErrConstraint.Extend(2) + ErrConstraintForeignKey = ErrConstraint.Extend(3) + ErrConstraintFunction = ErrConstraint.Extend(4) + ErrConstraintNotNull = ErrConstraint.Extend(5) + ErrConstraintPrimaryKey = ErrConstraint.Extend(6) + ErrConstraintTrigger = ErrConstraint.Extend(7) + ErrConstraintUnique = ErrConstraint.Extend(8) + ErrConstraintVTab = ErrConstraint.Extend(9) + ErrConstraintRowID = ErrConstraint.Extend(10) + ErrNoticeRecoverWAL = ErrNotice.Extend(1) + ErrNoticeRecoverRollback = ErrNotice.Extend(2) + ErrWarningAutoIndex = ErrWarning.Extend(1) +) + +// FIXME: remove this +// SQLiteTimestampFormats is timestamp formats understood by both this module +// and SQLite. The first format in the slice will be used when saving time +// values into the database. When parsing a string from a timestamp or datetime +// column, the formats are tried in order. +var SQLiteTimestampFormats = []string{ + // By default, store timestamps with whatever timezone they come with. + // When parsed, they will be returned with the same timezone. + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02T15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02T15:04:05.999999999", + "2006-01-02 15:04:05", + "2006-01-02T15:04:05", + "2006-01-02 15:04", + "2006-01-02T15:04", + "2006-01-02", +} + +const ( + columnDate string = "date" + columnDatetime string = "datetime" + columnTimestamp string = "timestamp" +) + +const DriverName = "acude" +func init() { + sql.Register(DriverName, &SQLiteDriver{}) +} + +// Version returns SQLite library version information. +func LibVersion() (libVersion string, libVersionNumber int, sourceID string) { + libVersion = C.GoString(C.sqlite3_libversion()) + libVersionNumber = int(C.sqlite3_libversion_number()) + sourceID = C.GoString(C.sqlite3_sourceid()) + return libVersion, libVersionNumber, sourceID +} + +const ( + // used by update hook + SQLITE_DELETE = C.SQLITE_DELETE + SQLITE_INSERT = C.SQLITE_INSERT + SQLITE_UPDATE = C.SQLITE_UPDATE +) + +// Standard File Control Opcodes +// See: https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html +const ( + SQLITE_FCNTL_LOCKSTATE = int(1) + SQLITE_FCNTL_GET_LOCKPROXYFILE = int(2) + SQLITE_FCNTL_SET_LOCKPROXYFILE = int(3) + SQLITE_FCNTL_LAST_ERRNO = int(4) + SQLITE_FCNTL_SIZE_HINT = int(5) + SQLITE_FCNTL_CHUNK_SIZE = int(6) + SQLITE_FCNTL_FILE_POINTER = int(7) + SQLITE_FCNTL_SYNC_OMITTED = int(8) + SQLITE_FCNTL_WIN32_AV_RETRY = int(9) + SQLITE_FCNTL_PERSIST_WAL = int(10) + SQLITE_FCNTL_OVERWRITE = int(11) + SQLITE_FCNTL_VFSNAME = int(12) + SQLITE_FCNTL_POWERSAFE_OVERWRITE = int(13) + SQLITE_FCNTL_PRAGMA = int(14) + SQLITE_FCNTL_BUSYHANDLER = int(15) + SQLITE_FCNTL_TEMPFILENAME = int(16) + SQLITE_FCNTL_MMAP_SIZE = int(18) + SQLITE_FCNTL_TRACE = int(19) + SQLITE_FCNTL_HAS_MOVED = int(20) + SQLITE_FCNTL_SYNC = int(21) + SQLITE_FCNTL_COMMIT_PHASETWO = int(22) + SQLITE_FCNTL_WIN32_SET_HANDLE = int(23) + SQLITE_FCNTL_WAL_BLOCK = int(24) + SQLITE_FCNTL_ZIPVFS = int(25) + SQLITE_FCNTL_RBU = int(26) + SQLITE_FCNTL_VFS_POINTER = int(27) + SQLITE_FCNTL_JOURNAL_POINTER = int(28) + SQLITE_FCNTL_WIN32_GET_HANDLE = int(29) + SQLITE_FCNTL_PDB = int(30) + SQLITE_FCNTL_BEGIN_ATOMIC_WRITE = int(31) + SQLITE_FCNTL_COMMIT_ATOMIC_WRITE = int(32) + SQLITE_FCNTL_ROLLBACK_ATOMIC_WRITE = int(33) + SQLITE_FCNTL_LOCK_TIMEOUT = int(34) + SQLITE_FCNTL_DATA_VERSION = int(35) + SQLITE_FCNTL_SIZE_LIMIT = int(36) + SQLITE_FCNTL_CKPT_DONE = int(37) + SQLITE_FCNTL_RESERVE_BYTES = int(38) + SQLITE_FCNTL_CKPT_START = int(39) + SQLITE_FCNTL_EXTERNAL_READER = int(40) + SQLITE_FCNTL_CKSM_FILE = int(41) +) + +// SQLiteDriver implements driver.Driver. +type SQLiteDriver struct { + ConnectHook func(*SQLiteConn) error +} + +// SQLiteConn implements driver.Conn. +type SQLiteConn struct { + mu sync.Mutex + db *C.sqlite3 + loc *time.Location + funcs []*functionInfo + aggregators []*aggInfo +} + +// SQLiteTx implements driver.Tx. +type SQLiteTx struct { + c *SQLiteConn +} + +// SQLiteStmt implements driver.Stmt. +type SQLiteStmt struct { + mu sync.Mutex + c *SQLiteConn + s *C.sqlite3_stmt + t string + closed bool + cls bool +} + +// SQLiteResult implements sql.Result. +type SQLiteResult struct { + id int64 + changes int64 +} + +// SQLiteRows implements driver.Rows. +type SQLiteRows struct { + s *SQLiteStmt + nc int + cols []string + decltype []string + cls bool + closed bool + ctx context.Context // no better alternative to pass context into Next() method +} + +type functionInfo struct { + f reflect.Value + argConverters []callbackArgConverter + variadicConverter callbackArgConverter + retConverter callbackRetConverter +} + +func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { + args, err := callbackConvertArgs(argv, fi.argConverters, fi.variadicConverter) + if err != nil { + callbackError(ctx, err) + return + } + + ret := fi.f.Call(args) + + if len(ret) == 2 && ret[1].Interface() != nil { + callbackError(ctx, ret[1].Interface().(error)) + return + } + + err = fi.retConverter(ctx, ret[0]) + if err != nil { + callbackError(ctx, err) + return + } +} + +type aggInfo struct { + constructor reflect.Value + + // Active aggregator objects for aggregations in flight. The + // aggregators are indexed by a counter stored in the aggregation + // user data space provided by sqlite. + active map[int64]reflect.Value + next int64 + + stepArgConverters []callbackArgConverter + stepVariadicConverter callbackArgConverter + + doneRetConverter callbackRetConverter +} + +func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) { + aggIdx := (*int64)(C.sqlite3_aggregate_context(ctx, C.int(8))) + if *aggIdx == 0 { + *aggIdx = ai.next + ret := ai.constructor.Call(nil) + if len(ret) == 2 && ret[1].Interface() != nil { + return 0, reflect.Value{}, ret[1].Interface().(error) + } + if ret[0].IsNil() { + return 0, reflect.Value{}, errors.New("aggregator constructor returned nil state") + } + ai.next++ + ai.active[*aggIdx] = ret[0] + } + return *aggIdx, ai.active[*aggIdx], nil +} + +func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { + _, agg, err := ai.agg(ctx) + if err != nil { + callbackError(ctx, err) + return + } + + args, err := callbackConvertArgs(argv, ai.stepArgConverters, ai.stepVariadicConverter) + if err != nil { + callbackError(ctx, err) + return + } + + ret := agg.MethodByName("Step").Call(args) + if len(ret) == 1 && ret[0].Interface() != nil { + callbackError(ctx, ret[0].Interface().(error)) + return + } +} + +func (ai *aggInfo) Done(ctx *C.sqlite3_context) { + idx, agg, err := ai.agg(ctx) + if err != nil { + callbackError(ctx, err) + return + } + defer func() { delete(ai.active, idx) }() + + ret := agg.MethodByName("Done").Call(nil) + if len(ret) == 2 && ret[1].Interface() != nil { + callbackError(ctx, ret[1].Interface().(error)) + return + } + + err = ai.doneRetConverter(ctx, ret[0]) + if err != nil { + callbackError(ctx, err) + return + } +} + +// Commit transaction. +func (tx *SQLiteTx) Commit() error { + _, err := tx.c.exec(context.Background(), "COMMIT", nil) + if err != nil { + // sqlite3 may leave the transaction open in this scenario. + // However, database/sql considers the transaction complete once we + // return from Commit() - we must clean up to honour its semantics. + // We don't know if the ROLLBACK is strictly necessary, but according + // to sqlite's docs, there is no harm in calling ROLLBACK unnecessarily. + tx.c.exec(context.Background(), "ROLLBACK", nil) + } + return err +} + +// Rollback transaction. +func (tx *SQLiteTx) Rollback() error { + _, err := tx.c.exec(context.Background(), "ROLLBACK", nil) + return err +} + +// RegisterCollation makes a Go function available as a collation. +// +// cmp receives two UTF-8 strings, a and b. The result should be 0 if +// a==b, -1 if a < b, and +1 if a > b. +// +// cmp must always return the same result given the same +// inputs. Additionally, it must have the following properties for all +// strings A, B and C: if A==B then B==A; if A==B and B==C then A==C; +// if AA; if A 0 { + stmtArgs = append(stmtArgs, args[start:start+na]...) + for i := range args { + if (i < start || i >= na) && args[i].Name != "" { + stmtArgs = append(stmtArgs, args[i]) + } + } + for i := range stmtArgs { + stmtArgs[i].Ordinal = i + 1 + } + } + res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs) + if err != nil && err != driver.ErrSkip { + s.Close() + return nil, err + } + start += na + } + tail := s.(*SQLiteStmt).t + s.Close() + if tail == "" { + if res == nil { + // https://github.com/mattn/go-sqlite3/issues/963 + res = &SQLiteResult{0, 0} + } + return res, nil + } + query = tail + } +} + +// Query implements Queryer. +func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) { + list := make([]driver.NamedValue, len(args)) + for i, v := range args { + list[i] = driver.NamedValue{ + Ordinal: i + 1, + Value: v, + } + } + return c.query(context.Background(), query, list) +} + +func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + start := 0 + for { + stmtArgs := make([]driver.NamedValue, 0, len(args)) + s, err := c.prepare(ctx, query) + if err != nil { + return nil, err + } + s.(*SQLiteStmt).cls = true + na := s.NumInput() + if len(args)-start < na { + return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start) + } + // consume the number of arguments used in the current + // statement and append all named arguments not contained + // therein + stmtArgs = append(stmtArgs, args[start:start+na]...) + for i := range args { + if (i < start || i >= na) && args[i].Name != "" { + stmtArgs = append(stmtArgs, args[i]) + } + } + for i := range stmtArgs { + stmtArgs[i].Ordinal = i + 1 + } + rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs) + if err != nil && err != driver.ErrSkip { + s.Close() + return rows, err + } + start += na + tail := s.(*SQLiteStmt).t + if tail == "" { + return rows, nil + } + rows.Close() + s.Close() + query = tail + } +} + +// Begin transaction. +func (c *SQLiteConn) Begin() (driver.Tx, error) { + return c.begin(context.Background()) +} + +func (c *SQLiteConn) begin(ctx context.Context) (driver.Tx, error) { + if _, err := c.exec(ctx, "BEGIN", nil); err != nil { + return nil, err + } + return &SQLiteTx{c}, nil +} + +// Open database and return a new connection. +// +// A pragma can take either zero or one argument. +// The argument is may be either in parentheses or it may be separated from +// the pragma name by an equal sign. The two syntaxes yield identical results. +// In many pragmas, the argument is a boolean. The boolean can be one of: +// +// 1 yes true on +// 0 no false off +// +// You can specify a DSN string using a URI as the filename. +// +// test.db +// file:test.db?cache=shared&mode=memory +// :memory: +// file::memory: +// +// cache +// SQLite Shared-Cache Mode +// https://www.sqlite.org/sharedcache.html +// Values: +// - shared +// - private +// +// go-sqlite3 adds the following query parameters to those used by SQLite: +// +// _loc=XXX +// Specify location of time format. It's possible to specify "auto". +func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { + // Options + var loc *time.Location + + pos := strings.IndexRune(dsn, '?') + if pos >= 1 { + params, err := url.ParseQuery(dsn[pos+1:]) + if err != nil { + return nil, err + } + + // _loc + if val := params.Get("_loc"); val != "" { + switch strings.ToLower(val) { + case "auto": + loc = time.Local + default: + loc, err = time.LoadLocation(val) + if err != nil { + return nil, fmt.Errorf("Invalid _loc: %v: %v", val, err) + } + } + } + + if !strings.HasPrefix(dsn, "file:") { + dsn = dsn[:pos] + } + } + + var db *C.sqlite3 + name := C.CString(dsn) + defer C.free(unsafe.Pointer(name)) + + var openFlags C.int = + C.SQLITE_OPEN_READWRITE | // FIXME: fails if RO FS? + C.SQLITE_OPEN_CREATE | + C.SQLITE_OPEN_URI | + C.SQLITE_OPEN_FULLMUTEX + rv := C.sqlite3_open_v2(name, &db, openFlags, nil) + if rv != 0 { + // Save off the error _before_ closing the database. + // This is safe even if db is nil. + err := lastError(db) + if db != nil { + C.sqlite3_close_v2(db) + } + return nil, err + } + if db == nil { + return nil, errors.New("sqlite succeeded without returning a database") + } + + const setup = ` + PRAGMA journal_mode = WAL; + PRAGMA busy_timeout = 10000; + ` + setupCStr := C.CString(setup) + defer C.free(unsafe.Pointer(setupCStr)) + rv = C.sqlite3_exec(db, setupCStr, nil, nil, nil) + if rv != C.SQLITE_OK { + err := lastError(db) + C.sqlite3_close_v2(db) + return nil, err + } + + // Create connection to SQLite + conn := &SQLiteConn{db: db, loc: loc} + + if d.ConnectHook != nil { + if err := d.ConnectHook(conn); err != nil { + conn.Close() + return nil, err + } + } + + runtime.SetFinalizer(conn, (*SQLiteConn).Close) + return conn, nil +} + +// Close the connection. +func (c *SQLiteConn) Close() error { + rv := C.sqlite3_close_v2(c.db) + if rv != C.SQLITE_OK { + return c.lastError() + } + deleteHandles(c) + c.mu.Lock() + c.db = nil + c.mu.Unlock() + runtime.SetFinalizer(c, nil) + return nil +} + +func (c *SQLiteConn) dbConnOpen() bool { + if c == nil { + return false + } + c.mu.Lock() + defer c.mu.Unlock() + return c.db != nil +} + +// Prepare the query string. Return a new statement. +func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { + return c.prepare(context.Background(), query) +} + +func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, error) { + pquery := C.CString(query) + defer C.free(unsafe.Pointer(pquery)) + var s *C.sqlite3_stmt + var tail *C.char + rv := C.sqlite3_prepare_v2(c.db, pquery, C.int(-1), &s, &tail) + if rv != C.SQLITE_OK { + return nil, c.lastError() + } + var t string + if tail != nil && *tail != '\000' { + t = strings.TrimSpace(C.GoString(tail)) + } + ss := &SQLiteStmt{c: c, s: s, t: t} + runtime.SetFinalizer(ss, (*SQLiteStmt).Close) + return ss, nil +} + +// Run-Time Limit Categories. +// See: http://www.sqlite.org/c3ref/c_limit_attached.html +const ( + SQLITE_LIMIT_LENGTH = C.SQLITE_LIMIT_LENGTH + SQLITE_LIMIT_SQL_LENGTH = C.SQLITE_LIMIT_SQL_LENGTH + SQLITE_LIMIT_COLUMN = C.SQLITE_LIMIT_COLUMN + SQLITE_LIMIT_EXPR_DEPTH = C.SQLITE_LIMIT_EXPR_DEPTH + SQLITE_LIMIT_COMPOUND_SELECT = C.SQLITE_LIMIT_COMPOUND_SELECT + SQLITE_LIMIT_VDBE_OP = C.SQLITE_LIMIT_VDBE_OP + SQLITE_LIMIT_FUNCTION_ARG = C.SQLITE_LIMIT_FUNCTION_ARG + SQLITE_LIMIT_ATTACHED = C.SQLITE_LIMIT_ATTACHED + SQLITE_LIMIT_LIKE_PATTERN_LENGTH = C.SQLITE_LIMIT_LIKE_PATTERN_LENGTH + SQLITE_LIMIT_VARIABLE_NUMBER = C.SQLITE_LIMIT_VARIABLE_NUMBER + SQLITE_LIMIT_TRIGGER_DEPTH = C.SQLITE_LIMIT_TRIGGER_DEPTH + SQLITE_LIMIT_WORKER_THREADS = C.SQLITE_LIMIT_WORKER_THREADS +) + +// GetFilename returns the absolute path to the file containing +// the requested schema. When passed an empty string, it will +// instead use the database's default schema: "main". +// See: sqlite3_db_filename, https://www.sqlite.org/c3ref/db_filename.html +func (c *SQLiteConn) GetFilename(schemaName string) string { + if schemaName == "" { + schemaName = "main" + } + return C.GoString(C.sqlite3_db_filename(c.db, C.CString(schemaName))) +} + +func (c *SQLiteConn) GetLimit(id int) int { + return int(C.sqlite3_limit(c.db, C.int(id), C.int(-1))) +} + +func (c *SQLiteConn) SetLimit(id int, newVal int) int { + return int(C.sqlite3_limit(c.db, C.int(id), C.int(newVal))) +} + +// SetFileControlInt invokes the xFileControl method on a given database. The +// dbName is the name of the database. It will default to "main" if left blank. +// The op is one of the opcodes prefixed by "SQLITE_FCNTL_". The arg argument +// and return code are both opcode-specific. Please see the SQLite documentation. +// +// This method is not thread-safe as the returned error code can be changed by +// another call if invoked concurrently. +// +// See: sqlite3_file_control, https://www.sqlite.org/c3ref/file_control.html +func (c *SQLiteConn) SetFileControlInt(dbName string, op int, arg int) error { + if dbName == "" { + dbName = "main" + } + + cDBName := C.CString(dbName) + defer C.free(unsafe.Pointer(cDBName)) + + cArg := C.int(arg) + rv := C.sqlite3_file_control(c.db, cDBName, C.int(op), unsafe.Pointer(&cArg)) + if rv != C.SQLITE_OK { + return c.lastError() + } + return nil +} + +func (s *SQLiteStmt) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return nil + } + s.closed = true + if !s.c.dbConnOpen() { + return errors.New("sqlite statement with already closed database connection") + } + rv := C.sqlite3_finalize(s.s) + s.s = nil + if rv != C.SQLITE_OK { + return s.c.lastError() + } + s.c = nil + runtime.SetFinalizer(s, nil) + return nil +} + +func (s *SQLiteStmt) NumInput() int { + return int(C.sqlite3_bind_parameter_count(s.s)) +} + +var placeHolder = []byte{0} + +func (s *SQLiteStmt) bind(args []driver.NamedValue) error { + rv := C.sqlite3_reset(s.s) + if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { + return s.c.lastError() + } + + bindIndices := make([][3]int, len(args)) + prefixes := []string{":", "@", "$"} + for i, v := range args { + bindIndices[i][0] = args[i].Ordinal + if v.Name != "" { + for j := range prefixes { + cname := C.CString(prefixes[j] + v.Name) + bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname)) + C.free(unsafe.Pointer(cname)) + } + args[i].Ordinal = bindIndices[i][0] + } + } + + for i, arg := range args { + for j := range bindIndices[i] { + if bindIndices[i][j] == 0 { + continue + } + n := C.int(bindIndices[i][j]) + switch v := arg.Value.(type) { + case nil: + rv = C.sqlite3_bind_null(s.s, n) + case string: + if len(v) == 0 { + rv = C.sqlite3_bind_text( + s.s, + n, + (*C.char)(unsafe.Pointer(&placeHolder[0])), + C.int(0), + C.SQLITE_TRANSIENT, + ) + } else { + b := []byte(v) + rv = C.sqlite3_bind_text( + s.s, + n, + (*C.char)(unsafe.Pointer(&b[0])), + C.int(len(b)), + C.SQLITE_TRANSIENT, + ) + } + case int64: + rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v)) + case bool: + if v { + rv = C.sqlite3_bind_int(s.s, n, 1) + } else { + rv = C.sqlite3_bind_int(s.s, n, 0) + } + case float64: + rv = C.sqlite3_bind_double(s.s, n, C.double(v)) + case []byte: + if v == nil { + rv = C.sqlite3_bind_null(s.s, n) + } else { + ln := len(v) + if ln == 0 { + v = placeHolder + } + rv = C.sqlite3_bind_blob( + s.s, + n, + unsafe.Pointer(&v[0]), + C.int(ln), + C.SQLITE_TRANSIENT, + ) + } + case time.Time: + b := []byte(v.Format(SQLiteTimestampFormats[0])) + rv = C.sqlite3_bind_text( + s.s, + n, + (*C.char)(unsafe.Pointer(&b[0])), + C.int(len(b)), + C.SQLITE_TRANSIENT, + ) + } + if rv != C.SQLITE_OK { + return s.c.lastError() + } + } + } + return nil +} + +// Query the statement with arguments. Return records. +func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) { + list := make([]driver.NamedValue, len(args)) + for i, v := range args { + list[i] = driver.NamedValue{ + Ordinal: i + 1, + Value: v, + } + } + return s.query(context.Background(), list) +} + +func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + if err := s.bind(args); err != nil { + return nil, err + } + + rows := &SQLiteRows{ + s: s, + nc: int(C.sqlite3_column_count(s.s)), + cols: nil, + decltype: nil, + cls: s.cls, + closed: false, + ctx: ctx, + } + runtime.SetFinalizer(rows, (*SQLiteRows).Close) + + return rows, nil +} + +// LastInsertId return last inserted ID. +func (r *SQLiteResult) LastInsertId() (int64, error) { + return r.id, nil +} + +// RowsAffected return how many rows affected. +func (r *SQLiteResult) RowsAffected() (int64, error) { + return r.changes, nil +} + +// Exec execute the statement with arguments. Return result object. +func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { + list := make([]driver.NamedValue, len(args)) + for i, v := range args { + list[i] = driver.NamedValue{ + Ordinal: i + 1, + Value: v, + } + } + return s.exec(context.Background(), list) +} + +func isInterruptErr(err error) bool { + sqliteErr, ok := err.(Error) + if ok { + return sqliteErr.Code == ErrInterrupt + } + return false +} + +// exec executes a query that doesn't return rows. Attempts to honor context timeout. +func (s *SQLiteStmt) exec(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + if ctx.Done() == nil { + return s.execSync(args) + } + + type result struct { + r driver.Result + err error + } + resultCh := make(chan result) + defer close(resultCh) + go func() { + r, err := s.execSync(args) + resultCh <- result{r, err} + }() + var rv result + select { + case rv = <-resultCh: + case <-ctx.Done(): + select { + case rv = <-resultCh: // no need to interrupt, operation completed in db + default: + // this is still racy and can be no-op if executed between sqlite3_* calls in execSync. + C.sqlite3_interrupt(s.c.db) + rv = <-resultCh // wait for goroutine completed + if isInterruptErr(rv.err) { + return nil, ctx.Err() + } + } + } + return rv.r, rv.err +} + +func (s *SQLiteStmt) execSync(args []driver.NamedValue) (driver.Result, error) { + if err := s.bind(args); err != nil { + C.sqlite3_reset(s.s) + C.sqlite3_clear_bindings(s.s) + return nil, err + } + + rv := C.sqlite3_step(s.s) + if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { + err := s.c.lastError() + C.sqlite3_reset(s.s) + C.sqlite3_clear_bindings(s.s) + return nil, err + } + + db := C.sqlite3_db_handle(s.s) + id := int64(C.sqlite3_last_insert_rowid(db)) + changes := int64(C.sqlite3_changes(db)) + + return &SQLiteResult{ + id: id, + changes: changes, + }, nil +} + +// Readonly reports if this statement is considered readonly by SQLite. +// +// See: https://sqlite.org/c3ref/stmt_readonly.html +func (s *SQLiteStmt) Readonly() bool { + return C.sqlite3_stmt_readonly(s.s) == 1 +} + +func (rc *SQLiteRows) Close() error { + rc.s.mu.Lock() + if rc.s.closed || rc.closed { + rc.s.mu.Unlock() + return nil + } + rc.closed = true + if rc.cls { + rc.s.mu.Unlock() + return rc.s.Close() + } + rv := C.sqlite3_reset(rc.s.s) + if rv != C.SQLITE_OK { + rc.s.mu.Unlock() + return rc.s.c.lastError() + } + rc.s.mu.Unlock() + rc.s = nil + runtime.SetFinalizer(rc, nil) + return nil +} + +func (rc *SQLiteRows) Columns() []string { + rc.s.mu.Lock() + defer rc.s.mu.Unlock() + if rc.s.s != nil && rc.nc != len(rc.cols) { + rc.cols = make([]string, rc.nc) + for i := 0; i < rc.nc; i++ { + rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i))) + } + } + return rc.cols +} + +func (rc *SQLiteRows) declTypes() []string { + if rc.s.s != nil && rc.decltype == nil { + rc.decltype = make([]string, rc.nc) + for i := 0; i < rc.nc; i++ { + rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i)))) + } + } + return rc.decltype +} + +// DeclTypes return column types. +func (rc *SQLiteRows) DeclTypes() []string { + rc.s.mu.Lock() + defer rc.s.mu.Unlock() + return rc.declTypes() +} + +// Next move cursor to next. Attempts to honor context timeout from QueryContext call. +func (rc *SQLiteRows) Next(dest []driver.Value) error { + rc.s.mu.Lock() + defer rc.s.mu.Unlock() + + if rc.s.closed { + return io.EOF + } + + if rc.ctx.Done() == nil { + return rc.nextSyncLocked(dest) + } + resultCh := make(chan error) + defer close(resultCh) + go func() { + resultCh <- rc.nextSyncLocked(dest) + }() + select { + case err := <-resultCh: + return err + case <-rc.ctx.Done(): + select { + case <-resultCh: // no need to interrupt + default: + // this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked. + C.sqlite3_interrupt(rc.s.c.db) + <-resultCh // ensure goroutine completed + } + return rc.ctx.Err() + } +} + +// nextSyncLocked moves cursor to next; must be called with locked mutex. +func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error { + rv := C.sqlite3_step(rc.s.s) + if rv == C.SQLITE_DONE { + return io.EOF + } + if rv != C.SQLITE_ROW { + rv = C.sqlite3_reset(rc.s.s) + if rv != C.SQLITE_OK { + return rc.s.c.lastError() + } + return nil + } + + rc.declTypes() + + for i := range dest { + switch C.sqlite3_column_type(rc.s.s, C.int(i)) { + case C.SQLITE_INTEGER: + val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i))) + switch rc.decltype[i] { + case columnTimestamp, columnDatetime, columnDate: + var t time.Time + // Assume a millisecond unix timestamp if it's 13 digits -- too + // large to be a reasonable timestamp in seconds. + if val > 1e12 || val < -1e12 { + val *= int64(time.Millisecond) // convert ms to nsec + t = time.Unix(0, val) + } else { + t = time.Unix(val, 0) + } + t = t.UTC() + if rc.s.c.loc != nil { + t = t.In(rc.s.c.loc) + } + dest[i] = t + case "boolean": + dest[i] = val > 0 + default: + dest[i] = val + } + case C.SQLITE_FLOAT: + dest[i] = float64(C.sqlite3_column_double(rc.s.s, C.int(i))) + case C.SQLITE_BLOB: + p := C.sqlite3_column_blob(rc.s.s, C.int(i)) + if p == nil { + dest[i] = []byte{} + continue + } + n := C.sqlite3_column_bytes(rc.s.s, C.int(i)) + dest[i] = C.GoBytes(p, n) + case C.SQLITE_NULL: + dest[i] = nil + case C.SQLITE_TEXT: + var err error + var timeVal time.Time + + n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i))) + s := C.GoStringN((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))), C.int(n)) + + switch rc.decltype[i] { + case columnTimestamp, columnDatetime, columnDate: + var t time.Time + s = strings.TrimSuffix(s, "Z") + for _, format := range SQLiteTimestampFormats { + if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil { + t = timeVal + break + } + } + if err != nil { + // The column is a time value, so return the zero time on parse failure. + t = time.Time{} + } + if rc.s.c.loc != nil { + t = t.In(rc.s.c.loc) + } + dest[i] = t + default: + dest[i] = s + } + } + } + return nil +} + +const i64 = unsafe.Sizeof(int(0)) > 4 + +// SQLiteContext behave sqlite3_context +type SQLiteContext C.sqlite3_context + +// ResultBool sets the result of an SQL function. +func (c *SQLiteContext) ResultBool(b bool) { + if b { + c.ResultInt(1) + } else { + c.ResultInt(0) + } +} + +// ResultBlob sets the result of an SQL function. +// See: sqlite3_result_blob, http://sqlite.org/c3ref/result_blob.html +func (c *SQLiteContext) ResultBlob(b []byte) { + if i64 && len(b) > math.MaxInt32 { + C.sqlite3_result_error_toobig((*C.sqlite3_context)(c)) + return + } + var p *byte + if len(b) > 0 { + p = &b[0] + } + C.sqlite3_result_blob( + (*C.sqlite3_context)(c), + unsafe.Pointer(p), + C.int(len(b)), + C.SQLITE_TRANSIENT, + ) +} + +// ResultDouble sets the result of an SQL function. +// See: sqlite3_result_double, http://sqlite.org/c3ref/result_blob.html +func (c *SQLiteContext) ResultDouble(d float64) { + C.sqlite3_result_double((*C.sqlite3_context)(c), C.double(d)) +} + +// ResultInt sets the result of an SQL function. +// See: sqlite3_result_int, http://sqlite.org/c3ref/result_blob.html +func (c *SQLiteContext) ResultInt(i int) { + if i64 && (i > math.MaxInt32 || i < math.MinInt32) { + C.sqlite3_result_int64((*C.sqlite3_context)(c), C.sqlite3_int64(i)) + } else { + C.sqlite3_result_int((*C.sqlite3_context)(c), C.int(i)) + } +} + +// ResultInt64 sets the result of an SQL function. +// See: sqlite3_result_int64, http://sqlite.org/c3ref/result_blob.html +func (c *SQLiteContext) ResultInt64(i int64) { + C.sqlite3_result_int64((*C.sqlite3_context)(c), C.sqlite3_int64(i)) +} + +// ResultNull sets the result of an SQL function. +// See: sqlite3_result_null, http://sqlite.org/c3ref/result_blob.html +func (c *SQLiteContext) ResultNull() { + C.sqlite3_result_null((*C.sqlite3_context)(c)) +} + +// ResultText sets the result of an SQL function. +// See: sqlite3_result_text, http://sqlite.org/c3ref/result_blob.html +func (c *SQLiteContext) ResultText(s string) { + h := (*reflect.StringHeader)(unsafe.Pointer(&s)) + cs, l := (*C.char)(unsafe.Pointer(h.Data)), C.int(h.Len) + C.sqlite3_result_text( + (*C.sqlite3_context)(c), + cs, + l, + C.SQLITE_TRANSIENT, + ) +} + +// ResultZeroblob sets the result of an SQL function. +// See: sqlite3_result_zeroblob, http://sqlite.org/c3ref/result_blob.html +func (c *SQLiteContext) ResultZeroblob(n int) { + C.sqlite3_result_zeroblob((*C.sqlite3_context)(c), C.int(n)) +} + +// Ping implement Pinger. +func (c *SQLiteConn) Ping(ctx context.Context) error { + if c.db == nil { + // must be ErrBadConn for sql to close the database + return driver.ErrBadConn + } + return nil +} + +// QueryContext implement QueryerContext. +func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + return c.query(ctx, query, args) +} + +// ExecContext implement ExecerContext. +func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + return c.exec(ctx, query, args) +} + +// PrepareContext implement ConnPrepareContext. +func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + return c.prepare(ctx, query) +} + +// BeginTx implement ConnBeginTx. +func (c *SQLiteConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + return c.begin(ctx) +} + +// QueryContext implement QueryerContext. +func (s *SQLiteStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + return s.query(ctx, args) +} + +// ExecContext implement ExecerContext. +func (s *SQLiteStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + return s.exec(ctx, args) +} + +// ColumnTableName returns the table that is the origin of a particular result +// column in a SELECT statement. +// +// See https://www.sqlite.org/c3ref/column_database_name.html +func (s *SQLiteStmt) ColumnTableName(n int) string { + return C.GoString(C.sqlite3_column_table_name(s.s, C.int(n))) +} + +// Serialize returns a byte slice that is a serialization of the database. +// +// See https://www.sqlite.org/c3ref/serialize.html +func (c *SQLiteConn) Serialize(schema string) ([]byte, error) { + if schema == "" { + schema = "main" + } + var zSchema *C.char + zSchema = C.CString(schema) + defer C.free(unsafe.Pointer(zSchema)) + + var sz C.sqlite3_int64 + ptr := C.sqlite3_serialize(c.db, zSchema, &sz, 0) + if ptr == nil { + return nil, fmt.Errorf("serialize failed") + } + defer C.sqlite3_free(unsafe.Pointer(ptr)) + + if sz > C.sqlite3_int64(math.MaxInt) { + return nil, fmt.Errorf("serialized database is too large (%d bytes)", sz) + } + + cBuf := *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(ptr)), + Len: int(sz), + Cap: int(sz), + })) + + res := make([]byte, int(sz)) + copy(res, cBuf) + return res, nil +} + +// Deserialize causes the connection to disconnect from the current database and +// then re-open as an in-memory database based on the contents of the byte slice. +// +// See https://www.sqlite.org/c3ref/deserialize.html +func (c *SQLiteConn) Deserialize(b []byte, schema string) error { + if schema == "" { + schema = "main" + } + var zSchema *C.char + zSchema = C.CString(schema) + defer C.free(unsafe.Pointer(zSchema)) + + tmpBuf := (*C.uchar)(C.sqlite3_malloc64(C.sqlite3_uint64(len(b)))) + cBuf := *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(tmpBuf)), + Len: len(b), + Cap: len(b), + })) + copy(cBuf, b) + + rc := C.sqlite3_deserialize(c.db, zSchema, tmpBuf, C.sqlite3_int64(len(b)), + C.sqlite3_int64(len(b)), C.SQLITE_DESERIALIZE_FREEONCLOSE) + if rc != C.SQLITE_OK { + return fmt.Errorf("deserialize failed with return %v", rc) + } + return nil +} + +// Op is type of operations. +type Op uint8 + +// Op mean identity of operations. +const ( + OpEQ Op = 2 + OpGT = 4 + OpLE = 8 + OpLT = 16 + OpGE = 32 + OpMATCH = 64 + OpLIKE = 65 /* 3.10.0 and later only */ + OpGLOB = 66 /* 3.10.0 and later only */ + OpREGEXP = 67 /* 3.10.0 and later only */ + OpScanUnique = 1 /* Scan visits at most 1 row */ +) + +// InfoConstraint give information of constraint. +type InfoConstraint struct { + Column int + Op Op + Usable bool +} + +// InfoOrderBy give information of order-by. +type InfoOrderBy struct { + Column int + Desc bool +} + +func constraints(info *C.sqlite3_index_info) []InfoConstraint { + slice := *(*[]C.struct_sqlite3_index_constraint)(unsafe.Pointer(&reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(info.aConstraint)), + Len: int(info.nConstraint), + Cap: int(info.nConstraint), + })) + + cst := make([]InfoConstraint, 0, len(slice)) + for _, c := range slice { + var usable bool + if c.usable > 0 { + usable = true + } + cst = append(cst, InfoConstraint{ + Column: int(c.iColumn), + Op: Op(c.op), + Usable: usable, + }) + } + return cst +} + +func orderBys(info *C.sqlite3_index_info) []InfoOrderBy { + slice := *(*[]C.struct_sqlite3_index_orderby)(unsafe.Pointer(&reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(info.aOrderBy)), + Len: int(info.nOrderBy), + Cap: int(info.nOrderBy), + })) + + ob := make([]InfoOrderBy, 0, len(slice)) + for _, c := range slice { + var desc bool + if c.desc > 0 { + desc = true + } + ob = append(ob, InfoOrderBy{ + Column: int(c.iColumn), + Desc: desc, + }) + } + return ob +} + +// IndexResult is a Go struct representation of what eventually ends up in the +// output fields for `sqlite3_index_info` +// See: https://www.sqlite.org/c3ref/index_info.html +type IndexResult struct { + Used []bool // aConstraintUsage + IdxNum int + IdxStr string + AlreadyOrdered bool // orderByConsumed + EstimatedCost float64 + EstimatedRows float64 +} + +func fillDBError(dbErr *Error, db *C.sqlite3) { + // See SQLiteConn.lastError(), in file 'sqlite3.go' at the time of writing (Sept 5, 2016) + dbErr.Code = ErrNo(C.sqlite3_errcode(db)) + dbErr.ExtendedCode = ErrNoExtended(C.sqlite3_extended_errcode(db)) + dbErr.err = C.GoString(C.sqlite3_errmsg(db)) +} + +// ColumnTypeDatabaseTypeName implement RowsColumnTypeDatabaseTypeName. +func (rc *SQLiteRows) ColumnTypeDatabaseTypeName(i int) string { + return C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))) +} + +// ColumnTypeNullable implement RowsColumnTypeNullable. +func (rc *SQLiteRows) ColumnTypeNullable(i int) (nullable, ok bool) { + return true, true +} + +// ColumnTypeScanType implement RowsColumnTypeScanType. +func (rc *SQLiteRows) ColumnTypeScanType(i int) reflect.Type { + //ct := C.sqlite3_column_type(rc.s.s, C.int(i)) // Always returns 5 + return scanType(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i)))) +} + +const ( + SQLITE_INTEGER = iota + SQLITE_TEXT + SQLITE_BLOB + SQLITE_REAL + SQLITE_NUMERIC + SQLITE_TIME + SQLITE_BOOL + SQLITE_NULL +) + +func scanType(cdt string) reflect.Type { + t := strings.ToUpper(cdt) + i := databaseTypeConvSqlite(t) + switch i { + case SQLITE_INTEGER: + return reflect.TypeOf(sql.NullInt64{}) + case SQLITE_TEXT: + return reflect.TypeOf(sql.NullString{}) + case SQLITE_BLOB: + return reflect.TypeOf(sql.RawBytes{}) + case SQLITE_REAL: + return reflect.TypeOf(sql.NullFloat64{}) + case SQLITE_NUMERIC: + return reflect.TypeOf(sql.NullFloat64{}) + case SQLITE_BOOL: + return reflect.TypeOf(sql.NullBool{}) + case SQLITE_TIME: + return reflect.TypeOf(sql.NullTime{}) + } + return reflect.TypeOf(new(any)) +} + +func databaseTypeConvSqlite(t string) int { + if strings.Contains(t, "INT") { + return SQLITE_INTEGER + } + if t == "CLOB" || t == "TEXT" || + strings.Contains(t, "CHAR") { + return SQLITE_TEXT + } + if t == "BLOB" { + return SQLITE_BLOB + } + if t == "REAL" || t == "FLOAT" || + strings.Contains(t, "DOUBLE") { + return SQLITE_REAL + } + if t == "DATE" || t == "DATETIME" || + t == "TIMESTAMP" { + return SQLITE_TIME + } + if t == "NUMERIC" || + strings.Contains(t, "DECIMAL") { + return SQLITE_NUMERIC + } + if t == "BOOLEAN" { + return SQLITE_BOOL + } + + return SQLITE_NULL +} diff --git a/tests/acudego.go b/tests/acudego.go deleted file mode 100644 index de41ca9..0000000 --- a/tests/acudego.go +++ /dev/null @@ -1,3652 +0,0 @@ -package acudego - -import ( - "bytes" - "context" - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "io/ioutil" - "math" - "math/rand" - "net/url" - "os" - "path" - "reflect" - "strings" - "sync" - "testing" - "testing/internal/testdeps" - "time" -) - - - -// The number of rows of test data to create in the source database. -// Can be used to control how many pages are available to be backed up. -const testRowCount = 100 - -// The maximum number of seconds after which the page-by-page backup is considered to have taken too long. -const usePagePerStepsTimeoutSeconds = 30 - -// Test the backup functionality. -func testBackup(t *testing.T, testRowCount int, usePerPageSteps bool) { - // This function will be called multiple times. - // It uses sql.Register(), which requires the name parameter value to be unique. - // There does not currently appear to be a way to unregister a registered driver, however. - // So generate a database driver name that will likely be unique. - var driverName = fmt.Sprintf("sqlite3_testBackup_%v_%v_%v", testRowCount, usePerPageSteps, time.Now().UnixNano()) - - // The driver's connection will be needed in order to perform the backup. - driverConns := []*SQLiteConn{} - sql.Register(driverName, &SQLiteDriver{ - ConnectHook: func(conn *SQLiteConn) error { - driverConns = append(driverConns, conn) - return nil - }, - }) - - // Connect to the source database. - srcTempFilename := "file:src?mode=memory&cache=shared" - srcDb, err := sql.Open(driverName, srcTempFilename) - if err != nil { - t.Fatal("Failed to open the source database:", err) - } - defer srcDb.Close() - err = srcDb.Ping() - if err != nil { - t.Fatal("Failed to connect to the source database:", err) - } - - // Connect to the destination database. - destTempFilename := "file:dst?mode=memory&cache=shared" - destDb, err := sql.Open(driverName, destTempFilename) - if err != nil { - t.Fatal("Failed to open the destination database:", err) - } - defer destDb.Close() - err = destDb.Ping() - if err != nil { - t.Fatal("Failed to connect to the destination database:", err) - } - - // Check the driver connections. - if len(driverConns) != 2 { - t.Fatalf("Expected 2 driver connections, but found %v.", len(driverConns)) - } - srcDbDriverConn := driverConns[0] - if srcDbDriverConn == nil { - t.Fatal("The source database driver connection is nil.") - } - destDbDriverConn := driverConns[1] - if destDbDriverConn == nil { - t.Fatal("The destination database driver connection is nil.") - } - - // Generate some test data for the given ID. - var generateTestData = func(id int) string { - return fmt.Sprintf("test-%v", id) - } - - // Populate the source database with a test table containing some test data. - tx, err := srcDb.Begin() - if err != nil { - t.Fatal("Failed to begin a transaction when populating the source database:", err) - } - _, err = srcDb.Exec("CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)") - if err != nil { - tx.Rollback() - t.Fatal("Failed to create the source database \"test\" table:", err) - } - for id := 0; id < testRowCount; id++ { - _, err = srcDb.Exec("INSERT INTO test (id, value) VALUES (?, ?)", id, generateTestData(id)) - if err != nil { - tx.Rollback() - t.Fatal("Failed to insert a row into the source database \"test\" table:", err) - } - } - err = tx.Commit() - if err != nil { - t.Fatal("Failed to populate the source database:", err) - } - - // Confirm that the destination database is initially empty. - var destTableCount int - err = destDb.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'").Scan(&destTableCount) - if err != nil { - t.Fatal("Failed to check the destination table count:", err) - } - if destTableCount != 0 { - t.Fatalf("The destination database is not empty; %v table(s) found.", destTableCount) - } - - // Prepare to perform the backup. - backup, err := destDbDriverConn.Backup("main", srcDbDriverConn, "main") - if err != nil { - t.Fatal("Failed to initialize the backup:", err) - } - - // Allow the initial page count and remaining values to be retrieved. - // According to , the page count and remaining values are "... only updated by sqlite3_backup_step()." - isDone, err := backup.Step(0) - if err != nil { - t.Fatal("Unable to perform an initial 0-page backup step:", err) - } - if isDone { - t.Fatal("Backup is unexpectedly done.") - } - - // Check that the page count and remaining values are reasonable. - initialPageCount := backup.PageCount() - if initialPageCount <= 0 { - t.Fatalf("Unexpected initial page count value: %v", initialPageCount) - } - initialRemaining := backup.Remaining() - if initialRemaining <= 0 { - t.Fatalf("Unexpected initial remaining value: %v", initialRemaining) - } - if initialRemaining != initialPageCount { - t.Fatalf("Initial remaining value differs from the initial page count value; remaining: %v; page count: %v", initialRemaining, initialPageCount) - } - - // Perform the backup. - if usePerPageSteps { - var startTime = time.Now().Unix() - - // Test backing-up using a page-by-page approach. - var latestRemaining = initialRemaining - for { - // Perform the backup step. - isDone, err = backup.Step(1) - if err != nil { - t.Fatal("Failed to perform a backup step:", err) - } - - // The page count should remain unchanged from its initial value. - currentPageCount := backup.PageCount() - if currentPageCount != initialPageCount { - t.Fatalf("Current page count differs from the initial page count; initial page count: %v; current page count: %v", initialPageCount, currentPageCount) - } - - // There should now be one less page remaining. - currentRemaining := backup.Remaining() - expectedRemaining := latestRemaining - 1 - if currentRemaining != expectedRemaining { - t.Fatalf("Unexpected remaining value; expected remaining value: %v; actual remaining value: %v", expectedRemaining, currentRemaining) - } - latestRemaining = currentRemaining - - if isDone { - break - } - - // Limit the runtime of the backup attempt. - if (time.Now().Unix() - startTime) > usePagePerStepsTimeoutSeconds { - t.Fatal("Backup is taking longer than expected.") - } - } - } else { - // Test the copying of all remaining pages. - isDone, err = backup.Step(-1) - if err != nil { - t.Fatal("Failed to perform a backup step:", err) - } - if !isDone { - t.Fatal("Backup is unexpectedly not done.") - } - } - - // Check that the page count and remaining values are reasonable. - finalPageCount := backup.PageCount() - if finalPageCount != initialPageCount { - t.Fatalf("Final page count differs from the initial page count; initial page count: %v; final page count: %v", initialPageCount, finalPageCount) - } - finalRemaining := backup.Remaining() - if finalRemaining != 0 { - t.Fatalf("Unexpected remaining value: %v", finalRemaining) - } - - // Finish the backup. - err = backup.Finish() - if err != nil { - t.Fatal("Failed to finish backup:", err) - } - - // Confirm that the "test" table now exists in the destination database. - var doesTestTableExist bool - err = destDb.QueryRow("SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'test' LIMIT 1) AS test_table_exists").Scan(&doesTestTableExist) - if err != nil { - t.Fatal("Failed to check if the \"test\" table exists in the destination database:", err) - } - if !doesTestTableExist { - t.Fatal("The \"test\" table could not be found in the destination database.") - } - - // Confirm that the number of rows in the destination database's "test" table matches that of the source table. - var actualTestTableRowCount int - err = destDb.QueryRow("SELECT COUNT(*) FROM test").Scan(&actualTestTableRowCount) - if err != nil { - t.Fatal("Failed to determine the rowcount of the \"test\" table in the destination database:", err) - } - if testRowCount != actualTestTableRowCount { - t.Fatalf("Unexpected destination \"test\" table row count; expected: %v; found: %v", testRowCount, actualTestTableRowCount) - } - - // Check each of the rows in the destination database. - for id := 0; id < testRowCount; id++ { - var checkedValue string - err = destDb.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&checkedValue) - if err != nil { - t.Fatal("Failed to query the \"test\" table in the destination database:", err) - } - - var expectedValue = generateTestData(id) - if checkedValue != expectedValue { - t.Fatalf("Unexpected value in the \"test\" table in the destination database; expected value: %v; actual value: %v", expectedValue, checkedValue) - } - } -} - -func TestBackupStepByStep(t *testing.T) { - testBackup(t, testRowCount, true) -} - -func TestBackupAllRemainingPages(t *testing.T) { - testBackup(t, testRowCount, false) -} - -// Test the error reporting when preparing to perform a backup. -func TestBackupError(t *testing.T) { - const driverName = "sqlite3_TestBackupError" - - // The driver's connection will be needed in order to perform the backup. - var dbDriverConn *SQLiteConn - sql.Register(driverName, &SQLiteDriver{ - ConnectHook: func(conn *SQLiteConn) error { - dbDriverConn = conn - return nil - }, - }) - - // Connect to the database. - db, err := sql.Open(driverName, ":memory:") - if err != nil { - t.Fatal("Failed to open the database:", err) - } - defer db.Close() - db.Ping() - - // Need the driver connection in order to perform the backup. - if dbDriverConn == nil { - t.Fatal("Failed to get the driver connection.") - } - - // Prepare to perform the backup. - // Intentionally using the same connection for both the source and destination databases, to trigger an error result. - backup, err := dbDriverConn.Backup("main", dbDriverConn, "main") - if err == nil { - t.Fatal("Failed to get the expected error result.") - } - const expectedError = "source and destination must be distinct" - if err.Error() != expectedError { - t.Fatalf("Unexpected error message; expected value: \"%v\"; actual value: \"%v\"", expectedError, err.Error()) - } - if backup != nil { - t.Fatal("Failed to get the expected nil backup result.") - } -} - -func TestCallbackArgCast(t *testing.T) { - intConv := callbackSyntheticForTests(reflect.ValueOf(int64(math.MaxInt64)), nil) - floatConv := callbackSyntheticForTests(reflect.ValueOf(float64(math.MaxFloat64)), nil) - errConv := callbackSyntheticForTests(reflect.Value{}, errors.New("test")) - - tests := []struct { - f callbackArgConverter - o reflect.Value - }{ - {intConv, reflect.ValueOf(int8(-1))}, - {intConv, reflect.ValueOf(int16(-1))}, - {intConv, reflect.ValueOf(int32(-1))}, - {intConv, reflect.ValueOf(uint8(math.MaxUint8))}, - {intConv, reflect.ValueOf(uint16(math.MaxUint16))}, - {intConv, reflect.ValueOf(uint32(math.MaxUint32))}, - // Special case, int64->uint64 is only 1<<63 - 1, not 1<<64 - 1 - {intConv, reflect.ValueOf(uint64(math.MaxInt64))}, - {floatConv, reflect.ValueOf(float32(math.Inf(1)))}, - } - - for _, test := range tests { - conv := callbackArgCast{test.f, test.o.Type()} - val, err := conv.Run(nil) - if err != nil { - t.Errorf("Couldn't convert to %s: %s", test.o.Type(), err) - } else if !reflect.DeepEqual(val.Interface(), test.o.Interface()) { - t.Errorf("Unexpected result from converting to %s: got %v, want %v", test.o.Type(), val.Interface(), test.o.Interface()) - } - } - - conv := callbackArgCast{errConv, reflect.TypeOf(int8(0))} - _, err := conv.Run(nil) - if err == nil { - t.Errorf("Expected error during callbackArgCast, but got none") - } -} - -func TestCallbackConverters(t *testing.T) { - tests := []struct { - v any - err bool - }{ - // Unfortunately, we can't tell which converter was returned, - // but we can at least check which types can be converted. - {[]byte{0}, false}, - {"text", false}, - {true, false}, - {int8(0), false}, - {int16(0), false}, - {int32(0), false}, - {int64(0), false}, - {uint8(0), false}, - {uint16(0), false}, - {uint32(0), false}, - {uint64(0), false}, - {int(0), false}, - {uint(0), false}, - {float64(0), false}, - {float32(0), false}, - - {func() {}, true}, - {complex64(complex(0, 0)), true}, - {complex128(complex(0, 0)), true}, - {struct{}{}, true}, - {map[string]string{}, true}, - {[]string{}, true}, - {(*int8)(nil), true}, - {make(chan int), true}, - } - - for _, test := range tests { - _, err := callbackArg(reflect.TypeOf(test.v)) - if test.err && err == nil { - t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v)) - } else if !test.err && err != nil { - t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err) - } - } - - for _, test := range tests { - _, err := callbackRet(reflect.TypeOf(test.v)) - if test.err && err == nil { - t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v)) - } else if !test.err && err != nil { - t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err) - } - } -} - -func TestCallbackReturnAny(t *testing.T) { - udf := func() any { - return 1 - } - - typ := reflect.TypeOf(udf) - _, err := callbackRet(typ.Out(0)) - if err != nil { - t.Errorf("Expected valid callback for any return type, got: %s", err) - } -} - -func TestSimpleError(t *testing.T) { - e := ErrError.Error() - if e != "SQL logic error or missing database" && e != "SQL logic error" { - t.Error("wrong error code: " + e) - } -} - -func TestCorruptDbErrors(t *testing.T) { - dirName, err := ioutil.TempDir("", "FIXME") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dirName) - - dbFileName := path.Join(dirName, "test.db") - f, err := os.Create(dbFileName) - if err != nil { - t.Error(err) - } - f.Write([]byte{1, 2, 3, 4, 5}) - f.Close() - - db, err := sql.Open(DriverName, dbFileName) - if err == nil { - _, err = db.Exec("drop table foo") - } - - sqliteErr := err.(Error) - if sqliteErr.Code != ErrNotADB { - t.Error("wrong error code for corrupted DB") - } - if err.Error() == "" { - t.Error("wrong error string for corrupted DB") - } - db.Close() -} - -func TestSqlLogicErrors(t *testing.T) { - dirName, err := ioutil.TempDir("", "FIXME") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dirName) - - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Error(err) - } - defer db.Close() - - _, err = db.Exec("CREATE TABLE Foo (id INTEGER PRIMARY KEY)") - if err != nil { - t.Error(err) - } - - const expectedErr = "table Foo already exists" - _, err = db.Exec("CREATE TABLE Foo (id INTEGER PRIMARY KEY)") - if err.Error() != expectedErr { - t.Errorf("Unexpected error: %s, expected %s", err.Error(), expectedErr) - } - -} - -func TestExtendedErrorCodes_ForeignKey(t *testing.T) { - dirName, err := ioutil.TempDir("", "sqlite3-err") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dirName) - - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Error(err) - } - defer db.Close() - - _, err = db.Exec(`CREATE TABLE Foo ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - value INTEGER NOT NULL, - ref INTEGER NULL REFERENCES Foo (id), - UNIQUE(value) - );`) - if err != nil { - t.Error(err) - } - - _, err = db.Exec("INSERT INTO Foo (ref, value) VALUES (100, 100);") - if err == nil { - t.Error("No error!") - } else { - sqliteErr := err.(Error) - if sqliteErr.Code != ErrConstraint { - t.Errorf("Wrong basic error code: %d != %d", - sqliteErr.Code, ErrConstraint) - } - if sqliteErr.ExtendedCode != ErrConstraintForeignKey { - t.Errorf("Wrong extended error code: %d != %d", - sqliteErr.ExtendedCode, ErrConstraintForeignKey) - } - } - -} - -func TestExtendedErrorCodes_NotNull(t *testing.T) { - dirName, err := ioutil.TempDir("", "sqlite3-err") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dirName) - - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Error(err) - } - defer db.Close() - - _, err = db.Exec(`CREATE TABLE Foo ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - value INTEGER NOT NULL, - ref INTEGER NULL REFERENCES Foo (id), - UNIQUE(value) - );`) - if err != nil { - t.Error(err) - } - - res, err := db.Exec("INSERT INTO Foo (value) VALUES (100);") - if err != nil { - t.Fatalf("Creating first row: %v", err) - } - - id, err := res.LastInsertId() - if err != nil { - t.Fatalf("Retrieving last insert id: %v", err) - } - - _, err = db.Exec("INSERT INTO Foo (ref) VALUES (?);", id) - if err == nil { - t.Error("No error!") - } else { - sqliteErr := err.(Error) - if sqliteErr.Code != ErrConstraint { - t.Errorf("Wrong basic error code: %d != %d", - sqliteErr.Code, ErrConstraint) - } - if sqliteErr.ExtendedCode != ErrConstraintNotNull { - t.Errorf("Wrong extended error code: %d != %d", - sqliteErr.ExtendedCode, ErrConstraintNotNull) - } - } - -} - -func TestExtendedErrorCodes_Unique(t *testing.T) { - dirName, err := ioutil.TempDir("", "sqlite3-err") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dirName) - - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Error(err) - } - defer db.Close() - - _, err = db.Exec(`CREATE TABLE Foo ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - value INTEGER NOT NULL, - ref INTEGER NULL REFERENCES Foo (id), - UNIQUE(value) - );`) - if err != nil { - t.Error(err) - } - - res, err := db.Exec("INSERT INTO Foo (value) VALUES (100);") - if err != nil { - t.Fatalf("Creating first row: %v", err) - } - - id, err := res.LastInsertId() - if err != nil { - t.Fatalf("Retrieving last insert id: %v", err) - } - - _, err = db.Exec("INSERT INTO Foo (ref, value) VALUES (?, 100);", id) - if err == nil { - t.Error("No error!") - } else { - sqliteErr := err.(Error) - if sqliteErr.Code != ErrConstraint { - t.Errorf("Wrong basic error code: %d != %d", - sqliteErr.Code, ErrConstraint) - } - if sqliteErr.ExtendedCode != ErrConstraintUnique { - t.Errorf("Wrong extended error code: %d != %d", - sqliteErr.ExtendedCode, ErrConstraintUnique) - } - extended := sqliteErr.Code.Extend(3).Error() - expected := "constraint failed" - if extended != expected { - t.Errorf("Wrong basic error code: %q != %q", - extended, expected) - } - } -} - -func TestError_SystemErrno(t *testing.T) { - _, n, _ := LibVersion() - if n < 3012000 { - t.Skip("sqlite3_system_errno requires sqlite3 >= 3.12.0") - } - - // open a non-existent database in read-only mode so we get an IO error. - db, err := sql.Open(DriverName, "file:nonexistent.db?mode=ro") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - err = db.Ping() - if err == nil { - t.Fatal("expected error pinging read-only non-existent database, but got nil") - } - - serr, ok := err.(Error) - if !ok { - t.Fatalf("expected error to be of type Error, but got %[1]T %[1]v", err) - } - - if serr.SystemErrno == 0 { - t.Fatal("expected SystemErrno to be set") - } - - if !os.IsNotExist(serr.SystemErrno) { - t.Errorf("expected SystemErrno to be a not exists error, but got %v", serr.SystemErrno) - } -} - -func TestBeginTxCancel(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal(err) - } - - db.SetMaxOpenConns(10) - db.SetMaxIdleConns(5) - - defer db.Close() - initDatabase(t, db, 100) - - // create several go-routines to expose racy issue - for i := 0; i < 1000; i++ { - func() { - ctx, cancel := context.WithCancel(context.Background()) - conn, err := db.Conn(ctx) - if err != nil { - t.Fatal(err) - } - defer func() { - if err := conn.Close(); err != nil { - t.Error(err) - } - }() - - err = conn.Raw(func(driverConn any) error { - d, ok := driverConn.(driver.ConnBeginTx) - if !ok { - t.Fatal("unexpected: wrong type") - } - // checks that conn.Raw can be used to get *SQLiteConn - if _, ok = driverConn.(*SQLiteConn); !ok { - t.Fatalf("conn.Raw() driverConn type=%T, expected *SQLiteConn", driverConn) - } - - go cancel() // make it cancel concurrently with exec("BEGIN"); - tx, err := d.BeginTx(ctx, driver.TxOptions{}) - switch err { - case nil: - switch err := tx.Rollback(); err { - case nil, sql.ErrTxDone: - default: - return err - } - case context.Canceled: - default: - // must not fail with "cannot start a transaction within a transaction" - return err - } - return nil - }) - if err != nil { - t.Fatal(err) - } - }() - } -} - -func TestStmtReadonly(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - _, err = db.Exec("CREATE TABLE t (count INT)") - if err != nil { - t.Fatal(err) - } - - isRO := func(query string) bool { - c, err := db.Conn(context.Background()) - if err != nil { - return false - } - - var ro bool - c.Raw(func(dc any) error { - stmt, err := dc.(*SQLiteConn).Prepare(query) - if err != nil { - return err - } - if stmt == nil { - return errors.New("stmt is nil") - } - ro = stmt.(*SQLiteStmt).Readonly() - return nil - }) - return ro // On errors ro will remain false. - } - - if !isRO(`select * from t`) { - t.Error("select not seen as read-only") - } - if isRO(`insert into t values (1), (2)`) { - t.Error("insert seen as read-only") - } -} - -func TestNamedParams(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec(` - create table foo (id integer, name text, extra text); - `) - if err != nil { - t.Error("Failed to call db.Query:", err) - } - - _, err = db.Exec(`insert into foo(id, name, extra) values(:id, :name, :name)`, sql.Named("name", "foo"), sql.Named("id", 1)) - if err != nil { - t.Error("Failed to call db.Exec:", err) - } - - row := db.QueryRow(`select id, extra from foo where id = :id and extra = :extra`, sql.Named("id", 1), sql.Named("extra", "foo")) - if row == nil { - t.Error("Failed to call db.QueryRow") - } - var id int - var extra string - err = row.Scan(&id, &extra) - if err != nil { - t.Error("Failed to db.Scan:", err) - } - if id != 1 || extra != "foo" { - t.Error("Failed to db.QueryRow: not matched results") - } -} - -var ( - testTableStatements = []string{ - `DROP TABLE IF EXISTS test_table`, - ` -CREATE TABLE IF NOT EXISTS test_table ( - key1 VARCHAR(64) PRIMARY KEY, - key_id VARCHAR(64) NOT NULL, - key2 VARCHAR(64) NOT NULL, - key3 VARCHAR(64) NOT NULL, - key4 VARCHAR(64) NOT NULL, - key5 VARCHAR(64) NOT NULL, - key6 VARCHAR(64) NOT NULL, - data BLOB NOT NULL -);`, - } - letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" -) - -func randStringBytes(n int) string { - b := make([]byte, n) - for i := range b { - b[i] = letterBytes[rand.Intn(len(letterBytes))] - } - return string(b) -} - -func initDatabase(t *testing.T, db *sql.DB, rowCount int64) { - for _, query := range testTableStatements { - _, err := db.Exec(query) - if err != nil { - t.Fatal(err) - } - } - for i := int64(0); i < rowCount; i++ { - query := `INSERT INTO test_table - (key1, key_id, key2, key3, key4, key5, key6, data) - VALUES - (?, ?, ?, ?, ?, ?, ?, ?);` - args := []interface{}{ - randStringBytes(50), - fmt.Sprint(i), - randStringBytes(50), - randStringBytes(50), - randStringBytes(50), - randStringBytes(50), - randStringBytes(50), - randStringBytes(50), - randStringBytes(2048), - } - _, err := db.Exec(query, args...) - if err != nil { - t.Fatal(err) - } - } -} - -func TestShortTimeout(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal(err) - } - defer db.Close() - initDatabase(t, db, 100) - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond) - defer cancel() - query := `SELECT key1, key_id, key2, key3, key4, key5, key6, data - FROM test_table - ORDER BY key2 ASC` - _, err = db.QueryContext(ctx, query) - if err != nil && err != context.DeadlineExceeded { - t.Fatal(err) - } - if ctx.Err() != nil && ctx.Err() != context.DeadlineExceeded { - t.Fatal(ctx.Err()) - } -} - -func TestExecContextCancel(t *testing.T) { - db, err := sql.Open(DriverName, "file:exec?mode=memory&cache=shared") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - ts := time.Now() - initDatabase(t, db, 1000) - spent := time.Since(ts) - const minTestTime = 100 * time.Millisecond - if spent < minTestTime && false { - t.Skipf("test will be too racy (spent=%s < min=%s) as ExecContext below will be too fast.", - spent.String(), minTestTime.String(), - ) - } - - // expected to be extremely slow query - q := ` -INSERT INTO test_table (key1, key_id, key2, key3, key4, key5, key6, data) -SELECT t1.key1 || t2.key1, t1.key_id || t2.key_id, t1.key2 || t2.key2, t1.key3 || t2.key3, t1.key4 || t2.key4, t1.key5 || t2.key5, t1.key6 || t2.key6, t1.data || t2.data -FROM test_table t1 LEFT OUTER JOIN test_table t2` - // expect query above take ~ same time as setup above - // This is racy: the context must be valid so sql/db.ExecContext calls the sqlite3 driver. - // It starts the query, the context expires, then calls sqlite3_interrupt - ctx, cancel := context.WithTimeout(context.Background(), minTestTime/2) - defer cancel() - ts = time.Now() - r, err := db.ExecContext(ctx, q) - // racy check - if r != nil { - n, err := r.RowsAffected() - t.Logf("query should not have succeeded: rows=%d; err=%v; duration=%s", - n, err, time.Since(ts).String()) - } - if err != context.DeadlineExceeded { - t.Fatal(err, ctx.Err()) - } -} - -func TestQueryRowContextCancel(t *testing.T) { - // FIXME: too slow - db, err := sql.Open(DriverName, "file:query?mode=memory&cache=shared") - if err != nil { - t.Fatal(err) - } - defer db.Close() - initDatabase(t, db, 100) - - const query = `SELECT key_id FROM test_table ORDER BY key2 ASC` - var keyID string - unexpectedErrors := make(map[string]int) - for i := 0; i < 10000; i++ { - ctx, cancel := context.WithCancel(context.Background()) - row := db.QueryRowContext(ctx, query) - - cancel() - // it is fine to get "nil" as context cancellation can be handled with delay - if err := row.Scan(&keyID); err != nil && err != context.Canceled { - if err.Error() == "sql: Rows are closed" { - // see https://github.com/golang/go/issues/24431 - // fixed in 1.11.1 to properly return context error - continue - } - unexpectedErrors[err.Error()]++ - } - } - for errText, count := range unexpectedErrors { - t.Error(errText, count) - } -} - -func TestQueryRowContextCancelParallel(t *testing.T) { - // FIXME: too slow - db, err := sql.Open(DriverName, "file:parallel?mode=memory&cache=shared") - if err != nil { - t.Fatal(err) - } - db.SetMaxOpenConns(10) - db.SetMaxIdleConns(5) - - defer db.Close() - initDatabase(t, db, 100) - - const query = `SELECT key_id FROM test_table ORDER BY key2 ASC` - wg := sync.WaitGroup{} - defer wg.Wait() - - testCtx, cancel := context.WithCancel(context.Background()) - defer cancel() - - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - - var keyID string - for { - select { - case <-testCtx.Done(): - return - default: - } - ctx, cancel := context.WithCancel(context.Background()) - row := db.QueryRowContext(ctx, query) - - cancel() - _ = row.Scan(&keyID) // see TestQueryRowContextCancel - } - }() - } - - var keyID string - for i := 0; i < 10000; i++ { - // note that testCtx is not cancelled during query execution - row := db.QueryRowContext(testCtx, query) - - if err := row.Scan(&keyID); err != nil { - t.Fatal(i, err) - } - } -} - -func TestExecCancel(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - if _, err = db.Exec("create table foo (id integer primary key)"); err != nil { - t.Fatal(err) - } - - for n := 0; n < 100; n++ { - ctx, cancel := context.WithCancel(context.Background()) - _, err = db.ExecContext(ctx, "insert into foo (id) values (?)", n) - cancel() - if err != nil { - t.Fatal(err) - } - } -} - -func doTestOpenContext(t *testing.T, url string) (string, error) { - db, err := sql.Open(DriverName, url) - if err != nil { - return "Failed to open database:", err - } - - defer func() { - err = db.Close() - if err != nil { - t.Error("db close error:", err) - } - }() - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - err = db.PingContext(ctx) - cancel() - if err != nil { - return "ping error:", err - } - - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - _, err = db.ExecContext(ctx, "drop table foo") - cancel() - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - _, err = db.ExecContext(ctx, "create table foo (id integer)") - cancel() - if err != nil { - return "Failed to create table:", err - } - - return "", nil -} - -func TestFileCopyTruncate(t *testing.T) { - var err error - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - - db, err := sql.Open(DriverName, tempFilename) - if err != nil { - t.Fatal("open error:", err) - } - defer db.Close() - - if true { - _, err = db.Exec("PRAGMA journal_mode = delete;") - if err != nil { - t.Fatal("journal_mode delete:", err) - } - } - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - err = db.PingContext(ctx) - cancel() - if err != nil { - t.Fatal("ping error:", err) - } - - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - _, err = db.ExecContext(ctx, "drop table foo") - cancel() - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - _, err = db.ExecContext(ctx, "create table foo (id integer)") - cancel() - if err != nil { - t.Fatal("create table error:", err) - } - - // copy db to new file - var data []byte - data, err = ioutil.ReadFile(tempFilename) - if err != nil { - t.Fatal("read file error:", err) - } - - var f *os.File - copyFilename := tempFilename + "-db-copy" - f, err = os.Create(copyFilename) - if err != nil { - t.Fatal("create file error:", err) - } - defer os.Remove(copyFilename) - - _, err = f.Write(data) - if err != nil { - f.Close() - t.Fatal("write file error:", err) - } - err = f.Close() - if err != nil { - t.Fatal("close file error:", err) - } - - // truncate current db file - f, err = os.OpenFile(tempFilename, os.O_WRONLY|os.O_TRUNC, 0666) - if err != nil { - t.Fatal("open file error:", err) - } - err = f.Close() - if err != nil { - t.Fatal("close file error:", err) - } - - // test db after file truncate - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - err = db.PingContext(ctx) - cancel() - if err != nil { - t.Fatal("ping error:", err) - } - - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Millisecond) - _, err = db.ExecContext(ctx, "drop table foo") - cancel() - if err == nil { - t.Fatal("drop table no error") - } - - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - _, err = db.ExecContext(ctx, "create table foo (id integer)") - cancel() - if err != nil { - t.Fatal("create table error:", err) - } - - err = db.Close() - if err != nil { - t.Error("db close error:", err) - } - - // test copied file - db, err = sql.Open(DriverName, copyFilename) - if err != nil { - t.Fatal("open error:", err) - } - defer db.Close() - - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - err = db.PingContext(ctx) - cancel() - if err != nil { - t.Fatal("ping error:", err) - } - - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - _, err = db.ExecContext(ctx, "drop table foo") - cancel() - if err != nil { - t.Fatal("drop table error:", err) - } - - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - _, err = db.ExecContext(ctx, "create table foo (id integer)") - cancel() - if err != nil { - t.Fatal("create table error:", err) - } -} - -func TestColumnTableName(t *testing.T) { - d := SQLiteDriver{} - conn, err := d.Open(":memory:") - if err != nil { - t.Fatal("failed to get database connection:", err) - } - defer conn.Close() - sqlite3conn := conn.(*SQLiteConn) - - _, err = sqlite3conn.Exec(`CREATE TABLE foo (name string)`, nil) - if err != nil { - t.Fatal("Failed to create table:", err) - } - _, err = sqlite3conn.Exec(`CREATE TABLE bar (name string)`, nil) - if err != nil { - t.Fatal("Failed to create table:", err) - } - - stmt, err := sqlite3conn.Prepare(`SELECT * FROM foo JOIN bar ON foo.name = bar.name`) - if err != nil { - t.Fatal(err) - } - - if exp, got := "foo", stmt.(*SQLiteStmt).ColumnTableName(0); exp != got { - t.Fatalf("Incorrect table name returned expected: %s, got: %s", exp, got) - } - if exp, got := "bar", stmt.(*SQLiteStmt).ColumnTableName(1); exp != got { - t.Fatalf("Incorrect table name returned expected: %s, got: %s", exp, got) - } - if exp, got := "", stmt.(*SQLiteStmt).ColumnTableName(2); exp != got { - t.Fatalf("Incorrect table name returned expected: %s, got: %s", exp, got) - } -} - -func TestFTS5(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("DROP TABLE foo") - _, err = db.Exec("CREATE VIRTUAL TABLE foo USING fts5(id, value)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - _, err = db.Exec("INSERT INTO foo(id, value) VALUES(?, ?)", 1, `今日の 晩御飯は 天麩羅よ`) - if err != nil { - t.Fatal("Failed to insert value:", err) - } - - _, err = db.Exec("INSERT INTO foo(id, value) VALUES(?, ?)", 2, `今日は いい 天気だ`) - if err != nil { - t.Fatal("Failed to insert value:", err) - } - - rows, err := db.Query("SELECT id, value FROM foo WHERE value MATCH '今日* 天*'") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - defer rows.Close() - - for rows.Next() { - var id int - var value string - - if err := rows.Scan(&id, &value); err != nil { - t.Error("Unable to scan results:", err) - continue - } - - if id == 1 && value != `今日の 晩御飯は 天麩羅よ` { - t.Error("Value for id 1 should be `今日の 晩御飯は 天麩羅よ`, but:", value) - } else if id == 2 && value != `今日は いい 天気だ` { - t.Error("Value for id 2 should be `今日は いい 天気だ`, but:", value) - } - } - - rows, err = db.Query("SELECT value FROM foo WHERE value MATCH '今日* 天麩羅*'") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - defer rows.Close() - - var value string - if !rows.Next() { - t.Fatal("Result should be only one") - } - - if err := rows.Scan(&value); err != nil { - t.Fatal("Unable to scan results:", err) - } - - if value != `今日の 晩御飯は 天麩羅よ` { - t.Fatal("Value should be `今日の 晩御飯は 天麩羅よ`, but:", value) - } - - if rows.Next() { - t.Fatal("Result should be only one") - } - - _, err = db.Exec("DROP TABLE foo") - _, err = db.Exec("CREATE VIRTUAL TABLE foo USING fts5(tokenize=unicode61, id, value)") - switch { - case err != nil && err.Error() == "unknown tokenizer: unicode61": - t.Skip("FTS4 not supported") - case err != nil: - t.Fatal("Failed to create table:", err) - } - - _, err = db.Exec("INSERT INTO foo(id, value) VALUES(?, ?)", 1, `février`) - if err != nil { - t.Fatal("Failed to insert value:", err) - } - - rows, err = db.Query("SELECT value FROM foo WHERE value MATCH 'fevrier'") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - defer rows.Close() - - if !rows.Next() { - t.Fatal("Result should be only one") - } - - if err := rows.Scan(&value); err != nil { - t.Fatal("Unable to scan results:", err) - } - - if value != `février` { - t.Fatal("Value should be `février`, but:", value) - } - - if rows.Next() { - t.Fatal("Result should be only one") - } -} - -type preUpdateHookDataForTest struct { - databaseName string - tableName string - count int - op int - oldRow []any - newRow []any -} - -func TestSerializeDeserialize(t *testing.T) { - // Connect to the source database. - srcDb, err := sql.Open(DriverName, "file:src?mode=memory&cache=shared") - if err != nil { - t.Fatal("Failed to open the source database:", err) - } - defer srcDb.Close() - err = srcDb.Ping() - if err != nil { - t.Fatal("Failed to connect to the source database:", err) - } - - // Connect to the destination database. - destDb, err := sql.Open(DriverName, "file:dst?mode=memory&cache=shared") - if err != nil { - t.Fatal("Failed to open the destination database:", err) - } - defer destDb.Close() - err = destDb.Ping() - if err != nil { - t.Fatal("Failed to connect to the destination database:", err) - } - - // Write data to source database. - _, err = srcDb.Exec(`CREATE TABLE foo (name string)`) - if err != nil { - t.Fatal("Failed to create table in source database:", err) - } - _, err = srcDb.Exec(`INSERT INTO foo(name) VALUES('alice')`) - if err != nil { - t.Fatal("Failed to insert data into source database", err) - } - - // Serialize the source database - srcConn, err := srcDb.Conn(context.Background()) - if err != nil { - t.Fatal("Failed to get connection to source database:", err) - } - defer srcConn.Close() - - var serialized []byte - if err := srcConn.Raw(func(raw any) error { - var err error - serialized, err = raw.(*SQLiteConn).Serialize("") - return err - }); err != nil { - t.Fatal("Failed to serialize source database:", err) - } - srcConn.Close() - - // Confirm that the destination database is initially empty. - var destTableCount int - err = destDb.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'").Scan(&destTableCount) - if err != nil { - t.Fatal("Failed to check the destination table count:", err) - } - if destTableCount != 0 { - t.Fatalf("The destination database is not empty; %v table(s) found.", destTableCount) - } - - // Deserialize to destination database - destConn, err := destDb.Conn(context.Background()) - if err != nil { - t.Fatal("Failed to get connection to destination database:", err) - } - defer destConn.Close() - - if err := destConn.Raw(func(raw any) error { - return raw.(*SQLiteConn).Deserialize(serialized, "") - }); err != nil { - t.Fatal("Failed to deserialize source database:", err) - } - destConn.Close() - - // Confirm that destination database has been loaded correctly. - var destRowCount int - err = destDb.QueryRow(`SELECT COUNT(*) FROM foo`).Scan(&destRowCount) - if err != nil { - t.Fatal("Failed to count rows in destination database table", err) - } - if destRowCount != 1 { - t.Fatalf("Destination table does not have the expected records") - } -} - -func TestUnlockNotify(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - dsn := fmt.Sprintf("file:%s?cache=shared&mode=memory", tempFilename) - db, err := sql.Open(DriverName, dsn) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("CREATE TABLE foo(id INTEGER, status INTEGER)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - tx, err := db.Begin() - if err != nil { - t.Fatal("Failed to begin transaction:", err) - } - - _, err = tx.Exec("INSERT INTO foo(id, status) VALUES(1, 100)") - if err != nil { - t.Fatal("Failed to insert null:", err) - } - - _, err = tx.Exec("UPDATE foo SET status = 200 WHERE id = 1") - if err != nil { - t.Fatal("Failed to update table:", err) - } - - wg := sync.WaitGroup{} - wg.Add(1) - timer := time.NewTimer(500 * time.Millisecond) - go func() { - <-timer.C - err := tx.Commit() - if err != nil { - t.Fatal("Failed to commit transaction:", err) - } - wg.Done() - }() - - rows, err := db.Query("SELECT count(*) from foo") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - - if rows.Next() { - var count int - if err := rows.Scan(&count); err != nil { - t.Fatal("Failed to Scan rows", err) - } - } - if err := rows.Err(); err != nil { - t.Fatal("Failed at the call to Next:", err) - } - wg.Wait() - -} - -func TestUnlockNotifyMany(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - dsn := fmt.Sprintf("file:%s?cache=shared&mode=memory", tempFilename) - db, err := sql.Open(DriverName, dsn) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("CREATE TABLE foo(id INTEGER, status INTEGER)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - tx, err := db.Begin() - if err != nil { - t.Fatal("Failed to begin transaction:", err) - } - - _, err = tx.Exec("INSERT INTO foo(id, status) VALUES(1, 100)") - if err != nil { - t.Fatal("Failed to insert null:", err) - } - - _, err = tx.Exec("UPDATE foo SET status = 200 WHERE id = 1") - if err != nil { - t.Fatal("Failed to update table:", err) - } - - wg := sync.WaitGroup{} - wg.Add(1) - timer := time.NewTimer(500 * time.Millisecond) - go func() { - <-timer.C - err := tx.Commit() - if err != nil { - t.Fatal("Failed to commit transaction:", err) - } - wg.Done() - }() - - const concurrentQueries = 1000 - wg.Add(concurrentQueries) - for i := 0; i < concurrentQueries; i++ { - go func() { - rows, err := db.Query("SELECT count(*) from foo") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - - if rows.Next() { - var count int - if err := rows.Scan(&count); err != nil { - t.Fatal("Failed to Scan rows", err) - } - } - if err := rows.Err(); err != nil { - t.Fatal("Failed at the call to Next:", err) - } - wg.Done() - }() - } - wg.Wait() -} - -func TestUnlockNotifyDeadlock(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - dsn := fmt.Sprintf("file:%s?cache=shared&mode=memory", tempFilename) - db, err := sql.Open(DriverName, dsn) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("CREATE TABLE foo(id INTEGER, status INTEGER)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - tx, err := db.Begin() - if err != nil { - t.Fatal("Failed to begin transaction:", err) - } - - _, err = tx.Exec("INSERT INTO foo(id, status) VALUES(1, 100)") - if err != nil { - t.Fatal("Failed to insert null:", err) - } - - _, err = tx.Exec("UPDATE foo SET status = 200 WHERE id = 1") - if err != nil { - t.Fatal("Failed to update table:", err) - } - - wg := sync.WaitGroup{} - wg.Add(1) - timer := time.NewTimer(500 * time.Millisecond) - go func() { - <-timer.C - err := tx.Commit() - if err != nil { - t.Fatal("Failed to commit transaction:", err) - } - wg.Done() - }() - - wg.Add(1) - go func() { - tx2, err := db.Begin() - if err != nil { - t.Fatal("Failed to begin transaction:", err) - } - defer tx2.Rollback() - - _, err = tx2.Exec("DELETE FROM foo") - if err != nil { - t.Fatal("Failed to delete table:", err) - } - err = tx2.Commit() - if err != nil { - t.Fatal("Failed to commit transaction:", err) - } - wg.Done() - }() - - rows, err := tx.Query("SELECT count(*) from foo") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - - if rows.Next() { - var count int - if err := rows.Scan(&count); err != nil { - t.Fatal("Failed to Scan rows", err) - } - } - if err := rows.Err(); err != nil { - t.Fatal("Failed at the call to Next:", err) - } - - wg.Wait() -} - -func getRowCount(rows *sql.Rows) (int, error) { - var i int - for rows.Next() { - i++ - } - return i, nil -} - -func TempFilename(t testing.TB) string { - f, err := ioutil.TempFile("", "go-sqlite3-test-") - if err != nil { - t.Fatal(err) - } - f.Close() - return f.Name() -} - -func doTestOpen(t *testing.T, url string) (string, error) { - db, err := sql.Open(DriverName, url) - if err != nil { - return "Failed to open database:", err - } - - defer func() { - err = db.Close() - if err != nil { - t.Error("db close error:", err) - } - }() - - err = db.Ping() - if err != nil { - return "ping error:", err - } - - _, err = db.Exec("drop table foo") - _, err = db.Exec("create table foo (id integer)") - if err != nil { - return "Failed to create table:", err - } - - return "", nil -} - -func TestOpenWithVFS(t *testing.T) { - { - uri := fmt.Sprintf("file:%s?mode=memory&vfs=hello", t.Name()) - db, err := sql.Open(DriverName, uri) - if err != nil { - t.Fatal("Failed to open", err) - } - err = db.Ping() - if err == nil { - t.Fatal("Failed to open", err) - } - db.Close() - } - - { - uri := fmt.Sprintf("file:%s?mode=memory&vfs=unix-none", t.Name()) - db, err := sql.Open(DriverName, uri) - if err != nil { - t.Fatal("Failed to open", err) - } - err = db.Ping() - if err != nil { - t.Fatal("Failed to ping", err) - } - db.Close() - } -} - -func TestOpenNoCreate(t *testing.T) { - filename := t.Name() + ".sqlite" - - if err := os.Remove(filename); err != nil && !os.IsNotExist(err) { - t.Fatal(err) - } - defer os.Remove(filename) - - // https://golang.org/pkg/database/sql/#Open - // "Open may just validate its arguments without creating a connection - // to the database. To verify that the data source name is valid, call Ping." - db, err := sql.Open(DriverName, fmt.Sprintf("file:%s?mode=rw", filename)) - if err == nil { - defer db.Close() - - err = db.Ping() - if err == nil { - t.Fatal("expected error from Open or Ping") - } - } - - sqlErr, ok := err.(Error) - if !ok { - t.Fatalf("expected sqlite3.Error, but got %T", err) - } - - if sqlErr.Code != ErrCantOpen { - t.Fatalf("expected SQLITE_CANTOPEN, but got %v", sqlErr) - } - - // make sure database file truly was not created - if _, err := os.Stat(filename); !os.IsNotExist(err) { - if err != nil { - t.Fatal(err) - } - t.Fatal("expected database file to not exist") - } - - // verify that it works if the mode is "rwc" instead - db, err = sql.Open(DriverName, fmt.Sprintf("file:%s?mode=rwc", filename)) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - if err := db.Ping(); err != nil { - t.Fatal(err) - } - - // make sure database file truly was created - if _, err := os.Stat(filename); err != nil { - if !os.IsNotExist(err) { - t.Fatal(err) - } - t.Fatal("expected database file to exist") - } -} - -func TestReadonly(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - - db1, err := sql.Open(DriverName, "file:"+tempFilename) - if err != nil { - t.Fatal(err) - } - defer db1.Close() - db1.Exec("CREATE TABLE test (x int, y float)") - - db2, err := sql.Open(DriverName, "file:"+tempFilename+"?mode=ro") - if err != nil { - t.Fatal(err) - } - defer db2.Close() - _ = db2 - _, err = db2.Exec("INSERT INTO test VALUES (1, 3.14)") - if err == nil { - t.Fatal("didn't expect INSERT into read-only database to work") - } -} - -func TestDeferredForeignKey(t *testing.T) { - fname := TempFilename(t) - uri := "file:" + fname + "?_foreign_keys=1&mode=memory" - db, err := sql.Open(DriverName, uri) - if err != nil { - os.Remove(fname) - t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err) - } - _, err = db.Exec("CREATE TABLE bar (id INTEGER PRIMARY KEY)") - if err != nil { - t.Errorf("failed creating tables: %v", err) - } - _, err = db.Exec("CREATE TABLE foo (bar_id INTEGER, FOREIGN KEY(bar_id) REFERENCES bar(id) DEFERRABLE INITIALLY DEFERRED)") - if err != nil { - t.Errorf("failed creating tables: %v", err) - } - tx, err := db.Begin() - if err != nil { - t.Errorf("Failed to begin transaction: %v", err) - } - _, err = tx.Exec("INSERT INTO foo (bar_id) VALUES (123)") - if err != nil { - t.Errorf("Failed to insert row: %v", err) - } - err = tx.Commit() - if err == nil { - t.Errorf("Expected an error: %v", err) - } - _, err = db.Begin() - if err != nil { - t.Errorf("Failed to begin transaction: %v", err) - } - - db.Close() - os.Remove(fname) -} - -func TestClose(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - - _, err = db.Exec("drop table foo") - _, err = db.Exec("create table foo (id integer)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - stmt, err := db.Prepare("select id from foo where id = ?") - if err != nil { - t.Fatal("Failed to select records:", err) - } - - db.Close() - _, err = stmt.Exec(1) - if err == nil { - t.Fatal("Failed to operate closed statement") - } -} - -func TestInsert(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("drop table foo") - _, err = db.Exec("create table foo (id integer)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - res, err := db.Exec("insert into foo(id) values(123)") - if err != nil { - t.Fatal("Failed to insert record:", err) - } - affected, _ := res.RowsAffected() - if affected != 1 { - t.Fatalf("Expected %d for affected rows, but %d:", 1, affected) - } - - rows, err := db.Query("select id from foo") - if err != nil { - t.Fatal("Failed to select records:", err) - } - defer rows.Close() - - rows.Next() - - var result int - rows.Scan(&result) - if result != 123 { - t.Errorf("Expected %d for fetched result, but %d:", 123, result) - } -} - -func TestUpsert(t *testing.T) { - _, n, _ := LibVersion() - if n < 3024000 { - t.Skip("UPSERT requires sqlite3 >= 3.24.0") - } - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("drop table foo") - _, err = db.Exec("create table foo (name string primary key, counter integer)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - for i := 0; i < 10; i++ { - res, err := db.Exec("insert into foo(name, counter) values('key', 1) on conflict (name) do update set counter=counter+1") - if err != nil { - t.Fatal("Failed to upsert record:", err) - } - affected, _ := res.RowsAffected() - if affected != 1 { - t.Fatalf("Expected %d for affected rows, but %d:", 1, affected) - } - } - rows, err := db.Query("select name, counter from foo") - if err != nil { - t.Fatal("Failed to select records:", err) - } - defer rows.Close() - - rows.Next() - - var resultName string - var resultCounter int - rows.Scan(&resultName, &resultCounter) - if resultName != "key" { - t.Errorf("Expected %s for fetched result, but %s:", "key", resultName) - } - if resultCounter != 10 { - t.Errorf("Expected %d for fetched result, but %d:", 10, resultCounter) - } - -} - -func TestUpdate(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("drop table foo") - _, err = db.Exec("create table foo (id integer)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - res, err := db.Exec("insert into foo(id) values(123)") - if err != nil { - t.Fatal("Failed to insert record:", err) - } - expected, err := res.LastInsertId() - if err != nil { - t.Fatal("Failed to get LastInsertId:", err) - } - affected, _ := res.RowsAffected() - if err != nil { - t.Fatal("Failed to get RowsAffected:", err) - } - if affected != 1 { - t.Fatalf("Expected %d for affected rows, but %d:", 1, affected) - } - - res, err = db.Exec("update foo set id = 234") - if err != nil { - t.Fatal("Failed to update record:", err) - } - lastID, err := res.LastInsertId() - if err != nil { - t.Fatal("Failed to get LastInsertId:", err) - } - if expected != lastID { - t.Errorf("Expected %q for last Id, but %q:", expected, lastID) - } - affected, _ = res.RowsAffected() - if err != nil { - t.Fatal("Failed to get RowsAffected:", err) - } - if affected != 1 { - t.Fatalf("Expected %d for affected rows, but %d:", 1, affected) - } - - rows, err := db.Query("select id from foo") - if err != nil { - t.Fatal("Failed to select records:", err) - } - defer rows.Close() - - rows.Next() - - var result int - rows.Scan(&result) - if result != 234 { - t.Errorf("Expected %d for fetched result, but %d:", 234, result) - } -} - -func TestDelete(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("drop table foo") - _, err = db.Exec("create table foo (id integer)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - res, err := db.Exec("insert into foo(id) values(123)") - if err != nil { - t.Fatal("Failed to insert record:", err) - } - expected, err := res.LastInsertId() - if err != nil { - t.Fatal("Failed to get LastInsertId:", err) - } - affected, err := res.RowsAffected() - if err != nil { - t.Fatal("Failed to get RowsAffected:", err) - } - if affected != 1 { - t.Errorf("Expected %d for cout of affected rows, but %q:", 1, affected) - } - - res, err = db.Exec("delete from foo where id = 123") - if err != nil { - t.Fatal("Failed to delete record:", err) - } - lastID, err := res.LastInsertId() - if err != nil { - t.Fatal("Failed to get LastInsertId:", err) - } - if expected != lastID { - t.Errorf("Expected %q for last Id, but %q:", expected, lastID) - } - affected, err = res.RowsAffected() - if err != nil { - t.Fatal("Failed to get RowsAffected:", err) - } - if affected != 1 { - t.Errorf("Expected %d for cout of affected rows, but %q:", 1, affected) - } - - rows, err := db.Query("select id from foo") - if err != nil { - t.Fatal("Failed to select records:", err) - } - defer rows.Close() - - if rows.Next() { - t.Error("Fetched row but expected not rows") - } -} - -func TestBooleanRoundtrip(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("DROP TABLE foo") - _, err = db.Exec("CREATE TABLE foo(id INTEGER, value BOOL)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - _, err = db.Exec("INSERT INTO foo(id, value) VALUES(1, ?)", true) - if err != nil { - t.Fatal("Failed to insert true value:", err) - } - - _, err = db.Exec("INSERT INTO foo(id, value) VALUES(2, ?)", false) - if err != nil { - t.Fatal("Failed to insert false value:", err) - } - - rows, err := db.Query("SELECT id, value FROM foo") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - defer rows.Close() - - for rows.Next() { - var id int - var value bool - - if err := rows.Scan(&id, &value); err != nil { - t.Error("Unable to scan results:", err) - continue - } - - if id == 1 && !value { - t.Error("Value for id 1 should be true, not false") - - } else if id == 2 && value { - t.Error("Value for id 2 should be false, not true") - } - } -} - -func timezone(t time.Time) string { return t.Format("-07:00") } - -func TestTimestamp(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("DROP TABLE foo") - _, err = db.Exec("CREATE TABLE foo(id INTEGER, ts timeSTAMP, dt DATETIME)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC) - timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC) - timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC) - tzTest := time.FixedZone("TEST", -9*3600-13*60) - tests := []struct { - value any - expected time.Time - }{ - {"nonsense", time.Time{}}, - {"0000-00-00 00:00:00", time.Time{}}, - {time.Time{}.Unix(), time.Time{}}, - {timestamp1, timestamp1}, - {timestamp2.Unix(), timestamp2.Truncate(time.Second)}, - {timestamp2.UnixNano() / int64(time.Millisecond), timestamp2.Truncate(time.Millisecond)}, - {timestamp1.In(tzTest), timestamp1.In(tzTest)}, - {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1}, - {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1}, - {timestamp1.Format("2006-01-02 15:04:05"), timestamp1}, - {timestamp1.Format("2006-01-02T15:04:05"), timestamp1}, - {timestamp2, timestamp2}, - {"2006-01-02 15:04:05.123456789", timestamp2}, - {"2006-01-02T15:04:05.123456789", timestamp2}, - {"2006-01-02T05:51:05.123456789-09:13", timestamp2.In(tzTest)}, - {"2012-11-04", timestamp3}, - {"2012-11-04 00:00", timestamp3}, - {"2012-11-04 00:00:00", timestamp3}, - {"2012-11-04 00:00:00.000", timestamp3}, - {"2012-11-04T00:00", timestamp3}, - {"2012-11-04T00:00:00", timestamp3}, - {"2012-11-04T00:00:00.000", timestamp3}, - {"2006-01-02T15:04:05.123456789Z", timestamp2}, - {"2012-11-04Z", timestamp3}, - {"2012-11-04 00:00Z", timestamp3}, - {"2012-11-04 00:00:00Z", timestamp3}, - {"2012-11-04 00:00:00.000Z", timestamp3}, - {"2012-11-04T00:00Z", timestamp3}, - {"2012-11-04T00:00:00Z", timestamp3}, - {"2012-11-04T00:00:00.000Z", timestamp3}, - } - for i := range tests { - _, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value) - if err != nil { - t.Fatal("Failed to insert timestamp:", err) - } - } - - rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - defer rows.Close() - - seen := 0 - for rows.Next() { - var id int - var ts, dt time.Time - - if err := rows.Scan(&id, &ts, &dt); err != nil { - t.Error("Unable to scan results:", err) - continue - } - if id < 0 || id >= len(tests) { - t.Error("Bad row id: ", id) - continue - } - seen++ - if !tests[id].expected.Equal(ts) { - t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) - } - if !tests[id].expected.Equal(dt) { - t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) - } - if timezone(tests[id].expected) != timezone(ts) { - t.Errorf("Timezone for id %v (%v) should be %v, not %v", id, tests[id].value, - timezone(tests[id].expected), timezone(ts)) - } - if timezone(tests[id].expected) != timezone(dt) { - t.Errorf("Timezone for id %v (%v) should be %v, not %v", id, tests[id].value, - timezone(tests[id].expected), timezone(dt)) - } - } - - if seen != len(tests) { - t.Errorf("Expected to see %d rows", len(tests)) - } -} - -func TestBoolean(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - - defer db.Close() - - _, err = db.Exec("CREATE TABLE foo(id INTEGER, fbool BOOLEAN)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - bool1 := true - _, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(1, ?)", bool1) - if err != nil { - t.Fatal("Failed to insert boolean:", err) - } - - bool2 := false - _, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(2, ?)", bool2) - if err != nil { - t.Fatal("Failed to insert boolean:", err) - } - - bool3 := "nonsense" - _, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(3, ?)", bool3) - if err != nil { - t.Fatal("Failed to insert nonsense:", err) - } - - rows, err := db.Query("SELECT id, fbool FROM foo where fbool = ?", bool1) - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - counter := 0 - - var id int - var fbool bool - - for rows.Next() { - if err := rows.Scan(&id, &fbool); err != nil { - t.Fatal("Unable to scan results:", err) - } - counter++ - } - - if counter != 1 { - t.Fatalf("Expected 1 row but %v", counter) - } - - if id != 1 && !fbool { - t.Fatalf("Value for id 1 should be %v, not %v", bool1, fbool) - } - - rows, err = db.Query("SELECT id, fbool FROM foo where fbool = ?", bool2) - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - - counter = 0 - - for rows.Next() { - if err := rows.Scan(&id, &fbool); err != nil { - t.Fatal("Unable to scan results:", err) - } - counter++ - } - - if counter != 1 { - t.Fatalf("Expected 1 row but %v", counter) - } - - if id != 2 && fbool { - t.Fatalf("Value for id 2 should be %v, not %v", bool2, fbool) - } - - // make sure "nonsense" triggered an error - rows, err = db.Query("SELECT id, fbool FROM foo where id=?;", 3) - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - - rows.Next() - err = rows.Scan(&id, &fbool) - if err == nil { - t.Error("Expected error from \"nonsense\" bool") - } -} - -func TestFloat32(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("CREATE TABLE foo(id INTEGER)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - _, err = db.Exec("INSERT INTO foo(id) VALUES(null)") - if err != nil { - t.Fatal("Failed to insert null:", err) - } - - rows, err := db.Query("SELECT id FROM foo") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - - if !rows.Next() { - t.Fatal("Unable to query results:", err) - } - - var id any - if err := rows.Scan(&id); err != nil { - t.Fatal("Unable to scan results:", err) - } - if id != nil { - t.Error("Expected nil but not") - } -} - -func TestNull(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - rows, err := db.Query("SELECT 3.141592") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - - if !rows.Next() { - t.Fatal("Unable to query results:", err) - } - - var v any - if err := rows.Scan(&v); err != nil { - t.Fatal("Unable to scan results:", err) - } - f, ok := v.(float64) - if !ok { - t.Error("Expected float but not") - } - if f != 3.141592 { - t.Error("Expected 3.141592 but not") - } -} - -func TestTransaction(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("CREATE TABLE foo(id INTEGER)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - tx, err := db.Begin() - if err != nil { - t.Fatal("Failed to begin transaction:", err) - } - - _, err = tx.Exec("INSERT INTO foo(id) VALUES(1)") - if err != nil { - t.Fatal("Failed to insert null:", err) - } - - rows, err := tx.Query("SELECT id from foo") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - - err = tx.Rollback() - if err != nil { - t.Fatal("Failed to rollback transaction:", err) - } - - if rows.Next() { - t.Fatal("Unable to query results:", err) - } - - tx, err = db.Begin() - if err != nil { - t.Fatal("Failed to begin transaction:", err) - } - - _, err = tx.Exec("INSERT INTO foo(id) VALUES(1)") - if err != nil { - t.Fatal("Failed to insert null:", err) - } - - err = tx.Commit() - if err != nil { - t.Fatal("Failed to commit transaction:", err) - } - - rows, err = tx.Query("SELECT id from foo") - if err == nil { - t.Fatal("Expected failure to query") - } -} - -func TestWAL(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - if _, err = db.Exec("CREATE TABLE test (id SERIAL, user TEXT NOT NULL, name TEXT NOT NULL);"); err != nil { - t.Fatal("Failed to Exec CREATE TABLE:", err) - } - if _, err = db.Exec("INSERT INTO test (user, name) VALUES ('user','name');"); err != nil { - t.Fatal("Failed to Exec INSERT:", err) - } - - trans, err := db.Begin() - if err != nil { - t.Fatal("Failed to Begin:", err) - } - s, err := trans.Prepare("INSERT INTO test (user, name) VALUES (?, ?);") - if err != nil { - t.Fatal("Failed to Prepare:", err) - } - - var count int - if err = trans.QueryRow("SELECT count(user) FROM test;").Scan(&count); err != nil { - t.Fatal("Failed to QueryRow:", err) - } - if _, err = s.Exec("bbbb", "aaaa"); err != nil { - t.Fatal("Failed to Exec prepared statement:", err) - } - if err = s.Close(); err != nil { - t.Fatal("Failed to Close prepared statement:", err) - } - if err = trans.Commit(); err != nil { - t.Fatal("Failed to Commit:", err) - } -} - -func TestTimezoneConversion(t *testing.T) { - zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} - for _, tz := range zones { - db, err := sql.Open(DriverName, "file:tz?mode=memory&_loc="+url.QueryEscape(tz)) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("DROP TABLE foo") - _, err = db.Exec("CREATE TABLE foo(id INTEGER, ts TIMESTAMP, dt DATETIME)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - loc, err := time.LoadLocation(tz) - if err != nil { - t.Fatal("Failed to load location:", err) - } - - timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC) - timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC) - timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC) - tests := []struct { - value any - expected time.Time - }{ - {"nonsense", time.Time{}.In(loc)}, - {"0000-00-00 00:00:00", time.Time{}.In(loc)}, - {timestamp1, timestamp1.In(loc)}, - {timestamp1.Unix(), timestamp1.In(loc)}, - {timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1.In(loc)}, - {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1.In(loc)}, - {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1.In(loc)}, - {timestamp1.Format("2006-01-02 15:04:05"), timestamp1.In(loc)}, - {timestamp1.Format("2006-01-02T15:04:05"), timestamp1.In(loc)}, - {timestamp2, timestamp2.In(loc)}, - {"2006-01-02 15:04:05.123456789", timestamp2.In(loc)}, - {"2006-01-02T15:04:05.123456789", timestamp2.In(loc)}, - {"2012-11-04", timestamp3.In(loc)}, - {"2012-11-04 00:00", timestamp3.In(loc)}, - {"2012-11-04 00:00:00", timestamp3.In(loc)}, - {"2012-11-04 00:00:00.000", timestamp3.In(loc)}, - {"2012-11-04T00:00", timestamp3.In(loc)}, - {"2012-11-04T00:00:00", timestamp3.In(loc)}, - {"2012-11-04T00:00:00.000", timestamp3.In(loc)}, - } - for i := range tests { - _, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value) - if err != nil { - t.Fatal("Failed to insert timestamp:", err) - } - } - - rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - defer rows.Close() - - seen := 0 - for rows.Next() { - var id int - var ts, dt time.Time - - if err := rows.Scan(&id, &ts, &dt); err != nil { - t.Error("Unable to scan results:", err) - continue - } - if id < 0 || id >= len(tests) { - t.Error("Bad row id: ", id) - continue - } - seen++ - if !tests[id].expected.Equal(ts) { - t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, ts) - } - if !tests[id].expected.Equal(dt) { - t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) - } - if tests[id].expected.Location().String() != ts.Location().String() { - t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), ts.Location().String()) - } - if tests[id].expected.Location().String() != dt.Location().String() { - t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), dt.Location().String()) - } - } - - if seen != len(tests) { - t.Errorf("Expected to see %d rows", len(tests)) - } - } -} - -// TODO: Execer & Queryer currently disabled -// https://github.com/mattn/go-sqlite3/issues/82 -func TestExecer(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec(` - create table foo (id integer); -- one comment - insert into foo(id) values(?); - insert into foo(id) values(?); - insert into foo(id) values(?); -- another comment - `, 1, 2, 3) - if err != nil { - t.Error("Failed to call db.Exec:", err) - } -} - -func TestQueryer(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec(` - create table foo (id integer); - `) - if err != nil { - t.Error("Failed to call db.Query:", err) - } - - _, err = db.Exec(` - insert into foo(id) values(?); - insert into foo(id) values(?); - insert into foo(id) values(?); - `, 3, 2, 1) - if err != nil { - t.Error("Failed to call db.Exec:", err) - } - rows, err := db.Query(` - select id from foo order by id; - `) - if err != nil { - t.Error("Failed to call db.Query:", err) - } - defer rows.Close() - n := 0 - for rows.Next() { - var id int - err = rows.Scan(&id) - if err != nil { - t.Error("Failed to db.Query:", err) - } - if id != n+1 { - t.Error("Failed to db.Query: not matched results") - } - n = n + 1 - } - if err := rows.Err(); err != nil { - t.Errorf("Post-scan failed: %v\n", err) - } - if n != 3 { - t.Errorf("Expected 3 rows but retrieved %v", n) - } -} - -func TestStress(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - db.Exec("CREATE TABLE foo (id int);") - db.Exec("INSERT INTO foo VALUES(1);") - db.Exec("INSERT INTO foo VALUES(2);") - - for i := 0; i < 10000; i++ { - for j := 0; j < 3; j++ { - rows, err := db.Query("select * from foo where id=1;") - if err != nil { - t.Error("Failed to call db.Query:", err) - } - for rows.Next() { - var i int - if err := rows.Scan(&i); err != nil { - t.Errorf("Scan failed: %v\n", err) - } - } - if err := rows.Err(); err != nil { - t.Errorf("Post-scan failed: %v\n", err) - } - rows.Close() - } - } -} - -func TestDateTimeLocal(t *testing.T) { - const zone = "Asia/Tokyo" - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - filename1 := tempFilename + "?mode=memory&cache=shared" - filename2 := filename1 + "&_loc=" + zone - db1, err := sql.Open(DriverName, filename2) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db1.Close() - db1.Exec("CREATE TABLE foo (dt datetime);") - db1.Exec("INSERT INTO foo VALUES('2015-03-05 15:16:17');") - - row := db1.QueryRow("select * from foo") - var d time.Time - err = row.Scan(&d) - if err != nil { - t.Fatal("Failed to scan datetime:", err) - } - if d.Hour() == 15 || !strings.Contains(d.String(), "JST") { - t.Fatal("Result should have timezone", d) - } - - db2, err := sql.Open(DriverName, filename1) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db2.Close() - - row = db2.QueryRow("select * from foo") - err = row.Scan(&d) - if err != nil { - t.Fatal("Failed to scan datetime:", err) - } - if d.UTC().Hour() != 15 || !strings.Contains(d.String(), "UTC") { - t.Fatalf("Result should not have timezone %v %v", zone, d.String()) - } - - _, err = db2.Exec("DELETE FROM foo") - if err != nil { - t.Fatal("Failed to delete table:", err) - } - dt, err := time.Parse("2006/1/2 15/4/5 -0700 MST", "2015/3/5 15/16/17 +0900 JST") - if err != nil { - t.Fatal("Failed to parse datetime:", err) - } - db2.Exec("INSERT INTO foo VALUES(?);", dt) - - db3, err := sql.Open(DriverName, filename2) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db3.Close() - - row = db3.QueryRow("select * from foo") - err = row.Scan(&d) - if err != nil { - t.Fatal("Failed to scan datetime:", err) - } - if d.Hour() != 15 || !strings.Contains(d.String(), "JST") { - t.Fatalf("Result should have timezone %v %v", zone, d.String()) - } -} - -func TestVersion(t *testing.T) { - s, n, id := LibVersion() - if s == "" || n == 0 || id == "" { - t.Errorf("Version failed %q, %d, %q\n", s, n, id) - } -} - -func TestStringContainingZero(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec(` - create table foo (id integer, name, extra text); - `) - if err != nil { - t.Error("Failed to call db.Query:", err) - } - - const text = "foo\x00bar" - - _, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, text) - if err != nil { - t.Error("Failed to call db.Exec:", err) - } - - row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, text) - if row == nil { - t.Error("Failed to call db.QueryRow") - } - - var id int - var extra string - err = row.Scan(&id, &extra) - if err != nil { - t.Error("Failed to db.Scan:", err) - } - if id != 1 || extra != text { - t.Error("Failed to db.QueryRow: not matched results") - } -} - -const CurrentTimeStamp = "2006-01-02 15:04:05" - -type TimeStamp struct{ *time.Time } - -func (t TimeStamp) Scan(value any) error { - var err error - switch v := value.(type) { - case string: - *t.Time, err = time.Parse(CurrentTimeStamp, v) - case []byte: - *t.Time, err = time.Parse(CurrentTimeStamp, string(v)) - default: - err = errors.New("invalid type for current_timestamp") - } - return err -} - -func (t TimeStamp) Value() (driver.Value, error) { - return t.Time.Format(CurrentTimeStamp), nil -} - -func TestDateTimeNow(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - var d time.Time - err = db.QueryRow("SELECT datetime('now')").Scan(TimeStamp{&d}) - if err != nil { - t.Fatal("Failed to scan datetime:", err) - } -} - -type sumAggregator int64 - -func (s *sumAggregator) Step(x int64) { - *s += sumAggregator(x) -} - -func (s *sumAggregator) Done() int64 { - return int64(*s) -} - -func TestAggregatorRegistration(t *testing.T) { - customSum := func() *sumAggregator { - var ret sumAggregator - return &ret - } - - sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{ - ConnectHook: func(conn *SQLiteConn) error { - return conn.RegisterAggregator("customSum", customSum, true) - }, - }) - db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("create table foo (department integer, profits integer)") - if err != nil { - // trace feature is not implemented - t.Skip("Failed to create table:", err) - } - - _, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)") - if err != nil { - t.Fatal("Failed to insert records:", err) - } - - tests := []struct { - dept, sum int64 - }{ - {1, 30}, - {2, 42}, - } - - for _, test := range tests { - var ret int64 - err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret) - if err != nil { - t.Fatal("Query failed:", err) - } - if ret != test.sum { - t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum) - } - } -} - -type mode struct { - counts map[any]int - top any - topCount int -} - -func newMode() *mode { - return &mode{ - counts: map[any]int{}, - } -} - -func (m *mode) Step(x any) { - m.counts[x]++ - c := m.counts[x] - if c > m.topCount { - m.top = x - m.topCount = c - } -} - -func (m *mode) Done() any { - return m.top -} - -func TestAggregatorRegistration_GenericReturn(t *testing.T) { - sql.Register("sqlite3_AggregatorRegistration_GenericReturn", &SQLiteDriver{ - ConnectHook: func(conn *SQLiteConn) error { - return conn.RegisterAggregator("mode", newMode, true) - }, - }) - db, err := sql.Open("sqlite3_AggregatorRegistration_GenericReturn", ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("create table foo (department integer, profits integer)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - _, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115), (2, 20)") - if err != nil { - t.Fatal("Failed to insert records:", err) - } - - var mode int - err = db.QueryRow("select mode(profits) from foo").Scan(&mode) - if err != nil { - t.Fatal("MODE query error:", err) - } - - if mode != 20 { - t.Fatal("Got incorrect mode. Wanted 20, got: ", mode) - } -} - -func rot13(r rune) rune { - switch { - case r >= 'A' && r <= 'Z': - return 'A' + (r-'A'+13)%26 - case r >= 'a' && r <= 'z': - return 'a' + (r-'a'+13)%26 - } - return r -} - -func TestCollationRegistration(t *testing.T) { - collateRot13 := func(a, b string) int { - ra, rb := strings.Map(rot13, a), strings.Map(rot13, b) - return strings.Compare(ra, rb) - } - collateRot13Reverse := func(a, b string) int { - return collateRot13(b, a) - } - - sql.Register("sqlite3_CollationRegistration", &SQLiteDriver{ - ConnectHook: func(conn *SQLiteConn) error { - if err := conn.RegisterCollation("rot13", collateRot13); err != nil { - return err - } - if err := conn.RegisterCollation("rot13reverse", collateRot13Reverse); err != nil { - return err - } - return nil - }, - }) - - db, err := sql.Open("sqlite3_CollationRegistration", ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - populate := []string{ - `CREATE TABLE test (s TEXT)`, - `INSERT INTO test VALUES ('aaaa')`, - `INSERT INTO test VALUES ('ffff')`, - `INSERT INTO test VALUES ('qqqq')`, - `INSERT INTO test VALUES ('tttt')`, - `INSERT INTO test VALUES ('zzzz')`, - } - for _, stmt := range populate { - if _, err := db.Exec(stmt); err != nil { - t.Fatal("Failed to populate test DB:", err) - } - } - - ops := []struct { - query string - want []string - }{ - { - "SELECT * FROM test ORDER BY s COLLATE rot13 ASC", - []string{ - "qqqq", - "tttt", - "zzzz", - "aaaa", - "ffff", - }, - }, - { - "SELECT * FROM test ORDER BY s COLLATE rot13 DESC", - []string{ - "ffff", - "aaaa", - "zzzz", - "tttt", - "qqqq", - }, - }, - { - "SELECT * FROM test ORDER BY s COLLATE rot13reverse ASC", - []string{ - "ffff", - "aaaa", - "zzzz", - "tttt", - "qqqq", - }, - }, - { - "SELECT * FROM test ORDER BY s COLLATE rot13reverse DESC", - []string{ - "qqqq", - "tttt", - "zzzz", - "aaaa", - "ffff", - }, - }, - } - - for _, op := range ops { - rows, err := db.Query(op.query) - if err != nil { - t.Fatalf("Query %q failed: %s", op.query, err) - } - got := []string{} - defer rows.Close() - for rows.Next() { - var s string - if err = rows.Scan(&s); err != nil { - t.Fatalf("Reading row for %q: %s", op.query, err) - } - got = append(got, s) - } - if err = rows.Err(); err != nil { - t.Fatalf("Reading rows for %q: %s", op.query, err) - } - - if !reflect.DeepEqual(got, op.want) { - t.Fatalf("Unexpected output from %q\ngot:\n%s\n\nwant:\n%s", op.query, strings.Join(got, "\n"), strings.Join(op.want, "\n")) - } - } -} - -func TestDeclTypes(t *testing.T) { - - d := SQLiteDriver{} - - conn, err := d.Open(":memory:") - if err != nil { - t.Fatal("Failed to begin transaction:", err) - } - defer conn.Close() - - sqlite3conn := conn.(*SQLiteConn) - - _, err = sqlite3conn.Exec("create table foo (id integer not null primary key, name text)", nil) - if err != nil { - t.Fatal("Failed to create table:", err) - } - - _, err = sqlite3conn.Exec("insert into foo(name) values('bar')", nil) - if err != nil { - t.Fatal("Failed to insert:", err) - } - - rs, err := sqlite3conn.Query("select * from foo", nil) - if err != nil { - t.Fatal("Failed to select:", err) - } - defer rs.Close() - - declTypes := rs.(*SQLiteRows).DeclTypes() - - if !reflect.DeepEqual(declTypes, []string{"integer", "text"}) { - t.Fatal("Unexpected declTypes:", declTypes) - } -} - -func TestPinger(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal(err) - } - err = db.Ping() - if err != nil { - t.Fatal(err) - } - db.Close() - err = db.Ping() - if err == nil { - t.Fatal("Should be closed") - } -} - -func TestUpdateAndTransactionHooks(t *testing.T) { - var events []string - var commitHookReturn = 0 - - sql.Register("sqlite3_UpdateHook", &SQLiteDriver{ - ConnectHook: func(conn *SQLiteConn) error { - conn.RegisterCommitHook(func() int { - events = append(events, "commit") - return commitHookReturn - }) - conn.RegisterRollbackHook(func() { - events = append(events, "rollback") - }) - conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) { - events = append(events, fmt.Sprintf("update(op=%v db=%v table=%v rowid=%v)", op, db, table, rowid)) - }) - return nil - }, - }) - db, err := sql.Open("sqlite3_UpdateHook", ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - statements := []string{ - "create table foo (id integer primary key)", - "insert into foo values (9)", - "update foo set id = 99 where id = 9", - "delete from foo where id = 99", - } - for _, statement := range statements { - _, err = db.Exec(statement) - if err != nil { - t.Fatalf("Unable to prepare test data [%v]: %v", statement, err) - } - } - - commitHookReturn = 1 - _, err = db.Exec("insert into foo values (5)") - if err == nil { - t.Error("Commit hook failed to rollback transaction") - } - - var expected = []string{ - "commit", - fmt.Sprintf("update(op=%v db=main table=foo rowid=9)", SQLITE_INSERT), - "commit", - fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_UPDATE), - "commit", - fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_DELETE), - "commit", - fmt.Sprintf("update(op=%v db=main table=foo rowid=5)", SQLITE_INSERT), - "commit", - "rollback", - } - if !reflect.DeepEqual(events, expected) { - t.Errorf("Expected notifications: %#v\nbut got: %#v\n", expected, events) - } -} - -func TestSetFileControlInt(t *testing.T) { - t.Run("PERSIST_WAL", func(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - - sql.Register("sqlite3_FCNTL_PERSIST_WAL", &SQLiteDriver{ - ConnectHook: func(conn *SQLiteConn) error { - if err := conn.SetFileControlInt("", SQLITE_FCNTL_PERSIST_WAL, 1); err != nil { - return fmt.Errorf("Unexpected error from SetFileControlInt(): %w", err) - } - return nil - }, - }) - - db, err := sql.Open("sqlite3_FCNTL_PERSIST_WAL", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - if _, err := db.Exec(`CREATE TABLE t (x)`); err != nil { - t.Fatal("Failed to create table:", err) - } - if err := db.Close(); err != nil { - t.Fatal("Failed to close database", err) - } - - // Ensure WAL file persists after close. - if _, err := os.Stat(tempFilename + "-wal"); err != nil { - t.Fatal("Expected WAL file to be persisted after close", err) - } - }) -} - -func TestNonColumnString(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - var x any - if err := db.QueryRow("SELECT 'hello'").Scan(&x); err != nil { - t.Fatal(err) - } - s, ok := x.(string) - if !ok { - t.Fatalf("non-column string must return string but got %T", x) - } - if s != "hello" { - t.Fatalf("non-column string must return %q but got %q", "hello", s) - } -} - -func TestNilAndEmptyBytes(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal(err) - } - defer db.Close() - actualNil := []byte("use this to use an actual nil not a reference to nil") - emptyBytes := []byte{} - for tsti, tst := range []struct { - name string - columnType string - insertBytes []byte - expectedBytes []byte - }{ - {"actual nil blob", "blob", actualNil, nil}, - {"referenced nil blob", "blob", nil, nil}, - {"empty blob", "blob", emptyBytes, emptyBytes}, - {"actual nil text", "text", actualNil, nil}, - {"referenced nil text", "text", nil, nil}, - {"empty text", "text", emptyBytes, emptyBytes}, - } { - if _, err = db.Exec(fmt.Sprintf("create table tbl%d (txt %s)", tsti, tst.columnType)); err != nil { - t.Fatal(tst.name, err) - } - if bytes.Equal(tst.insertBytes, actualNil) { - if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), nil); err != nil { - t.Fatal(tst.name, err) - } - } else { - if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), &tst.insertBytes); err != nil { - t.Fatal(tst.name, err) - } - } - rows, err := db.Query(fmt.Sprintf("select txt from tbl%d", tsti)) - if err != nil { - t.Fatal(tst.name, err) - } - if !rows.Next() { - t.Fatal(tst.name, "no rows") - } - var scanBytes []byte - if err = rows.Scan(&scanBytes); err != nil { - t.Fatal(tst.name, err) - } - if err = rows.Err(); err != nil { - t.Fatal(tst.name, err) - } - if tst.expectedBytes == nil && scanBytes != nil { - t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes) - } else if !bytes.Equal(scanBytes, tst.expectedBytes) { - t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes) - } - } -} - -func TestInsertNilByteSlice(t *testing.T) { - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal(err) - } - defer db.Close() - if _, err := db.Exec("create table blob_not_null (b blob not null)"); err != nil { - t.Fatal(err) - } - var nilSlice []byte - if _, err := db.Exec("insert into blob_not_null (b) values (?)", nilSlice); err == nil { - t.Fatal("didn't expect INSERT to 'not null' column with a nil []byte slice to work") - } - zeroLenSlice := []byte{} - if _, err := db.Exec("insert into blob_not_null (b) values (?)", zeroLenSlice); err != nil { - t.Fatal("failed to insert zero-length slice") - } -} - -func TestNamedParam(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open(DriverName, ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("drop table foo") - _, err = db.Exec("create table foo (id integer, name text, amount integer)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - _, err = db.Exec("insert into foo(id, name, amount) values(:id, @name, $amount)", - sql.Named("bar", 42), sql.Named("baz", "quux"), - sql.Named("amount", 123), sql.Named("corge", "waldo"), - sql.Named("id", 2), sql.Named("name", "grault")) - if err != nil { - t.Fatal("Failed to insert record with named parameters:", err) - } - - rows, err := db.Query("select id, name, amount from foo") - if err != nil { - t.Fatal("Failed to select records:", err) - } - defer rows.Close() - - rows.Next() - - var id, amount int - var name string - rows.Scan(&id, &name, &amount) - if id != 2 || name != "grault" || amount != 123 { - t.Errorf("Expected %d, %q, %d for fetched result, but got %d, %q, %d:", 2, "grault", 123, id, name, amount) - } -} - -var customFunctionOnce sync.Once - -func TestSuite(t *testing.T) { - initializeTestDB(t) - defer freeTestDB() - - for _, test := range tests { - t.Run(test.Name, test.F) - } -} - -// DB provide context for the tests -type TestDB struct { - testing.TB - *sql.DB - once sync.Once - tempFilename string -} - -var tdb *TestDB - -func initializeTestDB(t testing.TB) { - tempFilename := TempFilename(t) - d, err := sql.Open(DriverName, tempFilename+"?mode=memory&cache=shared") - if err != nil { - os.Remove(tempFilename) - t.Fatal(err) - } - - tdb = &TestDB{t, d, sync.Once{}, tempFilename} -} - -func freeTestDB() { - err := tdb.DB.Close() - if err != nil { - panic(err) - } - err = os.Remove(tdb.tempFilename) - if err != nil { - panic(err) - } -} - -// the following tables will be created and dropped during the test -var testTables = []string{"foo"} - -var tests = []testing.InternalTest{ - {Name: "TestResult", F: testResult}, - {Name: "TestBlobs", F: testBlobs}, - {Name: "TestMultiBlobs", F: testMultiBlobs}, - {Name: "TestNullZeroLengthBlobs", F: testNullZeroLengthBlobs}, - {Name: "TestManyQueryRow", F: testManyQueryRow}, - {Name: "TestTxQuery", F: testTxQuery}, - {Name: "TestPreparedStmt", F: testPreparedStmt}, - {Name: "TestExecEmptyQuery", F: testExecEmptyQuery}, -} - -func (db *TestDB) mustExec(sql string, args ...any) sql.Result { - res, err := db.Exec(sql, args...) - if err != nil { - db.Fatalf("Error running %q: %v", sql, err) - } - return res -} - -func (db *TestDB) tearDown() { - for _, tbl := range testTables { - db.mustExec("drop table if exists " + tbl) - } -} - -// testResult is test for result -func testResult(t *testing.T) { - tdb.tearDown() - tdb.mustExec("create temporary table test (id integer primary key autoincrement, name varchar(10))") - - for i := 1; i < 3; i++ { - r := tdb.mustExec("insert into test (name) values (?)", fmt.Sprintf("row %d", i)) - n, err := r.RowsAffected() - if err != nil { - t.Fatal(err) - } - if n != 1 { - t.Errorf("got %v, want %v", n, 1) - } - n, err = r.LastInsertId() - if err != nil { - t.Fatal(err) - } - if n != int64(i) { - t.Errorf("got %v, want %v", n, i) - } - } - if _, err := tdb.Exec("error!"); err == nil { - t.Fatalf("expected error") - } -} - -// testBlobs is test for blobs -func testBlobs(t *testing.T) { - tdb.tearDown() - var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - tdb.mustExec("create table foo (id integer primary key, bar blob[16])") - tdb.mustExec("insert into foo (id, bar) values(?,?)", 0, blob) - - want := fmt.Sprintf("%x", blob) - - b := make([]byte, 16) - err := tdb.QueryRow("select bar from foo where id = ?", 0).Scan(&b) - got := fmt.Sprintf("%x", b) - if err != nil { - t.Errorf("[]byte scan: %v", err) - } else if got != want { - t.Errorf("for []byte, got %q; want %q", got, want) - } - - err = tdb.QueryRow("select bar from foo where id = ?", 0).Scan(&got) - want = string(blob) - if err != nil { - t.Errorf("string scan: %v", err) - } else if got != want { - t.Errorf("for string, got %q; want %q", got, want) - } -} - -func testMultiBlobs(t *testing.T) { - tdb.tearDown() - tdb.mustExec("create table foo (id integer primary key, bar blob[16])") - var blob0 = []byte{0, 1, 2, 3, 4, 5, 6, 7} - tdb.mustExec("insert into foo (id, bar) values(?,?)", 0, blob0) - var blob1 = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - tdb.mustExec("insert into foo (id, bar) values(?,?)", 1, blob1) - - r, err := tdb.Query("select bar from foo order by id") - if err != nil { - t.Fatal(err) - } - defer r.Close() - if !r.Next() { - if r.Err() != nil { - t.Fatal(err) - } - t.Fatal("expected one rows") - } - - want0 := fmt.Sprintf("%x", blob0) - b0 := make([]byte, 8) - err = r.Scan(&b0) - if err != nil { - t.Fatal(err) - } - got0 := fmt.Sprintf("%x", b0) - - if !r.Next() { - if r.Err() != nil { - t.Fatal(err) - } - t.Fatal("expected one rows") - } - - want1 := fmt.Sprintf("%x", blob1) - b1 := make([]byte, 16) - err = r.Scan(&b1) - if err != nil { - t.Fatal(err) - } - got1 := fmt.Sprintf("%x", b1) - if got0 != want0 { - t.Errorf("for []byte, got %q; want %q", got0, want0) - } - if got1 != want1 { - t.Errorf("for []byte, got %q; want %q", got1, want1) - } -} - -// testBlobs tests that we distinguish between null and zero-length blobs -func testNullZeroLengthBlobs(t *testing.T) { - tdb.tearDown() - tdb.mustExec("create table foo (id integer primary key, bar blob[16])") - tdb.mustExec("insert into foo (id, bar) values(?,?)", 0, nil) - tdb.mustExec("insert into foo (id, bar) values(?,?)", 1, []byte{}) - - r0 := tdb.QueryRow("select bar from foo where id=0") - var b0 []byte - err := r0.Scan(&b0) - if err != nil { - t.Fatal(err) - } - if b0 != nil { - t.Errorf("for id=0, got %x; want nil", b0) - } - - r1 := tdb.QueryRow("select bar from foo where id=1") - var b1 []byte - err = r1.Scan(&b1) - if err != nil { - t.Fatal(err) - } - if b1 == nil { - t.Error("for id=1, got nil; want zero-length slice") - } else if len(b1) > 0 { - t.Errorf("for id=1, got %x; want zero-length slice", b1) - } -} - -func testManyQueryRow(t *testing.T) { - // FIXME: too slow - tdb.tearDown() - tdb.mustExec("create table foo (id integer primary key, name varchar(50))") - tdb.mustExec("insert into foo (id, name) values(?,?)", 1, "bob") - var name string - for i := 0; i < 10000; i++ { - err := tdb.QueryRow("select name from foo where id = ?", 1).Scan(&name) - if err != nil || name != "bob" { - t.Fatalf("on query %d: err=%v, name=%q", i, err, name) - } - } -} - -func testTxQuery(t *testing.T) { - tdb.tearDown() - tx, err := tdb.Begin() - if err != nil { - t.Fatal(err) - } - defer tx.Rollback() - - _, err = tx.Exec("create table foo (id integer primary key, name varchar(50))") - if err != nil { - t.Fatal(err) - } - - _, err = tx.Exec("insert into foo (id, name) values(?,?)", 1, "bob") - if err != nil { - t.Fatal(err) - } - - r, err := tx.Query("select name from foo where id = ?", 1) - if err != nil { - t.Fatal(err) - } - defer r.Close() - - if !r.Next() { - if r.Err() != nil { - t.Fatal(err) - } - t.Fatal("expected one rows") - } - - var name string - err = r.Scan(&name) - if err != nil { - t.Fatal(err) - } -} - -func testPreparedStmt(t *testing.T) { - tdb.tearDown() - tdb.mustExec("CREATE TABLE t (count INT)") - sel, err := tdb.Prepare("SELECT count FROM t ORDER BY count DESC") - if err != nil { - t.Fatalf("prepare 1: %v", err) - } - ins, err := tdb.Prepare("INSERT INTO t (count) VALUES (?)") - if err != nil { - t.Fatalf("prepare 2: %v", err) - } - - for n := 1; n <= 3; n++ { - if _, err := ins.Exec(n); err != nil { - t.Fatalf("insert(%d) = %v", n, err) - } - } - - const nRuns = 10 - var wg sync.WaitGroup - for i := 0; i < nRuns; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for j := 0; j < 10; j++ { - count := 0 - if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows { - t.Errorf("Query: %v", err) - return - } - if _, err := ins.Exec(rand.Intn(100)); err != nil { - t.Errorf("Insert: %v", err) - return - } - } - }() - } - wg.Wait() -} - -// testEmptyQuery is test for validating the API in case of empty query -func testExecEmptyQuery(t *testing.T) { - tdb.tearDown() - res, err := tdb.Exec(" -- this is just a comment ") - if err != nil { - t.Fatalf("empty query err: %v", err) - } - - _, err = res.LastInsertId() - if err != nil { - t.Fatalf("LastInsertId returned an error: %v", err) - } - - _, err = res.RowsAffected() - if err != nil { - t.Fatalf("RowsAffected returned an error: %v", err) - } -} - - - -func MainTest() { - tests := []testing.InternalTest { - { "TestBackupStepByStep", TestBackupStepByStep }, - { "TestBackupAllRemainingPages", TestBackupAllRemainingPages }, - { "TestBackupError", TestBackupError }, - { "TestCallbackArgCast", TestCallbackArgCast }, - { "TestCallbackConverters", TestCallbackConverters }, - { "TestCallbackReturnAny", TestCallbackReturnAny }, - { "TestSimpleError", TestSimpleError }, - { "TestCorruptDbErrors", TestCorruptDbErrors }, - { "TestSqlLogicErrors", TestSqlLogicErrors }, - { "TestExtendedErrorCodes_ForeignKey", TestExtendedErrorCodes_ForeignKey }, - { "TestExtendedErrorCodes_NotNull", TestExtendedErrorCodes_NotNull }, - { "TestExtendedErrorCodes_Unique", TestExtendedErrorCodes_Unique }, - { "TestError_SystemErrno", TestError_SystemErrno }, - { "TestBeginTxCancel", TestBeginTxCancel }, - { "TestStmtReadonly", TestStmtReadonly }, - { "TestNamedParams", TestNamedParams }, - { "TestShortTimeout", TestShortTimeout }, - { "TestExecContextCancel", TestExecContextCancel }, - { "TestQueryRowContextCancel", TestQueryRowContextCancel }, - { "TestQueryRowContextCancelParallel", TestQueryRowContextCancelParallel }, - { "TestExecCancel", TestExecCancel }, - { "TestFileCopyTruncate", TestFileCopyTruncate }, - { "TestColumnTableName", TestColumnTableName }, - { "TestFTS5", TestFTS5 }, - { "TestSerializeDeserialize", TestSerializeDeserialize }, - { "TestOpenWithVFS", TestOpenWithVFS }, - { "TestOpenNoCreate", TestOpenNoCreate }, - { "TestReadonly", TestReadonly }, - { "TestDeferredForeignKey", TestDeferredForeignKey }, - { "TestClose", TestClose }, - { "TestInsert", TestInsert }, - { "TestUpsert", TestUpsert }, - { "TestUpdate", TestUpdate }, - { "TestDelete", TestDelete }, - { "TestBooleanRoundtrip", TestBooleanRoundtrip }, - { "TestTimestamp", TestTimestamp }, - { "TestBoolean", TestBoolean }, - { "TestFloat32", TestFloat32 }, - { "TestNull", TestNull }, - { "TestTransaction", TestTransaction }, - { "TestWAL", TestWAL }, - { "TestTimezoneConversion", TestTimezoneConversion }, - { "TestExecer", TestExecer }, - { "TestQueryer", TestQueryer }, - { "TestStress", TestStress }, - { "TestDateTimeLocal", TestDateTimeLocal }, - { "TestVersion", TestVersion }, - { "TestStringContainingZero", TestStringContainingZero }, - { "TestDateTimeNow", TestDateTimeNow }, - { "TestAggregatorRegistration", TestAggregatorRegistration }, - { "TestAggregatorRegistration_GenericReturn", TestAggregatorRegistration_GenericReturn }, - { "TestCollationRegistration", TestCollationRegistration }, - { "TestDeclTypes", TestDeclTypes }, - { "TestPinger", TestPinger }, - { "TestUpdateAndTransactionHooks", TestUpdateAndTransactionHooks }, - { "TestSetFileControlInt", TestSetFileControlInt }, - { "TestNonColumnString", TestNonColumnString }, - { "TestNilAndEmptyBytes", TestNilAndEmptyBytes }, - { "TestInsertNilByteSlice", TestInsertNilByteSlice }, - { "TestNamedParam", TestNamedParam }, - { "TestSuite", TestSuite }, // FIXME: too slow - } - - deps := testdeps.TestDeps{} - benchmarks := []testing.InternalBenchmark {} - fuzzTargets := []testing.InternalFuzzTarget{} - examples := []testing.InternalExample {} - m := testing.MainStart(deps, tests, benchmarks, fuzzTargets, examples) - os.Exit(m.Run()) -} diff --git a/tests/benchmarks/exec/acudego.go b/tests/benchmarks/exec/acudego.go deleted file mode 100644 index 79e7992..0000000 --- a/tests/benchmarks/exec/acudego.go +++ /dev/null @@ -1,31 +0,0 @@ -package acudego - -import ( - "database/sql" - "flag" -) - - - -var nFlag = flag.Int( - "n", - 1_000_000, - "The number of iterations to execute", -) - -func MainTest() { - flag.Parse() - n := *nFlag - - db, err := sql.Open(DriverName, "file:bench.db?mode=memory&cache=shared") - if err != nil { - panic(err) - } - - for i := 0; i < n; i++ { - _, err = db.Exec("SELECT 1;") - if err != nil { - panic(err) - } - } -} diff --git a/tests/benchmarks/exec/golite.go b/tests/benchmarks/exec/golite.go new file mode 100644 index 0000000..f861c3c --- /dev/null +++ b/tests/benchmarks/exec/golite.go @@ -0,0 +1,31 @@ +package golite + +import ( + "database/sql" + "flag" +) + + + +var nFlag = flag.Int( + "n", + 1_000_000, + "The number of iterations to execute", +) + +func MainTest() { + flag.Parse() + n := *nFlag + + db, err := sql.Open(DriverName, "file:bench.db?mode=memory&cache=shared") + if err != nil { + panic(err) + } + + for i := 0; i < n; i++ { + _, err = db.Exec("SELECT 1;") + if err != nil { + panic(err) + } + } +} diff --git a/tests/benchmarks/query/acudego.go b/tests/benchmarks/query/acudego.go deleted file mode 100644 index 3386576..0000000 --- a/tests/benchmarks/query/acudego.go +++ /dev/null @@ -1,42 +0,0 @@ -package acudego - -import ( - "database/sql" - "flag" -) - - - -var nFlag = flag.Int( - "n", - 100_000, - "The number of iterations to execute", -) - -func MainTest() { - flag.Parse() - n := *nFlag - - db, err := sql.Open(DriverName, "file:benchdb?mode=memory&cache=shared") - if err != nil { - panic(err) - } - - var ( - S sql.NullString - I int - f float64 - s string - ) - for i := 0; i < n; i++ { - err = db.QueryRow("SELECT NULL, 1, 1.1, 'string';").Scan( - &S, - &I, - &f, - &s, - ) - if err != nil { - panic(err) - } - } -} diff --git a/tests/benchmarks/query/golite.go b/tests/benchmarks/query/golite.go new file mode 100644 index 0000000..a900356 --- /dev/null +++ b/tests/benchmarks/query/golite.go @@ -0,0 +1,42 @@ +package golite + +import ( + "database/sql" + "flag" +) + + + +var nFlag = flag.Int( + "n", + 100_000, + "The number of iterations to execute", +) + +func MainTest() { + flag.Parse() + n := *nFlag + + db, err := sql.Open(DriverName, "file:benchdb?mode=memory&cache=shared") + if err != nil { + panic(err) + } + + var ( + S sql.NullString + I int + f float64 + s string + ) + for i := 0; i < n; i++ { + err = db.QueryRow("SELECT NULL, 1, 1.1, 'string';").Scan( + &S, + &I, + &f, + &s, + ) + if err != nil { + panic(err) + } + } +} diff --git a/tests/functional/json/acudego.go b/tests/functional/json/acudego.go deleted file mode 100644 index 8f5e923..0000000 --- a/tests/functional/json/acudego.go +++ /dev/null @@ -1,98 +0,0 @@ -package acudego - -import ( - "database/sql" - "database/sql/driver" - "encoding/json" - "errors" - "log" - "os" -) - - - -type Tag struct { - Name string `json:"name"` - Place string `json:"place"` -} - -func (t *Tag) Scan(value interface{}) error { - return json.Unmarshal([]byte(value.(string)), t) -} - -func (t *Tag) Value() (driver.Value, error) { - b, err := json.Marshal(t) - return string(b), err -} - - - -func MainTest() { - os.Remove("json.db") - defer os.Remove("json.db") - - db, err := sql.Open(DriverName, "json.db") - if err != nil { - log.Fatal(err) - } - defer db.Close() - - _, err = db.Exec(`create table myjsontable (tag jsonb)`) - if err != nil { - log.Fatal(err) - } - - stmt, err := db.Prepare("insert into myjsontable(tag) values(?)") - if err != nil { - log.Fatal(err) - } - defer stmt.Close() - - _, err = stmt.Exec(`{"name": "name1", "place": "right-here"}`) - if err != nil { - log.Fatal(err) - } - - _, err = stmt.Exec(`{"name": "michael", "place": "usa"}`) - if err != nil { - log.Fatal(err) - } - - var place string - err = db.QueryRow("select tag->>'place' from myjsontable where tag->>'name' = 'name1'").Scan(&place) - if err != nil { - log.Fatal(err) - } - - if place != "right-here" { - log.Fatal(errors.New("expected right-here, got: " + place)) - } - - var tag Tag - err = db.QueryRow("select tag from myjsontable where tag->>'name' = 'name1'").Scan(&tag) - if err != nil { - log.Fatal(err) - } - - if tag.Name != "name1" { - log.Fatal(errors.New("expected name1, got: " + tag.Name)) - } - if tag.Place != "right-here" { - log.Fatal(errors.New("expected right-here, got: " + tag.Place)) - } - - tag.Place = "日本" - _, err = db.Exec(`update myjsontable set tag = ? where tag->>'name' == 'name1'`, &tag) - if err != nil { - log.Fatal(err) - } - - err = db.QueryRow("select tag->>'place' from myjsontable where tag->>'name' = 'name1'").Scan(&place) - if err != nil { - log.Fatal(err) - } - - if place != "日本" { - log.Fatal(errors.New("expected 日本, got: " + place)) - } -} diff --git a/tests/functional/json/golite.go b/tests/functional/json/golite.go new file mode 100644 index 0000000..0fa696e --- /dev/null +++ b/tests/functional/json/golite.go @@ -0,0 +1,98 @@ +package golite + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "log" + "os" +) + + + +type Tag struct { + Name string `json:"name"` + Place string `json:"place"` +} + +func (t *Tag) Scan(value interface{}) error { + return json.Unmarshal([]byte(value.(string)), t) +} + +func (t *Tag) Value() (driver.Value, error) { + b, err := json.Marshal(t) + return string(b), err +} + + + +func MainTest() { + os.Remove("json.db") + defer os.Remove("json.db") + + db, err := sql.Open(DriverName, "json.db") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + _, err = db.Exec(`create table myjsontable (tag jsonb)`) + if err != nil { + log.Fatal(err) + } + + stmt, err := db.Prepare("insert into myjsontable(tag) values(?)") + if err != nil { + log.Fatal(err) + } + defer stmt.Close() + + _, err = stmt.Exec(`{"name": "name1", "place": "right-here"}`) + if err != nil { + log.Fatal(err) + } + + _, err = stmt.Exec(`{"name": "michael", "place": "usa"}`) + if err != nil { + log.Fatal(err) + } + + var place string + err = db.QueryRow("select tag->>'place' from myjsontable where tag->>'name' = 'name1'").Scan(&place) + if err != nil { + log.Fatal(err) + } + + if place != "right-here" { + log.Fatal(errors.New("expected right-here, got: " + place)) + } + + var tag Tag + err = db.QueryRow("select tag from myjsontable where tag->>'name' = 'name1'").Scan(&tag) + if err != nil { + log.Fatal(err) + } + + if tag.Name != "name1" { + log.Fatal(errors.New("expected name1, got: " + tag.Name)) + } + if tag.Place != "right-here" { + log.Fatal(errors.New("expected right-here, got: " + tag.Place)) + } + + tag.Place = "日本" + _, err = db.Exec(`update myjsontable set tag = ? where tag->>'name' == 'name1'`, &tag) + if err != nil { + log.Fatal(err) + } + + err = db.QueryRow("select tag->>'place' from myjsontable where tag->>'name' = 'name1'").Scan(&place) + if err != nil { + log.Fatal(err) + } + + if place != "日本" { + log.Fatal(errors.New("expected 日本, got: " + place)) + } +} diff --git a/tests/functional/limit/acudego.go b/tests/functional/limit/acudego.go deleted file mode 100644 index d47dc27..0000000 --- a/tests/functional/limit/acudego.go +++ /dev/null @@ -1,104 +0,0 @@ -package acudego - -import ( - "database/sql" - "fmt" - "log" - "os" - "strings" -) - - - -func createBulkInsertQuery(n int, start int) (string, []any) { - values := make([]string, n) - args := make([]any, n*2) - pos := 0 - for i := 0; i < n; i++ { - values[i] = "(?, ?)" - args[pos] = start + i - args[pos+1] = fmt.Sprintf("こんにちは世界%03d", i) - pos += 2 - } - query := fmt.Sprintf( - "insert into mylimittable(id, name) values %s", - strings.Join(values, ", "), - ) - return query, args -} - -func bulkInsert(db *sql.DB, query string, args []any) error { - stmt, err := db.Prepare(query) - if err != nil { - return err - } - - _, err = stmt.Exec(args...) - return err -} - - - -func MainTest() { - const ( - num = 400 - smallLimit = 100 - bigLimit = 999999 - ) - - const SQL = ` - create table mylimittable (id integer not null primary key, name text); - delete from mylimittable; - ` - - var conn *SQLiteConn - sql.Register("sqlite3_with_limit", &SQLiteDriver{ - ConnectHook: func(c *SQLiteConn) error { - conn = c - return nil - }, - }) - - os.Remove("limit.db") - defer os.Remove("limit.db") - db, err := sql.Open("sqlite3_with_limit", "limit.db") - if err != nil { - log.Fatal(err) - } - defer db.Close() - - _, err = db.Exec(SQL) - if err != nil { - log.Fatal(err) - } - - if conn == nil { - log.Fatal("not set sqlite3 connection") - } - - { - query, args := createBulkInsertQuery(num, 0) - err := bulkInsert(db, query, args) - if err != nil { - log.Fatal(err) - } - } - - conn.SetLimit(SQLITE_LIMIT_VARIABLE_NUMBER, smallLimit) - { - query, args := createBulkInsertQuery(num, num) - err := bulkInsert(db, query, args) - if err == nil { - log.Fatal("expected failure didn't happen") - } - } - - conn.SetLimit(SQLITE_LIMIT_VARIABLE_NUMBER, bigLimit) - { - query, args := createBulkInsertQuery(500, num+num) - err := bulkInsert(db, query, args) - if err != nil { - log.Fatal(err) - } - } -} diff --git a/tests/functional/limit/golite.go b/tests/functional/limit/golite.go new file mode 100644 index 0000000..231e8c2 --- /dev/null +++ b/tests/functional/limit/golite.go @@ -0,0 +1,104 @@ +package golite + +import ( + "database/sql" + "fmt" + "log" + "os" + "strings" +) + + + +func createBulkInsertQuery(n int, start int) (string, []any) { + values := make([]string, n) + args := make([]any, n*2) + pos := 0 + for i := 0; i < n; i++ { + values[i] = "(?, ?)" + args[pos] = start + i + args[pos+1] = fmt.Sprintf("こんにちは世界%03d", i) + pos += 2 + } + query := fmt.Sprintf( + "insert into mylimittable(id, name) values %s", + strings.Join(values, ", "), + ) + return query, args +} + +func bulkInsert(db *sql.DB, query string, args []any) error { + stmt, err := db.Prepare(query) + if err != nil { + return err + } + + _, err = stmt.Exec(args...) + return err +} + + + +func MainTest() { + const ( + num = 400 + smallLimit = 100 + bigLimit = 999999 + ) + + const SQL = ` + create table mylimittable (id integer not null primary key, name text); + delete from mylimittable; + ` + + var conn *SQLiteConn + sql.Register("sqlite3_with_limit", &SQLiteDriver{ + ConnectHook: func(c *SQLiteConn) error { + conn = c + return nil + }, + }) + + os.Remove("limit.db") + defer os.Remove("limit.db") + db, err := sql.Open("sqlite3_with_limit", "limit.db") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + _, err = db.Exec(SQL) + if err != nil { + log.Fatal(err) + } + + if conn == nil { + log.Fatal("not set sqlite3 connection") + } + + { + query, args := createBulkInsertQuery(num, 0) + err := bulkInsert(db, query, args) + if err != nil { + log.Fatal(err) + } + } + + conn.SetLimit(SQLITE_LIMIT_VARIABLE_NUMBER, smallLimit) + { + query, args := createBulkInsertQuery(num, num) + err := bulkInsert(db, query, args) + if err == nil { + log.Fatal("expected failure didn't happen") + } + } + + conn.SetLimit(SQLITE_LIMIT_VARIABLE_NUMBER, bigLimit) + { + query, args := createBulkInsertQuery(500, num+num) + err := bulkInsert(db, query, args) + if err != nil { + log.Fatal(err) + } + } +} diff --git a/tests/fuzz/api/acudego.go b/tests/fuzz/api/acudego.go deleted file mode 100644 index 7f8e3a0..0000000 --- a/tests/fuzz/api/acudego.go +++ /dev/null @@ -1,34 +0,0 @@ -package acudego - -import ( - "os" - "testing" - "testing/internal/testdeps" -) - - - -func FuzzAPI(f *testing.F) { - f.Add(123) - f.Fuzz(func(t *testing.T, n int) { - // FIXME - if n == 1234 { - t.Errorf("Failed n: %q\n", n) - } - }) -} - - - -func MainTest() { - fuzzTargets := []testing.InternalFuzzTarget{ - { "FuzzAPI", FuzzAPI }, - } - - deps := testdeps.TestDeps{} - tests := []testing.InternalTest {} - benchmarks := []testing.InternalBenchmark{} - examples := []testing.InternalExample {} - m := testing.MainStart(deps, tests, benchmarks, fuzzTargets, examples) - os.Exit(m.Run()) -} diff --git a/tests/fuzz/api/golite.go b/tests/fuzz/api/golite.go new file mode 100644 index 0000000..1c86d04 --- /dev/null +++ b/tests/fuzz/api/golite.go @@ -0,0 +1,34 @@ +package golite + +import ( + "os" + "testing" + "testing/internal/testdeps" +) + + + +func FuzzAPI(f *testing.F) { + f.Add(123) + f.Fuzz(func(t *testing.T, n int) { + // FIXME + if n == 1234 { + t.Errorf("Failed n: %q\n", n) + } + }) +} + + + +func MainTest() { + fuzzTargets := []testing.InternalFuzzTarget{ + { "FuzzAPI", FuzzAPI }, + } + + deps := testdeps.TestDeps{} + tests := []testing.InternalTest {} + benchmarks := []testing.InternalBenchmark{} + examples := []testing.InternalExample {} + m := testing.MainStart(deps, tests, benchmarks, fuzzTargets, examples) + os.Exit(m.Run()) +} diff --git a/tests/golite.go b/tests/golite.go new file mode 100644 index 0000000..2459993 --- /dev/null +++ b/tests/golite.go @@ -0,0 +1,3652 @@ +package golite + +import ( + "bytes" + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "io/ioutil" + "math" + "math/rand" + "net/url" + "os" + "path" + "reflect" + "strings" + "sync" + "testing" + "testing/internal/testdeps" + "time" +) + + + +// The number of rows of test data to create in the source database. +// Can be used to control how many pages are available to be backed up. +const testRowCount = 100 + +// The maximum number of seconds after which the page-by-page backup is considered to have taken too long. +const usePagePerStepsTimeoutSeconds = 30 + +// Test the backup functionality. +func testBackup(t *testing.T, testRowCount int, usePerPageSteps bool) { + // This function will be called multiple times. + // It uses sql.Register(), which requires the name parameter value to be unique. + // There does not currently appear to be a way to unregister a registered driver, however. + // So generate a database driver name that will likely be unique. + var driverName = fmt.Sprintf("sqlite3_testBackup_%v_%v_%v", testRowCount, usePerPageSteps, time.Now().UnixNano()) + + // The driver's connection will be needed in order to perform the backup. + driverConns := []*SQLiteConn{} + sql.Register(driverName, &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + driverConns = append(driverConns, conn) + return nil + }, + }) + + // Connect to the source database. + srcTempFilename := "file:src?mode=memory&cache=shared" + srcDb, err := sql.Open(driverName, srcTempFilename) + if err != nil { + t.Fatal("Failed to open the source database:", err) + } + defer srcDb.Close() + err = srcDb.Ping() + if err != nil { + t.Fatal("Failed to connect to the source database:", err) + } + + // Connect to the destination database. + destTempFilename := "file:dst?mode=memory&cache=shared" + destDb, err := sql.Open(driverName, destTempFilename) + if err != nil { + t.Fatal("Failed to open the destination database:", err) + } + defer destDb.Close() + err = destDb.Ping() + if err != nil { + t.Fatal("Failed to connect to the destination database:", err) + } + + // Check the driver connections. + if len(driverConns) != 2 { + t.Fatalf("Expected 2 driver connections, but found %v.", len(driverConns)) + } + srcDbDriverConn := driverConns[0] + if srcDbDriverConn == nil { + t.Fatal("The source database driver connection is nil.") + } + destDbDriverConn := driverConns[1] + if destDbDriverConn == nil { + t.Fatal("The destination database driver connection is nil.") + } + + // Generate some test data for the given ID. + var generateTestData = func(id int) string { + return fmt.Sprintf("test-%v", id) + } + + // Populate the source database with a test table containing some test data. + tx, err := srcDb.Begin() + if err != nil { + t.Fatal("Failed to begin a transaction when populating the source database:", err) + } + _, err = srcDb.Exec("CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)") + if err != nil { + tx.Rollback() + t.Fatal("Failed to create the source database \"test\" table:", err) + } + for id := 0; id < testRowCount; id++ { + _, err = srcDb.Exec("INSERT INTO test (id, value) VALUES (?, ?)", id, generateTestData(id)) + if err != nil { + tx.Rollback() + t.Fatal("Failed to insert a row into the source database \"test\" table:", err) + } + } + err = tx.Commit() + if err != nil { + t.Fatal("Failed to populate the source database:", err) + } + + // Confirm that the destination database is initially empty. + var destTableCount int + err = destDb.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'").Scan(&destTableCount) + if err != nil { + t.Fatal("Failed to check the destination table count:", err) + } + if destTableCount != 0 { + t.Fatalf("The destination database is not empty; %v table(s) found.", destTableCount) + } + + // Prepare to perform the backup. + backup, err := destDbDriverConn.Backup("main", srcDbDriverConn, "main") + if err != nil { + t.Fatal("Failed to initialize the backup:", err) + } + + // Allow the initial page count and remaining values to be retrieved. + // According to , the page count and remaining values are "... only updated by sqlite3_backup_step()." + isDone, err := backup.Step(0) + if err != nil { + t.Fatal("Unable to perform an initial 0-page backup step:", err) + } + if isDone { + t.Fatal("Backup is unexpectedly done.") + } + + // Check that the page count and remaining values are reasonable. + initialPageCount := backup.PageCount() + if initialPageCount <= 0 { + t.Fatalf("Unexpected initial page count value: %v", initialPageCount) + } + initialRemaining := backup.Remaining() + if initialRemaining <= 0 { + t.Fatalf("Unexpected initial remaining value: %v", initialRemaining) + } + if initialRemaining != initialPageCount { + t.Fatalf("Initial remaining value differs from the initial page count value; remaining: %v; page count: %v", initialRemaining, initialPageCount) + } + + // Perform the backup. + if usePerPageSteps { + var startTime = time.Now().Unix() + + // Test backing-up using a page-by-page approach. + var latestRemaining = initialRemaining + for { + // Perform the backup step. + isDone, err = backup.Step(1) + if err != nil { + t.Fatal("Failed to perform a backup step:", err) + } + + // The page count should remain unchanged from its initial value. + currentPageCount := backup.PageCount() + if currentPageCount != initialPageCount { + t.Fatalf("Current page count differs from the initial page count; initial page count: %v; current page count: %v", initialPageCount, currentPageCount) + } + + // There should now be one less page remaining. + currentRemaining := backup.Remaining() + expectedRemaining := latestRemaining - 1 + if currentRemaining != expectedRemaining { + t.Fatalf("Unexpected remaining value; expected remaining value: %v; actual remaining value: %v", expectedRemaining, currentRemaining) + } + latestRemaining = currentRemaining + + if isDone { + break + } + + // Limit the runtime of the backup attempt. + if (time.Now().Unix() - startTime) > usePagePerStepsTimeoutSeconds { + t.Fatal("Backup is taking longer than expected.") + } + } + } else { + // Test the copying of all remaining pages. + isDone, err = backup.Step(-1) + if err != nil { + t.Fatal("Failed to perform a backup step:", err) + } + if !isDone { + t.Fatal("Backup is unexpectedly not done.") + } + } + + // Check that the page count and remaining values are reasonable. + finalPageCount := backup.PageCount() + if finalPageCount != initialPageCount { + t.Fatalf("Final page count differs from the initial page count; initial page count: %v; final page count: %v", initialPageCount, finalPageCount) + } + finalRemaining := backup.Remaining() + if finalRemaining != 0 { + t.Fatalf("Unexpected remaining value: %v", finalRemaining) + } + + // Finish the backup. + err = backup.Finish() + if err != nil { + t.Fatal("Failed to finish backup:", err) + } + + // Confirm that the "test" table now exists in the destination database. + var doesTestTableExist bool + err = destDb.QueryRow("SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'test' LIMIT 1) AS test_table_exists").Scan(&doesTestTableExist) + if err != nil { + t.Fatal("Failed to check if the \"test\" table exists in the destination database:", err) + } + if !doesTestTableExist { + t.Fatal("The \"test\" table could not be found in the destination database.") + } + + // Confirm that the number of rows in the destination database's "test" table matches that of the source table. + var actualTestTableRowCount int + err = destDb.QueryRow("SELECT COUNT(*) FROM test").Scan(&actualTestTableRowCount) + if err != nil { + t.Fatal("Failed to determine the rowcount of the \"test\" table in the destination database:", err) + } + if testRowCount != actualTestTableRowCount { + t.Fatalf("Unexpected destination \"test\" table row count; expected: %v; found: %v", testRowCount, actualTestTableRowCount) + } + + // Check each of the rows in the destination database. + for id := 0; id < testRowCount; id++ { + var checkedValue string + err = destDb.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&checkedValue) + if err != nil { + t.Fatal("Failed to query the \"test\" table in the destination database:", err) + } + + var expectedValue = generateTestData(id) + if checkedValue != expectedValue { + t.Fatalf("Unexpected value in the \"test\" table in the destination database; expected value: %v; actual value: %v", expectedValue, checkedValue) + } + } +} + +func TestBackupStepByStep(t *testing.T) { + testBackup(t, testRowCount, true) +} + +func TestBackupAllRemainingPages(t *testing.T) { + testBackup(t, testRowCount, false) +} + +// Test the error reporting when preparing to perform a backup. +func TestBackupError(t *testing.T) { + const driverName = "sqlite3_TestBackupError" + + // The driver's connection will be needed in order to perform the backup. + var dbDriverConn *SQLiteConn + sql.Register(driverName, &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + dbDriverConn = conn + return nil + }, + }) + + // Connect to the database. + db, err := sql.Open(driverName, ":memory:") + if err != nil { + t.Fatal("Failed to open the database:", err) + } + defer db.Close() + db.Ping() + + // Need the driver connection in order to perform the backup. + if dbDriverConn == nil { + t.Fatal("Failed to get the driver connection.") + } + + // Prepare to perform the backup. + // Intentionally using the same connection for both the source and destination databases, to trigger an error result. + backup, err := dbDriverConn.Backup("main", dbDriverConn, "main") + if err == nil { + t.Fatal("Failed to get the expected error result.") + } + const expectedError = "source and destination must be distinct" + if err.Error() != expectedError { + t.Fatalf("Unexpected error message; expected value: \"%v\"; actual value: \"%v\"", expectedError, err.Error()) + } + if backup != nil { + t.Fatal("Failed to get the expected nil backup result.") + } +} + +func TestCallbackArgCast(t *testing.T) { + intConv := callbackSyntheticForTests(reflect.ValueOf(int64(math.MaxInt64)), nil) + floatConv := callbackSyntheticForTests(reflect.ValueOf(float64(math.MaxFloat64)), nil) + errConv := callbackSyntheticForTests(reflect.Value{}, errors.New("test")) + + tests := []struct { + f callbackArgConverter + o reflect.Value + }{ + {intConv, reflect.ValueOf(int8(-1))}, + {intConv, reflect.ValueOf(int16(-1))}, + {intConv, reflect.ValueOf(int32(-1))}, + {intConv, reflect.ValueOf(uint8(math.MaxUint8))}, + {intConv, reflect.ValueOf(uint16(math.MaxUint16))}, + {intConv, reflect.ValueOf(uint32(math.MaxUint32))}, + // Special case, int64->uint64 is only 1<<63 - 1, not 1<<64 - 1 + {intConv, reflect.ValueOf(uint64(math.MaxInt64))}, + {floatConv, reflect.ValueOf(float32(math.Inf(1)))}, + } + + for _, test := range tests { + conv := callbackArgCast{test.f, test.o.Type()} + val, err := conv.Run(nil) + if err != nil { + t.Errorf("Couldn't convert to %s: %s", test.o.Type(), err) + } else if !reflect.DeepEqual(val.Interface(), test.o.Interface()) { + t.Errorf("Unexpected result from converting to %s: got %v, want %v", test.o.Type(), val.Interface(), test.o.Interface()) + } + } + + conv := callbackArgCast{errConv, reflect.TypeOf(int8(0))} + _, err := conv.Run(nil) + if err == nil { + t.Errorf("Expected error during callbackArgCast, but got none") + } +} + +func TestCallbackConverters(t *testing.T) { + tests := []struct { + v any + err bool + }{ + // Unfortunately, we can't tell which converter was returned, + // but we can at least check which types can be converted. + {[]byte{0}, false}, + {"text", false}, + {true, false}, + {int8(0), false}, + {int16(0), false}, + {int32(0), false}, + {int64(0), false}, + {uint8(0), false}, + {uint16(0), false}, + {uint32(0), false}, + {uint64(0), false}, + {int(0), false}, + {uint(0), false}, + {float64(0), false}, + {float32(0), false}, + + {func() {}, true}, + {complex64(complex(0, 0)), true}, + {complex128(complex(0, 0)), true}, + {struct{}{}, true}, + {map[string]string{}, true}, + {[]string{}, true}, + {(*int8)(nil), true}, + {make(chan int), true}, + } + + for _, test := range tests { + _, err := callbackArg(reflect.TypeOf(test.v)) + if test.err && err == nil { + t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v)) + } else if !test.err && err != nil { + t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err) + } + } + + for _, test := range tests { + _, err := callbackRet(reflect.TypeOf(test.v)) + if test.err && err == nil { + t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v)) + } else if !test.err && err != nil { + t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err) + } + } +} + +func TestCallbackReturnAny(t *testing.T) { + udf := func() any { + return 1 + } + + typ := reflect.TypeOf(udf) + _, err := callbackRet(typ.Out(0)) + if err != nil { + t.Errorf("Expected valid callback for any return type, got: %s", err) + } +} + +func TestSimpleError(t *testing.T) { + e := ErrError.Error() + if e != "SQL logic error or missing database" && e != "SQL logic error" { + t.Error("wrong error code: " + e) + } +} + +func TestCorruptDbErrors(t *testing.T) { + dirName, err := ioutil.TempDir("", "FIXME") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dirName) + + dbFileName := path.Join(dirName, "test.db") + f, err := os.Create(dbFileName) + if err != nil { + t.Error(err) + } + f.Write([]byte{1, 2, 3, 4, 5}) + f.Close() + + db, err := sql.Open(DriverName, dbFileName) + if err == nil { + _, err = db.Exec("drop table foo") + } + + sqliteErr := err.(Error) + if sqliteErr.Code != ErrNotADB { + t.Error("wrong error code for corrupted DB") + } + if err.Error() == "" { + t.Error("wrong error string for corrupted DB") + } + db.Close() +} + +func TestSqlLogicErrors(t *testing.T) { + dirName, err := ioutil.TempDir("", "FIXME") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dirName) + + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Error(err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE Foo (id INTEGER PRIMARY KEY)") + if err != nil { + t.Error(err) + } + + const expectedErr = "table Foo already exists" + _, err = db.Exec("CREATE TABLE Foo (id INTEGER PRIMARY KEY)") + if err.Error() != expectedErr { + t.Errorf("Unexpected error: %s, expected %s", err.Error(), expectedErr) + } + +} + +func TestExtendedErrorCodes_ForeignKey(t *testing.T) { + dirName, err := ioutil.TempDir("", "sqlite3-err") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dirName) + + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Error(err) + } + defer db.Close() + + _, err = db.Exec(`CREATE TABLE Foo ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + value INTEGER NOT NULL, + ref INTEGER NULL REFERENCES Foo (id), + UNIQUE(value) + );`) + if err != nil { + t.Error(err) + } + + _, err = db.Exec("INSERT INTO Foo (ref, value) VALUES (100, 100);") + if err == nil { + t.Error("No error!") + } else { + sqliteErr := err.(Error) + if sqliteErr.Code != ErrConstraint { + t.Errorf("Wrong basic error code: %d != %d", + sqliteErr.Code, ErrConstraint) + } + if sqliteErr.ExtendedCode != ErrConstraintForeignKey { + t.Errorf("Wrong extended error code: %d != %d", + sqliteErr.ExtendedCode, ErrConstraintForeignKey) + } + } + +} + +func TestExtendedErrorCodes_NotNull(t *testing.T) { + dirName, err := ioutil.TempDir("", "sqlite3-err") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dirName) + + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Error(err) + } + defer db.Close() + + _, err = db.Exec(`CREATE TABLE Foo ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + value INTEGER NOT NULL, + ref INTEGER NULL REFERENCES Foo (id), + UNIQUE(value) + );`) + if err != nil { + t.Error(err) + } + + res, err := db.Exec("INSERT INTO Foo (value) VALUES (100);") + if err != nil { + t.Fatalf("Creating first row: %v", err) + } + + id, err := res.LastInsertId() + if err != nil { + t.Fatalf("Retrieving last insert id: %v", err) + } + + _, err = db.Exec("INSERT INTO Foo (ref) VALUES (?);", id) + if err == nil { + t.Error("No error!") + } else { + sqliteErr := err.(Error) + if sqliteErr.Code != ErrConstraint { + t.Errorf("Wrong basic error code: %d != %d", + sqliteErr.Code, ErrConstraint) + } + if sqliteErr.ExtendedCode != ErrConstraintNotNull { + t.Errorf("Wrong extended error code: %d != %d", + sqliteErr.ExtendedCode, ErrConstraintNotNull) + } + } + +} + +func TestExtendedErrorCodes_Unique(t *testing.T) { + dirName, err := ioutil.TempDir("", "sqlite3-err") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dirName) + + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Error(err) + } + defer db.Close() + + _, err = db.Exec(`CREATE TABLE Foo ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + value INTEGER NOT NULL, + ref INTEGER NULL REFERENCES Foo (id), + UNIQUE(value) + );`) + if err != nil { + t.Error(err) + } + + res, err := db.Exec("INSERT INTO Foo (value) VALUES (100);") + if err != nil { + t.Fatalf("Creating first row: %v", err) + } + + id, err := res.LastInsertId() + if err != nil { + t.Fatalf("Retrieving last insert id: %v", err) + } + + _, err = db.Exec("INSERT INTO Foo (ref, value) VALUES (?, 100);", id) + if err == nil { + t.Error("No error!") + } else { + sqliteErr := err.(Error) + if sqliteErr.Code != ErrConstraint { + t.Errorf("Wrong basic error code: %d != %d", + sqliteErr.Code, ErrConstraint) + } + if sqliteErr.ExtendedCode != ErrConstraintUnique { + t.Errorf("Wrong extended error code: %d != %d", + sqliteErr.ExtendedCode, ErrConstraintUnique) + } + extended := sqliteErr.Code.Extend(3).Error() + expected := "constraint failed" + if extended != expected { + t.Errorf("Wrong basic error code: %q != %q", + extended, expected) + } + } +} + +func TestError_SystemErrno(t *testing.T) { + _, n, _ := LibVersion() + if n < 3012000 { + t.Skip("sqlite3_system_errno requires sqlite3 >= 3.12.0") + } + + // open a non-existent database in read-only mode so we get an IO error. + db, err := sql.Open(DriverName, "file:nonexistent.db?mode=ro") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Ping() + if err == nil { + t.Fatal("expected error pinging read-only non-existent database, but got nil") + } + + serr, ok := err.(Error) + if !ok { + t.Fatalf("expected error to be of type Error, but got %[1]T %[1]v", err) + } + + if serr.SystemErrno == 0 { + t.Fatal("expected SystemErrno to be set") + } + + if !os.IsNotExist(serr.SystemErrno) { + t.Errorf("expected SystemErrno to be a not exists error, but got %v", serr.SystemErrno) + } +} + +func TestBeginTxCancel(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal(err) + } + + db.SetMaxOpenConns(10) + db.SetMaxIdleConns(5) + + defer db.Close() + initDatabase(t, db, 100) + + // create several go-routines to expose racy issue + for i := 0; i < 1000; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() + + err = conn.Raw(func(driverConn any) error { + d, ok := driverConn.(driver.ConnBeginTx) + if !ok { + t.Fatal("unexpected: wrong type") + } + // checks that conn.Raw can be used to get *SQLiteConn + if _, ok = driverConn.(*SQLiteConn); !ok { + t.Fatalf("conn.Raw() driverConn type=%T, expected *SQLiteConn", driverConn) + } + + go cancel() // make it cancel concurrently with exec("BEGIN"); + tx, err := d.BeginTx(ctx, driver.TxOptions{}) + switch err { + case nil: + switch err := tx.Rollback(); err { + case nil, sql.ErrTxDone: + default: + return err + } + case context.Canceled: + default: + // must not fail with "cannot start a transaction within a transaction" + return err + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }() + } +} + +func TestStmtReadonly(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE t (count INT)") + if err != nil { + t.Fatal(err) + } + + isRO := func(query string) bool { + c, err := db.Conn(context.Background()) + if err != nil { + return false + } + + var ro bool + c.Raw(func(dc any) error { + stmt, err := dc.(*SQLiteConn).Prepare(query) + if err != nil { + return err + } + if stmt == nil { + return errors.New("stmt is nil") + } + ro = stmt.(*SQLiteStmt).Readonly() + return nil + }) + return ro // On errors ro will remain false. + } + + if !isRO(`select * from t`) { + t.Error("select not seen as read-only") + } + if isRO(`insert into t values (1), (2)`) { + t.Error("insert seen as read-only") + } +} + +func TestNamedParams(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec(` + create table foo (id integer, name text, extra text); + `) + if err != nil { + t.Error("Failed to call db.Query:", err) + } + + _, err = db.Exec(`insert into foo(id, name, extra) values(:id, :name, :name)`, sql.Named("name", "foo"), sql.Named("id", 1)) + if err != nil { + t.Error("Failed to call db.Exec:", err) + } + + row := db.QueryRow(`select id, extra from foo where id = :id and extra = :extra`, sql.Named("id", 1), sql.Named("extra", "foo")) + if row == nil { + t.Error("Failed to call db.QueryRow") + } + var id int + var extra string + err = row.Scan(&id, &extra) + if err != nil { + t.Error("Failed to db.Scan:", err) + } + if id != 1 || extra != "foo" { + t.Error("Failed to db.QueryRow: not matched results") + } +} + +var ( + testTableStatements = []string{ + `DROP TABLE IF EXISTS test_table`, + ` +CREATE TABLE IF NOT EXISTS test_table ( + key1 VARCHAR(64) PRIMARY KEY, + key_id VARCHAR(64) NOT NULL, + key2 VARCHAR(64) NOT NULL, + key3 VARCHAR(64) NOT NULL, + key4 VARCHAR(64) NOT NULL, + key5 VARCHAR(64) NOT NULL, + key6 VARCHAR(64) NOT NULL, + data BLOB NOT NULL +);`, + } + letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +) + +func randStringBytes(n int) string { + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return string(b) +} + +func initDatabase(t *testing.T, db *sql.DB, rowCount int64) { + for _, query := range testTableStatements { + _, err := db.Exec(query) + if err != nil { + t.Fatal(err) + } + } + for i := int64(0); i < rowCount; i++ { + query := `INSERT INTO test_table + (key1, key_id, key2, key3, key4, key5, key6, data) + VALUES + (?, ?, ?, ?, ?, ?, ?, ?);` + args := []interface{}{ + randStringBytes(50), + fmt.Sprint(i), + randStringBytes(50), + randStringBytes(50), + randStringBytes(50), + randStringBytes(50), + randStringBytes(50), + randStringBytes(50), + randStringBytes(2048), + } + _, err := db.Exec(query, args...) + if err != nil { + t.Fatal(err) + } + } +} + +func TestShortTimeout(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + initDatabase(t, db, 100) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond) + defer cancel() + query := `SELECT key1, key_id, key2, key3, key4, key5, key6, data + FROM test_table + ORDER BY key2 ASC` + _, err = db.QueryContext(ctx, query) + if err != nil && err != context.DeadlineExceeded { + t.Fatal(err) + } + if ctx.Err() != nil && ctx.Err() != context.DeadlineExceeded { + t.Fatal(ctx.Err()) + } +} + +func TestExecContextCancel(t *testing.T) { + db, err := sql.Open(DriverName, "file:exec?mode=memory&cache=shared") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ts := time.Now() + initDatabase(t, db, 1000) + spent := time.Since(ts) + const minTestTime = 100 * time.Millisecond + if spent < minTestTime && false { + t.Skipf("test will be too racy (spent=%s < min=%s) as ExecContext below will be too fast.", + spent.String(), minTestTime.String(), + ) + } + + // expected to be extremely slow query + q := ` +INSERT INTO test_table (key1, key_id, key2, key3, key4, key5, key6, data) +SELECT t1.key1 || t2.key1, t1.key_id || t2.key_id, t1.key2 || t2.key2, t1.key3 || t2.key3, t1.key4 || t2.key4, t1.key5 || t2.key5, t1.key6 || t2.key6, t1.data || t2.data +FROM test_table t1 LEFT OUTER JOIN test_table t2` + // expect query above take ~ same time as setup above + // This is racy: the context must be valid so sql/db.ExecContext calls the sqlite3 driver. + // It starts the query, the context expires, then calls sqlite3_interrupt + ctx, cancel := context.WithTimeout(context.Background(), minTestTime/2) + defer cancel() + ts = time.Now() + r, err := db.ExecContext(ctx, q) + // racy check + if r != nil { + n, err := r.RowsAffected() + t.Logf("query should not have succeeded: rows=%d; err=%v; duration=%s", + n, err, time.Since(ts).String()) + } + if err != context.DeadlineExceeded { + t.Fatal(err, ctx.Err()) + } +} + +func TestQueryRowContextCancel(t *testing.T) { + // FIXME: too slow + db, err := sql.Open(DriverName, "file:query?mode=memory&cache=shared") + if err != nil { + t.Fatal(err) + } + defer db.Close() + initDatabase(t, db, 100) + + const query = `SELECT key_id FROM test_table ORDER BY key2 ASC` + var keyID string + unexpectedErrors := make(map[string]int) + for i := 0; i < 10000; i++ { + ctx, cancel := context.WithCancel(context.Background()) + row := db.QueryRowContext(ctx, query) + + cancel() + // it is fine to get "nil" as context cancellation can be handled with delay + if err := row.Scan(&keyID); err != nil && err != context.Canceled { + if err.Error() == "sql: Rows are closed" { + // see https://github.com/golang/go/issues/24431 + // fixed in 1.11.1 to properly return context error + continue + } + unexpectedErrors[err.Error()]++ + } + } + for errText, count := range unexpectedErrors { + t.Error(errText, count) + } +} + +func TestQueryRowContextCancelParallel(t *testing.T) { + // FIXME: too slow + db, err := sql.Open(DriverName, "file:parallel?mode=memory&cache=shared") + if err != nil { + t.Fatal(err) + } + db.SetMaxOpenConns(10) + db.SetMaxIdleConns(5) + + defer db.Close() + initDatabase(t, db, 100) + + const query = `SELECT key_id FROM test_table ORDER BY key2 ASC` + wg := sync.WaitGroup{} + defer wg.Wait() + + testCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + var keyID string + for { + select { + case <-testCtx.Done(): + return + default: + } + ctx, cancel := context.WithCancel(context.Background()) + row := db.QueryRowContext(ctx, query) + + cancel() + _ = row.Scan(&keyID) // see TestQueryRowContextCancel + } + }() + } + + var keyID string + for i := 0; i < 10000; i++ { + // note that testCtx is not cancelled during query execution + row := db.QueryRowContext(testCtx, query) + + if err := row.Scan(&keyID); err != nil { + t.Fatal(i, err) + } + } +} + +func TestExecCancel(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + if _, err = db.Exec("create table foo (id integer primary key)"); err != nil { + t.Fatal(err) + } + + for n := 0; n < 100; n++ { + ctx, cancel := context.WithCancel(context.Background()) + _, err = db.ExecContext(ctx, "insert into foo (id) values (?)", n) + cancel() + if err != nil { + t.Fatal(err) + } + } +} + +func doTestOpenContext(t *testing.T, url string) (string, error) { + db, err := sql.Open(DriverName, url) + if err != nil { + return "Failed to open database:", err + } + + defer func() { + err = db.Close() + if err != nil { + t.Error("db close error:", err) + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + err = db.PingContext(ctx) + cancel() + if err != nil { + return "ping error:", err + } + + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + _, err = db.ExecContext(ctx, "drop table foo") + cancel() + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + _, err = db.ExecContext(ctx, "create table foo (id integer)") + cancel() + if err != nil { + return "Failed to create table:", err + } + + return "", nil +} + +func TestFileCopyTruncate(t *testing.T) { + var err error + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + + db, err := sql.Open(DriverName, tempFilename) + if err != nil { + t.Fatal("open error:", err) + } + defer db.Close() + + if true { + _, err = db.Exec("PRAGMA journal_mode = delete;") + if err != nil { + t.Fatal("journal_mode delete:", err) + } + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + err = db.PingContext(ctx) + cancel() + if err != nil { + t.Fatal("ping error:", err) + } + + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + _, err = db.ExecContext(ctx, "drop table foo") + cancel() + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + _, err = db.ExecContext(ctx, "create table foo (id integer)") + cancel() + if err != nil { + t.Fatal("create table error:", err) + } + + // copy db to new file + var data []byte + data, err = ioutil.ReadFile(tempFilename) + if err != nil { + t.Fatal("read file error:", err) + } + + var f *os.File + copyFilename := tempFilename + "-db-copy" + f, err = os.Create(copyFilename) + if err != nil { + t.Fatal("create file error:", err) + } + defer os.Remove(copyFilename) + + _, err = f.Write(data) + if err != nil { + f.Close() + t.Fatal("write file error:", err) + } + err = f.Close() + if err != nil { + t.Fatal("close file error:", err) + } + + // truncate current db file + f, err = os.OpenFile(tempFilename, os.O_WRONLY|os.O_TRUNC, 0666) + if err != nil { + t.Fatal("open file error:", err) + } + err = f.Close() + if err != nil { + t.Fatal("close file error:", err) + } + + // test db after file truncate + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + err = db.PingContext(ctx) + cancel() + if err != nil { + t.Fatal("ping error:", err) + } + + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Millisecond) + _, err = db.ExecContext(ctx, "drop table foo") + cancel() + if err == nil { + t.Fatal("drop table no error") + } + + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + _, err = db.ExecContext(ctx, "create table foo (id integer)") + cancel() + if err != nil { + t.Fatal("create table error:", err) + } + + err = db.Close() + if err != nil { + t.Error("db close error:", err) + } + + // test copied file + db, err = sql.Open(DriverName, copyFilename) + if err != nil { + t.Fatal("open error:", err) + } + defer db.Close() + + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + err = db.PingContext(ctx) + cancel() + if err != nil { + t.Fatal("ping error:", err) + } + + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + _, err = db.ExecContext(ctx, "drop table foo") + cancel() + if err != nil { + t.Fatal("drop table error:", err) + } + + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + _, err = db.ExecContext(ctx, "create table foo (id integer)") + cancel() + if err != nil { + t.Fatal("create table error:", err) + } +} + +func TestColumnTableName(t *testing.T) { + d := SQLiteDriver{} + conn, err := d.Open(":memory:") + if err != nil { + t.Fatal("failed to get database connection:", err) + } + defer conn.Close() + sqlite3conn := conn.(*SQLiteConn) + + _, err = sqlite3conn.Exec(`CREATE TABLE foo (name string)`, nil) + if err != nil { + t.Fatal("Failed to create table:", err) + } + _, err = sqlite3conn.Exec(`CREATE TABLE bar (name string)`, nil) + if err != nil { + t.Fatal("Failed to create table:", err) + } + + stmt, err := sqlite3conn.Prepare(`SELECT * FROM foo JOIN bar ON foo.name = bar.name`) + if err != nil { + t.Fatal(err) + } + + if exp, got := "foo", stmt.(*SQLiteStmt).ColumnTableName(0); exp != got { + t.Fatalf("Incorrect table name returned expected: %s, got: %s", exp, got) + } + if exp, got := "bar", stmt.(*SQLiteStmt).ColumnTableName(1); exp != got { + t.Fatalf("Incorrect table name returned expected: %s, got: %s", exp, got) + } + if exp, got := "", stmt.(*SQLiteStmt).ColumnTableName(2); exp != got { + t.Fatalf("Incorrect table name returned expected: %s, got: %s", exp, got) + } +} + +func TestFTS5(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("DROP TABLE foo") + _, err = db.Exec("CREATE VIRTUAL TABLE foo USING fts5(id, value)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = db.Exec("INSERT INTO foo(id, value) VALUES(?, ?)", 1, `今日の 晩御飯は 天麩羅よ`) + if err != nil { + t.Fatal("Failed to insert value:", err) + } + + _, err = db.Exec("INSERT INTO foo(id, value) VALUES(?, ?)", 2, `今日は いい 天気だ`) + if err != nil { + t.Fatal("Failed to insert value:", err) + } + + rows, err := db.Query("SELECT id, value FROM foo WHERE value MATCH '今日* 天*'") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + defer rows.Close() + + for rows.Next() { + var id int + var value string + + if err := rows.Scan(&id, &value); err != nil { + t.Error("Unable to scan results:", err) + continue + } + + if id == 1 && value != `今日の 晩御飯は 天麩羅よ` { + t.Error("Value for id 1 should be `今日の 晩御飯は 天麩羅よ`, but:", value) + } else if id == 2 && value != `今日は いい 天気だ` { + t.Error("Value for id 2 should be `今日は いい 天気だ`, but:", value) + } + } + + rows, err = db.Query("SELECT value FROM foo WHERE value MATCH '今日* 天麩羅*'") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + defer rows.Close() + + var value string + if !rows.Next() { + t.Fatal("Result should be only one") + } + + if err := rows.Scan(&value); err != nil { + t.Fatal("Unable to scan results:", err) + } + + if value != `今日の 晩御飯は 天麩羅よ` { + t.Fatal("Value should be `今日の 晩御飯は 天麩羅よ`, but:", value) + } + + if rows.Next() { + t.Fatal("Result should be only one") + } + + _, err = db.Exec("DROP TABLE foo") + _, err = db.Exec("CREATE VIRTUAL TABLE foo USING fts5(tokenize=unicode61, id, value)") + switch { + case err != nil && err.Error() == "unknown tokenizer: unicode61": + t.Skip("FTS4 not supported") + case err != nil: + t.Fatal("Failed to create table:", err) + } + + _, err = db.Exec("INSERT INTO foo(id, value) VALUES(?, ?)", 1, `février`) + if err != nil { + t.Fatal("Failed to insert value:", err) + } + + rows, err = db.Query("SELECT value FROM foo WHERE value MATCH 'fevrier'") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + defer rows.Close() + + if !rows.Next() { + t.Fatal("Result should be only one") + } + + if err := rows.Scan(&value); err != nil { + t.Fatal("Unable to scan results:", err) + } + + if value != `février` { + t.Fatal("Value should be `février`, but:", value) + } + + if rows.Next() { + t.Fatal("Result should be only one") + } +} + +type preUpdateHookDataForTest struct { + databaseName string + tableName string + count int + op int + oldRow []any + newRow []any +} + +func TestSerializeDeserialize(t *testing.T) { + // Connect to the source database. + srcDb, err := sql.Open(DriverName, "file:src?mode=memory&cache=shared") + if err != nil { + t.Fatal("Failed to open the source database:", err) + } + defer srcDb.Close() + err = srcDb.Ping() + if err != nil { + t.Fatal("Failed to connect to the source database:", err) + } + + // Connect to the destination database. + destDb, err := sql.Open(DriverName, "file:dst?mode=memory&cache=shared") + if err != nil { + t.Fatal("Failed to open the destination database:", err) + } + defer destDb.Close() + err = destDb.Ping() + if err != nil { + t.Fatal("Failed to connect to the destination database:", err) + } + + // Write data to source database. + _, err = srcDb.Exec(`CREATE TABLE foo (name string)`) + if err != nil { + t.Fatal("Failed to create table in source database:", err) + } + _, err = srcDb.Exec(`INSERT INTO foo(name) VALUES('alice')`) + if err != nil { + t.Fatal("Failed to insert data into source database", err) + } + + // Serialize the source database + srcConn, err := srcDb.Conn(context.Background()) + if err != nil { + t.Fatal("Failed to get connection to source database:", err) + } + defer srcConn.Close() + + var serialized []byte + if err := srcConn.Raw(func(raw any) error { + var err error + serialized, err = raw.(*SQLiteConn).Serialize("") + return err + }); err != nil { + t.Fatal("Failed to serialize source database:", err) + } + srcConn.Close() + + // Confirm that the destination database is initially empty. + var destTableCount int + err = destDb.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'").Scan(&destTableCount) + if err != nil { + t.Fatal("Failed to check the destination table count:", err) + } + if destTableCount != 0 { + t.Fatalf("The destination database is not empty; %v table(s) found.", destTableCount) + } + + // Deserialize to destination database + destConn, err := destDb.Conn(context.Background()) + if err != nil { + t.Fatal("Failed to get connection to destination database:", err) + } + defer destConn.Close() + + if err := destConn.Raw(func(raw any) error { + return raw.(*SQLiteConn).Deserialize(serialized, "") + }); err != nil { + t.Fatal("Failed to deserialize source database:", err) + } + destConn.Close() + + // Confirm that destination database has been loaded correctly. + var destRowCount int + err = destDb.QueryRow(`SELECT COUNT(*) FROM foo`).Scan(&destRowCount) + if err != nil { + t.Fatal("Failed to count rows in destination database table", err) + } + if destRowCount != 1 { + t.Fatalf("Destination table does not have the expected records") + } +} + +func TestUnlockNotify(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + dsn := fmt.Sprintf("file:%s?cache=shared&mode=memory", tempFilename) + db, err := sql.Open(DriverName, dsn) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE foo(id INTEGER, status INTEGER)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + tx, err := db.Begin() + if err != nil { + t.Fatal("Failed to begin transaction:", err) + } + + _, err = tx.Exec("INSERT INTO foo(id, status) VALUES(1, 100)") + if err != nil { + t.Fatal("Failed to insert null:", err) + } + + _, err = tx.Exec("UPDATE foo SET status = 200 WHERE id = 1") + if err != nil { + t.Fatal("Failed to update table:", err) + } + + wg := sync.WaitGroup{} + wg.Add(1) + timer := time.NewTimer(500 * time.Millisecond) + go func() { + <-timer.C + err := tx.Commit() + if err != nil { + t.Fatal("Failed to commit transaction:", err) + } + wg.Done() + }() + + rows, err := db.Query("SELECT count(*) from foo") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + + if rows.Next() { + var count int + if err := rows.Scan(&count); err != nil { + t.Fatal("Failed to Scan rows", err) + } + } + if err := rows.Err(); err != nil { + t.Fatal("Failed at the call to Next:", err) + } + wg.Wait() + +} + +func TestUnlockNotifyMany(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + dsn := fmt.Sprintf("file:%s?cache=shared&mode=memory", tempFilename) + db, err := sql.Open(DriverName, dsn) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE foo(id INTEGER, status INTEGER)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + tx, err := db.Begin() + if err != nil { + t.Fatal("Failed to begin transaction:", err) + } + + _, err = tx.Exec("INSERT INTO foo(id, status) VALUES(1, 100)") + if err != nil { + t.Fatal("Failed to insert null:", err) + } + + _, err = tx.Exec("UPDATE foo SET status = 200 WHERE id = 1") + if err != nil { + t.Fatal("Failed to update table:", err) + } + + wg := sync.WaitGroup{} + wg.Add(1) + timer := time.NewTimer(500 * time.Millisecond) + go func() { + <-timer.C + err := tx.Commit() + if err != nil { + t.Fatal("Failed to commit transaction:", err) + } + wg.Done() + }() + + const concurrentQueries = 1000 + wg.Add(concurrentQueries) + for i := 0; i < concurrentQueries; i++ { + go func() { + rows, err := db.Query("SELECT count(*) from foo") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + + if rows.Next() { + var count int + if err := rows.Scan(&count); err != nil { + t.Fatal("Failed to Scan rows", err) + } + } + if err := rows.Err(); err != nil { + t.Fatal("Failed at the call to Next:", err) + } + wg.Done() + }() + } + wg.Wait() +} + +func TestUnlockNotifyDeadlock(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + dsn := fmt.Sprintf("file:%s?cache=shared&mode=memory", tempFilename) + db, err := sql.Open(DriverName, dsn) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE foo(id INTEGER, status INTEGER)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + tx, err := db.Begin() + if err != nil { + t.Fatal("Failed to begin transaction:", err) + } + + _, err = tx.Exec("INSERT INTO foo(id, status) VALUES(1, 100)") + if err != nil { + t.Fatal("Failed to insert null:", err) + } + + _, err = tx.Exec("UPDATE foo SET status = 200 WHERE id = 1") + if err != nil { + t.Fatal("Failed to update table:", err) + } + + wg := sync.WaitGroup{} + wg.Add(1) + timer := time.NewTimer(500 * time.Millisecond) + go func() { + <-timer.C + err := tx.Commit() + if err != nil { + t.Fatal("Failed to commit transaction:", err) + } + wg.Done() + }() + + wg.Add(1) + go func() { + tx2, err := db.Begin() + if err != nil { + t.Fatal("Failed to begin transaction:", err) + } + defer tx2.Rollback() + + _, err = tx2.Exec("DELETE FROM foo") + if err != nil { + t.Fatal("Failed to delete table:", err) + } + err = tx2.Commit() + if err != nil { + t.Fatal("Failed to commit transaction:", err) + } + wg.Done() + }() + + rows, err := tx.Query("SELECT count(*) from foo") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + + if rows.Next() { + var count int + if err := rows.Scan(&count); err != nil { + t.Fatal("Failed to Scan rows", err) + } + } + if err := rows.Err(); err != nil { + t.Fatal("Failed at the call to Next:", err) + } + + wg.Wait() +} + +func getRowCount(rows *sql.Rows) (int, error) { + var i int + for rows.Next() { + i++ + } + return i, nil +} + +func TempFilename(t testing.TB) string { + f, err := ioutil.TempFile("", "go-sqlite3-test-") + if err != nil { + t.Fatal(err) + } + f.Close() + return f.Name() +} + +func doTestOpen(t *testing.T, url string) (string, error) { + db, err := sql.Open(DriverName, url) + if err != nil { + return "Failed to open database:", err + } + + defer func() { + err = db.Close() + if err != nil { + t.Error("db close error:", err) + } + }() + + err = db.Ping() + if err != nil { + return "ping error:", err + } + + _, err = db.Exec("drop table foo") + _, err = db.Exec("create table foo (id integer)") + if err != nil { + return "Failed to create table:", err + } + + return "", nil +} + +func TestOpenWithVFS(t *testing.T) { + { + uri := fmt.Sprintf("file:%s?mode=memory&vfs=hello", t.Name()) + db, err := sql.Open(DriverName, uri) + if err != nil { + t.Fatal("Failed to open", err) + } + err = db.Ping() + if err == nil { + t.Fatal("Failed to open", err) + } + db.Close() + } + + { + uri := fmt.Sprintf("file:%s?mode=memory&vfs=unix-none", t.Name()) + db, err := sql.Open(DriverName, uri) + if err != nil { + t.Fatal("Failed to open", err) + } + err = db.Ping() + if err != nil { + t.Fatal("Failed to ping", err) + } + db.Close() + } +} + +func TestOpenNoCreate(t *testing.T) { + filename := t.Name() + ".sqlite" + + if err := os.Remove(filename); err != nil && !os.IsNotExist(err) { + t.Fatal(err) + } + defer os.Remove(filename) + + // https://golang.org/pkg/database/sql/#Open + // "Open may just validate its arguments without creating a connection + // to the database. To verify that the data source name is valid, call Ping." + db, err := sql.Open(DriverName, fmt.Sprintf("file:%s?mode=rw", filename)) + if err == nil { + defer db.Close() + + err = db.Ping() + if err == nil { + t.Fatal("expected error from Open or Ping") + } + } + + sqlErr, ok := err.(Error) + if !ok { + t.Fatalf("expected sqlite3.Error, but got %T", err) + } + + if sqlErr.Code != ErrCantOpen { + t.Fatalf("expected SQLITE_CANTOPEN, but got %v", sqlErr) + } + + // make sure database file truly was not created + if _, err := os.Stat(filename); !os.IsNotExist(err) { + if err != nil { + t.Fatal(err) + } + t.Fatal("expected database file to not exist") + } + + // verify that it works if the mode is "rwc" instead + db, err = sql.Open(DriverName, fmt.Sprintf("file:%s?mode=rwc", filename)) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + if err := db.Ping(); err != nil { + t.Fatal(err) + } + + // make sure database file truly was created + if _, err := os.Stat(filename); err != nil { + if !os.IsNotExist(err) { + t.Fatal(err) + } + t.Fatal("expected database file to exist") + } +} + +func TestReadonly(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + + db1, err := sql.Open(DriverName, "file:"+tempFilename) + if err != nil { + t.Fatal(err) + } + defer db1.Close() + db1.Exec("CREATE TABLE test (x int, y float)") + + db2, err := sql.Open(DriverName, "file:"+tempFilename+"?mode=ro") + if err != nil { + t.Fatal(err) + } + defer db2.Close() + _ = db2 + _, err = db2.Exec("INSERT INTO test VALUES (1, 3.14)") + if err == nil { + t.Fatal("didn't expect INSERT into read-only database to work") + } +} + +func TestDeferredForeignKey(t *testing.T) { + fname := TempFilename(t) + uri := "file:" + fname + "?_foreign_keys=1&mode=memory" + db, err := sql.Open(DriverName, uri) + if err != nil { + os.Remove(fname) + t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err) + } + _, err = db.Exec("CREATE TABLE bar (id INTEGER PRIMARY KEY)") + if err != nil { + t.Errorf("failed creating tables: %v", err) + } + _, err = db.Exec("CREATE TABLE foo (bar_id INTEGER, FOREIGN KEY(bar_id) REFERENCES bar(id) DEFERRABLE INITIALLY DEFERRED)") + if err != nil { + t.Errorf("failed creating tables: %v", err) + } + tx, err := db.Begin() + if err != nil { + t.Errorf("Failed to begin transaction: %v", err) + } + _, err = tx.Exec("INSERT INTO foo (bar_id) VALUES (123)") + if err != nil { + t.Errorf("Failed to insert row: %v", err) + } + err = tx.Commit() + if err == nil { + t.Errorf("Expected an error: %v", err) + } + _, err = db.Begin() + if err != nil { + t.Errorf("Failed to begin transaction: %v", err) + } + + db.Close() + os.Remove(fname) +} + +func TestClose(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + + _, err = db.Exec("drop table foo") + _, err = db.Exec("create table foo (id integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + stmt, err := db.Prepare("select id from foo where id = ?") + if err != nil { + t.Fatal("Failed to select records:", err) + } + + db.Close() + _, err = stmt.Exec(1) + if err == nil { + t.Fatal("Failed to operate closed statement") + } +} + +func TestInsert(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("drop table foo") + _, err = db.Exec("create table foo (id integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + res, err := db.Exec("insert into foo(id) values(123)") + if err != nil { + t.Fatal("Failed to insert record:", err) + } + affected, _ := res.RowsAffected() + if affected != 1 { + t.Fatalf("Expected %d for affected rows, but %d:", 1, affected) + } + + rows, err := db.Query("select id from foo") + if err != nil { + t.Fatal("Failed to select records:", err) + } + defer rows.Close() + + rows.Next() + + var result int + rows.Scan(&result) + if result != 123 { + t.Errorf("Expected %d for fetched result, but %d:", 123, result) + } +} + +func TestUpsert(t *testing.T) { + _, n, _ := LibVersion() + if n < 3024000 { + t.Skip("UPSERT requires sqlite3 >= 3.24.0") + } + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("drop table foo") + _, err = db.Exec("create table foo (name string primary key, counter integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + for i := 0; i < 10; i++ { + res, err := db.Exec("insert into foo(name, counter) values('key', 1) on conflict (name) do update set counter=counter+1") + if err != nil { + t.Fatal("Failed to upsert record:", err) + } + affected, _ := res.RowsAffected() + if affected != 1 { + t.Fatalf("Expected %d for affected rows, but %d:", 1, affected) + } + } + rows, err := db.Query("select name, counter from foo") + if err != nil { + t.Fatal("Failed to select records:", err) + } + defer rows.Close() + + rows.Next() + + var resultName string + var resultCounter int + rows.Scan(&resultName, &resultCounter) + if resultName != "key" { + t.Errorf("Expected %s for fetched result, but %s:", "key", resultName) + } + if resultCounter != 10 { + t.Errorf("Expected %d for fetched result, but %d:", 10, resultCounter) + } + +} + +func TestUpdate(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("drop table foo") + _, err = db.Exec("create table foo (id integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + res, err := db.Exec("insert into foo(id) values(123)") + if err != nil { + t.Fatal("Failed to insert record:", err) + } + expected, err := res.LastInsertId() + if err != nil { + t.Fatal("Failed to get LastInsertId:", err) + } + affected, _ := res.RowsAffected() + if err != nil { + t.Fatal("Failed to get RowsAffected:", err) + } + if affected != 1 { + t.Fatalf("Expected %d for affected rows, but %d:", 1, affected) + } + + res, err = db.Exec("update foo set id = 234") + if err != nil { + t.Fatal("Failed to update record:", err) + } + lastID, err := res.LastInsertId() + if err != nil { + t.Fatal("Failed to get LastInsertId:", err) + } + if expected != lastID { + t.Errorf("Expected %q for last Id, but %q:", expected, lastID) + } + affected, _ = res.RowsAffected() + if err != nil { + t.Fatal("Failed to get RowsAffected:", err) + } + if affected != 1 { + t.Fatalf("Expected %d for affected rows, but %d:", 1, affected) + } + + rows, err := db.Query("select id from foo") + if err != nil { + t.Fatal("Failed to select records:", err) + } + defer rows.Close() + + rows.Next() + + var result int + rows.Scan(&result) + if result != 234 { + t.Errorf("Expected %d for fetched result, but %d:", 234, result) + } +} + +func TestDelete(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("drop table foo") + _, err = db.Exec("create table foo (id integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + res, err := db.Exec("insert into foo(id) values(123)") + if err != nil { + t.Fatal("Failed to insert record:", err) + } + expected, err := res.LastInsertId() + if err != nil { + t.Fatal("Failed to get LastInsertId:", err) + } + affected, err := res.RowsAffected() + if err != nil { + t.Fatal("Failed to get RowsAffected:", err) + } + if affected != 1 { + t.Errorf("Expected %d for cout of affected rows, but %q:", 1, affected) + } + + res, err = db.Exec("delete from foo where id = 123") + if err != nil { + t.Fatal("Failed to delete record:", err) + } + lastID, err := res.LastInsertId() + if err != nil { + t.Fatal("Failed to get LastInsertId:", err) + } + if expected != lastID { + t.Errorf("Expected %q for last Id, but %q:", expected, lastID) + } + affected, err = res.RowsAffected() + if err != nil { + t.Fatal("Failed to get RowsAffected:", err) + } + if affected != 1 { + t.Errorf("Expected %d for cout of affected rows, but %q:", 1, affected) + } + + rows, err := db.Query("select id from foo") + if err != nil { + t.Fatal("Failed to select records:", err) + } + defer rows.Close() + + if rows.Next() { + t.Error("Fetched row but expected not rows") + } +} + +func TestBooleanRoundtrip(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("DROP TABLE foo") + _, err = db.Exec("CREATE TABLE foo(id INTEGER, value BOOL)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = db.Exec("INSERT INTO foo(id, value) VALUES(1, ?)", true) + if err != nil { + t.Fatal("Failed to insert true value:", err) + } + + _, err = db.Exec("INSERT INTO foo(id, value) VALUES(2, ?)", false) + if err != nil { + t.Fatal("Failed to insert false value:", err) + } + + rows, err := db.Query("SELECT id, value FROM foo") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + defer rows.Close() + + for rows.Next() { + var id int + var value bool + + if err := rows.Scan(&id, &value); err != nil { + t.Error("Unable to scan results:", err) + continue + } + + if id == 1 && !value { + t.Error("Value for id 1 should be true, not false") + + } else if id == 2 && value { + t.Error("Value for id 2 should be false, not true") + } + } +} + +func timezone(t time.Time) string { return t.Format("-07:00") } + +func TestTimestamp(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("DROP TABLE foo") + _, err = db.Exec("CREATE TABLE foo(id INTEGER, ts timeSTAMP, dt DATETIME)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC) + timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC) + timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC) + tzTest := time.FixedZone("TEST", -9*3600-13*60) + tests := []struct { + value any + expected time.Time + }{ + {"nonsense", time.Time{}}, + {"0000-00-00 00:00:00", time.Time{}}, + {time.Time{}.Unix(), time.Time{}}, + {timestamp1, timestamp1}, + {timestamp2.Unix(), timestamp2.Truncate(time.Second)}, + {timestamp2.UnixNano() / int64(time.Millisecond), timestamp2.Truncate(time.Millisecond)}, + {timestamp1.In(tzTest), timestamp1.In(tzTest)}, + {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1}, + {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1}, + {timestamp1.Format("2006-01-02 15:04:05"), timestamp1}, + {timestamp1.Format("2006-01-02T15:04:05"), timestamp1}, + {timestamp2, timestamp2}, + {"2006-01-02 15:04:05.123456789", timestamp2}, + {"2006-01-02T15:04:05.123456789", timestamp2}, + {"2006-01-02T05:51:05.123456789-09:13", timestamp2.In(tzTest)}, + {"2012-11-04", timestamp3}, + {"2012-11-04 00:00", timestamp3}, + {"2012-11-04 00:00:00", timestamp3}, + {"2012-11-04 00:00:00.000", timestamp3}, + {"2012-11-04T00:00", timestamp3}, + {"2012-11-04T00:00:00", timestamp3}, + {"2012-11-04T00:00:00.000", timestamp3}, + {"2006-01-02T15:04:05.123456789Z", timestamp2}, + {"2012-11-04Z", timestamp3}, + {"2012-11-04 00:00Z", timestamp3}, + {"2012-11-04 00:00:00Z", timestamp3}, + {"2012-11-04 00:00:00.000Z", timestamp3}, + {"2012-11-04T00:00Z", timestamp3}, + {"2012-11-04T00:00:00Z", timestamp3}, + {"2012-11-04T00:00:00.000Z", timestamp3}, + } + for i := range tests { + _, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value) + if err != nil { + t.Fatal("Failed to insert timestamp:", err) + } + } + + rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + defer rows.Close() + + seen := 0 + for rows.Next() { + var id int + var ts, dt time.Time + + if err := rows.Scan(&id, &ts, &dt); err != nil { + t.Error("Unable to scan results:", err) + continue + } + if id < 0 || id >= len(tests) { + t.Error("Bad row id: ", id) + continue + } + seen++ + if !tests[id].expected.Equal(ts) { + t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) + } + if !tests[id].expected.Equal(dt) { + t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) + } + if timezone(tests[id].expected) != timezone(ts) { + t.Errorf("Timezone for id %v (%v) should be %v, not %v", id, tests[id].value, + timezone(tests[id].expected), timezone(ts)) + } + if timezone(tests[id].expected) != timezone(dt) { + t.Errorf("Timezone for id %v (%v) should be %v, not %v", id, tests[id].value, + timezone(tests[id].expected), timezone(dt)) + } + } + + if seen != len(tests) { + t.Errorf("Expected to see %d rows", len(tests)) + } +} + +func TestBoolean(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + + defer db.Close() + + _, err = db.Exec("CREATE TABLE foo(id INTEGER, fbool BOOLEAN)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + bool1 := true + _, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(1, ?)", bool1) + if err != nil { + t.Fatal("Failed to insert boolean:", err) + } + + bool2 := false + _, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(2, ?)", bool2) + if err != nil { + t.Fatal("Failed to insert boolean:", err) + } + + bool3 := "nonsense" + _, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(3, ?)", bool3) + if err != nil { + t.Fatal("Failed to insert nonsense:", err) + } + + rows, err := db.Query("SELECT id, fbool FROM foo where fbool = ?", bool1) + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + counter := 0 + + var id int + var fbool bool + + for rows.Next() { + if err := rows.Scan(&id, &fbool); err != nil { + t.Fatal("Unable to scan results:", err) + } + counter++ + } + + if counter != 1 { + t.Fatalf("Expected 1 row but %v", counter) + } + + if id != 1 && !fbool { + t.Fatalf("Value for id 1 should be %v, not %v", bool1, fbool) + } + + rows, err = db.Query("SELECT id, fbool FROM foo where fbool = ?", bool2) + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + + counter = 0 + + for rows.Next() { + if err := rows.Scan(&id, &fbool); err != nil { + t.Fatal("Unable to scan results:", err) + } + counter++ + } + + if counter != 1 { + t.Fatalf("Expected 1 row but %v", counter) + } + + if id != 2 && fbool { + t.Fatalf("Value for id 2 should be %v, not %v", bool2, fbool) + } + + // make sure "nonsense" triggered an error + rows, err = db.Query("SELECT id, fbool FROM foo where id=?;", 3) + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + + rows.Next() + err = rows.Scan(&id, &fbool) + if err == nil { + t.Error("Expected error from \"nonsense\" bool") + } +} + +func TestFloat32(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE foo(id INTEGER)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = db.Exec("INSERT INTO foo(id) VALUES(null)") + if err != nil { + t.Fatal("Failed to insert null:", err) + } + + rows, err := db.Query("SELECT id FROM foo") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + + if !rows.Next() { + t.Fatal("Unable to query results:", err) + } + + var id any + if err := rows.Scan(&id); err != nil { + t.Fatal("Unable to scan results:", err) + } + if id != nil { + t.Error("Expected nil but not") + } +} + +func TestNull(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + rows, err := db.Query("SELECT 3.141592") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + + if !rows.Next() { + t.Fatal("Unable to query results:", err) + } + + var v any + if err := rows.Scan(&v); err != nil { + t.Fatal("Unable to scan results:", err) + } + f, ok := v.(float64) + if !ok { + t.Error("Expected float but not") + } + if f != 3.141592 { + t.Error("Expected 3.141592 but not") + } +} + +func TestTransaction(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE foo(id INTEGER)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + tx, err := db.Begin() + if err != nil { + t.Fatal("Failed to begin transaction:", err) + } + + _, err = tx.Exec("INSERT INTO foo(id) VALUES(1)") + if err != nil { + t.Fatal("Failed to insert null:", err) + } + + rows, err := tx.Query("SELECT id from foo") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + + err = tx.Rollback() + if err != nil { + t.Fatal("Failed to rollback transaction:", err) + } + + if rows.Next() { + t.Fatal("Unable to query results:", err) + } + + tx, err = db.Begin() + if err != nil { + t.Fatal("Failed to begin transaction:", err) + } + + _, err = tx.Exec("INSERT INTO foo(id) VALUES(1)") + if err != nil { + t.Fatal("Failed to insert null:", err) + } + + err = tx.Commit() + if err != nil { + t.Fatal("Failed to commit transaction:", err) + } + + rows, err = tx.Query("SELECT id from foo") + if err == nil { + t.Fatal("Expected failure to query") + } +} + +func TestWAL(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + if _, err = db.Exec("CREATE TABLE test (id SERIAL, user TEXT NOT NULL, name TEXT NOT NULL);"); err != nil { + t.Fatal("Failed to Exec CREATE TABLE:", err) + } + if _, err = db.Exec("INSERT INTO test (user, name) VALUES ('user','name');"); err != nil { + t.Fatal("Failed to Exec INSERT:", err) + } + + trans, err := db.Begin() + if err != nil { + t.Fatal("Failed to Begin:", err) + } + s, err := trans.Prepare("INSERT INTO test (user, name) VALUES (?, ?);") + if err != nil { + t.Fatal("Failed to Prepare:", err) + } + + var count int + if err = trans.QueryRow("SELECT count(user) FROM test;").Scan(&count); err != nil { + t.Fatal("Failed to QueryRow:", err) + } + if _, err = s.Exec("bbbb", "aaaa"); err != nil { + t.Fatal("Failed to Exec prepared statement:", err) + } + if err = s.Close(); err != nil { + t.Fatal("Failed to Close prepared statement:", err) + } + if err = trans.Commit(); err != nil { + t.Fatal("Failed to Commit:", err) + } +} + +func TestTimezoneConversion(t *testing.T) { + zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} + for _, tz := range zones { + db, err := sql.Open(DriverName, "file:tz?mode=memory&_loc="+url.QueryEscape(tz)) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("DROP TABLE foo") + _, err = db.Exec("CREATE TABLE foo(id INTEGER, ts TIMESTAMP, dt DATETIME)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + loc, err := time.LoadLocation(tz) + if err != nil { + t.Fatal("Failed to load location:", err) + } + + timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC) + timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC) + timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC) + tests := []struct { + value any + expected time.Time + }{ + {"nonsense", time.Time{}.In(loc)}, + {"0000-00-00 00:00:00", time.Time{}.In(loc)}, + {timestamp1, timestamp1.In(loc)}, + {timestamp1.Unix(), timestamp1.In(loc)}, + {timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1.In(loc)}, + {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1.In(loc)}, + {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1.In(loc)}, + {timestamp1.Format("2006-01-02 15:04:05"), timestamp1.In(loc)}, + {timestamp1.Format("2006-01-02T15:04:05"), timestamp1.In(loc)}, + {timestamp2, timestamp2.In(loc)}, + {"2006-01-02 15:04:05.123456789", timestamp2.In(loc)}, + {"2006-01-02T15:04:05.123456789", timestamp2.In(loc)}, + {"2012-11-04", timestamp3.In(loc)}, + {"2012-11-04 00:00", timestamp3.In(loc)}, + {"2012-11-04 00:00:00", timestamp3.In(loc)}, + {"2012-11-04 00:00:00.000", timestamp3.In(loc)}, + {"2012-11-04T00:00", timestamp3.In(loc)}, + {"2012-11-04T00:00:00", timestamp3.In(loc)}, + {"2012-11-04T00:00:00.000", timestamp3.In(loc)}, + } + for i := range tests { + _, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value) + if err != nil { + t.Fatal("Failed to insert timestamp:", err) + } + } + + rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + defer rows.Close() + + seen := 0 + for rows.Next() { + var id int + var ts, dt time.Time + + if err := rows.Scan(&id, &ts, &dt); err != nil { + t.Error("Unable to scan results:", err) + continue + } + if id < 0 || id >= len(tests) { + t.Error("Bad row id: ", id) + continue + } + seen++ + if !tests[id].expected.Equal(ts) { + t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, ts) + } + if !tests[id].expected.Equal(dt) { + t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) + } + if tests[id].expected.Location().String() != ts.Location().String() { + t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), ts.Location().String()) + } + if tests[id].expected.Location().String() != dt.Location().String() { + t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), dt.Location().String()) + } + } + + if seen != len(tests) { + t.Errorf("Expected to see %d rows", len(tests)) + } + } +} + +// TODO: Execer & Queryer currently disabled +// https://github.com/mattn/go-sqlite3/issues/82 +func TestExecer(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec(` + create table foo (id integer); -- one comment + insert into foo(id) values(?); + insert into foo(id) values(?); + insert into foo(id) values(?); -- another comment + `, 1, 2, 3) + if err != nil { + t.Error("Failed to call db.Exec:", err) + } +} + +func TestQueryer(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec(` + create table foo (id integer); + `) + if err != nil { + t.Error("Failed to call db.Query:", err) + } + + _, err = db.Exec(` + insert into foo(id) values(?); + insert into foo(id) values(?); + insert into foo(id) values(?); + `, 3, 2, 1) + if err != nil { + t.Error("Failed to call db.Exec:", err) + } + rows, err := db.Query(` + select id from foo order by id; + `) + if err != nil { + t.Error("Failed to call db.Query:", err) + } + defer rows.Close() + n := 0 + for rows.Next() { + var id int + err = rows.Scan(&id) + if err != nil { + t.Error("Failed to db.Query:", err) + } + if id != n+1 { + t.Error("Failed to db.Query: not matched results") + } + n = n + 1 + } + if err := rows.Err(); err != nil { + t.Errorf("Post-scan failed: %v\n", err) + } + if n != 3 { + t.Errorf("Expected 3 rows but retrieved %v", n) + } +} + +func TestStress(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + db.Exec("CREATE TABLE foo (id int);") + db.Exec("INSERT INTO foo VALUES(1);") + db.Exec("INSERT INTO foo VALUES(2);") + + for i := 0; i < 10000; i++ { + for j := 0; j < 3; j++ { + rows, err := db.Query("select * from foo where id=1;") + if err != nil { + t.Error("Failed to call db.Query:", err) + } + for rows.Next() { + var i int + if err := rows.Scan(&i); err != nil { + t.Errorf("Scan failed: %v\n", err) + } + } + if err := rows.Err(); err != nil { + t.Errorf("Post-scan failed: %v\n", err) + } + rows.Close() + } + } +} + +func TestDateTimeLocal(t *testing.T) { + const zone = "Asia/Tokyo" + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + filename1 := tempFilename + "?mode=memory&cache=shared" + filename2 := filename1 + "&_loc=" + zone + db1, err := sql.Open(DriverName, filename2) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db1.Close() + db1.Exec("CREATE TABLE foo (dt datetime);") + db1.Exec("INSERT INTO foo VALUES('2015-03-05 15:16:17');") + + row := db1.QueryRow("select * from foo") + var d time.Time + err = row.Scan(&d) + if err != nil { + t.Fatal("Failed to scan datetime:", err) + } + if d.Hour() == 15 || !strings.Contains(d.String(), "JST") { + t.Fatal("Result should have timezone", d) + } + + db2, err := sql.Open(DriverName, filename1) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db2.Close() + + row = db2.QueryRow("select * from foo") + err = row.Scan(&d) + if err != nil { + t.Fatal("Failed to scan datetime:", err) + } + if d.UTC().Hour() != 15 || !strings.Contains(d.String(), "UTC") { + t.Fatalf("Result should not have timezone %v %v", zone, d.String()) + } + + _, err = db2.Exec("DELETE FROM foo") + if err != nil { + t.Fatal("Failed to delete table:", err) + } + dt, err := time.Parse("2006/1/2 15/4/5 -0700 MST", "2015/3/5 15/16/17 +0900 JST") + if err != nil { + t.Fatal("Failed to parse datetime:", err) + } + db2.Exec("INSERT INTO foo VALUES(?);", dt) + + db3, err := sql.Open(DriverName, filename2) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db3.Close() + + row = db3.QueryRow("select * from foo") + err = row.Scan(&d) + if err != nil { + t.Fatal("Failed to scan datetime:", err) + } + if d.Hour() != 15 || !strings.Contains(d.String(), "JST") { + t.Fatalf("Result should have timezone %v %v", zone, d.String()) + } +} + +func TestVersion(t *testing.T) { + s, n, id := LibVersion() + if s == "" || n == 0 || id == "" { + t.Errorf("Version failed %q, %d, %q\n", s, n, id) + } +} + +func TestStringContainingZero(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec(` + create table foo (id integer, name, extra text); + `) + if err != nil { + t.Error("Failed to call db.Query:", err) + } + + const text = "foo\x00bar" + + _, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, text) + if err != nil { + t.Error("Failed to call db.Exec:", err) + } + + row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, text) + if row == nil { + t.Error("Failed to call db.QueryRow") + } + + var id int + var extra string + err = row.Scan(&id, &extra) + if err != nil { + t.Error("Failed to db.Scan:", err) + } + if id != 1 || extra != text { + t.Error("Failed to db.QueryRow: not matched results") + } +} + +const CurrentTimeStamp = "2006-01-02 15:04:05" + +type TimeStamp struct{ *time.Time } + +func (t TimeStamp) Scan(value any) error { + var err error + switch v := value.(type) { + case string: + *t.Time, err = time.Parse(CurrentTimeStamp, v) + case []byte: + *t.Time, err = time.Parse(CurrentTimeStamp, string(v)) + default: + err = errors.New("invalid type for current_timestamp") + } + return err +} + +func (t TimeStamp) Value() (driver.Value, error) { + return t.Time.Format(CurrentTimeStamp), nil +} + +func TestDateTimeNow(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + var d time.Time + err = db.QueryRow("SELECT datetime('now')").Scan(TimeStamp{&d}) + if err != nil { + t.Fatal("Failed to scan datetime:", err) + } +} + +type sumAggregator int64 + +func (s *sumAggregator) Step(x int64) { + *s += sumAggregator(x) +} + +func (s *sumAggregator) Done() int64 { + return int64(*s) +} + +func TestAggregatorRegistration(t *testing.T) { + customSum := func() *sumAggregator { + var ret sumAggregator + return &ret + } + + sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + return conn.RegisterAggregator("customSum", customSum, true) + }, + }) + db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("create table foo (department integer, profits integer)") + if err != nil { + // trace feature is not implemented + t.Skip("Failed to create table:", err) + } + + _, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)") + if err != nil { + t.Fatal("Failed to insert records:", err) + } + + tests := []struct { + dept, sum int64 + }{ + {1, 30}, + {2, 42}, + } + + for _, test := range tests { + var ret int64 + err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret) + if err != nil { + t.Fatal("Query failed:", err) + } + if ret != test.sum { + t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum) + } + } +} + +type mode struct { + counts map[any]int + top any + topCount int +} + +func newMode() *mode { + return &mode{ + counts: map[any]int{}, + } +} + +func (m *mode) Step(x any) { + m.counts[x]++ + c := m.counts[x] + if c > m.topCount { + m.top = x + m.topCount = c + } +} + +func (m *mode) Done() any { + return m.top +} + +func TestAggregatorRegistration_GenericReturn(t *testing.T) { + sql.Register("sqlite3_AggregatorRegistration_GenericReturn", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + return conn.RegisterAggregator("mode", newMode, true) + }, + }) + db, err := sql.Open("sqlite3_AggregatorRegistration_GenericReturn", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("create table foo (department integer, profits integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + _, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115), (2, 20)") + if err != nil { + t.Fatal("Failed to insert records:", err) + } + + var mode int + err = db.QueryRow("select mode(profits) from foo").Scan(&mode) + if err != nil { + t.Fatal("MODE query error:", err) + } + + if mode != 20 { + t.Fatal("Got incorrect mode. Wanted 20, got: ", mode) + } +} + +func rot13(r rune) rune { + switch { + case r >= 'A' && r <= 'Z': + return 'A' + (r-'A'+13)%26 + case r >= 'a' && r <= 'z': + return 'a' + (r-'a'+13)%26 + } + return r +} + +func TestCollationRegistration(t *testing.T) { + collateRot13 := func(a, b string) int { + ra, rb := strings.Map(rot13, a), strings.Map(rot13, b) + return strings.Compare(ra, rb) + } + collateRot13Reverse := func(a, b string) int { + return collateRot13(b, a) + } + + sql.Register("sqlite3_CollationRegistration", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + if err := conn.RegisterCollation("rot13", collateRot13); err != nil { + return err + } + if err := conn.RegisterCollation("rot13reverse", collateRot13Reverse); err != nil { + return err + } + return nil + }, + }) + + db, err := sql.Open("sqlite3_CollationRegistration", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + populate := []string{ + `CREATE TABLE test (s TEXT)`, + `INSERT INTO test VALUES ('aaaa')`, + `INSERT INTO test VALUES ('ffff')`, + `INSERT INTO test VALUES ('qqqq')`, + `INSERT INTO test VALUES ('tttt')`, + `INSERT INTO test VALUES ('zzzz')`, + } + for _, stmt := range populate { + if _, err := db.Exec(stmt); err != nil { + t.Fatal("Failed to populate test DB:", err) + } + } + + ops := []struct { + query string + want []string + }{ + { + "SELECT * FROM test ORDER BY s COLLATE rot13 ASC", + []string{ + "qqqq", + "tttt", + "zzzz", + "aaaa", + "ffff", + }, + }, + { + "SELECT * FROM test ORDER BY s COLLATE rot13 DESC", + []string{ + "ffff", + "aaaa", + "zzzz", + "tttt", + "qqqq", + }, + }, + { + "SELECT * FROM test ORDER BY s COLLATE rot13reverse ASC", + []string{ + "ffff", + "aaaa", + "zzzz", + "tttt", + "qqqq", + }, + }, + { + "SELECT * FROM test ORDER BY s COLLATE rot13reverse DESC", + []string{ + "qqqq", + "tttt", + "zzzz", + "aaaa", + "ffff", + }, + }, + } + + for _, op := range ops { + rows, err := db.Query(op.query) + if err != nil { + t.Fatalf("Query %q failed: %s", op.query, err) + } + got := []string{} + defer rows.Close() + for rows.Next() { + var s string + if err = rows.Scan(&s); err != nil { + t.Fatalf("Reading row for %q: %s", op.query, err) + } + got = append(got, s) + } + if err = rows.Err(); err != nil { + t.Fatalf("Reading rows for %q: %s", op.query, err) + } + + if !reflect.DeepEqual(got, op.want) { + t.Fatalf("Unexpected output from %q\ngot:\n%s\n\nwant:\n%s", op.query, strings.Join(got, "\n"), strings.Join(op.want, "\n")) + } + } +} + +func TestDeclTypes(t *testing.T) { + + d := SQLiteDriver{} + + conn, err := d.Open(":memory:") + if err != nil { + t.Fatal("Failed to begin transaction:", err) + } + defer conn.Close() + + sqlite3conn := conn.(*SQLiteConn) + + _, err = sqlite3conn.Exec("create table foo (id integer not null primary key, name text)", nil) + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = sqlite3conn.Exec("insert into foo(name) values('bar')", nil) + if err != nil { + t.Fatal("Failed to insert:", err) + } + + rs, err := sqlite3conn.Query("select * from foo", nil) + if err != nil { + t.Fatal("Failed to select:", err) + } + defer rs.Close() + + declTypes := rs.(*SQLiteRows).DeclTypes() + + if !reflect.DeepEqual(declTypes, []string{"integer", "text"}) { + t.Fatal("Unexpected declTypes:", declTypes) + } +} + +func TestPinger(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal(err) + } + err = db.Ping() + if err != nil { + t.Fatal(err) + } + db.Close() + err = db.Ping() + if err == nil { + t.Fatal("Should be closed") + } +} + +func TestUpdateAndTransactionHooks(t *testing.T) { + var events []string + var commitHookReturn = 0 + + sql.Register("sqlite3_UpdateHook", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + conn.RegisterCommitHook(func() int { + events = append(events, "commit") + return commitHookReturn + }) + conn.RegisterRollbackHook(func() { + events = append(events, "rollback") + }) + conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) { + events = append(events, fmt.Sprintf("update(op=%v db=%v table=%v rowid=%v)", op, db, table, rowid)) + }) + return nil + }, + }) + db, err := sql.Open("sqlite3_UpdateHook", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + statements := []string{ + "create table foo (id integer primary key)", + "insert into foo values (9)", + "update foo set id = 99 where id = 9", + "delete from foo where id = 99", + } + for _, statement := range statements { + _, err = db.Exec(statement) + if err != nil { + t.Fatalf("Unable to prepare test data [%v]: %v", statement, err) + } + } + + commitHookReturn = 1 + _, err = db.Exec("insert into foo values (5)") + if err == nil { + t.Error("Commit hook failed to rollback transaction") + } + + var expected = []string{ + "commit", + fmt.Sprintf("update(op=%v db=main table=foo rowid=9)", SQLITE_INSERT), + "commit", + fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_UPDATE), + "commit", + fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_DELETE), + "commit", + fmt.Sprintf("update(op=%v db=main table=foo rowid=5)", SQLITE_INSERT), + "commit", + "rollback", + } + if !reflect.DeepEqual(events, expected) { + t.Errorf("Expected notifications: %#v\nbut got: %#v\n", expected, events) + } +} + +func TestSetFileControlInt(t *testing.T) { + t.Run("PERSIST_WAL", func(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + + sql.Register("sqlite3_FCNTL_PERSIST_WAL", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + if err := conn.SetFileControlInt("", SQLITE_FCNTL_PERSIST_WAL, 1); err != nil { + return fmt.Errorf("Unexpected error from SetFileControlInt(): %w", err) + } + return nil + }, + }) + + db, err := sql.Open("sqlite3_FCNTL_PERSIST_WAL", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + if _, err := db.Exec(`CREATE TABLE t (x)`); err != nil { + t.Fatal("Failed to create table:", err) + } + if err := db.Close(); err != nil { + t.Fatal("Failed to close database", err) + } + + // Ensure WAL file persists after close. + if _, err := os.Stat(tempFilename + "-wal"); err != nil { + t.Fatal("Expected WAL file to be persisted after close", err) + } + }) +} + +func TestNonColumnString(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + var x any + if err := db.QueryRow("SELECT 'hello'").Scan(&x); err != nil { + t.Fatal(err) + } + s, ok := x.(string) + if !ok { + t.Fatalf("non-column string must return string but got %T", x) + } + if s != "hello" { + t.Fatalf("non-column string must return %q but got %q", "hello", s) + } +} + +func TestNilAndEmptyBytes(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + actualNil := []byte("use this to use an actual nil not a reference to nil") + emptyBytes := []byte{} + for tsti, tst := range []struct { + name string + columnType string + insertBytes []byte + expectedBytes []byte + }{ + {"actual nil blob", "blob", actualNil, nil}, + {"referenced nil blob", "blob", nil, nil}, + {"empty blob", "blob", emptyBytes, emptyBytes}, + {"actual nil text", "text", actualNil, nil}, + {"referenced nil text", "text", nil, nil}, + {"empty text", "text", emptyBytes, emptyBytes}, + } { + if _, err = db.Exec(fmt.Sprintf("create table tbl%d (txt %s)", tsti, tst.columnType)); err != nil { + t.Fatal(tst.name, err) + } + if bytes.Equal(tst.insertBytes, actualNil) { + if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), nil); err != nil { + t.Fatal(tst.name, err) + } + } else { + if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), &tst.insertBytes); err != nil { + t.Fatal(tst.name, err) + } + } + rows, err := db.Query(fmt.Sprintf("select txt from tbl%d", tsti)) + if err != nil { + t.Fatal(tst.name, err) + } + if !rows.Next() { + t.Fatal(tst.name, "no rows") + } + var scanBytes []byte + if err = rows.Scan(&scanBytes); err != nil { + t.Fatal(tst.name, err) + } + if err = rows.Err(); err != nil { + t.Fatal(tst.name, err) + } + if tst.expectedBytes == nil && scanBytes != nil { + t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes) + } else if !bytes.Equal(scanBytes, tst.expectedBytes) { + t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes) + } + } +} + +func TestInsertNilByteSlice(t *testing.T) { + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + if _, err := db.Exec("create table blob_not_null (b blob not null)"); err != nil { + t.Fatal(err) + } + var nilSlice []byte + if _, err := db.Exec("insert into blob_not_null (b) values (?)", nilSlice); err == nil { + t.Fatal("didn't expect INSERT to 'not null' column with a nil []byte slice to work") + } + zeroLenSlice := []byte{} + if _, err := db.Exec("insert into blob_not_null (b) values (?)", zeroLenSlice); err != nil { + t.Fatal("failed to insert zero-length slice") + } +} + +func TestNamedParam(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open(DriverName, ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("drop table foo") + _, err = db.Exec("create table foo (id integer, name text, amount integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = db.Exec("insert into foo(id, name, amount) values(:id, @name, $amount)", + sql.Named("bar", 42), sql.Named("baz", "quux"), + sql.Named("amount", 123), sql.Named("corge", "waldo"), + sql.Named("id", 2), sql.Named("name", "grault")) + if err != nil { + t.Fatal("Failed to insert record with named parameters:", err) + } + + rows, err := db.Query("select id, name, amount from foo") + if err != nil { + t.Fatal("Failed to select records:", err) + } + defer rows.Close() + + rows.Next() + + var id, amount int + var name string + rows.Scan(&id, &name, &amount) + if id != 2 || name != "grault" || amount != 123 { + t.Errorf("Expected %d, %q, %d for fetched result, but got %d, %q, %d:", 2, "grault", 123, id, name, amount) + } +} + +var customFunctionOnce sync.Once + +func TestSuite(t *testing.T) { + initializeTestDB(t) + defer freeTestDB() + + for _, test := range tests { + t.Run(test.Name, test.F) + } +} + +// DB provide context for the tests +type TestDB struct { + testing.TB + *sql.DB + once sync.Once + tempFilename string +} + +var tdb *TestDB + +func initializeTestDB(t testing.TB) { + tempFilename := TempFilename(t) + d, err := sql.Open(DriverName, tempFilename+"?mode=memory&cache=shared") + if err != nil { + os.Remove(tempFilename) + t.Fatal(err) + } + + tdb = &TestDB{t, d, sync.Once{}, tempFilename} +} + +func freeTestDB() { + err := tdb.DB.Close() + if err != nil { + panic(err) + } + err = os.Remove(tdb.tempFilename) + if err != nil { + panic(err) + } +} + +// the following tables will be created and dropped during the test +var testTables = []string{"foo"} + +var tests = []testing.InternalTest{ + {Name: "TestResult", F: testResult}, + {Name: "TestBlobs", F: testBlobs}, + {Name: "TestMultiBlobs", F: testMultiBlobs}, + {Name: "TestNullZeroLengthBlobs", F: testNullZeroLengthBlobs}, + {Name: "TestManyQueryRow", F: testManyQueryRow}, + {Name: "TestTxQuery", F: testTxQuery}, + {Name: "TestPreparedStmt", F: testPreparedStmt}, + {Name: "TestExecEmptyQuery", F: testExecEmptyQuery}, +} + +func (db *TestDB) mustExec(sql string, args ...any) sql.Result { + res, err := db.Exec(sql, args...) + if err != nil { + db.Fatalf("Error running %q: %v", sql, err) + } + return res +} + +func (db *TestDB) tearDown() { + for _, tbl := range testTables { + db.mustExec("drop table if exists " + tbl) + } +} + +// testResult is test for result +func testResult(t *testing.T) { + tdb.tearDown() + tdb.mustExec("create temporary table test (id integer primary key autoincrement, name varchar(10))") + + for i := 1; i < 3; i++ { + r := tdb.mustExec("insert into test (name) values (?)", fmt.Sprintf("row %d", i)) + n, err := r.RowsAffected() + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Errorf("got %v, want %v", n, 1) + } + n, err = r.LastInsertId() + if err != nil { + t.Fatal(err) + } + if n != int64(i) { + t.Errorf("got %v, want %v", n, i) + } + } + if _, err := tdb.Exec("error!"); err == nil { + t.Fatalf("expected error") + } +} + +// testBlobs is test for blobs +func testBlobs(t *testing.T) { + tdb.tearDown() + var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + tdb.mustExec("create table foo (id integer primary key, bar blob[16])") + tdb.mustExec("insert into foo (id, bar) values(?,?)", 0, blob) + + want := fmt.Sprintf("%x", blob) + + b := make([]byte, 16) + err := tdb.QueryRow("select bar from foo where id = ?", 0).Scan(&b) + got := fmt.Sprintf("%x", b) + if err != nil { + t.Errorf("[]byte scan: %v", err) + } else if got != want { + t.Errorf("for []byte, got %q; want %q", got, want) + } + + err = tdb.QueryRow("select bar from foo where id = ?", 0).Scan(&got) + want = string(blob) + if err != nil { + t.Errorf("string scan: %v", err) + } else if got != want { + t.Errorf("for string, got %q; want %q", got, want) + } +} + +func testMultiBlobs(t *testing.T) { + tdb.tearDown() + tdb.mustExec("create table foo (id integer primary key, bar blob[16])") + var blob0 = []byte{0, 1, 2, 3, 4, 5, 6, 7} + tdb.mustExec("insert into foo (id, bar) values(?,?)", 0, blob0) + var blob1 = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + tdb.mustExec("insert into foo (id, bar) values(?,?)", 1, blob1) + + r, err := tdb.Query("select bar from foo order by id") + if err != nil { + t.Fatal(err) + } + defer r.Close() + if !r.Next() { + if r.Err() != nil { + t.Fatal(err) + } + t.Fatal("expected one rows") + } + + want0 := fmt.Sprintf("%x", blob0) + b0 := make([]byte, 8) + err = r.Scan(&b0) + if err != nil { + t.Fatal(err) + } + got0 := fmt.Sprintf("%x", b0) + + if !r.Next() { + if r.Err() != nil { + t.Fatal(err) + } + t.Fatal("expected one rows") + } + + want1 := fmt.Sprintf("%x", blob1) + b1 := make([]byte, 16) + err = r.Scan(&b1) + if err != nil { + t.Fatal(err) + } + got1 := fmt.Sprintf("%x", b1) + if got0 != want0 { + t.Errorf("for []byte, got %q; want %q", got0, want0) + } + if got1 != want1 { + t.Errorf("for []byte, got %q; want %q", got1, want1) + } +} + +// testBlobs tests that we distinguish between null and zero-length blobs +func testNullZeroLengthBlobs(t *testing.T) { + tdb.tearDown() + tdb.mustExec("create table foo (id integer primary key, bar blob[16])") + tdb.mustExec("insert into foo (id, bar) values(?,?)", 0, nil) + tdb.mustExec("insert into foo (id, bar) values(?,?)", 1, []byte{}) + + r0 := tdb.QueryRow("select bar from foo where id=0") + var b0 []byte + err := r0.Scan(&b0) + if err != nil { + t.Fatal(err) + } + if b0 != nil { + t.Errorf("for id=0, got %x; want nil", b0) + } + + r1 := tdb.QueryRow("select bar from foo where id=1") + var b1 []byte + err = r1.Scan(&b1) + if err != nil { + t.Fatal(err) + } + if b1 == nil { + t.Error("for id=1, got nil; want zero-length slice") + } else if len(b1) > 0 { + t.Errorf("for id=1, got %x; want zero-length slice", b1) + } +} + +func testManyQueryRow(t *testing.T) { + // FIXME: too slow + tdb.tearDown() + tdb.mustExec("create table foo (id integer primary key, name varchar(50))") + tdb.mustExec("insert into foo (id, name) values(?,?)", 1, "bob") + var name string + for i := 0; i < 10000; i++ { + err := tdb.QueryRow("select name from foo where id = ?", 1).Scan(&name) + if err != nil || name != "bob" { + t.Fatalf("on query %d: err=%v, name=%q", i, err, name) + } + } +} + +func testTxQuery(t *testing.T) { + tdb.tearDown() + tx, err := tdb.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + _, err = tx.Exec("create table foo (id integer primary key, name varchar(50))") + if err != nil { + t.Fatal(err) + } + + _, err = tx.Exec("insert into foo (id, name) values(?,?)", 1, "bob") + if err != nil { + t.Fatal(err) + } + + r, err := tx.Query("select name from foo where id = ?", 1) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if !r.Next() { + if r.Err() != nil { + t.Fatal(err) + } + t.Fatal("expected one rows") + } + + var name string + err = r.Scan(&name) + if err != nil { + t.Fatal(err) + } +} + +func testPreparedStmt(t *testing.T) { + tdb.tearDown() + tdb.mustExec("CREATE TABLE t (count INT)") + sel, err := tdb.Prepare("SELECT count FROM t ORDER BY count DESC") + if err != nil { + t.Fatalf("prepare 1: %v", err) + } + ins, err := tdb.Prepare("INSERT INTO t (count) VALUES (?)") + if err != nil { + t.Fatalf("prepare 2: %v", err) + } + + for n := 1; n <= 3; n++ { + if _, err := ins.Exec(n); err != nil { + t.Fatalf("insert(%d) = %v", n, err) + } + } + + const nRuns = 10 + var wg sync.WaitGroup + for i := 0; i < nRuns; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + count := 0 + if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows { + t.Errorf("Query: %v", err) + return + } + if _, err := ins.Exec(rand.Intn(100)); err != nil { + t.Errorf("Insert: %v", err) + return + } + } + }() + } + wg.Wait() +} + +// testEmptyQuery is test for validating the API in case of empty query +func testExecEmptyQuery(t *testing.T) { + tdb.tearDown() + res, err := tdb.Exec(" -- this is just a comment ") + if err != nil { + t.Fatalf("empty query err: %v", err) + } + + _, err = res.LastInsertId() + if err != nil { + t.Fatalf("LastInsertId returned an error: %v", err) + } + + _, err = res.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected returned an error: %v", err) + } +} + + + +func MainTest() { + tests := []testing.InternalTest { + { "TestBackupStepByStep", TestBackupStepByStep }, + { "TestBackupAllRemainingPages", TestBackupAllRemainingPages }, + { "TestBackupError", TestBackupError }, + { "TestCallbackArgCast", TestCallbackArgCast }, + { "TestCallbackConverters", TestCallbackConverters }, + { "TestCallbackReturnAny", TestCallbackReturnAny }, + { "TestSimpleError", TestSimpleError }, + { "TestCorruptDbErrors", TestCorruptDbErrors }, + { "TestSqlLogicErrors", TestSqlLogicErrors }, + { "TestExtendedErrorCodes_ForeignKey", TestExtendedErrorCodes_ForeignKey }, + { "TestExtendedErrorCodes_NotNull", TestExtendedErrorCodes_NotNull }, + { "TestExtendedErrorCodes_Unique", TestExtendedErrorCodes_Unique }, + { "TestError_SystemErrno", TestError_SystemErrno }, + { "TestBeginTxCancel", TestBeginTxCancel }, + { "TestStmtReadonly", TestStmtReadonly }, + { "TestNamedParams", TestNamedParams }, + { "TestShortTimeout", TestShortTimeout }, + { "TestExecContextCancel", TestExecContextCancel }, + { "TestQueryRowContextCancel", TestQueryRowContextCancel }, + { "TestQueryRowContextCancelParallel", TestQueryRowContextCancelParallel }, + { "TestExecCancel", TestExecCancel }, + { "TestFileCopyTruncate", TestFileCopyTruncate }, + { "TestColumnTableName", TestColumnTableName }, + { "TestFTS5", TestFTS5 }, + { "TestSerializeDeserialize", TestSerializeDeserialize }, + { "TestOpenWithVFS", TestOpenWithVFS }, + { "TestOpenNoCreate", TestOpenNoCreate }, + { "TestReadonly", TestReadonly }, + { "TestDeferredForeignKey", TestDeferredForeignKey }, + { "TestClose", TestClose }, + { "TestInsert", TestInsert }, + { "TestUpsert", TestUpsert }, + { "TestUpdate", TestUpdate }, + { "TestDelete", TestDelete }, + { "TestBooleanRoundtrip", TestBooleanRoundtrip }, + { "TestTimestamp", TestTimestamp }, + { "TestBoolean", TestBoolean }, + { "TestFloat32", TestFloat32 }, + { "TestNull", TestNull }, + { "TestTransaction", TestTransaction }, + { "TestWAL", TestWAL }, + { "TestTimezoneConversion", TestTimezoneConversion }, + { "TestExecer", TestExecer }, + { "TestQueryer", TestQueryer }, + { "TestStress", TestStress }, + { "TestDateTimeLocal", TestDateTimeLocal }, + { "TestVersion", TestVersion }, + { "TestStringContainingZero", TestStringContainingZero }, + { "TestDateTimeNow", TestDateTimeNow }, + { "TestAggregatorRegistration", TestAggregatorRegistration }, + { "TestAggregatorRegistration_GenericReturn", TestAggregatorRegistration_GenericReturn }, + { "TestCollationRegistration", TestCollationRegistration }, + { "TestDeclTypes", TestDeclTypes }, + { "TestPinger", TestPinger }, + { "TestUpdateAndTransactionHooks", TestUpdateAndTransactionHooks }, + { "TestSetFileControlInt", TestSetFileControlInt }, + { "TestNonColumnString", TestNonColumnString }, + { "TestNilAndEmptyBytes", TestNilAndEmptyBytes }, + { "TestInsertNilByteSlice", TestInsertNilByteSlice }, + { "TestNamedParam", TestNamedParam }, + { "TestSuite", TestSuite }, // FIXME: too slow + } + + deps := testdeps.TestDeps{} + benchmarks := []testing.InternalBenchmark {} + fuzzTargets := []testing.InternalFuzzTarget{} + examples := []testing.InternalExample {} + m := testing.MainStart(deps, tests, benchmarks, fuzzTargets, examples) + os.Exit(m.Run()) +} diff --git a/tests/main.go b/tests/main.go index 7e0b774..f8a5b6f 100644 --- a/tests/main.go +++ b/tests/main.go @@ -1,7 +1,7 @@ package main -import "acudego" +import "golite" func main() { - acudego.MainTest() + golite.MainTest() } -- cgit v1.2.3