summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/gracha.go209
1 files changed, 144 insertions, 65 deletions
diff --git a/src/gracha.go b/src/gracha.go
index e41c20b..8800ee6 100644
--- a/src/gracha.go
+++ b/src/gracha.go
@@ -61,7 +61,7 @@ type queriesT struct{
login func(guuid.UUID, guuid.UUID) (sessionT, error)
refresh func(guuid.UUID, guuid.UUID) (sessionT, error)
reset func(int64, []byte, guuid.UUID) (sessionT, error)
- change func(int64, []byte) (sessionT, error)
+ change func(guuid.UUID, []byte) (sessionT, error)
byUUID func(guuid.UUID) (sessionT, error)
logout func(guuid.UUID) error
outOthers func(guuid.UUID) error
@@ -79,6 +79,12 @@ type userT struct{
confirmed bool
}
+type User struct{
+ UUID guuid.UUID
+ Salt []byte
+ Confirmed bool
+}
+
type sessionT struct{
id int64
timestamp time.Time
@@ -88,6 +94,10 @@ type sessionT struct{
revoked_at *time.Time
}
+type Session struct{
+ UUID guuid.UUID
+}
+
type consumerT struct{
topic string
handlerFn func(authT) func(q.Message) error
@@ -101,22 +111,27 @@ type authT struct{
}
type IAuth interface{
- Register(string, string, string) (userT, error)
+ Register(string, string, string) (User, error)
ResendConfirmation(string) error
- ConfirmEmail(string) (sessionT, error)
- LoginEmail(string, string) (sessionT, error)
+ ConfirmEmail(string) (Session, error)
+ LoginEmail(string, string) (Session, error)
ForgotPassword(string) error
- Refresh(sessionT) (sessionT, error)
- ResetPassword(guuid.UUID, string, string) (sessionT, error)
- ChangePassword(userT, string, string, string) (sessionT, error)
- Logout(sessionT) error
- LogoutOthers(sessionT) error
- LogoutAll(sessionT) error
+ Refresh(Session) (Session, error)
+ ResetPassword(guuid.UUID, string, string) (Session, error)
+ ChangePassword(User, string, string, string) (Session, error)
+ Logout(Session) error
+ LogoutOthers(Session) error
+ LogoutAll(Session) error
Close() error
}
+func validateSession(session sessionT) (sessionT, error) {
+ // FIXME: implement
+ return session, nil
+}
+
func tryRollback(db *sql.DB, ctx context.Context, err error) error {
_, rollbackErr := db.ExecContext(ctx, "ROLLBACK;")
if rollbackErr != nil {
@@ -482,7 +497,7 @@ func confirmStmt(
return sessionT{}, err
}
- return session, nil
+ return validateSession(session)
}
closeFn := func() error {
@@ -610,7 +625,7 @@ func loginStmt(
return sessionT{}, err
}
- return session, err
+ return validateSession(session)
}
return fn, sessionStmt.Close, nil
@@ -684,7 +699,7 @@ func refreshStmt(
return sessionT{}, err
}
- return session, err
+ return validateSession(session)
}
return fn, writeStmt.Close, nil
@@ -713,7 +728,10 @@ func resetStmt(
fn := func(id int64, pwhash []byte, token guuid.UUID) (sessionT, error) {
var session sessionT
err := writeStmt.QueryRow(id, pwhash, token).Scan(&session)
- return session, err
+ if err != nil {
+ return sessionT{}, err
+ }
+ return validateSession(session)
}
return fn, writeStmt.Close, nil
@@ -731,7 +749,7 @@ func changeSQL(prefix string) queryT {
func changeStmt(
db *sql.DB,
prefix string,
-) (func(int64, []byte) (sessionT, error), func() error, error) {
+) (func(guuid.UUID, []byte) (sessionT, error), func() error, error) {
q := changeSQL(prefix)
writeStmt, err := db.Prepare(q.write)
@@ -739,10 +757,13 @@ func changeStmt(
return nil, nil, err
}
- fn := func(id int64, pwhash []byte) (sessionT, error) {
+ fn := func(uuid guuid.UUID, pwhash []byte) (sessionT, error) {
var session sessionT
- err := writeStmt.QueryRow(id, pwhash).Scan(&session)
- return session, err
+ err := writeStmt.QueryRow(uuid, pwhash).Scan(&session)
+ if err != nil {
+ return sessionT{}, err
+ }
+ return validateSession(session)
}
return fn, writeStmt.Close, nil
@@ -771,7 +792,10 @@ func byUUIDStmt(
fn := func(sessionID guuid.UUID) (sessionT, error) {
var session sessionT
err := readStmt.QueryRow(sessionID).Scan(&session)
- return session, err
+ if err != nil {
+ return sessionT{}, err
+ }
+ return validateSession(session)
}
return fn, readStmt.Close, nil
@@ -957,7 +981,7 @@ func initDB(
defer connMutex.RUnlock()
return reset(a, b, c)
},
- change: func(a int64, b []byte) (sessionT, error) {
+ change: func(a guuid.UUID, b []byte) (sessionT, error) {
connMutex.RLock()
defer connMutex.RUnlock()
return change(a, b)
@@ -1176,29 +1200,33 @@ func newUserPayload(email string, salt []byte, pwhash []byte) ([]byte, error) {
return json.Marshal(data)
}
+func asPublicUser(user userT) (User, error) {
+ return User{}, nil
+}
+
func (auth authT) Register(
email string,
password string,
confirmPassword string,
-) (userT, error) {
+) (User, error) {
if password != confirmPassword {
- return userT{}, ErrPassMismatch
+ return User{}, ErrPassMismatch
}
if len(password) < scrypt.MinimumPasswordLength {
- return userT{}, ErrTooShort
+ return User{}, ErrTooShort
}
// special check for sql.ErrNoRows to combat enumeration attacks.
// FIXME: how so?
_, lookupErr := auth.queries.byEmail(email)
if lookupErr != nil && lookupErr != sql.ErrNoRows {
- return userT{}, lookupErr
+ return User{}, lookupErr
}
salt, err := scrypt.Salt()
if err != nil {
- return userT{}, err
+ return User{}, err
}
input := scrypt.HashInput{
@@ -1207,7 +1235,7 @@ func (auth authT) Register(
}
hash, err := auth.hasher(input)
if err != nil {
- return userT{}, err
+ return User{}, err
}
/*
@@ -1221,7 +1249,7 @@ func (auth authT) Register(
payload, err := newUserPayload(email, salt, hash)
if err != nil {
- return userT{}, err
+ return User{}, err
}
unsent := q.UnsentMessage{
@@ -1231,7 +1259,7 @@ func (auth authT) Register(
}
_, err = auth.queue.Publish(unsent)
if err != nil {
- return userT{}, err
+ return User{}, err
}
user := userT{}
@@ -1242,7 +1270,7 @@ func (auth authT) Register(
user, err = auth.queries.byEmail(email)
}
- return user, nil
+ return asPublicUser(user)
}
func sendConfirmationMessage(
@@ -1273,23 +1301,45 @@ func (auth authT) ResendConfirmation(email string) error {
return err
}
-func (auth authT) ConfirmEmail(token string) (sessionT, error) {
- return auth.queries.confirm("token FIXME", guuid.New())
+/*
+func asPublicSession(session sessionT) Session {
+ _, err := validateSession(session)
+ if err != nil {
+ panic(fmt.Errorf(
+ "session must have been validated at this stage: %w",
+ err,
+ ))
+ }
+ return Session{}
+}
+*/
+
+func asPublicSession(session sessionT) (Session, error) {
+ return Session{}, nil
+}
+
+func (auth authT) ConfirmEmail(token string) (Session, error) {
+ session, err := auth.queries.confirm("token FIXME", guuid.New())
+ if err != nil {
+ return Session{}, err
+ }
+
+ return asPublicSession(session)
}
func (auth authT) LoginEmail(
email string,
password string,
-) (sessionT, error) {
+) (Session, error) {
if len(password) < scrypt.MinimumPasswordLength {
- return sessionT{}, ErrTooShort
+ return Session{}, ErrTooShort
}
// special check for sql.ErrNoRows to combat enumeration attacks.
// FIXME: how so?
user, err := auth.queries.byEmail(email)
if err != nil && err != sql.ErrNoRows {
- return sessionT{}, err
+ return Session{}, err
}
input := scrypt.HashInput{
@@ -1298,24 +1348,24 @@ func (auth authT) LoginEmail(
}
hash, err := auth.hasher(input)
if err != nil {
- return sessionT{}, err
+ return Session{}, err
}
ok := slices.Equal(hash, user.pwhash)
if !ok {
- return sessionT{}, ErrBadCombo
+ return Session{}, ErrBadCombo
}
if !user.confirmed {
- return sessionT{}, ErrUnconfirmed
+ return Session{}, ErrUnconfirmed
}
session, err := auth.queries.login(user.uuid, guuid.New())
if err != nil {
- return sessionT{}, err
+ return Session{}, err
}
- return session, nil
+ return asPublicSession(session)
}
func forgotPasswordMessage(
@@ -1352,6 +1402,7 @@ func (auth authT) ForgotPassword(email string) error {
return err
}
+/*
func checkSession(session sessionT, now time.Time) error {
if session.revoked_at != nil {
return ErrRevokedSession
@@ -1375,27 +1426,36 @@ func validateSession(
return checkSession(dbSession, time.Now())
}
+*/
-func (auth authT) Refresh(session sessionT) (sessionT, error) {
+func (auth authT) Refresh(session Session) (Session, error) {
+ /*
err := validateSession(auth.queries.byUUID, session)
if err != nil {
return sessionT{}, err
}
+ */
- return auth.queries.refresh(session.uuid, guuid.New())
+ // return auth.queries.refresh(session.uuid, guuid.New())
+ newSession, err := auth.queries.refresh(session.UUID, guuid.New())
+ if err != nil {
+ return Session{}, err
+ }
+
+ return asPublicSession(newSession)
}
func (auth authT) ResetPassword(
token guuid.UUID,
password string,
confirmPassword string,
-) (sessionT, error) {
+) (Session, error) {
if password != confirmPassword {
- return sessionT{}, ErrPassMismatch
+ return Session{}, ErrPassMismatch
}
if len(password) < scrypt.MinimumPasswordLength {
- return sessionT{}, ErrTooShort
+ return Session{}, ErrTooShort
}
user := userT{}
@@ -1406,80 +1466,99 @@ func (auth authT) ResetPassword(
}
pwhash, err := auth.hasher(input)
if err != nil {
- return sessionT{}, err
+ return Session{}, err
}
+ var nextFn func() (sessionT, error)
if user.confirmed {
- return auth.queries.reset(user.id, pwhash, token)
+ nextFn = func() (sessionT, error) {
+ return auth.queries.reset(user.id, pwhash, token)
+ }
} else {
- // return auth.queries.confirm(user.id, pwhash, token)
- return auth.queries.confirm("token FIXME", guuid.New())
+ nextFn = func() (sessionT, error) {
+ // return auth.queries.confirm(user.id, pwhash, token)
+ return auth.queries.confirm("token FIXME", guuid.New())
+ }
}
+
+ session, err := nextFn()
+ if err != nil {
+ return Session{}, err
+ }
+
+ return asPublicSession(session)
}
func (auth authT) ChangePassword(
- user userT,
+ user User,
currentPassword string,
newPassword string,
confirmNewPassword string,
-) (sessionT, error) {
+) (Session, error) {
if newPassword != confirmNewPassword {
- return sessionT{}, ErrPassMismatch
+ return Session{}, ErrPassMismatch
}
if len(newPassword) < scrypt.MinimumPasswordLength {
- return sessionT{}, ErrTooShort
+ return Session{}, ErrTooShort
}
input := scrypt.HashInput{
Password: []byte(newPassword),
- Salt: user.salt,
+ Salt: user.Salt,
}
pwhash, err := auth.hasher(input)
if err != nil {
- return sessionT{}, err
+ return Session{}, err
}
- if !user.confirmed {
- return sessionT{}, ErrUnconfirmed
+ if !user.Confirmed {
+ return Session{}, ErrUnconfirmed
}
- return auth.queries.change(user.id, pwhash)
+ session, err := auth.queries.change(user.UUID, pwhash)
+ if err != nil {
+ return Session{}, nil
+ }
+
+ return asPublicSession(session)
}
func runLogout(
lookupFn func(guuid.UUID) (sessionT, error),
- session sessionT,
+ sessionID guuid.UUID,
queryFn func(guuid.UUID) error,
) error {
+ /*
err := validateSession(lookupFn, session)
if err != nil {
return err
}
+ */
- return queryFn(session.uuid)
+ return queryFn(sessionID)
}
-func (auth authT) Logout(session sessionT) error {
+func (auth authT) Logout(session Session) error {
return runLogout(
auth.queries.byUUID,
- session,
+ session.UUID,
auth.queries.logout,
)
}
-func (auth authT) LogoutOthers(session sessionT) error {
+func (auth authT) LogoutOthers(session Session) error {
return runLogout(
auth.queries.byUUID,
- session,
+ session.UUID,
auth.queries.outOthers,
)
}
-func (auth authT) LogoutAll(session sessionT) error {
+func (auth authT) LogoutAll(session Session) error {
return runLogout(
auth.queries.byUUID,
- session,
+ session.UUID,
auth.queries.outAll,
)
}