package cracha import ( "database/sql" "encoding/hex" "encoding/json" "errors" "fmt" "runtime" "slices" "sync" "time" "fiinha" "golite" "uuid" "scrypt" g "gobang" ) const ( defaultPrefix = "cracha" rollbackErrorFmt = "rollback error: %w; while executing: %w" NEW_USER = "new-user" SEND_CONFIRMATION_REQUEST = "send-confirmation-request" FORGOT_PASSWORD_REQUEST = "forgot-password-request" day = 24 * time.Hour ) var ( SessionDuration = 7 * day RegisterTimeout = 15 * time.Second ErrPassMismatch = errors.New("cracha: password confirmation mismatch") ErrTooShort = errors.New("cracha: password too short") ErrTimeout = errors.New("cracha: timeout when creating user") ErrRegistered = errors.New("cracha: user already registered") ErrBadCombo = errors.New("cracha: bad username/passphrase combo") ErrUnconfirmed = errors.New("cracha: email is not confirmed") ErrRevokedSession = errors.New("cracha: this session was revoked") ErrSessionExpired = errors.New("cracha: session expired") ) type dbconfigT struct{ shared *sql.DB dbpath string prefix string } type queryT struct{ write string read string session string } type queriesT struct{ register func(uuid.UUID, string, []byte, []byte) (userT, error) sendToken func(uuid.UUID, string) error confirm func(string, uuid.UUID) (sessionT, error) byEmail func(string) (userT, error) userByUUID func(uuid.UUID) (userT, error) login func(uuid.UUID, uuid.UUID) (sessionT, error) refresh func(uuid.UUID, uuid.UUID) (sessionT, error) reset func(int64, []byte, uuid.UUID) (sessionT, error) change func(uuid.UUID, []byte) (sessionT, error) byUUID func(uuid.UUID) (sessionT, error) logout func(uuid.UUID) error outOthers func(uuid.UUID) error outAll func(uuid.UUID) error close func() error } type userT struct{ id int64 timestamp time.Time uuid uuid.UUID email string salt []byte pwhash []byte confirmed bool } type User struct{ UUID uuid.UUID Confirmed bool } type sessionT struct{ id int64 timestamp time.Time uuid uuid.UUID userID uuid.UUID revoked_at *time.Time } type Session struct{ UUID uuid.UUID } type consumerT struct{ topic string handlerFn func(authT) func(fiinha.Message) error } type authT struct{ queue fiinha.IQueue db *sql.DB queries queriesT hasher func(scrypt.HashInputT) ([]byte, error) close func() } type IAuth interface{ Register(string, string, string) (User, error) ResendConfirmation(string) error ConfirmEmail(string) (Session, error) LoginEmail(string, string) (Session, error) ForgotPassword(string) error Refresh(Session) (Session, error) ResetPassword(uuid.UUID, string, string) (Session, error) ChangePassword(uuid.UUID, string, string, string) (Session, error) Logout(Session) error LogoutOthers(Session) error LogoutAll(Session) error Close() error } func validateSession(session sessionT) (sessionT, error) { // FIXME: implement return session, nil } func tryRollback(tx *sql.Tx, err error) error { rollbackErr := tx.Rollback() if rollbackErr != nil { return fmt.Errorf( rollbackErrorFmt, rollbackErr, err, ) } return err } func inTx(db *sql.DB, fn func(*sql.Tx) error) error { tx, err := db.Begin() if err != nil { return err } err = fn(tx) if err != nil { return tryRollback(tx, err) } err = tx.Commit() if err != nil { return tryRollback(tx, err) } return nil } func serialized[A any, B any](callback func(...A) B) (func(...A) B, func()) { in := make(chan []A) out := make(chan B) closed := false var ( closeWg sync.WaitGroup closeMutex sync.Mutex ) closeWg.Add(1) go func() { for input := range in { out <- callback(input...) } close(out) closeWg.Done() }() fn := func(input ...A) B { in <- input return (<- out) } closeFn := func() { closeMutex.Lock() defer closeMutex.Unlock() if closed { return } close(in) closed = true closeWg.Wait() } return fn, closeFn } func execSerialized(query string, db *sql.DB) (func(...any) error, func()) { return serialized(func(args ...any) error { return inTx(db, func(tx *sql.Tx) error { _, err := tx.Exec(query, args...) return err }) }) } func createTablesSQL(prefix string) queryT { const tmpl_write = ` CREATE TABLE IF NOT EXISTS "%s_users" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, email TEXT NOT NULL UNIQUE, salt BLOB NOT NULL UNIQUE, pwhash BLOB NOT NULL ) STRICT; CREATE TABLE IF NOT EXISTS "%s_confirmation_attempts" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), -- uuid BLOB NOT NULL UNIQUE, user_id INTEGER NOT NULL REFERENCES "%s_users"(id), token TEXT NOT NULL UNIQUE ) STRICT; CREATE TABLE IF NOT EXISTS "%s_user_confirmations" ( id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), user_id INTEGER NOT NULL REFERENCES "%s_users"(id) UNIQUE, attempt_id INTEGER NOT NULL REFERENCES "%s_confirmation_attempts"(id) UNIQUE ) STRICT; CREATE TABLE IF NOT EXISTS "%s_user_changes" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), user_id INTEGER NOT NULL REFERENCES "%s_users"(id), attribute TEXT NOT NULL, value TEXT NOT NULL, op INT NOT NULL CHECK(op IN (0, 1)) ) STRICT; -- CREATE TABLE IF NOT EXISTS "%s_tokens" ( -- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, -- timestamp TEXT NOT NULL DEFAULT (%s), -- uuid BLOB NOT NULL UNIQUE, -- type TEXT NOT NULL -- ) STRICT; CREATE TABLE IF NOT EXISTS "%s_roles" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL REFERENCES "%s_users"(id), role TEXT NOT NULL, UNIQUE (user_id, role) ) STRICT; CREATE TABLE IF NOT EXISTS "%s_role_changes" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), user_id INTEGER NOT NULL REFERENCES "%s_roles"(id), role TEXT NOT NULL, op INT NOT NULL CHECK(op IN (0, 1)) ) STRICT; CREATE TABLE IF NOT EXISTS "%s_sessions" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, user_id INTEGER NOT NULL REFERENCES "%s_users"(id) -- type TEXT NOT NULL, -- revoked_at TEXT, -- revoker_id INTEGER REFERENCES "%s_users"(id), -- FIXME: add provenance: login, refresh, confirmation, etc. ) STRICT; CREATE TABLE IF NOT EXISTS "%s_attempts" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), user_id INTEGER REFERENCES "%s_users"(id), session_id INTEGER REFERENCES "%s_sessions"(id) ) STRICT; CREATE TABLE IF NOT EXISTS "%s_audit" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, attribute TEXT NOT NULL, value TEXT NOT NULL, op INT NOT NULL CHECK(op IN (0, 1)) ) STRICT; ` return queryT{ write: fmt.Sprintf( tmpl_write, prefix, g.SQLiteNow, prefix, g.SQLiteNow, prefix, prefix, g.SQLiteNow, prefix, prefix, prefix, g.SQLiteNow, prefix, prefix, g.SQLiteNow, prefix, prefix, prefix, g.SQLiteNow, prefix, prefix, g.SQLiteNow, prefix, prefix, prefix, g.SQLiteNow, prefix, prefix, prefix, g.SQLiteNow, ), } } func createTables(db *sql.DB, prefix string) error { q := createTablesSQL(prefix) return inTx(db, func(tx *sql.Tx) error { _, err := tx.Exec(q.write) return err }) } func registerSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_users" (uuid, email, salt, pwhash) VALUES (?, ?, ?, ?) RETURNING id, timestamp; ` const tmpl_read = ` SELECT id, timestamp from "%s_users" WHERE uuid = ?; ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), read: fmt.Sprintf(tmpl_read, prefix), } } func registerStmt( cfg dbconfigT, ) ( func(uuid.UUID, string, []byte, []byte) (userT, error), func() error, error, ) { q := registerSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } fn := func( userID uuid.UUID, email string, salt []byte, pwhash []byte, ) (userT, error) { user := userT{ uuid: userID, email: email, salt: salt, pwhash: pwhash, confirmed: false, } var timestr string user_id_bytes := userID[:] err := writeStmt.QueryRow( user_id_bytes, email, salt, pwhash, ).Scan(&user.id, ×tr) if err != nil { return userT{}, err } user.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return userT{}, err } return user, nil } closeFn := func() error { return g.SomeFnError(writeStmt.Close, readStmt.Close) } return fn, closeFn, nil } func sendTokenSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_confirmation_attempts" (user_id, token) VALUES ( (SELECT id FROM "%s_users" WHERE uuid = ?), ? ) ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix, prefix), } } func sendTokenStmt( cfg dbconfigT, ) (func(uuid.UUID, string) error, func() error, error) { q := sendTokenSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(userID uuid.UUID, token string) error { user_id_bytes := userID[:] _, err := writeStmt.Exec(user_id_bytes, token) return err } return fn, writeStmt.Close, nil } func confirmSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_user_confirmations" (user_id, attempt_id) VALUES (?, ?); ` const tmpl_read = ` SELECT "%s_confirmation_attempts".id, "%s_confirmation_attempts".user_id, "%s_users".uuid FROM "%s_confirmation_attempts" JOIN "%s_users" ON "%s_confirmation_attempts".user_id = "%s_users".id WHERE token = ?; ` const tmpl_session = ` INSERT INTO "%s_sessions" (uuid, user_id) VALUES (?, ?) RETURNING id, timestamp; ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), read: fmt.Sprintf( tmpl_read, prefix, prefix, prefix, prefix, prefix, prefix, prefix, ), session: fmt.Sprintf(tmpl_session, prefix), } } func confirmStmt( cfg dbconfigT, ) (func(string, uuid.UUID) (sessionT, error), func() error, error) { q := confirmSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, g.WrapErrors(writeStmt.Close(), err) } sessionStmt, err := cfg.shared.Prepare(q.session) if err != nil { return nil, nil, g.WrapErrors( writeStmt.Close(), readStmt.Close(), err, ) } fn := func(token string, sessionID uuid.UUID) (sessionT, error) { session := sessionT{ uuid: sessionID, } var ( user_id int64 attempt_id int64 user_id_bytes []byte ) err := readStmt.QueryRow(token).Scan( &attempt_id, &user_id, &user_id_bytes, ) if err != nil { return sessionT{}, err } session.userID = uuid.UUID(user_id_bytes) _, err = writeStmt.Exec(user_id, attempt_id) if err != nil { return sessionT{}, err } session_id_bytes := sessionID[:] var timestr string err = sessionStmt.QueryRow( session_id_bytes, user_id, ).Scan(&session.id, ×tr) if err != nil { return sessionT{}, err } session.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return sessionT{}, err } return validateSession(session) } closeFn := func() error { return g.SomeFnError( writeStmt.Close, readStmt.Close, sessionStmt.Close, ) } return fn, closeFn, nil } func byEmailSQL(prefix string) queryT { const tmpl_read = ` SELECT id, timestamp, uuid, salt, pwhash, ( CASE WHEN EXISTS ( SELECT id FROM "%s_user_confirmations" WHERE user_id = ( SELECT id FROM "%s_users" WHERE email = ? ) ) THEN 1 ELSE 0 END ) as confirmed FROM "%s_users" WHERE email = ?; ` return queryT{ read: fmt.Sprintf(tmpl_read, prefix, prefix, prefix), } } func byEmailStmt( cfg dbconfigT, ) (func(string) (userT, error), func() error, error) { q := byEmailSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } fn := func(email string) (userT, error) { user := userT{ email: email, } var ( timestr string user_id_bytes []byte ) err := readStmt.QueryRow(email, email).Scan( &user.id, ×tr, &user_id_bytes, &user.salt, &user.pwhash, &user.confirmed, ) if err != nil { return userT{}, err } user.uuid = uuid.UUID(user_id_bytes) user.timestamp, err = time.Parse(time.RFC3339Nano,timestr) if err != nil { return userT{}, err } return user, nil } return fn, readStmt.Close, nil } func userByUUIDSQL(prefix string) queryT { const tmpl_read = ` SELECT id, timestamp, email, salt, pwhash, ( CASE WHEN EXISTS ( SELECT id FROM "%s_user_confirmations" WHERE user_id = ( SELECT id FROM "%s_users" WHERE uuid = ? ) ) THEN 1 ELSE 0 END ) as confirmed FROM "%s_users" WHERE uuid = ?; ` return queryT{ read: fmt.Sprintf(tmpl_read, prefix, prefix, prefix), } } func userByUUIDStmt( cfg dbconfigT, ) (func(uuid.UUID) (userT, error), func() error, error) { q := userByUUIDSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } fn := func(userID uuid.UUID) (userT, error) { user := userT{ uuid: userID, } var timestr string err := readStmt.QueryRow(userID[:], userID[:]).Scan( &user.id, ×tr, &user.email, &user.salt, &user.pwhash, &user.confirmed, ) if err != nil { return userT{}, err } user.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return userT{}, err } return user, nil } return fn, readStmt.Close, nil } func loginSQL(prefix string) queryT { const tmpl_session = ` INSERT INTO "%s_sessions" (uuid, user_id) VALUES ( ?, (SELECT id FROM "%s_users" WHERE uuid = ?) ) RETURNING id, timestamp; ` return queryT{ session: fmt.Sprintf(tmpl_session, prefix, prefix), } } func loginStmt( cfg dbconfigT, ) (func(uuid.UUID, uuid.UUID) (sessionT, error), func() error, error) { q := loginSQL(cfg.prefix) sessionStmt, err := cfg.shared.Prepare(q.session) if err != nil { return nil, nil, err } fn := func(userID uuid.UUID, sessionID uuid.UUID) (sessionT, error) { session := sessionT{ uuid: sessionID, userID: userID, } user_id_bytes := userID[:] session_id_bytes := sessionID[:] var timestr string err := sessionStmt.QueryRow( session_id_bytes, user_id_bytes, ).Scan( &session.id, ×tr, ) if err != nil { return sessionT{}, err } session.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return sessionT{}, err } return validateSession(session) } return fn, sessionStmt.Close, nil } func refreshSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_sessions" (uuid, user_id) VALUES ( ?, (SELECT user_id FROM "%s_sessions" WHERE uuid = ?) ) RETURNING id, timestamp, ( SELECT "%s_users".uuid FROM "%s_users" JOIN "%s_sessions" ON "%s_users".id = "%s_sessions".user_id WHERE "%s_sessions".uuid = ? ) AS userID; ` return queryT{ write: fmt.Sprintf( tmpl_write, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, ), } } func refreshStmt( cfg dbconfigT, ) (func(uuid.UUID, uuid.UUID) (sessionT, error), func() error, error) { q := refreshSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } fn := func( sessionID uuid.UUID, newSessionID uuid.UUID, ) (sessionT, error) { session := sessionT{ uuid: newSessionID, } session_id_bytes := sessionID[:] new_session_id_bytes := newSessionID[:] var ( timestr string user_id_bytes []byte ) err := writeStmt.QueryRow( new_session_id_bytes, session_id_bytes, session_id_bytes, ).Scan(&session.id, ×tr, &user_id_bytes) if err != nil { return sessionT{}, err } session.userID = uuid.UUID(user_id_bytes) session.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return sessionT{}, err } return validateSession(session) } return fn, writeStmt.Close, nil } func resetSQL(prefix string) queryT { const tmpl_write = ` -- INSERT SOMETHING %s ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), } } func resetStmt( cfg dbconfigT, ) (func(int64, []byte, uuid.UUID) (sessionT, error), func() error, error) { q := resetSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(id int64, pwhash []byte, token uuid.UUID) (sessionT, error) { var session sessionT err := writeStmt.QueryRow(id, pwhash, token).Scan(&session) if err != nil { return sessionT{}, err } return validateSession(session) } return fn, writeStmt.Close, nil } func changeSQL(prefix string) queryT { const tmpl_write = ` -- INSERT SOMETHING %s ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), } } func changeStmt( cfg dbconfigT, ) (func(uuid.UUID, []byte) (sessionT, error), func() error, error) { q := changeSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(uuid uuid.UUID, pwhash []byte) (sessionT, error) { var session sessionT err := writeStmt.QueryRow(uuid, pwhash).Scan(&session) if err != nil { return sessionT{}, err } return validateSession(session) } return fn, writeStmt.Close, nil } func byUUIDSQL(prefix string) queryT { const tmpl_read = ` -- INSERT SOMETHING %s ` return queryT{ read: fmt.Sprintf(tmpl_read, prefix), } } func byUUIDStmt( cfg dbconfigT, ) (func(uuid.UUID) (sessionT, error), func() error, error) { q := byUUIDSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } fn := func(sessionID uuid.UUID) (sessionT, error) { var session sessionT err := readStmt.QueryRow(sessionID).Scan(&session) if err != nil { return sessionT{}, err } return validateSession(session) } return fn, readStmt.Close, nil } func logoutSQL(prefix string) queryT { const tmpl_write = ` -- INSERT SOMETHING %s ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), } } func logoutStmt( cfg dbconfigT, ) (func(uuid.UUID) error, func() error, error) { q := logoutSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(sessionID uuid.UUID) error { _, err := writeStmt.Exec(sessionID) return err } return fn, writeStmt.Close, nil } func outOthersSQL(prefix string) queryT { const tmpl_write = ` -- INSERT SOMETHING %s ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), } } func outOthersStmt( cfg dbconfigT, ) (func(uuid.UUID) error, func() error, error) { q := outOthersSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(sessionID uuid.UUID) error { _, err := writeStmt.Exec(sessionID) return err } return fn, writeStmt.Close, nil } func outAllSQL(prefix string) queryT { const tmpl_write = ` -- INSERT SOMETHING %s ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), } } func outAllStmt( cfg dbconfigT, ) (func(uuid.UUID) error, func() error, error) { q := outAllSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(sessionID uuid.UUID) error { _, err := writeStmt.Exec(sessionID) return err } return fn, writeStmt.Close, nil } func initDB( dbpath string, prefix string, ) (queriesT, error) { err := g.ValidateSQLTablePrefix(prefix) if err != nil { return queriesT{}, err } shared, err := sql.Open(golite.DriverName, dbpath) if err != nil { return queriesT{}, err } cfg := dbconfigT{ shared: shared, dbpath: dbpath, prefix: prefix, } createTablesErr := createTables(shared, prefix) register, registerClose, registerErr := registerStmt(cfg) sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(cfg) confirm, confirmClose, confirmErr := confirmStmt(cfg) byEmail, byEmailClose, byEmailErr := byEmailStmt(cfg) userByUUID, userByUUIDClose, userByUUIDErr := userByUUIDStmt(cfg) login, loginClose, loginErr := loginStmt(cfg) refresh, refreshClose, refreshErr := refreshStmt(cfg) reset, resetClose, resetErr := resetStmt(cfg) change, changeClose, changeErr := changeStmt(cfg) byUUID, byUUIDClose, byUUIDErr := byUUIDStmt(cfg) logout, logoutClose, logoutErr := logoutStmt(cfg) outOthers, outOthersClose, outOthersErr := outOthersStmt(cfg) outAll, outAllClose, outAllErr := outAllStmt(cfg) err = g.SomeError( createTablesErr, registerErr, sendTokenErr, confirmErr, byEmailErr, userByUUIDErr, loginErr, refreshErr, resetErr, changeErr, byUUIDErr, logoutErr, outOthersErr, outAllErr, ) if err != nil { return queriesT{}, err } close := func() error { return g.SomeFnError( registerClose, sendTokenClose, confirmClose, byEmailClose, userByUUIDClose, loginClose, refreshClose, resetClose, changeClose, byUUIDClose, logoutClose, outOthersClose, outAllClose, ) } var connMutex sync.RWMutex return queriesT{ register: func( a uuid.UUID, b string, c []byte, d []byte, ) (userT, error) { connMutex.RLock() defer connMutex.RUnlock() return register(a, b, c, d) }, sendToken: func(a uuid.UUID, b string) error { connMutex.RLock() defer connMutex.RUnlock() return sendToken(a, b) }, confirm: func(a string, b uuid.UUID) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return confirm(a, b) }, byEmail: func(a string) (userT, error) { connMutex.RLock() defer connMutex.RUnlock() return byEmail(a) }, userByUUID: func(a uuid.UUID) (userT, error) { connMutex.RLock() defer connMutex.RUnlock() return userByUUID(a) }, login: func(a uuid.UUID, b uuid.UUID) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return login(a, b) }, refresh: func(a uuid.UUID, b uuid.UUID) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return refresh(a, b) }, reset: func(a int64, b []byte, c uuid.UUID) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return reset(a, b, c) }, change: func(a uuid.UUID, b []byte) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return change(a, b) }, byUUID: func(a uuid.UUID) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return byUUID(a) }, logout: func(a uuid.UUID) error { connMutex.RLock() defer connMutex.RUnlock() return logout(a) }, outOthers: func(a uuid.UUID) error { connMutex.RLock() defer connMutex.RUnlock() return outOthers(a) }, outAll: func(a uuid.UUID) error { connMutex.RLock() defer connMutex.RUnlock() return outAll(a) }, close: func() error { connMutex.Lock() defer connMutex.Unlock() return close() }, }, nil } func newUserHandler(auth authT) func(fiinha.Message) error { return func(message fiinha.Message) error { return nil } } func sendConfirmationRequestHandler(auth authT) func(fiinha.Message) error { return func(message fiinha.Message) error { return nil } } func forgotPasswordRequestHandler(auth authT) func(fiinha.Message) error { return func(message fiinha.Message) error { return nil } } var consumers = []consumerT{ consumerT{ topic: NEW_USER, handlerFn: newUserHandler, }, consumerT{ topic: SEND_CONFIRMATION_REQUEST, handlerFn: sendConfirmationRequestHandler, }, consumerT{ topic: FORGOT_PASSWORD_REQUEST, handlerFn: forgotPasswordRequestHandler, }, } func registerConsumers(auth authT, consumers []consumerT, prefix string) error { for _, consumer := range consumers { err := auth.queue.Subscribe( consumer.topic, prefix + "-" + consumer.topic, consumer.handlerFn(auth), ) if err != nil { return err } } return nil } func unregisterConsumers( queue fiinha.IQueue, consumers []consumerT, prefix string, ) { for _, consumer := range consumers { queue.Unsubscribe(consumer.topic, prefix + "-" + consumer.topic) } } type resultT[T any] struct{ value T err error } type taggedT[T any] struct{ id int value T } func startRunner[A any, B any]( in <-chan taggedT[A], out chan<- taggedT[B], fn func(A) B, done func(), ) { for input := range in { out <- taggedT[B]{ id: input.id, value: fn(input.value), } } done() } func makePoolRunner[A any, B any](count int, fn func(A) B) (func(A) B, func()) { var wg sync.WaitGroup wg.Add(count) in := make(chan taggedT[A]) out := make(chan taggedT[B]) for _ = range count { go startRunner(in, out, fn, wg.Done) } var mutex sync.Mutex m := map[int]chan B{} id := 0 go func() { for output := range out { mutex.Lock() defer mutex.Unlock() m[output.id] <- output.value close(m[output.id]) delete(m, output.id) } }() poolRunFn := func(input A) B { c := make(chan B) { mutex.Lock() defer mutex.Unlock() m[id] = c id++ } in <- taggedT[A]{ id: id, value: input, } return <- c } close := func() { close(in) wg.Wait() close(out) } return poolRunFn, close } func asResult[A any, B any](fn func(A) (B, error)) func(A) resultT[B] { return func(input A) resultT[B] { output, err := fn(input) return resultT[B]{ value: output, err: err, } } } func unwrapResult[A any, B any](fn func(A) resultT[B]) func(A) (B, error) { return func(input A) (B, error) { result := fn(input) return result.value, result.err } } func NewWithPrefix(databasePath string, prefix string) (IAuth, error) { queue, err := fiinha.New(databasePath) if err != nil { return authT{}, err } queries, err := initDB(databasePath, prefix) if err != nil { queue.Close() return authT{}, err } numCPU := runtime.NumCPU() hasher, closeHasher := makePoolRunner(numCPU, asResult(scrypt.Hash)) closeFn := func() { queue.Close() unregisterConsumers(queue, consumers, prefix) closeHasher() } auth := authT{ queue: queue, queries: queries, hasher: unwrapResult(hasher), close: closeFn, } err = registerConsumers(auth, consumers, prefix) if err != nil { closeFn() return authT{}, err } return auth, nil } func New(databasePath string) (IAuth, error) { return NewWithPrefix(databasePath, defaultPrefix) } func newUserPayload(email string, salt []byte, pwhash []byte) ([]byte, error) { data := make(map[string]interface{}) data["email"] = email data["salt"] = hex.EncodeToString(salt) data["pwhash"] = hex.EncodeToString(pwhash) return json.Marshal(data) } func asPublicUser(user userT) (User, error) { return User{}, nil } func (auth authT) Register( email string, password string, confirmPassword string, ) (User, error) { if password != confirmPassword { return User{}, ErrPassMismatch } if len(password) < scrypt.MinimumPasswordLength { return User{}, ErrTooShort } // special check for sql.ErrNoRows to combat enumeration attacks. // FIXME: how so? _, lookupErr := auth.queries.byEmail(email) if lookupErr != nil && lookupErr != sql.ErrNoRows { return User{}, lookupErr } salt, err := scrypt.Salt() if err != nil { return User{}, err } input := scrypt.HashInputT{ Password: []byte(password), Salt: salt, } hash, err := auth.hasher(input) if err != nil { return User{}, err } /* We also try to register anyway, to prevent disk IO timing attacks. / FIXME: how so? */ flowID := uuid.New() waiter := auth.queue.WaitFor(NEW_USER, flowID, "register") defer waiter.Close() payload, err := newUserPayload(email, salt, hash) if err != nil { return User{}, err } unsent := fiinha.UnsentMessage{ Topic: NEW_USER, FlowID: flowID, Payload: payload, } _, err = auth.queue.Publish(unsent) if err != nil { return User{}, err } user := userT{} select { case <-time.After(RegisterTimeout): err = ErrTimeout case <-waiter.Channel: user, err = auth.queries.byEmail(email) } return asPublicUser(user) } func sendConfirmationMessage( email string, flowID uuid.UUID, ) (fiinha.UnsentMessage, error) { data := make(map[string]interface{}) data["email"] = email payload, err := json.Marshal(data) if err != nil { return fiinha.UnsentMessage{}, err } return fiinha.UnsentMessage{ Topic: SEND_CONFIRMATION_REQUEST, FlowID: flowID, Payload: payload, }, nil } func (auth authT) ResendConfirmation(email string) error { unsent, err := sendConfirmationMessage(email, uuid.New()) if err != nil { return err } _, err = auth.queue.Publish(unsent) return err } /* func asPublicSession(session sessionT) Session { _, err := validateSession(session) if err != nil { panic(fmt.Errorf( "session must have been validated at this stage: %w", err, )) } return Session{} } */ func asPublicSession(session sessionT) (Session, error) { return Session{}, nil } func (auth authT) ConfirmEmail(token string) (Session, error) { session, err := auth.queries.confirm("token FIXME", uuid.New()) if err != nil { return Session{}, err } return asPublicSession(session) } func (auth authT) LoginEmail( email string, password string, ) (Session, error) { if len(password) < scrypt.MinimumPasswordLength { return Session{}, ErrTooShort } // special check for sql.ErrNoRows to combat enumeration attacks. // FIXME: how so? user, err := auth.queries.byEmail(email) if err != nil && err != sql.ErrNoRows { return Session{}, err } input := scrypt.HashInputT{ Password: []byte(password), Salt: user.salt, } hash, err := auth.hasher(input) if err != nil { return Session{}, err } ok := slices.Equal(hash, user.pwhash) if !ok { return Session{}, ErrBadCombo } if !user.confirmed { return Session{}, ErrUnconfirmed } session, err := auth.queries.login(user.uuid, uuid.New()) if err != nil { return Session{}, err } return asPublicSession(session) } func forgotPasswordMessage( email string, flowID uuid.UUID, ) (fiinha.UnsentMessage, error) { data := make(map[string]interface{}) data["email"] = email payload, err := json.Marshal(data) if err != nil { return fiinha.UnsentMessage{}, err } return fiinha.UnsentMessage{ Topic: FORGOT_PASSWORD_REQUEST, FlowID: flowID, Payload: payload, }, nil } func (auth authT) ForgotPassword(email string) error { // special check for sql.ErrNoRows to combat enumeration attacks. user, err := auth.queries.byEmail(email) if err != nil && err != sql.ErrNoRows { return err } unsent, err := forgotPasswordMessage(user.email, uuid.New()) if err != nil { return err } _, err = auth.queue.Publish(unsent) return err } /* func checkSession(session sessionT, now time.Time) error { if session.revoked_at != nil { return ErrRevokedSession } if session.timestamp.Add(SessionDuration).After(now) { return ErrSessionExpired } return nil } func validateSession( lookupFn func(uuid.UUID) (sessionT, error), session sessionT, ) error { dbSession, err := lookupFn(session.uuid) if err != nil { return err } return checkSession(dbSession, time.Now()) } */ func (auth authT) Refresh(session Session) (Session, error) { /* err := validateSession(auth.queries.byUUID, session) if err != nil { return sessionT{}, err } */ // return auth.queries.refresh(session.uuid, uuid.New()) newSession, err := auth.queries.refresh(session.UUID, uuid.New()) if err != nil { return Session{}, err } return asPublicSession(newSession) } func (auth authT) ResetPassword( token uuid.UUID, password string, confirmPassword string, ) (Session, error) { if password != confirmPassword { return Session{}, ErrPassMismatch } if len(password) < scrypt.MinimumPasswordLength { return Session{}, ErrTooShort } user := userT{} input := scrypt.HashInputT{ Password: []byte(password), Salt: user.salt, } pwhash, err := auth.hasher(input) if err != nil { return Session{}, err } var nextFn func() (sessionT, error) if user.confirmed { nextFn = func() (sessionT, error) { return auth.queries.reset(user.id, pwhash, token) } } else { nextFn = func() (sessionT, error) { // return auth.queries.confirm(user.id, pwhash, token) return auth.queries.confirm("token FIXME", uuid.New()) } } session, err := nextFn() if err != nil { return Session{}, err } return asPublicSession(session) } func (auth authT) ChangePassword( userID uuid.UUID, currentPassword string, newPassword string, confirmNewPassword string, ) (Session, error) { if newPassword != confirmNewPassword { return Session{}, ErrPassMismatch } if len(newPassword) < scrypt.MinimumPasswordLength { return Session{}, ErrTooShort } user, err := auth.queries.userByUUID(userID) if err != nil { return Session{}, err } input := scrypt.HashInputT{ Password: []byte(newPassword), Salt: user.salt, } pwhash, err := auth.hasher(input) if err != nil { return Session{}, err } if !user.confirmed { return Session{}, ErrUnconfirmed } session, err := auth.queries.change(user.uuid, pwhash) if err != nil { return Session{}, nil } return asPublicSession(session) } func runLogout( lookupFn func(uuid.UUID) (sessionT, error), sessionID uuid.UUID, queryFn func(uuid.UUID) error, ) error { /* err := validateSession(lookupFn, session) if err != nil { return err } */ return queryFn(sessionID) } func (auth authT) Logout(session Session) error { return runLogout( auth.queries.byUUID, session.UUID, auth.queries.logout, ) } func (auth authT) LogoutOthers(session Session) error { return runLogout( auth.queries.byUUID, session.UUID, auth.queries.outOthers, ) } func (auth authT) LogoutAll(session Session) error { return runLogout( auth.queries.byUUID, session.UUID, auth.queries.outAll, ) } func (auth authT) Close() error { return g.WrapErrors(auth.queries.close(), auth.db.Close()) } func Main() { // FIXME }