diff options
author | David Anderson <dave@natulte.net> | 2015-08-20 23:08:48 -0700 |
---|---|---|
committer | David Anderson <dave@natulte.net> | 2015-08-21 13:39:50 -0700 |
commit | cf8fa0af80e0d227c79ef2b4635e8d0d77432275 (patch) | |
tree | aa5d09e0d949847240ef50f3da2ad1b99f5cefe2 /sqlite3.go | |
parent | Merge pull request #228 from whiter4bbit/added_icu_support (diff) | |
download | golite-cf8fa0af80e0d227c79ef2b4635e8d0d77432275.tar.gz golite-cf8fa0af80e0d227c79ef2b4635e8d0d77432275.tar.xz |
Implement support for passing Go functions as custom functions to SQLite.
Fixes #226.
Diffstat (limited to 'sqlite3.go')
-rw-r--r-- | sqlite3.go | 191 |
1 files changed, 191 insertions, 0 deletions
@@ -66,6 +66,15 @@ _sqlite3_step(sqlite3_stmt* stmt, long long* rowid, long long* changes) return rv; } +void _sqlite3_result_text(sqlite3_context* ctx, const char* s) { + sqlite3_result_text(ctx, s, -1, &free); +} + +void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l) { + sqlite3_result_blob(ctx, b, l, SQLITE_TRANSIENT); +} + +void callbackTrampoline(sqlite3_context*, int, sqlite3_value**); */ import "C" import ( @@ -75,6 +84,7 @@ import ( "fmt" "io" "net/url" + "reflect" "runtime" "strconv" "strings" @@ -120,6 +130,7 @@ type SQLiteConn struct { db *C.sqlite3 loc *time.Location txlock string + funcs []*functionInfo } // Tx struct. @@ -153,6 +164,89 @@ type SQLiteRows struct { cls bool } +type functionInfo struct { + f reflect.Value + argConverters []func(*C.sqlite3_value) (reflect.Value, error) +} + +func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) { + cstr := C.CString(err.Error()) + defer C.free(unsafe.Pointer(cstr)) + C.sqlite3_result_error(ctx, cstr, -1) +} + +func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { + var args []reflect.Value + for i, arg := range argv { + v, err := fi.argConverters[i](arg) + if err != nil { + fi.error(ctx, err) + return + } + args = append(args, v) + } + + ret := fi.f.Call(args) + + if len(ret) == 2 && ret[1].Interface() != nil { + fi.error(ctx, ret[1].Interface().(error)) + return + } + + res := ret[0].Interface() + // Normalize ret to one of the types sqlite knows. + switch r := res.(type) { + case int64, float64, []byte, string: + // Already the right type + case bool: + if r { + res = int64(1) + } else { + res = int64(0) + } + case int: + res = int64(r) + case uint: + res = int64(r) + case uint8: + res = int64(r) + case uint16: + res = int64(r) + case uint32: + res = int64(r) + case uint64: + res = int64(r) + case int8: + res = int64(r) + case int16: + res = int64(r) + case int32: + res = int64(r) + case float32: + res = float64(r) + default: + fi.error(ctx, errors.New("cannot convert returned type to sqlite type")) + return + } + + switch r := res.(type) { + case int64: + C.sqlite3_result_int64(ctx, C.sqlite3_int64(r)) + case float64: + C.sqlite3_result_double(ctx, C.double(r)) + case []byte: + if len(r) == 0 { + C.sqlite3_result_null(ctx) + } else { + C._sqlite3_result_blob(ctx, unsafe.Pointer(&r[0]), C.int(len(r))) + } + case string: + C._sqlite3_result_text(ctx, C.CString(r)) + default: + panic("unreachable") + } +} + // Commit transaction. func (tx *SQLiteTx) Commit() error { _, err := tx.c.exec("COMMIT") @@ -165,6 +259,103 @@ func (tx *SQLiteTx) Rollback() error { return err } +// RegisterFunc makes a Go function available as a SQLite function. +// +// The function must accept only arguments of type int64, float64, +// []byte or string, and return one value of any numeric type except +// complex, bool, []byte or string. Optionally, an error can be +// provided as a second return value. +// +// If pure is true. SQLite will assume that the function's return +// value depends only on its inputs, and make more aggressive +// optimizations in its queries. +func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) error { + var fi functionInfo + fi.f = reflect.ValueOf(impl) + t := fi.f.Type() + if t.Kind() != reflect.Func { + return errors.New("Non-function passed to RegisterFunc") + } + if t.IsVariadic() { + return errors.New("Variadic SQLite functions are not supported") + } + if t.NumOut() != 1 && t.NumOut() != 2 { + return errors.New("SQLite functions must return 1 or 2 values") + } + if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return errors.New("Second return value of SQLite function must be error") + } + + for i := 0; i < t.NumIn(); i++ { + arg := t.In(i) + var conv func(*C.sqlite3_value) (reflect.Value, error) + switch arg.Kind() { + case reflect.Int64: + conv = func(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { + return reflect.Value{}, fmt.Errorf("Argument %d to %s must be an INTEGER", i+1, name) + } + return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil + } + case reflect.Float64: + conv = func(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_FLOAT { + return reflect.Value{}, fmt.Errorf("Argument %d to %s must be a FLOAT", i+1, name) + } + return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil + } + case reflect.Slice: + if arg.Elem().Kind() != reflect.Uint8 { + return errors.New("The only supported slice type is []byte") + } + conv = func(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_BLOB: + l := C.sqlite3_value_bytes(v) + p := C.sqlite3_value_blob(v) + return reflect.ValueOf(C.GoBytes(p, l)), nil + case C.SQLITE_TEXT: + l := C.sqlite3_value_bytes(v) + c := unsafe.Pointer(C.sqlite3_value_text(v)) + return reflect.ValueOf(C.GoBytes(c, l)), nil + default: + return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name) + } + } + case reflect.String: + conv = func(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_BLOB: + l := C.sqlite3_value_bytes(v) + p := (*C.char)(C.sqlite3_value_blob(v)) + return reflect.ValueOf(C.GoStringN(p, l)), nil + case C.SQLITE_TEXT: + c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v))) + return reflect.ValueOf(C.GoString(c)), nil + default: + return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name) + } + } + } + fi.argConverters = append(fi.argConverters, conv) + } + + // fi must outlast the database connection, or we'll have dangling pointers. + c.funcs = append(c.funcs, &fi) + + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + opts := C.SQLITE_UTF8 + if pure { + opts |= C.SQLITE_DETERMINISTIC + } + rv := C.sqlite3_create_function_v2(c.db, cname, C.int(t.NumIn()), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil) + if rv != C.SQLITE_OK { + return c.lastError() + } + return nil +} + // AutoCommit return which currently auto commit or not. func (c *SQLiteConn) AutoCommit() bool { return int(C.sqlite3_get_autocommit(c.db)) != 0 |