summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/q.go652
-rw-r--r--tests/functional/consumer-with-deadletter/q.go6
-rw-r--r--tests/functional/new-instance-takeover/q.go20
-rw-r--r--tests/functional/wait-after-publish/q.go5
-rw-r--r--tests/q.go453
-rw-r--r--tests/queries.sql107
6 files changed, 704 insertions, 539 deletions
diff --git a/src/q.go b/src/q.go
index 79570df..598cd75 100644
--- a/src/q.go
+++ b/src/q.go
@@ -1,6 +1,6 @@
package q
+
import (
- "context"
"database/sql"
"flag"
"fmt"
@@ -18,15 +18,21 @@ import (
const (
- defaultPrefix = "q"
- reaperSkipCount = 1000
- notOwnerErrorFmt = "%v owns %#v as %#v, not us (%v)"
- noLongerOwnerErrorFmt = "we (%v) no longer own %#v as %#v, but %v does"
- rollbackErrorFmt = "rollback error: %w; while executing: %w"
+ defaultPrefix = "q"
+ reaperSkipCount = 1000
+ notOwnerErrorFmt = "%v owns %#v as %#v, not us (%v)"
+ rollbackErrorFmt = "rollback error: %w; while executing: %w"
)
+type dbconfigT struct{
+ shared *sql.DB
+ dbpath string
+ prefix string
+ instanceID int
+}
+
type queryT struct{
write string
read string
@@ -126,7 +132,6 @@ type subscriptionsT struct {
}
type queueT struct{
- db *sql.DB
queries queriesT
subscriptions subscriptionsT
pinger pingerT[struct{}]
@@ -158,12 +163,55 @@ type IQueue interface{
-func closeNoop() error {
- return nil
+func serialized[A any, B any](callback func(...A) B) (func(...A) B, func()) {
+ in := make(chan []A)
+ out := make(chan B)
+
+ closed := false
+ var (
+ closeWg sync.WaitGroup
+ closeMutex sync.Mutex
+ )
+ closeWg.Add(1)
+
+ go func() {
+ for input := range in {
+ out <- callback(input...)
+ }
+ close(out)
+ closeWg.Done()
+ }()
+
+ fn := func(input ...A) B {
+ in <- input
+ return (<- out)
+ }
+
+ closeFn := func() {
+ closeMutex.Lock()
+ defer closeMutex.Unlock()
+ if closed {
+ return
+ }
+ close(in)
+ closed = true
+ closeWg.Wait()
+ }
+
+ return fn, closeFn
+}
+
+func execSerialized(query string, db *sql.DB) (func(...any) error, func()) {
+ return serialized(func(args ...any) error {
+ return inTx(db, func(tx *sql.Tx) error {
+ _, err := tx.Exec(query, args...)
+ return err
+ })
+ })
}
-func tryRollback(db *sql.DB, ctx context.Context, err error) error {
- _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;")
+func tryRollback(tx *sql.Tx, err error) error {
+ rollbackErr := tx.Rollback()
if rollbackErr != nil {
return fmt.Errorf(
rollbackErrorFmt,
@@ -175,25 +223,20 @@ func tryRollback(db *sql.DB, ctx context.Context, err error) error {
return err
}
-// FIXME
-// See:
-// https://sqlite.org/forum/forumpost/2507664507
-func inTx(db *sql.DB, fn func(context.Context) error) error {
- ctx := context.Background()
-
- _, err := db.ExecContext(ctx, "BEGIN IMMEDIATE;")
+func inTx(db *sql.DB, fn func(*sql.Tx) error) error {
+ tx, err := db.Begin()
if err != nil {
return err
}
- err = fn(ctx)
+ err = fn(tx)
if err != nil {
- return tryRollback(db, ctx, err)
+ return tryRollback(tx, err)
}
- _, err = db.ExecContext(ctx, "COMMIT;")
+ err = tx.Commit()
if err != nil {
- return tryRollback(db, ctx, err)
+ return tryRollback(tx, err)
}
return nil
@@ -227,6 +270,7 @@ func createTablesSQL(prefix string) queryT {
consumer TEXT NOT NULL,
message_id INTEGER NOT NULL
REFERENCES "%s_messages"(id),
+ instance_id INTEGER NOT NULL,
UNIQUE (consumer, message_id)
) STRICT;
CREATE INDEX IF NOT EXISTS "%s_offsets_consumer"
@@ -238,6 +282,7 @@ func createTablesSQL(prefix string) queryT {
consumer TEXT NOT NULL,
message_id INTEGER NOT NULL
REFERENCES "%s_messages"(id),
+ instance_id INTEGER NOT NULL,
UNIQUE (consumer, message_id)
) STRICT;
CREATE INDEX IF NOT EXISTS "%s_deadletters_consumer"
@@ -258,6 +303,44 @@ func createTablesSQL(prefix string) queryT {
owner_id INTEGER NOT NULL,
UNIQUE (topic, consumer)
) STRICT;
+
+ CREATE TRIGGER IF NOT EXISTS "%s_check_instance_owns_topic"
+ BEFORE INSERT ON "%s_offsets"
+ WHEN NEW.instance_id != (
+ SELECT owner_id FROM "%s_owners"
+ WHERE topic = (
+ SELECT "%s_payloads".topic
+ FROM "%s_payloads"
+ JOIN "%s_messages" ON "%s_payloads".id =
+ "%s_messages".payload_id
+ WHERE "%s_messages".id = NEW.message_id
+ ) AND consumer = NEW.consumer
+ )
+ BEGIN
+ SELECT RAISE(
+ ABORT,
+ 'instance does not own topic/consumer combo'
+ );
+ END;
+
+ CREATE TRIGGER IF NOT EXISTS "%s_check_can_publish_deadletter"
+ BEFORE INSERT ON "%s_deadletters"
+ WHEN NEW.instance_id != (
+ SELECT owner_id FROM "%s_owners"
+ WHERE topic = (
+ SELECT "%s_payloads".topic
+ FROM "%s_payloads"
+ JOIN "%s_messages" ON "%s_payloads".id =
+ "%s_messages".payload_id
+ WHERE "%s_messages".id = NEW.message_id
+ ) AND consumer = NEW.consumer
+ )
+ BEGIN
+ SELECT RAISE(
+ ABORT,
+ 'Instance does not own topic/consumer combo'
+ );
+ END;
`
return queryT{
write: fmt.Sprintf(
@@ -284,6 +367,24 @@ func createTablesSQL(prefix string) queryT {
prefix,
prefix,
prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
),
}
}
@@ -291,8 +392,8 @@ func createTablesSQL(prefix string) queryT {
func createTables(db *sql.DB, prefix string) error {
q := createTablesSQL(prefix)
- return inTx(db, func(ctx context.Context) error {
- _, err := db.ExecContext(ctx, q.write)
+ return inTx(db, func(tx *sql.Tx) error {
+ _, err := tx.Exec(q.write)
return err
})
}
@@ -310,26 +411,25 @@ func takeSQL(prefix string) queryT {
}
func takeStmt(
- db *sql.DB,
- prefix string,
- instanceID int,
+ cfg dbconfigT,
) (func(string, string) error, func() error, error) {
- q := takeSQL(prefix)
+ q := takeSQL(cfg.prefix)
+
+ writeStmt, err := cfg.shared.Prepare(q.write)
+ if err != nil {
+ return nil, nil, err
+ }
fn := func(topic string, consumer string) error {
- return inTx(db, func(ctx context.Context) error {
- _, err := db.ExecContext(
- ctx,
- q.write,
- topic,
- consumer,
- instanceID,
- )
- return err
- })
+ _, err := writeStmt.Exec(
+ topic,
+ consumer,
+ cfg.instanceID,
+ )
+ return err
}
- return fn, closeNoop, nil
+ return fn, writeStmt.Close, nil
}
func publishSQL(prefix string) queryT {
@@ -337,7 +437,6 @@ func publishSQL(prefix string) queryT {
INSERT INTO "%s_payloads" (topic, payload)
VALUES (?, ?);
- -- FIXME: must be inside a trnsaction
INSERT INTO "%s_messages" (uuid, flow_id, payload_id)
VALUES (?, ?, last_insert_rowid());
`
@@ -352,17 +451,23 @@ func publishSQL(prefix string) queryT {
}
func publishStmt(
- db *sql.DB,
- prefix string,
- _ int,
+ cfg dbconfigT,
) (func(UnsentMessage, guuid.UUID) (messageT, error), func() error, error) {
- q := publishSQL(prefix)
+ q := publishSQL(cfg.prefix)
+
+ readStmt, err := cfg.shared.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
- readStmt, err := db.Prepare(q.read)
+ privateDB, err := sql.Open(golite.DriverName, cfg.dbpath)
if err != nil {
+ readStmt.Close()
return nil, nil, err
}
+ writeFn, writeFnClose := execSerialized(q.write, privateDB)
+
fn := func(
unsentMessage UnsentMessage,
messageID guuid.UUID,
@@ -376,17 +481,12 @@ func publishStmt(
message_id_bytes := messageID[:]
flow_id_bytes := unsentMessage.FlowID[:]
- err := inTx(db, func(ctx context.Context) error {
- _, err := db.ExecContext(
- ctx,
- q.write,
- unsentMessage.Topic,
- unsentMessage.Payload,
- message_id_bytes,
- flow_id_bytes,
- )
- return err
- })
+ err := writeFn(
+ unsentMessage.Topic,
+ unsentMessage.Payload,
+ message_id_bytes,
+ flow_id_bytes,
+ )
if err != nil {
return messageT{}, err
}
@@ -408,7 +508,12 @@ func publishStmt(
return message, nil
}
- return fn, readStmt.Close, nil
+ closeFn := func() error {
+ writeFnClose()
+ return g.SomeError(privateDB.Close(), readStmt.Close())
+ }
+
+ return fn, closeFn, nil
}
func findSQL(prefix string) queryT {
@@ -446,13 +551,11 @@ func findSQL(prefix string) queryT {
}
func findStmt(
- db *sql.DB,
- prefix string,
- _ int,
+ cfg dbconfigT,
) (func(string, guuid.UUID) (messageT, error), func() error, error) {
- q := findSQL(prefix)
+ q := findSQL(cfg.prefix)
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
@@ -493,6 +596,13 @@ func findStmt(
func nextSQL(prefix string) queryT {
const tmpl_read = `
SELECT
+ (
+ SELECT owner_id FROM "%s_owners"
+ WHERE
+ topic = ? AND
+ consumer = ?
+ LIMIT 1
+ ) AS owner_id,
"%s_messages".id,
"%s_messages".timestamp,
"%s_messages".uuid,
@@ -510,12 +620,6 @@ func nextSQL(prefix string) queryT {
ORDER BY "%s_messages".id ASC
LIMIT 1;
`
- const tmpl_owner = `
- SELECT owner_id FROM "%s_owners"
- WHERE
- topic = ? AND
- consumer = ?;
- `
return queryT{
read: fmt.Sprintf(
tmpl_read,
@@ -532,17 +636,20 @@ func nextSQL(prefix string) queryT {
prefix,
prefix,
prefix,
+ prefix,
),
- owner: fmt.Sprintf(tmpl_owner, prefix),
}
}
func nextStmt(
- db *sql.DB,
- prefix string,
- instanceID int,
+ cfg dbconfigT,
) (func(string, string) (messageT, error), func() error, error) {
- q := nextSQL(prefix)
+ q := nextSQL(cfg.prefix)
+
+ readStmt, err := cfg.shared.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
fn := func(topic string, consumer string) (messageT, error) {
message := messageT{
@@ -550,44 +657,34 @@ func nextStmt(
}
var (
- err error
ownerID int
timestr string
message_id_bytes []byte
flow_id_bytes []byte
)
- tx, err := db.Begin()
- if err != nil {
- return messageT{}, err
- }
- defer tx.Rollback()
- err = tx.QueryRow(q.owner, topic, consumer).Scan(&ownerID)
+ err = readStmt.QueryRow(topic, consumer, topic, consumer).Scan(
+ &ownerID,
+ &message.id,
+ &timestr,
+ &message_id_bytes,
+ &flow_id_bytes,
+ &message.payload,
+ )
if err != nil {
return messageT{}, err
}
- if ownerID != instanceID {
+ if ownerID != cfg.instanceID {
err := fmt.Errorf(
notOwnerErrorFmt,
ownerID,
topic,
consumer,
- instanceID,
+ cfg.instanceID,
)
return messageT{}, err
}
-
- err = tx.QueryRow(q.read, topic, consumer).Scan(
- &message.id,
- &timestr,
- &message_id_bytes,
- &flow_id_bytes,
- &message.payload,
- )
- if err != nil {
- return messageT{}, err
- }
message.uuid = guuid.UUID(message_id_bytes)
message.flowID = guuid.UUID(flow_id_bytes)
@@ -599,7 +696,7 @@ func nextStmt(
return message, nil
}
- return fn, closeNoop, nil
+ return fn, readStmt.Close, nil
}
func messageEach(rows *sql.Rows, callback func(messageT) error) error {
@@ -623,19 +720,22 @@ func messageEach(rows *sql.Rows, callback func(messageT) error) error {
&message.payload,
)
if err != nil {
- return g.WrapErrors(rows.Close(), err)
+ rows.Close()
+ return err
}
message.uuid = guuid.UUID(message_id_bytes)
message.flowID = guuid.UUID(flow_id_bytes)
message.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
if err != nil {
- return g.WrapErrors(rows.Close(), err)
+ rows.Close()
+ return err
}
err = callback(message)
if err != nil {
- return g.WrapErrors(rows.Close(), err)
+ rows.Close()
+ return err
}
}
@@ -691,20 +791,19 @@ func pendingSQL(prefix string) queryT {
}
func pendingStmt(
- db *sql.DB,
- prefix string,
- instanceID int,
+ cfg dbconfigT,
) (func(string, string) (*sql.Rows, error), func() error, error) {
- q := pendingSQL(prefix)
+ q := pendingSQL(cfg.prefix)
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
- ownerStmt, err := db.Prepare(q.owner)
+ ownerStmt, err := cfg.shared.Prepare(q.owner)
if err != nil {
- return nil, nil, g.WrapErrors(readStmt.Close(), err)
+ readStmt.Close()
+ return nil, nil, err
}
fn := func(topic string, consumer string) (*sql.Rows, error) {
@@ -716,7 +815,7 @@ func pendingStmt(
// best effort check, the final one is done during
// commit within a transaction
- if ownerID != instanceID {
+ if ownerID != cfg.instanceID {
return nil, nil
}
@@ -732,130 +831,67 @@ func pendingStmt(
func commitSQL(prefix string) queryT {
const tmpl_write = `
- INSERT INTO "%s_offsets" (consumer, message_id)
- VALUES (?, (SELECT id FROM "%s_messages" WHERE uuid = ?));
- `
- const tmpl_read = `
- SELECT "%s_payloads".topic from "%s_payloads"
- JOIN "%s_messages" ON
- "%s_payloads".id = "%s_messages".payload_id
- WHERE "%s_messages".uuid = ?;
- `
- const tmpl_owner = `
- SELECT owner_id FROM "%s_owners"
- WHERE
- topic = ? AND
- consumer = ?;
+ INSERT INTO "%s_offsets" (consumer, message_id, instance_id)
+ VALUES (?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?);
`
return queryT{
write: fmt.Sprintf(tmpl_write, prefix, prefix),
- read: fmt.Sprintf(
- tmpl_read,
- prefix,
- prefix,
- prefix,
- prefix,
- prefix,
- prefix,
- ),
- owner: fmt.Sprintf(tmpl_owner, prefix),
}
}
func commitStmt(
- db *sql.DB,
- prefix string,
- instanceID int,
+ cfg dbconfigT,
) (func(string, guuid.UUID) error, func() error, error) {
- q := commitSQL(prefix)
+ q := commitSQL(cfg.prefix)
+
+ writeStmt, err := cfg.shared.Prepare(q.write)
+ if err != nil {
+ return nil, nil, err
+ }
fn := func(consumer string, messageID guuid.UUID) error {
message_id_bytes := messageID[:]
- return inTx(db, func(ctx context.Context) error {
- var topic string
- err := db.QueryRowContext(
- ctx,
- q.read,
- message_id_bytes,
- ).Scan(&topic)
- if err != nil {
- return err
- }
-
- var ownerID int
- err = db.QueryRowContext(
- ctx,
- q.owner,
- topic,
- consumer,
- ).Scan(&ownerID)
- if err != nil {
- return err
- }
-
- if ownerID != instanceID {
- return fmt.Errorf(
- noLongerOwnerErrorFmt,
- instanceID,
- topic,
- consumer,
- ownerID,
- )
- }
-
- _, err = db.ExecContext(ctx, q.write, consumer, message_id_bytes)
- return err
- })
+ _, err = writeStmt.Exec(
+ consumer,
+ message_id_bytes,
+ cfg.instanceID,
+ )
+ return err
}
- return fn, closeNoop, nil
+ return fn, writeStmt.Close, nil
}
func toDeadSQL(prefix string) queryT {
const tmpl_write = `
- INSERT INTO "%s_offsets" ( consumer, message_id)
- VALUES ( ?, (SELECT id FROM "%s_messages" WHERE uuid = ?));
+ INSERT INTO "%s_offsets"
+ ( consumer, message_id, instance_id)
+ VALUES ( ?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?);
- INSERT INTO "%s_deadletters" (uuid, consumer, message_id)
- VALUES (?, ?, (SELECT id FROM "%s_messages" WHERE uuid = ?));
- `
- const tmpl_read = `
- SELECT "%s_payloads".topic FROM "%s_payloads"
- JOIN "%s_messages" ON
- "%s_payloads".id = "%s_messages".payload_id
- WHERE "%s_messages".uuid = ?;
- `
- const tmpl_owner = `
- SELECT owner_id FROM "%s_owners"
- WHERE
- topic = ? AND
- consumer = ?;
+ INSERT INTO "%s_deadletters"
+ (uuid, consumer, message_id, instance_id)
+ VALUES (?, ?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?);
`
return queryT{
write: fmt.Sprintf(tmpl_write, prefix, prefix, prefix, prefix),
- read: fmt.Sprintf(
- tmpl_read,
- prefix,
- prefix,
- prefix,
- prefix,
- prefix,
- prefix,
- ),
- owner: fmt.Sprintf(tmpl_owner, prefix),
}
}
func toDeadStmt(
- db *sql.DB,
- prefix string,
- instanceID int,
+ cfg dbconfigT,
) (
func(string, guuid.UUID, guuid.UUID) error,
func() error,
error,
) {
- q := toDeadSQL(prefix)
+ q := toDeadSQL(cfg.prefix)
+
+ privateDB, err := sql.Open(golite.DriverName, cfg.dbpath)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ writeFn, writeFnClose := execSerialized(q.write, privateDB)
fn := func(
consumer string,
@@ -864,52 +900,24 @@ func toDeadStmt(
) error {
message_id_bytes := messageID[:]
deadletter_id_bytes := deadletterID[:]
- return inTx(db, func(ctx context.Context) error {
- var topic string
- err := db.QueryRowContext(
- ctx,
- q.read,
- message_id_bytes,
- ).Scan(&topic)
- if err != nil {
- return err
- }
-
- var ownerID int
- err = db.QueryRowContext(
- ctx,
- q.owner,
- topic,
- consumer,
- ).Scan(&ownerID)
- if err != nil {
- return err
- }
-
- if ownerID != instanceID {
- return fmt.Errorf(
- noLongerOwnerErrorFmt,
- instanceID,
- topic,
- consumer,
- ownerID,
- )
- }
+ return writeFn(
+ consumer,
+ message_id_bytes,
+ cfg.instanceID,
+ deadletter_id_bytes,
+ consumer,
+ message_id_bytes,
+ cfg.instanceID,
+ )
+ }
- _, err = db.ExecContext(
- ctx,
- q.write,
- consumer,
- message_id_bytes,
- deadletter_id_bytes,
- consumer,
- message_id_bytes,
- )
- return err
- })
+ closeFn := func() error {
+ writeFnClose()
+ return privateDB.Close()
}
- return fn, closeNoop, nil
+
+ return fn, closeFn, nil
}
func replaySQL(prefix string) queryT {
@@ -973,33 +981,34 @@ func replaySQL(prefix string) queryT {
}
func replayStmt(
- db *sql.DB,
- prefix string,
- _ int,
+ cfg dbconfigT,
) (func(guuid.UUID, guuid.UUID) (messageT, error), func() error, error) {
- q := replaySQL(prefix)
+ q := replaySQL(cfg.prefix)
+
+ readStmt, err := cfg.shared.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
- readStmt, err := db.Prepare(q.read)
+ privateDB, err := sql.Open(golite.DriverName, cfg.dbpath)
if err != nil {
+ readStmt.Close()
return nil, nil, err
}
+ writeFn, writeFnClose := execSerialized(q.write, privateDB)
+
fn := func(
deadletterID guuid.UUID,
messageID guuid.UUID,
) (messageT, error) {
deadletter_id_bytes := deadletterID[:]
message_id_bytes := messageID[:]
- err := inTx(db, func(ctx context.Context) error {
- _, err := db.ExecContext(
- ctx,
- q.write,
- message_id_bytes,
- deadletter_id_bytes,
- deadletter_id_bytes,
- )
- return err
- })
+ err := writeFn(
+ message_id_bytes,
+ deadletter_id_bytes,
+ deadletter_id_bytes,
+ )
if err != nil {
return messageT{}, err
}
@@ -1032,7 +1041,12 @@ func replayStmt(
return message, nil
}
- return fn, readStmt.Close, nil
+ closeFn := func() error {
+ writeFnClose()
+ return g.SomeError(privateDB.Close(), readStmt.Close())
+ }
+
+ return fn, closeFn, nil
}
func oneDeadSQL(prefix string) queryT {
@@ -1085,13 +1099,11 @@ func oneDeadSQL(prefix string) queryT {
}
func oneDeadStmt(
- db *sql.DB,
- prefix string,
- _ int,
+ cfg dbconfigT,
) (func(string, string) (deadletterT, error), func() error, error) {
- q := oneDeadSQL(prefix)
+ q := oneDeadSQL(cfg.prefix)
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
@@ -1157,7 +1169,8 @@ func deadletterEach(
&message.payload,
)
if err != nil {
- return g.WrapErrors(rows.Close(), err)
+ rows.Close()
+ return err
}
deadletter.uuid = guuid.UUID(deadletter_id_bytes)
@@ -1170,7 +1183,8 @@ func deadletterEach(
messageTimestr,
)
if err != nil {
- return g.WrapErrors(rows.Close(), err)
+ rows.Close()
+ return err
}
deadletter.timestamp, err = time.Parse(
@@ -1178,12 +1192,14 @@ func deadletterEach(
deadletterTimestr,
)
if err != nil {
- return g.WrapErrors(rows.Close(), err)
+ rows.Close()
+ return err
}
err = callback(deadletter, message)
if err != nil {
- return g.WrapErrors(rows.Close(), err)
+ rows.Close()
+ return err
}
}
@@ -1251,13 +1267,11 @@ func allDeadSQL(prefix string) queryT {
}
func allDeadStmt(
- db *sql.DB,
- prefix string,
- _ int,
+ cfg dbconfigT,
) (func(string, string) (*sql.Rows, error), func() error, error) {
- q := allDeadSQL(prefix)
+ q := allDeadSQL(cfg.prefix)
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
@@ -1292,13 +1306,11 @@ func sizeSQL(prefix string) queryT {
func sizeStmt(
- db *sql.DB,
- prefix string,
- _ int,
+ cfg dbconfigT,
) (func(string) (int, error), func() error, error) {
- q := sizeSQL(prefix)
+ q := sizeSQL(cfg.prefix)
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
@@ -1346,13 +1358,11 @@ func countSQL(prefix string) queryT {
}
func countStmt(
- db *sql.DB,
- prefix string,
- _ int,
+ cfg dbconfigT,
) (func(string, string) (int, error), func() error, error) {
- q := countSQL(prefix)
+ q := countSQL(cfg.prefix)
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
@@ -1399,13 +1409,11 @@ func hasDataSQL(prefix string) queryT {
}
func hasDataStmt(
- db *sql.DB,
- prefix string,
- _ int,
+ cfg dbconfigT,
) (func(string, string) (bool, error), func() error, error) {
- q := hasDataSQL(prefix)
+ q := hasDataSQL(cfg.prefix)
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
@@ -1428,27 +1436,44 @@ func hasDataStmt(
}
func initDB(
- db *sql.DB,
+ dbpath string,
prefix string,
notifyFn func(messageT),
instanceID int,
) (queriesT, error) {
- createTablesErr := createTables(db, prefix)
- take, takeClose, takeErr := takeStmt(db, prefix, instanceID)
- publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
- find, findClose, findErr := findStmt(db, prefix, instanceID)
- next, nextClose, nextErr := nextStmt(db, prefix, instanceID)
- pending, pendingClose, pendingErr := pendingStmt(db, prefix, instanceID)
- commit, commitClose, commitErr := commitStmt(db, prefix, instanceID)
- toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID)
- replay, replayClose, replayErr := replayStmt(db, prefix, instanceID)
- oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(db, prefix, instanceID)
- allDead, allDeadClose, allDeadErr := allDeadStmt(db, prefix, instanceID)
- size, sizeClose, sizeErr := sizeStmt(db, prefix, instanceID)
- count, countClose, countErr := countStmt(db, prefix, instanceID)
- hasData, hasDataClose, hasDataErr := hasDataStmt(db, prefix, instanceID)
-
- err := g.SomeError(
+ err := g.ValidateSQLTablePrefix(prefix)
+ if err != nil {
+ return queriesT{}, err
+ }
+
+ shared, err := sql.Open(golite.DriverName, dbpath)
+ if err != nil {
+ return queriesT{}, err
+ }
+
+ cfg := dbconfigT{
+ shared: shared,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+
+ createTablesErr := createTables(shared, prefix)
+ take, takeClose, takeErr := takeStmt(cfg)
+ publish, publishClose, publishErr := publishStmt(cfg)
+ find, findClose, findErr := findStmt(cfg)
+ next, nextClose, nextErr := nextStmt(cfg)
+ pending, pendingClose, pendingErr := pendingStmt(cfg)
+ commit, commitClose, commitErr := commitStmt(cfg)
+ toDead, toDeadClose, toDeadErr := toDeadStmt(cfg)
+ replay, replayClose, replayErr := replayStmt(cfg)
+ oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(cfg)
+ allDead, allDeadClose, allDeadErr := allDeadStmt(cfg)
+ size, sizeClose, sizeErr := sizeStmt(cfg)
+ count, countClose, countErr := countStmt(cfg)
+ hasData, hasDataClose, hasDataErr := hasDataStmt(cfg)
+
+ err = g.SomeError(
createTablesErr,
takeErr,
publishErr,
@@ -1468,7 +1493,7 @@ func initDB(
return queriesT{}, err
}
- close := func() error {
+ closeFn := func() error {
return g.SomeFnError(
takeClose,
publishClose,
@@ -1483,6 +1508,7 @@ func initDB(
sizeClose,
countClose,
hasDataClose,
+ shared.Close,
)
}
@@ -1614,7 +1640,7 @@ func initDB(
close: func() error {
connMutex.Lock()
defer connMutex.Unlock()
- return close()
+ return closeFn()
},
}, nil
}
@@ -1826,20 +1852,10 @@ func runReaper(
}
func NewWithPrefix(databasePath string, prefix string) (IQueue, error) {
- err := g.ValidateSQLTablePrefix(prefix)
- if err != nil {
- return queueT{}, err
- }
-
- db, err := sql.Open(golite.DriverName, databasePath)
- if err != nil {
- return queueT{}, err
- }
-
subscriptions := makeSubscriptionsFuncs()
pinger := newPinger[struct{}]()
notifyFn := makeNotifyFn(subscriptions.read, pinger)
- queries, err := initDB(db, prefix, notifyFn, os.Getpid())
+ queries, err := initDB(databasePath, prefix, notifyFn, os.Getpid())
if err != nil {
return queueT{}, err
}
@@ -1847,7 +1863,6 @@ func NewWithPrefix(databasePath string, prefix string) (IQueue, error) {
go runReaper(pinger.onPing, subscriptions.read, subscriptions.write)
return queueT{
- db: db,
queries: queries,
subscriptions: subscriptions,
pinger: pinger,
@@ -2120,7 +2135,6 @@ func cleanSubscriptions(set subscriptionsSetM) error {
func (queue queueT) Close() error {
queue.pinger.close()
return g.WrapErrors(
- queue.db.Close(),
queue.subscriptions.write(cleanSubscriptions),
queue.queries.close(),
)
diff --git a/tests/functional/consumer-with-deadletter/q.go b/tests/functional/consumer-with-deadletter/q.go
index 25391c5..a79ad5b 100644
--- a/tests/functional/consumer-with-deadletter/q.go
+++ b/tests/functional/consumer-with-deadletter/q.go
@@ -2,7 +2,6 @@ package q
import (
"errors"
- "os"
"runtime"
"guuid"
@@ -44,15 +43,10 @@ func MainTest() {
g.TAssertEqualS(ok, true, "can't get filename")
databasePath := file + ".db"
- os.Remove(databasePath)
- os.Remove(databasePath + "-shm")
- os.Remove(databasePath + "-wal")
-
queue, err := New(databasePath)
g.TErrorIf(err)
defer queue.Close()
-
pub := func(payload []byte, flowID guuid.UUID) {
unsent := UnsentMessage{
Topic: topicX,
diff --git a/tests/functional/new-instance-takeover/q.go b/tests/functional/new-instance-takeover/q.go
index 6e04e5f..76ed59f 100644
--- a/tests/functional/new-instance-takeover/q.go
+++ b/tests/functional/new-instance-takeover/q.go
@@ -1,7 +1,6 @@
package q
import (
- "fmt"
"runtime"
"os"
@@ -33,16 +32,16 @@ func handlerFn(publish func(guuid.UUID)) func(Message) error {
}
func startInstance(
- databasePath string,
+ dbpath string,
instanceID int,
name string,
) (IQueue, error) {
- iqueue, err := New(databasePath)
+ iqueue, err := New(dbpath)
g.TErrorIf(err)
queue := iqueue.(queueT)
notifyFn := makeNotifyFn(queue.subscriptions.read, queue.pinger)
- queries, err := initDB(queue.db, defaultPrefix, notifyFn, instanceID)
+ queries, err := initDB(dbpath, defaultPrefix, notifyFn, instanceID)
g.TErrorIf(err)
err = queue.queries.close()
@@ -68,20 +67,12 @@ func startInstance(
func MainTest() {
- // https://sqlite.org/forum/forumpost/2507664507
g.Init()
_, file, _, ok := runtime.Caller(0)
g.TAssertEqualS(ok, true, "can't get filename")
dbpath := file + ".db"
- dbpath = "/mnt/dois/andreh/t.db"
- os.Remove(dbpath)
- os.Remove(dbpath + "-shm")
- os.Remove(dbpath + "-wal")
-
- // FIXME
- return
instanceID1 := os.Getpid()
instanceID2 := instanceID1 + 1
@@ -89,10 +80,6 @@ func MainTest() {
flowID2 := guuid.New()
g.Testing("new instances take ownership of topic+name combo", func() {
- if false {
- fmt.Fprintf(os.Stderr, "(PID %d + 1) ", instanceID1)
- }
-
q1, err := startInstance(dbpath, instanceID1, "first")
g.TErrorIf(err)
defer q1.Close()
@@ -103,7 +90,6 @@ func MainTest() {
<- q1.WaitFor("individual-first", flowID1, "w").Channel
<- q1.WaitFor( "shared-first", flowID1, "w").Channel
- // println("waited 1")
q2, err := startInstance(dbpath, instanceID2, "second")
g.TErrorIf(err)
diff --git a/tests/functional/wait-after-publish/q.go b/tests/functional/wait-after-publish/q.go
index 15a532d..b70d27e 100644
--- a/tests/functional/wait-after-publish/q.go
+++ b/tests/functional/wait-after-publish/q.go
@@ -1,7 +1,6 @@
package q
import (
- "os"
"runtime"
"guuid"
@@ -19,10 +18,6 @@ func MainTest() {
g.TAssertEqualS(ok, true, "can't get filename")
databasePath := file + ".db"
- os.Remove(databasePath)
- os.Remove(databasePath + "-shm")
- os.Remove(databasePath + "-wal")
-
queue, err := New(databasePath)
g.TErrorIf(err)
defer queue.Close()
diff --git a/tests/q.go b/tests/q.go
index 844e722..565ca25 100644
--- a/tests/q.go
+++ b/tests/q.go
@@ -21,6 +21,10 @@ import (
+var instanceID = os.Getpid()
+
+
+
func test_defaultPrefix() {
g.TestStart("defaultPrefix")
@@ -29,13 +33,46 @@ func test_defaultPrefix() {
})
}
-func test_tryRollback() {
+func test_serialized() {
// FIXME
}
-func test_inTx() {
- /*
+func test_execSerialized() {
// FIXME
+}
+
+func test_tryRollback() {
+ g.TestStart("tryRollback()")
+
+ myErr := errors.New("bottom error")
+
+ db, err := sql.Open(golite.DriverName, golite.InMemory)
+ g.TErrorIf(err)
+ defer db.Close()
+
+
+ g.Testing("the error is propagated if rollback doesn't fail", func() {
+ tx, err := db.Begin()
+ g.TErrorIf(err)
+
+ err = tryRollback(tx, myErr)
+ g.TAssertEqual(err, myErr)
+ })
+
+ g.Testing("a wrapped error when rollback fails", func() {
+ tx, err := db.Begin()
+ g.TErrorIf(err)
+
+ err = tx.Commit()
+ g.TErrorIf(err)
+
+ err = tryRollback(tx, myErr)
+ g.TAssertEqual(reflect.DeepEqual(err, myErr), false)
+ g.TAssertEqual(errors.Is(err, myErr), true)
+ })
+}
+
+func test_inTx() {
g.TestStart("inTx()")
db, err := sql.Open(golite.DriverName, golite.InMemory)
@@ -44,11 +81,11 @@ func test_inTx() {
g.Testing("when fn() errors, we propagate it", func() {
- myError := errors.New("to be propagated")
+ myErr := errors.New("to be propagated")
err := inTx(db, func(tx *sql.Tx) error {
- return myError
+ return myErr
})
- g.TAssertEqual(err, myError)
+ g.TAssertEqual(err, myErr)
})
g.Testing("on nil error we get nil", func() {
@@ -57,13 +94,17 @@ func test_inTx() {
})
g.TErrorIf(err)
})
- */
}
func test_createTables() {
g.TestStart("createTables()")
- db, err := sql.Open(golite.DriverName, golite.InMemory)
+ const (
+ dbpath = golite.InMemory
+ prefix = defaultPrefix
+ )
+
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
defer db.Close()
@@ -72,12 +113,12 @@ func test_createTables() {
const tmpl_read = `
SELECT id FROM "%s_messages" LIMIT 1;
`
- qRead := fmt.Sprintf(tmpl_read, defaultPrefix)
+ qRead := fmt.Sprintf(tmpl_read, prefix)
_, err := db.Exec(qRead)
g.TErrorNil(err)
- err = createTables(db, defaultPrefix)
+ err = createTables(db, prefix)
g.TErrorIf(err)
_, err = db.Exec(qRead)
@@ -86,9 +127,9 @@ func test_createTables() {
g.Testing("we can do it multiple times", func() {
g.TErrorIf(g.SomeError(
- createTables(db, defaultPrefix),
- createTables(db, defaultPrefix),
- createTables(db, defaultPrefix),
+ createTables(db, prefix),
+ createTables(db, prefix),
+ createTables(db, prefix),
))
})
}
@@ -99,15 +140,21 @@ func test_takeStmt() {
const (
topic = "take() topic"
consumer = "take() consumer"
+ dbpath = golite.InMemory
prefix = defaultPrefix
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- instanceID := os.Getpid()
- take, takeClose, takeErr := takeStmt(db, prefix, instanceID)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+ take, takeClose, takeErr := takeStmt(cfg)
g.TErrorIf(takeErr)
defer g.SomeFnError(
takeClose,
@@ -137,9 +184,10 @@ func test_takeStmt() {
})
g.Testing("if there is already an owner, we overtake it", func() {
- otherID := instanceID + 1
+ otherCfg := cfg
+ otherCfg.instanceID = instanceID + 1
- take, takeClose, takeErr := takeStmt(db, prefix, otherID)
+ take, takeClose, takeErr := takeStmt(otherCfg)
g.TErrorIf(takeErr)
defer takeClose()
@@ -153,7 +201,7 @@ func test_takeStmt() {
err = db.QueryRow(sqlOwner, topic, consumer).Scan(&ownerID)
g.TErrorIf(err)
- g.TAssertEqual(ownerID, otherID)
+ g.TAssertEqual(ownerID, otherCfg.instanceID)
})
g.Testing("no error if closed more than once", func() {
g.TErrorIf(g.SomeError(
@@ -170,6 +218,7 @@ func test_publishStmt() {
const (
topic = "publish() topic"
payloadStr = "publish() payload"
+ dbpath = golite.InMemory
prefix = defaultPrefix
)
var (
@@ -182,12 +231,17 @@ func test_publishStmt() {
}
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- instanceID := os.Getpid()
- publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+ publish, publishClose, publishErr := publishStmt(cfg)
g.TErrorIf(publishErr)
defer g.SomeFnError(
publishClose,
@@ -255,6 +309,7 @@ func test_findStmt() {
const (
topic = "find() topic"
payloadStr = "find() payload"
+ dbpath = golite.InMemory
prefix = defaultPrefix
)
var (
@@ -267,13 +322,18 @@ func test_findStmt() {
}
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- instanceID := os.Getpid()
- publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
- find, findClose, findErr := findStmt(db, prefix, instanceID)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+ publish, publishClose, publishErr := publishStmt(cfg)
+ find, findClose, findErr := findStmt(cfg)
g.TErrorIf(g.SomeError(
publishErr,
findErr,
@@ -354,6 +414,7 @@ func test_nextStmt() {
topic = "next() topic"
payloadStr = "next() payload"
consumer = "next() consumer"
+ dbpath = golite.InMemory
prefix = defaultPrefix
)
var (
@@ -366,15 +427,24 @@ func test_nextStmt() {
}
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- instanceID := os.Getpid()
- take, takeClose, takeErr := takeStmt(db, prefix, instanceID)
- publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
- next, nextClose, nextErr := nextStmt(db, prefix, instanceID)
- commit, commitClose, commitErr := commitStmt(db, prefix, instanceID)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+ take, takeClose, takeErr := takeStmt(cfg)
+ publish, publishClose, publishErr := publishStmt(cfg)
+ next, nextClose, nextErr := nextStmt(cfg)
+ commit, commitClose, commitErr := commitStmt(cfg)
+ g.TErrorIf(takeErr)
+ g.TErrorIf(publishErr)
+ g.TErrorIf(nextErr)
+ g.TErrorIf(commitErr)
g.TErrorIf(g.SomeError(
takeErr,
publishErr,
@@ -440,9 +510,10 @@ func test_nextStmt() {
})
g.Testing("error when we're not the owner", func() {
- otherID := instanceID + 1
+ otherCfg := cfg
+ otherCfg.instanceID = instanceID + 1
- take, takeClose, takeErr := takeStmt(db, prefix, otherID)
+ take, takeClose, takeErr := takeStmt(otherCfg)
g.TErrorIf(takeErr)
defer takeClose()
@@ -455,7 +526,7 @@ func test_nextStmt() {
_, err = next(topic, consumer)
g.TAssertEqual(err, fmt.Errorf(
notOwnerErrorFmt,
- otherID,
+ otherCfg.instanceID,
topic,
consumer,
instanceID,
@@ -478,6 +549,7 @@ func test_messageEach() {
topic = "messageEach() topic"
payloadStr = "messageEach() payload"
consumer = "messageEach() consumer"
+ dbpath = golite.InMemory
prefix = defaultPrefix
)
var (
@@ -490,14 +562,19 @@ func test_messageEach() {
}
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- instanceID := os.Getpid()
- take, takeClose, takeErr := takeStmt(db, prefix, instanceID)
- publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
- pending, pendingClose, pendingErr := pendingStmt(db, prefix, instanceID)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+ take, takeClose, takeErr := takeStmt(cfg)
+ publish, publishClose, publishErr := publishStmt(cfg)
+ pending, pendingClose, pendingErr := pendingStmt(cfg)
g.TErrorIf(g.SomeError(
takeErr,
publishErr,
@@ -640,6 +717,7 @@ func test_pendingStmt() {
topic = "pending() topic"
payloadStr = "pending() payload"
consumer = "pending() consumer"
+ dbpath = golite.InMemory
prefix = defaultPrefix
)
var (
@@ -652,16 +730,21 @@ func test_pendingStmt() {
}
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- instanceID := os.Getpid()
- take, takeClose, takeErr := takeStmt(db, prefix, instanceID)
- publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
- pending, pendingClose, pendingErr := pendingStmt(db, prefix, instanceID)
- commit, commitClose, commitErr := commitStmt(db, prefix, instanceID)
- toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+ take, takeClose, takeErr := takeStmt(cfg)
+ publish, publishClose, publishErr := publishStmt(cfg)
+ pending, pendingClose, pendingErr := pendingStmt(cfg)
+ commit, commitClose, commitErr := commitStmt(cfg)
+ toDead, toDeadClose, toDeadErr := toDeadStmt(cfg)
g.TErrorIf(g.SomeError(
takeErr,
publishErr,
@@ -820,9 +903,10 @@ func test_pendingStmt() {
})
g.Testing("when we're not the owners we get nothing", func() {
- otherID := instanceID + 1
+ otherCfg := cfg
+ otherCfg.instanceID = instanceID + 1
- take, takeClose, takeErr := takeStmt(db, prefix, otherID)
+ take, takeClose, takeErr := takeStmt(otherCfg)
g.TErrorIf(takeErr)
defer takeClose()
@@ -871,6 +955,7 @@ func test_commitStmt() {
topic = "commit() topic"
payloadStr = "commit() payload"
consumer = "commit() consumer"
+ dbpath = golite.InMemory
prefix = defaultPrefix
)
var (
@@ -883,15 +968,20 @@ func test_commitStmt() {
}
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- instanceID := os.Getpid()
- take, takeClose, takeErr := takeStmt(db, prefix, instanceID)
- publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
- commit, commitClose, commitErr := commitStmt(db, prefix, instanceID)
- toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+ take, takeClose, takeErr := takeStmt(cfg)
+ publish, publishClose, publishErr := publishStmt(cfg)
+ commit, commitClose, commitErr := commitStmt(cfg)
+ toDead, toDeadClose, toDeadErr := toDeadStmt(cfg)
g.TErrorIf(g.SomeError(
takeErr,
publishErr,
@@ -937,7 +1027,10 @@ func test_commitStmt() {
g.Testing("we can't commit non-existent messages", func() {
err := cmt(consumer, guuid.New())
- g.TAssertEqual(err, sql.ErrNoRows)
+ g.TAssertEqual(
+ err.(golite.Error).ExtendedCode,
+ golite.ErrConstraintNotNull,
+ )
})
g.Testing("multiple consumers may commit a message", func() {
@@ -987,8 +1080,10 @@ func test_commitStmt() {
})
g.Testing("error if we don't own the topic/consumer", func() {
- otherID := instanceID + 1
- take, takeClose, takeErr := takeStmt(db, prefix, otherID)
+ otherCfg := cfg
+ otherCfg.instanceID = instanceID + 1
+
+ take, takeClose, takeErr := takeStmt(otherCfg)
g.TErrorIf(takeErr)
defer takeClose()
@@ -998,13 +1093,10 @@ func test_commitStmt() {
g.TErrorIf(err)
err = commit(consumer, messageID)
- g.TAssertEqual(err, fmt.Errorf(
- noLongerOwnerErrorFmt,
- instanceID,
- topic,
- consumer,
- otherID,
- ))
+ g.TAssertEqual(
+ err.(golite.Error).ExtendedCode,
+ golite.ErrConstraintTrigger,
+ )
})
g.Testing("no actual closing occurs", func() {
@@ -1023,6 +1115,7 @@ func test_toDeadStmt() {
topic = "toDead() topic"
payloadStr = "toDead() payload"
consumer = "toDead() consumer"
+ dbpath = golite.InMemory
prefix = defaultPrefix
)
var (
@@ -1035,15 +1128,20 @@ func test_toDeadStmt() {
}
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- instanceID := os.Getpid()
- take, takeClose, takeErr := takeStmt(db, prefix, instanceID)
- publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
- commit, commitClose, commitErr := commitStmt(db, prefix, instanceID)
- toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+ take, takeClose, takeErr := takeStmt(cfg)
+ publish, publishClose, publishErr := publishStmt(cfg)
+ commit, commitClose, commitErr := commitStmt(cfg)
+ toDead, toDeadClose, toDeadErr := toDeadStmt(cfg)
g.TErrorIf(g.SomeError(
takeErr,
publishErr,
@@ -1107,7 +1205,10 @@ func test_toDeadStmt() {
g.Testing("we can't mark as dead non-existent messages", func() {
err := asDead(consumer, guuid.New(), guuid.New())
- g.TAssertEqual(err, sql.ErrNoRows)
+ g.TAssertEqual(
+ err.(golite.Error).ExtendedCode,
+ golite.ErrConstraintNotNull,
+ )
})
g.Testing("multiple consumers may mark a message as dead", func() {
@@ -1173,11 +1274,13 @@ func test_toDeadStmt() {
})
g.Testing("error if we don't own the message's consumer/topic", func() {
- otherID := instanceID + 1
+ otherCfg := cfg
+ otherCfg.instanceID = instanceID + 1
+
messageID1 := pub(topic)
messageID2 := pub(topic)
- take, takeClose, takeErr := takeStmt(db, prefix, otherID)
+ take, takeClose, takeErr := takeStmt(otherCfg)
g.TErrorIf(takeErr)
defer takeClose()
@@ -1188,13 +1291,10 @@ func test_toDeadStmt() {
g.TErrorIf(err)
err = toDead(consumer, messageID2, guuid.New())
- g.TAssertEqual(err, fmt.Errorf(
- noLongerOwnerErrorFmt,
- instanceID,
- topic,
- consumer,
- otherID,
- ))
+ g.TAssertEqual(
+ err.(golite.Error).ExtendedCode,
+ golite.ErrConstraintTrigger,
+ )
})
g.Testing("no actual closing occurs", func() {
@@ -1213,6 +1313,7 @@ func test_replayStmt() {
topic = "replay() topic"
payloadStr = "replay() payload"
consumer = "replay() consumer"
+ dbpath = golite.InMemory
prefix = defaultPrefix
)
var (
@@ -1225,15 +1326,20 @@ func test_replayStmt() {
}
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- instanceID := os.Getpid()
- take, takeClose, takeErr := takeStmt(db, prefix, instanceID)
- publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
- toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID)
- replay, replayClose, replayErr := replayStmt(db, prefix, instanceID)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+ take, takeClose, takeErr := takeStmt(cfg)
+ publish, publishClose, publishErr := publishStmt(cfg)
+ toDead, toDeadClose, toDeadErr := toDeadStmt(cfg)
+ replay, replayClose, replayErr := replayStmt(cfg)
g.TErrorIf(g.SomeError(
takeErr,
publishErr,
@@ -1343,6 +1449,7 @@ func test_oneDeadStmt() {
topic = "oneDead() topic"
payloadStr = "oneDead() payload"
consumer = "oneDead() consumer"
+ dbpath = golite.InMemory
prefix = defaultPrefix
)
var (
@@ -1355,16 +1462,21 @@ func test_oneDeadStmt() {
}
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- instanceID := os.Getpid()
- take, takeClose, takeErr := takeStmt(db, prefix, instanceID)
- publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
- toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID)
- replay, replayClose, replayErr := replayStmt(db, prefix, instanceID)
- oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(db, prefix, instanceID)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+ take, takeClose, takeErr := takeStmt(cfg)
+ publish, publishClose, publishErr := publishStmt(cfg)
+ toDead, toDeadClose, toDeadErr := toDeadStmt(cfg)
+ replay, replayClose, replayErr := replayStmt(cfg)
+ oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(cfg)
g.TErrorIf(g.SomeError(
takeErr,
publishErr,
@@ -1457,6 +1569,7 @@ func test_deadletterEach() {
topic = "deadletterEach() topic"
payloadStr = "deadletterEach() payload"
consumer = "deadletterEach() consumer"
+ dbpath = golite.InMemory
prefix = defaultPrefix
)
var (
@@ -1469,15 +1582,20 @@ func test_deadletterEach() {
}
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- instanceID := os.Getpid()
- take, takeClose, takeErr := takeStmt(db, prefix, instanceID)
- publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
- toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID)
- allDead, allDeadClose, allDeadErr := allDeadStmt(db, prefix, instanceID)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+ take, takeClose, takeErr := takeStmt(cfg)
+ publish, publishClose, publishErr := publishStmt(cfg)
+ toDead, toDeadClose, toDeadErr := toDeadStmt(cfg)
+ allDead, allDeadClose, allDeadErr := allDeadStmt(cfg)
g.TErrorIf(g.SomeError(
takeErr,
publishErr,
@@ -1631,6 +1749,7 @@ func test_allDeadStmt() {
topic = "allDead() topic"
payloadStr = "allDead() payload"
consumer = "allDead() consumer"
+ dbpath = golite.InMemory
prefix = defaultPrefix
)
var (
@@ -1643,16 +1762,21 @@ func test_allDeadStmt() {
}
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- instanceID := os.Getpid()
- take, takeClose, takeErr := takeStmt(db, prefix, instanceID)
- publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
- toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID)
- replay, replayClose, replayErr := replayStmt(db, prefix, instanceID)
- allDead, allDeadClose, allDeadErr := allDeadStmt(db, prefix, instanceID)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+ take, takeClose, takeErr := takeStmt(cfg)
+ publish, publishClose, publishErr := publishStmt(cfg)
+ toDead, toDeadClose, toDeadErr := toDeadStmt(cfg)
+ replay, replayClose, replayErr := replayStmt(cfg)
+ allDead, allDeadClose, allDeadErr := allDeadStmt(cfg)
g.TErrorIf(g.SomeError(
takeErr,
publishErr,
@@ -1784,6 +1908,7 @@ func test_sizeStmt() {
topic = "size() topic"
payloadStr = "size() payload"
consumer = "size() consumer"
+ dbpath = golite.InMemory
prefix = defaultPrefix
)
var (
@@ -1796,17 +1921,22 @@ func test_sizeStmt() {
}
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- instanceID := os.Getpid()
- take, takeClose, takeErr := takeStmt(db, prefix, instanceID)
- publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
- toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID)
- replay, replayClose, replayErr := replayStmt(db, prefix, instanceID)
- oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(db, prefix, instanceID)
- size, sizeClose, sizeErr := sizeStmt(db, prefix, instanceID)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+ take, takeClose, takeErr := takeStmt(cfg)
+ publish, publishClose, publishErr := publishStmt(cfg)
+ toDead, toDeadClose, toDeadErr := toDeadStmt(cfg)
+ replay, replayClose, replayErr := replayStmt(cfg)
+ oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(cfg)
+ size, sizeClose, sizeErr := sizeStmt(cfg)
g.TErrorIf(g.SomeError(
takeErr,
publishErr,
@@ -1892,6 +2022,7 @@ func test_countStmt() {
topic = "count() topic"
payloadStr = "count() payload"
consumer = "count() consumer"
+ dbpath = golite.InMemory
prefix = defaultPrefix
)
var (
@@ -1904,17 +2035,22 @@ func test_countStmt() {
}
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- instanceID := os.Getpid()
- take, takeClose, takeErr := takeStmt(db, prefix, instanceID)
- publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
- next, nextClose, nextErr := nextStmt(db, prefix, instanceID)
- commit, commitClose, commitErr := commitStmt(db, prefix, instanceID)
- toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID)
- count, countClose, countErr := countStmt(db, prefix, instanceID)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+ take, takeClose, takeErr := takeStmt(cfg)
+ publish, publishClose, publishErr := publishStmt(cfg)
+ next, nextClose, nextErr := nextStmt(cfg)
+ commit, commitClose, commitErr := commitStmt(cfg)
+ toDead, toDeadClose, toDeadErr := toDeadStmt(cfg)
+ count, countClose, countErr := countStmt(cfg)
g.TErrorIf(g.SomeError(
takeErr,
publishErr,
@@ -2020,6 +2156,7 @@ func test_hasDataStmt() {
topic = "hasData() topic"
payloadStr = "hasData() payload"
consumer = "hasData() consumer"
+ dbpath = golite.InMemory
prefix = defaultPrefix
)
var (
@@ -2032,17 +2169,22 @@ func test_hasDataStmt() {
}
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- instanceID := os.Getpid()
- take, takeClose, takeErr := takeStmt(db, prefix, instanceID)
- publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
- next, nextClose, nextErr := nextStmt(db, prefix, instanceID)
- commit, commitClose, commitErr := commitStmt(db, prefix, instanceID)
- toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID)
- hasData, hasDataClose, hasDataErr := hasDataStmt(db, prefix, instanceID)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+ take, takeClose, takeErr := takeStmt(cfg)
+ publish, publishClose, publishErr := publishStmt(cfg)
+ next, nextClose, nextErr := nextStmt(cfg)
+ commit, commitClose, commitErr := commitStmt(cfg)
+ toDead, toDeadClose, toDeadErr := toDeadStmt(cfg)
+ hasData, hasDataClose, hasDataErr := hasDataStmt(cfg)
g.TErrorIf(g.SomeError(
takeErr,
publishErr,
@@ -2148,6 +2290,8 @@ func test_initDB() {
topic = "initDB() topic"
payloadStr = "initDB() payload"
consumer = "initDB() consumer"
+ dbpath = golite.InMemory
+ prefix = defaultPrefix
)
var (
flowID = guuid.New()
@@ -2159,17 +2303,12 @@ func test_initDB() {
}
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
- g.TErrorIf(err)
- defer db.Close()
-
var messages []messageT
notifyFn := func(message messageT) {
messages = append(messages, message)
}
- instanceID := os.Getpid()
- queries, err := initDB(db, defaultPrefix, notifyFn, instanceID)
+ queries, err := initDB(dbpath, prefix, notifyFn, instanceID)
g.TErrorIf(err)
defer queries.close()
@@ -2256,31 +2395,16 @@ func test_initDB() {
func test_queriesTclose() {
g.TestStart("queriesT.close()")
- db, err := sql.Open(golite.DriverName, golite.InMemory)
- g.TErrorIf(err)
- defer db.Close()
+ const (
+ dbpath = golite.InMemory
+ prefix = defaultPrefix
+ )
- instanceID := os.Getpid()
- queries, err := initDB(db, defaultPrefix, func(messageT) {}, instanceID)
+ notifyFn := func(messageT) {}
+ queries, err := initDB(dbpath, prefix, notifyFn, instanceID)
g.TErrorIf(err)
- g.Testing("after closing, we can't run queries", func() {
- unsent := UnsentMessage{ Payload: []byte{}, }
- _, err := queries.publish(unsent, guuid.New())
- g.TErrorIf(err)
- g.TErrorIf(db.Close())
-
- err = queries.close()
- g.TErrorIf(err)
-
- _, err = queries.publish(unsent, guuid.New())
- g.TAssertEqual(
- err.Error(),
- "sql: database is closed",
- )
- })
-
g.Testing("closing mre than once does not error", func() {
g.TErrorIf(g.SomeError(
queries.close(),
@@ -2289,7 +2413,6 @@ func test_queriesTclose() {
})
}
-
func test_newPinger() {
g.TestStart("newPinger()")
@@ -3287,6 +3410,7 @@ func test_queueT_Publish() {
const (
topic = "queueT.Publish() topic"
payloadStr = "queueT.Publish() payload"
+ dbpath = golite.InMemory
)
var (
flowID = guuid.New()
@@ -3298,7 +3422,7 @@ func test_queueT_Publish() {
}
)
- queue, err := New(golite.InMemory)
+ queue, err := New(dbpath)
g.TErrorIf(err)
defer queue.Close()
@@ -4350,11 +4474,7 @@ func test_queueT_Close() {
queriesErr = errors.New("queriesT{} error")
)
- db, err := sql.Open(golite.DriverName, golite.InMemory)
- g.TErrorIf(err)
-
queue := queueT{
- db: db,
queries: queriesT{
close: func() error{
queriesCount++
@@ -4376,7 +4496,7 @@ func test_queueT_Close() {
},
}
- err = queue.Close()
+ err := queue.Close()
g.TAssertEqual(err, g.WrapErrors(subscriptionsErr, queriesErr))
g.TAssertEqual(pingerCount, 1)
g.TAssertEqual(subscriptionsCount, 1)
@@ -5674,6 +5794,7 @@ func dumpQueries() {
{ "take", takeSQL },
{ "publish", publishSQL },
{ "find", findSQL },
+ { "next", nextSQL },
{ "pending", pendingSQL },
{ "commit", commitSQL },
{ "toDead", toDeadSQL },
@@ -5703,6 +5824,8 @@ func MainTest() {
g.Init()
test_defaultPrefix()
+ test_serialized()
+ test_execSerialized()
test_tryRollback()
test_inTx()
test_createTables()
diff --git a/tests/queries.sql b/tests/queries.sql
index c821e25..e790d41 100644
--- a/tests/queries.sql
+++ b/tests/queries.sql
@@ -27,6 +27,7 @@
consumer TEXT NOT NULL,
message_id INTEGER NOT NULL
REFERENCES "q_messages"(id),
+ instance_id INTEGER NOT NULL,
UNIQUE (consumer, message_id)
) STRICT;
CREATE INDEX IF NOT EXISTS "q_offsets_consumer"
@@ -38,6 +39,7 @@
consumer TEXT NOT NULL,
message_id INTEGER NOT NULL
REFERENCES "q_messages"(id),
+ instance_id INTEGER NOT NULL,
UNIQUE (consumer, message_id)
) STRICT;
CREATE INDEX IF NOT EXISTS "q_deadletters_consumer"
@@ -58,6 +60,44 @@
owner_id INTEGER NOT NULL,
UNIQUE (topic, consumer)
) STRICT;
+
+ CREATE TRIGGER IF NOT EXISTS "q_check_instance_owns_topic"
+ BEFORE INSERT ON "q_offsets"
+ WHEN NEW.instance_id != (
+ SELECT owner_id FROM "q_owners"
+ WHERE topic = (
+ SELECT "q_payloads".topic
+ FROM "q_payloads"
+ JOIN "q_messages" ON "q_payloads".id =
+ "q_messages".payload_id
+ WHERE "q_messages".id = NEW.message_id
+ ) AND consumer = NEW.consumer
+ )
+ BEGIN
+ SELECT RAISE(
+ ABORT,
+ 'instance does not own topic/consumer combo'
+ );
+ END;
+
+ CREATE TRIGGER IF NOT EXISTS "q_check_can_publish_deadletter"
+ BEFORE INSERT ON "q_deadletters"
+ WHEN NEW.instance_id != (
+ SELECT owner_id FROM "q_owners"
+ WHERE topic = (
+ SELECT "q_payloads".topic
+ FROM "q_payloads"
+ JOIN "q_messages" ON "q_payloads".id =
+ "q_messages".payload_id
+ WHERE "q_messages".id = NEW.message_id
+ ) AND consumer = NEW.consumer
+ )
+ BEGIN
+ SELECT RAISE(
+ ABORT,
+ 'Instance does not own topic/consumer combo'
+ );
+ END;
-- read:
@@ -81,7 +121,6 @@
INSERT INTO "q_payloads" (topic, payload)
VALUES (?, ?);
- -- FIXME: must be inside a trnsaction
INSERT INTO "q_messages" (uuid, flow_id, payload_id)
VALUES (?, ?, last_insert_rowid());
@@ -114,6 +153,38 @@
-- owner:
+-- next.sql:
+-- write:
+
+-- read:
+ SELECT
+ (
+ SELECT owner_id FROM "q_owners"
+ WHERE
+ topic = ? AND
+ consumer = ?
+ LIMIT 1
+ ) AS owner_id,
+ "q_messages".id,
+ "q_messages".timestamp,
+ "q_messages".uuid,
+ "q_messages".flow_id,
+ "q_payloads".payload
+ FROM "q_messages"
+ JOIN "q_payloads" ON
+ "q_payloads".id = "q_messages".payload_id
+ WHERE
+ "q_payloads".topic = ? AND
+ "q_messages".id NOT IN (
+ SELECT message_id FROM "q_offsets"
+ WHERE consumer = ?
+ )
+ ORDER BY "q_messages".id ASC
+ LIMIT 1;
+
+
+-- owner:
+
-- pending.sql:
-- write:
@@ -146,46 +217,28 @@
-- commit.sql:
-- write:
- INSERT INTO "q_offsets" (consumer, message_id)
- VALUES (?, (SELECT id FROM "q_messages" WHERE uuid = ?));
+ INSERT INTO "q_offsets" (consumer, message_id, instance_id)
+ VALUES (?, (SELECT id FROM "q_messages" WHERE uuid = ?), ?);
-- read:
- SELECT "q_payloads".topic from "q_payloads"
- JOIN "q_messages" ON
- "q_payloads".id = "q_messages".payload_id
- WHERE "q_messages".uuid = ?;
-
-- owner:
- SELECT owner_id FROM "q_owners"
- WHERE
- topic = ? AND
- consumer = ?;
-
-- toDead.sql:
-- write:
- INSERT INTO "q_offsets" ( consumer, message_id)
- VALUES ( ?, (SELECT id FROM "q_messages" WHERE uuid = ?));
+ INSERT INTO "q_offsets"
+ ( consumer, message_id, instance_id)
+ VALUES ( ?, (SELECT id FROM "q_messages" WHERE uuid = ?), ?);
- INSERT INTO "q_deadletters" (uuid, consumer, message_id)
- VALUES (?, ?, (SELECT id FROM "q_messages" WHERE uuid = ?));
+ INSERT INTO "q_deadletters"
+ (uuid, consumer, message_id, instance_id)
+ VALUES (?, ?, (SELECT id FROM "q_messages" WHERE uuid = ?), ?);
-- read:
- SELECT "q_payloads".topic FROM "q_payloads"
- JOIN "q_messages" ON
- "q_payloads".id = "q_messages".payload_id
- WHERE "q_messages".uuid = ?;
-
-- owner:
- SELECT owner_id FROM "q_owners"
- WHERE
- topic = ? AND
- consumer = ?;
-
-- replay.sql:
-- write: