package gracha import ( "context" "database/sql" "encoding/hex" "encoding/json" "errors" "fmt" "runtime" "slices" "sync" "time" "guuid" "q" "scrypt" g "gobang" ) const ( defaultPrefix = "gracha" rollbackErrorFmt = "rollback error: %w; while executing: %w" NEW_USER = "new-user" SEND_CONFIRMATION_REQUEST = "send-confirmation-request" FORGOT_PASSWORD_REQUEST = "forgot-password-request" day = 24 * time.Hour ) var ( SessionDuration = 7 * day RegisterTimeout = 15 * time.Second ErrPassMismatch = errors.New("gracha: password confirmation mismatch") ErrTooShort = errors.New("gracha: password too short") ErrTimeout = errors.New("gracha: timeout when creating user") ErrRegistered = errors.New("gracha: user already registered") ErrBadCombo = errors.New("gracha: bad username/passphrase combo") ErrUnconfirmed = errors.New("gracha: email is not confirmed") ErrRevokedSession = errors.New("gracha: this session was revoked") ErrSessionExpired = errors.New("gracha: session expired") ) type queryT struct{ write string read string session string } type queriesT struct{ register func(guuid.UUID, string, []byte, []byte) (userT, error) sendToken func(guuid.UUID, string) error confirm func(string, guuid.UUID) (sessionT, error) byEmail func(string) (userT, error) login func(guuid.UUID, guuid.UUID) (sessionT, error) refresh func(guuid.UUID, guuid.UUID) (sessionT, error) reset func(int64, []byte, guuid.UUID) (sessionT, error) change func(int64, []byte) (sessionT, error) byUUID func(guuid.UUID) (sessionT, error) logout func(guuid.UUID) error outOthers func(guuid.UUID) error outAll func(guuid.UUID) error close func() error } type userT struct{ id int64 timestamp time.Time uuid guuid.UUID email string salt []byte pwhash []byte confirmed bool } type sessionT struct{ id int64 timestamp time.Time uuid guuid.UUID userID guuid.UUID // type_ string revoked_at *time.Time } type consumerT struct{ topic string handlerFn func(authT) func(q.Message) error } type authT struct{ queries queriesT queue q.IQueue hasher func(scrypt.HashInput) ([]byte, error) close func() } type IAuth interface{ Register(string, string, string) (userT, error) ResendConfirmation(string) error ConfirmEmail(string) (sessionT, error) LoginEmail(string, string) (sessionT, error) ForgotPassword(string) error Refresh(sessionT) (sessionT, error) ResetPassword(guuid.UUID, string, string) (sessionT, error) ChangePassword(userT, string, string, string) (sessionT, error) Logout(sessionT) error LogoutOthers(sessionT) error LogoutAll(sessionT) error Close() error } func tryRollback(db *sql.DB, ctx context.Context, err error) error { _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;") if rollbackErr != nil { return fmt.Errorf( rollbackErrorFmt, rollbackErr, err, ) } return err } func inTx(db *sql.DB, fn func(context.Context) error) error { ctx := context.Background() _, err := db.ExecContext(ctx, "BEGIN IMMEDIATE;") if err != nil { return err } err = fn(ctx) if err != nil { return tryRollback(db, ctx, err) } _, err = db.ExecContext(ctx, "COMMIT;") if err != nil { return tryRollback(db, ctx, err) } return nil } func createTablesSQL(prefix string) queryT { const tmpl_write = ` CREATE TABLE IF NOT EXISTS "%s_users" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, email TEXT NOT NULL UNIQUE, salt BLOB NOT NULL UNIQUE, pwhash BLOB NOT NULL ); CREATE TABLE IF NOT EXISTS "%s_confirmation_attempts" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), -- uuid BLOB NOT NULL UNIQUE, user_id INTEGER NOT NULL REFERENCES "%s_users"(id), token TEXT NOT NULL UNIQUE ); CREATE TABLE IF NOT EXISTS "%s_user_confirmations" ( id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), user_id INTEGER NOT NULL REFERENCES "%s_users"(id) UNIQUE, attempt_id INTEGER NOT NULL REFERENCES "%s_confirmation_attempts"(id) UNIQUE ); CREATE TABLE IF NOT EXISTS "%s_user_changes" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), user_id INTEGER NOT NULL REFERENCES "%s_users"(id), attribute TEXT NOT NULL, value TEXT NOT NULL, op BOOLEAN NOT NULL ); -- CREATE TABLE IF NOT EXISTS "%s_tokens" ( -- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, -- timestamp TEXT NOT NULL DEFAULT (%s), -- uuid BLOB NOT NULL UNIQUE, -- type TEXT NOT NULL -- ); CREATE TABLE IF NOT EXISTS "%s_roles" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL REFERENCES "%s_users"(id), role TEXT NOT NULL, UNIQUE (user_id, role) ); CREATE TABLE IF NOT EXISTS "%s_role_changes" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), user_id INTEGER NOT NULL REFERENCES "%s_roles"(id), role TEXT NOT NULL, op BOOLEAN NOT NULL ); CREATE TABLE IF NOT EXISTS "%s_sessions" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, user_id INTEGER NOT NULL REFERENCES "%s_users"(id) -- type TEXT NOT NULL, -- revoked_at TEXT, -- revoker_id INTEGER REFERENCES "%s_users"(id), -- FIXME: add provenance: login, refresh, confirmation, etc. ); CREATE TABLE IF NOT EXISTS "%s_attempts" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), user_id INTEGER REFERENCES "%s_users"(id), session_id INTEGER REFERENCES "%s_sessions"(id) ); CREATE TABLE IF NOT EXISTS "%s_audit" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, attribute TEXT NOT NULL, value TEXT NOT NULL, op BOOLEAN NOT NULL ); ` return queryT{ write: fmt.Sprintf( tmpl_write, prefix, g.SQLiteNow, prefix, g.SQLiteNow, prefix, prefix, g.SQLiteNow, prefix, prefix, prefix, g.SQLiteNow, prefix, prefix, g.SQLiteNow, prefix, prefix, prefix, g.SQLiteNow, prefix, prefix, g.SQLiteNow, prefix, prefix, prefix, g.SQLiteNow, prefix, prefix, prefix, g.SQLiteNow, ), } } func createTables(db *sql.DB, prefix string) error { q := createTablesSQL(prefix) return inTx(db, func(ctx context.Context) error { _, err := db.ExecContext(ctx, q.write) return err }) } func registerSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_users" (uuid, email, salt, pwhash) VALUES (?, ?, ?, ?) RETURNING id, timestamp; ` const tmpl_read = ` SELECT id, timestamp from "%s_users" WHERE uuid = ?; ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), read: fmt.Sprintf(tmpl_read, prefix), } } func registerStmt( db *sql.DB, prefix string, ) ( func(guuid.UUID, string, []byte, []byte) (userT, error), func() error, error, ) { q := registerSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } readStmt, err := db.Prepare(q.read) if err != nil { return nil, nil, err } fn := func( userID guuid.UUID, email string, salt []byte, pwhash []byte, ) (userT, error) { user := userT{ uuid: userID, email: email, salt: salt, pwhash: pwhash, confirmed: false, } var timestr string user_id_bytes := userID[:] err := writeStmt.QueryRow( user_id_bytes, email, salt, pwhash, ).Scan(&user.id, ×tr) if err != nil { return userT{}, err } user.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return userT{}, err } return user, nil } closeFn := func() error { return g.SomeFnError(writeStmt.Close, readStmt.Close) } return fn, closeFn, nil } func sendTokenSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_confirmation_attempts" (user_id, token) VALUES ( (SELECT id FROM "%s_users" WHERE uuid = ?), ? ) ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix, prefix), } } func sendTokenStmt( db *sql.DB, prefix string, ) (func(guuid.UUID, string) error, func() error, error) { q := sendTokenSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(userID guuid.UUID, token string) error { user_id_bytes := userID[:] _, err := writeStmt.Exec(user_id_bytes, token) return err } return fn, writeStmt.Close, nil } func confirmSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_user_confirmations" (user_id, attempt_id) VALUES (?, ?); ` const tmpl_read = ` SELECT "%s_confirmation_attempts".id, "%s_confirmation_attempts".user_id, "%s_users".uuid FROM "%s_confirmation_attempts" JOIN "%s_users" ON "%s_confirmation_attempts".user_id = "%s_users".id WHERE token = ?; ` const tmpl_session = ` INSERT INTO "%s_sessions" (uuid, user_id) VALUES (?, ?) RETURNING id, timestamp; ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), read: fmt.Sprintf( tmpl_read, prefix, prefix, prefix, prefix, prefix, prefix, prefix, ), session: fmt.Sprintf(tmpl_session, prefix), } } func confirmStmt( db *sql.DB, prefix string, ) (func(string, guuid.UUID) (sessionT, error), func() error, error) { q := confirmSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } readStmt, err := db.Prepare(q.read) if err != nil { return nil, nil, g.WrapErrors(writeStmt.Close(), err) } sessionStmt, err := db.Prepare(q.session) if err != nil { return nil, nil, g.WrapErrors( writeStmt.Close(), readStmt.Close(), err, ) } fn := func(token string, sessionID guuid.UUID) (sessionT, error) { session := sessionT{ uuid: sessionID, } var ( user_id int64 attempt_id int64 user_id_bytes []byte ) err := readStmt.QueryRow(token).Scan( &attempt_id, &user_id, &user_id_bytes, ) if err != nil { return sessionT{}, err } session.userID = guuid.UUID(user_id_bytes) _, err = writeStmt.Exec(user_id, attempt_id) if err != nil { return sessionT{}, err } session_id_bytes := sessionID[:] var timestr string err = sessionStmt.QueryRow( session_id_bytes, user_id, ).Scan(&session.id, ×tr) if err != nil { return sessionT{}, err } session.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return sessionT{}, err } return session, nil } closeFn := func() error { return g.SomeFnError( writeStmt.Close, readStmt.Close, sessionStmt.Close, ) } return fn, closeFn, nil } func byEmailSQL(prefix string) queryT { // FIXME: rewrite as LEFT JOIN? const tmpl_read = ` SELECT id, timestamp, uuid, salt, pwhash, ( CASE WHEN EXISTS ( SELECT id FROM "%s_user_confirmations" WHERE user_id = ( SELECT id FROM "%s_users" WHERE email = ? ) ) THEN 1 ELSE 0 END ) as confirmed FROM "%s_users" WHERE email = ?; ` return queryT{ read: fmt.Sprintf(tmpl_read, prefix, prefix, prefix), } } func byEmailStmt( db *sql.DB, prefix string, ) (func(string) (userT, error), func() error, error) { q := byEmailSQL(prefix) readStmt, err := db.Prepare(q.read) if err != nil { return nil, nil, err } fn := func(email string) (userT, error) { user := userT{ email: email, } var ( timestr string user_id_bytes []byte ) err := readStmt.QueryRow(email, email).Scan( &user.id, ×tr, &user_id_bytes, &user.salt, &user.pwhash, &user.confirmed, ) if err != nil { return userT{}, err } user.uuid = guuid.UUID(user_id_bytes) user.timestamp, err = time.Parse(time.RFC3339Nano,timestr) if err != nil { return userT{}, err } return user, nil } return fn, readStmt.Close, nil } func loginSQL(prefix string) queryT { const tmpl_session = ` INSERT INTO "%s_sessions" (uuid, user_id) VALUES ( ?, (SELECT id FROM "%s_users" WHERE uuid = ?) ) RETURNING id, timestamp; ` return queryT{ session: fmt.Sprintf(tmpl_session, prefix, prefix), } } func loginStmt( db *sql.DB, prefix string, ) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) { q := loginSQL(prefix) sessionStmt, err := db.Prepare(q.session) if err != nil { return nil, nil, err } fn := func(userID guuid.UUID, sessionID guuid.UUID) (sessionT, error) { session := sessionT{ uuid: sessionID, userID: userID, } user_id_bytes := userID[:] session_id_bytes := sessionID[:] var timestr string err := sessionStmt.QueryRow( session_id_bytes, user_id_bytes, ).Scan( &session.id, ×tr, ) if err != nil { return sessionT{}, err } session.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return sessionT{}, err } return session, err } return fn, sessionStmt.Close, nil } func refreshSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_sessions" (uuid, user_id) VALUES ( ?, (SELECT user_id FROM "%s_sessions" WHERE uuid = ?) ) RETURNING id, timestamp, ( SELECT "%s_users".uuid FROM "%s_users" JOIN "%s_sessions" ON "%s_users".id = "%s_sessions".user_id WHERE "%s_sessions".uuid = ? ) AS userID; ` return queryT{ write: fmt.Sprintf( tmpl_write, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, ), } } func refreshStmt( db *sql.DB, prefix string, ) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) { q := refreshSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } fn := func( sessionID guuid.UUID, newSessionID guuid.UUID, ) (sessionT, error) { session := sessionT{ uuid: newSessionID, } session_id_bytes := sessionID[:] new_session_id_bytes := newSessionID[:] var ( timestr string user_id_bytes []byte ) err := writeStmt.QueryRow( new_session_id_bytes, session_id_bytes, session_id_bytes, ).Scan(&session.id, ×tr, &user_id_bytes) if err != nil { return sessionT{}, err } session.userID = guuid.UUID(user_id_bytes) session.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return sessionT{}, err } return session, err } return fn, writeStmt.Close, nil } func resetSQL(prefix string) queryT { const tmpl_write = ` -- INSERT SOMETHING %s ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), } } func resetStmt( db *sql.DB, prefix string, ) (func(int64, []byte, guuid.UUID) (sessionT, error), func() error, error) { q := resetSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(id int64, pwhash []byte, token guuid.UUID) (sessionT, error) { var session sessionT err := writeStmt.QueryRow(id, pwhash, token).Scan(&session) return session, err } return fn, writeStmt.Close, nil } func changeSQL(prefix string) queryT { const tmpl_write = ` -- INSERT SOMETHING %s ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), } } func changeStmt( db *sql.DB, prefix string, ) (func(int64, []byte) (sessionT, error), func() error, error) { q := changeSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(id int64, pwhash []byte) (sessionT, error) { var session sessionT err := writeStmt.QueryRow(id, pwhash).Scan(&session) return session, err } return fn, writeStmt.Close, nil } func byUUIDSQL(prefix string) queryT { const tmpl_read = ` -- INSERT SOMETHING %s ` return queryT{ read: fmt.Sprintf(tmpl_read, prefix), } } func byUUIDStmt( db *sql.DB, prefix string, ) (func(guuid.UUID) (sessionT, error), func() error, error) { q := byUUIDSQL(prefix) readStmt, err := db.Prepare(q.read) if err != nil { return nil, nil, err } fn := func(sessionID guuid.UUID) (sessionT, error) { var session sessionT err := readStmt.QueryRow(sessionID).Scan(&session) return session, err } return fn, readStmt.Close, nil } func logoutSQL(prefix string) queryT { const tmpl_write = ` -- INSERT SOMETHING %s ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), } } func logoutStmt( db *sql.DB, prefix string, ) (func(guuid.UUID) error, func() error, error) { q := logoutSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(sessionID guuid.UUID) error { _, err := writeStmt.Exec(sessionID) return err } return fn, writeStmt.Close, nil } func outOthersSQL(prefix string) queryT { const tmpl_write = ` -- INSERT SOMETHING %s ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), } } func outOthersStmt( db *sql.DB, prefix string, ) (func(guuid.UUID) error, func() error, error) { q := outOthersSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(sessionID guuid.UUID) error { _, err := writeStmt.Exec(sessionID) return err } return fn, writeStmt.Close, nil } func outAllSQL(prefix string) queryT { const tmpl_write = ` -- INSERT SOMETHING %s ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), } } func outAllStmt( db *sql.DB, prefix string, ) (func(guuid.UUID) error, func() error, error) { q := outAllSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(sessionID guuid.UUID) error { _, err := writeStmt.Exec(sessionID) return err } return fn, writeStmt.Close, nil } func initDB( db *sql.DB, prefix string, ) (queriesT, error) { createTablesErr := createTables(db, prefix) register, registerClose, registerErr := registerStmt(db, prefix) sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix) confirm, confirmClose, confirmErr := confirmStmt(db, prefix) byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix) login, loginClose, loginErr := loginStmt(db, prefix) refresh, refreshClose, refreshErr := refreshStmt(db, prefix) reset, resetClose, resetErr := resetStmt(db, prefix) change, changeClose, changeErr := changeStmt(db, prefix) byUUID, byUUIDClose, byUUIDErr := byUUIDStmt(db, prefix) logout, logoutClose, logoutErr := logoutStmt(db, prefix) outOthers, outOthersClose, outOthersErr := outOthersStmt(db, prefix) outAll, outAllClose, outAllErr := outAllStmt(db, prefix) err := g.SomeError( createTablesErr, registerErr, sendTokenErr, confirmErr, byEmailErr, loginErr, refreshErr, resetErr, changeErr, byUUIDErr, logoutErr, outOthersErr, outAllErr, ) if err != nil { return queriesT{}, err } close := func() error { return g.SomeFnError( registerClose, sendTokenClose, confirmClose, byEmailClose, loginClose, refreshClose, resetClose, changeClose, byUUIDClose, logoutClose, outOthersClose, outAllClose, ) } var connMutex sync.RWMutex return queriesT{ register: func( a guuid.UUID, b string, c []byte, d []byte, ) (userT, error) { connMutex.RLock() defer connMutex.RUnlock() return register(a, b, c, d) }, sendToken: func(a guuid.UUID, b string) error { connMutex.RLock() defer connMutex.RUnlock() return sendToken(a, b) }, confirm: func(a string, b guuid.UUID) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return confirm(a, b) }, byEmail: func(a string) (userT, error) { connMutex.RLock() defer connMutex.RUnlock() return byEmail(a) }, login: func(a guuid.UUID, b guuid.UUID) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return login(a, b) }, refresh: func(a guuid.UUID, b guuid.UUID) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return refresh(a, b) }, reset: func(a int64, b []byte, c guuid.UUID) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return reset(a, b, c) }, change: func(a int64, b []byte) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return change(a, b) }, byUUID: func(a guuid.UUID) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return byUUID(a) }, logout: func(a guuid.UUID) error { connMutex.RLock() defer connMutex.RUnlock() return logout(a) }, outOthers: func(a guuid.UUID) error { connMutex.RLock() defer connMutex.RUnlock() return outOthers(a) }, outAll: func(a guuid.UUID) error { connMutex.RLock() defer connMutex.RUnlock() return outAll(a) }, close: func() error { connMutex.Lock() defer connMutex.Unlock() return close() }, }, nil } func newUserHandler(auth authT) func(q.Message) error { return func(message q.Message) error { return nil } } func sendConfirmationRequestHandler(auth authT) func(q.Message) error { return func(message q.Message) error { return nil } } func forgotPasswordRequestHandler(auth authT) func(q.Message) error { return func(message q.Message) error { return nil } } var consumers = []consumerT{ consumerT{ topic: NEW_USER, handlerFn: newUserHandler, }, consumerT{ topic: SEND_CONFIRMATION_REQUEST, handlerFn: sendConfirmationRequestHandler, }, consumerT{ topic: FORGOT_PASSWORD_REQUEST, handlerFn: forgotPasswordRequestHandler, }, } func registerConsumers(auth authT, consumers []consumerT, prefix string) error { for _, consumer := range consumers { err := auth.queue.Subscribe( consumer.topic, prefix + "-" + consumer.topic, consumer.handlerFn(auth), ) if err != nil { return err } } return nil } func unregisterConsumers(queue q.IQueue, consumers []consumerT, prefix string) { for _, consumer := range consumers { queue.Unsubscribe(consumer.topic, prefix + "-" + consumer.topic) } } type resultT[T any] struct{ value T err error } type taggedT[T any] struct{ id int value T } func startRunner[A any, B any]( in <-chan taggedT[A], out chan<- taggedT[B], fn func(A) B, done func(), ) { for input := range in { out <- taggedT[B]{ id: input.id, value: fn(input.value), } } done() } func makePoolRunner[A any, B any](count int, fn func(A) B) (func(A) B, func()) { var wg sync.WaitGroup wg.Add(count) in := make(chan taggedT[A]) out := make(chan taggedT[B]) for _ = range count { go startRunner(in, out, fn, wg.Done) } var mutex sync.Mutex m := map[int]chan B{} id := 0 go func() { for output := range out { mutex.Lock() defer mutex.Unlock() m[output.id] <- output.value close(m[output.id]) delete(m, output.id) } }() poolRunFn := func(input A) B { c := make(chan B) { mutex.Lock() defer mutex.Unlock() m[id] = c id++ } in <- taggedT[A]{ id: id, value: input, } return <- c } close := func() { close(in) wg.Wait() close(out) } return poolRunFn, close } func asResult[A any, B any](fn func(A) (B, error)) func(A) resultT[B] { return func(input A) resultT[B] { output, err := fn(input) return resultT[B]{ value: output, err: err, } } } func unwrapResult[A any, B any](fn func(A) resultT[B]) func(A) (B, error) { return func(input A) (B, error) { result := fn(input) return result.value, result.err } } func NewWithPrefix(db *sql.DB, queue q.IQueue, prefix string) (IAuth, error) { queries, err := initDB(db, prefix) if err != nil { return authT{}, err } numCPU := runtime.NumCPU() hasher, closeHasher := makePoolRunner(numCPU, asResult(scrypt.Hash)) closeFn := func() { unregisterConsumers(queue, consumers, prefix) closeHasher() } auth := authT{ queries: queries, queue: queue, hasher: unwrapResult(hasher), close: closeFn, } err = registerConsumers(auth, consumers, prefix) if err != nil { closeFn() return authT{}, err } return auth, nil } func New(db *sql.DB, queue q.IQueue) (IAuth, error) { return NewWithPrefix(db, queue, defaultPrefix) } func newUserPayload(email string, salt []byte, pwhash []byte) ([]byte, error) { data := make(map[string]interface{}) data["email"] = email data["salt"] = hex.EncodeToString(salt) data["pwhash"] = hex.EncodeToString(pwhash) return json.Marshal(data) } func (auth authT) Register( email string, password string, confirmPassword string, ) (userT, error) { if password != confirmPassword { return userT{}, ErrPassMismatch } if len(password) < scrypt.MinimumPasswordLength { return userT{}, ErrTooShort } // special check for sql.ErrNoRows to combat enumeration attacks. // FIXME: how so? _, lookupErr := auth.queries.byEmail(email) if lookupErr != nil && lookupErr != sql.ErrNoRows { return userT{}, lookupErr } salt, err := scrypt.Salt() if err != nil { return userT{}, err } input := scrypt.HashInput{ Password: []byte(password), Salt: salt, } hash, err := auth.hasher(input) if err != nil { return userT{}, err } /* We also try to register anyway, to prevent disk IO timing attacks. / FIXME: how so? */ flowID := guuid.New() waiter := auth.queue.WaitFor(NEW_USER, flowID, "register") defer waiter.Close() payload, err := newUserPayload(email, salt, hash) if err != nil { return userT{}, err } unsent := q.UnsentMessage{ Topic: NEW_USER, FlowID: flowID, Payload: payload, } _, err = auth.queue.Publish(unsent) if err != nil { return userT{}, err } user := userT{} select { case <-time.After(RegisterTimeout): err = ErrTimeout case <-waiter.Channel: user, err = auth.queries.byEmail(email) } return user, nil } func sendConfirmationMessage( email string, flowID guuid.UUID, ) (q.UnsentMessage, error) { data := make(map[string]interface{}) data["email"] = email payload, err := json.Marshal(data) if err != nil { return q.UnsentMessage{}, err } return q.UnsentMessage{ Topic: SEND_CONFIRMATION_REQUEST, FlowID: flowID, Payload: payload, }, nil } func (auth authT) ResendConfirmation(email string) error { unsent, err := sendConfirmationMessage(email, guuid.New()) if err != nil { return err } _, err = auth.queue.Publish(unsent) return err } func (auth authT) ConfirmEmail(token string) (sessionT, error) { return auth.queries.confirm("token FIXME", guuid.New()) } func (auth authT) LoginEmail( email string, password string, ) (sessionT, error) { if len(password) < scrypt.MinimumPasswordLength { return sessionT{}, ErrTooShort } // special check for sql.ErrNoRows to combat enumeration attacks. // FIXME: how so? user, err := auth.queries.byEmail(email) if err != nil && err != sql.ErrNoRows { return sessionT{}, err } input := scrypt.HashInput{ Password: []byte(password), Salt: user.salt, } hash, err := auth.hasher(input) if err != nil { return sessionT{}, err } ok := slices.Equal(hash, user.pwhash) if !ok { return sessionT{}, ErrBadCombo } if !user.confirmed { return sessionT{}, ErrUnconfirmed } session, err := auth.queries.login(user.uuid, guuid.New()) if err != nil { return sessionT{}, err } return session, nil } func forgotPasswordMessage( email string, flowID guuid.UUID, ) (q.UnsentMessage, error) { data := make(map[string]interface{}) data["email"] = email payload, err := json.Marshal(data) if err != nil { return q.UnsentMessage{}, err } return q.UnsentMessage{ Topic: FORGOT_PASSWORD_REQUEST, FlowID: flowID, Payload: payload, }, nil } func (auth authT) ForgotPassword(email string) error { // special check for sql.ErrNoRows to combat enumeration attacks. user, err := auth.queries.byEmail(email) if err != nil && err != sql.ErrNoRows { return err } unsent, err := forgotPasswordMessage(user.email, guuid.New()) if err != nil { return err } _, err = auth.queue.Publish(unsent) return err } func checkSession(session sessionT, now time.Time) error { if session.revoked_at != nil { return ErrRevokedSession } if session.timestamp.Add(SessionDuration).After(now) { return ErrSessionExpired } return nil } func validateSession( lookupFn func(guuid.UUID) (sessionT, error), session sessionT, ) error { dbSession, err := lookupFn(session.uuid) if err != nil { return err } return checkSession(dbSession, time.Now()) } func (auth authT) Refresh(session sessionT) (sessionT, error) { err := validateSession(auth.queries.byUUID, session) if err != nil { return sessionT{}, err } return auth.queries.refresh(session.uuid, guuid.New()) } func (auth authT) ResetPassword( token guuid.UUID, password string, confirmPassword string, ) (sessionT, error) { if password != confirmPassword { return sessionT{}, ErrPassMismatch } if len(password) < scrypt.MinimumPasswordLength { return sessionT{}, ErrTooShort } user := userT{} input := scrypt.HashInput{ Password: []byte(password), Salt: user.salt, } pwhash, err := auth.hasher(input) if err != nil { return sessionT{}, err } if user.confirmed { return auth.queries.reset(user.id, pwhash, token) } else { // return auth.queries.confirm(user.id, pwhash, token) return auth.queries.confirm("token FIXME", guuid.New()) } } func (auth authT) ChangePassword( user userT, currentPassword string, newPassword string, confirmNewPassword string, ) (sessionT, error) { if newPassword != confirmNewPassword { return sessionT{}, ErrPassMismatch } if len(newPassword) < scrypt.MinimumPasswordLength { return sessionT{}, ErrTooShort } input := scrypt.HashInput{ Password: []byte(newPassword), Salt: user.salt, } pwhash, err := auth.hasher(input) if err != nil { return sessionT{}, err } if !user.confirmed { return sessionT{}, ErrUnconfirmed } return auth.queries.change(user.id, pwhash) } func runLogout( lookupFn func(guuid.UUID) (sessionT, error), session sessionT, queryFn func(guuid.UUID) error, ) error { err := validateSession(lookupFn, session) if err != nil { return err } return queryFn(session.uuid) } func (auth authT) Logout(session sessionT) error { return runLogout( auth.queries.byUUID, session, auth.queries.logout, ) } func (auth authT) LogoutOthers(session sessionT) error { return runLogout( auth.queries.byUUID, session, auth.queries.outOthers, ) } func (auth authT) LogoutAll(session sessionT) error { return runLogout( auth.queries.byUUID, session, auth.queries.outAll, ) } func (auth authT) Close() error { return auth.queries.close() } func Main() { // FIXME }