diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/gracha.go | 209 |
1 files changed, 144 insertions, 65 deletions
diff --git a/src/gracha.go b/src/gracha.go index e41c20b..8800ee6 100644 --- a/src/gracha.go +++ b/src/gracha.go @@ -61,7 +61,7 @@ type queriesT struct{ login func(guuid.UUID, guuid.UUID) (sessionT, error) refresh func(guuid.UUID, guuid.UUID) (sessionT, error) reset func(int64, []byte, guuid.UUID) (sessionT, error) - change func(int64, []byte) (sessionT, error) + change func(guuid.UUID, []byte) (sessionT, error) byUUID func(guuid.UUID) (sessionT, error) logout func(guuid.UUID) error outOthers func(guuid.UUID) error @@ -79,6 +79,12 @@ type userT struct{ confirmed bool } +type User struct{ + UUID guuid.UUID + Salt []byte + Confirmed bool +} + type sessionT struct{ id int64 timestamp time.Time @@ -88,6 +94,10 @@ type sessionT struct{ revoked_at *time.Time } +type Session struct{ + UUID guuid.UUID +} + type consumerT struct{ topic string handlerFn func(authT) func(q.Message) error @@ -101,22 +111,27 @@ type authT struct{ } type IAuth interface{ - Register(string, string, string) (userT, error) + Register(string, string, string) (User, error) ResendConfirmation(string) error - ConfirmEmail(string) (sessionT, error) - LoginEmail(string, string) (sessionT, error) + ConfirmEmail(string) (Session, error) + LoginEmail(string, string) (Session, error) ForgotPassword(string) error - Refresh(sessionT) (sessionT, error) - ResetPassword(guuid.UUID, string, string) (sessionT, error) - ChangePassword(userT, string, string, string) (sessionT, error) - Logout(sessionT) error - LogoutOthers(sessionT) error - LogoutAll(sessionT) error + Refresh(Session) (Session, error) + ResetPassword(guuid.UUID, string, string) (Session, error) + ChangePassword(User, string, string, string) (Session, error) + Logout(Session) error + LogoutOthers(Session) error + LogoutAll(Session) error Close() error } +func validateSession(session sessionT) (sessionT, error) { + // FIXME: implement + return session, nil +} + func tryRollback(db *sql.DB, ctx context.Context, err error) error { _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;") if rollbackErr != nil { @@ -482,7 +497,7 @@ func confirmStmt( return sessionT{}, err } - return session, nil + return validateSession(session) } closeFn := func() error { @@ -610,7 +625,7 @@ func loginStmt( return sessionT{}, err } - return session, err + return validateSession(session) } return fn, sessionStmt.Close, nil @@ -684,7 +699,7 @@ func refreshStmt( return sessionT{}, err } - return session, err + return validateSession(session) } return fn, writeStmt.Close, nil @@ -713,7 +728,10 @@ func resetStmt( fn := func(id int64, pwhash []byte, token guuid.UUID) (sessionT, error) { var session sessionT err := writeStmt.QueryRow(id, pwhash, token).Scan(&session) - return session, err + if err != nil { + return sessionT{}, err + } + return validateSession(session) } return fn, writeStmt.Close, nil @@ -731,7 +749,7 @@ func changeSQL(prefix string) queryT { func changeStmt( db *sql.DB, prefix string, -) (func(int64, []byte) (sessionT, error), func() error, error) { +) (func(guuid.UUID, []byte) (sessionT, error), func() error, error) { q := changeSQL(prefix) writeStmt, err := db.Prepare(q.write) @@ -739,10 +757,13 @@ func changeStmt( return nil, nil, err } - fn := func(id int64, pwhash []byte) (sessionT, error) { + fn := func(uuid guuid.UUID, pwhash []byte) (sessionT, error) { var session sessionT - err := writeStmt.QueryRow(id, pwhash).Scan(&session) - return session, err + err := writeStmt.QueryRow(uuid, pwhash).Scan(&session) + if err != nil { + return sessionT{}, err + } + return validateSession(session) } return fn, writeStmt.Close, nil @@ -771,7 +792,10 @@ func byUUIDStmt( fn := func(sessionID guuid.UUID) (sessionT, error) { var session sessionT err := readStmt.QueryRow(sessionID).Scan(&session) - return session, err + if err != nil { + return sessionT{}, err + } + return validateSession(session) } return fn, readStmt.Close, nil @@ -957,7 +981,7 @@ func initDB( defer connMutex.RUnlock() return reset(a, b, c) }, - change: func(a int64, b []byte) (sessionT, error) { + change: func(a guuid.UUID, b []byte) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() return change(a, b) @@ -1176,29 +1200,33 @@ func newUserPayload(email string, salt []byte, pwhash []byte) ([]byte, error) { return json.Marshal(data) } +func asPublicUser(user userT) (User, error) { + return User{}, nil +} + func (auth authT) Register( email string, password string, confirmPassword string, -) (userT, error) { +) (User, error) { if password != confirmPassword { - return userT{}, ErrPassMismatch + return User{}, ErrPassMismatch } if len(password) < scrypt.MinimumPasswordLength { - return userT{}, ErrTooShort + return User{}, ErrTooShort } // special check for sql.ErrNoRows to combat enumeration attacks. // FIXME: how so? _, lookupErr := auth.queries.byEmail(email) if lookupErr != nil && lookupErr != sql.ErrNoRows { - return userT{}, lookupErr + return User{}, lookupErr } salt, err := scrypt.Salt() if err != nil { - return userT{}, err + return User{}, err } input := scrypt.HashInput{ @@ -1207,7 +1235,7 @@ func (auth authT) Register( } hash, err := auth.hasher(input) if err != nil { - return userT{}, err + return User{}, err } /* @@ -1221,7 +1249,7 @@ func (auth authT) Register( payload, err := newUserPayload(email, salt, hash) if err != nil { - return userT{}, err + return User{}, err } unsent := q.UnsentMessage{ @@ -1231,7 +1259,7 @@ func (auth authT) Register( } _, err = auth.queue.Publish(unsent) if err != nil { - return userT{}, err + return User{}, err } user := userT{} @@ -1242,7 +1270,7 @@ func (auth authT) Register( user, err = auth.queries.byEmail(email) } - return user, nil + return asPublicUser(user) } func sendConfirmationMessage( @@ -1273,23 +1301,45 @@ func (auth authT) ResendConfirmation(email string) error { return err } -func (auth authT) ConfirmEmail(token string) (sessionT, error) { - return auth.queries.confirm("token FIXME", guuid.New()) +/* +func asPublicSession(session sessionT) Session { + _, err := validateSession(session) + if err != nil { + panic(fmt.Errorf( + "session must have been validated at this stage: %w", + err, + )) + } + return Session{} +} +*/ + +func asPublicSession(session sessionT) (Session, error) { + return Session{}, nil +} + +func (auth authT) ConfirmEmail(token string) (Session, error) { + session, err := auth.queries.confirm("token FIXME", guuid.New()) + if err != nil { + return Session{}, err + } + + return asPublicSession(session) } func (auth authT) LoginEmail( email string, password string, -) (sessionT, error) { +) (Session, error) { if len(password) < scrypt.MinimumPasswordLength { - return sessionT{}, ErrTooShort + return Session{}, ErrTooShort } // special check for sql.ErrNoRows to combat enumeration attacks. // FIXME: how so? user, err := auth.queries.byEmail(email) if err != nil && err != sql.ErrNoRows { - return sessionT{}, err + return Session{}, err } input := scrypt.HashInput{ @@ -1298,24 +1348,24 @@ func (auth authT) LoginEmail( } hash, err := auth.hasher(input) if err != nil { - return sessionT{}, err + return Session{}, err } ok := slices.Equal(hash, user.pwhash) if !ok { - return sessionT{}, ErrBadCombo + return Session{}, ErrBadCombo } if !user.confirmed { - return sessionT{}, ErrUnconfirmed + return Session{}, ErrUnconfirmed } session, err := auth.queries.login(user.uuid, guuid.New()) if err != nil { - return sessionT{}, err + return Session{}, err } - return session, nil + return asPublicSession(session) } func forgotPasswordMessage( @@ -1352,6 +1402,7 @@ func (auth authT) ForgotPassword(email string) error { return err } +/* func checkSession(session sessionT, now time.Time) error { if session.revoked_at != nil { return ErrRevokedSession @@ -1375,27 +1426,36 @@ func validateSession( return checkSession(dbSession, time.Now()) } +*/ -func (auth authT) Refresh(session sessionT) (sessionT, error) { +func (auth authT) Refresh(session Session) (Session, error) { + /* err := validateSession(auth.queries.byUUID, session) if err != nil { return sessionT{}, err } + */ - return auth.queries.refresh(session.uuid, guuid.New()) + // return auth.queries.refresh(session.uuid, guuid.New()) + newSession, err := auth.queries.refresh(session.UUID, guuid.New()) + if err != nil { + return Session{}, err + } + + return asPublicSession(newSession) } func (auth authT) ResetPassword( token guuid.UUID, password string, confirmPassword string, -) (sessionT, error) { +) (Session, error) { if password != confirmPassword { - return sessionT{}, ErrPassMismatch + return Session{}, ErrPassMismatch } if len(password) < scrypt.MinimumPasswordLength { - return sessionT{}, ErrTooShort + return Session{}, ErrTooShort } user := userT{} @@ -1406,80 +1466,99 @@ func (auth authT) ResetPassword( } pwhash, err := auth.hasher(input) if err != nil { - return sessionT{}, err + return Session{}, err } + var nextFn func() (sessionT, error) if user.confirmed { - return auth.queries.reset(user.id, pwhash, token) + nextFn = func() (sessionT, error) { + return auth.queries.reset(user.id, pwhash, token) + } } else { - // return auth.queries.confirm(user.id, pwhash, token) - return auth.queries.confirm("token FIXME", guuid.New()) + nextFn = func() (sessionT, error) { + // return auth.queries.confirm(user.id, pwhash, token) + return auth.queries.confirm("token FIXME", guuid.New()) + } } + + session, err := nextFn() + if err != nil { + return Session{}, err + } + + return asPublicSession(session) } func (auth authT) ChangePassword( - user userT, + user User, currentPassword string, newPassword string, confirmNewPassword string, -) (sessionT, error) { +) (Session, error) { if newPassword != confirmNewPassword { - return sessionT{}, ErrPassMismatch + return Session{}, ErrPassMismatch } if len(newPassword) < scrypt.MinimumPasswordLength { - return sessionT{}, ErrTooShort + return Session{}, ErrTooShort } input := scrypt.HashInput{ Password: []byte(newPassword), - Salt: user.salt, + Salt: user.Salt, } pwhash, err := auth.hasher(input) if err != nil { - return sessionT{}, err + return Session{}, err } - if !user.confirmed { - return sessionT{}, ErrUnconfirmed + if !user.Confirmed { + return Session{}, ErrUnconfirmed } - return auth.queries.change(user.id, pwhash) + session, err := auth.queries.change(user.UUID, pwhash) + if err != nil { + return Session{}, nil + } + + return asPublicSession(session) } func runLogout( lookupFn func(guuid.UUID) (sessionT, error), - session sessionT, + sessionID guuid.UUID, queryFn func(guuid.UUID) error, ) error { + /* err := validateSession(lookupFn, session) if err != nil { return err } + */ - return queryFn(session.uuid) + return queryFn(sessionID) } -func (auth authT) Logout(session sessionT) error { +func (auth authT) Logout(session Session) error { return runLogout( auth.queries.byUUID, - session, + session.UUID, auth.queries.logout, ) } -func (auth authT) LogoutOthers(session sessionT) error { +func (auth authT) LogoutOthers(session Session) error { return runLogout( auth.queries.byUUID, - session, + session.UUID, auth.queries.outOthers, ) } -func (auth authT) LogoutAll(session sessionT) error { +func (auth authT) LogoutAll(session Session) error { return runLogout( auth.queries.byUUID, - session, + session.UUID, auth.queries.outAll, ) } |