diff options
Diffstat (limited to 'src/gracha.go')
-rw-r--r-- | src/gracha.go | 720 |
1 files changed, 498 insertions, 222 deletions
diff --git a/src/gracha.go b/src/gracha.go index e300f66..c240dd3 100644 --- a/src/gracha.go +++ b/src/gracha.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "runtime" + "slices" "sync" "time" @@ -27,7 +28,6 @@ const ( SEND_CONFIRMATION_REQUEST = "send-confirmation-request" FORGOT_PASSWORD_REQUEST = "forgot-password-request" - day = 24 * time.Hour ) @@ -48,17 +48,18 @@ var ( type queryT struct{ - write string - read string + write string + read string + session string } type queriesT struct{ + register func(guuid.UUID, string, []byte, []byte) (userT, error) + sendToken func(guuid.UUID, string) error + confirm func(string, guuid.UUID) (sessionT, error) byEmail func(string) (userT, error) - byToken func(guuid.UUID) (userT, error) - register func(string, []byte, []byte) (userT, error) - confirm func(guuid.UUID) (sessionT, error) - login func(string, string) (sessionT, error) - refresh func(guuid.UUID) (sessionT, error) + login func(guuid.UUID, guuid.UUID) (sessionT, error) + refresh func(guuid.UUID, guuid.UUID) (sessionT, error) reset func(int64, []byte, guuid.UUID) (sessionT, error) change func(int64, []byte) (sessionT, error) byUUID func(guuid.UUID) (sessionT, error) @@ -68,28 +69,23 @@ type queriesT struct{ close func() error } -type confirmationT struct{ -} - type userT struct{ id int64 timestamp time.Time uuid guuid.UUID email string - username *string salt []byte pwhash []byte - confirmed_at *time.Time metadata map[string]interface{} + confirmed bool } type sessionT struct{ id int64 - timestr string timestamp time.Time uuid guuid.UUID - user_id int64 - type_ string + userID guuid.UUID + // type_ string revoked_at *time.Time metadata map[string]interface{} } @@ -102,15 +98,14 @@ type consumerT struct{ type authT struct{ queries queriesT queue q.IQueue - hasher func(scrypt.HashInput) resultT[[]byte] - checker func(scrypt.CheckInput) resultT[bool] + hasher func(scrypt.HashInput) ([]byte, error) close func() } type IAuth interface{ Register(string, string, string) (userT, error) ResendConfirmation(string) error - ConfirmEmail(guuid.UUID) (sessionT, error) + ConfirmEmail(string) (sessionT, error) LoginEmail(string, string) (sessionT, error) ForgotPassword(string) error Refresh(sessionT) (sessionT, error) @@ -165,13 +160,25 @@ func createTablesSQL(prefix string) queryT { timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, email TEXT NOT NULL UNIQUE, - username TEXT UNIQUE, salt BLOB NOT NULL UNIQUE, pwhash BLOB NOT NULL, - confirmed_at TEXT, - confirmer_id INTEGER REFERENCES "%s"(id), metadata TEXT ); + CREATE TABLE IF NOT EXISTS "%s_confirmation_attempts" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + -- uuid BLOB NOT NULL UNIQUE, + user_id INTEGER NOT NULL REFERENCES "%s_users"(id), + token TEXT NOT NULL UNIQUE + ); + CREATE TABLE IF NOT EXISTS "%s_user_confirmations" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + user_id INTEGER NOT NULL + REFERENCES "%s_users"(id) UNIQUE, + attempt_id INTEGER NOT NULL + REFERENCES "%s_confirmation_attempts"(id) UNIQUE + ); CREATE TABLE IF NOT EXISTS "%s_user_changes" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), @@ -180,12 +187,12 @@ func createTablesSQL(prefix string) queryT { value TEXT NOT NULL, op BOOLEAN NOT NULL ); - CREATE TABLE IF NOT EXISTS "%s_tokens" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - uuid BLOB NOT NULL UNIQUE, - type TEXT NOT NULL - ); + -- CREATE TABLE IF NOT EXISTS "%s_tokens" ( + -- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + -- timestamp TEXT NOT NULL DEFAULT (%s), + -- uuid BLOB NOT NULL UNIQUE, + -- type TEXT NOT NULL + -- ); CREATE TABLE IF NOT EXISTS "%s_roles" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL REFERENCES "%s_users"(id), @@ -205,9 +212,10 @@ func createTablesSQL(prefix string) queryT { timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, user_id INTEGER NOT NULL REFERENCES "%s_users"(id), - type TEXT NOT NULL, - revoked_at TEXT, - revoker_id INTEGER REFERENCES "%s_users"(id), + -- type TEXT NOT NULL, + -- revoked_at TEXT, + -- revoker_id INTEGER REFERENCES "%s_users"(id), + -- FIXME: add provenance: login, refresh, confirmation, etc. metadata TEXT ); CREATE TABLE IF NOT EXISTS "%s_attempts" ( @@ -233,6 +241,12 @@ func createTablesSQL(prefix string) queryT { prefix, g.SQLiteNow, prefix, + g.SQLiteNow, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, prefix, g.SQLiteNow, prefix, @@ -266,50 +280,68 @@ func createTables(db *sql.DB, prefix string) error { }) } -func byEmailSQL(prefix string) queryT { +func registerSQL(prefix string) queryT { + const tmpl_write = ` + INSERT INTO "%s_users" (uuid, email, salt, pwhash) + VALUES (?, ?, ?, ?) RETURNING id, timestamp; + ` const tmpl_read = ` - SELECT id, timestamp, uuid, email, username, pwhash, metadata - FROM "%s_users" WHERE email = ?; + SELECT id, timestamp from "%s_users" + WHERE uuid = ?; ` return queryT{ - read: fmt.Sprintf(tmpl_read, prefix), + write: fmt.Sprintf(tmpl_write, prefix), + read: fmt.Sprintf(tmpl_read, prefix), } } -func byEmailStmt( +func registerStmt( db *sql.DB, prefix string, -) (func(string) (userT, error), func() error, error) { - q := byEmailSQL(prefix) +) ( + func(guuid.UUID, string, []byte, []byte) (userT, error), + func() error, + error, +) { + q := registerSQL(prefix) + + writeStmt, err := db.Prepare(q.write) + if err != nil { + return nil, nil, err + } readStmt, err := db.Prepare(q.read) if err != nil { return nil, nil, err } - fn := func(email string) (userT, error) { + fn := func( + userID guuid.UUID, + email string, + salt []byte, + pwhash []byte, + ) (userT, error) { user := userT{ - email: email, + uuid: userID, + email: email, + salt: salt, + pwhash: pwhash, + confirmed: false, } - var ( - timestr string - uuid_bytes []byte - ) - err := readStmt.QueryRow(email).Scan( - &user.id, - ×tr, - &uuid_bytes, - &user.username, - &user.pwhash, - &user.metadata, - ) + var timestr string + user_id_bytes := userID[:] + err := writeStmt.QueryRow( + user_id_bytes, + email, + salt, + pwhash, + ).Scan(&user.id, ×tr) if err != nil { return userT{}, err } - user.uuid = guuid.UUID(uuid_bytes) - user.timestamp, err = time.Parse(time.RFC3339Nano,timestr) + user.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return userT{}, err } @@ -317,165 +349,329 @@ func byEmailStmt( return user, nil } - return fn, readStmt.Close, nil + closeFn := func() error { + return g.SomeFnError(writeStmt.Close, readStmt.Close) + } + + return fn, closeFn, nil } -func byTokenSQL(prefix string) queryT{ - const tmpl_read = ` - SELECT id, timestamp, uuid, email, username, pwhash, metadata - FROM "%s" WHERE email = ?; +func sendTokenSQL(prefix string) queryT { + const tmpl_write = ` + INSERT INTO "%s_confirmation_attempts" (user_id, token) + VALUES ( + (SELECT id FROM "%s_users" WHERE uuid = ?), + ? + ) ` return queryT{ - read: fmt.Sprintf(tmpl_read, prefix), + write: fmt.Sprintf(tmpl_write, prefix, prefix), } } -func byTokenStmt( +func sendTokenStmt( db *sql.DB, prefix string, -) (func(guuid.UUID) (userT, error), func() error, error) { - q := byTokenSQL(prefix) +) (func(guuid.UUID, string) error, func() error, error) { + q := sendTokenSQL(prefix) - readStmt, err := db.Prepare(q.read) + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } - fn := func(token guuid.UUID) (userT, error) { - var user userT - // FIXME: build user - err := readStmt.QueryRow(token).Scan(&user.id) - return user, err + fn := func(userID guuid.UUID, token string) error { + user_id_bytes := userID[:] + _, err := writeStmt.Exec(user_id_bytes, token) + return err } - return fn, readStmt.Close, nil + return fn, writeStmt.Close, nil } -func registerSQL(prefix string) queryT { +func confirmSQL(prefix string) queryT { const tmpl_write = ` - INSERT INTO "%s" (uuid, email, username, salt, pwhash, metadata) - VALUES (?, ?, ?, ?, ?, ?); + INSERT INTO "%s_user_confirmations" (user_id, attempt_id) + VALUES (?, ?); + ` + const tmpl_read = ` + SELECT + "%s_confirmation_attempts".id, + "%s_confirmation_attempts".user_id, + "%s_users".uuid + FROM "%s_confirmation_attempts" + JOIN "%s_users" ON + "%s_confirmation_attempts".user_id = "%s_users".id + WHERE token = ?; + ` + const tmpl_session = ` + INSERT INTO "%s_sessions" (uuid, user_id) + VALUES (?, ?) RETURNING id, timestamp; ` return queryT{ - write: fmt.Sprintf(tmpl_write, prefix), + write: fmt.Sprintf(tmpl_write, prefix), + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + session: fmt.Sprintf(tmpl_session, prefix), } } -func registerStmt( +func confirmStmt( db *sql.DB, prefix string, -) (func(string, []byte, []byte) (userT, error), func() error, error) { - q := registerSQL(prefix) +) (func(string, guuid.UUID) (sessionT, error), func() error, error) { + q := confirmSQL(prefix) writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } - fn := func(email string, salt []byte, pwhash []byte) (userT, error) { - /* - timestamp TEXT NOT NULL DEFAULT (%s), - uuid BLOB NOT NULL UNIQUE, - email TEXT NOT NULL UNIQUE, - username TEXT UNIQUE, - pwhash TEXT NOT NULL, - metadata TEXT - */ + readStmt, err := db.Prepare(q.read) + if err != nil { + return nil, nil, g.WrapErrors(writeStmt.Close(), err) + } - var user userT - // err := stmt.QueryRow( - ret, err := writeStmt.Exec( - guuid.New(), - email, - "credentials.username", - salt, - pwhash, - "credentials.metadata", - // ).Scan(&user.email) - // FIXME: finish + sessionStmt, err := db.Prepare(q.session) + if err != nil { + return nil, nil, g.WrapErrors( + writeStmt.Close(), + readStmt.Close(), + err, ) - if false { - fmt.Printf("ret: %#v\n", ret) - fmt.Printf("user: %#v\n", user) + } + + fn := func(token string, sessionID guuid.UUID) (sessionT, error) { + session := sessionT{ + uuid: sessionID, } - return user, err + + var ( + user_id int64 + attempt_id int64 + user_id_bytes []byte + ) + err := readStmt.QueryRow(token).Scan( + &attempt_id, + &user_id, + &user_id_bytes, + ) + if err != nil { + return sessionT{}, err + } + session.userID = guuid.UUID(user_id_bytes) + + _, err = writeStmt.Exec(user_id, attempt_id) + if err != nil { + return sessionT{}, err + } + + session_id_bytes := sessionID[:] + var timestr string + err = sessionStmt.QueryRow( + session_id_bytes, + user_id, + ).Scan(&session.id, ×tr) + if err != nil { + return sessionT{}, err + } + + session.timestamp, err = time.Parse(time.RFC3339Nano, timestr) + if err != nil { + return sessionT{}, err + } + + return session, nil } - return fn, writeStmt.Close, nil + closeFn := func() error { + return g.SomeFnError( + writeStmt.Close, + readStmt.Close, + sessionStmt.Close, + ) + } + + return fn, closeFn, nil } -func confirmSQL(prefix string) queryT { - const tmpl_write = ` - -- INSERT SOMETHING %s +func byEmailSQL(prefix string) queryT { + // FIXME: rewrite as LEFT JOIN? + const tmpl_read = ` + SELECT id, timestamp, uuid, salt, pwhash, metadata, ( + CASE WHEN EXISTS ( + SELECT id FROM "%s_user_confirmations" + WHERE user_id = ( + SELECT id FROM "%s_users" + WHERE email = ? + ) + ) THEN 1 + ELSE 0 + END + ) as confirmed + FROM "%s_users" WHERE email = ?; ` return queryT{ - write: fmt.Sprintf(tmpl_write, prefix), + read: fmt.Sprintf(tmpl_read, prefix, prefix, prefix), } } -func confirmStmt( +func byEmailStmt( db *sql.DB, prefix string, -) (func(guuid.UUID) (sessionT, error), func() error, error) { - q := confirmSQL(prefix) +) (func(string) (userT, error), func() error, error) { + q := byEmailSQL(prefix) - writeStmt, err := db.Prepare(q.write) + readStmt, err := db.Prepare(q.read) if err != nil { return nil, nil, err } - fn := func(token guuid.UUID) (sessionT, error) { - var session sessionT - err := writeStmt.QueryRow(token).Scan(&session) - return session, err + fn := func(email string) (userT, error) { + user := userT{ + email: email, + } + + var ( + timestr string + user_id_bytes []byte + metadatastr sql.NullString + ) + err := readStmt.QueryRow(email, email).Scan( + &user.id, + ×tr, + &user_id_bytes, + &user.salt, + &user.pwhash, + &metadatastr, + &user.confirmed, + ) + if err != nil { + return userT{}, err + } + user.uuid = guuid.UUID(user_id_bytes) + + user.timestamp, err = time.Parse(time.RFC3339Nano,timestr) + if err != nil { + return userT{}, err + } + + if metadatastr.Valid { + err := json.Unmarshal( + []byte(metadatastr.String), + &user.metadata, + ) + if err != nil { + g.Warning( + "failed to parse metadata field", + "sqlite-json-unmarshal-error", + "userID", user.uuid.String(), + "error", err, + ) + } + } + + return user, nil } - return fn, writeStmt.Close, nil + return fn, readStmt.Close, nil } func loginSQL(prefix string) queryT { - const tmpl_write = ` - -- INSERT INTO "%s" (t3, t4) VALUES (?, ?); + const tmpl_session = ` + INSERT INTO "%s_sessions" (uuid, user_id) + VALUES ( + ?, + (SELECT id FROM "%s_users" WHERE uuid = ?) + ) RETURNING id, timestamp; ` return queryT{ - write: fmt.Sprintf(tmpl_write, prefix), + session: fmt.Sprintf(tmpl_session, prefix, prefix), } } func loginStmt( db *sql.DB, prefix string, -) (func(string, string) (sessionT, error), func() error, error) { +) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) { q := loginSQL(prefix) - writeStmt, err := db.Prepare(q.write) + sessionStmt, err := db.Prepare(q.session) if err != nil { return nil, nil, err } - fn := func(email string, pwhash string) (sessionT, error) { - var session sessionT - err := writeStmt.QueryRow(email, pwhash).Scan(session) - // FIXME: finish + fn := func(userID guuid.UUID, sessionID guuid.UUID) (sessionT, error) { + session := sessionT{ + uuid: sessionID, + userID: userID, + } + + user_id_bytes := userID[:] + session_id_bytes := sessionID[:] + var timestr string + err := sessionStmt.QueryRow( + session_id_bytes, + user_id_bytes, + ).Scan( + &session.id, + ×tr, + ) + if err != nil { + return sessionT{}, err + } + + session.timestamp, err = time.Parse(time.RFC3339Nano, timestr) + if err != nil { + return sessionT{}, err + } + return session, err } - return fn, writeStmt.Close, nil + return fn, sessionStmt.Close, nil } func refreshSQL(prefix string) queryT { const tmpl_write = ` - -- INSERT SOMETHING %s + INSERT INTO "%s_sessions" (uuid, user_id) + VALUES ( + ?, + (SELECT user_id FROM "%s_sessions" WHERE uuid = ?) + ) RETURNING id, timestamp, ( + SELECT "%s_users".uuid FROM "%s_users" + JOIN "%s_sessions" ON + "%s_users".id = "%s_sessions".user_id + WHERE "%s_sessions".uuid = ? + ) AS userID; ` return queryT{ - write: fmt.Sprintf(tmpl_write, prefix), + write: fmt.Sprintf( + tmpl_write, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), } } func refreshStmt( db *sql.DB, prefix string, -) (func(guuid.UUID) (sessionT, error), func() error, error) { +) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) { q := refreshSQL(prefix) writeStmt, err := db.Prepare(q.write) @@ -483,9 +679,35 @@ func refreshStmt( return nil, nil, err } - fn := func(uuid guuid.UUID) (sessionT, error) { - var session sessionT - err := writeStmt.QueryRow(uuid).Scan(&session) + fn := func( + sessionID guuid.UUID, + newSessionID guuid.UUID, + ) (sessionT, error) { + session := sessionT{ + uuid: newSessionID, + } + + session_id_bytes := sessionID[:] + new_session_id_bytes := newSessionID[:] + var ( + timestr string + user_id_bytes []byte + ) + err := writeStmt.QueryRow( + new_session_id_bytes, + session_id_bytes, + session_id_bytes, + ).Scan(&session.id, ×tr, &user_id_bytes) + if err != nil { + return sessionT{}, err + } + session.userID = guuid.UUID(user_id_bytes) + + session.timestamp, err = time.Parse(time.RFC3339Nano, timestr) + if err != nil { + return sessionT{}, err + } + return session, err } @@ -570,9 +792,9 @@ func byUUIDStmt( return nil, nil, err } - fn := func(uuid guuid.UUID) (sessionT, error) { + fn := func(sessionID guuid.UUID) (sessionT, error) { var session sessionT - err := readStmt.QueryRow(uuid).Scan(&session) + err := readStmt.QueryRow(sessionID).Scan(&session) return session, err } @@ -599,8 +821,8 @@ func logoutStmt( return nil, nil, err } - fn := func(uuid guuid.UUID) error { - _, err := writeStmt.Exec(uuid) + fn := func(sessionID guuid.UUID) error { + _, err := writeStmt.Exec(sessionID) return err } @@ -627,8 +849,8 @@ func outOthersStmt( return nil, nil, err } - fn := func(uuid guuid.UUID) error { - _, err := writeStmt.Exec(uuid) + fn := func(sessionID guuid.UUID) error { + _, err := writeStmt.Exec(sessionID) return err } @@ -655,8 +877,8 @@ func outAllStmt( return nil, nil, err } - fn := func(uuid guuid.UUID) error { - _, err := writeStmt.Exec(uuid) + fn := func(sessionID guuid.UUID) error { + _, err := writeStmt.Exec(sessionID) return err } @@ -668,10 +890,10 @@ func initDB( prefix string, ) (queriesT, error) { createTablesErr := createTables(db, prefix) - byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix) - byToken, byTokenClose, byTokenErr := byTokenStmt(db, prefix) register, registerClose, registerErr := registerStmt(db, prefix) + sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix) confirm, confirmClose, confirmErr := confirmStmt(db, prefix) + byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix) login, loginClose, loginErr := loginStmt(db, prefix) refresh, refreshClose, refreshErr := refreshStmt(db, prefix) reset, resetClose, resetErr := resetStmt(db, prefix) @@ -683,10 +905,10 @@ func initDB( err := g.SomeError( createTablesErr, - byEmailErr, - byTokenErr, registerErr, + sendTokenErr, confirmErr, + byEmailErr, loginErr, refreshErr, resetErr, @@ -702,10 +924,10 @@ func initDB( close := func() error { return g.SomeFnError( - byEmailClose, - byTokenClose, registerClose, + sendTokenClose, confirmClose, + byEmailClose, loginClose, refreshClose, resetClose, @@ -717,21 +939,78 @@ func initDB( ) } - // FIXME: lock + var connMutex sync.RWMutex return queriesT{ - byEmail: byEmail, - byToken: byToken, - register: register, - confirm: confirm, - login: login, - refresh: refresh, - reset: reset, - change: change, - byUUID: byUUID, - logout: logout, - outOthers: outOthers, - outAll: outAll, - close: close, + register: func( + a guuid.UUID, + b string, + c []byte, + d []byte, + ) (userT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return register(a, b, c, d) + }, + sendToken: func(a guuid.UUID, b string) error { + connMutex.RLock() + defer connMutex.RUnlock() + return sendToken(a, b) + }, + confirm: func(a string, b guuid.UUID) (sessionT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return confirm(a, b) + }, + byEmail: func(a string) (userT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return byEmail(a) + }, + login: func(a guuid.UUID, b guuid.UUID) (sessionT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return login(a, b) + }, + refresh: func(a guuid.UUID, b guuid.UUID) (sessionT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return refresh(a, b) + }, + reset: func(a int64, b []byte, c guuid.UUID) (sessionT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return reset(a, b, c) + }, + change: func(a int64, b []byte) (sessionT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return change(a, b) + }, + byUUID: func(a guuid.UUID) (sessionT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return byUUID(a) + }, + logout: func(a guuid.UUID) error { + connMutex.RLock() + defer connMutex.RUnlock() + return logout(a) + }, + outOthers: func(a guuid.UUID) error { + connMutex.RLock() + defer connMutex.RUnlock() + return outOthers(a) + }, + outAll: func(a guuid.UUID) error { + connMutex.RLock() + defer connMutex.RUnlock() + return outAll(a) + }, + close: func() error { + connMutex.Lock() + defer connMutex.Unlock() + return close() + }, }, nil } @@ -767,13 +1046,24 @@ var consumers = []consumerT{ handlerFn: forgotPasswordRequestHandler, }, } -func registerConsumers(auth authT, consumers []consumerT) { +func registerConsumers(auth authT, consumers []consumerT, prefix string) error { for _, consumer := range consumers { - auth.queue.Subscribe( + err := auth.queue.Subscribe( consumer.topic, - defaultPrefix + "-" + consumer.topic, + prefix + "-" + consumer.topic, consumer.handlerFn(auth), ) + if err != nil { + return err + } + } + + return nil +} + +func unregisterConsumers(queue q.IQueue, consumers []consumerT, prefix string) { + for _, consumer := range consumers { + queue.Unsubscribe(consumer.topic, prefix + "-" + consumer.topic) } } @@ -861,28 +1151,41 @@ func asResult[A any, B any](fn func(A) (B, error)) func(A) resultT[B] { } } +func unwrapResult[A any, B any](fn func(A) resultT[B]) func(A) (B, error) { + return func(input A) (B, error) { + result := fn(input) + return result.value, result.err + } +} + func NewWithPrefix(db *sql.DB, queue q.IQueue, prefix string) (IAuth, error) { queries, err := initDB(db, prefix) if err != nil { return authT{}, err } - numCPU := (runtime.NumCPU() / 2) + 1 - hasher, closeHasher := makePoolRunner(numCPU, asResult(scrypt.Hash)) - checker, closeChecker := makePoolRunner(numCPU, asResult(scrypt.Check)) + numCPU := runtime.NumCPU() + hasher, closeHasher := makePoolRunner(numCPU, asResult(scrypt.Hash)) - close := func() { + closeFn := func() { + unregisterConsumers(queue, consumers, prefix) closeHasher() - closeChecker() } - return authT{ + auth := authT{ queries: queries, queue: queue, - hasher: hasher, - checker: checker, - close: close, - }, nil + hasher: unwrapResult(hasher), + close: closeFn, + } + + err = registerConsumers(auth, consumers, prefix) + if err != nil { + closeFn() + return authT{}, err + } + + return auth, nil } func New(db *sql.DB, queue q.IQueue) (IAuth, error) { @@ -926,7 +1229,10 @@ func (auth authT) Register( Password: []byte(password), Salt: salt, } - result := auth.hasher(input) + hash, err := auth.hasher(input) + if err != nil { + return userT{}, err + } /* We also try to register anyway, to prevent disk IO timing attacks. @@ -937,7 +1243,7 @@ func (auth authT) Register( waiter := auth.queue.WaitFor(NEW_USER, flowID, "register") defer waiter.Close() - payload, err := newUserPayload(email, salt, result.value) + payload, err := newUserPayload(email, salt, hash) if err != nil { return userT{}, err } @@ -987,13 +1293,12 @@ func (auth authT) ResendConfirmation(email string) error { return err } - _, err = auth.queue.Publish(unsent) return err } -func (auth authT) ConfirmEmail(token guuid.UUID) (sessionT, error) { - return auth.queries.confirm(token) +func (auth authT) ConfirmEmail(token string) (sessionT, error) { + return auth.queries.confirm("token FIXME", guuid.New()) } func (auth authT) LoginEmail( @@ -1005,34 +1310,36 @@ func (auth authT) LoginEmail( } // special check for sql.ErrNoRows to combat enumeration attacks. + // FIXME: how so? user, err := auth.queries.byEmail(email) if err != nil && err != sql.ErrNoRows { return sessionT{}, err } - input := scrypt.CheckInput{ + input := scrypt.HashInput{ Password: []byte(password), Salt: user.salt, - Hash: user.pwhash, } - ok, err := scrypt.Check(input) + hash, err := auth.hasher(input) if err != nil { return sessionT{}, err } + + ok := slices.Equal(hash, user.pwhash) if !ok { return sessionT{}, ErrBadCombo } - if user.confirmed_at == nil { + if !user.confirmed { return sessionT{}, ErrUnconfirmed } - dbSession, err := auth.queries.login(email, password) + session, err := auth.queries.login(user.uuid, guuid.New()) if err != nil { return sessionT{}, err } - return dbSession, nil + return session, nil } func forgotPasswordMessage( @@ -1099,7 +1406,7 @@ func (auth authT) Refresh(session sessionT) (sessionT, error) { return sessionT{}, err } - return auth.queries.refresh(session.uuid) + return auth.queries.refresh(session.uuid, guuid.New()) } func (auth authT) ResetPassword( @@ -1115,25 +1422,22 @@ func (auth authT) ResetPassword( return sessionT{}, ErrTooShort } - user, err := auth.queries.byToken(token) - if err != nil { - return sessionT{}, err - } + user := userT{} input := scrypt.HashInput{ Password: []byte(password), Salt: user.salt, } - pwhash, err := scrypt.Hash(input) + pwhash, err := auth.hasher(input) if err != nil { return sessionT{}, err } - if user.confirmed_at != nil { + if user.confirmed { return auth.queries.reset(user.id, pwhash, token) } else { // return auth.queries.confirm(user.id, pwhash, token) - return auth.queries.confirm(token) + return auth.queries.confirm("token FIXME", guuid.New()) } } @@ -1151,17 +1455,16 @@ func (auth authT) ChangePassword( return sessionT{}, ErrTooShort } - // FIXME input := scrypt.HashInput{ Password: []byte(newPassword), Salt: user.salt, } - pwhash, err := scrypt.Hash(input) + pwhash, err := auth.hasher(input) if err != nil { return sessionT{}, err } - if user.confirmed_at == nil { + if !user.confirmed { return sessionT{}, ErrUnconfirmed } @@ -1212,32 +1515,5 @@ func (auth authT) Close() error { func Main() { - g.Init() - db, err := sql.Open("acude", "file:gracha.db?mode=memory&cache=shared") - if err != nil { - panic(err) - } - defer db.Close() - - queue, err := q.New(db) - if err != nil { - panic(err) - } - defer queue.Close() - - auth, err := New(db, queue) - if err != nil { - fmt.Println(err) - panic(err) - } - - user, err := auth.Register("contact@example.com", "password", "password") - if false { - fmt.Printf("user: %#v\n", user) - fmt.Printf("err: %#v\n", err) - } - - return - fmt.Printf("q: %#v\n", queue) - fmt.Printf("auth: %#v\n", auth) + // FIXME } |