summaryrefslogtreecommitdiff
path: root/src/cracha.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/cracha.go')
-rw-r--r--src/cracha.go1587
1 files changed, 1587 insertions, 0 deletions
diff --git a/src/cracha.go b/src/cracha.go
new file mode 100644
index 0000000..5547d6e
--- /dev/null
+++ b/src/cracha.go
@@ -0,0 +1,1587 @@
+package cracha
+
+import (
+ "context"
+ "database/sql"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "runtime"
+ "slices"
+ "sync"
+ "time"
+
+ "golite"
+ "guuid"
+ "q"
+ "scrypt"
+ g "gobang"
+)
+
+
+
+const (
+ defaultPrefix = "cracha"
+
+ rollbackErrorFmt = "rollback error: %w; while executing: %w"
+ NEW_USER = "new-user"
+ SEND_CONFIRMATION_REQUEST = "send-confirmation-request"
+ FORGOT_PASSWORD_REQUEST = "forgot-password-request"
+
+ day = 24 * time.Hour
+)
+
+var (
+ SessionDuration = 7 * day
+ RegisterTimeout = 15 * time.Second
+
+ ErrPassMismatch = errors.New("cracha: password confirmation mismatch")
+ ErrTooShort = errors.New("cracha: password too short")
+ ErrTimeout = errors.New("cracha: timeout when creating user")
+ ErrRegistered = errors.New("cracha: user already registered")
+ ErrBadCombo = errors.New("cracha: bad username/passphrase combo")
+ ErrUnconfirmed = errors.New("cracha: email is not confirmed")
+ ErrRevokedSession = errors.New("cracha: this session was revoked")
+ ErrSessionExpired = errors.New("cracha: session expired")
+)
+
+
+
+type queryT struct{
+ write string
+ read string
+ session string
+}
+
+type queriesT struct{
+ register func(guuid.UUID, string, []byte, []byte) (userT, error)
+ sendToken func(guuid.UUID, string) error
+ confirm func(string, guuid.UUID) (sessionT, error)
+ byEmail func(string) (userT, error)
+ login func(guuid.UUID, guuid.UUID) (sessionT, error)
+ refresh func(guuid.UUID, guuid.UUID) (sessionT, error)
+ reset func(int64, []byte, guuid.UUID) (sessionT, error)
+ change func(guuid.UUID, []byte) (sessionT, error)
+ byUUID func(guuid.UUID) (sessionT, error)
+ logout func(guuid.UUID) error
+ outOthers func(guuid.UUID) error
+ outAll func(guuid.UUID) error
+ close func() error
+}
+
+type userT struct{
+ id int64
+ timestamp time.Time
+ uuid guuid.UUID
+ email string
+ salt []byte
+ pwhash []byte
+ confirmed bool
+}
+
+type User struct{
+ UUID guuid.UUID
+ Salt []byte
+ Confirmed bool
+}
+
+type sessionT struct{
+ id int64
+ timestamp time.Time
+ uuid guuid.UUID
+ userID guuid.UUID
+ // type_ string
+ revoked_at *time.Time
+}
+
+type Session struct{
+ UUID guuid.UUID
+}
+
+type consumerT struct{
+ topic string
+ handlerFn func(authT) func(q.Message) error
+}
+
+type authT struct{
+ queue q.IQueue
+ db *sql.DB
+ queries queriesT
+ hasher func(scrypt.HashInput) ([]byte, error)
+ close func()
+}
+
+type IAuth interface{
+ Register(string, string, string) (User, error)
+ ResendConfirmation(string) error
+ ConfirmEmail(string) (Session, error)
+ LoginEmail(string, string) (Session, error)
+ ForgotPassword(string) error
+ Refresh(Session) (Session, error)
+ ResetPassword(guuid.UUID, string, string) (Session, error)
+ ChangePassword(User, string, string, string) (Session, error)
+ Logout(Session) error
+ LogoutOthers(Session) error
+ LogoutAll(Session) error
+ Close() error
+}
+
+
+
+func validateSession(session sessionT) (sessionT, error) {
+ // FIXME: implement
+ return session, nil
+}
+
+func tryRollback(db *sql.DB, ctx context.Context, err error) error {
+ _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;")
+ if rollbackErr != nil {
+ return fmt.Errorf(
+ rollbackErrorFmt,
+ rollbackErr,
+ err,
+ )
+ }
+
+ return err
+}
+
+func inTx(db *sql.DB, fn func(context.Context) error) error {
+ ctx := context.Background()
+
+ _, err := db.ExecContext(ctx, "BEGIN IMMEDIATE;")
+ if err != nil {
+ return err
+ }
+
+ err = fn(ctx)
+ if err != nil {
+ return tryRollback(db, ctx, err)
+ }
+
+ _, err = db.ExecContext(ctx, "COMMIT;")
+ if err != nil {
+ return tryRollback(db, ctx, err)
+ }
+
+ return nil
+}
+
+func createTablesSQL(prefix string) queryT {
+ const tmpl_write = `
+ CREATE TABLE IF NOT EXISTS "%s_users" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ uuid BLOB NOT NULL UNIQUE,
+ email TEXT NOT NULL UNIQUE,
+ salt BLOB NOT NULL UNIQUE,
+ pwhash BLOB NOT NULL
+ );
+ CREATE TABLE IF NOT EXISTS "%s_confirmation_attempts" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ -- uuid BLOB NOT NULL UNIQUE,
+ user_id INTEGER NOT NULL REFERENCES "%s_users"(id),
+ token TEXT NOT NULL UNIQUE
+ );
+ CREATE TABLE IF NOT EXISTS "%s_user_confirmations" (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ user_id INTEGER NOT NULL
+ REFERENCES "%s_users"(id) UNIQUE,
+ attempt_id INTEGER NOT NULL
+ REFERENCES "%s_confirmation_attempts"(id) UNIQUE
+ );
+ CREATE TABLE IF NOT EXISTS "%s_user_changes" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ user_id INTEGER NOT NULL REFERENCES "%s_users"(id),
+ attribute TEXT NOT NULL,
+ value TEXT NOT NULL,
+ op BOOLEAN NOT NULL
+ );
+ -- CREATE TABLE IF NOT EXISTS "%s_tokens" (
+ -- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ -- timestamp TEXT NOT NULL DEFAULT (%s),
+ -- uuid BLOB NOT NULL UNIQUE,
+ -- type TEXT NOT NULL
+ -- );
+ CREATE TABLE IF NOT EXISTS "%s_roles" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL REFERENCES "%s_users"(id),
+ role TEXT NOT NULL,
+ UNIQUE (user_id, role)
+ );
+ CREATE TABLE IF NOT EXISTS "%s_role_changes" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ user_id INTEGER NOT NULL REFERENCES "%s_roles"(id),
+ role TEXT NOT NULL,
+ op BOOLEAN NOT NULL
+ );
+ CREATE TABLE IF NOT EXISTS "%s_sessions" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ uuid BLOB NOT NULL UNIQUE,
+ user_id INTEGER NOT NULL REFERENCES "%s_users"(id)
+ -- type TEXT NOT NULL,
+ -- revoked_at TEXT,
+ -- revoker_id INTEGER REFERENCES "%s_users"(id),
+ -- FIXME: add provenance: login, refresh, confirmation, etc.
+ );
+ CREATE TABLE IF NOT EXISTS "%s_attempts" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ user_id INTEGER REFERENCES "%s_users"(id),
+ session_id INTEGER REFERENCES "%s_sessions"(id)
+ );
+ CREATE TABLE IF NOT EXISTS "%s_audit" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ uuid BLOB NOT NULL UNIQUE,
+ attribute TEXT NOT NULL,
+ value TEXT NOT NULL,
+ op BOOLEAN NOT NULL
+ );
+ `
+ return queryT{
+ write: fmt.Sprintf(
+ tmpl_write,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ ),
+ }
+}
+
+func createTables(db *sql.DB, prefix string) error {
+ q := createTablesSQL(prefix)
+
+ return inTx(db, func(ctx context.Context) error {
+ _, err := db.ExecContext(ctx, q.write)
+ return err
+ })
+}
+
+func registerSQL(prefix string) queryT {
+ const tmpl_write = `
+ INSERT INTO "%s_users" (uuid, email, salt, pwhash)
+ VALUES (?, ?, ?, ?) RETURNING id, timestamp;
+ `
+ const tmpl_read = `
+ SELECT id, timestamp from "%s_users"
+ WHERE uuid = ?;
+ `
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ read: fmt.Sprintf(tmpl_read, prefix),
+ }
+}
+
+func registerStmt(
+ db *sql.DB,
+ prefix string,
+) (
+ func(guuid.UUID, string, []byte, []byte) (userT, error),
+ func() error,
+ error,
+) {
+ q := registerSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ readStmt, err := db.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(
+ userID guuid.UUID,
+ email string,
+ salt []byte,
+ pwhash []byte,
+ ) (userT, error) {
+ user := userT{
+ uuid: userID,
+ email: email,
+ salt: salt,
+ pwhash: pwhash,
+ confirmed: false,
+ }
+
+ var timestr string
+ user_id_bytes := userID[:]
+ err := writeStmt.QueryRow(
+ user_id_bytes,
+ email,
+ salt,
+ pwhash,
+ ).Scan(&user.id, &timestr)
+ if err != nil {
+ return userT{}, err
+ }
+
+ user.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
+ if err != nil {
+ return userT{}, err
+ }
+
+ return user, nil
+ }
+
+ closeFn := func() error {
+ return g.SomeFnError(writeStmt.Close, readStmt.Close)
+ }
+
+ return fn, closeFn, nil
+}
+
+func sendTokenSQL(prefix string) queryT {
+ const tmpl_write = `
+ INSERT INTO "%s_confirmation_attempts" (user_id, token)
+ VALUES (
+ (SELECT id FROM "%s_users" WHERE uuid = ?),
+ ?
+ )
+ `
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix, prefix),
+ }
+}
+
+func sendTokenStmt(
+ db *sql.DB,
+ prefix string,
+) (func(guuid.UUID, string) error, func() error, error) {
+ q := sendTokenSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(userID guuid.UUID, token string) error {
+ user_id_bytes := userID[:]
+ _, err := writeStmt.Exec(user_id_bytes, token)
+ return err
+ }
+
+ return fn, writeStmt.Close, nil
+}
+
+func confirmSQL(prefix string) queryT {
+ const tmpl_write = `
+ INSERT INTO "%s_user_confirmations" (user_id, attempt_id)
+ VALUES (?, ?);
+ `
+ const tmpl_read = `
+ SELECT
+ "%s_confirmation_attempts".id,
+ "%s_confirmation_attempts".user_id,
+ "%s_users".uuid
+ FROM "%s_confirmation_attempts"
+ JOIN "%s_users" ON
+ "%s_confirmation_attempts".user_id = "%s_users".id
+ WHERE token = ?;
+ `
+ const tmpl_session = `
+ INSERT INTO "%s_sessions" (uuid, user_id)
+ VALUES (?, ?) RETURNING id, timestamp;
+ `
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ read: fmt.Sprintf(
+ tmpl_read,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ session: fmt.Sprintf(tmpl_session, prefix),
+ }
+}
+
+func confirmStmt(
+ db *sql.DB,
+ prefix string,
+) (func(string, guuid.UUID) (sessionT, error), func() error, error) {
+ q := confirmSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ readStmt, err := db.Prepare(q.read)
+ if err != nil {
+ return nil, nil, g.WrapErrors(writeStmt.Close(), err)
+ }
+
+ sessionStmt, err := db.Prepare(q.session)
+ if err != nil {
+ return nil, nil, g.WrapErrors(
+ writeStmt.Close(),
+ readStmt.Close(),
+ err,
+ )
+ }
+
+ fn := func(token string, sessionID guuid.UUID) (sessionT, error) {
+ session := sessionT{
+ uuid: sessionID,
+ }
+
+ var (
+ user_id int64
+ attempt_id int64
+ user_id_bytes []byte
+ )
+ err := readStmt.QueryRow(token).Scan(
+ &attempt_id,
+ &user_id,
+ &user_id_bytes,
+ )
+ if err != nil {
+ return sessionT{}, err
+ }
+ session.userID = guuid.UUID(user_id_bytes)
+
+ _, err = writeStmt.Exec(user_id, attempt_id)
+ if err != nil {
+ return sessionT{}, err
+ }
+
+ session_id_bytes := sessionID[:]
+ var timestr string
+ err = sessionStmt.QueryRow(
+ session_id_bytes,
+ user_id,
+ ).Scan(&session.id, &timestr)
+ if err != nil {
+ return sessionT{}, err
+ }
+
+ session.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
+ if err != nil {
+ return sessionT{}, err
+ }
+
+ return validateSession(session)
+ }
+
+ closeFn := func() error {
+ return g.SomeFnError(
+ writeStmt.Close,
+ readStmt.Close,
+ sessionStmt.Close,
+ )
+ }
+
+ return fn, closeFn, nil
+}
+
+func byEmailSQL(prefix string) queryT {
+ // FIXME: rewrite as LEFT JOIN?
+ const tmpl_read = `
+ SELECT id, timestamp, uuid, salt, pwhash, (
+ CASE WHEN EXISTS (
+ SELECT id FROM "%s_user_confirmations"
+ WHERE user_id = (
+ SELECT id FROM "%s_users"
+ WHERE email = ?
+ )
+ ) THEN 1
+ ELSE 0
+ END
+ ) as confirmed
+ FROM "%s_users" WHERE email = ?;
+ `
+ return queryT{
+ read: fmt.Sprintf(tmpl_read, prefix, prefix, prefix),
+ }
+}
+
+func byEmailStmt(
+ db *sql.DB,
+ prefix string,
+) (func(string) (userT, error), func() error, error) {
+ q := byEmailSQL(prefix)
+
+ readStmt, err := db.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(email string) (userT, error) {
+ user := userT{
+ email: email,
+ }
+
+ var (
+ timestr string
+ user_id_bytes []byte
+ )
+ err := readStmt.QueryRow(email, email).Scan(
+ &user.id,
+ &timestr,
+ &user_id_bytes,
+ &user.salt,
+ &user.pwhash,
+ &user.confirmed,
+ )
+ if err != nil {
+ return userT{}, err
+ }
+ user.uuid = guuid.UUID(user_id_bytes)
+
+ user.timestamp, err = time.Parse(time.RFC3339Nano,timestr)
+ if err != nil {
+ return userT{}, err
+ }
+
+ return user, nil
+ }
+
+ return fn, readStmt.Close, nil
+}
+
+func loginSQL(prefix string) queryT {
+ const tmpl_session = `
+ INSERT INTO "%s_sessions" (uuid, user_id)
+ VALUES (
+ ?,
+ (SELECT id FROM "%s_users" WHERE uuid = ?)
+ ) RETURNING id, timestamp;
+ `
+ return queryT{
+ session: fmt.Sprintf(tmpl_session, prefix, prefix),
+ }
+}
+
+func loginStmt(
+ db *sql.DB,
+ prefix string,
+) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) {
+ q := loginSQL(prefix)
+
+ sessionStmt, err := db.Prepare(q.session)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(userID guuid.UUID, sessionID guuid.UUID) (sessionT, error) {
+ session := sessionT{
+ uuid: sessionID,
+ userID: userID,
+ }
+
+ user_id_bytes := userID[:]
+ session_id_bytes := sessionID[:]
+ var timestr string
+ err := sessionStmt.QueryRow(
+ session_id_bytes,
+ user_id_bytes,
+ ).Scan(
+ &session.id,
+ &timestr,
+ )
+ if err != nil {
+ return sessionT{}, err
+ }
+
+ session.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
+ if err != nil {
+ return sessionT{}, err
+ }
+
+ return validateSession(session)
+ }
+
+ return fn, sessionStmt.Close, nil
+}
+
+func refreshSQL(prefix string) queryT {
+ const tmpl_write = `
+ INSERT INTO "%s_sessions" (uuid, user_id)
+ VALUES (
+ ?,
+ (SELECT user_id FROM "%s_sessions" WHERE uuid = ?)
+ ) RETURNING id, timestamp, (
+ SELECT "%s_users".uuid FROM "%s_users"
+ JOIN "%s_sessions" ON
+ "%s_users".id = "%s_sessions".user_id
+ WHERE "%s_sessions".uuid = ?
+ ) AS userID;
+ `
+ return queryT{
+ write: fmt.Sprintf(
+ tmpl_write,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ }
+}
+
+func refreshStmt(
+ db *sql.DB,
+ prefix string,
+) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) {
+ q := refreshSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(
+ sessionID guuid.UUID,
+ newSessionID guuid.UUID,
+ ) (sessionT, error) {
+ session := sessionT{
+ uuid: newSessionID,
+ }
+
+ session_id_bytes := sessionID[:]
+ new_session_id_bytes := newSessionID[:]
+ var (
+ timestr string
+ user_id_bytes []byte
+ )
+ err := writeStmt.QueryRow(
+ new_session_id_bytes,
+ session_id_bytes,
+ session_id_bytes,
+ ).Scan(&session.id, &timestr, &user_id_bytes)
+ if err != nil {
+ return sessionT{}, err
+ }
+ session.userID = guuid.UUID(user_id_bytes)
+
+ session.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
+ if err != nil {
+ return sessionT{}, err
+ }
+
+ return validateSession(session)
+ }
+
+ return fn, writeStmt.Close, nil
+}
+
+func resetSQL(prefix string) queryT {
+ const tmpl_write = `
+ -- INSERT SOMETHING %s
+ `
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ }
+}
+
+func resetStmt(
+ db *sql.DB,
+ prefix string,
+) (func(int64, []byte, guuid.UUID) (sessionT, error), func() error, error) {
+ q := resetSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(id int64, pwhash []byte, token guuid.UUID) (sessionT, error) {
+ var session sessionT
+ err := writeStmt.QueryRow(id, pwhash, token).Scan(&session)
+ if err != nil {
+ return sessionT{}, err
+ }
+ return validateSession(session)
+ }
+
+ return fn, writeStmt.Close, nil
+}
+
+func changeSQL(prefix string) queryT {
+ const tmpl_write = `
+ -- INSERT SOMETHING %s
+ `
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ }
+}
+
+func changeStmt(
+ db *sql.DB,
+ prefix string,
+) (func(guuid.UUID, []byte) (sessionT, error), func() error, error) {
+ q := changeSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(uuid guuid.UUID, pwhash []byte) (sessionT, error) {
+ var session sessionT
+ err := writeStmt.QueryRow(uuid, pwhash).Scan(&session)
+ if err != nil {
+ return sessionT{}, err
+ }
+ return validateSession(session)
+ }
+
+ return fn, writeStmt.Close, nil
+}
+
+func byUUIDSQL(prefix string) queryT {
+ const tmpl_read = `
+ -- INSERT SOMETHING %s
+ `
+ return queryT{
+ read: fmt.Sprintf(tmpl_read, prefix),
+ }
+}
+
+func byUUIDStmt(
+ db *sql.DB,
+ prefix string,
+) (func(guuid.UUID) (sessionT, error), func() error, error) {
+ q := byUUIDSQL(prefix)
+
+ readStmt, err := db.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(sessionID guuid.UUID) (sessionT, error) {
+ var session sessionT
+ err := readStmt.QueryRow(sessionID).Scan(&session)
+ if err != nil {
+ return sessionT{}, err
+ }
+ return validateSession(session)
+ }
+
+ return fn, readStmt.Close, nil
+}
+
+func logoutSQL(prefix string) queryT {
+ const tmpl_write = `
+ -- INSERT SOMETHING %s
+ `
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ }
+}
+
+func logoutStmt(
+ db *sql.DB,
+ prefix string,
+) (func(guuid.UUID) error, func() error, error) {
+ q := logoutSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(sessionID guuid.UUID) error {
+ _, err := writeStmt.Exec(sessionID)
+ return err
+ }
+
+ return fn, writeStmt.Close, nil
+}
+
+func outOthersSQL(prefix string) queryT {
+ const tmpl_write = `
+ -- INSERT SOMETHING %s
+ `
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ }
+}
+
+func outOthersStmt(
+ db *sql.DB,
+ prefix string,
+) (func(guuid.UUID) error, func() error, error) {
+ q := outOthersSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(sessionID guuid.UUID) error {
+ _, err := writeStmt.Exec(sessionID)
+ return err
+ }
+
+ return fn, writeStmt.Close, nil
+}
+
+func outAllSQL(prefix string) queryT {
+ const tmpl_write = `
+ -- INSERT SOMETHING %s
+ `
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ }
+}
+
+func outAllStmt(
+ db *sql.DB,
+ prefix string,
+) (func(guuid.UUID) error, func() error, error) {
+ q := outAllSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(sessionID guuid.UUID) error {
+ _, err := writeStmt.Exec(sessionID)
+ return err
+ }
+
+ return fn, writeStmt.Close, nil
+}
+
+func initDB(
+ db *sql.DB,
+ prefix string,
+) (queriesT, error) {
+ createTablesErr := createTables(db, prefix)
+ register, registerClose, registerErr := registerStmt(db, prefix)
+ sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix)
+ confirm, confirmClose, confirmErr := confirmStmt(db, prefix)
+ byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix)
+ login, loginClose, loginErr := loginStmt(db, prefix)
+ refresh, refreshClose, refreshErr := refreshStmt(db, prefix)
+ reset, resetClose, resetErr := resetStmt(db, prefix)
+ change, changeClose, changeErr := changeStmt(db, prefix)
+ byUUID, byUUIDClose, byUUIDErr := byUUIDStmt(db, prefix)
+ logout, logoutClose, logoutErr := logoutStmt(db, prefix)
+ outOthers, outOthersClose, outOthersErr := outOthersStmt(db, prefix)
+ outAll, outAllClose, outAllErr := outAllStmt(db, prefix)
+
+ err := g.SomeError(
+ createTablesErr,
+ registerErr,
+ sendTokenErr,
+ confirmErr,
+ byEmailErr,
+ loginErr,
+ refreshErr,
+ resetErr,
+ changeErr,
+ byUUIDErr,
+ logoutErr,
+ outOthersErr,
+ outAllErr,
+ )
+ if err != nil {
+ return queriesT{}, err
+ }
+
+ close := func() error {
+ return g.SomeFnError(
+ registerClose,
+ sendTokenClose,
+ confirmClose,
+ byEmailClose,
+ loginClose,
+ refreshClose,
+ resetClose,
+ changeClose,
+ byUUIDClose,
+ logoutClose,
+ outOthersClose,
+ outAllClose,
+ )
+ }
+
+ var connMutex sync.RWMutex
+ return queriesT{
+ register: func(
+ a guuid.UUID,
+ b string,
+ c []byte,
+ d []byte,
+ ) (userT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return register(a, b, c, d)
+ },
+ sendToken: func(a guuid.UUID, b string) error {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return sendToken(a, b)
+ },
+ confirm: func(a string, b guuid.UUID) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return confirm(a, b)
+ },
+ byEmail: func(a string) (userT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return byEmail(a)
+ },
+ login: func(a guuid.UUID, b guuid.UUID) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return login(a, b)
+ },
+ refresh: func(a guuid.UUID, b guuid.UUID) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return refresh(a, b)
+ },
+ reset: func(a int64, b []byte, c guuid.UUID) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return reset(a, b, c)
+ },
+ change: func(a guuid.UUID, b []byte) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return change(a, b)
+ },
+ byUUID: func(a guuid.UUID) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return byUUID(a)
+ },
+ logout: func(a guuid.UUID) error {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return logout(a)
+ },
+ outOthers: func(a guuid.UUID) error {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return outOthers(a)
+ },
+ outAll: func(a guuid.UUID) error {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return outAll(a)
+ },
+ close: func() error {
+ connMutex.Lock()
+ defer connMutex.Unlock()
+ return close()
+ },
+ }, nil
+}
+
+func newUserHandler(auth authT) func(q.Message) error {
+ return func(message q.Message) error {
+ return nil
+ }
+}
+
+func sendConfirmationRequestHandler(auth authT) func(q.Message) error {
+ return func(message q.Message) error {
+ return nil
+ }
+}
+
+func forgotPasswordRequestHandler(auth authT) func(q.Message) error {
+ return func(message q.Message) error {
+ return nil
+ }
+}
+
+var consumers = []consumerT{
+ consumerT{
+ topic: NEW_USER,
+ handlerFn: newUserHandler,
+ },
+ consumerT{
+ topic: SEND_CONFIRMATION_REQUEST,
+ handlerFn: sendConfirmationRequestHandler,
+ },
+ consumerT{
+ topic: FORGOT_PASSWORD_REQUEST,
+ handlerFn: forgotPasswordRequestHandler,
+ },
+}
+func registerConsumers(auth authT, consumers []consumerT, prefix string) error {
+ for _, consumer := range consumers {
+ err := auth.queue.Subscribe(
+ consumer.topic,
+ prefix + "-" + consumer.topic,
+ consumer.handlerFn(auth),
+ )
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func unregisterConsumers(queue q.IQueue, consumers []consumerT, prefix string) {
+ for _, consumer := range consumers {
+ queue.Unsubscribe(consumer.topic, prefix + "-" + consumer.topic)
+ }
+}
+
+type resultT[T any] struct{
+ value T
+ err error
+}
+
+type taggedT[T any] struct{
+ id int
+ value T
+}
+
+func startRunner[A any, B any](
+ in <-chan taggedT[A],
+ out chan<- taggedT[B],
+ fn func(A) B,
+ done func(),
+) {
+ for input := range in {
+ out <- taggedT[B]{
+ id: input.id,
+ value: fn(input.value),
+ }
+ }
+ done()
+}
+
+func makePoolRunner[A any, B any](count int, fn func(A) B) (func(A) B, func()) {
+ var wg sync.WaitGroup
+ wg.Add(count)
+
+ in := make(chan taggedT[A])
+ out := make(chan taggedT[B])
+
+ for _ = range count {
+ go startRunner(in, out, fn, wg.Done)
+ }
+
+ var mutex sync.Mutex
+ m := map[int]chan B{}
+ id := 0
+ go func() {
+ for output := range out {
+ mutex.Lock()
+ defer mutex.Unlock()
+ m[output.id] <- output.value
+ close(m[output.id])
+ delete(m, output.id)
+ }
+ }()
+
+ poolRunFn := func(input A) B {
+ c := make(chan B)
+ {
+ mutex.Lock()
+ defer mutex.Unlock()
+ m[id] = c
+ id++
+ }
+
+ in <- taggedT[A]{
+ id: id,
+ value: input,
+ }
+ return <- c
+ }
+
+ close := func() {
+ close(in)
+ wg.Wait()
+ close(out)
+ }
+
+ return poolRunFn, close
+}
+
+func asResult[A any, B any](fn func(A) (B, error)) func(A) resultT[B] {
+ return func(input A) resultT[B] {
+ output, err := fn(input)
+ return resultT[B]{
+ value: output,
+ err: err,
+ }
+ }
+}
+
+func unwrapResult[A any, B any](fn func(A) resultT[B]) func(A) (B, error) {
+ return func(input A) (B, error) {
+ result := fn(input)
+ return result.value, result.err
+ }
+}
+
+func NewWithPrefix(databasePath string, prefix string) (IAuth, error) {
+ queue, err := q.New(databasePath)
+ if err != nil {
+ return authT{}, err
+ }
+
+ db, err := sql.Open(golite.DriverName, databasePath)
+ if err != nil {
+ return authT{}, err
+ }
+
+ queries, err := initDB(db, prefix)
+ if err != nil {
+ return authT{}, err
+ }
+
+ numCPU := runtime.NumCPU()
+ hasher, closeHasher := makePoolRunner(numCPU, asResult(scrypt.Hash))
+
+ closeFn := func() {
+ unregisterConsumers(queue, consumers, prefix)
+ closeHasher()
+ }
+
+ auth := authT{
+ queue: queue,
+ db: db,
+ queries: queries,
+ hasher: unwrapResult(hasher),
+ close: closeFn,
+ }
+
+ err = registerConsumers(auth, consumers, prefix)
+ if err != nil {
+ closeFn()
+ return authT{}, err
+ }
+
+ return auth, nil
+}
+
+func New(databasePath string) (IAuth, error) {
+ return NewWithPrefix(databasePath, defaultPrefix)
+}
+
+func newUserPayload(email string, salt []byte, pwhash []byte) ([]byte, error) {
+ data := make(map[string]interface{})
+ data["email"] = email
+ data["salt"] = hex.EncodeToString(salt)
+ data["pwhash"] = hex.EncodeToString(pwhash)
+ return json.Marshal(data)
+}
+
+func asPublicUser(user userT) (User, error) {
+ return User{}, nil
+}
+
+func (auth authT) Register(
+ email string,
+ password string,
+ confirmPassword string,
+) (User, error) {
+ if password != confirmPassword {
+ return User{}, ErrPassMismatch
+ }
+
+ if len(password) < scrypt.MinimumPasswordLength {
+ return User{}, ErrTooShort
+ }
+
+ // special check for sql.ErrNoRows to combat enumeration attacks.
+ // FIXME: how so?
+ _, lookupErr := auth.queries.byEmail(email)
+ if lookupErr != nil && lookupErr != sql.ErrNoRows {
+ return User{}, lookupErr
+ }
+
+ salt, err := scrypt.Salt()
+ if err != nil {
+ return User{}, err
+ }
+
+ input := scrypt.HashInput{
+ Password: []byte(password),
+ Salt: salt,
+ }
+ hash, err := auth.hasher(input)
+ if err != nil {
+ return User{}, err
+ }
+
+ /*
+ We also try to register anyway, to prevent disk IO timing attacks.
+ / FIXME: how so?
+ */
+
+ flowID := guuid.New()
+ waiter := auth.queue.WaitFor(NEW_USER, flowID, "register")
+ defer waiter.Close()
+
+ payload, err := newUserPayload(email, salt, hash)
+ if err != nil {
+ return User{}, err
+ }
+
+ unsent := q.UnsentMessage{
+ Topic: NEW_USER,
+ FlowID: flowID,
+ Payload: payload,
+ }
+ _, err = auth.queue.Publish(unsent)
+ if err != nil {
+ return User{}, err
+ }
+
+ user := userT{}
+ select {
+ case <-time.After(RegisterTimeout):
+ err = ErrTimeout
+ case <-waiter.Channel:
+ user, err = auth.queries.byEmail(email)
+ }
+
+ return asPublicUser(user)
+}
+
+func sendConfirmationMessage(
+ email string,
+ flowID guuid.UUID,
+) (q.UnsentMessage, error) {
+ data := make(map[string]interface{})
+ data["email"] = email
+ payload, err := json.Marshal(data)
+ if err != nil {
+ return q.UnsentMessage{}, err
+ }
+
+ return q.UnsentMessage{
+ Topic: SEND_CONFIRMATION_REQUEST,
+ FlowID: flowID,
+ Payload: payload,
+ }, nil
+}
+
+func (auth authT) ResendConfirmation(email string) error {
+ unsent, err := sendConfirmationMessage(email, guuid.New())
+ if err != nil {
+ return err
+ }
+
+ _, err = auth.queue.Publish(unsent)
+ return err
+}
+
+/*
+func asPublicSession(session sessionT) Session {
+ _, err := validateSession(session)
+ if err != nil {
+ panic(fmt.Errorf(
+ "session must have been validated at this stage: %w",
+ err,
+ ))
+ }
+ return Session{}
+}
+*/
+
+func asPublicSession(session sessionT) (Session, error) {
+ return Session{}, nil
+}
+
+func (auth authT) ConfirmEmail(token string) (Session, error) {
+ session, err := auth.queries.confirm("token FIXME", guuid.New())
+ if err != nil {
+ return Session{}, err
+ }
+
+ return asPublicSession(session)
+}
+
+func (auth authT) LoginEmail(
+ email string,
+ password string,
+) (Session, error) {
+ if len(password) < scrypt.MinimumPasswordLength {
+ return Session{}, ErrTooShort
+ }
+
+ // special check for sql.ErrNoRows to combat enumeration attacks.
+ // FIXME: how so?
+ user, err := auth.queries.byEmail(email)
+ if err != nil && err != sql.ErrNoRows {
+ return Session{}, err
+ }
+
+ input := scrypt.HashInput{
+ Password: []byte(password),
+ Salt: user.salt,
+ }
+ hash, err := auth.hasher(input)
+ if err != nil {
+ return Session{}, err
+ }
+
+ ok := slices.Equal(hash, user.pwhash)
+ if !ok {
+ return Session{}, ErrBadCombo
+ }
+
+ if !user.confirmed {
+ return Session{}, ErrUnconfirmed
+ }
+
+ session, err := auth.queries.login(user.uuid, guuid.New())
+ if err != nil {
+ return Session{}, err
+ }
+
+ return asPublicSession(session)
+}
+
+func forgotPasswordMessage(
+ email string,
+ flowID guuid.UUID,
+) (q.UnsentMessage, error) {
+ data := make(map[string]interface{})
+ data["email"] = email
+ payload, err := json.Marshal(data)
+ if err != nil {
+ return q.UnsentMessage{}, err
+ }
+
+ return q.UnsentMessage{
+ Topic: FORGOT_PASSWORD_REQUEST,
+ FlowID: flowID,
+ Payload: payload,
+ }, nil
+}
+
+func (auth authT) ForgotPassword(email string) error {
+ // special check for sql.ErrNoRows to combat enumeration attacks.
+ user, err := auth.queries.byEmail(email)
+ if err != nil && err != sql.ErrNoRows {
+ return err
+ }
+
+ unsent, err := forgotPasswordMessage(user.email, guuid.New())
+ if err != nil {
+ return err
+ }
+
+ _, err = auth.queue.Publish(unsent)
+ return err
+}
+
+/*
+func checkSession(session sessionT, now time.Time) error {
+ if session.revoked_at != nil {
+ return ErrRevokedSession
+ }
+
+ if session.timestamp.Add(SessionDuration).After(now) {
+ return ErrSessionExpired
+ }
+
+ return nil
+}
+
+func validateSession(
+ lookupFn func(guuid.UUID) (sessionT, error),
+ session sessionT,
+) error {
+ dbSession, err := lookupFn(session.uuid)
+ if err != nil {
+ return err
+ }
+
+ return checkSession(dbSession, time.Now())
+}
+*/
+
+func (auth authT) Refresh(session Session) (Session, error) {
+ /*
+ err := validateSession(auth.queries.byUUID, session)
+ if err != nil {
+ return sessionT{}, err
+ }
+ */
+
+ // return auth.queries.refresh(session.uuid, guuid.New())
+ newSession, err := auth.queries.refresh(session.UUID, guuid.New())
+ if err != nil {
+ return Session{}, err
+ }
+
+ return asPublicSession(newSession)
+}
+
+func (auth authT) ResetPassword(
+ token guuid.UUID,
+ password string,
+ confirmPassword string,
+) (Session, error) {
+ if password != confirmPassword {
+ return Session{}, ErrPassMismatch
+ }
+
+ if len(password) < scrypt.MinimumPasswordLength {
+ return Session{}, ErrTooShort
+ }
+
+ user := userT{}
+
+ input := scrypt.HashInput{
+ Password: []byte(password),
+ Salt: user.salt,
+ }
+ pwhash, err := auth.hasher(input)
+ if err != nil {
+ return Session{}, err
+ }
+
+ var nextFn func() (sessionT, error)
+ if user.confirmed {
+ nextFn = func() (sessionT, error) {
+ return auth.queries.reset(user.id, pwhash, token)
+ }
+ } else {
+ nextFn = func() (sessionT, error) {
+ // return auth.queries.confirm(user.id, pwhash, token)
+ return auth.queries.confirm("token FIXME", guuid.New())
+ }
+ }
+
+ session, err := nextFn()
+ if err != nil {
+ return Session{}, err
+ }
+
+ return asPublicSession(session)
+}
+
+func (auth authT) ChangePassword(
+ user User,
+ currentPassword string,
+ newPassword string,
+ confirmNewPassword string,
+) (Session, error) {
+ if newPassword != confirmNewPassword {
+ return Session{}, ErrPassMismatch
+ }
+
+ if len(newPassword) < scrypt.MinimumPasswordLength {
+ return Session{}, ErrTooShort
+ }
+
+ input := scrypt.HashInput{
+ Password: []byte(newPassword),
+ Salt: user.Salt,
+ }
+ pwhash, err := auth.hasher(input)
+ if err != nil {
+ return Session{}, err
+ }
+
+ if !user.Confirmed {
+ return Session{}, ErrUnconfirmed
+ }
+
+ session, err := auth.queries.change(user.UUID, pwhash)
+ if err != nil {
+ return Session{}, nil
+ }
+
+ return asPublicSession(session)
+}
+
+func runLogout(
+ lookupFn func(guuid.UUID) (sessionT, error),
+ sessionID guuid.UUID,
+ queryFn func(guuid.UUID) error,
+) error {
+ /*
+ err := validateSession(lookupFn, session)
+ if err != nil {
+ return err
+ }
+ */
+
+ return queryFn(sessionID)
+}
+
+func (auth authT) Logout(session Session) error {
+ return runLogout(
+ auth.queries.byUUID,
+ session.UUID,
+ auth.queries.logout,
+ )
+}
+
+func (auth authT) LogoutOthers(session Session) error {
+ return runLogout(
+ auth.queries.byUUID,
+ session.UUID,
+ auth.queries.outOthers,
+ )
+}
+
+func (auth authT) LogoutAll(session Session) error {
+ return runLogout(
+ auth.queries.byUUID,
+ session.UUID,
+ auth.queries.outAll,
+ )
+}
+
+func (auth authT) Close() error {
+ return g.WrapErrors(auth.queries.close(), auth.db.Close())
+}
+
+
+
+func Main() {
+ // FIXME
+}