summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/q.go652
1 files changed, 333 insertions, 319 deletions
diff --git a/src/q.go b/src/q.go
index 79570df..598cd75 100644
--- a/src/q.go
+++ b/src/q.go
@@ -1,6 +1,6 @@
package q
+
import (
- "context"
"database/sql"
"flag"
"fmt"
@@ -18,15 +18,21 @@ import (
const (
- defaultPrefix = "q"
- reaperSkipCount = 1000
- notOwnerErrorFmt = "%v owns %#v as %#v, not us (%v)"
- noLongerOwnerErrorFmt = "we (%v) no longer own %#v as %#v, but %v does"
- rollbackErrorFmt = "rollback error: %w; while executing: %w"
+ defaultPrefix = "q"
+ reaperSkipCount = 1000
+ notOwnerErrorFmt = "%v owns %#v as %#v, not us (%v)"
+ rollbackErrorFmt = "rollback error: %w; while executing: %w"
)
+type dbconfigT struct{
+ shared *sql.DB
+ dbpath string
+ prefix string
+ instanceID int
+}
+
type queryT struct{
write string
read string
@@ -126,7 +132,6 @@ type subscriptionsT struct {
}
type queueT struct{
- db *sql.DB
queries queriesT
subscriptions subscriptionsT
pinger pingerT[struct{}]
@@ -158,12 +163,55 @@ type IQueue interface{
-func closeNoop() error {
- return nil
+func serialized[A any, B any](callback func(...A) B) (func(...A) B, func()) {
+ in := make(chan []A)
+ out := make(chan B)
+
+ closed := false
+ var (
+ closeWg sync.WaitGroup
+ closeMutex sync.Mutex
+ )
+ closeWg.Add(1)
+
+ go func() {
+ for input := range in {
+ out <- callback(input...)
+ }
+ close(out)
+ closeWg.Done()
+ }()
+
+ fn := func(input ...A) B {
+ in <- input
+ return (<- out)
+ }
+
+ closeFn := func() {
+ closeMutex.Lock()
+ defer closeMutex.Unlock()
+ if closed {
+ return
+ }
+ close(in)
+ closed = true
+ closeWg.Wait()
+ }
+
+ return fn, closeFn
+}
+
+func execSerialized(query string, db *sql.DB) (func(...any) error, func()) {
+ return serialized(func(args ...any) error {
+ return inTx(db, func(tx *sql.Tx) error {
+ _, err := tx.Exec(query, args...)
+ return err
+ })
+ })
}
-func tryRollback(db *sql.DB, ctx context.Context, err error) error {
- _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;")
+func tryRollback(tx *sql.Tx, err error) error {
+ rollbackErr := tx.Rollback()
if rollbackErr != nil {
return fmt.Errorf(
rollbackErrorFmt,
@@ -175,25 +223,20 @@ func tryRollback(db *sql.DB, ctx context.Context, err error) error {
return err
}
-// FIXME
-// See:
-// https://sqlite.org/forum/forumpost/2507664507
-func inTx(db *sql.DB, fn func(context.Context) error) error {
- ctx := context.Background()
-
- _, err := db.ExecContext(ctx, "BEGIN IMMEDIATE;")
+func inTx(db *sql.DB, fn func(*sql.Tx) error) error {
+ tx, err := db.Begin()
if err != nil {
return err
}
- err = fn(ctx)
+ err = fn(tx)
if err != nil {
- return tryRollback(db, ctx, err)
+ return tryRollback(tx, err)
}
- _, err = db.ExecContext(ctx, "COMMIT;")
+ err = tx.Commit()
if err != nil {
- return tryRollback(db, ctx, err)
+ return tryRollback(tx, err)
}
return nil
@@ -227,6 +270,7 @@ func createTablesSQL(prefix string) queryT {
consumer TEXT NOT NULL,
message_id INTEGER NOT NULL
REFERENCES "%s_messages"(id),
+ instance_id INTEGER NOT NULL,
UNIQUE (consumer, message_id)
) STRICT;
CREATE INDEX IF NOT EXISTS "%s_offsets_consumer"
@@ -238,6 +282,7 @@ func createTablesSQL(prefix string) queryT {
consumer TEXT NOT NULL,
message_id INTEGER NOT NULL
REFERENCES "%s_messages"(id),
+ instance_id INTEGER NOT NULL,
UNIQUE (consumer, message_id)
) STRICT;
CREATE INDEX IF NOT EXISTS "%s_deadletters_consumer"
@@ -258,6 +303,44 @@ func createTablesSQL(prefix string) queryT {
owner_id INTEGER NOT NULL,
UNIQUE (topic, consumer)
) STRICT;
+
+ CREATE TRIGGER IF NOT EXISTS "%s_check_instance_owns_topic"
+ BEFORE INSERT ON "%s_offsets"
+ WHEN NEW.instance_id != (
+ SELECT owner_id FROM "%s_owners"
+ WHERE topic = (
+ SELECT "%s_payloads".topic
+ FROM "%s_payloads"
+ JOIN "%s_messages" ON "%s_payloads".id =
+ "%s_messages".payload_id
+ WHERE "%s_messages".id = NEW.message_id
+ ) AND consumer = NEW.consumer
+ )
+ BEGIN
+ SELECT RAISE(
+ ABORT,
+ 'instance does not own topic/consumer combo'
+ );
+ END;
+
+ CREATE TRIGGER IF NOT EXISTS "%s_check_can_publish_deadletter"
+ BEFORE INSERT ON "%s_deadletters"
+ WHEN NEW.instance_id != (
+ SELECT owner_id FROM "%s_owners"
+ WHERE topic = (
+ SELECT "%s_payloads".topic
+ FROM "%s_payloads"
+ JOIN "%s_messages" ON "%s_payloads".id =
+ "%s_messages".payload_id
+ WHERE "%s_messages".id = NEW.message_id
+ ) AND consumer = NEW.consumer
+ )
+ BEGIN
+ SELECT RAISE(
+ ABORT,
+ 'Instance does not own topic/consumer combo'
+ );
+ END;
`
return queryT{
write: fmt.Sprintf(
@@ -284,6 +367,24 @@ func createTablesSQL(prefix string) queryT {
prefix,
prefix,
prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
),
}
}
@@ -291,8 +392,8 @@ func createTablesSQL(prefix string) queryT {
func createTables(db *sql.DB, prefix string) error {
q := createTablesSQL(prefix)
- return inTx(db, func(ctx context.Context) error {
- _, err := db.ExecContext(ctx, q.write)
+ return inTx(db, func(tx *sql.Tx) error {
+ _, err := tx.Exec(q.write)
return err
})
}
@@ -310,26 +411,25 @@ func takeSQL(prefix string) queryT {
}
func takeStmt(
- db *sql.DB,
- prefix string,
- instanceID int,
+ cfg dbconfigT,
) (func(string, string) error, func() error, error) {
- q := takeSQL(prefix)
+ q := takeSQL(cfg.prefix)
+
+ writeStmt, err := cfg.shared.Prepare(q.write)
+ if err != nil {
+ return nil, nil, err
+ }
fn := func(topic string, consumer string) error {
- return inTx(db, func(ctx context.Context) error {
- _, err := db.ExecContext(
- ctx,
- q.write,
- topic,
- consumer,
- instanceID,
- )
- return err
- })
+ _, err := writeStmt.Exec(
+ topic,
+ consumer,
+ cfg.instanceID,
+ )
+ return err
}
- return fn, closeNoop, nil
+ return fn, writeStmt.Close, nil
}
func publishSQL(prefix string) queryT {
@@ -337,7 +437,6 @@ func publishSQL(prefix string) queryT {
INSERT INTO "%s_payloads" (topic, payload)
VALUES (?, ?);
- -- FIXME: must be inside a trnsaction
INSERT INTO "%s_messages" (uuid, flow_id, payload_id)
VALUES (?, ?, last_insert_rowid());
`
@@ -352,17 +451,23 @@ func publishSQL(prefix string) queryT {
}
func publishStmt(
- db *sql.DB,
- prefix string,
- _ int,
+ cfg dbconfigT,
) (func(UnsentMessage, guuid.UUID) (messageT, error), func() error, error) {
- q := publishSQL(prefix)
+ q := publishSQL(cfg.prefix)
+
+ readStmt, err := cfg.shared.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
- readStmt, err := db.Prepare(q.read)
+ privateDB, err := sql.Open(golite.DriverName, cfg.dbpath)
if err != nil {
+ readStmt.Close()
return nil, nil, err
}
+ writeFn, writeFnClose := execSerialized(q.write, privateDB)
+
fn := func(
unsentMessage UnsentMessage,
messageID guuid.UUID,
@@ -376,17 +481,12 @@ func publishStmt(
message_id_bytes := messageID[:]
flow_id_bytes := unsentMessage.FlowID[:]
- err := inTx(db, func(ctx context.Context) error {
- _, err := db.ExecContext(
- ctx,
- q.write,
- unsentMessage.Topic,
- unsentMessage.Payload,
- message_id_bytes,
- flow_id_bytes,
- )
- return err
- })
+ err := writeFn(
+ unsentMessage.Topic,
+ unsentMessage.Payload,
+ message_id_bytes,
+ flow_id_bytes,
+ )
if err != nil {
return messageT{}, err
}
@@ -408,7 +508,12 @@ func publishStmt(
return message, nil
}
- return fn, readStmt.Close, nil
+ closeFn := func() error {
+ writeFnClose()
+ return g.SomeError(privateDB.Close(), readStmt.Close())
+ }
+
+ return fn, closeFn, nil
}
func findSQL(prefix string) queryT {
@@ -446,13 +551,11 @@ func findSQL(prefix string) queryT {
}
func findStmt(
- db *sql.DB,
- prefix string,
- _ int,
+ cfg dbconfigT,
) (func(string, guuid.UUID) (messageT, error), func() error, error) {
- q := findSQL(prefix)
+ q := findSQL(cfg.prefix)
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
@@ -493,6 +596,13 @@ func findStmt(
func nextSQL(prefix string) queryT {
const tmpl_read = `
SELECT
+ (
+ SELECT owner_id FROM "%s_owners"
+ WHERE
+ topic = ? AND
+ consumer = ?
+ LIMIT 1
+ ) AS owner_id,
"%s_messages".id,
"%s_messages".timestamp,
"%s_messages".uuid,
@@ -510,12 +620,6 @@ func nextSQL(prefix string) queryT {
ORDER BY "%s_messages".id ASC
LIMIT 1;
`
- const tmpl_owner = `
- SELECT owner_id FROM "%s_owners"
- WHERE
- topic = ? AND
- consumer = ?;
- `
return queryT{
read: fmt.Sprintf(
tmpl_read,
@@ -532,17 +636,20 @@ func nextSQL(prefix string) queryT {
prefix,
prefix,
prefix,
+ prefix,
),
- owner: fmt.Sprintf(tmpl_owner, prefix),
}
}
func nextStmt(
- db *sql.DB,
- prefix string,
- instanceID int,
+ cfg dbconfigT,
) (func(string, string) (messageT, error), func() error, error) {
- q := nextSQL(prefix)
+ q := nextSQL(cfg.prefix)
+
+ readStmt, err := cfg.shared.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
fn := func(topic string, consumer string) (messageT, error) {
message := messageT{
@@ -550,44 +657,34 @@ func nextStmt(
}
var (
- err error
ownerID int
timestr string
message_id_bytes []byte
flow_id_bytes []byte
)
- tx, err := db.Begin()
- if err != nil {
- return messageT{}, err
- }
- defer tx.Rollback()
- err = tx.QueryRow(q.owner, topic, consumer).Scan(&ownerID)
+ err = readStmt.QueryRow(topic, consumer, topic, consumer).Scan(
+ &ownerID,
+ &message.id,
+ &timestr,
+ &message_id_bytes,
+ &flow_id_bytes,
+ &message.payload,
+ )
if err != nil {
return messageT{}, err
}
- if ownerID != instanceID {
+ if ownerID != cfg.instanceID {
err := fmt.Errorf(
notOwnerErrorFmt,
ownerID,
topic,
consumer,
- instanceID,
+ cfg.instanceID,
)
return messageT{}, err
}
-
- err = tx.QueryRow(q.read, topic, consumer).Scan(
- &message.id,
- &timestr,
- &message_id_bytes,
- &flow_id_bytes,
- &message.payload,
- )
- if err != nil {
- return messageT{}, err
- }
message.uuid = guuid.UUID(message_id_bytes)
message.flowID = guuid.UUID(flow_id_bytes)
@@ -599,7 +696,7 @@ func nextStmt(
return message, nil
}
- return fn, closeNoop, nil
+ return fn, readStmt.Close, nil
}
func messageEach(rows *sql.Rows, callback func(messageT) error) error {
@@ -623,19 +720,22 @@ func messageEach(rows *sql.Rows, callback func(messageT) error) error {
&message.payload,
)
if err != nil {
- return g.WrapErrors(rows.Close(), err)
+ rows.Close()
+ return err
}
message.uuid = guuid.UUID(message_id_bytes)
message.flowID = guuid.UUID(flow_id_bytes)
message.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
if err != nil {
- return g.WrapErrors(rows.Close(), err)
+ rows.Close()
+ return err
}
err = callback(message)
if err != nil {
- return g.WrapErrors(rows.Close(), err)
+ rows.Close()
+ return err
}
}
@@ -691,20 +791,19 @@ func pendingSQL(prefix string) queryT {
}
func pendingStmt(
- db *sql.DB,
- prefix string,
- instanceID int,
+ cfg dbconfigT,
) (func(string, string) (*sql.Rows, error), func() error, error) {
- q := pendingSQL(prefix)
+ q := pendingSQL(cfg.prefix)
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
- ownerStmt, err := db.Prepare(q.owner)
+ ownerStmt, err := cfg.shared.Prepare(q.owner)
if err != nil {
- return nil, nil, g.WrapErrors(readStmt.Close(), err)
+ readStmt.Close()
+ return nil, nil, err
}
fn := func(topic string, consumer string) (*sql.Rows, error) {
@@ -716,7 +815,7 @@ func pendingStmt(
// best effort check, the final one is done during
// commit within a transaction
- if ownerID != instanceID {
+ if ownerID != cfg.instanceID {
return nil, nil
}
@@ -732,130 +831,67 @@ func pendingStmt(
func commitSQL(prefix string) queryT {
const tmpl_write = `
- INSERT INTO "%s_offsets" (consumer, message_id)
- VALUES (?, (SELECT id FROM "%s_messages" WHERE uuid = ?));
- `
- const tmpl_read = `
- SELECT "%s_payloads".topic from "%s_payloads"
- JOIN "%s_messages" ON
- "%s_payloads".id = "%s_messages".payload_id
- WHERE "%s_messages".uuid = ?;
- `
- const tmpl_owner = `
- SELECT owner_id FROM "%s_owners"
- WHERE
- topic = ? AND
- consumer = ?;
+ INSERT INTO "%s_offsets" (consumer, message_id, instance_id)
+ VALUES (?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?);
`
return queryT{
write: fmt.Sprintf(tmpl_write, prefix, prefix),
- read: fmt.Sprintf(
- tmpl_read,
- prefix,
- prefix,
- prefix,
- prefix,
- prefix,
- prefix,
- ),
- owner: fmt.Sprintf(tmpl_owner, prefix),
}
}
func commitStmt(
- db *sql.DB,
- prefix string,
- instanceID int,
+ cfg dbconfigT,
) (func(string, guuid.UUID) error, func() error, error) {
- q := commitSQL(prefix)
+ q := commitSQL(cfg.prefix)
+
+ writeStmt, err := cfg.shared.Prepare(q.write)
+ if err != nil {
+ return nil, nil, err
+ }
fn := func(consumer string, messageID guuid.UUID) error {
message_id_bytes := messageID[:]
- return inTx(db, func(ctx context.Context) error {
- var topic string
- err := db.QueryRowContext(
- ctx,
- q.read,
- message_id_bytes,
- ).Scan(&topic)
- if err != nil {
- return err
- }
-
- var ownerID int
- err = db.QueryRowContext(
- ctx,
- q.owner,
- topic,
- consumer,
- ).Scan(&ownerID)
- if err != nil {
- return err
- }
-
- if ownerID != instanceID {
- return fmt.Errorf(
- noLongerOwnerErrorFmt,
- instanceID,
- topic,
- consumer,
- ownerID,
- )
- }
-
- _, err = db.ExecContext(ctx, q.write, consumer, message_id_bytes)
- return err
- })
+ _, err = writeStmt.Exec(
+ consumer,
+ message_id_bytes,
+ cfg.instanceID,
+ )
+ return err
}
- return fn, closeNoop, nil
+ return fn, writeStmt.Close, nil
}
func toDeadSQL(prefix string) queryT {
const tmpl_write = `
- INSERT INTO "%s_offsets" ( consumer, message_id)
- VALUES ( ?, (SELECT id FROM "%s_messages" WHERE uuid = ?));
+ INSERT INTO "%s_offsets"
+ ( consumer, message_id, instance_id)
+ VALUES ( ?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?);
- INSERT INTO "%s_deadletters" (uuid, consumer, message_id)
- VALUES (?, ?, (SELECT id FROM "%s_messages" WHERE uuid = ?));
- `
- const tmpl_read = `
- SELECT "%s_payloads".topic FROM "%s_payloads"
- JOIN "%s_messages" ON
- "%s_payloads".id = "%s_messages".payload_id
- WHERE "%s_messages".uuid = ?;
- `
- const tmpl_owner = `
- SELECT owner_id FROM "%s_owners"
- WHERE
- topic = ? AND
- consumer = ?;
+ INSERT INTO "%s_deadletters"
+ (uuid, consumer, message_id, instance_id)
+ VALUES (?, ?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?);
`
return queryT{
write: fmt.Sprintf(tmpl_write, prefix, prefix, prefix, prefix),
- read: fmt.Sprintf(
- tmpl_read,
- prefix,
- prefix,
- prefix,
- prefix,
- prefix,
- prefix,
- ),
- owner: fmt.Sprintf(tmpl_owner, prefix),
}
}
func toDeadStmt(
- db *sql.DB,
- prefix string,
- instanceID int,
+ cfg dbconfigT,
) (
func(string, guuid.UUID, guuid.UUID) error,
func() error,
error,
) {
- q := toDeadSQL(prefix)
+ q := toDeadSQL(cfg.prefix)
+
+ privateDB, err := sql.Open(golite.DriverName, cfg.dbpath)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ writeFn, writeFnClose := execSerialized(q.write, privateDB)
fn := func(
consumer string,
@@ -864,52 +900,24 @@ func toDeadStmt(
) error {
message_id_bytes := messageID[:]
deadletter_id_bytes := deadletterID[:]
- return inTx(db, func(ctx context.Context) error {
- var topic string
- err := db.QueryRowContext(
- ctx,
- q.read,
- message_id_bytes,
- ).Scan(&topic)
- if err != nil {
- return err
- }
-
- var ownerID int
- err = db.QueryRowContext(
- ctx,
- q.owner,
- topic,
- consumer,
- ).Scan(&ownerID)
- if err != nil {
- return err
- }
-
- if ownerID != instanceID {
- return fmt.Errorf(
- noLongerOwnerErrorFmt,
- instanceID,
- topic,
- consumer,
- ownerID,
- )
- }
+ return writeFn(
+ consumer,
+ message_id_bytes,
+ cfg.instanceID,
+ deadletter_id_bytes,
+ consumer,
+ message_id_bytes,
+ cfg.instanceID,
+ )
+ }
- _, err = db.ExecContext(
- ctx,
- q.write,
- consumer,
- message_id_bytes,
- deadletter_id_bytes,
- consumer,
- message_id_bytes,
- )
- return err
- })
+ closeFn := func() error {
+ writeFnClose()
+ return privateDB.Close()
}
- return fn, closeNoop, nil
+
+ return fn, closeFn, nil
}
func replaySQL(prefix string) queryT {
@@ -973,33 +981,34 @@ func replaySQL(prefix string) queryT {
}
func replayStmt(
- db *sql.DB,
- prefix string,
- _ int,
+ cfg dbconfigT,
) (func(guuid.UUID, guuid.UUID) (messageT, error), func() error, error) {
- q := replaySQL(prefix)
+ q := replaySQL(cfg.prefix)
+
+ readStmt, err := cfg.shared.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
- readStmt, err := db.Prepare(q.read)
+ privateDB, err := sql.Open(golite.DriverName, cfg.dbpath)
if err != nil {
+ readStmt.Close()
return nil, nil, err
}
+ writeFn, writeFnClose := execSerialized(q.write, privateDB)
+
fn := func(
deadletterID guuid.UUID,
messageID guuid.UUID,
) (messageT, error) {
deadletter_id_bytes := deadletterID[:]
message_id_bytes := messageID[:]
- err := inTx(db, func(ctx context.Context) error {
- _, err := db.ExecContext(
- ctx,
- q.write,
- message_id_bytes,
- deadletter_id_bytes,
- deadletter_id_bytes,
- )
- return err
- })
+ err := writeFn(
+ message_id_bytes,
+ deadletter_id_bytes,
+ deadletter_id_bytes,
+ )
if err != nil {
return messageT{}, err
}
@@ -1032,7 +1041,12 @@ func replayStmt(
return message, nil
}
- return fn, readStmt.Close, nil
+ closeFn := func() error {
+ writeFnClose()
+ return g.SomeError(privateDB.Close(), readStmt.Close())
+ }
+
+ return fn, closeFn, nil
}
func oneDeadSQL(prefix string) queryT {
@@ -1085,13 +1099,11 @@ func oneDeadSQL(prefix string) queryT {
}
func oneDeadStmt(
- db *sql.DB,
- prefix string,
- _ int,
+ cfg dbconfigT,
) (func(string, string) (deadletterT, error), func() error, error) {
- q := oneDeadSQL(prefix)
+ q := oneDeadSQL(cfg.prefix)
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
@@ -1157,7 +1169,8 @@ func deadletterEach(
&message.payload,
)
if err != nil {
- return g.WrapErrors(rows.Close(), err)
+ rows.Close()
+ return err
}
deadletter.uuid = guuid.UUID(deadletter_id_bytes)
@@ -1170,7 +1183,8 @@ func deadletterEach(
messageTimestr,
)
if err != nil {
- return g.WrapErrors(rows.Close(), err)
+ rows.Close()
+ return err
}
deadletter.timestamp, err = time.Parse(
@@ -1178,12 +1192,14 @@ func deadletterEach(
deadletterTimestr,
)
if err != nil {
- return g.WrapErrors(rows.Close(), err)
+ rows.Close()
+ return err
}
err = callback(deadletter, message)
if err != nil {
- return g.WrapErrors(rows.Close(), err)
+ rows.Close()
+ return err
}
}
@@ -1251,13 +1267,11 @@ func allDeadSQL(prefix string) queryT {
}
func allDeadStmt(
- db *sql.DB,
- prefix string,
- _ int,
+ cfg dbconfigT,
) (func(string, string) (*sql.Rows, error), func() error, error) {
- q := allDeadSQL(prefix)
+ q := allDeadSQL(cfg.prefix)
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
@@ -1292,13 +1306,11 @@ func sizeSQL(prefix string) queryT {
func sizeStmt(
- db *sql.DB,
- prefix string,
- _ int,
+ cfg dbconfigT,
) (func(string) (int, error), func() error, error) {
- q := sizeSQL(prefix)
+ q := sizeSQL(cfg.prefix)
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
@@ -1346,13 +1358,11 @@ func countSQL(prefix string) queryT {
}
func countStmt(
- db *sql.DB,
- prefix string,
- _ int,
+ cfg dbconfigT,
) (func(string, string) (int, error), func() error, error) {
- q := countSQL(prefix)
+ q := countSQL(cfg.prefix)
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
@@ -1399,13 +1409,11 @@ func hasDataSQL(prefix string) queryT {
}
func hasDataStmt(
- db *sql.DB,
- prefix string,
- _ int,
+ cfg dbconfigT,
) (func(string, string) (bool, error), func() error, error) {
- q := hasDataSQL(prefix)
+ q := hasDataSQL(cfg.prefix)
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
@@ -1428,27 +1436,44 @@ func hasDataStmt(
}
func initDB(
- db *sql.DB,
+ dbpath string,
prefix string,
notifyFn func(messageT),
instanceID int,
) (queriesT, error) {
- createTablesErr := createTables(db, prefix)
- take, takeClose, takeErr := takeStmt(db, prefix, instanceID)
- publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
- find, findClose, findErr := findStmt(db, prefix, instanceID)
- next, nextClose, nextErr := nextStmt(db, prefix, instanceID)
- pending, pendingClose, pendingErr := pendingStmt(db, prefix, instanceID)
- commit, commitClose, commitErr := commitStmt(db, prefix, instanceID)
- toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID)
- replay, replayClose, replayErr := replayStmt(db, prefix, instanceID)
- oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(db, prefix, instanceID)
- allDead, allDeadClose, allDeadErr := allDeadStmt(db, prefix, instanceID)
- size, sizeClose, sizeErr := sizeStmt(db, prefix, instanceID)
- count, countClose, countErr := countStmt(db, prefix, instanceID)
- hasData, hasDataClose, hasDataErr := hasDataStmt(db, prefix, instanceID)
-
- err := g.SomeError(
+ err := g.ValidateSQLTablePrefix(prefix)
+ if err != nil {
+ return queriesT{}, err
+ }
+
+ shared, err := sql.Open(golite.DriverName, dbpath)
+ if err != nil {
+ return queriesT{}, err
+ }
+
+ cfg := dbconfigT{
+ shared: shared,
+ dbpath: dbpath,
+ prefix: prefix,
+ instanceID: instanceID,
+ }
+
+ createTablesErr := createTables(shared, prefix)
+ take, takeClose, takeErr := takeStmt(cfg)
+ publish, publishClose, publishErr := publishStmt(cfg)
+ find, findClose, findErr := findStmt(cfg)
+ next, nextClose, nextErr := nextStmt(cfg)
+ pending, pendingClose, pendingErr := pendingStmt(cfg)
+ commit, commitClose, commitErr := commitStmt(cfg)
+ toDead, toDeadClose, toDeadErr := toDeadStmt(cfg)
+ replay, replayClose, replayErr := replayStmt(cfg)
+ oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(cfg)
+ allDead, allDeadClose, allDeadErr := allDeadStmt(cfg)
+ size, sizeClose, sizeErr := sizeStmt(cfg)
+ count, countClose, countErr := countStmt(cfg)
+ hasData, hasDataClose, hasDataErr := hasDataStmt(cfg)
+
+ err = g.SomeError(
createTablesErr,
takeErr,
publishErr,
@@ -1468,7 +1493,7 @@ func initDB(
return queriesT{}, err
}
- close := func() error {
+ closeFn := func() error {
return g.SomeFnError(
takeClose,
publishClose,
@@ -1483,6 +1508,7 @@ func initDB(
sizeClose,
countClose,
hasDataClose,
+ shared.Close,
)
}
@@ -1614,7 +1640,7 @@ func initDB(
close: func() error {
connMutex.Lock()
defer connMutex.Unlock()
- return close()
+ return closeFn()
},
}, nil
}
@@ -1826,20 +1852,10 @@ func runReaper(
}
func NewWithPrefix(databasePath string, prefix string) (IQueue, error) {
- err := g.ValidateSQLTablePrefix(prefix)
- if err != nil {
- return queueT{}, err
- }
-
- db, err := sql.Open(golite.DriverName, databasePath)
- if err != nil {
- return queueT{}, err
- }
-
subscriptions := makeSubscriptionsFuncs()
pinger := newPinger[struct{}]()
notifyFn := makeNotifyFn(subscriptions.read, pinger)
- queries, err := initDB(db, prefix, notifyFn, os.Getpid())
+ queries, err := initDB(databasePath, prefix, notifyFn, os.Getpid())
if err != nil {
return queueT{}, err
}
@@ -1847,7 +1863,6 @@ func NewWithPrefix(databasePath string, prefix string) (IQueue, error) {
go runReaper(pinger.onPing, subscriptions.read, subscriptions.write)
return queueT{
- db: db,
queries: queries,
subscriptions: subscriptions,
pinger: pinger,
@@ -2120,7 +2135,6 @@ func cleanSubscriptions(set subscriptionsSetM) error {
func (queue queueT) Close() error {
queue.pinger.close()
return g.WrapErrors(
- queue.db.Close(),
queue.subscriptions.write(cleanSubscriptions),
queue.queries.close(),
)