From 7174000f779db8d51c898c082640cb2f475cd8d3 Mon Sep 17 00:00:00 2001 From: Kenneth Shaw Date: Sun, 5 Nov 2017 09:18:06 +0700 Subject: Move RegisterAggregator implementation The SQLiteConn.RegisterAggregator implementation was defined in sqlite3_trace.go file, which is guarded with a build constraint. This change simply moves RegisterAggregator to the main sqlite3.go file, and moves accompanying unit tests. The rationale for this move is that it was not possible for downstream using packages to use RegisterAggregator without also specifying (and notifying the user) the 'trace' build tag. --- sqlite3_test.go | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) (limited to 'sqlite3_test.go') diff --git a/sqlite3_test.go b/sqlite3_test.go index 9d4b373..84ecb5a 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1232,6 +1232,66 @@ func TestFunctionRegistration(t *testing.T) { } } +type sumAggregator int64 + +func (s *sumAggregator) Step(x int64) { + *s += sumAggregator(x) +} + +func (s *sumAggregator) Done() int64 { + return int64(*s) +} + +func TestAggregatorRegistration(t *testing.T) { + customSum := func() *sumAggregator { + var ret sumAggregator + return &ret + } + + sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + if err := conn.RegisterAggregator("customSum", customSum, true); err != nil { + return err + } + return nil + }, + }) + db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("create table foo (department integer, profits integer)") + if err != nil { + // trace feature is not implemented + t.Skip("Failed to create table:", err) + } + + _, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)") + if err != nil { + t.Fatal("Failed to insert records:", err) + } + + tests := []struct { + dept, sum int64 + }{ + {1, 30}, + {2, 42}, + } + + for _, test := range tests { + var ret int64 + err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret) + if err != nil { + t.Fatal("Query failed:", err) + } + if ret != test.sum { + t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum) + } + } +} + func rot13(r rune) rune { switch { case r >= 'A' && r <= 'Z': -- cgit v1.2.3