summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorEuAndreh <eu@euandre.org>2024-09-17 08:01:05 -0300
committerEuAndreh <eu@euandre.org>2024-10-20 07:39:33 -0300
commitab1795aeb8f00b61c331ac77fdc1011ec14c5253 (patch)
tree507b72b45f23f8a1bf1a1684a842fef51f1139a8 /src
parentInit Go project skeleton with golite init (diff)
downloadfiinha-ab1795aeb8f00b61c331ac77fdc1011ec14c5253.tar.gz
fiinha-ab1795aeb8f00b61c331ac77fdc1011ec14c5253.tar.xz
Initial version: first implementation
Diffstat (limited to 'src')
-rw-r--r--src/liteq.go117
-rw-r--r--src/main.go4
-rw-r--r--src/q.go2495
3 files changed, 2497 insertions, 119 deletions
diff --git a/src/liteq.go b/src/liteq.go
deleted file mode 100644
index 2eeff34..0000000
--- a/src/liteq.go
+++ /dev/null
@@ -1,117 +0,0 @@
-package liteq
-
-import (
- "database/sql"
- "flag"
- "io/ioutil"
- "log/slog"
- "os"
- "sort"
-
- g "gobang"
- "golite"
-)
-
-
-
-func InitMigrations(db *sql.DB) {
- _, err := db.Exec(`
- CREATE TABLE IF NOT EXISTS migrations (
- filename TEXT PRIMARY KEY
- );
- `)
- g.FatalIf(err)
-}
-
-const MIGRATIONS_DIR = "src/sql/migrations/"
-func PendingMigrations(db *sql.DB) []string {
- files, err := ioutil.ReadDir(MIGRATIONS_DIR)
- g.FatalIf(err)
-
- set := make(map[string]bool)
- for _, file := range files {
- set[file.Name()] = true
- }
-
- rows, err := db.Query(`SELECT filename FROM migrations;`)
- g.FatalIf(err)
- defer rows.Close()
-
- for rows.Next() {
- var filename string
- err := rows.Scan(&filename)
- g.FatalIf(err)
- delete(set, filename)
- }
- g.FatalIf(rows.Err())
-
- difference := make([]string, 0)
- for filename := range set {
- difference = append(difference, filename)
- }
-
- sort.Sort(sort.StringSlice(difference))
- return difference
-}
-
-func RunMigrations(db *sql.DB) {
- InitMigrations(db)
-
- stmt, err := db.Prepare(`INSERT INTO migrations (filename) VALUES (?);`)
- g.FatalIf(err)
- defer stmt.Close()
-
- for _, filename := range PendingMigrations(db) {
- g.Info("Running migration file", "exec-migration-file",
- "filename", filename,
- )
-
- tx, err := db.Begin()
- g.FatalIf(err)
-
- sql, err := os.ReadFile(MIGRATIONS_DIR + filename)
- g.FatalIf(err)
-
- _, err = tx.Exec(string(sql))
- g.FatalIf(err)
-
- _, err = tx.Stmt(stmt).Exec(filename)
- g.FatalIf(err)
-
- err = tx.Commit()
- g.FatalIf(err)
- }
-}
-
-func initDB(databasePath string) *sql.DB {
- db, err := sql.Open("sqlite3", databasePath)
- g.FatalIf(err)
- RunMigrations(db)
- return db
-}
-
-func run(db *sql.DB) {
-}
-
-
-
-var (
- databasePath = flag.String(
- "f",
- "q.db",
- "The path to the database file",
- )
-)
-
-
-func Main() {
- g.Init(slog.Group(
- "versions",
- "gobang", g.Version,
- "golite", golite.Version,
- "this", version,
- ))
- flag.Parse()
- db := initDB(*databasePath)
- run(db)
-}
diff --git a/src/main.go b/src/main.go
index 8d9a05e..51faffa 100644
--- a/src/main.go
+++ b/src/main.go
@@ -1,7 +1,7 @@
package main
-import "liteq"
+import "q"
func main() {
- liteq.Main()
+ q.Main()
}
diff --git a/src/q.go b/src/q.go
new file mode 100644
index 0000000..6eeefe6
--- /dev/null
+++ b/src/q.go
@@ -0,0 +1,2495 @@
+package q
+import (
+ "context"
+ "database/sql"
+ "flag"
+ "fmt"
+ "io"
+ "log/slog"
+ "os"
+ "sync"
+ "time"
+
+ _ "acudego"
+ "guuid"
+ g "gobang"
+)
+
+
+
+const (
+ defaultPrefix = "q"
+ reaperSkipCount = 1000
+ notOwnerErrorFmt = "%v owns %#v as %#v, not us (%v)"
+ noLongerOwnerErrorFmt = "we (%v) no longer own %#v as %#v, but %v does"
+ rollbackErrorFmt = "rollback error: %w; while executing: %w"
+)
+
+
+
+type dbI interface{
+ findOne(q string, args []any, bindings []any) error
+ exec(q string)
+}
+
+type queryT struct{
+ write string
+ read string
+ owner string
+}
+
+type queriesT struct{
+ take func(string, string) error
+ publish func(UnsentMessage, guuid.UUID) (messageT, error)
+ find func(string, guuid.UUID) (messageT, error)
+ next func(string, string) (messageT, error)
+ pending func(string, string, func(messageT) error) error
+ commit func(string, guuid.UUID) error
+ toDead func(string, guuid.UUID, guuid.UUID) error
+ replay func(guuid.UUID, guuid.UUID) (messageT, error)
+ oneDead func(string, string) (deadletterT, error)
+ allDead func(string, string, func(deadletterT, messageT) error) error
+ size func(string) (int, error)
+ count func(string, string) (int, error)
+ hasData func(string, string) (bool, error)
+ close func() error
+}
+
+type messageT struct{
+ id int64
+ timestamp time.Time
+ uuid guuid.UUID
+ topic string
+ flowID guuid.UUID
+ payload []byte
+}
+
+type UnsentMessage struct{
+ Topic string
+ FlowID guuid.UUID
+ Payload []byte
+}
+
+type Message struct{
+ ID guuid.UUID
+ Timestamp time.Time
+ Topic string
+ FlowID guuid.UUID
+ Payload []byte
+}
+
+type deadletterT struct{
+ uuid guuid.UUID
+ timestamp time.Time
+ consumer string
+ messageID guuid.UUID
+}
+
+type pingerT[T any] struct{
+ tryPing func(T)
+ onPing func(func(T))
+ closed func() bool
+ close func()
+}
+
+type consumerDataT struct{
+ topic string
+ name string
+}
+
+type waiterDataT struct{
+ topic string
+ flowID guuid.UUID
+ name string
+
+}
+
+type consumerT struct{
+ data consumerDataT
+ callback func(Message) error
+ pinger pingerT[struct{}]
+ close *func()
+}
+
+type waiterT struct{
+ data waiterDataT
+ pinger pingerT[[]byte]
+ closed *func() bool
+ close *func()
+}
+
+type topicSubscriptionT struct{
+ consumers map[string]consumerT
+ waiters map[guuid.UUID]map[string]waiterT
+}
+
+type subscriptionsSetM map[string]topicSubscriptionT
+
+type subscriptionsT struct {
+ read func(func(subscriptionsSetM) error) error
+ write func(func(subscriptionsSetM) error) error
+}
+
+type queueT struct{
+ queries queriesT
+ subscriptions subscriptionsT
+ pinger pingerT[struct{}]
+}
+
+type argsT struct{
+ databasePath string
+ prefix string
+ command string
+ allArgs []string
+ args []string
+ topic string
+ consumer string
+}
+
+type commandT struct{
+ name string
+ getopt func(argsT, io.Writer) (argsT, bool)
+ exec func(argsT, queriesT, io.Reader, io.Writer) (int, error)
+}
+
+type IQueue interface{
+ Publish(UnsentMessage) (Message, error)
+ Subscribe( string, string, func(Message) error) error
+ Unsubscribe(string, string)
+ WaitFor(string, guuid.UUID, string) Waiter
+ Close() error
+}
+
+
+
+func closeNoop() error {
+ return nil
+}
+
+func tryRollback(db *sql.DB, ctx context.Context, err error) error {
+ _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;")
+ if rollbackErr != nil {
+ return fmt.Errorf(
+ rollbackErrorFmt,
+ rollbackErr,
+ err,
+ )
+ }
+
+ return err
+}
+
+func inTx(db *sql.DB, fn func(context.Context) error) error {
+ ctx := context.Background()
+
+ _, err := db.ExecContext(ctx, "BEGIN IMMEDIATE;")
+ if err != nil {
+ return err
+ }
+
+ err = fn(ctx)
+ if err != nil {
+ return tryRollback(db, ctx, err)
+ }
+
+ _, err = db.ExecContext(ctx, "COMMIT;")
+ if err != nil {
+ return tryRollback(db, ctx, err)
+ }
+
+ return nil
+}
+
+func createTablesSQL(prefix string) queryT {
+ const tmpl_write = `
+ CREATE TABLE IF NOT EXISTS "%s_payloads" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ topic TEXT NOT NULL,
+ payload BLOB NOT NULL
+ ) STRICT;
+ CREATE INDEX IF NOT EXISTS "%s_payloads_topic"
+ ON "%s_payloads"(topic);
+
+ CREATE TABLE IF NOT EXISTS "%s_messages" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ uuid BLOB NOT NULL UNIQUE,
+ flow_id BLOB NOT NULL,
+ payload_id INTEGER NOT NULL
+ REFERENCES "%s_payloads"(id)
+ ) STRICT;
+ CREATE INDEX IF NOT EXISTS "%s_messages_flow_id"
+ ON "%s_messages"(flow_id);
+
+ CREATE TABLE IF NOT EXISTS "%s_offsets" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ consumer TEXT NOT NULL,
+ message_id INTEGER NOT NULL
+ REFERENCES "%s_messages"(id),
+ UNIQUE (consumer, message_id)
+ ) STRICT;
+ CREATE INDEX IF NOT EXISTS "%s_offsets_consumer"
+ ON "%s_offsets"(consumer);
+
+ CREATE TABLE IF NOT EXISTS "%s_deadletters" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ uuid BLOB NOT NULL UNIQUE,
+ consumer TEXT NOT NULL,
+ message_id INTEGER NOT NULL
+ REFERENCES "%s_messages"(id),
+ UNIQUE (consumer, message_id)
+ ) STRICT;
+ CREATE INDEX IF NOT EXISTS "%s_deadletters_consumer"
+ ON "%s_deadletters"(consumer);
+
+ CREATE TABLE IF NOT EXISTS "%s_replays" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ deadletter_id INTEGER NOT NULL UNIQUE
+ REFERENCES "%s_deadletters"(id) ,
+ message_id INTEGER NOT NULL UNIQUE
+ REFERENCES "%s_messages"(id)
+ ) STRICT;
+
+ CREATE TABLE IF NOT EXISTS "%s_owners" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ topic TEXT NOT NULL,
+ consumer TEXT NOT NULL,
+ owner_id INTEGER NOT NULL,
+ UNIQUE (topic, consumer)
+ ) STRICT;
+ `
+ return queryT{
+ write: fmt.Sprintf(
+ tmpl_write,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ }
+}
+
+func createTables(db *sql.DB, prefix string) error {
+ q := createTablesSQL(prefix)
+
+ return inTx(db, func(ctx context.Context) error {
+ _, err := db.ExecContext(ctx, q.write)
+ return err
+ })
+}
+
+func takeSQL(prefix string) queryT {
+ const tmpl_write = `
+ INSERT INTO "%s_owners" (topic, consumer, owner_id)
+ VALUES (?, ?, ?)
+ ON CONFLICT (topic, consumer) DO
+ UPDATE SET owner_id=excluded.owner_id;
+ `
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ }
+}
+
+func takeStmt(
+ db *sql.DB,
+ prefix string,
+ instanceID int,
+) (func(string, string) error, func() error, error) {
+ q := takeSQL(prefix)
+
+ fn := func(topic string, consumer string) error {
+ return inTx(db, func(ctx context.Context) error {
+ _, err := db.ExecContext(
+ ctx,
+ q.write,
+ topic,
+ consumer,
+ instanceID,
+ )
+ return err
+ })
+ }
+
+ return fn, closeNoop, nil
+}
+
+func publishSQL(prefix string) queryT {
+ const tmpl_write = `
+ INSERT INTO "%s_payloads" (topic, payload)
+ VALUES (?, ?);
+
+ INSERT INTO "%s_messages" (uuid, flow_id, payload_id)
+ VALUES (?, ?, last_insert_rowid());
+ `
+ const tmpl_read = `
+ SELECT id, timestamp FROM "%s_messages"
+ WHERE uuid = ?;
+ `
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix, prefix),
+ read: fmt.Sprintf(tmpl_read, prefix),
+ }
+}
+
+func publishStmt(
+ db *sql.DB,
+ prefix string,
+ _ int,
+) (func(UnsentMessage, guuid.UUID) (messageT, error), func() error, error) {
+ q := publishSQL(prefix)
+
+ readStmt, err := db.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(
+ unsentMessage UnsentMessage,
+ messageID guuid.UUID,
+ ) (messageT, error) {
+ message := messageT{
+ uuid: messageID,
+ topic: unsentMessage.Topic,
+ flowID: unsentMessage.FlowID,
+ payload: unsentMessage.Payload,
+ }
+
+ message_id_bytes := messageID[:]
+ flow_id_bytes := unsentMessage.FlowID[:]
+ _, err := db.Exec(
+ q.write,
+ unsentMessage.Topic,
+ unsentMessage.Payload,
+ message_id_bytes,
+ flow_id_bytes,
+ )
+ if err != nil {
+ return messageT{}, err
+ }
+
+ var timestr string
+ err = readStmt.QueryRow(message_id_bytes).Scan(
+ &message.id,
+ &timestr,
+ )
+ if err != nil {
+ return messageT{}, err
+ }
+
+ message.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
+ if err != nil {
+ return messageT{}, err
+ }
+
+ return message, nil
+ }
+
+ return fn, readStmt.Close, nil
+}
+
+func findSQL(prefix string) queryT {
+ const tmpl_read = `
+ SELECT
+ "%s_messages".id,
+ "%s_messages".timestamp,
+ "%s_messages".uuid,
+ "%s_payloads".payload
+ FROM "%s_messages"
+ JOIN "%s_payloads" ON
+ "%s_payloads".id = "%s_messages".payload_id
+ WHERE
+ "%s_payloads".topic = ? AND
+ "%s_messages".flow_id = ?
+ ORDER BY "%s_messages".id DESC
+ LIMIT 1;
+ `
+ return queryT{
+ read: fmt.Sprintf(
+ tmpl_read,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ }
+}
+
+func findStmt(
+ db *sql.DB,
+ prefix string,
+ _ int,
+) (func(string, guuid.UUID) (messageT, error), func() error, error) {
+ q := findSQL(prefix)
+
+ readStmt, err := db.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(topic string, flowID guuid.UUID) (messageT, error) {
+ message := messageT{
+ topic: topic,
+ flowID: flowID,
+ }
+
+ var (
+ timestr string
+ message_id_bytes []byte
+ )
+ flow_id_bytes := flowID[:]
+ err = readStmt.QueryRow(topic, flow_id_bytes).Scan(
+ &message.id,
+ &timestr,
+ &message_id_bytes,
+ &message.payload,
+ )
+ if err != nil {
+ return messageT{}, err
+ }
+ message.uuid = guuid.UUID(message_id_bytes)
+
+ message.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
+ if err != nil {
+ return messageT{}, err
+ }
+
+ return message, nil
+ }
+
+ return fn, readStmt.Close, nil
+}
+
+func nextSQL(prefix string) queryT {
+ const tmpl_read = `
+ SELECT
+ "%s_messages".id,
+ "%s_messages".timestamp,
+ "%s_messages".uuid,
+ "%s_messages".flow_id,
+ "%s_payloads".payload
+ FROM "%s_messages"
+ JOIN "%s_payloads" ON
+ "%s_payloads".id = "%s_messages".payload_id
+ WHERE
+ "%s_payloads".topic = ? AND
+ "%s_messages".id NOT IN (
+ SELECT message_id FROM "%s_offsets"
+ WHERE consumer = ?
+ )
+ ORDER BY "%s_messages".id ASC
+ LIMIT 1;
+ `
+ const tmpl_owner = `
+ SELECT owner_id FROM "%s_owners"
+ WHERE
+ topic = ? AND
+ consumer = ?;
+ `
+ return queryT{
+ read: fmt.Sprintf(
+ tmpl_read,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ owner: fmt.Sprintf(tmpl_owner, prefix),
+ }
+}
+
+func nextStmt(
+ db *sql.DB,
+ prefix string,
+ instanceID int,
+) (func(string, string) (messageT, error), func() error, error) {
+ q := nextSQL(prefix)
+
+ fn := func(topic string, consumer string) (messageT, error) {
+ message := messageT{
+ topic: topic,
+ }
+
+ var (
+ err error
+ ownerID int
+ timestr string
+ message_id_bytes []byte
+ flow_id_bytes []byte
+ )
+ tx, err := db.Begin()
+ if err != nil {
+ return messageT{}, err
+ }
+ defer tx.Rollback()
+
+ err = tx.QueryRow(q.owner, topic, consumer).Scan(&ownerID)
+ if err != nil {
+ return messageT{}, err
+ }
+
+ if ownerID != instanceID {
+ err := fmt.Errorf(
+ notOwnerErrorFmt,
+ ownerID,
+ topic,
+ consumer,
+ instanceID,
+ )
+ return messageT{}, err
+ }
+
+ err = tx.QueryRow(q.read, topic, consumer).Scan(
+ &message.id,
+ &timestr,
+ &message_id_bytes,
+ &flow_id_bytes,
+ &message.payload,
+ )
+ if err != nil {
+ return messageT{}, err
+ }
+ message.uuid = guuid.UUID(message_id_bytes)
+ message.flowID = guuid.UUID(flow_id_bytes)
+
+ message.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
+ if err != nil {
+ return messageT{}, err
+ }
+
+ return message, nil
+ }
+
+ return fn, closeNoop, nil
+}
+
+func messageEach(rows *sql.Rows, callback func(messageT) error) error {
+ if rows == nil {
+ return nil
+ }
+
+ for rows.Next() {
+ var (
+ message messageT
+ timestr string
+ message_id_bytes []byte
+ flow_id_bytes []byte
+ )
+ err := rows.Scan(
+ &message.id,
+ &timestr,
+ &message_id_bytes,
+ &flow_id_bytes,
+ &message.topic,
+ &message.payload,
+ )
+ if err != nil {
+ return g.WrapErrors(rows.Close(), err)
+ }
+ message.uuid = guuid.UUID(message_id_bytes)
+ message.flowID = guuid.UUID(flow_id_bytes)
+
+ message.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
+ if err != nil {
+ return g.WrapErrors(rows.Close(), err)
+ }
+
+ err = callback(message)
+ if err != nil {
+ return g.WrapErrors(rows.Close(), err)
+ }
+ }
+
+ return g.WrapErrors(rows.Err(), rows.Close())
+}
+
+func pendingSQL(prefix string) queryT {
+ const tmpl_read = `
+ SELECT
+ "%s_messages".id,
+ "%s_messages".timestamp,
+ "%s_messages".uuid,
+ "%s_messages".flow_id,
+ "%s_payloads".topic,
+ "%s_payloads".payload
+ FROM "%s_messages"
+ JOIN "%s_payloads" ON
+ "%s_payloads".id = "%s_messages".payload_id
+ WHERE
+ "%s_payloads".topic = ? AND
+ "%s_messages".id NOT IN (
+ SELECT message_id FROM "%s_offsets"
+ WHERE consumer = ?
+ )
+ ORDER BY "%s_messages".id ASC;
+ `
+ const tmpl_owner = `
+ SELECT owner_id FROM "%s_owners"
+ WHERE
+ topic = ? AND
+ consumer = ?;
+ `
+ return queryT{
+ read: fmt.Sprintf(
+ tmpl_read,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ owner: fmt.Sprintf(tmpl_owner, prefix),
+ }
+}
+
+func pendingStmt(
+ db *sql.DB,
+ prefix string,
+ instanceID int,
+) (func(string, string) (*sql.Rows, error), func() error, error) {
+ q := pendingSQL(prefix)
+
+ readStmt, err := db.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ ownerStmt, err := db.Prepare(q.owner)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(topic string, consumer string) (*sql.Rows, error) {
+ var ownerID int
+ err := ownerStmt.QueryRow(topic, consumer).Scan(&ownerID)
+ if err != nil {
+ return nil, err
+ }
+
+ // best effort check, the final one is done during
+ // commit within a transaction
+ if ownerID != instanceID {
+ return nil, nil
+ }
+
+ return readStmt.Query(topic, consumer)
+ }
+
+ closeFn := func() error {
+ return g.SomeFnError(readStmt.Close, ownerStmt.Close)
+ }
+
+ return fn, closeFn, nil
+}
+
+func commitSQL(prefix string) queryT {
+ const tmpl_write = `
+ INSERT INTO "%s_offsets" (consumer, message_id)
+ VALUES (?, (SELECT id FROM "%s_messages" WHERE uuid = ?));
+ `
+ const tmpl_read = `
+ SELECT "%s_payloads".topic from "%s_payloads"
+ JOIN "%s_messages" ON
+ "%s_payloads".id = "%s_messages".payload_id
+ WHERE "%s_messages".uuid = ?;
+ `
+ const tmpl_owner = `
+ SELECT owner_id FROM "%s_owners"
+ WHERE
+ topic = ? AND
+ consumer = ?;
+ `
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix, prefix),
+ read: fmt.Sprintf(
+ tmpl_read,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ owner: fmt.Sprintf(tmpl_owner, prefix),
+ }
+}
+
+func commitStmt(
+ db *sql.DB,
+ prefix string,
+ instanceID int,
+) (func(string, guuid.UUID) error, func() error, error) {
+ q := commitSQL(prefix)
+
+ fn := func(consumer string, messageID guuid.UUID) error {
+ message_id_bytes := messageID[:]
+ return inTx(db, func(ctx context.Context) error {
+ var topic string
+ err := db.QueryRowContext(
+ ctx,
+ q.read,
+ message_id_bytes,
+ ).Scan(&topic)
+ if err != nil {
+ return err
+ }
+
+ var ownerID int
+ err = db.QueryRowContext(
+ ctx,
+ q.owner,
+ topic,
+ consumer,
+ ).Scan(&ownerID)
+ if err != nil {
+ return err
+ }
+
+ if ownerID != instanceID {
+ return fmt.Errorf(
+ noLongerOwnerErrorFmt,
+ instanceID,
+ topic,
+ consumer,
+ ownerID,
+ )
+ }
+
+ _, err = db.ExecContext(ctx, q.write, consumer, message_id_bytes)
+ return err
+ })
+ }
+
+ return fn, closeNoop, nil
+}
+
+func toDeadSQL(prefix string) queryT {
+ const tmpl_write = `
+ INSERT INTO "%s_offsets" ( consumer, message_id)
+ VALUES ( ?, (SELECT id FROM "%s_messages" WHERE uuid = ?));
+
+ INSERT INTO "%s_deadletters" (uuid, consumer, message_id)
+ VALUES (?, ?, (SELECT id FROM "%s_messages" WHERE uuid = ?));
+ `
+ const tmpl_read = `
+ SELECT "%s_payloads".topic FROM "%s_payloads"
+ JOIN "%s_messages" ON
+ "%s_payloads".id = "%s_messages".payload_id
+ WHERE "%s_messages".uuid = ?;
+ `
+ const tmpl_owner = `
+ SELECT owner_id FROM "%s_owners"
+ WHERE
+ topic = ? AND
+ consumer = ?;
+ `
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix, prefix, prefix, prefix),
+ read: fmt.Sprintf(
+ tmpl_read,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ owner: fmt.Sprintf(tmpl_owner, prefix),
+ }
+}
+
+func toDeadStmt(
+ db *sql.DB,
+ prefix string,
+ instanceID int,
+) (
+ func(string, guuid.UUID, guuid.UUID) error,
+ func() error,
+ error,
+) {
+ q := toDeadSQL(prefix)
+
+ fn := func(
+ consumer string,
+ messageID guuid.UUID,
+ deadletterID guuid.UUID,
+ ) error {
+ message_id_bytes := messageID[:]
+ deadletter_id_bytes := deadletterID[:]
+ return inTx(db, func(ctx context.Context) error {
+ var topic string
+ err := db.QueryRowContext(
+ ctx,
+ q.read,
+ message_id_bytes,
+ ).Scan(&topic)
+ if err != nil {
+ return err
+ }
+
+ var ownerID int
+ err = db.QueryRowContext(
+ ctx,
+ q.owner,
+ topic,
+ consumer,
+ ).Scan(&ownerID)
+ if err != nil {
+ return err
+ }
+
+ if ownerID != instanceID {
+ return fmt.Errorf(
+ noLongerOwnerErrorFmt,
+ instanceID,
+ topic,
+ consumer,
+ ownerID,
+ )
+ }
+
+ _, err = db.ExecContext(
+ ctx,
+ q.write,
+ consumer,
+ message_id_bytes,
+ deadletter_id_bytes,
+ consumer,
+ message_id_bytes,
+ )
+ return err
+ })
+ }
+
+ return fn, closeNoop, nil
+}
+
+func replaySQL(prefix string) queryT {
+ const tmpl_write = `
+ INSERT INTO "%s_messages" (uuid, flow_id, payload_id)
+ SELECT
+ ?,
+ "%s_messages".flow_id,
+ "%s_messages".payload_id
+ FROM "%s_messages"
+ JOIN "%s_deadletters" ON
+ "%s_messages".id = "%s_deadletters".message_id
+ WHERE "%s_deadletters".uuid = ?;
+
+ INSERT INTO "%s_replays" (deadletter_id, message_id)
+ VALUES (
+ (SELECT id FROM "%s_deadletters" WHERE uuid = ?),
+ last_insert_rowid()
+ );
+ `
+ const tmpl_read = `
+ SELECT
+ "%s_messages".id,
+ "%s_messages".timestamp,
+ "%s_messages".flow_id,
+ "%s_payloads".topic,
+ "%s_payloads".payload
+ FROM "%s_messages"
+ JOIN "%s_payloads" ON
+ "%s_payloads".id = "%s_messages".payload_id
+ WHERE "%s_messages".uuid = ?;
+ `
+ return queryT{
+ write: fmt.Sprintf(
+ tmpl_write,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ read: fmt.Sprintf(
+ tmpl_read,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ }
+}
+
+func replayStmt(
+ db *sql.DB,
+ prefix string,
+ _ int,
+) (func(guuid.UUID, guuid.UUID) (messageT, error), func() error, error) {
+ q := replaySQL(prefix)
+
+ readStmt, err := db.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(
+ deadletterID guuid.UUID,
+ messageID guuid.UUID,
+ ) (messageT, error) {
+ deadletter_id_bytes := deadletterID[:]
+ message_id_bytes := messageID[:]
+ err := inTx(db, func(ctx context.Context) error {
+ _, err := db.ExecContext(
+ ctx,
+ q.write,
+ message_id_bytes,
+ deadletter_id_bytes,
+ deadletter_id_bytes,
+ )
+ return err
+ })
+ if err != nil {
+ return messageT{}, err
+ }
+
+ message := messageT{
+ uuid: messageID,
+ }
+
+ var (
+ timestr string
+ flow_id_bytes []byte
+ )
+ err = readStmt.QueryRow(message_id_bytes).Scan(
+ &message.id,
+ &timestr,
+ &flow_id_bytes,
+ &message.topic,
+ &message.payload,
+ )
+ if err != nil {
+ return messageT{}, err
+ }
+ message.flowID = guuid.UUID(flow_id_bytes)
+
+ message.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
+ if err != nil {
+ return messageT{}, err
+ }
+
+ return message, nil
+ }
+
+ return fn, readStmt.Close, nil
+}
+
+func oneDeadSQL(prefix string) queryT {
+ const tmpl_read = `
+ SELECT
+ "%s_deadletters".uuid,
+ "%s_offsets".timestamp,
+ "%s_messages".uuid
+ FROM "%s_deadletters"
+ JOIN "%s_offsets" ON
+ "%s_deadletters".message_id = "%s_offsets".message_id
+ JOIN "%s_messages" ON
+ "%s_deadletters".message_id = "%s_messages".id
+ JOIN "%s_payloads" ON
+ "%s_messages".payload_id = "%s_payloads".id
+ WHERE
+ "%s_payloads".topic = ? AND
+ "%s_deadletters".consumer = ? AND
+ "%s_offsets".consumer = ? AND
+ "%s_deadletters".id NOT IN (
+ SELECT deadletter_id FROM "%s_replays"
+ )
+ ORDER BY "%s_deadletters".id ASC
+ LIMIT 1;
+ `
+ return queryT{
+ read: fmt.Sprintf(
+ tmpl_read,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ }
+}
+
+func oneDeadStmt(
+ db *sql.DB,
+ prefix string,
+ _ int,
+) (func(string, string) (deadletterT, error), func() error, error) {
+ q := oneDeadSQL(prefix)
+
+ readStmt, err := db.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(topic string, consumer string) (deadletterT, error) {
+ deadletter := deadletterT{
+ consumer: consumer,
+ }
+
+ var (
+ deadletter_id_bytes []byte
+ timestr string
+ message_id_bytes []byte
+ )
+ err := readStmt.QueryRow(topic, consumer, consumer).Scan(
+ &deadletter_id_bytes,
+ &timestr,
+ &message_id_bytes,
+ )
+ if err != nil {
+ return deadletterT{}, err
+ }
+ deadletter.uuid = guuid.UUID(deadletter_id_bytes)
+ deadletter.messageID = guuid.UUID(message_id_bytes)
+
+ deadletter.timestamp, err = time.Parse(
+ time.RFC3339Nano,
+ timestr,
+ )
+ if err != nil {
+ return deadletterT{}, err
+ }
+
+ return deadletter, nil
+ }
+
+ return fn, readStmt.Close, nil
+}
+
+func deadletterEach(
+ rows *sql.Rows,
+ callback func(deadletterT, messageT) error,
+) error {
+ for rows.Next() {
+ var (
+ deadletter deadletterT
+ deadletter_id_bytes []byte
+ deadletterTimestr string
+ message messageT
+ messageTimestr string
+ message_id_bytes []byte
+ flow_id_bytes []byte
+ )
+ err := rows.Scan(
+ &deadletter_id_bytes,
+ &message.id,
+ &deadletterTimestr,
+ &deadletter.consumer,
+ &messageTimestr,
+ &message_id_bytes,
+ &flow_id_bytes,
+ &message.topic,
+ &message.payload,
+ )
+ if err != nil {
+ return g.WrapErrors(rows.Close(), err)
+ }
+
+ deadletter.uuid = guuid.UUID(deadletter_id_bytes)
+ deadletter.messageID = guuid.UUID(message_id_bytes)
+ message.uuid = guuid.UUID(message_id_bytes)
+ message.flowID = guuid.UUID(flow_id_bytes)
+
+ message.timestamp, err = time.Parse(
+ time.RFC3339Nano,
+ messageTimestr,
+ )
+ if err != nil {
+ return g.WrapErrors(rows.Close(), err)
+ }
+
+ deadletter.timestamp, err = time.Parse(
+ time.RFC3339Nano,
+ deadletterTimestr,
+ )
+ if err != nil {
+ return g.WrapErrors(rows.Close(), err)
+ }
+
+ err = callback(deadletter, message)
+ if err != nil {
+ return g.WrapErrors(rows.Close(), err)
+ }
+ }
+
+ return g.WrapErrors(rows.Err(), rows.Close())
+}
+
+func allDeadSQL(prefix string) queryT {
+ const tmpl_read = `
+ SELECT
+ "%s_deadletters".uuid,
+ "%s_deadletters".message_id,
+ "%s_offsets".timestamp,
+ "%s_offsets".consumer,
+ "%s_messages".timestamp,
+ "%s_messages".uuid,
+ "%s_messages".flow_id,
+ "%s_payloads".topic,
+ "%s_payloads".payload
+ FROM "%s_deadletters"
+ JOIN "%s_offsets" ON
+ "%s_deadletters".message_id = "%s_offsets".message_id
+ JOIN "%s_messages" ON
+ "%s_deadletters".message_id = "%s_messages".id
+ JOIN "%s_payloads" ON
+ "%s_messages".payload_id = "%s_payloads".id
+ WHERE
+ "%s_payloads".topic = ? AND
+ "%s_deadletters".consumer = ? AND
+ "%s_offsets".consumer = ? AND
+ "%s_deadletters".id NOT IN (
+ SELECT deadletter_id FROM "%s_replays"
+ )
+ ORDER BY "%s_deadletters".id ASC;
+ `
+ return queryT{
+ read: fmt.Sprintf(
+ tmpl_read,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ }
+}
+
+func allDeadStmt(
+ db *sql.DB,
+ prefix string,
+ _ int,
+) (func(string, string) (*sql.Rows, error), func() error, error) {
+ q := allDeadSQL(prefix)
+
+ readStmt, err := db.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(topic string, consumer string) (*sql.Rows, error) {
+ return readStmt.Query(topic, consumer, consumer)
+ }
+
+ return fn, readStmt.Close, nil
+}
+
+func sizeSQL(prefix string) queryT {
+ const tmpl_read = `
+ SELECT
+ COUNT(1) as size
+ FROM "%s_messages"
+ JOIN "%s_payloads" ON
+ "%s_messages".payload_id = "%s_payloads".id
+ WHERE "%s_payloads".topic = ?;
+ `
+ return queryT{
+ read: fmt.Sprintf(
+ tmpl_read,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ }
+}
+
+
+func sizeStmt(
+ db *sql.DB,
+ prefix string,
+ _ int,
+) (func(string) (int, error), func() error, error) {
+ q := sizeSQL(prefix)
+
+ readStmt, err := db.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(topic string) (int, error) {
+ var size int
+ err := readStmt.QueryRow(topic).Scan(&size)
+ if err != nil {
+ return -1, err
+ }
+
+ return size, nil
+ }
+
+ return fn, readStmt.Close, nil
+}
+
+func countSQL(prefix string) queryT {
+ const tmpl_read = `
+ SELECT
+ COUNT(1) as count
+ FROM "%s_messages"
+ JOIN "%s_offsets" ON
+ "%s_messages".id = "%s_offsets".message_id
+ JOIN "%s_payloads" ON
+ "%s_messages".payload_id = "%s_payloads".id
+ WHERE
+ "%s_payloads".topic = ? AND
+ "%s_offsets".consumer = ?;
+ `
+ return queryT{
+ read: fmt.Sprintf(
+ tmpl_read,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ }
+}
+
+func countStmt(
+ db *sql.DB,
+ prefix string,
+ _ int,
+) (func(string, string) (int, error), func() error, error) {
+ q := countSQL(prefix)
+
+ readStmt, err := db.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(topic string, consumer string) (int, error) {
+ var count int
+ err := readStmt.QueryRow(topic, consumer).Scan(&count)
+ if err != nil {
+ return -1, err
+ }
+
+ return count, nil
+ }
+
+ return fn, readStmt.Close, nil
+}
+
+func hasDataSQL(prefix string) queryT {
+ const tmpl_read = `
+ SELECT 1 as data
+ FROM "%s_messages"
+ JOIN "%s_payloads" ON
+ "%s_payloads".id = "%s_messages".payload_id
+ WHERE
+ "%s_payloads".topic = ? AND
+ "%s_messages".id NOT IN (
+ SELECT message_id FROM "%s_offsets"
+ WHERE consumer = ?
+ )
+ LIMIT 1;
+ `
+ return queryT{
+ read: fmt.Sprintf(
+ tmpl_read,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ }
+}
+
+func hasDataStmt(
+ db *sql.DB,
+ prefix string,
+ _ int,
+) (func(string, string) (bool, error), func() error, error) {
+ q := hasDataSQL(prefix)
+
+ readStmt, err := db.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(topic string, consumer string) (bool, error) {
+ var _x int
+ err := readStmt.QueryRow(topic, consumer).Scan(&_x)
+ if err == sql.ErrNoRows {
+ return false, nil
+ }
+
+ if err != nil {
+ return false, err
+ }
+
+ return true, nil
+ }
+
+ return fn, readStmt.Close, nil
+}
+
+func initDB(
+ db *sql.DB,
+ prefix string,
+ notifyFn func(messageT),
+ instanceID int,
+) (queriesT, error) {
+ createTablesErr := createTables(db, prefix)
+ take, takeClose, takeErr := takeStmt(db, prefix, instanceID)
+ publish, publishClose, publishErr := publishStmt(db, prefix, instanceID)
+ find, findClose, findErr := findStmt(db, prefix, instanceID)
+ next, nextClose, nextErr := nextStmt(db, prefix, instanceID)
+ pending, pendingClose, pendingErr := pendingStmt(db, prefix, instanceID)
+ commit, commitClose, commitErr := commitStmt(db, prefix, instanceID)
+ toDead, toDeadClose, toDeadErr := toDeadStmt(db, prefix, instanceID)
+ replay, replayClose, replayErr := replayStmt(db, prefix, instanceID)
+ oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(db, prefix, instanceID)
+ allDead, allDeadClose, allDeadErr := allDeadStmt(db, prefix, instanceID)
+ size, sizeClose, sizeErr := sizeStmt(db, prefix, instanceID)
+ count, countClose, countErr := countStmt(db, prefix, instanceID)
+ hasData, hasDataClose, hasDataErr := hasDataStmt(db, prefix, instanceID)
+
+ err := g.SomeError(
+ createTablesErr,
+ takeErr,
+ publishErr,
+ findErr,
+ nextErr,
+ pendingErr,
+ commitErr,
+ toDeadErr,
+ replayErr,
+ oneDeadErr,
+ allDeadErr,
+ sizeErr,
+ countErr,
+ hasDataErr,
+ )
+ if err != nil {
+ return queriesT{}, err
+ }
+
+ close := func() error {
+ return g.SomeFnError(
+ takeClose,
+ publishClose,
+ findClose,
+ nextClose,
+ pendingClose,
+ commitClose,
+ toDeadClose,
+ replayClose,
+ oneDeadClose,
+ allDeadClose,
+ sizeClose,
+ countClose,
+ hasDataClose,
+ )
+ }
+
+ var connMutex sync.RWMutex
+ return queriesT{
+ take: func(a string, b string) error {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return take(a, b)
+ },
+ publish: func(a UnsentMessage, b guuid.UUID) (messageT, error) {
+ var (
+ err error
+ message messageT
+ )
+ {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ message, err = publish(a, b)
+ }
+ if err != nil {
+ return messageT{}, err
+ }
+
+ go notifyFn(message)
+ return message, nil
+ },
+ find: func(a string, b guuid.UUID) (messageT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return find(a, b)
+ },
+ next: func(a string, b string) (messageT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return next(a, b)
+ },
+ pending: func(
+ a string,
+ b string,
+ callback func(messageT) error,
+ ) error {
+ var (
+ err error
+ rows *sql.Rows
+ )
+ {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ rows, err = pending(a, b)
+ }
+ if err != nil {
+ return err
+ }
+
+ return messageEach(rows, callback)
+ },
+ commit: func(a string, b guuid.UUID) error {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return commit(a, b)
+ },
+ toDead: func(
+ a string,
+ b guuid.UUID,
+ c guuid.UUID,
+ ) error {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return toDead(a, b, c)
+ },
+ replay: func(a guuid.UUID, b guuid.UUID) (messageT, error) {
+ var (
+ err error
+ message messageT
+ )
+ {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ message, err = replay(a, b)
+ }
+ if err != nil {
+ return messageT{}, err
+ }
+
+ go notifyFn(message)
+ return message, nil
+ },
+ oneDead: func(a string, b string) (deadletterT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return oneDead(a, b)
+ },
+ allDead: func(
+ a string,
+ b string,
+ callback func(deadletterT, messageT) error,
+ ) error {
+ var (
+ err error
+ rows *sql.Rows
+ )
+ {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ rows, err = allDead(a, b)
+ }
+ if err != nil {
+ return err
+ }
+
+ return deadletterEach(rows, callback)
+ },
+ size: func(a string) (int, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return size(a)
+ },
+ count: func(a string, b string) (int, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return count(a, b)
+ },
+ hasData: func(a string, b string) (bool, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return hasData(a, b)
+ },
+ close: func() error {
+ connMutex.Lock()
+ defer connMutex.Unlock()
+ return close()
+ },
+ }, nil
+}
+
+
+func newPinger[T any]() pingerT[T] {
+ channel := make(chan T, 1)
+ closed := false
+ var rwmutex sync.RWMutex
+ return pingerT[T]{
+ tryPing: func(x T) {
+ rwmutex.RLock()
+ defer rwmutex.RUnlock()
+ if closed {
+ return
+ }
+ select {
+ case channel <- x:
+ default:
+ }
+ },
+ onPing: func(cb func(T)) {
+ for x := range channel {
+ cb(x)
+ }
+ },
+ closed: func() bool {
+ rwmutex.RLock()
+ defer rwmutex.RUnlock()
+ return closed
+ },
+ close: func() {
+ rwmutex.Lock()
+ defer rwmutex.Unlock()
+ if closed {
+ return
+ }
+ close(channel)
+ closed = true
+ },
+ }
+}
+
+func makeSubscriptionsFuncs() subscriptionsT {
+ var rwmutex sync.RWMutex
+ subscriptions := subscriptionsSetM{}
+ return subscriptionsT{
+ read: func(callback func(subscriptionsSetM) error) error {
+ rwmutex.RLock()
+ defer rwmutex.RUnlock()
+ return callback(subscriptions)
+ },
+ write: func(callback func(subscriptionsSetM) error) error {
+ rwmutex.Lock()
+ defer rwmutex.Unlock()
+ return callback(subscriptions)
+ },
+ }
+}
+
+// Try notifying the consumer that they have data to work with. If they're
+// already full, we simply drop the notification, as on each they'll look for
+// all pending messages and process them all. So dropping the event here
+// doesn't mean not notifying the consumer, but simply acknoledging that the
+// existing notifications are enough for them to work with, without letting any
+// message slip through.
+func makeNotifyFn(
+ readFn func(func(subscriptionsSetM) error) error,
+ pinger pingerT[struct{}],
+) func(messageT) {
+ return func(message messageT) {
+ readFn(func(set subscriptionsSetM) error {
+ topicSub, ok := set[message.topic]
+ if !ok {
+ return nil
+ }
+
+ for _, consumer := range topicSub.consumers {
+ consumer.pinger.tryPing(struct{}{})
+ }
+ waiters := topicSub.waiters[message.flowID]
+ for _, waiter := range waiters {
+ waiter.pinger.tryPing(message.payload)
+ }
+ return nil
+ })
+ pinger.tryPing(struct{}{})
+ }
+}
+
+func collectClosedWaiters(
+ set subscriptionsSetM,
+) map[string]map[guuid.UUID][]string {
+ waiters := map[string]map[guuid.UUID][]string{}
+ for topic, topicSub := range set {
+ waiters[topic] = map[guuid.UUID][]string{}
+ for flowID, waitersByName := range topicSub.waiters {
+ names := []string{}
+ for name, waiter := range waitersByName {
+ if (*waiter.closed)() {
+ names = append(names, name)
+ }
+ }
+ waiters[topic][flowID] = names
+ }
+ }
+
+ return waiters
+}
+
+func trimEmptyLeaves(closedWaiters map[string]map[guuid.UUID][]string) {
+ for topic, waiters := range closedWaiters {
+ for flowID, names := range waiters {
+ if len(names) == 0 {
+ delete(closedWaiters[topic], flowID)
+ }
+ }
+ if len(waiters) == 0 {
+ delete(closedWaiters, topic)
+ }
+ }
+}
+
+func deleteIfEmpty(set subscriptionsSetM, topic string) {
+ topicSub, ok := set[topic]
+ if !ok {
+ return
+ }
+
+ emptyConsumers := len(topicSub.consumers) == 0
+ emptyWaiters := len(topicSub.waiters) == 0
+ if emptyConsumers && emptyWaiters {
+ delete(set, topic)
+ }
+}
+
+func deleteEmptyTopics(set subscriptionsSetM) {
+ for topic, _ := range set {
+ deleteIfEmpty(set, topic)
+ }
+}
+
+func removeClosedWaiters(
+ set subscriptionsSetM,
+ closedWaiters map[string]map[guuid.UUID][]string,
+) {
+ for topic, waiters := range closedWaiters {
+ _, ok := set[topic]
+ if !ok {
+ continue
+ }
+
+ for flowID, names := range waiters {
+ if set[topic].waiters[flowID] == nil {
+ continue
+ }
+ for _, name := range names {
+ delete(set[topic].waiters[flowID], name)
+ }
+ if len(set[topic].waiters[flowID]) == 0 {
+ delete(set[topic].waiters, flowID)
+ }
+ }
+ }
+
+ deleteEmptyTopics(set)
+}
+
+func reapClosedWaiters(
+ readFn func(func(subscriptionsSetM) error) error,
+ writeFn func(func(subscriptionsSetM) error) error,
+) {
+ var closedWaiters map[string]map[guuid.UUID][]string
+ readFn(func(set subscriptionsSetM) error {
+ closedWaiters = collectClosedWaiters(set)
+ return nil
+ })
+
+ trimEmptyLeaves(closedWaiters)
+ if len(closedWaiters) == 0 {
+ return
+ }
+
+ writeFn(func(set subscriptionsSetM) error {
+ removeClosedWaiters(set, closedWaiters)
+ return nil
+ })
+}
+
+func everyNthCall[T any](n int, fn func(T)) func(T) {
+ i := 0
+ return func(x T) {
+ i++
+ if i == n {
+ i = 0
+ fn(x)
+ }
+ }
+}
+
+func runReaper(
+ onPing func(func(struct{})),
+ readFn func(func(subscriptionsSetM) error) error,
+ writeFn func(func(subscriptionsSetM) error) error,
+) {
+ onPing(everyNthCall(reaperSkipCount, func(struct{}) {
+ reapClosedWaiters(readFn, writeFn)
+ }))
+}
+
+func NewWithPrefix(db *sql.DB, prefix string) (IQueue, error) {
+ err := g.ValidateSQLTablePrefix(prefix)
+ if err != nil {
+ return queueT{}, err
+ }
+
+ subscriptions := makeSubscriptionsFuncs()
+ pinger := newPinger[struct{}]()
+ notifyFn := makeNotifyFn(subscriptions.read, pinger)
+ queries, err := initDB(db, prefix, notifyFn, os.Getpid())
+ if err != nil {
+ return queueT{}, err
+ }
+
+ go runReaper(pinger.onPing, subscriptions.read, subscriptions.write)
+
+ return queueT{
+ queries: queries,
+ subscriptions: subscriptions,
+ pinger: pinger,
+ }, nil
+}
+
+func New(db *sql.DB) (IQueue, error) {
+ return NewWithPrefix(db, defaultPrefix)
+}
+
+func asPublicMessage(message messageT) Message {
+ return Message{
+ ID: message.uuid,
+ Timestamp: message.timestamp,
+ Topic: message.topic,
+ FlowID: message.flowID,
+ Payload: message.payload,
+ }
+}
+
+func (queue queueT) Publish(unsent UnsentMessage) (Message, error) {
+ message, err := queue.queries.publish(unsent, guuid.New())
+ return asPublicMessage(message), err
+}
+
+func registerConsumerFn(consumer consumerT) func(subscriptionsSetM) error {
+ topicSub := topicSubscriptionT{
+ consumers: map[string]consumerT{},
+ waiters: map[guuid.UUID]map[string]waiterT{},
+ }
+
+ return func(set subscriptionsSetM) error {
+ topic := consumer.data.topic
+ _, ok := set[topic]
+ if !ok {
+ set[topic] = topicSub
+ }
+ set[topic].consumers[consumer.data.name] = consumer
+
+ return nil
+ }
+}
+
+func registerWaiterFn(waiter waiterT) func(subscriptionsSetM) error {
+ topicSub := topicSubscriptionT{
+ consumers: map[string]consumerT{},
+ waiters: map[guuid.UUID]map[string]waiterT{},
+ }
+ waiters := map[string]waiterT{}
+
+ return func(set subscriptionsSetM) error {
+ var (
+ topic = waiter.data.topic
+ flowID = waiter.data.flowID
+ )
+ _, ok := set[topic]
+ if !ok {
+ set[topic] = topicSub
+ }
+ if set[topic].waiters[flowID] == nil {
+ set[topic].waiters[flowID] = waiters
+ }
+ set[topic].waiters[flowID][waiter.data.name] = waiter
+ return nil
+ }
+}
+
+func makeConsumeOneFn(
+ data consumerDataT,
+ callback func(Message) error,
+ successFn func(string, guuid.UUID) error,
+ errorFn func(string, guuid.UUID, guuid.UUID) error,
+) func(messageT) error {
+ return func(message messageT) error {
+ err := callback(asPublicMessage(message))
+ if err != nil {
+ g.Info(
+ "consumer failed", "q-consumer",
+ "topic", data.topic,
+ "consumer", data.name,
+ "error", err,
+ slog.Group(
+ "message",
+ "id", message.id,
+ "flow-id", message.flowID.String(),
+ ),
+ )
+
+ return errorFn(data.name, message.uuid, guuid.New())
+ }
+
+ return successFn(data.name, message.uuid)
+ }
+}
+
+func makeConsumeAllFn(
+ data consumerDataT,
+ consumeOneFn func(messageT) error,
+ eachFn func(string, string, func(messageT) error) error,
+) func(struct{}) {
+ return func(struct{}) {
+ err := eachFn(data.topic, data.name, consumeOneFn)
+ if err != nil {
+ g.Warning(
+ "eachFn failed", "q-consume-all",
+ "topic", data.topic,
+ "consumer", data.name,
+ "error", err,
+ "circuit-breaker-enabled?", false,
+ )
+ }
+ }
+}
+
+func makeWaitFn(channel chan []byte, closeFn func()) func([]byte) {
+ closed := false
+ var mutex sync.Mutex
+ return func(payload []byte) {
+ mutex.Lock()
+ defer mutex.Unlock()
+ if closed {
+ return
+ }
+
+ closeFn()
+ channel <- payload
+ close(channel)
+ closed = true
+ }
+}
+
+func runConsumer(onPing func(func(struct{})), consumeAllFn func(struct{})) {
+ consumeAllFn(struct{}{})
+ onPing(consumeAllFn)
+}
+
+func tryFinding(
+ findFn func(string, guuid.UUID) (messageT, error),
+ topic string,
+ flowID guuid.UUID,
+ waitFn func([]byte),
+) {
+ message, err := findFn(topic, flowID)
+ if err != nil {
+ return
+ }
+
+ waitFn(message.payload)
+}
+
+func (queue queueT) Subscribe(
+ topic string,
+ name string,
+ callback func(Message) error,
+) error {
+ data := consumerDataT{
+ topic: topic,
+ name: name,
+ }
+ pinger := newPinger[struct{}]()
+ consumer := consumerT{
+ data: data,
+ callback: callback,
+ pinger: pinger,
+ close: &pinger.close,
+ }
+ consumeOneFn := makeConsumeOneFn(
+ consumer.data,
+ consumer.callback,
+ queue.queries.commit,
+ queue.queries.toDead,
+ )
+ consumeAllFn := makeConsumeAllFn(
+ consumer.data,
+ consumeOneFn,
+ queue.queries.pending,
+ )
+
+ err := queue.queries.take(topic, name)
+ if err != nil {
+ return err
+ }
+
+ queue.subscriptions.write(registerConsumerFn(consumer))
+ go runConsumer(pinger.onPing, consumeAllFn)
+ return nil
+}
+
+type Waiter struct{
+ Channel <-chan []byte
+ Close func()
+}
+
+func (queue queueT) WaitFor(
+ topic string,
+ flowID guuid.UUID,
+ name string,
+) Waiter {
+ data := waiterDataT{
+ topic: topic,
+ flowID: flowID,
+ name: name,
+ }
+ pinger := newPinger[[]byte]()
+ waiter := waiterT{
+ data: data,
+ pinger: pinger,
+ closed: &pinger.closed,
+ close: &pinger.close,
+ }
+ channel := make(chan []byte, 1)
+ waitFn := makeWaitFn(channel, (*waiter.close))
+ closeFn := func() {
+ queue.subscriptions.read(func(set subscriptionsSetM) error {
+ (*set[topic].waiters[flowID][name].close)()
+ return nil
+ })
+ }
+
+ queue.subscriptions.write(registerWaiterFn(waiter))
+ tryFinding(queue.queries.find, topic, flowID, waitFn)
+ go pinger.onPing(waitFn)
+ return Waiter{channel, closeFn}
+}
+
+func unsubscribeIfExistsFn(
+ topic string,
+ name string,
+) func(subscriptionsSetM) error {
+ return func(set subscriptionsSetM) error {
+ topicSub, ok := set[topic]
+ if !ok {
+ return nil
+ }
+
+ consumer, ok := topicSub.consumers[name]
+ if !ok {
+ return nil
+ }
+
+ (*consumer.close)()
+ delete(set[topic].consumers, name)
+ deleteIfEmpty(set, topic)
+ return nil
+ }
+}
+
+func (queue queueT) Unsubscribe(topic string, name string) {
+ queue.subscriptions.write(unsubscribeIfExistsFn(topic, name))
+}
+
+func cleanSubscriptions(set subscriptionsSetM) error {
+ for _, topicSub := range set {
+ for _, consumer := range topicSub.consumers {
+ (*consumer.close)()
+ }
+ for _, waiters := range topicSub.waiters {
+ for _, waiter := range waiters {
+ (*waiter.close)()
+ }
+ }
+ }
+ return nil
+}
+
+func (queue queueT) Close() error {
+ queue.pinger.close()
+ return g.WrapErrors(
+ queue.subscriptions.write(cleanSubscriptions),
+ queue.queries.close(),
+ )
+}
+
+
+func topicGetopt(args argsT, w io.Writer) (argsT, bool) {
+ if len(args.args) == 0 {
+ fmt.Fprintf(w, "Missing TOPIC.\n")
+ return args, false
+ }
+
+ args.topic = args.args[0]
+ return args, true
+}
+
+func topicConsumerGetopt(args argsT, w io.Writer) (argsT, bool) {
+ fs := flag.NewFlagSet("", flag.ContinueOnError)
+ fs.Usage = func() {}
+ fs.SetOutput(w)
+
+ consumer := fs.String(
+ "C",
+ "default-consumer",
+ "The name of the consumer to be used",
+ )
+
+ if fs.Parse(args.args) != nil {
+ return args, false
+ }
+
+ subArgs := fs.Args()
+ if len(subArgs) == 0 {
+ fmt.Fprintf(w, "Missing TOPIC.\n")
+ return args, false
+ }
+
+ args.consumer = *consumer
+ args.topic = subArgs[0]
+ return args, true
+}
+
+func inExec(
+ args argsT,
+ queries queriesT,
+ r io.Reader,
+ w io.Writer,
+) (int, error) {
+ payload, err := io.ReadAll(r)
+ if err != nil {
+ return 1, err
+ }
+
+ unsent := UnsentMessage{
+ Topic: args.topic,
+ FlowID: guuid.New(),
+ Payload: payload,
+ }
+ message, err := queries.publish(unsent, guuid.New())
+ if err != nil {
+ return 1, err
+ }
+
+ fmt.Fprintf(w, "%s\n", message.uuid.String())
+
+ return 0, nil
+}
+
+func outExec(
+ args argsT,
+ queries queriesT,
+ _ io.Reader,
+ w io.Writer,
+) (int, error) {
+ err := queries.take(args.topic, args.consumer)
+ if err != nil {
+ return 1, err
+ }
+
+ message, err := queries.next(args.topic, args.consumer)
+
+ if err == sql.ErrNoRows {
+ return 3, nil
+ }
+
+ if err != nil {
+ return 1, err
+ }
+
+ fmt.Fprintln(w, string(message.payload))
+
+ return 0, nil
+}
+
+func commitExec(
+ args argsT,
+ queries queriesT,
+ r io.Reader,
+ w io.Writer,
+) (int, error) {
+ err := queries.take(args.topic, args.consumer)
+ if err != nil {
+ return 1, err
+ }
+
+ message, err := queries.next(args.topic, args.consumer)
+ if err != nil {
+ return 1, err
+ }
+
+ err = queries.commit(args.consumer, message.uuid)
+ if err != nil {
+ return 1, err
+ }
+
+ return 0, nil
+}
+
+func deadExec(
+ args argsT,
+ queries queriesT,
+ r io.Reader,
+ w io.Writer,
+) (int, error) {
+ err := queries.take(args.topic, args.consumer)
+ if err != nil {
+ return 1, err
+ }
+
+ message, err := queries.next(args.topic, args.consumer)
+ if err != nil {
+ return 1, err
+ }
+
+ err = queries.toDead(args.consumer, message.uuid, guuid.New())
+ if err != nil {
+ return 1, err
+ }
+
+ return 0, nil
+}
+
+func listDeadExec(
+ args argsT,
+ queries queriesT,
+ r io.Reader,
+ w io.Writer,
+) (int, error) {
+ eachFn := func(deadletter deadletterT, _ messageT) error {
+ fmt.Fprintf(
+ w,
+ "%s\t%s\t%s\n",
+ deadletter.uuid.String(),
+ deadletter.timestamp.Format(time.RFC3339),
+ deadletter.consumer,
+ )
+ return nil
+ }
+
+ err := queries.allDead(args.topic, args.consumer, eachFn)
+ if err != nil {
+ return 1, err
+ }
+
+ return 0, nil
+}
+
+func replayExec(
+ args argsT,
+ queries queriesT,
+ r io.Reader,
+ w io.Writer,
+) (int, error) {
+ deadletter, err := queries.oneDead(args.topic, args.consumer)
+ if err != nil {
+ return 1, err
+ }
+
+ _, err = queries.replay(deadletter.uuid, guuid.New())
+ if err != nil {
+ return 1, err
+ }
+
+ return 0, nil
+}
+
+func sizeExec(
+ args argsT,
+ queries queriesT,
+ r io.Reader,
+ w io.Writer,
+) (int, error) {
+ size, err := queries.size(args.topic)
+ if err != nil {
+ return 1, err
+ }
+
+ fmt.Fprintln(w, size)
+
+ return 0, nil
+}
+
+func countExec(
+ args argsT,
+ queries queriesT,
+ r io.Reader,
+ w io.Writer,
+) (int, error) {
+ count, err := queries.count(args.topic, args.consumer)
+ if err != nil {
+ return 1, err
+ }
+
+ fmt.Fprintln(w, count)
+
+ return 0, nil
+}
+
+func hasDataExec(
+ args argsT,
+ queries queriesT,
+ r io.Reader,
+ w io.Writer,
+) (int, error) {
+ hasData, err := queries.hasData(args.topic, args.consumer)
+ if err != nil {
+ return 1, err
+ }
+
+ if hasData {
+ return 0, nil
+ } else {
+ return 1, nil
+ }
+}
+
+func usage(argv0 string, w io.Writer) {
+ fmt.Fprintf(
+ w,
+ "Usage: %s [-f FILE] [-p PREFIX] COMMAND [OPTIONS]\n",
+ argv0,
+ )
+}
+
+func getopt(
+ allArgs []string,
+ commandsMap map[string]commandT,
+ w io.Writer,
+) (argsT, commandT, int) {
+ argv0 := allArgs[0]
+ argv := allArgs[1:]
+ fs := flag.NewFlagSet("", flag.ContinueOnError)
+ fs.Usage = func() {}
+ fs.SetOutput(w)
+ databasePath := fs.String(
+ "f",
+ "q.db",
+ "The path to the file where the queue is kept",
+ )
+ prefix := fs.String(
+ "p",
+ defaultPrefix,
+ "The q prefix of the table names",
+ )
+ if fs.Parse(argv) != nil {
+ usage(argv0, w)
+ return argsT{}, commandT{}, 2
+ }
+
+ subArgs := fs.Args()
+ if len(subArgs) == 0 {
+ fmt.Fprintf(w, "Missing COMMAND.\n")
+ usage(argv0, w)
+ return argsT{}, commandT{}, 2
+ }
+
+ args := argsT{
+ databasePath: *databasePath,
+ prefix: *prefix,
+ command: subArgs[0],
+ allArgs: allArgs,
+ args: subArgs[1:],
+ }
+
+ command := commandsMap[args.command]
+ if command.name == "" {
+ fmt.Fprintf(w, "Bad COMMAND: \"%s\".\n", args.command)
+ usage(allArgs[0], w)
+ return argsT{}, commandT{}, 2
+ }
+
+ args, ok := command.getopt(args, w)
+ if !ok {
+ usage(allArgs[0], w)
+ return argsT{}, commandT{}, 2
+ }
+
+ return args, command, 0
+}
+
+func runCommand(
+ args argsT,
+ command commandT,
+ stdin io.Reader,
+ stdout io.Writer,
+ stderr io.Writer,
+) int {
+ db, err := sql.Open("acude", args.databasePath)
+ if err != nil {
+ fmt.Fprintln(stderr, err)
+ return 1
+ }
+ defer db.Close()
+
+ iqueue, err := NewWithPrefix(db, args.prefix)
+ if err != nil {
+ fmt.Fprintln(stderr, err)
+ return 1
+ }
+ defer iqueue.Close()
+
+ rc, err := command.exec(args, iqueue.(queueT).queries, stdin, stdout)
+ if err != nil {
+ fmt.Fprintln(stderr, err)
+ }
+
+ return rc
+}
+
+var commands = map[string]commandT {
+ "in": commandT{
+ name: "in",
+ getopt: topicGetopt,
+ exec: inExec,
+ },
+ "out": commandT{
+ name: "out",
+ getopt: topicConsumerGetopt,
+ exec: outExec,
+ },
+ "commit": commandT{
+ name: "commit",
+ getopt: topicConsumerGetopt,
+ exec: commitExec,
+ },
+ "dead": commandT{
+ name: "dead",
+ getopt: topicConsumerGetopt,
+ exec: deadExec,
+ },
+ "ls-dead": commandT{
+ name: "ls-dead",
+ getopt: topicConsumerGetopt,
+ exec: listDeadExec,
+ },
+ "replay": commandT{
+ name: "replay",
+ getopt: topicConsumerGetopt,
+ exec: replayExec,
+ },
+ "size": commandT{
+ name: "size",
+ getopt: topicGetopt,
+ exec: sizeExec,
+ },
+ "count": commandT{
+ name: "count",
+ getopt: topicConsumerGetopt,
+ exec: countExec,
+ },
+ "has-data": commandT{
+ name: "has-data",
+ getopt: topicConsumerGetopt,
+ exec: hasDataExec,
+ },
+}
+
+
+
+func Main() {
+ g.Init()
+ args, command, rc := getopt(os.Args, commands, os.Stderr)
+ if rc != 0 {
+ os.Exit(2)
+ }
+ os.Exit(runCommand(args, command, os.Stdin, os.Stdout, os.Stderr))
+}