package gracha import ( "context" "database/sql" "encoding/hex" "encoding/json" "errors" "fmt" "runtime" "slices" "sync" "time" "guuid" "q" "scrypt" g "gobang" ) const ( defaultPrefix = "gracha" rollbackErrorFmt = "rollback error: %w; while executing: %w" NEW_USER = "new-user" SEND_CONFIRMATION_REQUEST = "send-confirmation-request" FORGOT_PASSWORD_REQUEST = "forgot-password-request" day = 24 * time.Hour ) var ( SessionDuration = 7 * day RegisterTimeout = 15 * time.Second ErrPassMismatch = errors.New("gracha: password confirmation mismatch") ErrTooShort = errors.New("gracha: password too short") ErrTimeout = errors.New("gracha: timeout when creating user") ErrRegistered = errors.New("gracha: user already registered") ErrBadCombo = errors.New("gracha: bad username/passphrase combo") ErrUnconfirmed = errors.New("gracha: email is not confirmed") ErrRevokedSession = errors.New("gracha: this session was revoked") ErrSessionExpired = errors.New("gracha: session expired") ) type queryT struct{ write string read string session string } type queriesT struct{ register func(guuid.UUID, string, []byte, []byte) (userT, error) sendToken func(guuid.UUID, string) error confirm func(string, guuid.UUID) (sessionT, error) byEmail func(string) (userT, error) login func(guuid.UUID, guuid.UUID) (sessionT, error) refresh func(guuid.UUID, guuid.UUID) (sessionT, error) reset func(int64, []byte, guuid.UUID) (sessionT, error) change func(guuid.UUID, []byte) (sessionT, error) byUUID func(guuid.UUID) (sessionT, error) logout func(guuid.UUID) error outOthers func(guuid.UUID) error outAll func(guuid.UUID) error close func() error } type userT struct{ id int64 timestamp time.Time uuid guuid.UUID email string salt []byte pwhash []byte confirmed bool } type User struct{ UUID guuid.UUID Salt []byte Confirmed bool } type sessionT struct{ id int64 timestamp time.Time uuid guuid.UUID userID guuid.UUID // type_ string revoked_at *time.Time } type Session struct{ UUID guuid.UUID } type consumerT struct{ topic string handlerFn func(authT) func(q.Message) error } type authT struct{ queries queriesT queue q.IQueue hasher func(scrypt.HashInput) ([]byte, error) close func() } type IAuth interface{ Register(string, string, string) (User, error) ResendConfirmation(string) error ConfirmEmail(string) (Session, error) LoginEmail(string, string) (Session, error) ForgotPassword(string) error Refresh(Session) (Session, error) ResetPassword(guuid.UUID, string, string) (Session, error) ChangePassword(User, string, string, string) (Session, error) Logout(Session) error LogoutOthers(Session) error LogoutAll(Session) error Close() error } func validateSession(session sessionT) (sessionT, error) { // FIXME: implement return session, nil } func tryRollback(db *sql.DB, ctx context.Context, err error) error { _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;") if rollbackErr != nil { return fmt.Errorf( rollbackErrorFmt, rollbackErr, err, ) } return err } func inTx(db *sql.DB, fn func(context.Context) error) error { ctx := context.Background() _, err := db.ExecContext(ctx, "BEGIN IMMEDIATE;") if err != nil { return err } err = fn(ctx) if err != nil { return tryRollback(db, ctx, err) } _, err = db.ExecContext(ctx, "COMMIT;") if err != nil { return tryRollback(db, ctx, err) } return nil } func createTablesSQL(prefix string) queryT { const tmpl_write = ` CREATE TABLE IF NOT EXISTS "%s_users" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, email TEXT NOT NULL UNIQUE, salt BLOB NOT NULL UNIQUE, pwhash BLOB NOT NULL ); CREATE TABLE IF NOT EXISTS "%s_confirmation_attempts" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), -- uuid BLOB NOT NULL UNIQUE, user_id INTEGER NOT NULL REFERENCES "%s_users"(id), token TEXT NOT NULL UNIQUE ); CREATE TABLE IF NOT EXISTS "%s_user_confirmations" ( id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), user_id INTEGER NOT NULL REFERENCES "%s_users"(id) UNIQUE, attempt_id INTEGER NOT NULL REFERENCES "%s_confirmation_attempts"(id) UNIQUE ); CREATE TABLE IF NOT EXISTS "%s_user_changes" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), user_id INTEGER NOT NULL REFERENCES "%s_users"(id), attribute TEXT NOT NULL, value TEXT NOT NULL, op BOOLEAN NOT NULL ); -- CREATE TABLE IF NOT EXISTS "%s_tokens" ( -- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, -- timestamp TEXT NOT NULL DEFAULT (%s), -- uuid BLOB NOT NULL UNIQUE, -- type TEXT NOT NULL -- ); CREATE TABLE IF NOT EXISTS "%s_roles" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL REFERENCES "%s_users"(id), role TEXT NOT NULL, UNIQUE (user_id, role) ); CREATE TABLE IF NOT EXISTS "%s_role_changes" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), user_id INTEGER NOT NULL REFERENCES "%s_roles"(id), role TEXT NOT NULL, op BOOLEAN NOT NULL ); CREATE TABLE IF NOT EXISTS "%s_sessions" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, user_id INTEGER NOT NULL REFERENCES "%s_users"(id) -- type TEXT NOT NULL, -- revoked_at TEXT, -- revoker_id INTEGER REFERENCES "%s_users"(id), -- FIXME: add provenance: login, refresh, confirmation, etc. ); CREATE TABLE IF NOT EXISTS "%s_attempts" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), user_id INTEGER REFERENCES "%s_users"(id), session_id INTEGER REFERENCES "%s_sessions"(id) ); CREATE TABLE IF NOT EXISTS "%s_audit" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, attribute TEXT NOT NULL, value TEXT NOT NULL, op BOOLEAN NOT NULL ); ` return queryT{ write: fmt.Sprintf( tmpl_write, prefix, g.SQLiteNow, prefix, g.SQLiteNow, prefix, prefix, g.SQLiteNow, prefix, prefix, prefix, g.SQLiteNow, prefix, prefix, g.SQLiteNow, prefix, prefix, prefix, g.SQLiteNow, prefix, prefix, g.SQLiteNow, prefix, prefix, prefix, g.SQLiteNow, prefix, prefix, prefix, g.SQLiteNow, ), } } func createTables(db *sql.DB, prefix string) error { q := createTablesSQL(prefix) return inTx(db, func(ctx context.Context) error { _, err := db.ExecContext(ctx, q.write) return err }) } func registerSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_users" (uuid, email, salt, pwhash) VALUES (?, ?, ?, ?) RETURNING id, timestamp; ` const tmpl_read = ` SELECT id, timestamp from "%s_users" WHERE uuid = ?; ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), read: fmt.Sprintf(tmpl_read, prefix), } } func registerStmt( db *sql.DB, prefix string, ) ( func(guuid.UUID, string, []byte, []byte) (userT, error), func() error, error, ) { q := registerSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } readStmt, err := db.Prepare(q.read) if err != nil { return nil, nil, err } fn := func( userID guuid.UUID, email string, salt []byte, pwhash []byte, ) (userT, error) { user := userT{ uuid: userID, email: email, salt: salt, pwhash: pwhash, confirmed: false, } var timestr string user_id_bytes := userID[:] err := writeStmt.QueryRow( user_id_bytes, email, salt, pwhash, ).Scan(&user.id, ×tr) if err != nil { return userT{}, err } user.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return userT{}, err } return user, nil } closeFn := func() error { return g.SomeFnError(writeStmt.Close, readStmt.Close) } return fn, closeFn, nil } func sendTokenSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_confirmation_attempts" (user_id, token) VALUES ( (SELECT id FROM "%s_users" WHERE uuid = ?), ? ) ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix, prefix), } } func sendTokenStmt( db *sql.DB, prefix string, ) (func(guuid.UUID, string) error, func() error, error) { q := sendTokenSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(userID guuid.UUID, token string) error { user_id_bytes := userID[:] _, err := writeStmt.Exec(user_id_bytes, token) return err } return fn, writeStmt.Close, nil } func confirmSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_user_confirmations" (user_id, attempt_id) VALUES (?, ?); ` const tmpl_read = ` SELECT "%s_confirmation_attempts".id, "%s_confirmation_attempts".user_id, "%s_users".uuid FROM "%s_confirmation_attempts" JOIN "%s_users" ON "%s_confirmation_attempts".user_id = "%s_users".id WHERE token = ?; ` const tmpl_session = ` INSERT INTO "%s_sessions" (uuid, user_id) VALUES (?, ?) RETURNING id, timestamp; ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), read: fmt.Sprintf( tmpl_read, prefix, prefix, prefix, prefix, prefix, prefix, prefix, ), session: fmt.Sprintf(tmpl_session, prefix), } } func confirmStmt( db *sql.DB, prefix string, ) (func(string, guuid.UUID) (sessionT, error), func() error, error) { q := confirmSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } readStmt, err := db.Prepare(q.read) if err != nil { return nil, nil, g.WrapErrors(writeStmt.Close(), err) } sessionStmt, err := db.Prepare(q.session) if err != nil { return nil, nil, g.WrapErrors( writeStmt.Close(), readStmt.Close(), err, ) } fn := func(token string, sessionID guuid.UUID) (sessionT, error) { session := sessionT{ uuid: sessionID, } var ( user_id int64 attempt_id int64 user_id_bytes []byte ) err := readStmt.QueryRow(token).Scan( &attempt_id, &user_id, &user_id_bytes, ) if err != nil { return sessionT{}, err } session.userID = guuid.UUID(user_id_bytes) _, err = writeStmt.Exec(user_id, attempt_id) if err != nil { return sessionT{}, err } session_id_bytes := sessionID[:] var timestr string err = sessionStmt.QueryRow( session_id_bytes, user_id, ).Scan(&session.id, ×tr) if err != nil { return sessionT{}, err } session.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return sessionT{}, err } return validateSession(session) } closeFn := func() error { return g.SomeFnError( writeStmt.Close, readStmt.Close, sessionStmt.Close, ) } return fn, closeFn, nil } func byEmailSQL(prefix string) queryT { // FIXME: rewrite as LEFT JOIN? const tmpl_read = ` SELECT id, timestamp, uuid, salt, pwhash, ( CASE WHEN EXISTS ( SELECT id FROM "%s_user_confirmations" WHERE user_id = ( SELECT id FROM "%s_users" WHERE email = ? ) ) THEN 1 ELSE 0 END ) as confirmed FROM "%s_users" WHERE email = ?; ` return queryT{ read: fmt.Sprintf(tmpl_read, prefix, prefix, prefix), } } func byEmailStmt( db *sql.DB, prefix string, ) (func(string) (userT, error), func() error, error) { q := byEmailSQL(prefix) readStmt, err := db.Prepare(q.read) if err != nil { return nil, nil, err } fn := func(email string) (userT, error) { user := userT{ email: email, } var ( timestr string user_id_bytes []byte ) err := readStmt.QueryRow(email, email).Scan( &user.id, ×tr, &user_id_bytes, &user.salt, &user.pwhash, &user.confirmed, ) if err != nil { return userT{}, err } user.uuid = guuid.UUID(user_id_bytes) user.timestamp, err = time.Parse(time.RFC3339Nano,timestr) if err != nil { return userT{}, err } return user, nil } return fn, readStmt.Close, nil } func loginSQL(prefix string) queryT { const tmpl_session = ` INSERT INTO "%s_sessions" (uuid, user_id) VALUES ( ?, (SELECT id FROM "%s_users" WHERE uuid = ?) ) RETURNING id, timestamp; ` return queryT{ session: fmt.Sprintf(tmpl_session, prefix, prefix), } } func loginStmt( db *sql.DB, prefix string, ) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) { q := loginSQL(prefix) sessionStmt, err := db.Prepare(q.session) if err != nil { return nil, nil, err } fn := func(userID guuid.UUID, sessionID guuid.UUID) (sessionT, error) { session := sessionT{ uuid: sessionID, userID: userID, } user_id_bytes := userID[:] session_id_bytes := sessionID[:] var timestr string err := sessionStmt.QueryRow( session_id_bytes, user_id_bytes, ).Scan( &session.id, ×tr, ) if err != nil { return sessionT{}, err } session.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return sessionT{}, err } return validateSession(session) } return fn, sessionStmt.Close, nil } func refreshSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_sessions" (uuid, user_id) VALUES ( ?, (SELECT user_id FROM "%s_sessions" WHERE uuid = ?) ) RETURNING id, timestamp, ( SELECT "%s_users".uuid FROM "%s_users" JOIN "%s_sessions" ON "%s_users".id = "%s_sessions".user_id WHERE "%s_sessions".uuid = ? ) AS userID; ` return queryT{ write: fmt.Sprintf( tmpl_write, prefix, prefix, prefix, prefix, prefix, prefix, prefix, prefix, ), } } func refreshStmt( db *sql.DB, prefix string, ) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) { q := refreshSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } fn := func( sessionID guuid.UUID, newSessionID guuid.UUID, ) (sessionT, error) { session := sessionT{ uuid: newSessionID, } session_id_bytes := sessionID[:] new_session_id_bytes := newSessionID[:] var ( timestr string user_id_bytes []byte ) err := writeStmt.QueryRow( new_session_id_bytes, session_id_bytes, session_id_bytes, ).Scan(&session.id, ×tr, &user_id_bytes) if err != nil { return sessionT{}, err } session.userID = guuid.UUID(user_id_bytes) session.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return sessionT{}, err } return validateSession(session) } return fn, writeStmt.Close, nil } func resetSQL(prefix string) queryT { const tmpl_write = ` -- INSERT SOMETHING %s ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), } } func resetStmt( db *sql.DB, prefix string, ) (func(int64, []byte, guuid.UUID) (sessionT, error), func() error, error) { q := resetSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(id int64, pwhash []byte, token guuid.UUID) (sessionT, error) { var session sessionT err := writeStmt.QueryRow(id, pwhash, token).Scan(&session) if err != nil { return sessionT{}, err } return validateSession(session) } return fn, writeStmt.Close, nil } func changeSQL(prefix string) queryT { const tmpl_write = ` -- INSERT SOMETHING %s ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), } } func changeStmt( db *sql.DB, prefix string, ) (func(guuid.UUID, []byte) (sessionT, error), func() error, error) { q := changeSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(uuid guuid.UUID, pwhash []byte) (sessionT, error) { var session sessionT err := writeStmt.QueryRow(uuid, pwhash).Scan(&session) if err != nil { return sessionT{}, err } return validateSession(session) } return fn, writeStmt.Close, nil } func byUUIDSQL(prefix string) queryT { const tmpl_read = ` -- INSERT SOMETHING %s ` return queryT{ read: fmt.Sprintf(tmpl_read, prefix), } } func byUUIDStmt( db *sql.DB, prefix string, ) (func(guuid.UUID) (sessionT, error), func() error, error) { q := byUUIDSQL(prefix) readStmt, err := db.Prepare(q.read) if err != nil { return nil, nil, err } fn := func(sessionID guuid.UUID) (sessionT, error) { var session sessionT err := readStmt.QueryRow(sessionID).Scan(&session) if err != nil { return sessionT{}, err } return validateSession(session) } return fn, readStmt.Close, nil } func logoutSQL(prefix string) queryT { const tmpl_write = ` -- INSERT SOMETHING %s ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), } } func logoutStmt( db *sql.DB, prefix string, ) (func(guuid.UUID) error, func() error, error) { q := logoutSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(sessionID guuid.UUID) error { _, err := writeStmt.Exec(sessionID) return err } return fn, writeStmt.Close, nil } func outOthersSQL(prefix string) queryT { const tmpl_write = ` -- INSERT SOMETHING %s ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), } } func outOthersStmt( db *sql.DB, prefix string, ) (func(guuid.UUID) error, func() error, error) { q := outOthersSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(sessionID guuid.UUID) error { _, err := writeStmt.Exec(sessionID) return err } return fn, writeStmt.Close, nil } func outAllSQL(prefix string) queryT { const tmpl_write = ` -- INSERT SOMETHING %s ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), } } func outAllStmt( db *sql.DB, prefix string, ) (func(guuid.UUID) error, func() error, error) { q := outAllSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(sessionID guuid.UUID) error { _, err := writeStmt.Exec(sessionID) return err } return fn, writeStmt.Close, nil } func initDB( db *sql.DB, prefix string, ) (queriesT, error) { createTablesErr := createTables(db, prefix) register, registerClose, registerErr := registerStmt(db, prefix) sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix) confirm, confirmClose, confirmErr := confirmStmt(db, prefix) byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix) login, loginClose, loginErr := loginStmt(db, prefix) refresh, refreshClose, refreshErr := refreshStmt(db, prefix) reset, resetClose, resetErr := resetStmt(db, prefix) change, changeClose, changeErr := changeStmt(db, prefix) byUUID, byUUIDClose, byUUIDErr := byUUIDStmt(db, prefix) logout, logoutClose, logoutErr := logoutStmt(db, prefix) outOthers, outOthersClose, outOthersErr := outOthersStmt(db, prefix) outAll, outAllClose, outAllErr := outAllStmt(db, prefix) err := g.SomeError( createTablesErr, registerErr, sendTokenErr, confirmErr, byEmailErr, loginErr, refreshErr, resetErr, changeErr, byUUIDErr, logoutErr, outOthersErr, outAllErr, ) if err != nil { return queriesT{}, err } close := func() error { return g.SomeFnError( registerClose, sendTokenClose, confirmClose, byEmailClose, loginClose, refreshClose, resetClose, changeClose, byUUIDClose, logoutClose, outOthersClose, outAllClose, ) } var connMutex sync.RWMutex return queriesT{ register: func( a guuid.UUID, b string, c []byte, d []byte, ) (userT, error) { connMutex.RLock() defer connMutex.RUnlock() return register(a, b, c, d) }, sendToken: func(a guuid.UUID, b string) error { connMutex.RLock() defer connMutex.RUnlock() return sendToken(a, b) }, confirm: func(a string, b guuid.UUID) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return confirm(a, b) }, byEmail: func(a string) (userT, error) { connMutex.RLock() defer connMutex.RUnlock() return byEmail(a) }, login: func(a guuid.UUID, b guuid.UUID) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return login(a, b) }, refresh: func(a guuid.UUID, b guuid.UUID) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return refresh(a, b) }, reset: func(a int64, b []byte, c guuid.UUID) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return reset(a, b, c) }, change: func(a guuid.UUID, b []byte) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return change(a, b) }, byUUID: func(a guuid.UUID) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return byUUID(a) }, logout: func(a guuid.UUID) error { connMutex.RLock() defer connMutex.RUnlock() return logout(a) }, outOthers: func(a guuid.UUID) error { connMutex.RLock() defer connMutex.RUnlock() return outOthers(a) }, outAll: func(a guuid.UUID) error { connMutex.RLock() defer connMutex.RUnlock() return outAll(a) }, close: func() error { connMutex.Lock() defer connMutex.Unlock() return close() }, }, nil } func newUserHandler(auth authT) func(q.Message) error { return func(message q.Message) error { return nil } } func sendConfirmationRequestHandler(auth authT) func(q.Message) error { return func(message q.Message) error { return nil } } func forgotPasswordRequestHandler(auth authT) func(q.Message) error { return func(message q.Message) error { return nil } } var consumers = []consumerT{ consumerT{ topic: NEW_USER, handlerFn: newUserHandler, }, consumerT{ topic: SEND_CONFIRMATION_REQUEST, handlerFn: sendConfirmationRequestHandler, }, consumerT{ topic: FORGOT_PASSWORD_REQUEST, handlerFn: forgotPasswordRequestHandler, }, } func registerConsumers(auth authT, consumers []consumerT, prefix string) error { for _, consumer := range consumers { err := auth.queue.Subscribe( consumer.topic, prefix + "-" + consumer.topic, consumer.handlerFn(auth), ) if err != nil { return err } } return nil } func unregisterConsumers(queue q.IQueue, consumers []consumerT, prefix string) { for _, consumer := range consumers { queue.Unsubscribe(consumer.topic, prefix + "-" + consumer.topic) } } type resultT[T any] struct{ value T err error } type taggedT[T any] struct{ id int value T } func startRunner[A any, B any]( in <-chan taggedT[A], out chan<- taggedT[B], fn func(A) B, done func(), ) { for input := range in { out <- taggedT[B]{ id: input.id, value: fn(input.value), } } done() } func makePoolRunner[A any, B any](count int, fn func(A) B) (func(A) B, func()) { var wg sync.WaitGroup wg.Add(count) in := make(chan taggedT[A]) out := make(chan taggedT[B]) for _ = range count { go startRunner(in, out, fn, wg.Done) } var mutex sync.Mutex m := map[int]chan B{} id := 0 go func() { for output := range out { mutex.Lock() defer mutex.Unlock() m[output.id] <- output.value close(m[output.id]) delete(m, output.id) } }() poolRunFn := func(input A) B { c := make(chan B) { mutex.Lock() defer mutex.Unlock() m[id] = c id++ } in <- taggedT[A]{ id: id, value: input, } return <- c } close := func() { close(in) wg.Wait() close(out) } return poolRunFn, close } func asResult[A any, B any](fn func(A) (B, error)) func(A) resultT[B] { return func(input A) resultT[B] { output, err := fn(input) return resultT[B]{ value: output, err: err, } } } func unwrapResult[A any, B any](fn func(A) resultT[B]) func(A) (B, error) { return func(input A) (B, error) { result := fn(input) return result.value, result.err } } func NewWithPrefix(db *sql.DB, queue q.IQueue, prefix string) (IAuth, error) { queries, err := initDB(db, prefix) if err != nil { return authT{}, err } numCPU := runtime.NumCPU() hasher, closeHasher := makePoolRunner(numCPU, asResult(scrypt.Hash)) closeFn := func() { unregisterConsumers(queue, consumers, prefix) closeHasher() } auth := authT{ queries: queries, queue: queue, hasher: unwrapResult(hasher), close: closeFn, } err = registerConsumers(auth, consumers, prefix) if err != nil { closeFn() return authT{}, err } return auth, nil } func New(db *sql.DB, queue q.IQueue) (IAuth, error) { return NewWithPrefix(db, queue, defaultPrefix) } func newUserPayload(email string, salt []byte, pwhash []byte) ([]byte, error) { data := make(map[string]interface{}) data["email"] = email data["salt"] = hex.EncodeToString(salt) data["pwhash"] = hex.EncodeToString(pwhash) return json.Marshal(data) } func asPublicUser(user userT) (User, error) { return User{}, nil } func (auth authT) Register( email string, password string, confirmPassword string, ) (User, error) { if password != confirmPassword { return User{}, ErrPassMismatch } if len(password) < scrypt.MinimumPasswordLength { return User{}, ErrTooShort } // special check for sql.ErrNoRows to combat enumeration attacks. // FIXME: how so? _, lookupErr := auth.queries.byEmail(email) if lookupErr != nil && lookupErr != sql.ErrNoRows { return User{}, lookupErr } salt, err := scrypt.Salt() if err != nil { return User{}, err } input := scrypt.HashInput{ Password: []byte(password), Salt: salt, } hash, err := auth.hasher(input) if err != nil { return User{}, err } /* We also try to register anyway, to prevent disk IO timing attacks. / FIXME: how so? */ flowID := guuid.New() waiter := auth.queue.WaitFor(NEW_USER, flowID, "register") defer waiter.Close() payload, err := newUserPayload(email, salt, hash) if err != nil { return User{}, err } unsent := q.UnsentMessage{ Topic: NEW_USER, FlowID: flowID, Payload: payload, } _, err = auth.queue.Publish(unsent) if err != nil { return User{}, err } user := userT{} select { case <-time.After(RegisterTimeout): err = ErrTimeout case <-waiter.Channel: user, err = auth.queries.byEmail(email) } return asPublicUser(user) } func sendConfirmationMessage( email string, flowID guuid.UUID, ) (q.UnsentMessage, error) { data := make(map[string]interface{}) data["email"] = email payload, err := json.Marshal(data) if err != nil { return q.UnsentMessage{}, err } return q.UnsentMessage{ Topic: SEND_CONFIRMATION_REQUEST, FlowID: flowID, Payload: payload, }, nil } func (auth authT) ResendConfirmation(email string) error { unsent, err := sendConfirmationMessage(email, guuid.New()) if err != nil { return err } _, err = auth.queue.Publish(unsent) return err } /* func asPublicSession(session sessionT) Session { _, err := validateSession(session) if err != nil { panic(fmt.Errorf( "session must have been validated at this stage: %w", err, )) } return Session{} } */ func asPublicSession(session sessionT) (Session, error) { return Session{}, nil } func (auth authT) ConfirmEmail(token string) (Session, error) { session, err := auth.queries.confirm("token FIXME", guuid.New()) if err != nil { return Session{}, err } return asPublicSession(session) } func (auth authT) LoginEmail( email string, password string, ) (Session, error) { if len(password) < scrypt.MinimumPasswordLength { return Session{}, ErrTooShort } // special check for sql.ErrNoRows to combat enumeration attacks. // FIXME: how so? user, err := auth.queries.byEmail(email) if err != nil && err != sql.ErrNoRows { return Session{}, err } input := scrypt.HashInput{ Password: []byte(password), Salt: user.salt, } hash, err := auth.hasher(input) if err != nil { return Session{}, err } ok := slices.Equal(hash, user.pwhash) if !ok { return Session{}, ErrBadCombo } if !user.confirmed { return Session{}, ErrUnconfirmed } session, err := auth.queries.login(user.uuid, guuid.New()) if err != nil { return Session{}, err } return asPublicSession(session) } func forgotPasswordMessage( email string, flowID guuid.UUID, ) (q.UnsentMessage, error) { data := make(map[string]interface{}) data["email"] = email payload, err := json.Marshal(data) if err != nil { return q.UnsentMessage{}, err } return q.UnsentMessage{ Topic: FORGOT_PASSWORD_REQUEST, FlowID: flowID, Payload: payload, }, nil } func (auth authT) ForgotPassword(email string) error { // special check for sql.ErrNoRows to combat enumeration attacks. user, err := auth.queries.byEmail(email) if err != nil && err != sql.ErrNoRows { return err } unsent, err := forgotPasswordMessage(user.email, guuid.New()) if err != nil { return err } _, err = auth.queue.Publish(unsent) return err } /* func checkSession(session sessionT, now time.Time) error { if session.revoked_at != nil { return ErrRevokedSession } if session.timestamp.Add(SessionDuration).After(now) { return ErrSessionExpired } return nil } func validateSession( lookupFn func(guuid.UUID) (sessionT, error), session sessionT, ) error { dbSession, err := lookupFn(session.uuid) if err != nil { return err } return checkSession(dbSession, time.Now()) } */ func (auth authT) Refresh(session Session) (Session, error) { /* err := validateSession(auth.queries.byUUID, session) if err != nil { return sessionT{}, err } */ // return auth.queries.refresh(session.uuid, guuid.New()) newSession, err := auth.queries.refresh(session.UUID, guuid.New()) if err != nil { return Session{}, err } return asPublicSession(newSession) } func (auth authT) ResetPassword( token guuid.UUID, password string, confirmPassword string, ) (Session, error) { if password != confirmPassword { return Session{}, ErrPassMismatch } if len(password) < scrypt.MinimumPasswordLength { return Session{}, ErrTooShort } user := userT{} input := scrypt.HashInput{ Password: []byte(password), Salt: user.salt, } pwhash, err := auth.hasher(input) if err != nil { return Session{}, err } var nextFn func() (sessionT, error) if user.confirmed { nextFn = func() (sessionT, error) { return auth.queries.reset(user.id, pwhash, token) } } else { nextFn = func() (sessionT, error) { // return auth.queries.confirm(user.id, pwhash, token) return auth.queries.confirm("token FIXME", guuid.New()) } } session, err := nextFn() if err != nil { return Session{}, err } return asPublicSession(session) } func (auth authT) ChangePassword( user User, currentPassword string, newPassword string, confirmNewPassword string, ) (Session, error) { if newPassword != confirmNewPassword { return Session{}, ErrPassMismatch } if len(newPassword) < scrypt.MinimumPasswordLength { return Session{}, ErrTooShort } input := scrypt.HashInput{ Password: []byte(newPassword), Salt: user.Salt, } pwhash, err := auth.hasher(input) if err != nil { return Session{}, err } if !user.Confirmed { return Session{}, ErrUnconfirmed } session, err := auth.queries.change(user.UUID, pwhash) if err != nil { return Session{}, nil } return asPublicSession(session) } func runLogout( lookupFn func(guuid.UUID) (sessionT, error), sessionID guuid.UUID, queryFn func(guuid.UUID) error, ) error { /* err := validateSession(lookupFn, session) if err != nil { return err } */ return queryFn(sessionID) } func (auth authT) Logout(session Session) error { return runLogout( auth.queries.byUUID, session.UUID, auth.queries.logout, ) } func (auth authT) LogoutOthers(session Session) error { return runLogout( auth.queries.byUUID, session.UUID, auth.queries.outOthers, ) } func (auth authT) LogoutAll(session Session) error { return runLogout( auth.queries.byUUID, session.UUID, auth.queries.outAll, ) } func (auth authT) Close() error { return auth.queries.close() } func Main() { // FIXME }