diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/liteq.go | 117 | ||||
-rw-r--r-- | src/main.go | 4 | ||||
-rw-r--r-- | src/q.go | 2495 |
3 files changed, 2497 insertions, 119 deletions
diff --git a/src/liteq.go b/src/liteq.go deleted file mode 100644 index 2eeff34..0000000 --- a/src/liteq.go +++ /dev/null @@ -1,117 +0,0 @@ -package liteq - -import ( - "database/sql" - "flag" - "io/ioutil" - "log/slog" - "os" - "sort" - - g "gobang" - "golite" -) - - - -func InitMigrations(db *sql.DB) { - _, err := db.Exec(` - CREATE TABLE IF NOT EXISTS migrations ( - filename TEXT PRIMARY KEY - ); - `) - g.FatalIf(err) -} - -const MIGRATIONS_DIR = "src/sql/migrations/" -func PendingMigrations(db *sql.DB) []string { - files, err := ioutil.ReadDir(MIGRATIONS_DIR) - g.FatalIf(err) - - set := make(map[string]bool) - for _, file := range files { - set[file.Name()] = true - } - - rows, err := db.Query(`SELECT filename FROM migrations;`) - g.FatalIf(err) - defer rows.Close() - - for rows.Next() { - var filename string - err := rows.Scan(&filename) - g.FatalIf(err) - delete(set, filename) - } - g.FatalIf(rows.Err()) - - difference := make([]string, 0) - for filename := range set { - difference = append(difference, filename) - } - - sort.Sort(sort.StringSlice(difference)) - return difference -} - -func RunMigrations(db *sql.DB) { - InitMigrations(db) - - stmt, err := db.Prepare(`INSERT INTO migrations (filename) VALUES (?);`) - g.FatalIf(err) - defer stmt.Close() - - for _, filename := range PendingMigrations(db) { - g.Info("Running migration file", "exec-migration-file", - "filename", filename, - ) - - tx, err := db.Begin() - g.FatalIf(err) - - sql, err := os.ReadFile(MIGRATIONS_DIR + filename) - g.FatalIf(err) - - _, err = tx.Exec(string(sql)) - g.FatalIf(err) - - _, err = tx.Stmt(stmt).Exec(filename) - g.FatalIf(err) - - err = tx.Commit() - g.FatalIf(err) - } -} - -func initDB(databasePath string) *sql.DB { - db, err := sql.Open("sqlite3", databasePath) - g.FatalIf(err) - RunMigrations(db) - return db -} - -func run(db *sql.DB) { -} - - - -var ( - databasePath = flag.String( - "f", - "q.db", - "The path to the database file", - ) -) - - -func Main() { - g.Init(slog.Group( - "versions", - "gobang", g.Version, - "golite", golite.Version, - "this", version, - )) - flag.Parse() - db := initDB(*databasePath) - run(db) -} diff --git a/src/main.go b/src/main.go index 8d9a05e..51faffa 100644 --- a/src/main.go +++ b/src/main.go @@ -1,7 +1,7 @@ package main -import "liteq" +import "q" func main() { - liteq.Main() + q.Main() } diff --git a/src/q.go b/src/q.go new file mode 100644 index 0000000..6eeefe6 --- /dev/null +++ b/src/q.go @@ -0,0 +1,2495 @@ +package q +import ( + "context" + "database/sql" + "flag" + "fmt" + "io" + "log/slog" + "os" + "sync" + "time" + + _ "acudego" + "guuid" + g "gobang" +) + + + +const ( + defaultPrefix = "q" + reaperSkipCount = 1000 + notOwnerErrorFmt = "%v owns %#v as %#v, not us (%v)" + noLongerOwnerErrorFmt = "we (%v) no longer own %#v as %#v, but %v does" + rollbackErrorFmt = "rollback error: %w; while executing: %w" +) + + + +type dbI interface{ + findOne(q string, args []any, bindings []any) error + exec(q string) +} + +type queryT struct{ + write string + read string + owner string +} + +type queriesT struct{ + take func(string, string) error + publish func(UnsentMessage, guuid.UUID) (messageT, error) + find func(string, guuid.UUID) (messageT, error) + next func(string, string) (messageT, error) + pending func(string, string, func(messageT) error) error + commit func(string, guuid.UUID) error + toDead func(string, guuid.UUID, guuid.UUID) error + replay func(guuid.UUID, guuid.UUID) (messageT, error) + oneDead func(string, string) (deadletterT, error) + allDead func(string, string, func(deadletterT, messageT) error) error + size func(string) (int, error) + count func(string, string) (int, error) + hasData func(string, string) (bool, error) + close func() error +} + +type messageT struct{ + id int64 + timestamp time.Time + uuid guuid.UUID + topic string + flowID guuid.UUID + payload []byte +} + +type UnsentMessage struct{ + Topic string + FlowID guuid.UUID + Payload []byte +} + +type Message struct{ + ID guuid.UUID + Timestamp time.Time + Topic string + FlowID guuid.UUID + Payload []byte +} + +type deadletterT struct{ + uuid guuid.UUID + timestamp time.Time + consumer string + messageID guuid.UUID +} + +type pingerT[T any] struct{ + tryPing func(T) + onPing func(func(T)) + closed func() bool + close func() +} + +type consumerDataT struct{ + topic string + name string +} + +type waiterDataT struct{ + topic string + flowID guuid.UUID + name string + +} + +type consumerT struct{ + data consumerDataT + callback func(Message) error + pinger pingerT[struct{}] + close *func() +} + +type waiterT struct{ + data waiterDataT + pinger pingerT[[]byte] + closed *func() bool + close *func() +} + +type topicSubscriptionT struct{ + consumers map[string]consumerT + waiters map[guuid.UUID]map[string]waiterT +} + +type subscriptionsSetM map[string]topicSubscriptionT + +type subscriptionsT struct { + read func(func(subscriptionsSetM) error) error + write func(func(subscriptionsSetM) error) error +} + +type queueT struct{ + queries queriesT + subscriptions subscriptionsT + pinger pingerT[struct{}] +} + +type argsT struct{ + databasePath string + prefix string + command string + allArgs []string + args []string + topic string + consumer string +} + +type commandT struct{ + name string + getopt func(argsT, io.Writer) (argsT, bool) + exec func(argsT, queriesT, io.Reader, io.Writer) (int, error) +} + +type IQueue interface{ + Publish(UnsentMessage) (Message, error) + Subscribe( string, string, func(Message) error) error + Unsubscribe(string, string) + WaitFor(string, guuid.UUID, string) Waiter + Close() error +} + + + +func closeNoop() error { + return nil +} + +func tryRollback(db *sql.DB, ctx context.Context, err error) error { + _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;") + if rollbackErr != nil { + return fmt.Errorf( + rollbackErrorFmt, + rollbackErr, + err, + ) + } + + return err +} + +func inTx(db *sql.DB, fn func(context.Context) error) error { + ctx := context.Background() + + _, err := db.ExecContext(ctx, "BEGIN IMMEDIATE;") + if err != nil { + return err + } + + err = fn(ctx) + if err != nil { + return tryRollback(db, ctx, err) + } + + _, err = db.ExecContext(ctx, "COMMIT;") + if err != nil { + return tryRollback(db, ctx, err) + } + + return nil +} + +func createTablesSQL(prefix string) queryT { + const tmpl_write = ` + CREATE TABLE IF NOT EXISTS "%s_payloads" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + topic TEXT NOT NULL, + payload BLOB NOT NULL + ) STRICT; + CREATE INDEX IF NOT EXISTS "%s_payloads_topic" + ON "%s_payloads"(topic); + + CREATE TABLE IF NOT EXISTS "%s_messages" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + uuid BLOB NOT NULL UNIQUE, + flow_id BLOB NOT NULL, + payload_id INTEGER NOT NULL + REFERENCES "%s_payloads"(id) + ) STRICT; + CREATE INDEX IF NOT EXISTS "%s_messages_flow_id" + ON "%s_messages"(flow_id); + + CREATE TABLE IF NOT EXISTS "%s_offsets" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + consumer TEXT NOT NULL, + message_id INTEGER NOT NULL + REFERENCES "%s_messages"(id), + UNIQUE (consumer, message_id) + ) STRICT; + CREATE INDEX IF NOT EXISTS "%s_offsets_consumer" + ON "%s_offsets"(consumer); + + CREATE TABLE IF NOT EXISTS "%s_deadletters" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + uuid BLOB NOT NULL UNIQUE, + consumer TEXT NOT NULL, + message_id INTEGER NOT NULL + REFERENCES "%s_messages"(id), + UNIQUE (consumer, message_id) + ) STRICT; + CREATE INDEX IF NOT EXISTS "%s_deadletters_consumer" + ON "%s_deadletters"(consumer); + + CREATE TABLE IF NOT EXISTS "%s_replays" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + deadletter_id INTEGER NOT NULL UNIQUE + REFERENCES "%s_deadletters"(id) , + message_id INTEGER NOT NULL UNIQUE + REFERENCES "%s_messages"(id) + ) STRICT; + + CREATE TABLE IF NOT EXISTS "%s_owners" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + topic TEXT NOT NULL, + consumer TEXT NOT NULL, + owner_id INTEGER NOT NULL, + UNIQUE (topic, consumer) + ) STRICT; + ` + return queryT{ + write: fmt.Sprintf( + tmpl_write, + prefix, + g.SQLiteNow, + prefix, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + } +} + +func createTables(db *sql.DB, prefix string) error { + q := createTablesSQL(prefix) + + return inTx(db, func(ctx context.Context) error { + _, err := db.ExecContext(ctx, q.write) + return err + }) +} + +func takeSQL(prefix string) queryT { + const tmpl_write = ` + INSERT INTO "%s_owners" (topic, consumer, owner_id) + VALUES (?, ?, ?) + ON CONFLICT (topic, consumer) DO + UPDATE SET owner_id=excluded.owner_id; + ` + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} + +func takeStmt( + db *sql.DB, + prefix string, + instanceID int, +) (func(string, string) error, func() error, error) { + q := takeSQL(prefix) + + fn := func(topic string, consumer string) error { + return inTx(db, func(ctx context.Context) error { + _, err := db.ExecContext( + ctx, + q.write, + topic, + consumer, + instanceID, + ) + return err + }) + } + + return fn, closeNoop, nil +} + +func publishSQL(prefix string) queryT { + const tmpl_write = ` + INSERT INTO "%s_payloads" (topic, payload) + VALUES (?, ?); + + INSERT INTO "%s_messages" (uuid, flow_id, payload_id) + VALUES (?, ?, last_insert_rowid()); + ` + const tmpl_read = ` + SELECT id, timestamp FROM "%s_messages" + WHERE uuid = ?; + ` + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix, prefix), + read: fmt.Sprintf(tmpl_read, prefix), + } +} + +func publishStmt( + db *sql.DB, + prefix string, + _ int, +) (func(UnsentMessage, guuid.UUID) (messageT, error), func() error, error) { + q := publishSQL(prefix) + + readStmt, err := db.Prepare(q.read) + if err != nil { + return nil, nil, err + } + + fn := func( + unsentMessage UnsentMessage, + messageID guuid.UUID, + ) (messageT, error) { + message := messageT{ + uuid: messageID, + topic: unsentMessage.Topic, + flowID: unsentMessage.FlowID, + payload: unsentMessage.Payload, + } + + message_id_bytes := messageID[:] + flow_id_bytes := unsentMessage.FlowID[:] + _, err := db.Exec( + q.write, + unsentMessage.Topic, + unsentMessage.Payload, + message_id_bytes, + flow_id_bytes, + ) + if err != nil { + return messageT{}, err + } + + var timestr string + err = readStmt.QueryRow(message_id_bytes).Scan( + &message.id, + ×tr, + ) + if err != nil { + return messageT{}, err + } + + message.timestamp, err = time.Parse(time.RFC3339Nano, timestr) + if err != nil { + return messageT{}, err + } + + return message, nil + } + + return fn, readStmt.Close, nil +} + +func findSQL(prefix string) queryT { + const tmpl_read = ` + SELECT + "%s_messages".id, + "%s_messages".timestamp, + "%s_messages".uuid, + "%s_payloads".payload + FROM "%s_messages" + JOIN "%s_payloads" ON + "%s_payloads".id = "%s_messages".payload_id + WHERE + "%s_payloads".topic = ? AND + "%s_messages".flow_id = ? + ORDER BY "%s_messages".id DESC + LIMIT 1; + ` + return queryT{ + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + } +} + +func findStmt( + db *sql.DB, + prefix string, + _ int, +) (func(string, guuid.UUID) (messageT, error), func() error, error) { + q := findSQL(prefix) + + readStmt, err := db.Prepare(q.read) + if err != nil { + return nil, nil, err + } + + fn := func(topic string, flowID guuid.UUID) (messageT, error) { + message := messageT{ + topic: topic, + flowID: flowID, + } + + var ( + timestr string + message_id_bytes []byte + ) + flow_id_bytes := flowID[:] + err = readStmt.QueryRow(topic, flow_id_bytes).Scan( + &message.id, + ×tr, + &message_id_bytes, + &message.payload, + ) + if err != nil { + return messageT{}, err + } + message.uuid = guuid.UUID(message_id_bytes) + + message.timestamp, err = time.Parse(time.RFC3339Nano, timestr) + if err != nil { + return messageT{}, err + } + + return message, nil + } + + return fn, readStmt.Close, nil +} + +func nextSQL(prefix string) queryT { + const tmpl_read = ` + SELECT + "%s_messages".id, + "%s_messages".timestamp, + "%s_messages".uuid, + "%s_messages".flow_id, + "%s_payloads".payload + FROM "%s_messages" + JOIN "%s_payloads" ON + "%s_payloads".id = "%s_messages".payload_id + WHERE + "%s_payloads".topic = ? AND + "%s_messages".id NOT IN ( + SELECT message_id FROM "%s_offsets" + WHERE consumer = ? + ) + ORDER BY "%s_messages".id ASC + LIMIT 1; + ` + const tmpl_owner = ` + SELECT owner_id FROM "%s_owners" + WHERE + topic = ? AND + consumer = ?; + ` + return queryT{ + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + owner: fmt.Sprintf(tmpl_owner, prefix), + } +} + +func nextStmt( + db *sql.DB, + prefix string, + instanceID int, +) (func(string, string) (messageT, error), func() error, error) { + q := nextSQL(prefix) + + fn := func(topic string, consumer string) (messageT, error) { + message := messageT{ + topic: topic, + } + + var ( + err error + ownerID int + timestr string + message_id_bytes []byte + flow_id_bytes []byte + ) + tx, err := db.Begin() + if err != nil { + return messageT{}, err + } + defer tx.Rollback() + + err = tx.QueryRow(q.owner, topic, consumer).Scan(&ownerID) + if err != nil { + return messageT{}, err + } + + if ownerID != instanceID { + err := fmt.Errorf( + notOwnerErrorFmt, + ownerID, + topic, + consumer, + instanceID, + ) + return messageT{}, err + } + + err = tx.QueryRow(q.read, topic, consumer).Scan( + &message.id, + ×tr, + &message_id_bytes, + &flow_id_bytes, + &message.payload, + ) + if err != nil { + return messageT{}, err + } + message.uuid = guuid.UUID(message_id_bytes) + message.flowID = guuid.UUID(flow_id_bytes) + + message.timestamp, err = time.Parse(time.RFC3339Nano, timestr) + if err != nil { + return messageT{}, err + } + + return message, nil + } + + return fn, closeNoop, nil +} + +func messageEach(rows *sql.Rows, callback func(messageT) error) error { + if rows == nil { + return nil + } + + for rows.Next() { + var ( + message messageT + timestr string + message_id_bytes []byte + flow_id_bytes []byte + ) + err := rows.Scan( + &message.id, + ×tr, + &message_id_bytes, + &flow_id_bytes, + &message.topic, + &message.payload, + ) + if err != nil { + return g.WrapErrors(rows.Close(), err) + } + message.uuid = guuid.UUID(message_id_bytes) + message.flowID = guuid.UUID(flow_id_bytes) + + message.timestamp, err = time.Parse(time.RFC3339Nano, timestr) + if err != nil { + return g.WrapErrors(rows.Close(), err) + } + + err = callback(message) + if err != nil { + return g.WrapErrors(rows.Close(), err) + } + } + + return g.WrapErrors(rows.Err(), rows.Close()) +} + +func pendingSQL(prefix string) queryT { + const tmpl_read = ` + SELECT + "%s_messages".id, + "%s_messages".timestamp, + "%s_messages".uuid, + "%s_messages".flow_id, + "%s_payloads".topic, + "%s_payloads".payload + FROM "%s_messages" + JOIN "%s_payloads" ON + "%s_payloads".id = "%s_messages".payload_id + WHERE + "%s_payloads".topic = ? AND + "%s_messages".id NOT IN ( + SELECT message_id FROM "%s_offsets" + WHERE consumer = ? + ) + ORDER BY "%s_messages".id ASC; + ` + const tmpl_owner = ` + SELECT owner_id FROM "%s_owners" + WHERE + topic = ? AND + consumer = ?; + ` + return queryT{ + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + owner: fmt.Sprintf(tmpl_owner, prefix), + } +} + +func pendingStmt( + db *sql.DB, + prefix string, + instanceID int, +) (func(string, string) (*sql.Rows, error), func() error, error) { + q := pendingSQL(prefix) + + readStmt, err := db.Prepare(q.read) + if err != nil { + return nil, nil, err + } + + ownerStmt, err := db.Prepare(q.owner) + if err != nil { + return nil, nil, err + } + + fn := func(topic string, consumer string) (*sql.Rows, error) { + var ownerID int + err := ownerStmt.QueryRow(topic, consumer).Scan(&ownerID) + if err != nil { + return nil, err + } + + // best effort check, the final one is done during + // commit within a transaction + if ownerID != instanceID { + return nil, nil + } + + return readStmt.Query(topic, consumer) + } + + closeFn := func() error { + return g.SomeFnError(readStmt.Close, ownerStmt.Close) + } + + return fn, closeFn, nil +} + +func commitSQL(prefix string) queryT { + const tmpl_write = ` + INSERT INTO "%s_offsets" (consumer, message_id) + VALUES (?, (SELECT id FROM "%s_messages" WHERE uuid = ?)); + ` + const tmpl_read = ` + SELECT "%s_payloads".topic from "%s_payloads" + JOIN "%s_messages" ON + "%s_payloads".id = "%s_messages".payload_id + WHERE "%s_messages".uuid = ?; + ` + const tmpl_owner = ` + SELECT owner_id FROM "%s_owners" + WHERE + topic = ? AND + consumer = ?; + ` + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix, prefix), + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + owner: fmt.Sprintf(tmpl_owner, prefix), + } +} + +func commitStmt( + db *sql.DB, + prefix string, + instanceID int, +) (func(string, guuid.UUID) error, func() error, error) { + q := commitSQL(prefix) + + fn := func(consumer string, messageID guuid.UUID) error { + message_id_bytes := messageID[:] + return inTx(db, func(ctx context.Context) error { + var topic string + err := db.QueryRowContext( + ctx, + q.read, + message_id_bytes, + ).Scan(&topic) + if err != nil { + return err + } + + var ownerID int + err = db.QueryRowContext( + ctx, + q.owner, + topic, + consumer, + ).Scan(&ownerID) + if err != nil { + return err + } + + if ownerID != instanceID { + return fmt.Errorf( + noLongerOwnerErrorFmt, + instanceID, + topic, + consumer, + ownerID, + ) + } + + _, err = db.ExecContext(ctx, q.write, consumer, message_id_bytes) + return err + }) + } + + return fn, closeNoop, nil +} + +func toDeadSQL(prefix string) queryT { + const tmpl_write = ` + INSERT INTO "%s_offsets" ( consumer, message_id) + VALUES ( ?, (SELECT id FROM "%s_messages" WHERE uuid = ?)); + + INSERT INTO "%s_deadletters" (uuid, consumer, message_id) + VALUES (?, ?, (SELECT id FROM "%s_messages" WHERE uuid = ?)); + ` + const tmpl_read = ` + SELECT "%s_payloads".topic FROM "%s_payloads" + JOIN "%s_messages" ON + "%s_payloads".id = "%s_messages".payload_id + WHERE "%s_messages".uuid = ?; + ` + const tmpl_owner = ` + SELECT owner_id FROM "%s_owners" + WHERE + topic = ? AND + consumer = ?; + ` + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix, prefix, prefix, prefix), + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + owner: fmt.Sprintf(tmpl_owner, prefix), + } +} + +func toDeadStmt( + db *sql.DB, + prefix string, + instanceID int, +) ( + func(string, guuid.UUID, guuid.UUID) error, + func() error, + error, +) { + q := toDeadSQL(prefix) + + fn := func( + consumer string, + messageID guuid.UUID, + deadletterID guuid.UUID, + ) error { + message_id_bytes := messageID[:] + deadletter_id_bytes := deadletterID[:] + return inTx(db, func(ctx context.Context) error { + var topic string + err := db.QueryRowContext( + ctx, + q.read, + message_id_bytes, + ).Scan(&topic) + if err != nil { + return err + } + + var ownerID int + err = db.QueryRowContext( + ctx, + q.owner, + topic, + consumer, + ).Scan(&ownerID) + if err != nil { + return err + } + + if ownerID != instanceID { + return fmt.Errorf( + noLongerOwnerErrorFmt, + instanceID, + topic, + consumer, + ownerID, + ) + } + + _, err = db.ExecContext( + ctx, + q.write, + consumer, + message_id_bytes, + deadletter_id_bytes, + consumer, + message_id_bytes, + ) + return err + }) + } + + return fn, closeNoop, nil +} + +func replaySQL(prefix string) queryT { + const tmpl_write = ` + INSERT INTO "%s_messages" (uuid, flow_id, payload_id) + SELECT + ?, + "%s_messages".flow_id, + "%s_messages".payload_id + FROM "%s_messages" + JOIN "%s_deadletters" ON + "%s_messages".id = "%s_deadletters".message_id + WHERE "%s_deadletters".uuid = ?; + + INSERT INTO "%s_replays" (deadletter_id, message_id) + VALUES ( + (SELECT id FROM "%s_deadletters" WHERE uuid = ?), + last_insert_rowid() + ); + ` + const tmpl_read = ` + SELECT + "%s_messages".id, + "%s_messages".timestamp, + "%s_messages".flow_id, + "%s_payloads".topic, + "%s_payloads".payload + FROM "%s_messages" + JOIN "%s_payloads" ON + "%s_payloads".id = "%s_messages".payload_id + WHERE "%s_messages".uuid = ?; + ` + return queryT{ + write: fmt.Sprintf( + tmpl_write, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + } +} + +func replayStmt( + db *sql.DB, + prefix string, + _ int, +) (func(guuid.UUID, guuid.UUID) (messageT, error), func() error, error) { + q := replaySQL(prefix) + + readStmt, err := db.Prepare(q.read) + if err != nil { + return nil, nil, err + } + + fn := func( + deadletterID guuid.UUID, + messageID guuid.UUID, + ) (messageT, error) { + deadletter_id_bytes := deadletterID[:] + message_id_bytes := messageID[:] + err := inTx(db, func(ctx context.Context) error { + _, err := db.ExecContext( + ctx, + q.write, + message_id_bytes, + deadletter_id_bytes, + deadletter_id_bytes, + ) + return err + }) + if err != nil { + return messageT{}, err + } + + message := messageT{ + uuid: messageID, + } + + var ( + timestr string + flow_id_bytes []byte + ) + err = readStmt.QueryRow(message_id_bytes).Scan( + &message.id, + ×tr, + &flow_id_bytes, + &message.topic, + &message.payload, + ) + if err != nil { + return messageT{}, err + } + message.flowID = guuid.UUID(flow_id_bytes) + + message.timestamp, err = time.Parse(time.RFC3339Nano, timestr) + if err != nil { + return messageT{}, err + } + + return message, nil + } + + return fn, readStmt.Close, nil +} + +func oneDeadSQL(prefix string) queryT { + const tmpl_read = ` + SELECT + "%s_deadletters".uuid, + "%s_offsets".timestamp, + "%s_messages".uuid + FROM "%s_deadletters" + JOIN "%s_offsets" ON + "%s_deadletters".message_id = "%s_offsets".message_id + JOIN "%s_messages" ON + "%s_deadletters".message_id = "%s_messages".id + JOIN "%s_payloads" ON + "%s_messages".payload_id = "%s_payloads".id + WHERE + "%s_payloads".topic = ? AND + "%s_deadletters".consumer = ? AND + "%s_offsets".consumer = ? AND + "%s_deadletters".id NOT IN ( + SELECT deadletter_id FROM "%s_replays" + ) + ORDER BY "%s_deadletters".id ASC + LIMIT 1; + ` + return queryT{ + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + } +} + +func oneDeadStmt( + db *sql.DB, + prefix string, + _ int, +) (func(string, string) (deadletterT, error), func() error, error) { + q := oneDeadSQL(prefix) + + readStmt, err := db.Prepare(q.read) + if err != nil { + return nil, nil, err + } + + fn := func(topic string, consumer string) (deadletterT, error) { + deadletter := deadletterT{ + consumer: consumer, + } + + var ( + deadletter_id_bytes []byte + timestr string + message_id_bytes []byte + ) + err := readStmt.QueryRow(topic, consumer, consumer).Scan( + &deadletter_id_bytes, + ×tr, + &message_id_bytes, + ) + if err != nil { + return deadletterT{}, err + } + deadletter.uuid = guuid.UUID(deadletter_id_bytes) + deadletter.messageID = guuid.UUID(message_id_bytes) + + deadletter.timestamp, err = time.Parse( + time.RFC3339Nano, + timestr, + ) + if err != nil { + return deadletterT{}, err + } + + return deadletter, nil + } + + return fn, readStmt.Close, nil +} + +func deadletterEach( + rows *sql.Rows, + callback func(deadletterT, messageT) error, +) error { + for rows.Next() { + var ( + deadletter deadletterT + deadletter_id_bytes []byte + deadletterTimestr string + message messageT + messageTimestr string + message_id_bytes []byte + flow_id_bytes []byte + ) + err := rows.Scan( + &deadletter_id_bytes, + &message.id, + &deadletterTimestr, + &deadletter.consumer, + &messageTimestr, + &message_id_bytes, + &flow_id_bytes, + &message.topic, + &message.payload, + ) + if err != nil { + return g.WrapErrors(rows.Close(), err) + } + + deadletter.uuid = guuid.UUID(deadletter_id_bytes) + deadletter.messageID = guuid.UUID(message_id_bytes) + message.uuid = guuid.UUID(message_id_bytes) + message.flowID = guuid.UUID(flow_id_bytes) + + message.timestamp, err = time.Parse( + time.RFC3339Nano, + messageTimestr, + ) + if err != nil { + return g.WrapErrors(rows.Close(), err) + } + + deadletter.timestamp, err = time.Parse( + time.RFC3339Nano, + deadletterTimestr, + ) + if err != nil { + return g.WrapErrors(rows.Close(), err) + } + + err = callback(deadletter, message) + if err != nil { + return g.WrapErrors(rows.Close(), err) + } + } + + return g.WrapErrors(rows.Err(), rows.Close()) +} + +func allDeadSQL(prefix string) queryT { + const tmpl_read = ` + SELECT + "%s_deadletters".uuid, + "%s_deadletters".message_id, + "%s_offsets".timestamp, + "%s_offsets".consumer, + "%s_messages".timestamp, + "%s_messages".uuid, + "%s_messages".flow_id, + "%s_payloads".topic, + "%s_payloads".payload + FROM "%s_deadletters" + JOIN "%s_offsets" ON + "%s_deadletters".message_id = "%s_offsets".message_id + JOIN "%s_messages" ON + "%s_deadletters".message_id = "%s_messages".id + JOIN "%s_payloads" ON + "%s_messages".payload_id = "%s_payloads".id + WHERE + "%s_payloads".topic = ? AND + "%s_deadletters".consumer = ? AND + "%s_offsets".consumer = ? AND + "%s_deadletters".id NOT IN ( + SELECT deadletter_id FROM "%s_replays" + ) + ORDER BY "%s_deadletters".id ASC; + ` + return queryT{ + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + } +} + +func allDeadStmt( + db *sql.DB, + prefix string, + _ int, +) (func(string, string) (*sql.Rows, error), func() error, error) { + q := allDeadSQL(prefix) + + readStmt, err := db.Prepare(q.read) + if err != nil { + return nil, nil, err + } + + fn := func(topic string, consumer string) (*sql.Rows, error) { + return readStmt.Query(topic, consumer, consumer) + } + + return fn, readStmt.Close, nil +} + +func sizeSQL(prefix string) queryT { + const tmpl_read = ` + SELECT + COUNT(1) as size + FROM "%s_messages" + JOIN "%s_payloads" ON + "%s_messages".payload_id = "%s_payloads".id + WHERE "%s_payloads".topic = ?; + ` + return queryT{ + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + } +} + + +func sizeStmt( + db *sql.DB, + prefix string, + _ int, +) (func(string) (int, error), func() error, error) { + q := sizeSQL(prefix) + + readStmt, err := db.Prepare(q.read) + if err != nil { + return nil, nil, err + } + + fn := func(topic string) (int, error) { + var size int + err := readStmt.QueryRow(topic).Scan(&size) + if err != nil { + return -1, err + } + + return size, nil + } + + return fn, readStmt.Close, nil +} + +func countSQL(prefix string) queryT { + const tmpl_read = ` + SELECT + COUNT(1) as count + FROM "%s_messages" + JOIN "%s_offsets" ON + "%s_messages".id = "%s_offsets".message_id + JOIN "%s_payloads" ON + "%s_messages".payload_id = "%s_payloads".id + WHERE + "%s_payloads".topic = ? AND + "%s_offsets".consumer = ?; + ` + return queryT{ + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + } +} + +func countStmt( + db *sql.DB, + prefix string, + _ int, +) (func(string, string) (int, error), func() error, error) { + q := countSQL(prefix) + + readStmt, err := db.Prepare(q.read) + if err != nil { + return nil, nil, err + } + + fn := func(topic string, consumer string) (int, error) { + var count int + err := readStmt.QueryRow(topic, consumer).Scan(&count) + if err != nil { + return -1, err + } + + return count, nil + } + + return fn, readStmt.Close, nil +} + +func hasDataSQL(prefix string) queryT { + const tmpl_read = ` + SELECT 1 as data + FROM "%s_messages" + JOIN "%s_payloads" ON + "%s_payloads".id = "%s_messages".payload_id + WHERE + "%s_payloads".topic = ? AND + "%s_messages".id NOT IN ( + SELECT message_id FROM "%s_offsets" + WHERE consumer = ? + ) + LIMIT 1; + ` + return queryT{ + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + } +} + +func hasDataStmt( + db *sql.DB, + prefix string, + _ int, +) (func(string, string) (bool, error), func() error, error) { + q := hasDataSQL(prefix) + + readStmt, err := db.Prepare(q.read) + if err != nil { + return nil, nil, err + } + + fn := func(topic string, consumer string) (bool, error) { + var _x int + err := readStmt.QueryRow(topic, consumer).Scan(&_x) + if err == sql.ErrNoRows { + return false, nil + } + + if err != nil { + return false, err + } + + return true, nil + } + + return fn, readStmt.Close, nil +} + +func initDB( + db *sql.DB, + prefix string, + notifyFn func(messageT), + instanceID int, +) (queriesT, error) { + createTablesErr := createTables(db, prefix) + take, takeClose, takeErr := takeStmt(db, prefix, instanceID) + publish, publishClose, publishErr := publishStmt(db, prefix, instanceID) + find, findClose, findErr := findStmt(db, prefix, instanceID) + next, nextClose, nextErr := nextStmt(db, prefix, instanceID) + pending, pendingClose, pendingErr := pendingStmt(db, prefix, instanceID) + commit, commitClose, commitErr := commitStmt(db, prefix, instanceID) + toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID) + replay, replayClose, replayErr := replayStmt(db, prefix, instanceID) + oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(db, prefix, instanceID) + allDead, allDeadClose, allDeadErr := allDeadStmt(db, prefix, instanceID) + size, sizeClose, sizeErr := sizeStmt(db, prefix, instanceID) + count, countClose, countErr := countStmt(db, prefix, instanceID) + hasData, hasDataClose, hasDataErr := hasDataStmt(db, prefix, instanceID) + + err := g.SomeError( + createTablesErr, + takeErr, + publishErr, + findErr, + nextErr, + pendingErr, + commitErr, + toDeadErr, + replayErr, + oneDeadErr, + allDeadErr, + sizeErr, + countErr, + hasDataErr, + ) + if err != nil { + return queriesT{}, err + } + + close := func() error { + return g.SomeFnError( + takeClose, + publishClose, + findClose, + nextClose, + pendingClose, + commitClose, + toDeadClose, + replayClose, + oneDeadClose, + allDeadClose, + sizeClose, + countClose, + hasDataClose, + ) + } + + var connMutex sync.RWMutex + return queriesT{ + take: func(a string, b string) error { + connMutex.RLock() + defer connMutex.RUnlock() + return take(a, b) + }, + publish: func(a UnsentMessage, b guuid.UUID) (messageT, error) { + var ( + err error + message messageT + ) + { + connMutex.RLock() + defer connMutex.RUnlock() + message, err = publish(a, b) + } + if err != nil { + return messageT{}, err + } + + go notifyFn(message) + return message, nil + }, + find: func(a string, b guuid.UUID) (messageT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return find(a, b) + }, + next: func(a string, b string) (messageT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return next(a, b) + }, + pending: func( + a string, + b string, + callback func(messageT) error, + ) error { + var ( + err error + rows *sql.Rows + ) + { + connMutex.RLock() + defer connMutex.RUnlock() + rows, err = pending(a, b) + } + if err != nil { + return err + } + + return messageEach(rows, callback) + }, + commit: func(a string, b guuid.UUID) error { + connMutex.RLock() + defer connMutex.RUnlock() + return commit(a, b) + }, + toDead: func( + a string, + b guuid.UUID, + c guuid.UUID, + ) error { + connMutex.RLock() + defer connMutex.RUnlock() + return toDead(a, b, c) + }, + replay: func(a guuid.UUID, b guuid.UUID) (messageT, error) { + var ( + err error + message messageT + ) + { + connMutex.RLock() + defer connMutex.RUnlock() + message, err = replay(a, b) + } + if err != nil { + return messageT{}, err + } + + go notifyFn(message) + return message, nil + }, + oneDead: func(a string, b string) (deadletterT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return oneDead(a, b) + }, + allDead: func( + a string, + b string, + callback func(deadletterT, messageT) error, + ) error { + var ( + err error + rows *sql.Rows + ) + { + connMutex.RLock() + defer connMutex.RUnlock() + rows, err = allDead(a, b) + } + if err != nil { + return err + } + + return deadletterEach(rows, callback) + }, + size: func(a string) (int, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return size(a) + }, + count: func(a string, b string) (int, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return count(a, b) + }, + hasData: func(a string, b string) (bool, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return hasData(a, b) + }, + close: func() error { + connMutex.Lock() + defer connMutex.Unlock() + return close() + }, + }, nil +} + + +func newPinger[T any]() pingerT[T] { + channel := make(chan T, 1) + closed := false + var rwmutex sync.RWMutex + return pingerT[T]{ + tryPing: func(x T) { + rwmutex.RLock() + defer rwmutex.RUnlock() + if closed { + return + } + select { + case channel <- x: + default: + } + }, + onPing: func(cb func(T)) { + for x := range channel { + cb(x) + } + }, + closed: func() bool { + rwmutex.RLock() + defer rwmutex.RUnlock() + return closed + }, + close: func() { + rwmutex.Lock() + defer rwmutex.Unlock() + if closed { + return + } + close(channel) + closed = true + }, + } +} + +func makeSubscriptionsFuncs() subscriptionsT { + var rwmutex sync.RWMutex + subscriptions := subscriptionsSetM{} + return subscriptionsT{ + read: func(callback func(subscriptionsSetM) error) error { + rwmutex.RLock() + defer rwmutex.RUnlock() + return callback(subscriptions) + }, + write: func(callback func(subscriptionsSetM) error) error { + rwmutex.Lock() + defer rwmutex.Unlock() + return callback(subscriptions) + }, + } +} + +// Try notifying the consumer that they have data to work with. If they're +// already full, we simply drop the notification, as on each they'll look for +// all pending messages and process them all. So dropping the event here +// doesn't mean not notifying the consumer, but simply acknoledging that the +// existing notifications are enough for them to work with, without letting any +// message slip through. +func makeNotifyFn( + readFn func(func(subscriptionsSetM) error) error, + pinger pingerT[struct{}], +) func(messageT) { + return func(message messageT) { + readFn(func(set subscriptionsSetM) error { + topicSub, ok := set[message.topic] + if !ok { + return nil + } + + for _, consumer := range topicSub.consumers { + consumer.pinger.tryPing(struct{}{}) + } + waiters := topicSub.waiters[message.flowID] + for _, waiter := range waiters { + waiter.pinger.tryPing(message.payload) + } + return nil + }) + pinger.tryPing(struct{}{}) + } +} + +func collectClosedWaiters( + set subscriptionsSetM, +) map[string]map[guuid.UUID][]string { + waiters := map[string]map[guuid.UUID][]string{} + for topic, topicSub := range set { + waiters[topic] = map[guuid.UUID][]string{} + for flowID, waitersByName := range topicSub.waiters { + names := []string{} + for name, waiter := range waitersByName { + if (*waiter.closed)() { + names = append(names, name) + } + } + waiters[topic][flowID] = names + } + } + + return waiters +} + +func trimEmptyLeaves(closedWaiters map[string]map[guuid.UUID][]string) { + for topic, waiters := range closedWaiters { + for flowID, names := range waiters { + if len(names) == 0 { + delete(closedWaiters[topic], flowID) + } + } + if len(waiters) == 0 { + delete(closedWaiters, topic) + } + } +} + +func deleteIfEmpty(set subscriptionsSetM, topic string) { + topicSub, ok := set[topic] + if !ok { + return + } + + emptyConsumers := len(topicSub.consumers) == 0 + emptyWaiters := len(topicSub.waiters) == 0 + if emptyConsumers && emptyWaiters { + delete(set, topic) + } +} + +func deleteEmptyTopics(set subscriptionsSetM) { + for topic, _ := range set { + deleteIfEmpty(set, topic) + } +} + +func removeClosedWaiters( + set subscriptionsSetM, + closedWaiters map[string]map[guuid.UUID][]string, +) { + for topic, waiters := range closedWaiters { + _, ok := set[topic] + if !ok { + continue + } + + for flowID, names := range waiters { + if set[topic].waiters[flowID] == nil { + continue + } + for _, name := range names { + delete(set[topic].waiters[flowID], name) + } + if len(set[topic].waiters[flowID]) == 0 { + delete(set[topic].waiters, flowID) + } + } + } + + deleteEmptyTopics(set) +} + +func reapClosedWaiters( + readFn func(func(subscriptionsSetM) error) error, + writeFn func(func(subscriptionsSetM) error) error, +) { + var closedWaiters map[string]map[guuid.UUID][]string + readFn(func(set subscriptionsSetM) error { + closedWaiters = collectClosedWaiters(set) + return nil + }) + + trimEmptyLeaves(closedWaiters) + if len(closedWaiters) == 0 { + return + } + + writeFn(func(set subscriptionsSetM) error { + removeClosedWaiters(set, closedWaiters) + return nil + }) +} + +func everyNthCall[T any](n int, fn func(T)) func(T) { + i := 0 + return func(x T) { + i++ + if i == n { + i = 0 + fn(x) + } + } +} + +func runReaper( + onPing func(func(struct{})), + readFn func(func(subscriptionsSetM) error) error, + writeFn func(func(subscriptionsSetM) error) error, +) { + onPing(everyNthCall(reaperSkipCount, func(struct{}) { + reapClosedWaiters(readFn, writeFn) + })) +} + +func NewWithPrefix(db *sql.DB, prefix string) (IQueue, error) { + err := g.ValidateSQLTablePrefix(prefix) + if err != nil { + return queueT{}, err + } + + subscriptions := makeSubscriptionsFuncs() + pinger := newPinger[struct{}]() + notifyFn := makeNotifyFn(subscriptions.read, pinger) + queries, err := initDB(db, prefix, notifyFn, os.Getpid()) + if err != nil { + return queueT{}, err + } + + go runReaper(pinger.onPing, subscriptions.read, subscriptions.write) + + return queueT{ + queries: queries, + subscriptions: subscriptions, + pinger: pinger, + }, nil +} + +func New(db *sql.DB) (IQueue, error) { + return NewWithPrefix(db, defaultPrefix) +} + +func asPublicMessage(message messageT) Message { + return Message{ + ID: message.uuid, + Timestamp: message.timestamp, + Topic: message.topic, + FlowID: message.flowID, + Payload: message.payload, + } +} + +func (queue queueT) Publish(unsent UnsentMessage) (Message, error) { + message, err := queue.queries.publish(unsent, guuid.New()) + return asPublicMessage(message), err +} + +func registerConsumerFn(consumer consumerT) func(subscriptionsSetM) error { + topicSub := topicSubscriptionT{ + consumers: map[string]consumerT{}, + waiters: map[guuid.UUID]map[string]waiterT{}, + } + + return func(set subscriptionsSetM) error { + topic := consumer.data.topic + _, ok := set[topic] + if !ok { + set[topic] = topicSub + } + set[topic].consumers[consumer.data.name] = consumer + + return nil + } +} + +func registerWaiterFn(waiter waiterT) func(subscriptionsSetM) error { + topicSub := topicSubscriptionT{ + consumers: map[string]consumerT{}, + waiters: map[guuid.UUID]map[string]waiterT{}, + } + waiters := map[string]waiterT{} + + return func(set subscriptionsSetM) error { + var ( + topic = waiter.data.topic + flowID = waiter.data.flowID + ) + _, ok := set[topic] + if !ok { + set[topic] = topicSub + } + if set[topic].waiters[flowID] == nil { + set[topic].waiters[flowID] = waiters + } + set[topic].waiters[flowID][waiter.data.name] = waiter + return nil + } +} + +func makeConsumeOneFn( + data consumerDataT, + callback func(Message) error, + successFn func(string, guuid.UUID) error, + errorFn func(string, guuid.UUID, guuid.UUID) error, +) func(messageT) error { + return func(message messageT) error { + err := callback(asPublicMessage(message)) + if err != nil { + g.Info( + "consumer failed", "q-consumer", + "topic", data.topic, + "consumer", data.name, + "error", err, + slog.Group( + "message", + "id", message.id, + "flow-id", message.flowID.String(), + ), + ) + + return errorFn(data.name, message.uuid, guuid.New()) + } + + return successFn(data.name, message.uuid) + } +} + +func makeConsumeAllFn( + data consumerDataT, + consumeOneFn func(messageT) error, + eachFn func(string, string, func(messageT) error) error, +) func(struct{}) { + return func(struct{}) { + err := eachFn(data.topic, data.name, consumeOneFn) + if err != nil { + g.Warning( + "eachFn failed", "q-consume-all", + "topic", data.topic, + "consumer", data.name, + "error", err, + "circuit-breaker-enabled?", false, + ) + } + } +} + +func makeWaitFn(channel chan []byte, closeFn func()) func([]byte) { + closed := false + var mutex sync.Mutex + return func(payload []byte) { + mutex.Lock() + defer mutex.Unlock() + if closed { + return + } + + closeFn() + channel <- payload + close(channel) + closed = true + } +} + +func runConsumer(onPing func(func(struct{})), consumeAllFn func(struct{})) { + consumeAllFn(struct{}{}) + onPing(consumeAllFn) +} + +func tryFinding( + findFn func(string, guuid.UUID) (messageT, error), + topic string, + flowID guuid.UUID, + waitFn func([]byte), +) { + message, err := findFn(topic, flowID) + if err != nil { + return + } + + waitFn(message.payload) +} + +func (queue queueT) Subscribe( + topic string, + name string, + callback func(Message) error, +) error { + data := consumerDataT{ + topic: topic, + name: name, + } + pinger := newPinger[struct{}]() + consumer := consumerT{ + data: data, + callback: callback, + pinger: pinger, + close: &pinger.close, + } + consumeOneFn := makeConsumeOneFn( + consumer.data, + consumer.callback, + queue.queries.commit, + queue.queries.toDead, + ) + consumeAllFn := makeConsumeAllFn( + consumer.data, + consumeOneFn, + queue.queries.pending, + ) + + err := queue.queries.take(topic, name) + if err != nil { + return err + } + + queue.subscriptions.write(registerConsumerFn(consumer)) + go runConsumer(pinger.onPing, consumeAllFn) + return nil +} + +type Waiter struct{ + Channel <-chan []byte + Close func() +} + +func (queue queueT) WaitFor( + topic string, + flowID guuid.UUID, + name string, +) Waiter { + data := waiterDataT{ + topic: topic, + flowID: flowID, + name: name, + } + pinger := newPinger[[]byte]() + waiter := waiterT{ + data: data, + pinger: pinger, + closed: &pinger.closed, + close: &pinger.close, + } + channel := make(chan []byte, 1) + waitFn := makeWaitFn(channel, (*waiter.close)) + closeFn := func() { + queue.subscriptions.read(func(set subscriptionsSetM) error { + (*set[topic].waiters[flowID][name].close)() + return nil + }) + } + + queue.subscriptions.write(registerWaiterFn(waiter)) + tryFinding(queue.queries.find, topic, flowID, waitFn) + go pinger.onPing(waitFn) + return Waiter{channel, closeFn} +} + +func unsubscribeIfExistsFn( + topic string, + name string, +) func(subscriptionsSetM) error { + return func(set subscriptionsSetM) error { + topicSub, ok := set[topic] + if !ok { + return nil + } + + consumer, ok := topicSub.consumers[name] + if !ok { + return nil + } + + (*consumer.close)() + delete(set[topic].consumers, name) + deleteIfEmpty(set, topic) + return nil + } +} + +func (queue queueT) Unsubscribe(topic string, name string) { + queue.subscriptions.write(unsubscribeIfExistsFn(topic, name)) +} + +func cleanSubscriptions(set subscriptionsSetM) error { + for _, topicSub := range set { + for _, consumer := range topicSub.consumers { + (*consumer.close)() + } + for _, waiters := range topicSub.waiters { + for _, waiter := range waiters { + (*waiter.close)() + } + } + } + return nil +} + +func (queue queueT) Close() error { + queue.pinger.close() + return g.WrapErrors( + queue.subscriptions.write(cleanSubscriptions), + queue.queries.close(), + ) +} + + +func topicGetopt(args argsT, w io.Writer) (argsT, bool) { + if len(args.args) == 0 { + fmt.Fprintf(w, "Missing TOPIC.\n") + return args, false + } + + args.topic = args.args[0] + return args, true +} + +func topicConsumerGetopt(args argsT, w io.Writer) (argsT, bool) { + fs := flag.NewFlagSet("", flag.ContinueOnError) + fs.Usage = func() {} + fs.SetOutput(w) + + consumer := fs.String( + "C", + "default-consumer", + "The name of the consumer to be used", + ) + + if fs.Parse(args.args) != nil { + return args, false + } + + subArgs := fs.Args() + if len(subArgs) == 0 { + fmt.Fprintf(w, "Missing TOPIC.\n") + return args, false + } + + args.consumer = *consumer + args.topic = subArgs[0] + return args, true +} + +func inExec( + args argsT, + queries queriesT, + r io.Reader, + w io.Writer, +) (int, error) { + payload, err := io.ReadAll(r) + if err != nil { + return 1, err + } + + unsent := UnsentMessage{ + Topic: args.topic, + FlowID: guuid.New(), + Payload: payload, + } + message, err := queries.publish(unsent, guuid.New()) + if err != nil { + return 1, err + } + + fmt.Fprintf(w, "%s\n", message.uuid.String()) + + return 0, nil +} + +func outExec( + args argsT, + queries queriesT, + _ io.Reader, + w io.Writer, +) (int, error) { + err := queries.take(args.topic, args.consumer) + if err != nil { + return 1, err + } + + message, err := queries.next(args.topic, args.consumer) + + if err == sql.ErrNoRows { + return 3, nil + } + + if err != nil { + return 1, err + } + + fmt.Fprintln(w, string(message.payload)) + + return 0, nil +} + +func commitExec( + args argsT, + queries queriesT, + r io.Reader, + w io.Writer, +) (int, error) { + err := queries.take(args.topic, args.consumer) + if err != nil { + return 1, err + } + + message, err := queries.next(args.topic, args.consumer) + if err != nil { + return 1, err + } + + err = queries.commit(args.consumer, message.uuid) + if err != nil { + return 1, err + } + + return 0, nil +} + +func deadExec( + args argsT, + queries queriesT, + r io.Reader, + w io.Writer, +) (int, error) { + err := queries.take(args.topic, args.consumer) + if err != nil { + return 1, err + } + + message, err := queries.next(args.topic, args.consumer) + if err != nil { + return 1, err + } + + err = queries.toDead(args.consumer, message.uuid, guuid.New()) + if err != nil { + return 1, err + } + + return 0, nil +} + +func listDeadExec( + args argsT, + queries queriesT, + r io.Reader, + w io.Writer, +) (int, error) { + eachFn := func(deadletter deadletterT, _ messageT) error { + fmt.Fprintf( + w, + "%s\t%s\t%s\n", + deadletter.uuid.String(), + deadletter.timestamp.Format(time.RFC3339), + deadletter.consumer, + ) + return nil + } + + err := queries.allDead(args.topic, args.consumer, eachFn) + if err != nil { + return 1, err + } + + return 0, nil +} + +func replayExec( + args argsT, + queries queriesT, + r io.Reader, + w io.Writer, +) (int, error) { + deadletter, err := queries.oneDead(args.topic, args.consumer) + if err != nil { + return 1, err + } + + _, err = queries.replay(deadletter.uuid, guuid.New()) + if err != nil { + return 1, err + } + + return 0, nil +} + +func sizeExec( + args argsT, + queries queriesT, + r io.Reader, + w io.Writer, +) (int, error) { + size, err := queries.size(args.topic) + if err != nil { + return 1, err + } + + fmt.Fprintln(w, size) + + return 0, nil +} + +func countExec( + args argsT, + queries queriesT, + r io.Reader, + w io.Writer, +) (int, error) { + count, err := queries.count(args.topic, args.consumer) + if err != nil { + return 1, err + } + + fmt.Fprintln(w, count) + + return 0, nil +} + +func hasDataExec( + args argsT, + queries queriesT, + r io.Reader, + w io.Writer, +) (int, error) { + hasData, err := queries.hasData(args.topic, args.consumer) + if err != nil { + return 1, err + } + + if hasData { + return 0, nil + } else { + return 1, nil + } +} + +func usage(argv0 string, w io.Writer) { + fmt.Fprintf( + w, + "Usage: %s [-f FILE] [-p PREFIX] COMMAND [OPTIONS]\n", + argv0, + ) +} + +func getopt( + allArgs []string, + commandsMap map[string]commandT, + w io.Writer, +) (argsT, commandT, int) { + argv0 := allArgs[0] + argv := allArgs[1:] + fs := flag.NewFlagSet("", flag.ContinueOnError) + fs.Usage = func() {} + fs.SetOutput(w) + databasePath := fs.String( + "f", + "q.db", + "The path to the file where the queue is kept", + ) + prefix := fs.String( + "p", + defaultPrefix, + "The q prefix of the table names", + ) + if fs.Parse(argv) != nil { + usage(argv0, w) + return argsT{}, commandT{}, 2 + } + + subArgs := fs.Args() + if len(subArgs) == 0 { + fmt.Fprintf(w, "Missing COMMAND.\n") + usage(argv0, w) + return argsT{}, commandT{}, 2 + } + + args := argsT{ + databasePath: *databasePath, + prefix: *prefix, + command: subArgs[0], + allArgs: allArgs, + args: subArgs[1:], + } + + command := commandsMap[args.command] + if command.name == "" { + fmt.Fprintf(w, "Bad COMMAND: \"%s\".\n", args.command) + usage(allArgs[0], w) + return argsT{}, commandT{}, 2 + } + + args, ok := command.getopt(args, w) + if !ok { + usage(allArgs[0], w) + return argsT{}, commandT{}, 2 + } + + return args, command, 0 +} + +func runCommand( + args argsT, + command commandT, + stdin io.Reader, + stdout io.Writer, + stderr io.Writer, +) int { + db, err := sql.Open("acude", args.databasePath) + if err != nil { + fmt.Fprintln(stderr, err) + return 1 + } + defer db.Close() + + iqueue, err := NewWithPrefix(db, args.prefix) + if err != nil { + fmt.Fprintln(stderr, err) + return 1 + } + defer iqueue.Close() + + rc, err := command.exec(args, iqueue.(queueT).queries, stdin, stdout) + if err != nil { + fmt.Fprintln(stderr, err) + } + + return rc +} + +var commands = map[string]commandT { + "in": commandT{ + name: "in", + getopt: topicGetopt, + exec: inExec, + }, + "out": commandT{ + name: "out", + getopt: topicConsumerGetopt, + exec: outExec, + }, + "commit": commandT{ + name: "commit", + getopt: topicConsumerGetopt, + exec: commitExec, + }, + "dead": commandT{ + name: "dead", + getopt: topicConsumerGetopt, + exec: deadExec, + }, + "ls-dead": commandT{ + name: "ls-dead", + getopt: topicConsumerGetopt, + exec: listDeadExec, + }, + "replay": commandT{ + name: "replay", + getopt: topicConsumerGetopt, + exec: replayExec, + }, + "size": commandT{ + name: "size", + getopt: topicGetopt, + exec: sizeExec, + }, + "count": commandT{ + name: "count", + getopt: topicConsumerGetopt, + exec: countExec, + }, + "has-data": commandT{ + name: "has-data", + getopt: topicConsumerGetopt, + exec: hasDataExec, + }, +} + + + +func Main() { + g.Init() + args, command, rc := getopt(os.Args, commands, os.Stderr) + if rc != 0 { + os.Exit(2) + } + os.Exit(runCommand(args, command, os.Stdin, os.Stdout, os.Stderr)) +} |