aboutsummaryrefslogtreecommitdiff
path: root/sqlite3_test.go
diff options
context:
space:
mode:
authorPhil Eaton <phil@eatonphil.com>2022-05-29 21:06:43 -0400
committerGitHub <noreply@github.com>2022-05-29 21:06:43 -0400
commit3ccccfb4c9c683a80e2cce810ac652616579e51c (patch)
treef7c0fa7538e2267a9f3008646ed737678b0f0c1c /sqlite3_test.go
parentAdd error checking in simple example for tx.Commit (diff)
downloadgolite-3ccccfb4c9c683a80e2cce810ac652616579e51c.tar.gz
golite-3ccccfb4c9c683a80e2cce810ac652616579e51c.tar.xz
Support returning any from callbacks (#1046)
Support returning any from callbacks
Diffstat (limited to 'sqlite3_test.go')
-rw-r--r--sqlite3_test.go57
1 files changed, 57 insertions, 0 deletions
diff --git a/sqlite3_test.go b/sqlite3_test.go
index c86aba4..9ee87e7 100644
--- a/sqlite3_test.go
+++ b/sqlite3_test.go
@@ -1449,6 +1449,63 @@ func TestAggregatorRegistration(t *testing.T) {
}
}
+type mode struct {
+ counts map[interface{}]int
+ top interface{}
+ topCount int
+}
+
+func newMode() *mode {
+ return &mode{
+ counts: map[interface{}]int{},
+ }
+}
+
+func (m *mode) Step(x interface{}) {
+ m.counts[x]++
+ c := m.counts[x]
+ if c > m.topCount {
+ m.top = x
+ m.topCount = c
+ }
+}
+
+func (m *mode) Done() interface{} {
+ return m.top
+}
+
+func TestAggregatorRegistration_GenericReturn(t *testing.T) {
+ sql.Register("sqlite3_AggregatorRegistration_GenericReturn", &SQLiteDriver{
+ ConnectHook: func(conn *SQLiteConn) error {
+ return conn.RegisterAggregator("mode", newMode, true)
+ },
+ })
+ db, err := sql.Open("sqlite3_AggregatorRegistration_GenericReturn", ":memory:")
+ if err != nil {
+ t.Fatal("Failed to open database:", err)
+ }
+ defer db.Close()
+
+ _, err = db.Exec("create table foo (department integer, profits integer)")
+ if err != nil {
+ t.Fatal("Failed to create table:", err)
+ }
+ _, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115), (2, 20)")
+ if err != nil {
+ t.Fatal("Failed to insert records:", err)
+ }
+
+ var mode int
+ err = db.QueryRow("select mode(profits) from foo").Scan(&mode)
+ if err != nil {
+ t.Fatal("MODE query error:", err)
+ }
+
+ if mode != 20 {
+ t.Fatal("Got incorrect mode. Wanted 20, got: ", mode)
+ }
+}
+
func rot13(r rune) rune {
switch {
case r >= 'A' && r <= 'Z':