aboutsummaryrefslogtreecommitdiff
path: root/sqlite3.go
diff options
context:
space:
mode:
authorDavid Anderson <dave@natulte.net>2015-08-21 16:34:55 -0700
committerDavid Anderson <dave@natulte.net>2015-08-21 16:38:23 -0700
commit566f63a43a314f8dcd758dba8c40dc11edc27a5e (patch)
tree51068f3eb538f6eb5ace5c326626ca2c322f7c63 /sqlite3.go
parentMove argument converters to callback.go, and optimize return value handling. (diff)
downloadgolite-566f63a43a314f8dcd758dba8c40dc11edc27a5e.tar.gz
golite-566f63a43a314f8dcd758dba8c40dc11edc27a5e.tar.xz
Implement support for variadic functions.
Currently, the variadic part must all be the same type, because there's no "generic" arg converter.
Diffstat (limited to 'sqlite3.go')
-rw-r--r--sqlite3.go52
1 files changed, 42 insertions, 10 deletions
diff --git a/sqlite3.go b/sqlite3.go
index 174a3ee..8bb9826 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -165,9 +165,10 @@ type SQLiteRows struct {
}
type functionInfo struct {
- f reflect.Value
- argConverters []callbackArgConverter
- retConverter callbackRetConverter
+ f reflect.Value
+ argConverters []callbackArgConverter
+ variadicConverter callbackArgConverter
+ retConverter callbackRetConverter
}
func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) {
@@ -178,7 +179,12 @@ func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) {
func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
var args []reflect.Value
- for i, arg := range argv {
+
+ if len(argv) < len(fi.argConverters) {
+ fi.error(ctx, fmt.Errorf("function requires at least %d arguments", len(fi.argConverters)))
+ }
+
+ for i, arg := range argv[:len(fi.argConverters)] {
v, err := fi.argConverters[i](arg)
if err != nil {
fi.error(ctx, err)
@@ -187,6 +193,17 @@ func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
args = append(args, v)
}
+ if fi.variadicConverter != nil {
+ for _, arg := range argv[len(fi.argConverters):] {
+ v, err := fi.variadicConverter(arg)
+ if err != nil {
+ fi.error(ctx, err)
+ return
+ }
+ args = append(args, v)
+ }
+ }
+
ret := fi.f.Call(args)
if len(ret) == 2 && ret[1].Interface() != nil {
@@ -218,7 +235,8 @@ func (tx *SQLiteTx) Rollback() error {
// The function can accept arguments of any real numeric type
// (i.e. not complex), as well as []byte and string. It must return a
// value of one of those types, and optionally an error as a second
-// value.
+// value. Variadic functions are allowed, if the variadic argument is
+// one of the allowed types.
//
// If pure is true. SQLite will assume that the function's return
// value depends only on its inputs, and make more aggressive
@@ -230,9 +248,6 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
if t.Kind() != reflect.Func {
return errors.New("Non-function passed to RegisterFunc")
}
- if t.IsVariadic() {
- return errors.New("Variadic SQLite functions are not supported")
- }
if t.NumOut() != 1 && t.NumOut() != 2 {
return errors.New("SQLite functions must return 1 or 2 values")
}
@@ -240,7 +255,12 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
return errors.New("Second return value of SQLite function must be error")
}
- for i := 0; i < t.NumIn(); i++ {
+ numArgs := t.NumIn()
+ if t.IsVariadic() {
+ numArgs--
+ }
+
+ for i := 0; i < numArgs; i++ {
conv, err := callbackArg(t.In(i))
if err != nil {
return err
@@ -248,6 +268,18 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
fi.argConverters = append(fi.argConverters, conv)
}
+ if t.IsVariadic() {
+ conv, err := callbackArg(t.In(numArgs).Elem())
+ if err != nil {
+ return err
+ }
+ fi.variadicConverter = conv
+ // Pass -1 to sqlite so that it allows any number of
+ // arguments. The call helper verifies that the minimum number
+ // of arguments is present for variadic functions.
+ numArgs = -1
+ }
+
conv, err := callbackRet(t.Out(0))
if err != nil {
return err
@@ -263,7 +295,7 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
- rv := C.sqlite3_create_function_v2(c.db, cname, C.int(t.NumIn()), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil)
+ rv := C.sqlite3_create_function_v2(c.db, cname, C.int(numArgs), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil)
if rv != C.SQLITE_OK {
return c.lastError()
}