package fiinha import ( "database/sql" "flag" "fmt" "io" "log/slog" "os" "sync" "time" "golite" "uuid" g "gobang" ) const ( defaultPrefix = "fiinha" reaperSkipCount = 1000 notOwnerErrorFmt = "%v owns %#v as %#v, not us (%v)" rollbackErrorFmt = "rollback error: %w; while executing: %w" ) type dbconfigT struct{ shared *sql.DB dbpath string prefix string instanceID int } type queryT struct{ write string read string owner string } type queriesT struct{ take func(string, string) error publish func(UnsentMessage, uuid.UUID) (messageT, error) find func(string, uuid.UUID) (messageT, error) next func(string, string) (messageT, error) pending func(string, string, func(messageT) error) error commit func(string, uuid.UUID) error toDead func(string, uuid.UUID, uuid.UUID) error replay func(uuid.UUID, uuid.UUID) (messageT, error) oneDead func(string, string) (deadletterT, error) allDead func(string, string, func(deadletterT, messageT) error) error size func(string) (int, error) count func(string, string) (int, error) hasData func(string, string) (bool, error) close func() error } type messageT struct{ id int64 timestamp time.Time uuid uuid.UUID topic string flowID uuid.UUID payload []byte } type UnsentMessage struct{ Topic string FlowID uuid.UUID Payload []byte } type Message struct{ ID uuid.UUID Timestamp time.Time Topic string FlowID uuid.UUID Payload []byte } type deadletterT struct{ uuid uuid.UUID timestamp time.Time consumer string messageID uuid.UUID } type pingerT[T any] struct{ tryPing func(T) onPing func(func(T)) closed func() bool close func() } type consumerDataT struct{ topic string name string } type waiterDataT struct{ topic string flowID uuid.UUID name string } type consumerT struct{ data consumerDataT callback func(Message) error pinger pingerT[struct{}] close *func() } type waiterT struct{ data waiterDataT pinger pingerT[[]byte] closed *func() bool close *func() } type topicSubscriptionT struct{ consumers map[string]consumerT waiters map[uuid.UUID]map[string]waiterT } type subscriptionsSetM map[string]topicSubscriptionT type subscriptionsT struct { read func(func(subscriptionsSetM) error) error write func(func(subscriptionsSetM) error) error } type queueT struct{ queries queriesT subscriptions subscriptionsT pinger pingerT[struct{}] } type argsT struct{ databasePath string prefix string command string allArgs []string args []string topic string consumer string } type commandT struct{ name string getopt func(argsT, io.Writer) (argsT, bool) exec func(argsT, queriesT, io.Reader, io.Writer) (int, error) } type IQueue interface{ Publish(UnsentMessage) (Message, error) Subscribe( string, string, func(Message) error) error Unsubscribe(string, string) WaitFor(string, uuid.UUID, string) Waiter Close() error } func tryRollback(tx *sql.Tx, err error) error { rollbackErr := tx.Rollback() if rollbackErr != nil { return fmt.Errorf( rollbackErrorFmt, rollbackErr, err, ) } return err } func inTx(db *sql.DB, fn func(*sql.Tx) error) error { tx, err := db.Begin() if err != nil { return err } err = fn(tx) if err != nil { return tryRollback(tx, err) } err = tx.Commit() if err != nil { return tryRollback(tx, err) } return nil } func serialized[A any, B any](callback func(...A) B) (func(...A) B, func()) { in := make(chan []A) out := make(chan B) closed := false var ( closeWg sync.WaitGroup closeMutex sync.Mutex ) closeWg.Add(1) go func() { for input := range in { out <- callback(input...) } close(out) closeWg.Done() }() fn := func(input ...A) B { in <- input return (<- out) } closeFn := func() { closeMutex.Lock() defer closeMutex.Unlock() if closed { return } close(in) closed = true closeWg.Wait() } return fn, closeFn } func execSerialized(query string, db *sql.DB) (func(...any) error, func()) { return serialized(func(args ...any) error { return inTx(db, func(tx *sql.Tx) error { _, err := tx.Exec(query, args...) return err }) }) } func createTablesSQL(prefix string) queryT { const tmpl_write = ` CREATE TABLE IF NOT EXISTS "%s_payloads" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), topic TEXT NOT NULL, payload BLOB NOT NULL ) STRICT; CREATE INDEX IF NOT EXISTS "%s_payloads_topic" ON "%s_payloads"(topic); CREATE TABLE IF NOT EXISTS "%s_messages" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, flow_id BLOB NOT NULL, payload_id INTEGER NOT NULL REFERENCES "%s_payloads"(id) ) STRICT; CREATE INDEX IF NOT EXISTS "%s_messages_flow_id" ON "%s_messages"(flow_id); CREATE TABLE IF NOT EXISTS "%s_offsets" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), consumer TEXT NOT NULL, message_id INTEGER NOT NULL REFERENCES "%s_messages"(id), instance_id INTEGER NOT NULL, UNIQUE (consumer, message_id) ) STRICT; CREATE INDEX IF NOT EXISTS "%s_offsets_consumer" ON "%s_offsets"(consumer); CREATE TABLE IF NOT EXISTS "%s_deadletters" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, uuid BLOB NOT NULL UNIQUE, consumer TEXT NOT NULL, message_id INTEGER NOT NULL REFERENCES "%s_messages"(id), instance_id INTEGER NOT NULL, UNIQUE (consumer, message_id) ) STRICT; CREATE INDEX IF NOT EXISTS "%s_deadletters_consumer" ON "%s_deadletters"(consumer); CREATE TABLE IF NOT EXISTS "%s_replays" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, deadletter_id INTEGER NOT NULL UNIQUE REFERENCES "%s_deadletters"(id) , message_id INTEGER NOT NULL UNIQUE REFERENCES "%s_messages"(id) ) STRICT; CREATE TABLE IF NOT EXISTS "%s_owners" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, topic TEXT NOT NULL, consumer TEXT NOT NULL, owner_id INTEGER NOT NULL, UNIQUE (topic, consumer) ) STRICT; CREATE TRIGGER IF NOT EXISTS "%s_check_instance_owns_topic" BEFORE INSERT ON "%s_offsets" WHEN NEW.instance_id != ( SELECT owner_id FROM "%s_owners" WHERE topic = ( SELECT "%s_payloads".topic FROM "%s_payloads" JOIN "%s_messages" ON "%s_payloads".id = "%s_messages".payload_id WHERE "%s_messages".id = NEW.message_id ) AND consumer = NEW.consumer ) BEGIN SELECT RAISE( ABORT, 'instance does not own topic/consumer combo' ); END; CREATE TRIGGER IF NOT EXISTS "%s_check_can_publish_deadletter" BEFORE INSERT ON "%s_deadletters" WHEN NEW.instance_id != ( SELECT owner_id FROM "%s_owners" WHERE topic = ( SELECT "%s_payloads".topic FROM "%s_payloads" JOIN "%s_messages" ON "%s_payloads".id = "%s_messages".payload_id WHERE "%s_messages".id = NEW.message_id ) AND consumer = NEW.consumer ) BEGIN SELECT RAISE( ABORT, 'Instance does not own topic/consumer combo' ); END; ` return queryT{ write: fmt.Sprintf( tmpl_write, prefix, g.SQLiteNow, prefix, prefix, prefix, g.SQLiteNow, prefix, prefix, prefix, prefix, g.SQLiteNow, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, ), } } func createTables(db *sql.DB, prefix string) error { q := createTablesSQL(prefix) return inTx(db, func(tx *sql.Tx) error { _, err := tx.Exec(q.write) return err }) } func takeSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_owners" (topic, consumer, owner_id) VALUES (?, ?, ?) ON CONFLICT (topic, consumer) DO UPDATE SET owner_id=excluded.owner_id; ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), } } func takeStmt( cfg dbconfigT, ) (func(string, string) error, func() error, error) { q := takeSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(topic string, consumer string) error { _, err := writeStmt.Exec( topic, consumer, cfg.instanceID, ) return err } return fn, writeStmt.Close, nil } func publishSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_payloads" (topic, payload) VALUES (?, ?); INSERT INTO "%s_messages" (uuid, flow_id, payload_id) VALUES (?, ?, last_insert_rowid()); ` const tmpl_read = ` SELECT id, timestamp FROM "%s_messages" WHERE uuid = ?; ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix, prefix), read: fmt.Sprintf(tmpl_read, prefix), } } func publishStmt( cfg dbconfigT, ) (func(UnsentMessage, uuid.UUID) (messageT, error), func() error, error) { q := publishSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } privateDB, err := sql.Open(golite.DriverName, cfg.dbpath) if err != nil { readStmt.Close() return nil, nil, err } writeFn, writeFnClose := execSerialized(q.write, privateDB) fn := func( unsentMessage UnsentMessage, messageID uuid.UUID, ) (messageT, error) { message := messageT{ uuid: messageID, topic: unsentMessage.Topic, flowID: unsentMessage.FlowID, payload: unsentMessage.Payload, } message_id_bytes := messageID[:] flow_id_bytes := unsentMessage.FlowID[:] err := writeFn( unsentMessage.Topic, unsentMessage.Payload, message_id_bytes, flow_id_bytes, ) if err != nil { return messageT{}, err } var timestr string err = readStmt.QueryRow(message_id_bytes).Scan( &message.id, ×tr, ) if err != nil { return messageT{}, err } message.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return messageT{}, err } return message, nil } closeFn := func() error { writeFnClose() return g.SomeError(privateDB.Close(), readStmt.Close()) } return fn, closeFn, nil } func findSQL(prefix string) queryT { const tmpl_read = ` SELECT "%s_messages".id, "%s_messages".timestamp, "%s_messages".uuid, "%s_payloads".payload FROM "%s_messages" JOIN "%s_payloads" ON "%s_payloads".id = "%s_messages".payload_id WHERE "%s_payloads".topic = ? AND "%s_messages".flow_id = ? ORDER BY "%s_messages".id DESC LIMIT 1; ` return queryT{ read: fmt.Sprintf( tmpl_read, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, ), } } func findStmt( cfg dbconfigT, ) (func(string, uuid.UUID) (messageT, error), func() error, error) { q := findSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } fn := func(topic string, flowID uuid.UUID) (messageT, error) { message := messageT{ topic: topic, flowID: flowID, } var ( timestr string message_id_bytes []byte ) flow_id_bytes := flowID[:] err = readStmt.QueryRow(topic, flow_id_bytes).Scan( &message.id, ×tr, &message_id_bytes, &message.payload, ) if err != nil { return messageT{}, err } message.uuid = uuid.UUID(message_id_bytes) message.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return messageT{}, err } return message, nil } return fn, readStmt.Close, nil } func nextSQL(prefix string) queryT { const tmpl_read = ` SELECT ( SELECT owner_id FROM "%s_owners" WHERE topic = ? AND consumer = ? LIMIT 1 ) AS owner_id, "%s_messages".id, "%s_messages".timestamp, "%s_messages".uuid, "%s_messages".flow_id, "%s_payloads".payload FROM "%s_messages" JOIN "%s_payloads" ON "%s_payloads".id = "%s_messages".payload_id WHERE "%s_payloads".topic = ? AND "%s_messages".id NOT IN ( SELECT message_id FROM "%s_offsets" WHERE consumer = ? ) ORDER BY "%s_messages".id ASC LIMIT 1; ` return queryT{ read: fmt.Sprintf( tmpl_read, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, ), } } func nextStmt( cfg dbconfigT, ) (func(string, string) (messageT, error), func() error, error) { q := nextSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } fn := func(topic string, consumer string) (messageT, error) { message := messageT{ topic: topic, } var ( ownerID int timestr string message_id_bytes []byte flow_id_bytes []byte ) err = readStmt.QueryRow(topic, consumer, topic, consumer).Scan( &ownerID, &message.id, ×tr, &message_id_bytes, &flow_id_bytes, &message.payload, ) if err != nil { return messageT{}, err } if ownerID != cfg.instanceID { err := fmt.Errorf( notOwnerErrorFmt, ownerID, topic, consumer, cfg.instanceID, ) return messageT{}, err } message.uuid = uuid.UUID(message_id_bytes) message.flowID = uuid.UUID(flow_id_bytes) message.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return messageT{}, err } return message, nil } return fn, readStmt.Close, nil } func messageEach(rows *sql.Rows, callback func(messageT) error) error { if rows == nil { return nil } for rows.Next() { var ( message messageT timestr string message_id_bytes []byte flow_id_bytes []byte ) err := rows.Scan( &message.id, ×tr, &message_id_bytes, &flow_id_bytes, &message.topic, &message.payload, ) if err != nil { rows.Close() return err } message.uuid = uuid.UUID(message_id_bytes) message.flowID = uuid.UUID(flow_id_bytes) message.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { rows.Close() return err } err = callback(message) if err != nil { rows.Close() return err } } return g.WrapErrors(rows.Err(), rows.Close()) } func pendingSQL(prefix string) queryT { const tmpl_read = ` SELECT "%s_messages".id, "%s_messages".timestamp, "%s_messages".uuid, "%s_messages".flow_id, "%s_payloads".topic, "%s_payloads".payload FROM "%s_messages" JOIN "%s_payloads" ON "%s_payloads".id = "%s_messages".payload_id WHERE "%s_payloads".topic = ? AND "%s_messages".id NOT IN ( SELECT message_id FROM "%s_offsets" WHERE consumer = ? ) ORDER BY "%s_messages".id ASC; ` const tmpl_owner = ` SELECT owner_id FROM "%s_owners" WHERE topic = ? AND consumer = ?; ` return queryT{ read: fmt.Sprintf( tmpl_read, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, ), owner: fmt.Sprintf(tmpl_owner, prefix), } } func pendingStmt( cfg dbconfigT, ) (func(string, string) (*sql.Rows, error), func() error, error) { q := pendingSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } ownerStmt, err := cfg.shared.Prepare(q.owner) if err != nil { readStmt.Close() return nil, nil, err } fn := func(topic string, consumer string) (*sql.Rows, error) { var ownerID int err := ownerStmt.QueryRow(topic, consumer).Scan(&ownerID) if err != nil { return nil, err } // best effort check, the final one is done during // commit within a transaction if ownerID != cfg.instanceID { return nil, nil } return readStmt.Query(topic, consumer) } closeFn := func() error { return g.SomeFnError(readStmt.Close, ownerStmt.Close) } return fn, closeFn, nil } func commitSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_offsets" (consumer, message_id, instance_id) VALUES (?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?); ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix, prefix), } } func commitStmt( cfg dbconfigT, ) (func(string, uuid.UUID) error, func() error, error) { q := commitSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(consumer string, messageID uuid.UUID) error { message_id_bytes := messageID[:] _, err = writeStmt.Exec( consumer, message_id_bytes, cfg.instanceID, ) return err } return fn, writeStmt.Close, nil } func toDeadSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_offsets" ( consumer, message_id, instance_id) VALUES ( ?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?); INSERT INTO "%s_deadletters" (uuid, consumer, message_id, instance_id) VALUES (?, ?, (SELECT id FROM "%s_messages" WHERE uuid = ?), ?); ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix, prefix, prefix, prefix), } } func toDeadStmt( cfg dbconfigT, ) ( func(string, uuid.UUID, uuid.UUID) error, func() error, error, ) { q := toDeadSQL(cfg.prefix) privateDB, err := sql.Open(golite.DriverName, cfg.dbpath) if err != nil { return nil, nil, err } writeFn, writeFnClose := execSerialized(q.write, privateDB) fn := func( consumer string, messageID uuid.UUID, deadletterID uuid.UUID, ) error { message_id_bytes := messageID[:] deadletter_id_bytes := deadletterID[:] return writeFn( consumer, message_id_bytes, cfg.instanceID, deadletter_id_bytes, consumer, message_id_bytes, cfg.instanceID, ) } closeFn := func() error { writeFnClose() return privateDB.Close() } return fn, closeFn, nil } func replaySQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_messages" (uuid, flow_id, payload_id) SELECT ?, "%s_messages".flow_id, "%s_messages".payload_id FROM "%s_messages" JOIN "%s_deadletters" ON "%s_messages".id = "%s_deadletters".message_id WHERE "%s_deadletters".uuid = ?; INSERT INTO "%s_replays" (deadletter_id, message_id) VALUES ( (SELECT id FROM "%s_deadletters" WHERE uuid = ?), last_insert_rowid() ); ` const tmpl_read = ` SELECT "%s_messages".id, "%s_messages".timestamp, "%s_messages".flow_id, "%s_payloads".topic, "%s_payloads".payload FROM "%s_messages" JOIN "%s_payloads" ON "%s_payloads".id = "%s_messages".payload_id WHERE "%s_messages".uuid = ?; ` return queryT{ write: fmt.Sprintf( tmpl_write, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, ), read: fmt.Sprintf( tmpl_read, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, ), } } func replayStmt( cfg dbconfigT, ) (func(uuid.UUID, uuid.UUID) (messageT, error), func() error, error) { q := replaySQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } privateDB, err := sql.Open(golite.DriverName, cfg.dbpath) if err != nil { readStmt.Close() return nil, nil, err } writeFn, writeFnClose := execSerialized(q.write, privateDB) fn := func( deadletterID uuid.UUID, messageID uuid.UUID, ) (messageT, error) { deadletter_id_bytes := deadletterID[:] message_id_bytes := messageID[:] err := writeFn( message_id_bytes, deadletter_id_bytes, deadletter_id_bytes, ) if err != nil { return messageT{}, err } message := messageT{ uuid: messageID, } var ( timestr string flow_id_bytes []byte ) err = readStmt.QueryRow(message_id_bytes).Scan( &message.id, ×tr, &flow_id_bytes, &message.topic, &message.payload, ) if err != nil { return messageT{}, err } message.flowID = uuid.UUID(flow_id_bytes) message.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return messageT{}, err } return message, nil } closeFn := func() error { writeFnClose() return g.SomeError(privateDB.Close(), readStmt.Close()) } return fn, closeFn, nil } func oneDeadSQL(prefix string) queryT { const tmpl_read = ` SELECT "%s_deadletters".uuid, "%s_offsets".timestamp, "%s_messages".uuid FROM "%s_deadletters" JOIN "%s_offsets" ON "%s_deadletters".message_id = "%s_offsets".message_id JOIN "%s_messages" ON "%s_deadletters".message_id = "%s_messages".id JOIN "%s_payloads" ON "%s_messages".payload_id = "%s_payloads".id WHERE "%s_payloads".topic = ? AND "%s_deadletters".consumer = ? AND "%s_offsets".consumer = ? AND "%s_deadletters".id NOT IN ( SELECT deadletter_id FROM "%s_replays" ) ORDER BY "%s_deadletters".id ASC LIMIT 1; ` return queryT{ read: fmt.Sprintf( tmpl_read, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, ), } } func oneDeadStmt( cfg dbconfigT, ) (func(string, string) (deadletterT, error), func() error, error) { q := oneDeadSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } fn := func(topic string, consumer string) (deadletterT, error) { deadletter := deadletterT{ consumer: consumer, } var ( deadletter_id_bytes []byte timestr string message_id_bytes []byte ) err := readStmt.QueryRow(topic, consumer, consumer).Scan( &deadletter_id_bytes, ×tr, &message_id_bytes, ) if err != nil { return deadletterT{}, err } deadletter.uuid = uuid.UUID(deadletter_id_bytes) deadletter.messageID = uuid.UUID(message_id_bytes) deadletter.timestamp, err = time.Parse( time.RFC3339Nano, timestr, ) if err != nil { return deadletterT{}, err } return deadletter, nil } return fn, readStmt.Close, nil } func deadletterEach( rows *sql.Rows, callback func(deadletterT, messageT) error, ) error { for rows.Next() { var ( deadletter deadletterT deadletter_id_bytes []byte deadletterTimestr string message messageT messageTimestr string message_id_bytes []byte flow_id_bytes []byte ) err := rows.Scan( &deadletter_id_bytes, &message.id, &deadletterTimestr, &deadletter.consumer, &messageTimestr, &message_id_bytes, &flow_id_bytes, &message.topic, &message.payload, ) if err != nil { rows.Close() return err } deadletter.uuid = uuid.UUID(deadletter_id_bytes) deadletter.messageID = uuid.UUID(message_id_bytes) message.uuid = uuid.UUID(message_id_bytes) message.flowID = uuid.UUID(flow_id_bytes) message.timestamp, err = time.Parse( time.RFC3339Nano, messageTimestr, ) if err != nil { rows.Close() return err } deadletter.timestamp, err = time.Parse( time.RFC3339Nano, deadletterTimestr, ) if err != nil { rows.Close() return err } err = callback(deadletter, message) if err != nil { rows.Close() return err } } return g.WrapErrors(rows.Err(), rows.Close()) } func allDeadSQL(prefix string) queryT { const tmpl_read = ` SELECT "%s_deadletters".uuid, "%s_deadletters".message_id, "%s_offsets".timestamp, "%s_offsets".consumer, "%s_messages".timestamp, "%s_messages".uuid, "%s_messages".flow_id, "%s_payloads".topic, "%s_payloads".payload FROM "%s_deadletters" JOIN "%s_offsets" ON "%s_deadletters".message_id = "%s_offsets".message_id JOIN "%s_messages" ON "%s_deadletters".message_id = "%s_messages".id JOIN "%s_payloads" ON "%s_messages".payload_id = "%s_payloads".id WHERE "%s_payloads".topic = ? AND "%s_deadletters".consumer = ? AND "%s_offsets".consumer = ? AND "%s_deadletters".id NOT IN ( SELECT deadletter_id FROM "%s_replays" ) ORDER BY "%s_deadletters".id ASC; ` return queryT{ read: fmt.Sprintf( tmpl_read, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, ), } } func allDeadStmt( cfg dbconfigT, ) (func(string, string) (*sql.Rows, error), func() error, error) { q := allDeadSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } fn := func(topic string, consumer string) (*sql.Rows, error) { return readStmt.Query(topic, consumer, consumer) } return fn, readStmt.Close, nil } func sizeSQL(prefix string) queryT { const tmpl_read = ` SELECT COUNT(1) as size FROM "%s_messages" JOIN "%s_payloads" ON "%s_messages".payload_id = "%s_payloads".id WHERE "%s_payloads".topic = ?; ` return queryT{ read: fmt.Sprintf( tmpl_read, prefix, prefix, prefix, prefix, prefix, ), } } func sizeStmt( cfg dbconfigT, ) (func(string) (int, error), func() error, error) { q := sizeSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } fn := func(topic string) (int, error) { var size int err := readStmt.QueryRow(topic).Scan(&size) if err != nil { return -1, err } return size, nil } return fn, readStmt.Close, nil } func countSQL(prefix string) queryT { const tmpl_read = ` SELECT COUNT(1) as count FROM "%s_messages" JOIN "%s_offsets" ON "%s_messages".id = "%s_offsets".message_id JOIN "%s_payloads" ON "%s_messages".payload_id = "%s_payloads".id WHERE "%s_payloads".topic = ? AND "%s_offsets".consumer = ?; ` return queryT{ read: fmt.Sprintf( tmpl_read, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, ), } } func countStmt( cfg dbconfigT, ) (func(string, string) (int, error), func() error, error) { q := countSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } fn := func(topic string, consumer string) (int, error) { var count int err := readStmt.QueryRow(topic, consumer).Scan(&count) if err != nil { return -1, err } return count, nil } return fn, readStmt.Close, nil } func hasDataSQL(prefix string) queryT { const tmpl_read = ` SELECT 1 as data FROM "%s_messages" JOIN "%s_payloads" ON "%s_payloads".id = "%s_messages".payload_id WHERE "%s_payloads".topic = ? AND "%s_messages".id NOT IN ( SELECT message_id FROM "%s_offsets" WHERE consumer = ? ) LIMIT 1; ` return queryT{ read: fmt.Sprintf( tmpl_read, prefix, prefix, prefix, prefix, prefix, prefix, prefix, ), } } func hasDataStmt( cfg dbconfigT, ) (func(string, string) (bool, error), func() error, error) { q := hasDataSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } fn := func(topic string, consumer string) (bool, error) { var _x int err := readStmt.QueryRow(topic, consumer).Scan(&_x) if err == sql.ErrNoRows { return false, nil } if err != nil { return false, err } return true, nil } return fn, readStmt.Close, nil } func initDB( dbpath string, prefix string, notifyFn func(messageT), instanceID int, ) (queriesT, error) { err := g.ValidateSQLTablePrefix(prefix) if err != nil { return queriesT{}, err } shared, err := sql.Open(golite.DriverName, dbpath) if err != nil { return queriesT{}, err } cfg := dbconfigT{ shared: shared, dbpath: dbpath, prefix: prefix, instanceID: instanceID, } createTablesErr := createTables(shared, prefix) take, takeClose, takeErr := takeStmt(cfg) publish, publishClose, publishErr := publishStmt(cfg) find, findClose, findErr := findStmt(cfg) next, nextClose, nextErr := nextStmt(cfg) pending, pendingClose, pendingErr := pendingStmt(cfg) commit, commitClose, commitErr := commitStmt(cfg) toDead, toDeadClose, toDeadErr := toDeadStmt(cfg) replay, replayClose, replayErr := replayStmt(cfg) oneDead, oneDeadClose, oneDeadErr := oneDeadStmt(cfg) allDead, allDeadClose, allDeadErr := allDeadStmt(cfg) size, sizeClose, sizeErr := sizeStmt(cfg) count, countClose, countErr := countStmt(cfg) hasData, hasDataClose, hasDataErr := hasDataStmt(cfg) err = g.SomeError( createTablesErr, takeErr, publishErr, findErr, nextErr, pendingErr, commitErr, toDeadErr, replayErr, oneDeadErr, allDeadErr, sizeErr, countErr, hasDataErr, ) if err != nil { return queriesT{}, err } closeFn := func() error { return g.SomeFnError( takeClose, publishClose, findClose, nextClose, pendingClose, commitClose, toDeadClose, replayClose, oneDeadClose, allDeadClose, sizeClose, countClose, hasDataClose, shared.Close, ) } var connMutex sync.RWMutex return queriesT{ take: func(a string, b string) error { connMutex.RLock() defer connMutex.RUnlock() return take(a, b) }, publish: func(a UnsentMessage, b uuid.UUID) (messageT, error) { var ( err error message messageT ) { connMutex.RLock() defer connMutex.RUnlock() message, err = publish(a, b) } if err != nil { return messageT{}, err } go notifyFn(message) return message, nil }, find: func(a string, b uuid.UUID) (messageT, error) { connMutex.RLock() defer connMutex.RUnlock() return find(a, b) }, next: func(a string, b string) (messageT, error) { connMutex.RLock() defer connMutex.RUnlock() return next(a, b) }, pending: func( a string, b string, callback func(messageT) error, ) error { var ( err error rows *sql.Rows ) { connMutex.RLock() defer connMutex.RUnlock() rows, err = pending(a, b) } if err != nil { return err } return messageEach(rows, callback) }, commit: func(a string, b uuid.UUID) error { connMutex.RLock() defer connMutex.RUnlock() return commit(a, b) }, toDead: func( a string, b uuid.UUID, c uuid.UUID, ) error { connMutex.RLock() defer connMutex.RUnlock() return toDead(a, b, c) }, replay: func(a uuid.UUID, b uuid.UUID) (messageT, error) { var ( err error message messageT ) { connMutex.RLock() defer connMutex.RUnlock() message, err = replay(a, b) } if err != nil { return messageT{}, err } go notifyFn(message) return message, nil }, oneDead: func(a string, b string) (deadletterT, error) { connMutex.RLock() defer connMutex.RUnlock() return oneDead(a, b) }, allDead: func( a string, b string, callback func(deadletterT, messageT) error, ) error { var ( err error rows *sql.Rows ) { connMutex.RLock() defer connMutex.RUnlock() rows, err = allDead(a, b) } if err != nil { return err } return deadletterEach(rows, callback) }, size: func(a string) (int, error) { connMutex.RLock() defer connMutex.RUnlock() return size(a) }, count: func(a string, b string) (int, error) { connMutex.RLock() defer connMutex.RUnlock() return count(a, b) }, hasData: func(a string, b string) (bool, error) { connMutex.RLock() defer connMutex.RUnlock() return hasData(a, b) }, close: func() error { connMutex.Lock() defer connMutex.Unlock() return closeFn() }, }, nil } func newPinger[T any]() pingerT[T] { channel := make(chan T, 1) closed := false var rwmutex sync.RWMutex return pingerT[T]{ tryPing: func(x T) { rwmutex.RLock() defer rwmutex.RUnlock() if closed { return } select { case channel <- x: default: } }, onPing: func(cb func(T)) { for x := range channel { cb(x) } }, closed: func() bool { rwmutex.RLock() defer rwmutex.RUnlock() return closed }, close: func() { rwmutex.Lock() defer rwmutex.Unlock() if closed { return } close(channel) closed = true }, } } func makeSubscriptionsFuncs() subscriptionsT { var rwmutex sync.RWMutex subscriptions := subscriptionsSetM{} return subscriptionsT{ read: func(callback func(subscriptionsSetM) error) error { rwmutex.RLock() defer rwmutex.RUnlock() return callback(subscriptions) }, write: func(callback func(subscriptionsSetM) error) error { rwmutex.Lock() defer rwmutex.Unlock() return callback(subscriptions) }, } } // Try notifying the consumer that they have data to work with. If they're // already full, we simply drop the notification, as on each they'll look for // all pending messages and process them all. So dropping the event here // doesn't mean not notifying the consumer, but simply acknoledging that the // existing notifications are enough for them to work with, without letting any // message slip through. func makeNotifyFn( readFn func(func(subscriptionsSetM) error) error, pinger pingerT[struct{}], ) func(messageT) { return func(message messageT) { readFn(func(set subscriptionsSetM) error { topicSub, ok := set[message.topic] if !ok { return nil } for _, consumer := range topicSub.consumers { consumer.pinger.tryPing(struct{}{}) } waiters := topicSub.waiters[message.flowID] for _, waiter := range waiters { waiter.pinger.tryPing(message.payload) } return nil }) pinger.tryPing(struct{}{}) } } func collectClosedWaiters( set subscriptionsSetM, ) map[string]map[uuid.UUID][]string { waiters := map[string]map[uuid.UUID][]string{} for topic, topicSub := range set { waiters[topic] = map[uuid.UUID][]string{} for flowID, waitersByName := range topicSub.waiters { names := []string{} for name, waiter := range waitersByName { if (*waiter.closed)() { names = append(names, name) } } waiters[topic][flowID] = names } } return waiters } func trimEmptyLeaves(closedWaiters map[string]map[uuid.UUID][]string) { for topic, waiters := range closedWaiters { for flowID, names := range waiters { if len(names) == 0 { delete(closedWaiters[topic], flowID) } } if len(waiters) == 0 { delete(closedWaiters, topic) } } } func deleteIfEmpty(set subscriptionsSetM, topic string) { topicSub, ok := set[topic] if !ok { return } emptyConsumers := len(topicSub.consumers) == 0 emptyWaiters := len(topicSub.waiters) == 0 if emptyConsumers && emptyWaiters { delete(set, topic) } } func deleteEmptyTopics(set subscriptionsSetM) { for topic, _ := range set { deleteIfEmpty(set, topic) } } func removeClosedWaiters( set subscriptionsSetM, closedWaiters map[string]map[uuid.UUID][]string, ) { for topic, waiters := range closedWaiters { _, ok := set[topic] if !ok { continue } for flowID, names := range waiters { if set[topic].waiters[flowID] == nil { continue } for _, name := range names { delete(set[topic].waiters[flowID], name) } if len(set[topic].waiters[flowID]) == 0 { delete(set[topic].waiters, flowID) } } } deleteEmptyTopics(set) } func reapClosedWaiters( readFn func(func(subscriptionsSetM) error) error, writeFn func(func(subscriptionsSetM) error) error, ) { var closedWaiters map[string]map[uuid.UUID][]string readFn(func(set subscriptionsSetM) error { closedWaiters = collectClosedWaiters(set) return nil }) trimEmptyLeaves(closedWaiters) if len(closedWaiters) == 0 { return } writeFn(func(set subscriptionsSetM) error { removeClosedWaiters(set, closedWaiters) return nil }) } func everyNthCall[T any](n int, fn func(T)) func(T) { i := 0 return func(x T) { i++ if i == n { i = 0 fn(x) } } } func runReaper( onPing func(func(struct{})), readFn func(func(subscriptionsSetM) error) error, writeFn func(func(subscriptionsSetM) error) error, ) { onPing(everyNthCall(reaperSkipCount, func(struct{}) { reapClosedWaiters(readFn, writeFn) })) } func NewWithPrefix(databasePath string, prefix string) (IQueue, error) { subscriptions := makeSubscriptionsFuncs() pinger := newPinger[struct{}]() notifyFn := makeNotifyFn(subscriptions.read, pinger) queries, err := initDB(databasePath, prefix, notifyFn, os.Getpid()) if err != nil { return queueT{}, err } go runReaper(pinger.onPing, subscriptions.read, subscriptions.write) return queueT{ queries: queries, subscriptions: subscriptions, pinger: pinger, }, nil } func New(databasePath string) (IQueue, error) { return NewWithPrefix(databasePath, defaultPrefix) } func asPublicMessage(message messageT) Message { return Message{ ID: message.uuid, Timestamp: message.timestamp, Topic: message.topic, FlowID: message.flowID, Payload: message.payload, } } func (queue queueT) Publish(unsent UnsentMessage) (Message, error) { message, err := queue.queries.publish(unsent, uuid.New()) if err != nil { return Message{}, err } return asPublicMessage(message), nil } func registerConsumerFn(consumer consumerT) func(subscriptionsSetM) error { topicSub := topicSubscriptionT{ consumers: map[string]consumerT{}, waiters: map[uuid.UUID]map[string]waiterT{}, } return func(set subscriptionsSetM) error { topic := consumer.data.topic _, ok := set[topic] if !ok { set[topic] = topicSub } set[topic].consumers[consumer.data.name] = consumer return nil } } func registerWaiterFn(waiter waiterT) func(subscriptionsSetM) error { topicSub := topicSubscriptionT{ consumers: map[string]consumerT{}, waiters: map[uuid.UUID]map[string]waiterT{}, } waiters := map[string]waiterT{} return func(set subscriptionsSetM) error { var ( topic = waiter.data.topic flowID = waiter.data.flowID ) _, ok := set[topic] if !ok { set[topic] = topicSub } if set[topic].waiters[flowID] == nil { set[topic].waiters[flowID] = waiters } set[topic].waiters[flowID][waiter.data.name] = waiter return nil } } func makeConsumeOneFn( data consumerDataT, callback func(Message) error, successFn func(string, uuid.UUID) error, errorFn func(string, uuid.UUID, uuid.UUID) error, ) func(messageT) error { return func(message messageT) error { err := callback(asPublicMessage(message)) if err != nil { g.Info( "consumer failed", "fiinha-consumer", "topic", data.topic, "consumer", data.name, "error", err, slog.Group( "message", "id", message.id, "flow-id", message.flowID.String(), ), ) return errorFn(data.name, message.uuid, uuid.New()) } return successFn(data.name, message.uuid) } } func makeConsumeAllFn( data consumerDataT, consumeOneFn func(messageT) error, eachFn func(string, string, func(messageT) error) error, ) func(struct{}) { return func(struct{}) { err := eachFn(data.topic, data.name, consumeOneFn) if err != nil { g.Warning( "eachFn failed", "fiinha-consume-all", "topic", data.topic, "consumer", data.name, "error", err, "circuit-breaker-enabled?", false, ) } } } func makeWaitFn(channel chan []byte, closeFn func()) func([]byte) { closed := false var mutex sync.Mutex return func(payload []byte) { mutex.Lock() defer mutex.Unlock() if closed { return } closeFn() channel <- payload close(channel) closed = true } } func runConsumer(onPing func(func(struct{})), consumeAllFn func(struct{})) { consumeAllFn(struct{}{}) onPing(consumeAllFn) } func tryFinding( findFn func(string, uuid.UUID) (messageT, error), topic string, flowID uuid.UUID, waitFn func([]byte), ) { message, err := findFn(topic, flowID) if err != nil { return } waitFn(message.payload) } func (queue queueT) Subscribe( topic string, name string, callback func(Message) error, ) error { data := consumerDataT{ topic: topic, name: name, } pinger := newPinger[struct{}]() consumer := consumerT{ data: data, callback: callback, pinger: pinger, close: &pinger.close, } consumeOneFn := makeConsumeOneFn( consumer.data, consumer.callback, queue.queries.commit, queue.queries.toDead, ) consumeAllFn := makeConsumeAllFn( consumer.data, consumeOneFn, queue.queries.pending, ) err := queue.queries.take(topic, name) if err != nil { return err } queue.subscriptions.write(registerConsumerFn(consumer)) go runConsumer(pinger.onPing, consumeAllFn) return nil } type Waiter struct{ Channel <-chan []byte Close func() } func (queue queueT) WaitFor( topic string, flowID uuid.UUID, name string, ) Waiter { data := waiterDataT{ topic: topic, flowID: flowID, name: name, } pinger := newPinger[[]byte]() waiter := waiterT{ data: data, pinger: pinger, closed: &pinger.closed, close: &pinger.close, } channel := make(chan []byte, 1) waitFn := makeWaitFn(channel, (*waiter.close)) closeFn := func() { queue.subscriptions.read(func(set subscriptionsSetM) error { (*set[topic].waiters[flowID][name].close)() return nil }) } queue.subscriptions.write(registerWaiterFn(waiter)) tryFinding(queue.queries.find, topic, flowID, waitFn) go pinger.onPing(waitFn) return Waiter{channel, closeFn} } func unsubscribeIfExistsFn( topic string, name string, ) func(subscriptionsSetM) error { return func(set subscriptionsSetM) error { topicSub, ok := set[topic] if !ok { return nil } consumer, ok := topicSub.consumers[name] if !ok { return nil } (*consumer.close)() delete(set[topic].consumers, name) deleteIfEmpty(set, topic) return nil } } func (queue queueT) Unsubscribe(topic string, name string) { queue.subscriptions.write(unsubscribeIfExistsFn(topic, name)) } func cleanSubscriptions(set subscriptionsSetM) error { for _, topicSub := range set { for _, consumer := range topicSub.consumers { (*consumer.close)() } for _, waiters := range topicSub.waiters { for _, waiter := range waiters { (*waiter.close)() } } } return nil } func (queue queueT) Close() error { queue.pinger.close() return g.WrapErrors( queue.subscriptions.write(cleanSubscriptions), queue.queries.close(), ) } func topicGetopt(args argsT, w io.Writer) (argsT, bool) { if len(args.args) == 0 { fmt.Fprintf(w, "Missing TOPIC.\n") return args, false } args.topic = args.args[0] return args, true } func topicConsumerGetopt(args argsT, w io.Writer) (argsT, bool) { fs := flag.NewFlagSet("", flag.ContinueOnError) fs.Usage = func() {} fs.SetOutput(w) consumer := fs.String( "C", "default-consumer", "The name of the consumer to be used", ) if fs.Parse(args.args) != nil { return args, false } subArgs := fs.Args() if len(subArgs) == 0 { fmt.Fprintf(w, "Missing TOPIC.\n") return args, false } args.consumer = *consumer args.topic = subArgs[0] return args, true } func inExec( args argsT, queries queriesT, r io.Reader, w io.Writer, ) (int, error) { payload, err := io.ReadAll(r) if err != nil { return 1, err } unsent := UnsentMessage{ Topic: args.topic, FlowID: uuid.New(), Payload: payload, } message, err := queries.publish(unsent, uuid.New()) if err != nil { return 1, err } fmt.Fprintf(w, "%s\n", message.uuid.String()) return 0, nil } func outExec( args argsT, queries queriesT, _ io.Reader, w io.Writer, ) (int, error) { err := queries.take(args.topic, args.consumer) if err != nil { return 1, err } message, err := queries.next(args.topic, args.consumer) if err == sql.ErrNoRows { return 3, nil } if err != nil { return 1, err } fmt.Fprintln(w, string(message.payload)) return 0, nil } func commitExec( args argsT, queries queriesT, r io.Reader, w io.Writer, ) (int, error) { err := queries.take(args.topic, args.consumer) if err != nil { return 1, err } message, err := queries.next(args.topic, args.consumer) if err != nil { return 1, err } err = queries.commit(args.consumer, message.uuid) if err != nil { return 1, err } return 0, nil } func deadExec( args argsT, queries queriesT, r io.Reader, w io.Writer, ) (int, error) { err := queries.take(args.topic, args.consumer) if err != nil { return 1, err } message, err := queries.next(args.topic, args.consumer) if err != nil { return 1, err } err = queries.toDead(args.consumer, message.uuid, uuid.New()) if err != nil { return 1, err } return 0, nil } func listDeadExec( args argsT, queries queriesT, r io.Reader, w io.Writer, ) (int, error) { eachFn := func(deadletter deadletterT, _ messageT) error { fmt.Fprintf( w, "%s\t%s\t%s\n", deadletter.uuid.String(), deadletter.timestamp.Format(time.RFC3339), deadletter.consumer, ) return nil } err := queries.allDead(args.topic, args.consumer, eachFn) if err != nil { return 1, err } return 0, nil } func replayExec( args argsT, queries queriesT, r io.Reader, w io.Writer, ) (int, error) { deadletter, err := queries.oneDead(args.topic, args.consumer) if err != nil { return 1, err } _, err = queries.replay(deadletter.uuid, uuid.New()) if err != nil { return 1, err } return 0, nil } func sizeExec( args argsT, queries queriesT, r io.Reader, w io.Writer, ) (int, error) { size, err := queries.size(args.topic) if err != nil { return 1, err } fmt.Fprintln(w, size) return 0, nil } func countExec( args argsT, queries queriesT, r io.Reader, w io.Writer, ) (int, error) { count, err := queries.count(args.topic, args.consumer) if err != nil { return 1, err } fmt.Fprintln(w, count) return 0, nil } func hasDataExec( args argsT, queries queriesT, r io.Reader, w io.Writer, ) (int, error) { hasData, err := queries.hasData(args.topic, args.consumer) if err != nil { return 1, err } if hasData { return 0, nil } else { return 1, nil } } func usage(argv0 string, w io.Writer) { fmt.Fprintf( w, "Usage: %s [-f FILE] [-p PREFIX] COMMAND [OPTIONS]\n", argv0, ) } func getopt( allArgs []string, commandsMap map[string]commandT, w io.Writer, ) (argsT, commandT, int) { argv0 := allArgs[0] argv := allArgs[1:] fs := flag.NewFlagSet("", flag.ContinueOnError) fs.Usage = func() {} fs.SetOutput(w) databasePath := fs.String( "f", "fiinha.db", "The path to the file where the queue is kept", ) prefix := fs.String( "p", defaultPrefix, "The fiinha prefix of the table names", ) if fs.Parse(argv) != nil { usage(argv0, w) return argsT{}, commandT{}, 2 } subArgs := fs.Args() if len(subArgs) == 0 { fmt.Fprintf(w, "Missing COMMAND.\n") usage(argv0, w) return argsT{}, commandT{}, 2 } args := argsT{ databasePath: *databasePath, prefix: *prefix, command: subArgs[0], allArgs: allArgs, args: subArgs[1:], } command := commandsMap[args.command] if command.name == "" { fmt.Fprintf(w, "Bad COMMAND: \"%s\".\n", args.command) usage(argv0, w) return argsT{}, commandT{}, 2 } args, ok := command.getopt(args, w) if !ok { usage(argv0, w) return argsT{}, commandT{}, 2 } return args, command, 0 } func runCommand( args argsT, command commandT, stdin io.Reader, stdout io.Writer, stderr io.Writer, ) int { iqueue, err := NewWithPrefix(args.databasePath, args.prefix) if err != nil { fmt.Fprintln(stderr, err) return 1 } defer iqueue.Close() rc, err := command.exec(args, iqueue.(queueT).queries, stdin, stdout) if err != nil { fmt.Fprintln(stderr, err) } return rc } var commands = map[string]commandT{ "in": commandT{ name: "in", getopt: topicGetopt, exec: inExec, }, "out": commandT{ name: "out", getopt: topicConsumerGetopt, exec: outExec, }, "commit": commandT{ name: "commit", getopt: topicConsumerGetopt, exec: commitExec, }, "dead": commandT{ name: "dead", getopt: topicConsumerGetopt, exec: deadExec, }, "ls-dead": commandT{ name: "ls-dead", getopt: topicConsumerGetopt, exec: listDeadExec, }, "replay": commandT{ name: "replay", getopt: topicConsumerGetopt, exec: replayExec, }, "size": commandT{ name: "size", getopt: topicGetopt, exec: sizeExec, }, "count": commandT{ name: "count", getopt: topicConsumerGetopt, exec: countExec, }, "has-data": commandT{ name: "has-data", getopt: topicConsumerGetopt, exec: hasDataExec, }, } func Main() { g.Init() args, command, rc := getopt(os.Args, commands, os.Stderr) if rc != 0 { os.Exit(rc) } os.Exit(runCommand(args, command, os.Stdin, os.Stdout, os.Stderr)) }