diff options
-rw-r--r-- | src/q.go | 652 | ||||
-rw-r--r-- | tests/functional/consumer-with-deadletter/q.go | 6 | ||||
-rw-r--r-- | tests/functional/new-instance-takeover/q.go | 20 | ||||
-rw-r--r-- | tests/functional/wait-after-publish/q.go | 5 | ||||
-rw-r--r-- | tests/q.go | 453 | ||||
-rw-r--r-- | tests/queries.sql | 107 |
6 files changed, 704 insertions, 539 deletions
@@ -1,6 +1,6 @@ package q + import ( - "context" "database/sql" "flag" "fmt" @@ -18,15 +18,21 @@ import ( const ( - defaultPrefix = "q" - reaperSkipCount = 1000 - notOwnerErrorFmt = "%v owns %#v as %#v, not us (%v)" - noLongerOwnerErrorFmt = "we (%v) no longer own %#v as %#v, but %v does" - rollbackErrorFmt = "rollback error: %w; while executing: %w" + defaultPrefix = "q" + reaperSkipCount = 1000 + notOwnerErrorFmt = "%v owns %#v as %#v, not us (%v)" + rollbackErrorFmt = "rollback error: %w; while executing: %w" ) +type dbconfigT struct{ + shared *sql.DB + dbpath string + prefix string + instanceID int +} + type queryT struct{ write string read string @@ -126,7 +132,6 @@ type subscriptionsT struct { } type queueT struct{ - db *sql.DB queries queriesT subscriptions subscriptionsT pinger pingerT[struct{}] @@ -158,12 +163,55 @@ type IQueue interface{ -func closeNoop() error { - return nil +func serialized[A any, B any](callback func(...A) B) (func(...A) B, func()) { + in := make(chan []A) + out := make(chan B) + + closed := false + var ( + closeWg sync.WaitGroup + closeMutex sync.Mutex + ) + closeWg.Add(1) + + go func() { + for input := range in { + out <- callback(input...) + } + close(out) + closeWg.Done() + }() + + fn := func(input ...A) B { + in <- input + return (<- out) + } + + closeFn := func() { + closeMutex.Lock() + defer closeMutex.Unlock() + if closed { + return + } + close(in) + closed = true + closeWg.Wait() + } + + return fn, closeFn +} + +func execSerialized(query string, db *sql.DB) (func(...any) error, func()) { + return serialized(func(args ...any) error { + return inTx(db, func(tx *sql.Tx) error { + _, err := tx.Exec(query, args...) + return err + }) + }) } -func tryRollback(db *sql.DB, ctx context.Context, err error) error { - _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;") +func tryRollback(tx *sql.Tx, err error) error { + rollbackErr := tx.Rollback() if rollbackErr != nil { return fmt.Errorf( rollbackErrorFmt, @@ -175,25 +223,20 @@ func tryRollback(db *sql.DB, ctx context.Context, err error) error { return err } -// FIXME -// See: -// https://sqlite.org/forum/forumpost/2507664507 -func inTx(db *sql.DB, fn func(context.Context) error) error { - ctx := context.Background() - - _, err := db.ExecContext(ctx, "BEGIN IMMEDIATE;") +func inTx(db *sql.DB, fn func(*sql.Tx) error) error { + tx, err := db.Begin() if err != nil { return err } - err = fn(ctx) + err = fn(tx) if err != nil { - return tryRollback(db, ctx, err) + return tryRollback(tx, err) } - _, err = db.ExecContext(ctx, "COMMIT;") + err = tx.Commit() if err != nil { - return tryRollback(db, ctx, err) + return tryRollback(tx, err) } return nil @@ -227,6 +270,7 @@ func createTablesSQL(prefix string) queryT { consumer TEXT NOT NULL, message_id INTEGER NOT NULL REFERENCES "%s_messages"(id), + instance_id INTEGER NOT NULL, UNIQUE (consumer, message_id) ) STRICT; CREATE INDEX IF NOT EXISTS "%s_offsets_consumer" @@ -238,6 +282,7 @@ func createTablesSQL(prefix string) queryT { consumer TEXT NOT NULL, message_id INTEGER NOT NULL REFERENCES "%s_messages"(id), + instance_id INTEGER NOT NULL, UNIQUE (consumer, message_id) ) STRICT; CREATE INDEX IF NOT EXISTS "%s_deadletters_consumer" @@ -258,6 +303,44 @@ func createTablesSQL(prefix string) queryT { owner_id INTEGER NOT NULL, UNIQUE (topic, consumer) ) STRICT; + + CREATE TRIGGER IF NOT EXISTS "%s_check_instance_owns_topic" + BEFORE INSERT ON "%s_offsets" + WHEN NEW.instance_id != ( + SELECT owner_id FROM "%s_owners" + WHERE topic = ( + SELECT "%s_payloads".topic + FROM "%s_payloads" + JOIN "%s_messages" ON "%s_payloads".id = + "%s_messages".payload_id + WHERE "%s_messages".id = NEW.message_id + ) AND consumer = NEW.consumer + ) + BEGIN + SELECT RAISE( + ABORT, + 'instance does not own topic/consumer combo' + ); + END; + + CREATE TRIGGER IF NOT EXISTS "%s_check_can_publish_deadletter" + BEFORE INSERT ON "%s_deadletters" + WHEN NEW.instance_id != ( + SELECT owner_id FROM "%s_owners" + WHERE topic = ( + SELECT "%s_payloads".topic + FROM "%s_payloads" + JOIN "%s_messages" ON "%s_payloads".id = + "%s_messages".payload_id + WHERE "%s_messages".id = NEW.message_id + ) AND consumer = NEW.consumer + ) + BEGIN + SELECT RAISE( + ABORT, + 'Instance does not own topic/consumer combo' + ); + END; ` return queryT{ write: fmt.Sprintf( @@ -284,6 +367,24 @@ func createTablesSQL(prefix string) queryT { prefix, prefix, prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, ), } } @@ -291,8 +392,8 @@ func createTablesSQL(prefix string) queryT { func createTables(db *sql.DB, prefix string) error { q := createTablesSQL(prefix) - return inTx(db, func(ctx context.Context) error { - _, err := db.ExecContext(ctx, q.write) + return inTx(db, func(tx *sql.Tx) error { + _, err := tx.Exec(q.write) return err }) } @@ -310,26 +411,25 @@ func takeSQL(prefix string) queryT { } func takeStmt( - db *sql.DB, - prefix string, - instanceID int, + cfg dbconfigT, ) (func(string, string) error, func() error, error) { - q := takeSQL(prefix) + q := takeSQL(cfg.prefix) + + writeStmt, err := cfg.shared.Prepare(q.write) + if err != nil { + return nil, nil, err + } fn := func(topic string, consumer string) error { - return inTx(db, func(ctx context.Context) error { - _, err := db.ExecContext( - ctx, - q.write, - topic, - consumer, - instanceID, - ) - return err - }) + _, err := writeStmt.Exec( + topic, + consumer, + cfg.instanceID, + ) + return err } - return fn, closeNoop, nil + return fn, writeStmt.Close, nil } func publishSQL(prefix string) queryT { @@ -337,7 +437,6 @@ func publishSQL(prefix string) queryT { INSERT INTO "%s_payloads" (topic, payload) VALUES (?, ?); - -- FIXME: must be inside a trnsaction INSERT INTO "%s_messages" (uuid, flow_id, payload_id) VALUES (?, ?, last_insert_rowid()); ` @@ -352,17 +451,23 @@ func publishSQL(prefix string) queryT { } func publishStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(UnsentMessage, guuid.UUID) (messageT, error), func() error, error) { - q := publishSQL(prefix) + q := publishSQL(cfg.prefix) + + readStmt, err := cfg.shared.Prepare(q.read) + if err != nil { + return nil, nil, err + } - readStmt, err := db.Prepare(q.read) + privateDB, err := sql.Open(golite.DriverName, cfg.dbpath) if err != nil { + readStmt.Close() return nil, nil, err } + writeFn, writeFnClose := execSerialized(q.write, privateDB) + fn := func( unsentMessage UnsentMessage, messageID guuid.UUID, @@ -376,17 +481,12 @@ func publishStmt( message_id_bytes := messageID[:] flow_id_bytes := unsentMessage.FlowID[:] - err := inTx(db, func(ctx context.Context) error { - _, err := db.ExecContext( - ctx, - q.write, - unsentMessage.Topic, - unsentMessage.Payload, - message_id_bytes, - flow_id_bytes, - ) - return err - }) + err := writeFn( + unsentMessage.Topic, + unsentMessage.Payload, + message_id_bytes, + flow_id_bytes, + ) if err != nil { return messageT{}, err } @@ -408,7 +508,12 @@ func publishStmt( return message, nil } - return fn, readStmt.Close, nil + closeFn := func() error { + writeFnClose() + return g.SomeError(privateDB.Close(), readStmt.Close()) + } + + return fn, closeFn, nil } func findSQL(prefix string) queryT { @@ -446,13 +551,11 @@ func findSQL(prefix string) queryT { } func findStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string, guuid.UUID) (messageT, error), func() error, error) { - q := findSQL(prefix) + q := findSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -493,6 +596,13 @@ func findStmt( func nextSQL(prefix string) queryT { const tmpl_read = ` SELECT + ( + SELECT owner_id FROM "%s_owners" + WHERE + topic = ? AND + consumer = ? + LIMIT 1 + ) AS owner_id, "%s_messages".id, "%s_messages".timestamp, "%s_messages".uuid, @@ -510,12 +620,6 @@ func nextSQL(prefix string) queryT { ORDER BY "%s_messages".id ASC LIMIT 1; ` - const tmpl_owner = ` - SELECT owner_id FROM "%s_owners" - WHERE - topic = ? AND - consumer = ?; - ` return queryT{ read: fmt.Sprintf( tmpl_read, @@ -532,17 +636,20 @@ func nextSQL(prefix string) queryT { prefix, prefix, prefix, + prefix, ), - owner: fmt.Sprintf(tmpl_owner, prefix), } } func nextStmt( - db *sql.DB, - prefix string, - instanceID int, + cfg dbconfigT, ) (func(string, string) (messageT, error), func() error, error) { - q := nextSQL(prefix) + q := nextSQL(cfg.prefix) + + readStmt, err := cfg.shared.Prepare(q.read) + if err != nil { + return nil, nil, err + } fn := func(topic string, consumer string) (messageT, error) { message := messageT{ @@ -550,44 +657,34 @@ func nextStmt( } var ( - err error ownerID int timestr string message_id_bytes []byte flow_id_bytes []byte ) - tx, err := db.Begin() - if err != nil { - return messageT{}, err - } - defer tx.Rollback() - err = tx.QueryRow(q.owner, topic, consumer).Scan(&ownerID) + err = readStmt.QueryRow(topic, consumer, topic, consumer).Scan( + &ownerID, + &message.id, + ×tr, + &message_id_bytes, + &flow_id_bytes, + &message.payload, + ) if err != nil { return messageT{}, err } - if ownerID != instanceID { + if ownerID != cfg.instanceID { err := fmt.Errorf( notOwnerErrorFmt, ownerID, topic, consumer, - instanceID, + cfg.instanceID, ) return messageT{}, err } - - err = tx.QueryRow(q.read, topic, consumer).Scan( - &message.id, - ×tr, - &message_id_bytes, - &flow_id_bytes, - &message.payload, - ) - if err != nil { - return messageT{}, err - } message.uuid = guuid.UUID(message_id_bytes) message.flowID = guuid.UUID(flow_id_bytes) @@ -599,7 +696,7 @@ func nextStmt( return message, nil } - return fn, closeNoop, nil + return fn, readStmt.Close, nil } func messageEach(rows *sql.Rows, callback func(messageT) error) error { @@ -623,19 +720,22 @@ func messageEach(rows *sql.Rows, callback func(messageT) error) error { &message.payload, ) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } message.uuid = guuid.UUID(message_id_bytes) message.flowID = guuid.UUID(flow_id_bytes) message.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } err = callback(message) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } } @@ -691,20 +791,19 @@ func pendingSQL(prefix string) queryT { } func pendingStmt( - db *sql.DB, - prefix string, - instanceID int, + cfg dbconfigT, ) (func(string, string) (*sql.Rows, error), func() error, error) { - q := pendingSQL(prefix) + q := pendingSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } - ownerStmt, err := db.Prepare(q.owner) + ownerStmt, err := cfg.shared.Prepare(q.owner) if err != nil { - return nil, nil, g.WrapErrors(readStmt.Close(), err) + readStmt.Close() + return nil, nil, err } fn := func(topic string, consumer string) (*sql.Rows, error) { @@ -716,7 +815,7 @@ func pendingStmt( // best effort check, the final one is done during // commit within a transaction - if ownerID != instanceID { + if ownerID != cfg.instanceID { return nil, nil } @@ -732,130 +831,67 @@ func pendingStmt( func commitSQL(prefix string) queryT { const tmpl_write = ` - INSERT INTO "%s_offsets" (consumer, message_id) - VALUES (?, (SELECT id FROM "%s_messages" WHERE uuid = ?)); - ` - const tmpl_read = ` - SELECT "%s_payloads".topic from "%s_payloads" - JOIN "%s_messages" ON - "%s_payloads".id = "%s_messages".payload_id - WHERE "%s_messages".uuid = ?; - ` - const tmpl_owner = ` - SELECT owner_id FROM "%s_owners" - WHERE - topic = ? AND - consumer = ?; + INSERT INTO "%s_offsets" (consumer, message_id, instance_id) + VALUES (?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?); ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix, prefix), - read: fmt.Sprintf( - tmpl_read, - prefix, - prefix, - prefix, - prefix, - prefix, - prefix, - ), - owner: fmt.Sprintf(tmpl_owner, prefix), } } func commitStmt( - db *sql.DB, - prefix string, - instanceID int, + cfg dbconfigT, ) (func(string, guuid.UUID) error, func() error, error) { - q := commitSQL(prefix) + q := commitSQL(cfg.prefix) + + writeStmt, err := cfg.shared.Prepare(q.write) + if err != nil { + return nil, nil, err + } fn := func(consumer string, messageID guuid.UUID) error { message_id_bytes := messageID[:] - return inTx(db, func(ctx context.Context) error { - var topic string - err := db.QueryRowContext( - ctx, - q.read, - message_id_bytes, - ).Scan(&topic) - if err != nil { - return err - } - - var ownerID int - err = db.QueryRowContext( - ctx, - q.owner, - topic, - consumer, - ).Scan(&ownerID) - if err != nil { - return err - } - - if ownerID != instanceID { - return fmt.Errorf( - noLongerOwnerErrorFmt, - instanceID, - topic, - consumer, - ownerID, - ) - } - - _, err = db.ExecContext(ctx, q.write, consumer, message_id_bytes) - return err - }) + _, err = writeStmt.Exec( + consumer, + message_id_bytes, + cfg.instanceID, + ) + return err } - return fn, closeNoop, nil + return fn, writeStmt.Close, nil } func toDeadSQL(prefix string) queryT { const tmpl_write = ` - INSERT INTO "%s_offsets" ( consumer, message_id) - VALUES ( ?, (SELECT id FROM "%s_messages" WHERE uuid = ?)); + INSERT INTO "%s_offsets" + ( consumer, message_id, instance_id) + VALUES ( ?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?); - INSERT INTO "%s_deadletters" (uuid, consumer, message_id) - VALUES (?, ?, (SELECT id FROM "%s_messages" WHERE uuid = ?)); - ` - const tmpl_read = ` - SELECT "%s_payloads".topic FROM "%s_payloads" - JOIN "%s_messages" ON - "%s_payloads".id = "%s_messages".payload_id - WHERE "%s_messages".uuid = ?; - ` - const tmpl_owner = ` - SELECT owner_id FROM "%s_owners" - WHERE - topic = ? AND - consumer = ?; + INSERT INTO "%s_deadletters" + (uuid, consumer, message_id, instance_id) + VALUES (?, ?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?); ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix, prefix, prefix, prefix), - read: fmt.Sprintf( - tmpl_read, - prefix, - prefix, - prefix, - prefix, - prefix, - prefix, - ), - owner: fmt.Sprintf(tmpl_owner, prefix), } } func toDeadStmt( - db *sql.DB, - prefix string, - instanceID int, + cfg dbconfigT, ) ( func(string, guuid.UUID, guuid.UUID) error, func() error, error, ) { - q := toDeadSQL(prefix) + q := toDeadSQL(cfg.prefix) + + privateDB, err := sql.Open(golite.DriverName, cfg.dbpath) + if err != nil { + return nil, nil, err + } + + writeFn, writeFnClose := execSerialized(q.write, privateDB) fn := func( consumer string, @@ -864,52 +900,24 @@ func toDeadStmt( ) error { message_id_bytes := messageID[:] deadletter_id_bytes := deadletterID[:] - return inTx(db, func(ctx context.Context) error { - var topic string - err := db.QueryRowContext( - ctx, - q.read, - message_id_bytes, - ).Scan(&topic) - if err != nil { - return err - } - - var ownerID int - err = db.QueryRowContext( - ctx, - q.owner, - topic, - consumer, - ).Scan(&ownerID) - if err != nil { - return err - } - - if ownerID != instanceID { - return fmt.Errorf( - noLongerOwnerErrorFmt, - instanceID, - topic, - consumer, - ownerID, - ) - } + return writeFn( + consumer, + message_id_bytes, + cfg.instanceID, + deadletter_id_bytes, + consumer, + message_id_bytes, + cfg.instanceID, + ) + } - _, err = db.ExecContext( - ctx, - q.write, - consumer, - message_id_bytes, - deadletter_id_bytes, - consumer, - message_id_bytes, - ) - return err - }) + closeFn := func() error { + writeFnClose() + return privateDB.Close() } - return fn, closeNoop, nil + + return fn, closeFn, nil } func replaySQL(prefix string) queryT { @@ -973,33 +981,34 @@ func replaySQL(prefix string) queryT { } func replayStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(guuid.UUID, guuid.UUID) (messageT, error), func() error, error) { - q := replaySQL(prefix) + q := replaySQL(cfg.prefix) + + readStmt, err := cfg.shared.Prepare(q.read) + if err != nil { + return nil, nil, err + } - readStmt, err := db.Prepare(q.read) + privateDB, err := sql.Open(golite.DriverName, cfg.dbpath) if err != nil { + readStmt.Close() return nil, nil, err } + writeFn, writeFnClose := execSerialized(q.write, privateDB) + fn := func( deadletterID guuid.UUID, messageID guuid.UUID, ) (messageT, error) { deadletter_id_bytes := deadletterID[:] message_id_bytes := messageID[:] - err := inTx(db, func(ctx context.Context) error { - _, err := db.ExecContext( - ctx, - q.write, - message_id_bytes, - deadletter_id_bytes, - deadletter_id_bytes, - ) - return err - }) + err := writeFn( + message_id_bytes, + deadletter_id_bytes, + deadletter_id_bytes, + ) if err != nil { return messageT{}, err } @@ -1032,7 +1041,12 @@ func replayStmt( return message, nil } - return fn, readStmt.Close, nil + closeFn := func() error { + writeFnClose() + return g.SomeError(privateDB.Close(), readStmt.Close()) + } + + return fn, closeFn, nil } func oneDeadSQL(prefix string) queryT { @@ -1085,13 +1099,11 @@ func oneDeadSQL(prefix string) queryT { } func oneDeadStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string, string) (deadletterT, error), func() error, error) { - q := oneDeadSQL(prefix) + q := oneDeadSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -1157,7 +1169,8 @@ func deadletterEach( &message.payload, ) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } deadletter.uuid = guuid.UUID(deadletter_id_bytes) @@ -1170,7 +1183,8 @@ func deadletterEach( messageTimestr, ) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } deadletter.timestamp, err = time.Parse( @@ -1178,12 +1192,14 @@ func deadletterEach( deadletterTimestr, ) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } err = callback(deadletter, message) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } } @@ -1251,13 +1267,11 @@ func allDeadSQL(prefix string) queryT { } func allDeadStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string, string) (*sql.Rows, error), func() error, error) { - q := allDeadSQL(prefix) + q := allDeadSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -1292,13 +1306,11 @@ func sizeSQL(prefix string) queryT { func sizeStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string) (int, error), func() error, error) { - q := sizeSQL(prefix) + q := sizeSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -1346,13 +1358,11 @@ func countSQL(prefix string) queryT { } func countStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string, string) (int, error), func() error, error) { - q := countSQL(prefix) + q := countSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -1399,13 +1409,11 @@ func hasDataSQL(prefix string) queryT { } func hasDataStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string, string) (bool, error), func() error, error) { - q := hasDataSQL(prefix) + q := hasDataSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -1428,27 +1436,44 @@ func hasDataStmt( } func initDB( - db *sql.DB, + dbpath string, prefix string, notifyFn func(messageT), instanceID int, ) (queriesT, error) { - createTablesErr := createTables(db, prefix) - take, takeClose, takeErr := takeStmt(db, prefix, instanceID) - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) - find, findClose, findErr := findStmt(db, prefix, instanceID) - next, nextClose, nextErr := nextStmt(db, prefix, instanceID) - pending, pendingClose, pendingErr := pendingStmt(db, prefix, instanceID) - commit, commitClose, commitErr := commitStmt(db, prefix, instanceID) - toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID) - replay, replayClose, replayErr := replayStmt(db, prefix, instanceID) - oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(db, prefix, instanceID) - allDead, allDeadClose, allDeadErr := allDeadStmt(db, prefix, instanceID) - size, sizeClose, sizeErr := sizeStmt(db, prefix, instanceID) - count, countClose, countErr := countStmt(db, prefix, instanceID) - hasData, hasDataClose, hasDataErr := hasDataStmt(db, prefix, instanceID) - - err := g.SomeError( + err := g.ValidateSQLTablePrefix(prefix) + if err != nil { + return queriesT{}, err + } + + shared, err := sql.Open(golite.DriverName, dbpath) + if err != nil { + return queriesT{}, err + } + + cfg := dbconfigT{ + shared: shared, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + + createTablesErr := createTables(shared, prefix) + take, takeClose, takeErr := takeStmt(cfg) + publish, publishClose, publishErr := publishStmt(cfg) + find, findClose, findErr := findStmt(cfg) + next, nextClose, nextErr := nextStmt(cfg) + pending, pendingClose, pendingErr := pendingStmt(cfg) + commit, commitClose, commitErr := commitStmt(cfg) + toDead, toDeadClose, toDeadErr := toDeadStmt(cfg) + replay, replayClose, replayErr := replayStmt(cfg) + oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(cfg) + allDead, allDeadClose, allDeadErr := allDeadStmt(cfg) + size, sizeClose, sizeErr := sizeStmt(cfg) + count, countClose, countErr := countStmt(cfg) + hasData, hasDataClose, hasDataErr := hasDataStmt(cfg) + + err = g.SomeError( createTablesErr, takeErr, publishErr, @@ -1468,7 +1493,7 @@ func initDB( return queriesT{}, err } - close := func() error { + closeFn := func() error { return g.SomeFnError( takeClose, publishClose, @@ -1483,6 +1508,7 @@ func initDB( sizeClose, countClose, hasDataClose, + shared.Close, ) } @@ -1614,7 +1640,7 @@ func initDB( close: func() error { connMutex.Lock() defer connMutex.Unlock() - return close() + return closeFn() }, }, nil } @@ -1826,20 +1852,10 @@ func runReaper( } func NewWithPrefix(databasePath string, prefix string) (IQueue, error) { - err := g.ValidateSQLTablePrefix(prefix) - if err != nil { - return queueT{}, err - } - - db, err := sql.Open(golite.DriverName, databasePath) - if err != nil { - return queueT{}, err - } - subscriptions := makeSubscriptionsFuncs() pinger := newPinger[struct{}]() notifyFn := makeNotifyFn(subscriptions.read, pinger) - queries, err := initDB(db, prefix, notifyFn, os.Getpid()) + queries, err := initDB(databasePath, prefix, notifyFn, os.Getpid()) if err != nil { return queueT{}, err } @@ -1847,7 +1863,6 @@ func NewWithPrefix(databasePath string, prefix string) (IQueue, error) { go runReaper(pinger.onPing, subscriptions.read, subscriptions.write) return queueT{ - db: db, queries: queries, subscriptions: subscriptions, pinger: pinger, @@ -2120,7 +2135,6 @@ func cleanSubscriptions(set subscriptionsSetM) error { func (queue queueT) Close() error { queue.pinger.close() return g.WrapErrors( - queue.db.Close(), queue.subscriptions.write(cleanSubscriptions), queue.queries.close(), ) diff --git a/tests/functional/consumer-with-deadletter/q.go b/tests/functional/consumer-with-deadletter/q.go index 25391c5..a79ad5b 100644 --- a/tests/functional/consumer-with-deadletter/q.go +++ b/tests/functional/consumer-with-deadletter/q.go @@ -2,7 +2,6 @@ package q import ( "errors" - "os" "runtime" "guuid" @@ -44,15 +43,10 @@ func MainTest() { g.TAssertEqualS(ok, true, "can't get filename") databasePath := file + ".db" - os.Remove(databasePath) - os.Remove(databasePath + "-shm") - os.Remove(databasePath + "-wal") - queue, err := New(databasePath) g.TErrorIf(err) defer queue.Close() - pub := func(payload []byte, flowID guuid.UUID) { unsent := UnsentMessage{ Topic: topicX, diff --git a/tests/functional/new-instance-takeover/q.go b/tests/functional/new-instance-takeover/q.go index 6e04e5f..76ed59f 100644 --- a/tests/functional/new-instance-takeover/q.go +++ b/tests/functional/new-instance-takeover/q.go @@ -1,7 +1,6 @@ package q import ( - "fmt" "runtime" "os" @@ -33,16 +32,16 @@ func handlerFn(publish func(guuid.UUID)) func(Message) error { } func startInstance( - databasePath string, + dbpath string, instanceID int, name string, ) (IQueue, error) { - iqueue, err := New(databasePath) + iqueue, err := New(dbpath) g.TErrorIf(err) queue := iqueue.(queueT) notifyFn := makeNotifyFn(queue.subscriptions.read, queue.pinger) - queries, err := initDB(queue.db, defaultPrefix, notifyFn, instanceID) + queries, err := initDB(dbpath, defaultPrefix, notifyFn, instanceID) g.TErrorIf(err) err = queue.queries.close() @@ -68,20 +67,12 @@ func startInstance( func MainTest() { - // https://sqlite.org/forum/forumpost/2507664507 g.Init() _, file, _, ok := runtime.Caller(0) g.TAssertEqualS(ok, true, "can't get filename") dbpath := file + ".db" - dbpath = "/mnt/dois/andreh/t.db" - os.Remove(dbpath) - os.Remove(dbpath + "-shm") - os.Remove(dbpath + "-wal") - - // FIXME - return instanceID1 := os.Getpid() instanceID2 := instanceID1 + 1 @@ -89,10 +80,6 @@ func MainTest() { flowID2 := guuid.New() g.Testing("new instances take ownership of topic+name combo", func() { - if false { - fmt.Fprintf(os.Stderr, "(PID %d + 1) ", instanceID1) - } - q1, err := startInstance(dbpath, instanceID1, "first") g.TErrorIf(err) defer q1.Close() @@ -103,7 +90,6 @@ func MainTest() { <- q1.WaitFor("individual-first", flowID1, "w").Channel <- q1.WaitFor( "shared-first", flowID1, "w").Channel - // println("waited 1") q2, err := startInstance(dbpath, instanceID2, "second") g.TErrorIf(err) diff --git a/tests/functional/wait-after-publish/q.go b/tests/functional/wait-after-publish/q.go index 15a532d..b70d27e 100644 --- a/tests/functional/wait-after-publish/q.go +++ b/tests/functional/wait-after-publish/q.go @@ -1,7 +1,6 @@ package q import ( - "os" "runtime" "guuid" @@ -19,10 +18,6 @@ func MainTest() { g.TAssertEqualS(ok, true, "can't get filename") databasePath := file + ".db" - os.Remove(databasePath) - os.Remove(databasePath + "-shm") - os.Remove(databasePath + "-wal") - queue, err := New(databasePath) g.TErrorIf(err) defer queue.Close() @@ -21,6 +21,10 @@ import ( +var instanceID = os.Getpid() + + + func test_defaultPrefix() { g.TestStart("defaultPrefix") @@ -29,13 +33,46 @@ func test_defaultPrefix() { }) } -func test_tryRollback() { +func test_serialized() { // FIXME } -func test_inTx() { - /* +func test_execSerialized() { // FIXME +} + +func test_tryRollback() { + g.TestStart("tryRollback()") + + myErr := errors.New("bottom error") + + db, err := sql.Open(golite.DriverName, golite.InMemory) + g.TErrorIf(err) + defer db.Close() + + + g.Testing("the error is propagated if rollback doesn't fail", func() { + tx, err := db.Begin() + g.TErrorIf(err) + + err = tryRollback(tx, myErr) + g.TAssertEqual(err, myErr) + }) + + g.Testing("a wrapped error when rollback fails", func() { + tx, err := db.Begin() + g.TErrorIf(err) + + err = tx.Commit() + g.TErrorIf(err) + + err = tryRollback(tx, myErr) + g.TAssertEqual(reflect.DeepEqual(err, myErr), false) + g.TAssertEqual(errors.Is(err, myErr), true) + }) +} + +func test_inTx() { g.TestStart("inTx()") db, err := sql.Open(golite.DriverName, golite.InMemory) @@ -44,11 +81,11 @@ func test_inTx() { g.Testing("when fn() errors, we propagate it", func() { - myError := errors.New("to be propagated") + myErr := errors.New("to be propagated") err := inTx(db, func(tx *sql.Tx) error { - return myError + return myErr }) - g.TAssertEqual(err, myError) + g.TAssertEqual(err, myErr) }) g.Testing("on nil error we get nil", func() { @@ -57,13 +94,17 @@ func test_inTx() { }) g.TErrorIf(err) }) - */ } func test_createTables() { g.TestStart("createTables()") - db, err := sql.Open(golite.DriverName, golite.InMemory) + const ( + dbpath = golite.InMemory + prefix = defaultPrefix + ) + + db, err := sql.Open(golite.DriverName, dbpath) g.TErrorIf(err) defer db.Close() @@ -72,12 +113,12 @@ func test_createTables() { const tmpl_read = ` SELECT id FROM "%s_messages" LIMIT 1; ` - qRead := fmt.Sprintf(tmpl_read, defaultPrefix) + qRead := fmt.Sprintf(tmpl_read, prefix) _, err := db.Exec(qRead) g.TErrorNil(err) - err = createTables(db, defaultPrefix) + err = createTables(db, prefix) g.TErrorIf(err) _, err = db.Exec(qRead) @@ -86,9 +127,9 @@ func test_createTables() { g.Testing("we can do it multiple times", func() { g.TErrorIf(g.SomeError( - createTables(db, defaultPrefix), - createTables(db, defaultPrefix), - createTables(db, defaultPrefix), + createTables(db, prefix), + createTables(db, prefix), + createTables(db, prefix), )) }) } @@ -99,15 +140,21 @@ func test_takeStmt() { const ( topic = "take() topic" consumer = "take() consumer" + dbpath = golite.InMemory prefix = defaultPrefix ) - db, err := sql.Open(golite.DriverName, golite.InMemory) + db, err := sql.Open(golite.DriverName, dbpath) g.TErrorIf(err) g.TErrorIf(createTables(db, prefix)) - instanceID := os.Getpid() - take, takeClose, takeErr := takeStmt(db, prefix, instanceID) + cfg := dbconfigT{ + shared: db, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + take, takeClose, takeErr := takeStmt(cfg) g.TErrorIf(takeErr) defer g.SomeFnError( takeClose, @@ -137,9 +184,10 @@ func test_takeStmt() { }) g.Testing("if there is already an owner, we overtake it", func() { - otherID := instanceID + 1 + otherCfg := cfg + otherCfg.instanceID = instanceID + 1 - take, takeClose, takeErr := takeStmt(db, prefix, otherID) + take, takeClose, takeErr := takeStmt(otherCfg) g.TErrorIf(takeErr) defer takeClose() @@ -153,7 +201,7 @@ func test_takeStmt() { err = db.QueryRow(sqlOwner, topic, consumer).Scan(&ownerID) g.TErrorIf(err) - g.TAssertEqual(ownerID, otherID) + g.TAssertEqual(ownerID, otherCfg.instanceID) }) g.Testing("no error if closed more than once", func() { g.TErrorIf(g.SomeError( @@ -170,6 +218,7 @@ func test_publishStmt() { const ( topic = "publish() topic" payloadStr = "publish() payload" + dbpath = golite.InMemory prefix = defaultPrefix ) var ( @@ -182,12 +231,17 @@ func test_publishStmt() { } ) - db, err := sql.Open(golite.DriverName, golite.InMemory) + db, err := sql.Open(golite.DriverName, dbpath) g.TErrorIf(err) g.TErrorIf(createTables(db, prefix)) - instanceID := os.Getpid() - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) + cfg := dbconfigT{ + shared: db, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + publish, publishClose, publishErr := publishStmt(cfg) g.TErrorIf(publishErr) defer g.SomeFnError( publishClose, @@ -255,6 +309,7 @@ func test_findStmt() { const ( topic = "find() topic" payloadStr = "find() payload" + dbpath = golite.InMemory prefix = defaultPrefix ) var ( @@ -267,13 +322,18 @@ func test_findStmt() { } ) - db, err := sql.Open(golite.DriverName, golite.InMemory) + db, err := sql.Open(golite.DriverName, dbpath) g.TErrorIf(err) g.TErrorIf(createTables(db, prefix)) - instanceID := os.Getpid() - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) - find, findClose, findErr := findStmt(db, prefix, instanceID) + cfg := dbconfigT{ + shared: db, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + publish, publishClose, publishErr := publishStmt(cfg) + find, findClose, findErr := findStmt(cfg) g.TErrorIf(g.SomeError( publishErr, findErr, @@ -354,6 +414,7 @@ func test_nextStmt() { topic = "next() topic" payloadStr = "next() payload" consumer = "next() consumer" + dbpath = golite.InMemory prefix = defaultPrefix ) var ( @@ -366,15 +427,24 @@ func test_nextStmt() { } ) - db, err := sql.Open(golite.DriverName, golite.InMemory) + db, err := sql.Open(golite.DriverName, dbpath) g.TErrorIf(err) g.TErrorIf(createTables(db, prefix)) - instanceID := os.Getpid() - take, takeClose, takeErr := takeStmt(db, prefix, instanceID) - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) - next, nextClose, nextErr := nextStmt(db, prefix, instanceID) - commit, commitClose, commitErr := commitStmt(db, prefix, instanceID) + cfg := dbconfigT{ + shared: db, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + take, takeClose, takeErr := takeStmt(cfg) + publish, publishClose, publishErr := publishStmt(cfg) + next, nextClose, nextErr := nextStmt(cfg) + commit, commitClose, commitErr := commitStmt(cfg) + g.TErrorIf(takeErr) + g.TErrorIf(publishErr) + g.TErrorIf(nextErr) + g.TErrorIf(commitErr) g.TErrorIf(g.SomeError( takeErr, publishErr, @@ -440,9 +510,10 @@ func test_nextStmt() { }) g.Testing("error when we're not the owner", func() { - otherID := instanceID + 1 + otherCfg := cfg + otherCfg.instanceID = instanceID + 1 - take, takeClose, takeErr := takeStmt(db, prefix, otherID) + take, takeClose, takeErr := takeStmt(otherCfg) g.TErrorIf(takeErr) defer takeClose() @@ -455,7 +526,7 @@ func test_nextStmt() { _, err = next(topic, consumer) g.TAssertEqual(err, fmt.Errorf( notOwnerErrorFmt, - otherID, + otherCfg.instanceID, topic, consumer, instanceID, @@ -478,6 +549,7 @@ func test_messageEach() { topic = "messageEach() topic" payloadStr = "messageEach() payload" consumer = "messageEach() consumer" + dbpath = golite.InMemory prefix = defaultPrefix ) var ( @@ -490,14 +562,19 @@ func test_messageEach() { } ) - db, err := sql.Open(golite.DriverName, golite.InMemory) + db, err := sql.Open(golite.DriverName, dbpath) g.TErrorIf(err) g.TErrorIf(createTables(db, prefix)) - instanceID := os.Getpid() - take, takeClose, takeErr := takeStmt(db, prefix, instanceID) - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) - pending, pendingClose, pendingErr := pendingStmt(db, prefix, instanceID) + cfg := dbconfigT{ + shared: db, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + take, takeClose, takeErr := takeStmt(cfg) + publish, publishClose, publishErr := publishStmt(cfg) + pending, pendingClose, pendingErr := pendingStmt(cfg) g.TErrorIf(g.SomeError( takeErr, publishErr, @@ -640,6 +717,7 @@ func test_pendingStmt() { topic = "pending() topic" payloadStr = "pending() payload" consumer = "pending() consumer" + dbpath = golite.InMemory prefix = defaultPrefix ) var ( @@ -652,16 +730,21 @@ func test_pendingStmt() { } ) - db, err := sql.Open(golite.DriverName, golite.InMemory) + db, err := sql.Open(golite.DriverName, dbpath) g.TErrorIf(err) g.TErrorIf(createTables(db, prefix)) - instanceID := os.Getpid() - take, takeClose, takeErr := takeStmt(db, prefix, instanceID) - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) - pending, pendingClose, pendingErr := pendingStmt(db, prefix, instanceID) - commit, commitClose, commitErr := commitStmt(db, prefix, instanceID) - toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID) + cfg := dbconfigT{ + shared: db, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + take, takeClose, takeErr := takeStmt(cfg) + publish, publishClose, publishErr := publishStmt(cfg) + pending, pendingClose, pendingErr := pendingStmt(cfg) + commit, commitClose, commitErr := commitStmt(cfg) + toDead, toDeadClose, toDeadErr := toDeadStmt(cfg) g.TErrorIf(g.SomeError( takeErr, publishErr, @@ -820,9 +903,10 @@ func test_pendingStmt() { }) g.Testing("when we're not the owners we get nothing", func() { - otherID := instanceID + 1 + otherCfg := cfg + otherCfg.instanceID = instanceID + 1 - take, takeClose, takeErr := takeStmt(db, prefix, otherID) + take, takeClose, takeErr := takeStmt(otherCfg) g.TErrorIf(takeErr) defer takeClose() @@ -871,6 +955,7 @@ func test_commitStmt() { topic = "commit() topic" payloadStr = "commit() payload" consumer = "commit() consumer" + dbpath = golite.InMemory prefix = defaultPrefix ) var ( @@ -883,15 +968,20 @@ func test_commitStmt() { } ) - db, err := sql.Open(golite.DriverName, golite.InMemory) + db, err := sql.Open(golite.DriverName, dbpath) g.TErrorIf(err) g.TErrorIf(createTables(db, prefix)) - instanceID := os.Getpid() - take, takeClose, takeErr := takeStmt(db, prefix, instanceID) - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) - commit, commitClose, commitErr := commitStmt(db, prefix, instanceID) - toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID) + cfg := dbconfigT{ + shared: db, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + take, takeClose, takeErr := takeStmt(cfg) + publish, publishClose, publishErr := publishStmt(cfg) + commit, commitClose, commitErr := commitStmt(cfg) + toDead, toDeadClose, toDeadErr := toDeadStmt(cfg) g.TErrorIf(g.SomeError( takeErr, publishErr, @@ -937,7 +1027,10 @@ func test_commitStmt() { g.Testing("we can't commit non-existent messages", func() { err := cmt(consumer, guuid.New()) - g.TAssertEqual(err, sql.ErrNoRows) + g.TAssertEqual( + err.(golite.Error).ExtendedCode, + golite.ErrConstraintNotNull, + ) }) g.Testing("multiple consumers may commit a message", func() { @@ -987,8 +1080,10 @@ func test_commitStmt() { }) g.Testing("error if we don't own the topic/consumer", func() { - otherID := instanceID + 1 - take, takeClose, takeErr := takeStmt(db, prefix, otherID) + otherCfg := cfg + otherCfg.instanceID = instanceID + 1 + + take, takeClose, takeErr := takeStmt(otherCfg) g.TErrorIf(takeErr) defer takeClose() @@ -998,13 +1093,10 @@ func test_commitStmt() { g.TErrorIf(err) err = commit(consumer, messageID) - g.TAssertEqual(err, fmt.Errorf( - noLongerOwnerErrorFmt, - instanceID, - topic, - consumer, - otherID, - )) + g.TAssertEqual( + err.(golite.Error).ExtendedCode, + golite.ErrConstraintTrigger, + ) }) g.Testing("no actual closing occurs", func() { @@ -1023,6 +1115,7 @@ func test_toDeadStmt() { topic = "toDead() topic" payloadStr = "toDead() payload" consumer = "toDead() consumer" + dbpath = golite.InMemory prefix = defaultPrefix ) var ( @@ -1035,15 +1128,20 @@ func test_toDeadStmt() { } ) - db, err := sql.Open(golite.DriverName, golite.InMemory) + db, err := sql.Open(golite.DriverName, dbpath) g.TErrorIf(err) g.TErrorIf(createTables(db, prefix)) - instanceID := os.Getpid() - take, takeClose, takeErr := takeStmt(db, prefix, instanceID) - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) - commit, commitClose, commitErr := commitStmt(db, prefix, instanceID) - toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID) + cfg := dbconfigT{ + shared: db, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + take, takeClose, takeErr := takeStmt(cfg) + publish, publishClose, publishErr := publishStmt(cfg) + commit, commitClose, commitErr := commitStmt(cfg) + toDead, toDeadClose, toDeadErr := toDeadStmt(cfg) g.TErrorIf(g.SomeError( takeErr, publishErr, @@ -1107,7 +1205,10 @@ func test_toDeadStmt() { g.Testing("we can't mark as dead non-existent messages", func() { err := asDead(consumer, guuid.New(), guuid.New()) - g.TAssertEqual(err, sql.ErrNoRows) + g.TAssertEqual( + err.(golite.Error).ExtendedCode, + golite.ErrConstraintNotNull, + ) }) g.Testing("multiple consumers may mark a message as dead", func() { @@ -1173,11 +1274,13 @@ func test_toDeadStmt() { }) g.Testing("error if we don't own the message's consumer/topic", func() { - otherID := instanceID + 1 + otherCfg := cfg + otherCfg.instanceID = instanceID + 1 + messageID1 := pub(topic) messageID2 := pub(topic) - take, takeClose, takeErr := takeStmt(db, prefix, otherID) + take, takeClose, takeErr := takeStmt(otherCfg) g.TErrorIf(takeErr) defer takeClose() @@ -1188,13 +1291,10 @@ func test_toDeadStmt() { g.TErrorIf(err) err = toDead(consumer, messageID2, guuid.New()) - g.TAssertEqual(err, fmt.Errorf( - noLongerOwnerErrorFmt, - instanceID, - topic, - consumer, - otherID, - )) + g.TAssertEqual( + err.(golite.Error).ExtendedCode, + golite.ErrConstraintTrigger, + ) }) g.Testing("no actual closing occurs", func() { @@ -1213,6 +1313,7 @@ func test_replayStmt() { topic = "replay() topic" payloadStr = "replay() payload" consumer = "replay() consumer" + dbpath = golite.InMemory prefix = defaultPrefix ) var ( @@ -1225,15 +1326,20 @@ func test_replayStmt() { } ) - db, err := sql.Open(golite.DriverName, golite.InMemory) + db, err := sql.Open(golite.DriverName, dbpath) g.TErrorIf(err) g.TErrorIf(createTables(db, prefix)) - instanceID := os.Getpid() - take, takeClose, takeErr := takeStmt(db, prefix, instanceID) - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) - toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID) - replay, replayClose, replayErr := replayStmt(db, prefix, instanceID) + cfg := dbconfigT{ + shared: db, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + take, takeClose, takeErr := takeStmt(cfg) + publish, publishClose, publishErr := publishStmt(cfg) + toDead, toDeadClose, toDeadErr := toDeadStmt(cfg) + replay, replayClose, replayErr := replayStmt(cfg) g.TErrorIf(g.SomeError( takeErr, publishErr, @@ -1343,6 +1449,7 @@ func test_oneDeadStmt() { topic = "oneDead() topic" payloadStr = "oneDead() payload" consumer = "oneDead() consumer" + dbpath = golite.InMemory prefix = defaultPrefix ) var ( @@ -1355,16 +1462,21 @@ func test_oneDeadStmt() { } ) - db, err := sql.Open(golite.DriverName, golite.InMemory) + db, err := sql.Open(golite.DriverName, dbpath) g.TErrorIf(err) g.TErrorIf(createTables(db, prefix)) - instanceID := os.Getpid() - take, takeClose, takeErr := takeStmt(db, prefix, instanceID) - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) - toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID) - replay, replayClose, replayErr := replayStmt(db, prefix, instanceID) - oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(db, prefix, instanceID) + cfg := dbconfigT{ + shared: db, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + take, takeClose, takeErr := takeStmt(cfg) + publish, publishClose, publishErr := publishStmt(cfg) + toDead, toDeadClose, toDeadErr := toDeadStmt(cfg) + replay, replayClose, replayErr := replayStmt(cfg) + oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(cfg) g.TErrorIf(g.SomeError( takeErr, publishErr, @@ -1457,6 +1569,7 @@ func test_deadletterEach() { topic = "deadletterEach() topic" payloadStr = "deadletterEach() payload" consumer = "deadletterEach() consumer" + dbpath = golite.InMemory prefix = defaultPrefix ) var ( @@ -1469,15 +1582,20 @@ func test_deadletterEach() { } ) - db, err := sql.Open(golite.DriverName, golite.InMemory) + db, err := sql.Open(golite.DriverName, dbpath) g.TErrorIf(err) g.TErrorIf(createTables(db, prefix)) - instanceID := os.Getpid() - take, takeClose, takeErr := takeStmt(db, prefix, instanceID) - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) - toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID) - allDead, allDeadClose, allDeadErr := allDeadStmt(db, prefix, instanceID) + cfg := dbconfigT{ + shared: db, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + take, takeClose, takeErr := takeStmt(cfg) + publish, publishClose, publishErr := publishStmt(cfg) + toDead, toDeadClose, toDeadErr := toDeadStmt(cfg) + allDead, allDeadClose, allDeadErr := allDeadStmt(cfg) g.TErrorIf(g.SomeError( takeErr, publishErr, @@ -1631,6 +1749,7 @@ func test_allDeadStmt() { topic = "allDead() topic" payloadStr = "allDead() payload" consumer = "allDead() consumer" + dbpath = golite.InMemory prefix = defaultPrefix ) var ( @@ -1643,16 +1762,21 @@ func test_allDeadStmt() { } ) - db, err := sql.Open(golite.DriverName, golite.InMemory) + db, err := sql.Open(golite.DriverName, dbpath) g.TErrorIf(err) g.TErrorIf(createTables(db, prefix)) - instanceID := os.Getpid() - take, takeClose, takeErr := takeStmt(db, prefix, instanceID) - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) - toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID) - replay, replayClose, replayErr := replayStmt(db, prefix, instanceID) - allDead, allDeadClose, allDeadErr := allDeadStmt(db, prefix, instanceID) + cfg := dbconfigT{ + shared: db, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + take, takeClose, takeErr := takeStmt(cfg) + publish, publishClose, publishErr := publishStmt(cfg) + toDead, toDeadClose, toDeadErr := toDeadStmt(cfg) + replay, replayClose, replayErr := replayStmt(cfg) + allDead, allDeadClose, allDeadErr := allDeadStmt(cfg) g.TErrorIf(g.SomeError( takeErr, publishErr, @@ -1784,6 +1908,7 @@ func test_sizeStmt() { topic = "size() topic" payloadStr = "size() payload" consumer = "size() consumer" + dbpath = golite.InMemory prefix = defaultPrefix ) var ( @@ -1796,17 +1921,22 @@ func test_sizeStmt() { } ) - db, err := sql.Open(golite.DriverName, golite.InMemory) + db, err := sql.Open(golite.DriverName, dbpath) g.TErrorIf(err) g.TErrorIf(createTables(db, prefix)) - instanceID := os.Getpid() - take, takeClose, takeErr := takeStmt(db, prefix, instanceID) - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) - toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID) - replay, replayClose, replayErr := replayStmt(db, prefix, instanceID) - oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(db, prefix, instanceID) - size, sizeClose, sizeErr := sizeStmt(db, prefix, instanceID) + cfg := dbconfigT{ + shared: db, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + take, takeClose, takeErr := takeStmt(cfg) + publish, publishClose, publishErr := publishStmt(cfg) + toDead, toDeadClose, toDeadErr := toDeadStmt(cfg) + replay, replayClose, replayErr := replayStmt(cfg) + oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(cfg) + size, sizeClose, sizeErr := sizeStmt(cfg) g.TErrorIf(g.SomeError( takeErr, publishErr, @@ -1892,6 +2022,7 @@ func test_countStmt() { topic = "count() topic" payloadStr = "count() payload" consumer = "count() consumer" + dbpath = golite.InMemory prefix = defaultPrefix ) var ( @@ -1904,17 +2035,22 @@ func test_countStmt() { } ) - db, err := sql.Open(golite.DriverName, golite.InMemory) + db, err := sql.Open(golite.DriverName, dbpath) g.TErrorIf(err) g.TErrorIf(createTables(db, prefix)) - instanceID := os.Getpid() - take, takeClose, takeErr := takeStmt(db, prefix, instanceID) - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) - next, nextClose, nextErr := nextStmt(db, prefix, instanceID) - commit, commitClose, commitErr := commitStmt(db, prefix, instanceID) - toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID) - count, countClose, countErr := countStmt(db, prefix, instanceID) + cfg := dbconfigT{ + shared: db, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + take, takeClose, takeErr := takeStmt(cfg) + publish, publishClose, publishErr := publishStmt(cfg) + next, nextClose, nextErr := nextStmt(cfg) + commit, commitClose, commitErr := commitStmt(cfg) + toDead, toDeadClose, toDeadErr := toDeadStmt(cfg) + count, countClose, countErr := countStmt(cfg) g.TErrorIf(g.SomeError( takeErr, publishErr, @@ -2020,6 +2156,7 @@ func test_hasDataStmt() { topic = "hasData() topic" payloadStr = "hasData() payload" consumer = "hasData() consumer" + dbpath = golite.InMemory prefix = defaultPrefix ) var ( @@ -2032,17 +2169,22 @@ func test_hasDataStmt() { } ) - db, err := sql.Open(golite.DriverName, golite.InMemory) + db, err := sql.Open(golite.DriverName, dbpath) g.TErrorIf(err) g.TErrorIf(createTables(db, prefix)) - instanceID := os.Getpid() - take, takeClose, takeErr := takeStmt(db, prefix, instanceID) - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) - next, nextClose, nextErr := nextStmt(db, prefix, instanceID) - commit, commitClose, commitErr := commitStmt(db, prefix, instanceID) - toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID) - hasData, hasDataClose, hasDataErr := hasDataStmt(db, prefix, instanceID) + cfg := dbconfigT{ + shared: db, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + take, takeClose, takeErr := takeStmt(cfg) + publish, publishClose, publishErr := publishStmt(cfg) + next, nextClose, nextErr := nextStmt(cfg) + commit, commitClose, commitErr := commitStmt(cfg) + toDead, toDeadClose, toDeadErr := toDeadStmt(cfg) + hasData, hasDataClose, hasDataErr := hasDataStmt(cfg) g.TErrorIf(g.SomeError( takeErr, publishErr, @@ -2148,6 +2290,8 @@ func test_initDB() { topic = "initDB() topic" payloadStr = "initDB() payload" consumer = "initDB() consumer" + dbpath = golite.InMemory + prefix = defaultPrefix ) var ( flowID = guuid.New() @@ -2159,17 +2303,12 @@ func test_initDB() { } ) - db, err := sql.Open(golite.DriverName, golite.InMemory) - g.TErrorIf(err) - defer db.Close() - var messages []messageT notifyFn := func(message messageT) { messages = append(messages, message) } - instanceID := os.Getpid() - queries, err := initDB(db, defaultPrefix, notifyFn, instanceID) + queries, err := initDB(dbpath, prefix, notifyFn, instanceID) g.TErrorIf(err) defer queries.close() @@ -2256,31 +2395,16 @@ func test_initDB() { func test_queriesTclose() { g.TestStart("queriesT.close()") - db, err := sql.Open(golite.DriverName, golite.InMemory) - g.TErrorIf(err) - defer db.Close() + const ( + dbpath = golite.InMemory + prefix = defaultPrefix + ) - instanceID := os.Getpid() - queries, err := initDB(db, defaultPrefix, func(messageT) {}, instanceID) + notifyFn := func(messageT) {} + queries, err := initDB(dbpath, prefix, notifyFn, instanceID) g.TErrorIf(err) - g.Testing("after closing, we can't run queries", func() { - unsent := UnsentMessage{ Payload: []byte{}, } - _, err := queries.publish(unsent, guuid.New()) - g.TErrorIf(err) - g.TErrorIf(db.Close()) - - err = queries.close() - g.TErrorIf(err) - - _, err = queries.publish(unsent, guuid.New()) - g.TAssertEqual( - err.Error(), - "sql: database is closed", - ) - }) - g.Testing("closing mre than once does not error", func() { g.TErrorIf(g.SomeError( queries.close(), @@ -2289,7 +2413,6 @@ func test_queriesTclose() { }) } - func test_newPinger() { g.TestStart("newPinger()") @@ -3287,6 +3410,7 @@ func test_queueT_Publish() { const ( topic = "queueT.Publish() topic" payloadStr = "queueT.Publish() payload" + dbpath = golite.InMemory ) var ( flowID = guuid.New() @@ -3298,7 +3422,7 @@ func test_queueT_Publish() { } ) - queue, err := New(golite.InMemory) + queue, err := New(dbpath) g.TErrorIf(err) defer queue.Close() @@ -4350,11 +4474,7 @@ func test_queueT_Close() { queriesErr = errors.New("queriesT{} error") ) - db, err := sql.Open(golite.DriverName, golite.InMemory) - g.TErrorIf(err) - queue := queueT{ - db: db, queries: queriesT{ close: func() error{ queriesCount++ @@ -4376,7 +4496,7 @@ func test_queueT_Close() { }, } - err = queue.Close() + err := queue.Close() g.TAssertEqual(err, g.WrapErrors(subscriptionsErr, queriesErr)) g.TAssertEqual(pingerCount, 1) g.TAssertEqual(subscriptionsCount, 1) @@ -5674,6 +5794,7 @@ func dumpQueries() { { "take", takeSQL }, { "publish", publishSQL }, { "find", findSQL }, + { "next", nextSQL }, { "pending", pendingSQL }, { "commit", commitSQL }, { "toDead", toDeadSQL }, @@ -5703,6 +5824,8 @@ func MainTest() { g.Init() test_defaultPrefix() + test_serialized() + test_execSerialized() test_tryRollback() test_inTx() test_createTables() diff --git a/tests/queries.sql b/tests/queries.sql index c821e25..e790d41 100644 --- a/tests/queries.sql +++ b/tests/queries.sql @@ -27,6 +27,7 @@ consumer TEXT NOT NULL, message_id INTEGER NOT NULL REFERENCES "q_messages"(id), + instance_id INTEGER NOT NULL, UNIQUE (consumer, message_id) ) STRICT; CREATE INDEX IF NOT EXISTS "q_offsets_consumer" @@ -38,6 +39,7 @@ consumer TEXT NOT NULL, message_id INTEGER NOT NULL REFERENCES "q_messages"(id), + instance_id INTEGER NOT NULL, UNIQUE (consumer, message_id) ) STRICT; CREATE INDEX IF NOT EXISTS "q_deadletters_consumer" @@ -58,6 +60,44 @@ owner_id INTEGER NOT NULL, UNIQUE (topic, consumer) ) STRICT; + + CREATE TRIGGER IF NOT EXISTS "q_check_instance_owns_topic" + BEFORE INSERT ON "q_offsets" + WHEN NEW.instance_id != ( + SELECT owner_id FROM "q_owners" + WHERE topic = ( + SELECT "q_payloads".topic + FROM "q_payloads" + JOIN "q_messages" ON "q_payloads".id = + "q_messages".payload_id + WHERE "q_messages".id = NEW.message_id + ) AND consumer = NEW.consumer + ) + BEGIN + SELECT RAISE( + ABORT, + 'instance does not own topic/consumer combo' + ); + END; + + CREATE TRIGGER IF NOT EXISTS "q_check_can_publish_deadletter" + BEFORE INSERT ON "q_deadletters" + WHEN NEW.instance_id != ( + SELECT owner_id FROM "q_owners" + WHERE topic = ( + SELECT "q_payloads".topic + FROM "q_payloads" + JOIN "q_messages" ON "q_payloads".id = + "q_messages".payload_id + WHERE "q_messages".id = NEW.message_id + ) AND consumer = NEW.consumer + ) + BEGIN + SELECT RAISE( + ABORT, + 'Instance does not own topic/consumer combo' + ); + END; -- read: @@ -81,7 +121,6 @@ INSERT INTO "q_payloads" (topic, payload) VALUES (?, ?); - -- FIXME: must be inside a trnsaction INSERT INTO "q_messages" (uuid, flow_id, payload_id) VALUES (?, ?, last_insert_rowid()); @@ -114,6 +153,38 @@ -- owner: +-- next.sql: +-- write: + +-- read: + SELECT + ( + SELECT owner_id FROM "q_owners" + WHERE + topic = ? AND + consumer = ? + LIMIT 1 + ) AS owner_id, + "q_messages".id, + "q_messages".timestamp, + "q_messages".uuid, + "q_messages".flow_id, + "q_payloads".payload + FROM "q_messages" + JOIN "q_payloads" ON + "q_payloads".id = "q_messages".payload_id + WHERE + "q_payloads".topic = ? AND + "q_messages".id NOT IN ( + SELECT message_id FROM "q_offsets" + WHERE consumer = ? + ) + ORDER BY "q_messages".id ASC + LIMIT 1; + + +-- owner: + -- pending.sql: -- write: @@ -146,46 +217,28 @@ -- commit.sql: -- write: - INSERT INTO "q_offsets" (consumer, message_id) - VALUES (?, (SELECT id FROM "q_messages" WHERE uuid = ?)); + INSERT INTO "q_offsets" (consumer, message_id, instance_id) + VALUES (?, (SELECT id FROM "q_messages" WHERE uuid = ?), ?); -- read: - SELECT "q_payloads".topic from "q_payloads" - JOIN "q_messages" ON - "q_payloads".id = "q_messages".payload_id - WHERE "q_messages".uuid = ?; - -- owner: - SELECT owner_id FROM "q_owners" - WHERE - topic = ? AND - consumer = ?; - -- toDead.sql: -- write: - INSERT INTO "q_offsets" ( consumer, message_id) - VALUES ( ?, (SELECT id FROM "q_messages" WHERE uuid = ?)); + INSERT INTO "q_offsets" + ( consumer, message_id, instance_id) + VALUES ( ?, (SELECT id FROM "q_messages" WHERE uuid = ?), ?); - INSERT INTO "q_deadletters" (uuid, consumer, message_id) - VALUES (?, ?, (SELECT id FROM "q_messages" WHERE uuid = ?)); + INSERT INTO "q_deadletters" + (uuid, consumer, message_id, instance_id) + VALUES (?, ?, (SELECT id FROM "q_messages" WHERE uuid = ?), ?); -- read: - SELECT "q_payloads".topic FROM "q_payloads" - JOIN "q_messages" ON - "q_payloads".id = "q_messages".payload_id - WHERE "q_messages".uuid = ?; - -- owner: - SELECT owner_id FROM "q_owners" - WHERE - topic = ? AND - consumer = ?; - -- replay.sql: -- write: |