diff options
author | EuAndreh <eu@euandre.org> | 2024-10-30 20:26:41 -0300 |
---|---|---|
committer | EuAndreh <eu@euandre.org> | 2024-10-31 19:57:42 -0300 |
commit | 5d53704f450788452837b18b26f54d235d883490 (patch) | |
tree | 91bb61d6f823523675ae9f93bb8561858bf7be14 /src | |
parent | tests/q.go: Replace ":memory:" with golite.InMemory (diff) | |
download | fiinha-5d53704f450788452837b18b26f54d235d883490.tar.gz fiinha-5d53704f450788452837b18b26f54d235d883490.tar.xz |
src/q.go: Fix SQLite ~broken~ transactions
As per expected by SQLite, create new connections so they have
independent transaction, and serialize its use via a channel with a
single consumer.
== Other changes
=== Remove `db` attribute from `queueT` type
Now that `initDB()` can create many database connections (A.K.A. opaque
pointers to the `sqlite3*` object), it makes more sense to hand to it
the responsability to create and destroy these databases. So now a
`queries.close()` includes what previously was `db.Close()`, and needing
to do it was the only reason that made us keep the `db` attribute.
The arguments to `initDB()` were adjusted to reflect that, as it no
longer is given an initialized database handle, but only the database
path, and does the creation of the handles by itself.
=== Shoehorn multi-statement queries into single-statement ones
In order to avoid creating one `execSerialized()` function per attribute
of the `queryT` queries, we try to leverage SQLite's built-in
transactionality per individual statement. So compound queries that were
previously done in multiple statements wrapped with a `inTx()` that
could be shoe-horned into a single `SELECT` were rewritten to avoid
having more of this application-level serialization.
The topic/consumer owner validation was also changed: now it is a
trigger on the relevant tables that `ABORT` the implicit/invisible
transaction and makes the `INSERT` fail.
In order to make this possible, the `"%s_owners"` table now has an extra
column: `instance_id`. Despite being extra data, this meta-information
isn't duplicated from anywhere else, and it is an actual useful
information for operators to leverage.
So now the trigger is responsible for stopping the transaction from
going forward without adding an explicit `inTx()` around it, and the
writes can mostly be shrunk to single-statement queries.
Fortunately, all this is enough to fix the `new-instance-takeover`
functional test.
Diffstat (limited to 'src')
-rw-r--r-- | src/q.go | 652 |
1 files changed, 333 insertions, 319 deletions
@@ -1,6 +1,6 @@ package q + import ( - "context" "database/sql" "flag" "fmt" @@ -18,15 +18,21 @@ import ( const ( - defaultPrefix = "q" - reaperSkipCount = 1000 - notOwnerErrorFmt = "%v owns %#v as %#v, not us (%v)" - noLongerOwnerErrorFmt = "we (%v) no longer own %#v as %#v, but %v does" - rollbackErrorFmt = "rollback error: %w; while executing: %w" + defaultPrefix = "q" + reaperSkipCount = 1000 + notOwnerErrorFmt = "%v owns %#v as %#v, not us (%v)" + rollbackErrorFmt = "rollback error: %w; while executing: %w" ) +type dbconfigT struct{ + shared *sql.DB + dbpath string + prefix string + instanceID int +} + type queryT struct{ write string read string @@ -126,7 +132,6 @@ type subscriptionsT struct { } type queueT struct{ - db *sql.DB queries queriesT subscriptions subscriptionsT pinger pingerT[struct{}] @@ -158,12 +163,55 @@ type IQueue interface{ -func closeNoop() error { - return nil +func serialized[A any, B any](callback func(...A) B) (func(...A) B, func()) { + in := make(chan []A) + out := make(chan B) + + closed := false + var ( + closeWg sync.WaitGroup + closeMutex sync.Mutex + ) + closeWg.Add(1) + + go func() { + for input := range in { + out <- callback(input...) + } + close(out) + closeWg.Done() + }() + + fn := func(input ...A) B { + in <- input + return (<- out) + } + + closeFn := func() { + closeMutex.Lock() + defer closeMutex.Unlock() + if closed { + return + } + close(in) + closed = true + closeWg.Wait() + } + + return fn, closeFn +} + +func execSerialized(query string, db *sql.DB) (func(...any) error, func()) { + return serialized(func(args ...any) error { + return inTx(db, func(tx *sql.Tx) error { + _, err := tx.Exec(query, args...) + return err + }) + }) } -func tryRollback(db *sql.DB, ctx context.Context, err error) error { - _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;") +func tryRollback(tx *sql.Tx, err error) error { + rollbackErr := tx.Rollback() if rollbackErr != nil { return fmt.Errorf( rollbackErrorFmt, @@ -175,25 +223,20 @@ func tryRollback(db *sql.DB, ctx context.Context, err error) error { return err } -// FIXME -// See: -// https://sqlite.org/forum/forumpost/2507664507 -func inTx(db *sql.DB, fn func(context.Context) error) error { - ctx := context.Background() - - _, err := db.ExecContext(ctx, "BEGIN IMMEDIATE;") +func inTx(db *sql.DB, fn func(*sql.Tx) error) error { + tx, err := db.Begin() if err != nil { return err } - err = fn(ctx) + err = fn(tx) if err != nil { - return tryRollback(db, ctx, err) + return tryRollback(tx, err) } - _, err = db.ExecContext(ctx, "COMMIT;") + err = tx.Commit() if err != nil { - return tryRollback(db, ctx, err) + return tryRollback(tx, err) } return nil @@ -227,6 +270,7 @@ func createTablesSQL(prefix string) queryT { consumer TEXT NOT NULL, message_id INTEGER NOT NULL REFERENCES "%s_messages"(id), + instance_id INTEGER NOT NULL, UNIQUE (consumer, message_id) ) STRICT; CREATE INDEX IF NOT EXISTS "%s_offsets_consumer" @@ -238,6 +282,7 @@ func createTablesSQL(prefix string) queryT { consumer TEXT NOT NULL, message_id INTEGER NOT NULL REFERENCES "%s_messages"(id), + instance_id INTEGER NOT NULL, UNIQUE (consumer, message_id) ) STRICT; CREATE INDEX IF NOT EXISTS "%s_deadletters_consumer" @@ -258,6 +303,44 @@ func createTablesSQL(prefix string) queryT { owner_id INTEGER NOT NULL, UNIQUE (topic, consumer) ) STRICT; + + CREATE TRIGGER IF NOT EXISTS "%s_check_instance_owns_topic" + BEFORE INSERT ON "%s_offsets" + WHEN NEW.instance_id != ( + SELECT owner_id FROM "%s_owners" + WHERE topic = ( + SELECT "%s_payloads".topic + FROM "%s_payloads" + JOIN "%s_messages" ON "%s_payloads".id = + "%s_messages".payload_id + WHERE "%s_messages".id = NEW.message_id + ) AND consumer = NEW.consumer + ) + BEGIN + SELECT RAISE( + ABORT, + 'instance does not own topic/consumer combo' + ); + END; + + CREATE TRIGGER IF NOT EXISTS "%s_check_can_publish_deadletter" + BEFORE INSERT ON "%s_deadletters" + WHEN NEW.instance_id != ( + SELECT owner_id FROM "%s_owners" + WHERE topic = ( + SELECT "%s_payloads".topic + FROM "%s_payloads" + JOIN "%s_messages" ON "%s_payloads".id = + "%s_messages".payload_id + WHERE "%s_messages".id = NEW.message_id + ) AND consumer = NEW.consumer + ) + BEGIN + SELECT RAISE( + ABORT, + 'Instance does not own topic/consumer combo' + ); + END; ` return queryT{ write: fmt.Sprintf( @@ -284,6 +367,24 @@ func createTablesSQL(prefix string) queryT { prefix, prefix, prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, ), } } @@ -291,8 +392,8 @@ func createTablesSQL(prefix string) queryT { func createTables(db *sql.DB, prefix string) error { q := createTablesSQL(prefix) - return inTx(db, func(ctx context.Context) error { - _, err := db.ExecContext(ctx, q.write) + return inTx(db, func(tx *sql.Tx) error { + _, err := tx.Exec(q.write) return err }) } @@ -310,26 +411,25 @@ func takeSQL(prefix string) queryT { } func takeStmt( - db *sql.DB, - prefix string, - instanceID int, + cfg dbconfigT, ) (func(string, string) error, func() error, error) { - q := takeSQL(prefix) + q := takeSQL(cfg.prefix) + + writeStmt, err := cfg.shared.Prepare(q.write) + if err != nil { + return nil, nil, err + } fn := func(topic string, consumer string) error { - return inTx(db, func(ctx context.Context) error { - _, err := db.ExecContext( - ctx, - q.write, - topic, - consumer, - instanceID, - ) - return err - }) + _, err := writeStmt.Exec( + topic, + consumer, + cfg.instanceID, + ) + return err } - return fn, closeNoop, nil + return fn, writeStmt.Close, nil } func publishSQL(prefix string) queryT { @@ -337,7 +437,6 @@ func publishSQL(prefix string) queryT { INSERT INTO "%s_payloads" (topic, payload) VALUES (?, ?); - -- FIXME: must be inside a trnsaction INSERT INTO "%s_messages" (uuid, flow_id, payload_id) VALUES (?, ?, last_insert_rowid()); ` @@ -352,17 +451,23 @@ func publishSQL(prefix string) queryT { } func publishStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(UnsentMessage, guuid.UUID) (messageT, error), func() error, error) { - q := publishSQL(prefix) + q := publishSQL(cfg.prefix) + + readStmt, err := cfg.shared.Prepare(q.read) + if err != nil { + return nil, nil, err + } - readStmt, err := db.Prepare(q.read) + privateDB, err := sql.Open(golite.DriverName, cfg.dbpath) if err != nil { + readStmt.Close() return nil, nil, err } + writeFn, writeFnClose := execSerialized(q.write, privateDB) + fn := func( unsentMessage UnsentMessage, messageID guuid.UUID, @@ -376,17 +481,12 @@ func publishStmt( message_id_bytes := messageID[:] flow_id_bytes := unsentMessage.FlowID[:] - err := inTx(db, func(ctx context.Context) error { - _, err := db.ExecContext( - ctx, - q.write, - unsentMessage.Topic, - unsentMessage.Payload, - message_id_bytes, - flow_id_bytes, - ) - return err - }) + err := writeFn( + unsentMessage.Topic, + unsentMessage.Payload, + message_id_bytes, + flow_id_bytes, + ) if err != nil { return messageT{}, err } @@ -408,7 +508,12 @@ func publishStmt( return message, nil } - return fn, readStmt.Close, nil + closeFn := func() error { + writeFnClose() + return g.SomeError(privateDB.Close(), readStmt.Close()) + } + + return fn, closeFn, nil } func findSQL(prefix string) queryT { @@ -446,13 +551,11 @@ func findSQL(prefix string) queryT { } func findStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string, guuid.UUID) (messageT, error), func() error, error) { - q := findSQL(prefix) + q := findSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -493,6 +596,13 @@ func findStmt( func nextSQL(prefix string) queryT { const tmpl_read = ` SELECT + ( + SELECT owner_id FROM "%s_owners" + WHERE + topic = ? AND + consumer = ? + LIMIT 1 + ) AS owner_id, "%s_messages".id, "%s_messages".timestamp, "%s_messages".uuid, @@ -510,12 +620,6 @@ func nextSQL(prefix string) queryT { ORDER BY "%s_messages".id ASC LIMIT 1; ` - const tmpl_owner = ` - SELECT owner_id FROM "%s_owners" - WHERE - topic = ? AND - consumer = ?; - ` return queryT{ read: fmt.Sprintf( tmpl_read, @@ -532,17 +636,20 @@ func nextSQL(prefix string) queryT { prefix, prefix, prefix, + prefix, ), - owner: fmt.Sprintf(tmpl_owner, prefix), } } func nextStmt( - db *sql.DB, - prefix string, - instanceID int, + cfg dbconfigT, ) (func(string, string) (messageT, error), func() error, error) { - q := nextSQL(prefix) + q := nextSQL(cfg.prefix) + + readStmt, err := cfg.shared.Prepare(q.read) + if err != nil { + return nil, nil, err + } fn := func(topic string, consumer string) (messageT, error) { message := messageT{ @@ -550,44 +657,34 @@ func nextStmt( } var ( - err error ownerID int timestr string message_id_bytes []byte flow_id_bytes []byte ) - tx, err := db.Begin() - if err != nil { - return messageT{}, err - } - defer tx.Rollback() - err = tx.QueryRow(q.owner, topic, consumer).Scan(&ownerID) + err = readStmt.QueryRow(topic, consumer, topic, consumer).Scan( + &ownerID, + &message.id, + ×tr, + &message_id_bytes, + &flow_id_bytes, + &message.payload, + ) if err != nil { return messageT{}, err } - if ownerID != instanceID { + if ownerID != cfg.instanceID { err := fmt.Errorf( notOwnerErrorFmt, ownerID, topic, consumer, - instanceID, + cfg.instanceID, ) return messageT{}, err } - - err = tx.QueryRow(q.read, topic, consumer).Scan( - &message.id, - ×tr, - &message_id_bytes, - &flow_id_bytes, - &message.payload, - ) - if err != nil { - return messageT{}, err - } message.uuid = guuid.UUID(message_id_bytes) message.flowID = guuid.UUID(flow_id_bytes) @@ -599,7 +696,7 @@ func nextStmt( return message, nil } - return fn, closeNoop, nil + return fn, readStmt.Close, nil } func messageEach(rows *sql.Rows, callback func(messageT) error) error { @@ -623,19 +720,22 @@ func messageEach(rows *sql.Rows, callback func(messageT) error) error { &message.payload, ) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } message.uuid = guuid.UUID(message_id_bytes) message.flowID = guuid.UUID(flow_id_bytes) message.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } err = callback(message) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } } @@ -691,20 +791,19 @@ func pendingSQL(prefix string) queryT { } func pendingStmt( - db *sql.DB, - prefix string, - instanceID int, + cfg dbconfigT, ) (func(string, string) (*sql.Rows, error), func() error, error) { - q := pendingSQL(prefix) + q := pendingSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } - ownerStmt, err := db.Prepare(q.owner) + ownerStmt, err := cfg.shared.Prepare(q.owner) if err != nil { - return nil, nil, g.WrapErrors(readStmt.Close(), err) + readStmt.Close() + return nil, nil, err } fn := func(topic string, consumer string) (*sql.Rows, error) { @@ -716,7 +815,7 @@ func pendingStmt( // best effort check, the final one is done during // commit within a transaction - if ownerID != instanceID { + if ownerID != cfg.instanceID { return nil, nil } @@ -732,130 +831,67 @@ func pendingStmt( func commitSQL(prefix string) queryT { const tmpl_write = ` - INSERT INTO "%s_offsets" (consumer, message_id) - VALUES (?, (SELECT id FROM "%s_messages" WHERE uuid = ?)); - ` - const tmpl_read = ` - SELECT "%s_payloads".topic from "%s_payloads" - JOIN "%s_messages" ON - "%s_payloads".id = "%s_messages".payload_id - WHERE "%s_messages".uuid = ?; - ` - const tmpl_owner = ` - SELECT owner_id FROM "%s_owners" - WHERE - topic = ? AND - consumer = ?; + INSERT INTO "%s_offsets" (consumer, message_id, instance_id) + VALUES (?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?); ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix, prefix), - read: fmt.Sprintf( - tmpl_read, - prefix, - prefix, - prefix, - prefix, - prefix, - prefix, - ), - owner: fmt.Sprintf(tmpl_owner, prefix), } } func commitStmt( - db *sql.DB, - prefix string, - instanceID int, + cfg dbconfigT, ) (func(string, guuid.UUID) error, func() error, error) { - q := commitSQL(prefix) + q := commitSQL(cfg.prefix) + + writeStmt, err := cfg.shared.Prepare(q.write) + if err != nil { + return nil, nil, err + } fn := func(consumer string, messageID guuid.UUID) error { message_id_bytes := messageID[:] - return inTx(db, func(ctx context.Context) error { - var topic string - err := db.QueryRowContext( - ctx, - q.read, - message_id_bytes, - ).Scan(&topic) - if err != nil { - return err - } - - var ownerID int - err = db.QueryRowContext( - ctx, - q.owner, - topic, - consumer, - ).Scan(&ownerID) - if err != nil { - return err - } - - if ownerID != instanceID { - return fmt.Errorf( - noLongerOwnerErrorFmt, - instanceID, - topic, - consumer, - ownerID, - ) - } - - _, err = db.ExecContext(ctx, q.write, consumer, message_id_bytes) - return err - }) + _, err = writeStmt.Exec( + consumer, + message_id_bytes, + cfg.instanceID, + ) + return err } - return fn, closeNoop, nil + return fn, writeStmt.Close, nil } func toDeadSQL(prefix string) queryT { const tmpl_write = ` - INSERT INTO "%s_offsets" ( consumer, message_id) - VALUES ( ?, (SELECT id FROM "%s_messages" WHERE uuid = ?)); + INSERT INTO "%s_offsets" + ( consumer, message_id, instance_id) + VALUES ( ?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?); - INSERT INTO "%s_deadletters" (uuid, consumer, message_id) - VALUES (?, ?, (SELECT id FROM "%s_messages" WHERE uuid = ?)); - ` - const tmpl_read = ` - SELECT "%s_payloads".topic FROM "%s_payloads" - JOIN "%s_messages" ON - "%s_payloads".id = "%s_messages".payload_id - WHERE "%s_messages".uuid = ?; - ` - const tmpl_owner = ` - SELECT owner_id FROM "%s_owners" - WHERE - topic = ? AND - consumer = ?; + INSERT INTO "%s_deadletters" + (uuid, consumer, message_id, instance_id) + VALUES (?, ?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?); ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix, prefix, prefix, prefix), - read: fmt.Sprintf( - tmpl_read, - prefix, - prefix, - prefix, - prefix, - prefix, - prefix, - ), - owner: fmt.Sprintf(tmpl_owner, prefix), } } func toDeadStmt( - db *sql.DB, - prefix string, - instanceID int, + cfg dbconfigT, ) ( func(string, guuid.UUID, guuid.UUID) error, func() error, error, ) { - q := toDeadSQL(prefix) + q := toDeadSQL(cfg.prefix) + + privateDB, err := sql.Open(golite.DriverName, cfg.dbpath) + if err != nil { + return nil, nil, err + } + + writeFn, writeFnClose := execSerialized(q.write, privateDB) fn := func( consumer string, @@ -864,52 +900,24 @@ func toDeadStmt( ) error { message_id_bytes := messageID[:] deadletter_id_bytes := deadletterID[:] - return inTx(db, func(ctx context.Context) error { - var topic string - err := db.QueryRowContext( - ctx, - q.read, - message_id_bytes, - ).Scan(&topic) - if err != nil { - return err - } - - var ownerID int - err = db.QueryRowContext( - ctx, - q.owner, - topic, - consumer, - ).Scan(&ownerID) - if err != nil { - return err - } - - if ownerID != instanceID { - return fmt.Errorf( - noLongerOwnerErrorFmt, - instanceID, - topic, - consumer, - ownerID, - ) - } + return writeFn( + consumer, + message_id_bytes, + cfg.instanceID, + deadletter_id_bytes, + consumer, + message_id_bytes, + cfg.instanceID, + ) + } - _, err = db.ExecContext( - ctx, - q.write, - consumer, - message_id_bytes, - deadletter_id_bytes, - consumer, - message_id_bytes, - ) - return err - }) + closeFn := func() error { + writeFnClose() + return privateDB.Close() } - return fn, closeNoop, nil + + return fn, closeFn, nil } func replaySQL(prefix string) queryT { @@ -973,33 +981,34 @@ func replaySQL(prefix string) queryT { } func replayStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(guuid.UUID, guuid.UUID) (messageT, error), func() error, error) { - q := replaySQL(prefix) + q := replaySQL(cfg.prefix) + + readStmt, err := cfg.shared.Prepare(q.read) + if err != nil { + return nil, nil, err + } - readStmt, err := db.Prepare(q.read) + privateDB, err := sql.Open(golite.DriverName, cfg.dbpath) if err != nil { + readStmt.Close() return nil, nil, err } + writeFn, writeFnClose := execSerialized(q.write, privateDB) + fn := func( deadletterID guuid.UUID, messageID guuid.UUID, ) (messageT, error) { deadletter_id_bytes := deadletterID[:] message_id_bytes := messageID[:] - err := inTx(db, func(ctx context.Context) error { - _, err := db.ExecContext( - ctx, - q.write, - message_id_bytes, - deadletter_id_bytes, - deadletter_id_bytes, - ) - return err - }) + err := writeFn( + message_id_bytes, + deadletter_id_bytes, + deadletter_id_bytes, + ) if err != nil { return messageT{}, err } @@ -1032,7 +1041,12 @@ func replayStmt( return message, nil } - return fn, readStmt.Close, nil + closeFn := func() error { + writeFnClose() + return g.SomeError(privateDB.Close(), readStmt.Close()) + } + + return fn, closeFn, nil } func oneDeadSQL(prefix string) queryT { @@ -1085,13 +1099,11 @@ func oneDeadSQL(prefix string) queryT { } func oneDeadStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string, string) (deadletterT, error), func() error, error) { - q := oneDeadSQL(prefix) + q := oneDeadSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -1157,7 +1169,8 @@ func deadletterEach( &message.payload, ) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } deadletter.uuid = guuid.UUID(deadletter_id_bytes) @@ -1170,7 +1183,8 @@ func deadletterEach( messageTimestr, ) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } deadletter.timestamp, err = time.Parse( @@ -1178,12 +1192,14 @@ func deadletterEach( deadletterTimestr, ) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } err = callback(deadletter, message) if err != nil { - return g.WrapErrors(rows.Close(), err) + rows.Close() + return err } } @@ -1251,13 +1267,11 @@ func allDeadSQL(prefix string) queryT { } func allDeadStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string, string) (*sql.Rows, error), func() error, error) { - q := allDeadSQL(prefix) + q := allDeadSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -1292,13 +1306,11 @@ func sizeSQL(prefix string) queryT { func sizeStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string) (int, error), func() error, error) { - q := sizeSQL(prefix) + q := sizeSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -1346,13 +1358,11 @@ func countSQL(prefix string) queryT { } func countStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string, string) (int, error), func() error, error) { - q := countSQL(prefix) + q := countSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -1399,13 +1409,11 @@ func hasDataSQL(prefix string) queryT { } func hasDataStmt( - db *sql.DB, - prefix string, - _ int, + cfg dbconfigT, ) (func(string, string) (bool, error), func() error, error) { - q := hasDataSQL(prefix) + q := hasDataSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -1428,27 +1436,44 @@ func hasDataStmt( } func initDB( - db *sql.DB, + dbpath string, prefix string, notifyFn func(messageT), instanceID int, ) (queriesT, error) { - createTablesErr := createTables(db, prefix) - take, takeClose, takeErr := takeStmt(db, prefix, instanceID) - publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) - find, findClose, findErr := findStmt(db, prefix, instanceID) - next, nextClose, nextErr := nextStmt(db, prefix, instanceID) - pending, pendingClose, pendingErr := pendingStmt(db, prefix, instanceID) - commit, commitClose, commitErr := commitStmt(db, prefix, instanceID) - toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID) - replay, replayClose, replayErr := replayStmt(db, prefix, instanceID) - oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(db, prefix, instanceID) - allDead, allDeadClose, allDeadErr := allDeadStmt(db, prefix, instanceID) - size, sizeClose, sizeErr := sizeStmt(db, prefix, instanceID) - count, countClose, countErr := countStmt(db, prefix, instanceID) - hasData, hasDataClose, hasDataErr := hasDataStmt(db, prefix, instanceID) - - err := g.SomeError( + err := g.ValidateSQLTablePrefix(prefix) + if err != nil { + return queriesT{}, err + } + + shared, err := sql.Open(golite.DriverName, dbpath) + if err != nil { + return queriesT{}, err + } + + cfg := dbconfigT{ + shared: shared, + dbpath: dbpath, + prefix: prefix, + instanceID: instanceID, + } + + createTablesErr := createTables(shared, prefix) + take, takeClose, takeErr := takeStmt(cfg) + publish, publishClose, publishErr := publishStmt(cfg) + find, findClose, findErr := findStmt(cfg) + next, nextClose, nextErr := nextStmt(cfg) + pending, pendingClose, pendingErr := pendingStmt(cfg) + commit, commitClose, commitErr := commitStmt(cfg) + toDead, toDeadClose, toDeadErr := toDeadStmt(cfg) + replay, replayClose, replayErr := replayStmt(cfg) + oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(cfg) + allDead, allDeadClose, allDeadErr := allDeadStmt(cfg) + size, sizeClose, sizeErr := sizeStmt(cfg) + count, countClose, countErr := countStmt(cfg) + hasData, hasDataClose, hasDataErr := hasDataStmt(cfg) + + err = g.SomeError( createTablesErr, takeErr, publishErr, @@ -1468,7 +1493,7 @@ func initDB( return queriesT{}, err } - close := func() error { + closeFn := func() error { return g.SomeFnError( takeClose, publishClose, @@ -1483,6 +1508,7 @@ func initDB( sizeClose, countClose, hasDataClose, + shared.Close, ) } @@ -1614,7 +1640,7 @@ func initDB( close: func() error { connMutex.Lock() defer connMutex.Unlock() - return close() + return closeFn() }, }, nil } @@ -1826,20 +1852,10 @@ func runReaper( } func NewWithPrefix(databasePath string, prefix string) (IQueue, error) { - err := g.ValidateSQLTablePrefix(prefix) - if err != nil { - return queueT{}, err - } - - db, err := sql.Open(golite.DriverName, databasePath) - if err != nil { - return queueT{}, err - } - subscriptions := makeSubscriptionsFuncs() pinger := newPinger[struct{}]() notifyFn := makeNotifyFn(subscriptions.read, pinger) - queries, err := initDB(db, prefix, notifyFn, os.Getpid()) + queries, err := initDB(databasePath, prefix, notifyFn, os.Getpid()) if err != nil { return queueT{}, err } @@ -1847,7 +1863,6 @@ func NewWithPrefix(databasePath string, prefix string) (IQueue, error) { go runReaper(pinger.onPing, subscriptions.read, subscriptions.write) return queueT{ - db: db, queries: queries, subscriptions: subscriptions, pinger: pinger, @@ -2120,7 +2135,6 @@ func cleanSubscriptions(set subscriptionsSetM) error { func (queue queueT) Close() error { queue.pinger.close() return g.WrapErrors( - queue.db.Close(), queue.subscriptions.write(cleanSubscriptions), queue.queries.close(), ) |