diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/q.go | 652 |
1 files changed, 333 insertions, 319 deletions
@@ -1,6 +1,6 @@ package q + import ( - "context" "database/sql" "flag" "fmt" @@ -18,15 +18,21 @@ import ( const ( - defaultPrefix = "q" - reaperSkipCount = 1000 - notOwnerErrorFmt = "%v owns %#v as %#v, not us (%v)" - noLongerOwnerErrorFmt = "we (%v) no longer own %#v as %#v, but %v does" - rollbackErrorFmt = "rollback error: %w; while executing: %w" + defaultPrefix = "q" + reaperSkipCount = 1000 + notOwnerErrorFmt = "%v owns %#v as %#v, not us (%v)" + rollbackErrorFmt = "rollback error: %w; while executing: %w" ) +type dbconfigT struct{ + shared *sql.DB + dbpath string + prefix string + instanceID int +} + type queryT struct{ write string read string @@ -126,7 +132,6 @@ type subscriptionsT struct { } type queueT struct{ - db *sql.DB queries queriesT subscriptions subscriptionsT pinger pingerT[struct{}] @@ -158,12 +163,55 @@ type IQueue interface{ -func closeNoop() error { - return nil +func serialized[A any, B any](callback func(...A) B) (func(...A) B, func()) { + in := make(chan []A) + out := make(chan B) + + closed := false + var ( + closeWg sync.WaitGroup + closeMutex sync.Mutex + ) + closeWg.Add(1) + + go func() { + for input := range in { + out <- callback(input...) + } + close(out) + closeWg.Done() + }() + + fn := func(input ...A) B { + in <- input + return (<- out) + } + + closeFn := func() { + closeMutex.Lock() + defer closeMutex.Unlock() + if closed { + return + } + close(in) + closed = true + closeWg.Wait() + } + + return fn, closeFn +} + +func execSerialized(query string, db *sql.DB) (func(...any) error, func()) { + return serialized(func(args ...any) error { + return inTx(db, func(tx *sql.Tx) error { + _, err := tx.Exec(query, args...) + return err + }) + }) } -func tryRollback(db *sql.DB, ctx context.Context, err error) error { - _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;") +func tryRollback(tx *sql.Tx, err error) error { + rollbackErr := tx.Rollback() if rollbackErr != nil { return fmt.Errorf( rollbackErrorFmt, @@ -175,25 +223,20 @@ func tryRollback(db *sql.DB, ctx context.Context, err error) error { return err } -// FIXME -// See: -// https://sqlite.org/forum/forumpost/2507664507 -func inTx(db *sql.DB, fn func(context.Context) error) error { - ctx := context.Background() - - _, err := db.ExecContext(ctx, "BEGIN IMMEDIATE;") +func inTx(db *sql.DB, fn func(*sql.Tx) error) error { + tx, err := db.Begin() if err != nil { return err } - err = fn(ctx) + err = fn(tx) if err != nil { - return tryRollback(db, ctx, err) + return tryRollback(tx, err) } - _, err = db.ExecContext(ctx, "COMMIT;") + err = tx.Commit() if err != nil { - return tryRollback(db, ctx, err) + return tryRollback(tx, err) } return nil @@ -227,6 +270,7 @@ func createTablesSQL(prefix string) queryT { consumer TEXT NOT NULL, message_id INTEGER NOT NULL REFERENCES "%s_messages"(id), + instance_id INTEGER NOT NULL, UNIQUE (consumer, message_id) ) STRICT; CREATE INDEX IF NOT EXISTS "%s_offsets_consumer" @@ -238,6 +282,7 @@ func createTablesSQL(prefix string) queryT { consumer TEXT NOT NULL, message_id INTEGER NOT NULL REFERENCES "%s_messages"(id), + instance_id INTEGER NOT NULL, UNIQUE (consumer, message_id) ) STRICT; CREATE INDEX IF NOT EXISTS "%s_deadletters_consumer" @@ -258,6 +303,44 @@ func createTablesSQL(prefix string) queryT { owner_id INTEGER NOT NULL, UNIQUE (topic, consumer) ) STRICT; + + CREATE TRIGGER IF NOT EXISTS "%s_check_instance_owns_topic" + BEFORE INSERT ON "%s_offsets" + WHEN NEW.instance_id != ( + SELECT owner_id FROM "%s_owners" + WHERE topic = ( + SELECT "%s_payloads".topic + FROM "%s_payloads" + JOIN "%s_messages" ON "%s_payloads".id = + "%s_messages".payload_id + WHERE "%s_messages".id = NEW.message_id + ) AND consumer = NEW.consumer + ) + BEGIN + SELECT RAISE( + ABORT, + 'instance does not own topic/consumer combo' + ); + END; + + CREATE TRIGGER IF NOT EXISTS "%s_check_can_publish_deadletter" + BEFORE INSERT ON "%s_deadletters" + WHEN NEW.instance_id != ( + SELECT owner_id FROM "%s_owners" + WHERE topic = ( + SELECT "%s_payloads".topic + FROM "%s_payloads" + JOIN "%s_messages" ON "%s_payloads".id = + "%s_messages".payload_id + WHERE "%s_messages".id = NEW.message_id + ) AND consumer = NEW.consumer + ) + BEGIN + SELECT RAISE( + ABORT, + 'Instance does not own topic/consumer combo' + ); + END; ` return queryT{ write: fmt.Sprintf( @@ -284,6 +367,24 @@ func createTablesSQL(prefix string) queryT { prefix, prefix, prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, ), } } @@ -291,8 +392,8 @@ func createTablesSQL(prefix string) queryT { func createTables(db *sql.DB, prefix string) error { q := createTablesSQL(prefix) - return inTx(db, func(ctx context.Context) error { - _, err := db.ExecContext(ctx, q.write) + return inTx(db, func(tx *sql.Tx) error { + _, err := tx.Exec(q.write) return err }) } @@ -310,26 +411,25 @@ func takeSQL(prefix string) queryT { } func takeStmt( - db *sql.DB, - prefix string, - instanceID int, + cfg dbconfigT, ) (func(string, string) error, func() error, error) { - q := takeSQL(prefix) + q := takeSQL(cfg.prefix) + + writeStmt, err := cfg.shared.Prepare(q.write) + if err != nil { + return nil, nil, err + } fn := func(topic string, consumer string) error { - return inTx(db, func(ctx context.Context) error { - _, err := db.ExecContext( - ctx, - q.write, - topic, - consumer, - instanceID, - ) - return err - }) + _, err := writeStmt.Exec( + topic, + consumer, + cfg.instanceID, + ) + return err } - return fn, closeNoop, nil + return fn, writeStmt.Close, nil } func publishSQL(prefix string) queryT { @@ -337,7 +437,6 @@ func publishSQL(prefix string) queryT { INSERT INTO "%s_payloads" (topic, payload) VALUES (?, ?); - -- FIXME: must be inside a trnsaction INSERT INTO "%s_messages" (uuid, flow_id, payload_id) VALUES (?, ?, last_insert_rowid()); ` @@ -352,17 +451,23 @@ func publishSQL(prefix string) queryT { } func publishStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(UnsentMessage, guuid.UUID) (messageT, error), func() error, error) { - q := publishSQL(prefix) + q := publishSQL(cfg.prefix) + + readStmt, err := cfg.shared.Prepare(q.read) + if err != nil { + return nil, nil, err + } - readStmt, err := db.Prepare(q.read) + privateDB, err := sql.Open(golite.DriverName, cfg.dbpath) if err != nil { + readStmt.Close() return nil, nil, err } + writeFn, writeFnClose := execSerialized(q.write, privateDB) + fn := func( unsentMessage UnsentMessage, messageID guuid.UUID, @@ -376,17 +481,12 @@ func publishStmt( message_id_bytes := messageID[:] flow_id_bytes := unsentMessage.FlowID[:] - err := inTx(db, func(ctx context.Context) error { - _, err := db.ExecContext( - ctx, - q.write, - unsentMessage.Topic, - unsentMessage.Payload, - message_id_bytes, - flow_id_bytes, - ) - return err - }) + err := writeFn( + unsentMessage.Topic, + unsentMessage.Payload, + message_id_bytes, + flow_id_bytes, + ) if err != nil { return messageT{}, err } @@ -408,7 +508,12 @@ func publishStmt( return message, nil } - return fn, readStmt.Close, nil + closeFn := func() error { + writeFnClose() + return g.SomeError(privateDB.Close(), readStmt.Close()) + } + + return fn, closeFn, nil } func findSQL(prefix string) queryT { @@ -446,13 +551,11 @@ func findSQL(prefix string) queryT { } func findStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string, guuid.UUID) (messageT, error), func() error, error) { - q := findSQL(prefix) + q := findSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -493,6 +596,13 @@ func findStmt( func nextSQL(prefix string) queryT { const tmpl_read = ` SELECT + ( + SELECT owner_id FROM "%s_owners" + WHERE + topic = ? AND + consumer = ? + LIMIT 1 + ) AS owner_id, "%s_messages".id, "%s_messages".timestamp, "%s_messages".uuid, @@ -510,12 +620,6 @@ func nextSQL(prefix string) queryT { ORDER BY "%s_messages".id ASC LIMIT 1; ` - const tmpl_owner = ` - SELECT owner_id FROM "%s_owners" - WHERE - topic = ? AND - consumer = ?; - ` return queryT{ read: fmt.Sprintf( tmpl_read, @@ -532,17 +636,20 @@ func nextSQL(prefix string) queryT { prefix, prefix, prefix, + prefix, ), - owner: fmt.Sprintf(tmpl_owner, prefix), } } func nextStmt( - db *sql.DB, - prefix string, - instanceID int, + cfg dbconfigT, ) (func(string, string) (messageT, error), func() error, error) { - q := nextSQL(prefix) + q := nextSQL(cfg.prefix) + + readStmt, err := cfg.shared.Prepare(q.read) + if err != nil { + return nil, nil, err + } fn := func(topic string, consumer string) (messageT, error) { message := messageT{ @@ -550,44 +657,34 @@ func nextStmt( } var ( - err error ownerID int timestr string message_id_bytes []byte flow_id_bytes []byte ) - tx, err := db.Begin() - if err != nil { - return messageT{}, err - } - defer tx.Rollback() - err = tx.QueryRow(q.owner, topic, consumer).Scan(&ownerID) + err = readStmt.QueryRow(topic, consumer, topic, consumer).Scan( + &ownerID, + &message.id, + ×tr, + &message_id_bytes, + &flow_id_bytes, + &message.payload, + ) if err != nil { return messageT{}, err } - if ownerID != instanceID { + if ownerID != cfg.instanceID { err := fmt.Errorf( notOwnerErrorFmt, ownerID, topic, consumer, - instanceID, + cfg.instanceID, ) return messageT{}, err } - - err = tx.QueryRow(q.read, topic, consumer).Scan( - &message.id, - ×tr, - &message_id_bytes, - &flow_id_bytes, - &message.payload, - ) - if err != nil { - return messageT{}, err - } message.uuid = guuid.UUID(message_id_bytes) message.flowID = guuid.UUID(flow_id_bytes) @@ -599,7 +696,7 @@ func nextStmt( return message, nil } - return fn, closeNoop, nil + return fn, readStmt.Close, nil } func messageEach(rows *sql.Rows, callback func(messageT) error) error { @@ -623,19 +720,22 @@ func messageEach(rows *sql.Rows, callback func(messageT) error) error { &message.payload, ) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } message.uuid = guuid.UUID(message_id_bytes) message.flowID = guuid.UUID(flow_id_bytes) message.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } err = callback(message) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } } @@ -691,20 +791,19 @@ func pendingSQL(prefix string) queryT { } func pendingStmt( - db *sql.DB, - prefix string, - instanceID int, + cfg dbconfigT, ) (func(string, string) (*sql.Rows, error), func() error, error) { - q := pendingSQL(prefix) + q := pendingSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } - ownerStmt, err := db.Prepare(q.owner) + ownerStmt, err := cfg.shared.Prepare(q.owner) if err != nil { - return nil, nil, g.WrapErrors(readStmt.Close(), err) + readStmt.Close() + return nil, nil, err } fn := func(topic string, consumer string) (*sql.Rows, error) { @@ -716,7 +815,7 @@ func pendingStmt( // best effort check, the final one is done during // commit within a transaction - if ownerID != instanceID { + if ownerID != cfg.instanceID { return nil, nil } @@ -732,130 +831,67 @@ func pendingStmt( func commitSQL(prefix string) queryT { const tmpl_write = ` - INSERT INTO "%s_offsets" (consumer, message_id) - VALUES (?, (SELECT id FROM "%s_messages" WHERE uuid = ?)); - ` - const tmpl_read = ` - SELECT "%s_payloads".topic from "%s_payloads" - JOIN "%s_messages" ON - "%s_payloads".id = "%s_messages".payload_id - WHERE "%s_messages".uuid = ?; - ` - const tmpl_owner = ` - SELECT owner_id FROM "%s_owners" - WHERE - topic = ? AND - consumer = ?; + INSERT INTO "%s_offsets" (consumer, message_id, instance_id) + VALUES (?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?); ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix, prefix), - read: fmt.Sprintf( - tmpl_read, - prefix, - prefix, - prefix, - prefix, - prefix, - prefix, - ), - owner: fmt.Sprintf(tmpl_owner, prefix), } } func commitStmt( - db *sql.DB, - prefix string, - instanceID int, + cfg dbconfigT, ) (func(string, guuid.UUID) error, func() error, error) { - q := commitSQL(prefix) + q := commitSQL(cfg.prefix) + + writeStmt, err := cfg.shared.Prepare(q.write) + if err != nil { + return nil, nil, err + } fn := func(consumer string, messageID guuid.UUID) error { message_id_bytes := messageID[:] - return inTx(db, func(ctx context.Context) error { - var topic string - err := db.QueryRowContext( - ctx, - q.read, - message_id_bytes, - ).Scan(&topic) - if err != nil { - return err - } - - var ownerID int - err = db.QueryRowContext( - ctx, - q.owner, - topic, - consumer, - ).Scan(&ownerID) - if err != nil { - return err - } - - if ownerID != instanceID { - return fmt.Errorf( - noLongerOwnerErrorFmt, - instanceID, - topic, - consumer, - ownerID, - ) - } - - _, err = db.ExecContext(ctx, q.write, consumer, message_id_bytes) - return err - }) + _, err = writeStmt.Exec( + consumer, + message_id_bytes, + cfg.instanceID, + ) + return err } - return fn, closeNoop, nil + return fn, writeStmt.Close, nil } func toDeadSQL(prefix string) queryT { const tmpl_write = ` - INSERT INTO "%s_offsets" ( consumer, message_id) - VALUES ( ?, (SELECT id FROM "%s_messages" WHERE uuid = ?)); + INSERT INTO "%s_offsets" + ( consumer, message_id, instance_id) + VALUES ( ?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?); - INSERT INTO "%s_deadletters" (uuid, consumer, message_id) - VALUES (?, ?, (SELECT id FROM "%s_messages" WHERE uuid = ?)); - ` - const tmpl_read = ` - SELECT "%s_payloads".topic FROM "%s_payloads" - JOIN "%s_messages" ON - "%s_payloads".id = "%s_messages".payload_id - WHERE "%s_messages".uuid = ?; - ` - const tmpl_owner = ` - SELECT owner_id FROM "%s_owners" - WHERE - topic = ? AND - consumer = ?; + INSERT INTO "%s_deadletters" + (uuid, consumer, message_id, instance_id) + VALUES (?, ?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?); ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix, prefix, prefix, prefix), - read: fmt.Sprintf( - tmpl_read, - prefix, - prefix, - prefix, - prefix, - prefix, - prefix, - ), - owner: fmt.Sprintf(tmpl_owner, prefix), } } func toDeadStmt( - db *sql.DB, - prefix string, - instanceID int, + cfg dbconfigT, ) ( func(string, guuid.UUID, guuid.UUID) error, func() error, error, ) { - q := toDeadSQL(prefix) + q := toDeadSQL(cfg.prefix) + + privateDB, err := sql.Open(golite.DriverName, cfg.dbpath) + if err != nil { + return nil, nil, err + } + + writeFn, writeFnClose := execSerialized(q.write, privateDB) fn := func( consumer string, @@ -864,52 +900,24 @@ func toDeadStmt( ) error { message_id_bytes := messageID[:] deadletter_id_bytes := deadletterID[:] - return inTx(db, func(ctx context.Context) error { - var topic string - err := db.QueryRowContext( - ctx, - q.read, - message_id_bytes, - ).Scan(&topic) - if err != nil { - return err - } - - var ownerID int - err = db.QueryRowContext( - ctx, - q.owner, - topic, - consumer, - ).Scan(&ownerID) - if err != nil { - return err - } - - if ownerID != instanceID { - return fmt.Errorf( - noLongerOwnerErrorFmt, - instanceID, - topic, - consumer, - ownerID, - ) - } + return writeFn( + consumer, + message_id_bytes, + cfg.instanceID, + deadletter_id_bytes, + consumer, + message_id_bytes, + cfg.instanceID, + ) + } - _, err = db.ExecContext( - ctx, - q.write, - consumer, - message_id_bytes, - deadletter_id_bytes, - consumer, - message_id_bytes, - ) - return err - }) + closeFn := func() error { + writeFnClose() + return privateDB.Close() } - return fn, closeNoop, nil + + return fn, closeFn, nil } func replaySQL(prefix string) queryT { @@ -973,33 +981,34 @@ func replaySQL(prefix string) queryT { } func replayStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(guuid.UUID, guuid.UUID) (messageT, error), func() error, error) { - q := replaySQL(prefix) + q := replaySQL(cfg.prefix) + + readStmt, err := cfg.shared.Prepare(q.read) + if err != nil { + return nil, nil, err + } - readStmt, err := db.Prepare(q.read) + privateDB, err := sql.Open(golite.DriverName, cfg.dbpath) if err != nil { + readStmt.Close() return nil, nil, err } + writeFn, writeFnClose := execSerialized(q.write, privateDB) + fn := func( deadletterID guuid.UUID, messageID guuid.UUID, ) (messageT, error) { deadletter_id_bytes := deadletterID[:] message_id_bytes := messageID[:] - err := inTx(db, func(ctx context.Context) error { - _, err := db.ExecContext( - ctx, - q.write, - message_id_bytes, - deadletter_id_bytes, - deadletter_id_bytes, - ) - return err - }) + err := writeFn( + message_id_bytes, + deadletter_id_bytes, + deadletter_id_bytes, + ) if err != nil { return messageT{}, err } @@ -1032,7 +1041,12 @@ func replayStmt( return message, nil } - return fn, readStmt.Close, nil + closeFn := func() error { + writeFnClose() + return g.SomeError(privateDB.Close(), readStmt.Close()) + } + + return fn, closeFn, nil } func oneDeadSQL(prefix string) queryT { @@ -1085,13 +1099,11 @@ func oneDeadSQL(prefix string) queryT { } func oneDeadStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string, string) (deadletterT, error), func() error, error) { - q := oneDeadSQL(prefix) + q := oneDeadSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -1157,7 +1169,8 @@ func deadletterEach( &message.payload, ) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } deadletter.uuid = guuid.UUID(deadletter_id_bytes) @@ -1170,7 +1183,8 @@ func deadletterEach( messageTimestr, ) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } deadletter.timestamp, err = time.Parse( @@ -1178,12 +1192,14 @@ func deadletterEach( deadletterTimestr, ) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } err = callback(deadletter, message) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } } @@ -1251,13 +1267,11 @@ func allDeadSQL(prefix string) queryT { } func allDeadStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string, string) (*sql.Rows, error), func() error, error) { - q := allDeadSQL(prefix) + q := allDeadSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -1292,13 +1306,11 @@ func sizeSQL(prefix string) queryT { func sizeStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string) (int, error), func() error, error) { - q := sizeSQL(prefix) + q := sizeSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -1346,13 +1358,11 @@ func countSQL(prefix string) queryT { } func countStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string, string) (int, error), func() error, error) { - q := countSQL(prefix) + q := countSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -1399,13 +1409,11 @@ func hasDataSQL(prefix string) queryT { } func hasDataStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string, string) (bool, error), func() error, error) { - q := hasDataSQL(prefix) + q := hasDataSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -1428,27 +1436,44 @@ func hasDataStmt( } func initDB( - db *sql.DB, + dbpath string, prefix string, notifyFn func(messageT), instanceID int, ) (queriesT, error) { - createTablesErr := createTables(db, prefix) - take, takeClose, takeErr := takeStmt(db, prefix, instanceID) - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) - find, findClose, findErr := findStmt(db, prefix, instanceID) - next, nextClose, nextErr := nextStmt(db, prefix, instanceID) - pending, pendingClose, pendingErr := pendingStmt(db, prefix, instanceID) - commit, commitClose, commitErr := commitStmt(db, prefix, instanceID) - toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID) - replay, replayClose, replayErr := replayStmt(db, prefix, instanceID) - oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(db, prefix, instanceID) - allDead, allDeadClose, allDeadErr := allDeadStmt(db, prefix, instanceID) - size, sizeClose, sizeErr := sizeStmt(db, prefix, instanceID) - count, countClose, countErr := countStmt(db, prefix, instanceID) - hasData, hasDataClose, hasDataErr := hasDataStmt(db, prefix, instanceID) - - err := g.SomeError( + err := g.ValidateSQLTablePrefix(prefix) + if err != nil { + return queriesT{}, err + } + + shared, err := sql.Open(golite.DriverName, dbpath) + if err != nil { + return queriesT{}, err + } + + cfg := dbconfigT{ + shared: shared, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + + createTablesErr := createTables(shared, prefix) + take, takeClose, takeErr := takeStmt(cfg) + publish, publishClose, publishErr := publishStmt(cfg) + find, findClose, findErr := findStmt(cfg) + next, nextClose, nextErr := nextStmt(cfg) + pending, pendingClose, pendingErr := pendingStmt(cfg) + commit, commitClose, commitErr := commitStmt(cfg) + toDead, toDeadClose, toDeadErr := toDeadStmt(cfg) + replay, replayClose, replayErr := replayStmt(cfg) + oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(cfg) + allDead, allDeadClose, allDeadErr := allDeadStmt(cfg) + size, sizeClose, sizeErr := sizeStmt(cfg) + count, countClose, countErr := countStmt(cfg) + hasData, hasDataClose, hasDataErr := hasDataStmt(cfg) + + err = g.SomeError( createTablesErr, takeErr, publishErr, @@ -1468,7 +1493,7 @@ func initDB( return queriesT{}, err } - close := func() error { + closeFn := func() error { return g.SomeFnError( takeClose, publishClose, @@ -1483,6 +1508,7 @@ func initDB( sizeClose, countClose, hasDataClose, + shared.Close, ) } @@ -1614,7 +1640,7 @@ func initDB( close: func() error { connMutex.Lock() defer connMutex.Unlock() - return close() + return closeFn() }, }, nil } @@ -1826,20 +1852,10 @@ func runReaper( } func NewWithPrefix(databasePath string, prefix string) (IQueue, error) { - err := g.ValidateSQLTablePrefix(prefix) - if err != nil { - return queueT{}, err - } - - db, err := sql.Open(golite.DriverName, databasePath) - if err != nil { - return queueT{}, err - } - subscriptions := makeSubscriptionsFuncs() pinger := newPinger[struct{}]() notifyFn := makeNotifyFn(subscriptions.read, pinger) - queries, err := initDB(db, prefix, notifyFn, os.Getpid()) + queries, err := initDB(databasePath, prefix, notifyFn, os.Getpid()) if err != nil { return queueT{}, err } @@ -1847,7 +1863,6 @@ func NewWithPrefix(databasePath string, prefix string) (IQueue, error) { go runReaper(pinger.onPing, subscriptions.read, subscriptions.write) return queueT{ - db: db, queries: queries, subscriptions: subscriptions, pinger: pinger, @@ -2120,7 +2135,6 @@ func cleanSubscriptions(set subscriptionsSetM) error { func (queue queueT) Close() error { queue.pinger.close() return g.WrapErrors( - queue.db.Close(), queue.subscriptions.write(cleanSubscriptions), queue.queries.close(), ) |