diff options
Diffstat (limited to 'src/cracha.go')
-rw-r--r-- | src/cracha.go | 1587 |
1 files changed, 1587 insertions, 0 deletions
diff --git a/src/cracha.go b/src/cracha.go new file mode 100644 index 0000000..5547d6e --- /dev/null +++ b/src/cracha.go @@ -0,0 +1,1587 @@ +package cracha + +import ( + "context" + "database/sql" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "runtime" + "slices" + "sync" + "time" + + "golite" + "guuid" + "q" + "scrypt" + g "gobang" +) + + + +const ( + defaultPrefix = "cracha" + + rollbackErrorFmt = "rollback error: %w; while executing: %w" + NEW_USER = "new-user" + SEND_CONFIRMATION_REQUEST = "send-confirmation-request" + FORGOT_PASSWORD_REQUEST = "forgot-password-request" + + day = 24 * time.Hour +) + +var ( + SessionDuration = 7 * day + RegisterTimeout = 15 * time.Second + + ErrPassMismatch = errors.New("cracha: password confirmation mismatch") + ErrTooShort = errors.New("cracha: password too short") + ErrTimeout = errors.New("cracha: timeout when creating user") + ErrRegistered = errors.New("cracha: user already registered") + ErrBadCombo = errors.New("cracha: bad username/passphrase combo") + ErrUnconfirmed = errors.New("cracha: email is not confirmed") + ErrRevokedSession = errors.New("cracha: this session was revoked") + ErrSessionExpired = errors.New("cracha: session expired") +) + + + +type queryT struct{ + write string + read string + session string +} + +type queriesT struct{ + register func(guuid.UUID, string, []byte, []byte) (userT, error) + sendToken func(guuid.UUID, string) error + confirm func(string, guuid.UUID) (sessionT, error) + byEmail func(string) (userT, error) + login func(guuid.UUID, guuid.UUID) (sessionT, error) + refresh func(guuid.UUID, guuid.UUID) (sessionT, error) + reset func(int64, []byte, guuid.UUID) (sessionT, error) + change func(guuid.UUID, []byte) (sessionT, error) + byUUID func(guuid.UUID) (sessionT, error) + logout func(guuid.UUID) error + outOthers func(guuid.UUID) error + outAll func(guuid.UUID) error + close func() error +} + +type userT struct{ + id int64 + timestamp time.Time + uuid guuid.UUID + email string + salt []byte + pwhash []byte + confirmed bool +} + +type User struct{ + UUID guuid.UUID + Salt []byte + Confirmed bool +} + +type sessionT struct{ + id int64 + timestamp time.Time + uuid guuid.UUID + userID guuid.UUID + // type_ string + revoked_at *time.Time +} + +type Session struct{ + UUID guuid.UUID +} + +type consumerT struct{ + topic string + handlerFn func(authT) func(q.Message) error +} + +type authT struct{ + queue q.IQueue + db *sql.DB + queries queriesT + hasher func(scrypt.HashInput) ([]byte, error) + close func() +} + +type IAuth interface{ + Register(string, string, string) (User, error) + ResendConfirmation(string) error + ConfirmEmail(string) (Session, error) + LoginEmail(string, string) (Session, error) + ForgotPassword(string) error + Refresh(Session) (Session, error) + ResetPassword(guuid.UUID, string, string) (Session, error) + ChangePassword(User, string, string, string) (Session, error) + Logout(Session) error + LogoutOthers(Session) error + LogoutAll(Session) error + Close() error +} + + + +func validateSession(session sessionT) (sessionT, error) { + // FIXME: implement + return session, nil +} + +func tryRollback(db *sql.DB, ctx context.Context, err error) error { + _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;") + if rollbackErr != nil { + return fmt.Errorf( + rollbackErrorFmt, + rollbackErr, + err, + ) + } + + return err +} + +func inTx(db *sql.DB, fn func(context.Context) error) error { + ctx := context.Background() + + _, err := db.ExecContext(ctx, "BEGIN IMMEDIATE;") + if err != nil { + return err + } + + err = fn(ctx) + if err != nil { + return tryRollback(db, ctx, err) + } + + _, err = db.ExecContext(ctx, "COMMIT;") + if err != nil { + return tryRollback(db, ctx, err) + } + + return nil +} + +func createTablesSQL(prefix string) queryT { + const tmpl_write = ` + CREATE TABLE IF NOT EXISTS "%s_users" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + uuid BLOB NOT NULL UNIQUE, + email TEXT NOT NULL UNIQUE, + salt BLOB NOT NULL UNIQUE, + pwhash BLOB NOT NULL + ); + CREATE TABLE IF NOT EXISTS "%s_confirmation_attempts" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + -- uuid BLOB NOT NULL UNIQUE, + user_id INTEGER NOT NULL REFERENCES "%s_users"(id), + token TEXT NOT NULL UNIQUE + ); + CREATE TABLE IF NOT EXISTS "%s_user_confirmations" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + user_id INTEGER NOT NULL + REFERENCES "%s_users"(id) UNIQUE, + attempt_id INTEGER NOT NULL + REFERENCES "%s_confirmation_attempts"(id) UNIQUE + ); + CREATE TABLE IF NOT EXISTS "%s_user_changes" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + user_id INTEGER NOT NULL REFERENCES "%s_users"(id), + attribute TEXT NOT NULL, + value TEXT NOT NULL, + op BOOLEAN NOT NULL + ); + -- CREATE TABLE IF NOT EXISTS "%s_tokens" ( + -- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + -- timestamp TEXT NOT NULL DEFAULT (%s), + -- uuid BLOB NOT NULL UNIQUE, + -- type TEXT NOT NULL + -- ); + CREATE TABLE IF NOT EXISTS "%s_roles" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL REFERENCES "%s_users"(id), + role TEXT NOT NULL, + UNIQUE (user_id, role) + ); + CREATE TABLE IF NOT EXISTS "%s_role_changes" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + user_id INTEGER NOT NULL REFERENCES "%s_roles"(id), + role TEXT NOT NULL, + op BOOLEAN NOT NULL + ); + CREATE TABLE IF NOT EXISTS "%s_sessions" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + uuid BLOB NOT NULL UNIQUE, + user_id INTEGER NOT NULL REFERENCES "%s_users"(id) + -- type TEXT NOT NULL, + -- revoked_at TEXT, + -- revoker_id INTEGER REFERENCES "%s_users"(id), + -- FIXME: add provenance: login, refresh, confirmation, etc. + ); + CREATE TABLE IF NOT EXISTS "%s_attempts" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + user_id INTEGER REFERENCES "%s_users"(id), + session_id INTEGER REFERENCES "%s_sessions"(id) + ); + CREATE TABLE IF NOT EXISTS "%s_audit" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + uuid BLOB NOT NULL UNIQUE, + attribute TEXT NOT NULL, + value TEXT NOT NULL, + op BOOLEAN NOT NULL + ); + ` + return queryT{ + write: fmt.Sprintf( + tmpl_write, + prefix, + g.SQLiteNow, + prefix, + g.SQLiteNow, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + prefix, + g.SQLiteNow, + ), + } +} + +func createTables(db *sql.DB, prefix string) error { + q := createTablesSQL(prefix) + + return inTx(db, func(ctx context.Context) error { + _, err := db.ExecContext(ctx, q.write) + return err + }) +} + +func registerSQL(prefix string) queryT { + const tmpl_write = ` + INSERT INTO "%s_users" (uuid, email, salt, pwhash) + VALUES (?, ?, ?, ?) RETURNING id, timestamp; + ` + const tmpl_read = ` + SELECT id, timestamp from "%s_users" + WHERE uuid = ?; + ` + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + read: fmt.Sprintf(tmpl_read, prefix), + } +} + +func registerStmt( + db *sql.DB, + prefix string, +) ( + func(guuid.UUID, string, []byte, []byte) (userT, error), + func() error, + error, +) { + q := registerSQL(prefix) + + writeStmt, err := db.Prepare(q.write) + if err != nil { + return nil, nil, err + } + + readStmt, err := db.Prepare(q.read) + if err != nil { + return nil, nil, err + } + + fn := func( + userID guuid.UUID, + email string, + salt []byte, + pwhash []byte, + ) (userT, error) { + user := userT{ + uuid: userID, + email: email, + salt: salt, + pwhash: pwhash, + confirmed: false, + } + + var timestr string + user_id_bytes := userID[:] + err := writeStmt.QueryRow( + user_id_bytes, + email, + salt, + pwhash, + ).Scan(&user.id, ×tr) + if err != nil { + return userT{}, err + } + + user.timestamp, err = time.Parse(time.RFC3339Nano, timestr) + if err != nil { + return userT{}, err + } + + return user, nil + } + + closeFn := func() error { + return g.SomeFnError(writeStmt.Close, readStmt.Close) + } + + return fn, closeFn, nil +} + +func sendTokenSQL(prefix string) queryT { + const tmpl_write = ` + INSERT INTO "%s_confirmation_attempts" (user_id, token) + VALUES ( + (SELECT id FROM "%s_users" WHERE uuid = ?), + ? + ) + ` + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix, prefix), + } +} + +func sendTokenStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID, string) error, func() error, error) { + q := sendTokenSQL(prefix) + + writeStmt, err := db.Prepare(q.write) + if err != nil { + return nil, nil, err + } + + fn := func(userID guuid.UUID, token string) error { + user_id_bytes := userID[:] + _, err := writeStmt.Exec(user_id_bytes, token) + return err + } + + return fn, writeStmt.Close, nil +} + +func confirmSQL(prefix string) queryT { + const tmpl_write = ` + INSERT INTO "%s_user_confirmations" (user_id, attempt_id) + VALUES (?, ?); + ` + const tmpl_read = ` + SELECT + "%s_confirmation_attempts".id, + "%s_confirmation_attempts".user_id, + "%s_users".uuid + FROM "%s_confirmation_attempts" + JOIN "%s_users" ON + "%s_confirmation_attempts".user_id = "%s_users".id + WHERE token = ?; + ` + const tmpl_session = ` + INSERT INTO "%s_sessions" (uuid, user_id) + VALUES (?, ?) RETURNING id, timestamp; + ` + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + session: fmt.Sprintf(tmpl_session, prefix), + } +} + +func confirmStmt( + db *sql.DB, + prefix string, +) (func(string, guuid.UUID) (sessionT, error), func() error, error) { + q := confirmSQL(prefix) + + writeStmt, err := db.Prepare(q.write) + if err != nil { + return nil, nil, err + } + + readStmt, err := db.Prepare(q.read) + if err != nil { + return nil, nil, g.WrapErrors(writeStmt.Close(), err) + } + + sessionStmt, err := db.Prepare(q.session) + if err != nil { + return nil, nil, g.WrapErrors( + writeStmt.Close(), + readStmt.Close(), + err, + ) + } + + fn := func(token string, sessionID guuid.UUID) (sessionT, error) { + session := sessionT{ + uuid: sessionID, + } + + var ( + user_id int64 + attempt_id int64 + user_id_bytes []byte + ) + err := readStmt.QueryRow(token).Scan( + &attempt_id, + &user_id, + &user_id_bytes, + ) + if err != nil { + return sessionT{}, err + } + session.userID = guuid.UUID(user_id_bytes) + + _, err = writeStmt.Exec(user_id, attempt_id) + if err != nil { + return sessionT{}, err + } + + session_id_bytes := sessionID[:] + var timestr string + err = sessionStmt.QueryRow( + session_id_bytes, + user_id, + ).Scan(&session.id, ×tr) + if err != nil { + return sessionT{}, err + } + + session.timestamp, err = time.Parse(time.RFC3339Nano, timestr) + if err != nil { + return sessionT{}, err + } + + return validateSession(session) + } + + closeFn := func() error { + return g.SomeFnError( + writeStmt.Close, + readStmt.Close, + sessionStmt.Close, + ) + } + + return fn, closeFn, nil +} + +func byEmailSQL(prefix string) queryT { + // FIXME: rewrite as LEFT JOIN? + const tmpl_read = ` + SELECT id, timestamp, uuid, salt, pwhash, ( + CASE WHEN EXISTS ( + SELECT id FROM "%s_user_confirmations" + WHERE user_id = ( + SELECT id FROM "%s_users" + WHERE email = ? + ) + ) THEN 1 + ELSE 0 + END + ) as confirmed + FROM "%s_users" WHERE email = ?; + ` + return queryT{ + read: fmt.Sprintf(tmpl_read, prefix, prefix, prefix), + } +} + +func byEmailStmt( + db *sql.DB, + prefix string, +) (func(string) (userT, error), func() error, error) { + q := byEmailSQL(prefix) + + readStmt, err := db.Prepare(q.read) + if err != nil { + return nil, nil, err + } + + fn := func(email string) (userT, error) { + user := userT{ + email: email, + } + + var ( + timestr string + user_id_bytes []byte + ) + err := readStmt.QueryRow(email, email).Scan( + &user.id, + ×tr, + &user_id_bytes, + &user.salt, + &user.pwhash, + &user.confirmed, + ) + if err != nil { + return userT{}, err + } + user.uuid = guuid.UUID(user_id_bytes) + + user.timestamp, err = time.Parse(time.RFC3339Nano,timestr) + if err != nil { + return userT{}, err + } + + return user, nil + } + + return fn, readStmt.Close, nil +} + +func loginSQL(prefix string) queryT { + const tmpl_session = ` + INSERT INTO "%s_sessions" (uuid, user_id) + VALUES ( + ?, + (SELECT id FROM "%s_users" WHERE uuid = ?) + ) RETURNING id, timestamp; + ` + return queryT{ + session: fmt.Sprintf(tmpl_session, prefix, prefix), + } +} + +func loginStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) { + q := loginSQL(prefix) + + sessionStmt, err := db.Prepare(q.session) + if err != nil { + return nil, nil, err + } + + fn := func(userID guuid.UUID, sessionID guuid.UUID) (sessionT, error) { + session := sessionT{ + uuid: sessionID, + userID: userID, + } + + user_id_bytes := userID[:] + session_id_bytes := sessionID[:] + var timestr string + err := sessionStmt.QueryRow( + session_id_bytes, + user_id_bytes, + ).Scan( + &session.id, + ×tr, + ) + if err != nil { + return sessionT{}, err + } + + session.timestamp, err = time.Parse(time.RFC3339Nano, timestr) + if err != nil { + return sessionT{}, err + } + + return validateSession(session) + } + + return fn, sessionStmt.Close, nil +} + +func refreshSQL(prefix string) queryT { + const tmpl_write = ` + INSERT INTO "%s_sessions" (uuid, user_id) + VALUES ( + ?, + (SELECT user_id FROM "%s_sessions" WHERE uuid = ?) + ) RETURNING id, timestamp, ( + SELECT "%s_users".uuid FROM "%s_users" + JOIN "%s_sessions" ON + "%s_users".id = "%s_sessions".user_id + WHERE "%s_sessions".uuid = ? + ) AS userID; + ` + return queryT{ + write: fmt.Sprintf( + tmpl_write, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + } +} + +func refreshStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) { + q := refreshSQL(prefix) + + writeStmt, err := db.Prepare(q.write) + if err != nil { + return nil, nil, err + } + + fn := func( + sessionID guuid.UUID, + newSessionID guuid.UUID, + ) (sessionT, error) { + session := sessionT{ + uuid: newSessionID, + } + + session_id_bytes := sessionID[:] + new_session_id_bytes := newSessionID[:] + var ( + timestr string + user_id_bytes []byte + ) + err := writeStmt.QueryRow( + new_session_id_bytes, + session_id_bytes, + session_id_bytes, + ).Scan(&session.id, ×tr, &user_id_bytes) + if err != nil { + return sessionT{}, err + } + session.userID = guuid.UUID(user_id_bytes) + + session.timestamp, err = time.Parse(time.RFC3339Nano, timestr) + if err != nil { + return sessionT{}, err + } + + return validateSession(session) + } + + return fn, writeStmt.Close, nil +} + +func resetSQL(prefix string) queryT { + const tmpl_write = ` + -- INSERT SOMETHING %s + ` + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} + +func resetStmt( + db *sql.DB, + prefix string, +) (func(int64, []byte, guuid.UUID) (sessionT, error), func() error, error) { + q := resetSQL(prefix) + + writeStmt, err := db.Prepare(q.write) + if err != nil { + return nil, nil, err + } + + fn := func(id int64, pwhash []byte, token guuid.UUID) (sessionT, error) { + var session sessionT + err := writeStmt.QueryRow(id, pwhash, token).Scan(&session) + if err != nil { + return sessionT{}, err + } + return validateSession(session) + } + + return fn, writeStmt.Close, nil +} + +func changeSQL(prefix string) queryT { + const tmpl_write = ` + -- INSERT SOMETHING %s + ` + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} + +func changeStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID, []byte) (sessionT, error), func() error, error) { + q := changeSQL(prefix) + + writeStmt, err := db.Prepare(q.write) + if err != nil { + return nil, nil, err + } + + fn := func(uuid guuid.UUID, pwhash []byte) (sessionT, error) { + var session sessionT + err := writeStmt.QueryRow(uuid, pwhash).Scan(&session) + if err != nil { + return sessionT{}, err + } + return validateSession(session) + } + + return fn, writeStmt.Close, nil +} + +func byUUIDSQL(prefix string) queryT { + const tmpl_read = ` + -- INSERT SOMETHING %s + ` + return queryT{ + read: fmt.Sprintf(tmpl_read, prefix), + } +} + +func byUUIDStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) (sessionT, error), func() error, error) { + q := byUUIDSQL(prefix) + + readStmt, err := db.Prepare(q.read) + if err != nil { + return nil, nil, err + } + + fn := func(sessionID guuid.UUID) (sessionT, error) { + var session sessionT + err := readStmt.QueryRow(sessionID).Scan(&session) + if err != nil { + return sessionT{}, err + } + return validateSession(session) + } + + return fn, readStmt.Close, nil +} + +func logoutSQL(prefix string) queryT { + const tmpl_write = ` + -- INSERT SOMETHING %s + ` + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} + +func logoutStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) error, func() error, error) { + q := logoutSQL(prefix) + + writeStmt, err := db.Prepare(q.write) + if err != nil { + return nil, nil, err + } + + fn := func(sessionID guuid.UUID) error { + _, err := writeStmt.Exec(sessionID) + return err + } + + return fn, writeStmt.Close, nil +} + +func outOthersSQL(prefix string) queryT { + const tmpl_write = ` + -- INSERT SOMETHING %s + ` + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} + +func outOthersStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) error, func() error, error) { + q := outOthersSQL(prefix) + + writeStmt, err := db.Prepare(q.write) + if err != nil { + return nil, nil, err + } + + fn := func(sessionID guuid.UUID) error { + _, err := writeStmt.Exec(sessionID) + return err + } + + return fn, writeStmt.Close, nil +} + +func outAllSQL(prefix string) queryT { + const tmpl_write = ` + -- INSERT SOMETHING %s + ` + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} + +func outAllStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) error, func() error, error) { + q := outAllSQL(prefix) + + writeStmt, err := db.Prepare(q.write) + if err != nil { + return nil, nil, err + } + + fn := func(sessionID guuid.UUID) error { + _, err := writeStmt.Exec(sessionID) + return err + } + + return fn, writeStmt.Close, nil +} + +func initDB( + db *sql.DB, + prefix string, +) (queriesT, error) { + createTablesErr := createTables(db, prefix) + register, registerClose, registerErr := registerStmt(db, prefix) + sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix) + confirm, confirmClose, confirmErr := confirmStmt(db, prefix) + byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix) + login, loginClose, loginErr := loginStmt(db, prefix) + refresh, refreshClose, refreshErr := refreshStmt(db, prefix) + reset, resetClose, resetErr := resetStmt(db, prefix) + change, changeClose, changeErr := changeStmt(db, prefix) + byUUID, byUUIDClose, byUUIDErr := byUUIDStmt(db, prefix) + logout, logoutClose, logoutErr := logoutStmt(db, prefix) + outOthers, outOthersClose, outOthersErr := outOthersStmt(db, prefix) + outAll, outAllClose, outAllErr := outAllStmt(db, prefix) + + err := g.SomeError( + createTablesErr, + registerErr, + sendTokenErr, + confirmErr, + byEmailErr, + loginErr, + refreshErr, + resetErr, + changeErr, + byUUIDErr, + logoutErr, + outOthersErr, + outAllErr, + ) + if err != nil { + return queriesT{}, err + } + + close := func() error { + return g.SomeFnError( + registerClose, + sendTokenClose, + confirmClose, + byEmailClose, + loginClose, + refreshClose, + resetClose, + changeClose, + byUUIDClose, + logoutClose, + outOthersClose, + outAllClose, + ) + } + + var connMutex sync.RWMutex + return queriesT{ + register: func( + a guuid.UUID, + b string, + c []byte, + d []byte, + ) (userT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return register(a, b, c, d) + }, + sendToken: func(a guuid.UUID, b string) error { + connMutex.RLock() + defer connMutex.RUnlock() + return sendToken(a, b) + }, + confirm: func(a string, b guuid.UUID) (sessionT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return confirm(a, b) + }, + byEmail: func(a string) (userT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return byEmail(a) + }, + login: func(a guuid.UUID, b guuid.UUID) (sessionT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return login(a, b) + }, + refresh: func(a guuid.UUID, b guuid.UUID) (sessionT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return refresh(a, b) + }, + reset: func(a int64, b []byte, c guuid.UUID) (sessionT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return reset(a, b, c) + }, + change: func(a guuid.UUID, b []byte) (sessionT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return change(a, b) + }, + byUUID: func(a guuid.UUID) (sessionT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return byUUID(a) + }, + logout: func(a guuid.UUID) error { + connMutex.RLock() + defer connMutex.RUnlock() + return logout(a) + }, + outOthers: func(a guuid.UUID) error { + connMutex.RLock() + defer connMutex.RUnlock() + return outOthers(a) + }, + outAll: func(a guuid.UUID) error { + connMutex.RLock() + defer connMutex.RUnlock() + return outAll(a) + }, + close: func() error { + connMutex.Lock() + defer connMutex.Unlock() + return close() + }, + }, nil +} + +func newUserHandler(auth authT) func(q.Message) error { + return func(message q.Message) error { + return nil + } +} + +func sendConfirmationRequestHandler(auth authT) func(q.Message) error { + return func(message q.Message) error { + return nil + } +} + +func forgotPasswordRequestHandler(auth authT) func(q.Message) error { + return func(message q.Message) error { + return nil + } +} + +var consumers = []consumerT{ + consumerT{ + topic: NEW_USER, + handlerFn: newUserHandler, + }, + consumerT{ + topic: SEND_CONFIRMATION_REQUEST, + handlerFn: sendConfirmationRequestHandler, + }, + consumerT{ + topic: FORGOT_PASSWORD_REQUEST, + handlerFn: forgotPasswordRequestHandler, + }, +} +func registerConsumers(auth authT, consumers []consumerT, prefix string) error { + for _, consumer := range consumers { + err := auth.queue.Subscribe( + consumer.topic, + prefix + "-" + consumer.topic, + consumer.handlerFn(auth), + ) + if err != nil { + return err + } + } + + return nil +} + +func unregisterConsumers(queue q.IQueue, consumers []consumerT, prefix string) { + for _, consumer := range consumers { + queue.Unsubscribe(consumer.topic, prefix + "-" + consumer.topic) + } +} + +type resultT[T any] struct{ + value T + err error +} + +type taggedT[T any] struct{ + id int + value T +} + +func startRunner[A any, B any]( + in <-chan taggedT[A], + out chan<- taggedT[B], + fn func(A) B, + done func(), +) { + for input := range in { + out <- taggedT[B]{ + id: input.id, + value: fn(input.value), + } + } + done() +} + +func makePoolRunner[A any, B any](count int, fn func(A) B) (func(A) B, func()) { + var wg sync.WaitGroup + wg.Add(count) + + in := make(chan taggedT[A]) + out := make(chan taggedT[B]) + + for _ = range count { + go startRunner(in, out, fn, wg.Done) + } + + var mutex sync.Mutex + m := map[int]chan B{} + id := 0 + go func() { + for output := range out { + mutex.Lock() + defer mutex.Unlock() + m[output.id] <- output.value + close(m[output.id]) + delete(m, output.id) + } + }() + + poolRunFn := func(input A) B { + c := make(chan B) + { + mutex.Lock() + defer mutex.Unlock() + m[id] = c + id++ + } + + in <- taggedT[A]{ + id: id, + value: input, + } + return <- c + } + + close := func() { + close(in) + wg.Wait() + close(out) + } + + return poolRunFn, close +} + +func asResult[A any, B any](fn func(A) (B, error)) func(A) resultT[B] { + return func(input A) resultT[B] { + output, err := fn(input) + return resultT[B]{ + value: output, + err: err, + } + } +} + +func unwrapResult[A any, B any](fn func(A) resultT[B]) func(A) (B, error) { + return func(input A) (B, error) { + result := fn(input) + return result.value, result.err + } +} + +func NewWithPrefix(databasePath string, prefix string) (IAuth, error) { + queue, err := q.New(databasePath) + if err != nil { + return authT{}, err + } + + db, err := sql.Open(golite.DriverName, databasePath) + if err != nil { + return authT{}, err + } + + queries, err := initDB(db, prefix) + if err != nil { + return authT{}, err + } + + numCPU := runtime.NumCPU() + hasher, closeHasher := makePoolRunner(numCPU, asResult(scrypt.Hash)) + + closeFn := func() { + unregisterConsumers(queue, consumers, prefix) + closeHasher() + } + + auth := authT{ + queue: queue, + db: db, + queries: queries, + hasher: unwrapResult(hasher), + close: closeFn, + } + + err = registerConsumers(auth, consumers, prefix) + if err != nil { + closeFn() + return authT{}, err + } + + return auth, nil +} + +func New(databasePath string) (IAuth, error) { + return NewWithPrefix(databasePath, defaultPrefix) +} + +func newUserPayload(email string, salt []byte, pwhash []byte) ([]byte, error) { + data := make(map[string]interface{}) + data["email"] = email + data["salt"] = hex.EncodeToString(salt) + data["pwhash"] = hex.EncodeToString(pwhash) + return json.Marshal(data) +} + +func asPublicUser(user userT) (User, error) { + return User{}, nil +} + +func (auth authT) Register( + email string, + password string, + confirmPassword string, +) (User, error) { + if password != confirmPassword { + return User{}, ErrPassMismatch + } + + if len(password) < scrypt.MinimumPasswordLength { + return User{}, ErrTooShort + } + + // special check for sql.ErrNoRows to combat enumeration attacks. + // FIXME: how so? + _, lookupErr := auth.queries.byEmail(email) + if lookupErr != nil && lookupErr != sql.ErrNoRows { + return User{}, lookupErr + } + + salt, err := scrypt.Salt() + if err != nil { + return User{}, err + } + + input := scrypt.HashInput{ + Password: []byte(password), + Salt: salt, + } + hash, err := auth.hasher(input) + if err != nil { + return User{}, err + } + + /* + We also try to register anyway, to prevent disk IO timing attacks. + / FIXME: how so? + */ + + flowID := guuid.New() + waiter := auth.queue.WaitFor(NEW_USER, flowID, "register") + defer waiter.Close() + + payload, err := newUserPayload(email, salt, hash) + if err != nil { + return User{}, err + } + + unsent := q.UnsentMessage{ + Topic: NEW_USER, + FlowID: flowID, + Payload: payload, + } + _, err = auth.queue.Publish(unsent) + if err != nil { + return User{}, err + } + + user := userT{} + select { + case <-time.After(RegisterTimeout): + err = ErrTimeout + case <-waiter.Channel: + user, err = auth.queries.byEmail(email) + } + + return asPublicUser(user) +} + +func sendConfirmationMessage( + email string, + flowID guuid.UUID, +) (q.UnsentMessage, error) { + data := make(map[string]interface{}) + data["email"] = email + payload, err := json.Marshal(data) + if err != nil { + return q.UnsentMessage{}, err + } + + return q.UnsentMessage{ + Topic: SEND_CONFIRMATION_REQUEST, + FlowID: flowID, + Payload: payload, + }, nil +} + +func (auth authT) ResendConfirmation(email string) error { + unsent, err := sendConfirmationMessage(email, guuid.New()) + if err != nil { + return err + } + + _, err = auth.queue.Publish(unsent) + return err +} + +/* +func asPublicSession(session sessionT) Session { + _, err := validateSession(session) + if err != nil { + panic(fmt.Errorf( + "session must have been validated at this stage: %w", + err, + )) + } + return Session{} +} +*/ + +func asPublicSession(session sessionT) (Session, error) { + return Session{}, nil +} + +func (auth authT) ConfirmEmail(token string) (Session, error) { + session, err := auth.queries.confirm("token FIXME", guuid.New()) + if err != nil { + return Session{}, err + } + + return asPublicSession(session) +} + +func (auth authT) LoginEmail( + email string, + password string, +) (Session, error) { + if len(password) < scrypt.MinimumPasswordLength { + return Session{}, ErrTooShort + } + + // special check for sql.ErrNoRows to combat enumeration attacks. + // FIXME: how so? + user, err := auth.queries.byEmail(email) + if err != nil && err != sql.ErrNoRows { + return Session{}, err + } + + input := scrypt.HashInput{ + Password: []byte(password), + Salt: user.salt, + } + hash, err := auth.hasher(input) + if err != nil { + return Session{}, err + } + + ok := slices.Equal(hash, user.pwhash) + if !ok { + return Session{}, ErrBadCombo + } + + if !user.confirmed { + return Session{}, ErrUnconfirmed + } + + session, err := auth.queries.login(user.uuid, guuid.New()) + if err != nil { + return Session{}, err + } + + return asPublicSession(session) +} + +func forgotPasswordMessage( + email string, + flowID guuid.UUID, +) (q.UnsentMessage, error) { + data := make(map[string]interface{}) + data["email"] = email + payload, err := json.Marshal(data) + if err != nil { + return q.UnsentMessage{}, err + } + + return q.UnsentMessage{ + Topic: FORGOT_PASSWORD_REQUEST, + FlowID: flowID, + Payload: payload, + }, nil +} + +func (auth authT) ForgotPassword(email string) error { + // special check for sql.ErrNoRows to combat enumeration attacks. + user, err := auth.queries.byEmail(email) + if err != nil && err != sql.ErrNoRows { + return err + } + + unsent, err := forgotPasswordMessage(user.email, guuid.New()) + if err != nil { + return err + } + + _, err = auth.queue.Publish(unsent) + return err +} + +/* +func checkSession(session sessionT, now time.Time) error { + if session.revoked_at != nil { + return ErrRevokedSession + } + + if session.timestamp.Add(SessionDuration).After(now) { + return ErrSessionExpired + } + + return nil +} + +func validateSession( + lookupFn func(guuid.UUID) (sessionT, error), + session sessionT, +) error { + dbSession, err := lookupFn(session.uuid) + if err != nil { + return err + } + + return checkSession(dbSession, time.Now()) +} +*/ + +func (auth authT) Refresh(session Session) (Session, error) { + /* + err := validateSession(auth.queries.byUUID, session) + if err != nil { + return sessionT{}, err + } + */ + + // return auth.queries.refresh(session.uuid, guuid.New()) + newSession, err := auth.queries.refresh(session.UUID, guuid.New()) + if err != nil { + return Session{}, err + } + + return asPublicSession(newSession) +} + +func (auth authT) ResetPassword( + token guuid.UUID, + password string, + confirmPassword string, +) (Session, error) { + if password != confirmPassword { + return Session{}, ErrPassMismatch + } + + if len(password) < scrypt.MinimumPasswordLength { + return Session{}, ErrTooShort + } + + user := userT{} + + input := scrypt.HashInput{ + Password: []byte(password), + Salt: user.salt, + } + pwhash, err := auth.hasher(input) + if err != nil { + return Session{}, err + } + + var nextFn func() (sessionT, error) + if user.confirmed { + nextFn = func() (sessionT, error) { + return auth.queries.reset(user.id, pwhash, token) + } + } else { + nextFn = func() (sessionT, error) { + // return auth.queries.confirm(user.id, pwhash, token) + return auth.queries.confirm("token FIXME", guuid.New()) + } + } + + session, err := nextFn() + if err != nil { + return Session{}, err + } + + return asPublicSession(session) +} + +func (auth authT) ChangePassword( + user User, + currentPassword string, + newPassword string, + confirmNewPassword string, +) (Session, error) { + if newPassword != confirmNewPassword { + return Session{}, ErrPassMismatch + } + + if len(newPassword) < scrypt.MinimumPasswordLength { + return Session{}, ErrTooShort + } + + input := scrypt.HashInput{ + Password: []byte(newPassword), + Salt: user.Salt, + } + pwhash, err := auth.hasher(input) + if err != nil { + return Session{}, err + } + + if !user.Confirmed { + return Session{}, ErrUnconfirmed + } + + session, err := auth.queries.change(user.UUID, pwhash) + if err != nil { + return Session{}, nil + } + + return asPublicSession(session) +} + +func runLogout( + lookupFn func(guuid.UUID) (sessionT, error), + sessionID guuid.UUID, + queryFn func(guuid.UUID) error, +) error { + /* + err := validateSession(lookupFn, session) + if err != nil { + return err + } + */ + + return queryFn(sessionID) +} + +func (auth authT) Logout(session Session) error { + return runLogout( + auth.queries.byUUID, + session.UUID, + auth.queries.logout, + ) +} + +func (auth authT) LogoutOthers(session Session) error { + return runLogout( + auth.queries.byUUID, + session.UUID, + auth.queries.outOthers, + ) +} + +func (auth authT) LogoutAll(session Session) error { + return runLogout( + auth.queries.byUUID, + session.UUID, + auth.queries.outAll, + ) +} + +func (auth authT) Close() error { + return g.WrapErrors(auth.queries.close(), auth.db.Close()) +} + + + +func Main() { + // FIXME +} |