summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/gracha.go720
1 files changed, 498 insertions, 222 deletions
diff --git a/src/gracha.go b/src/gracha.go
index e300f66..c240dd3 100644
--- a/src/gracha.go
+++ b/src/gracha.go
@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"runtime"
+ "slices"
"sync"
"time"
@@ -27,7 +28,6 @@ const (
SEND_CONFIRMATION_REQUEST = "send-confirmation-request"
FORGOT_PASSWORD_REQUEST = "forgot-password-request"
-
day = 24 * time.Hour
)
@@ -48,17 +48,18 @@ var (
type queryT struct{
- write string
- read string
+ write string
+ read string
+ session string
}
type queriesT struct{
+ register func(guuid.UUID, string, []byte, []byte) (userT, error)
+ sendToken func(guuid.UUID, string) error
+ confirm func(string, guuid.UUID) (sessionT, error)
byEmail func(string) (userT, error)
- byToken func(guuid.UUID) (userT, error)
- register func(string, []byte, []byte) (userT, error)
- confirm func(guuid.UUID) (sessionT, error)
- login func(string, string) (sessionT, error)
- refresh func(guuid.UUID) (sessionT, error)
+ login func(guuid.UUID, guuid.UUID) (sessionT, error)
+ refresh func(guuid.UUID, guuid.UUID) (sessionT, error)
reset func(int64, []byte, guuid.UUID) (sessionT, error)
change func(int64, []byte) (sessionT, error)
byUUID func(guuid.UUID) (sessionT, error)
@@ -68,28 +69,23 @@ type queriesT struct{
close func() error
}
-type confirmationT struct{
-}
-
type userT struct{
id int64
timestamp time.Time
uuid guuid.UUID
email string
- username *string
salt []byte
pwhash []byte
- confirmed_at *time.Time
metadata map[string]interface{}
+ confirmed bool
}
type sessionT struct{
id int64
- timestr string
timestamp time.Time
uuid guuid.UUID
- user_id int64
- type_ string
+ userID guuid.UUID
+ // type_ string
revoked_at *time.Time
metadata map[string]interface{}
}
@@ -102,15 +98,14 @@ type consumerT struct{
type authT struct{
queries queriesT
queue q.IQueue
- hasher func(scrypt.HashInput) resultT[[]byte]
- checker func(scrypt.CheckInput) resultT[bool]
+ hasher func(scrypt.HashInput) ([]byte, error)
close func()
}
type IAuth interface{
Register(string, string, string) (userT, error)
ResendConfirmation(string) error
- ConfirmEmail(guuid.UUID) (sessionT, error)
+ ConfirmEmail(string) (sessionT, error)
LoginEmail(string, string) (sessionT, error)
ForgotPassword(string) error
Refresh(sessionT) (sessionT, error)
@@ -165,13 +160,25 @@ func createTablesSQL(prefix string) queryT {
timestamp TEXT NOT NULL DEFAULT (%s),
uuid BLOB NOT NULL UNIQUE,
email TEXT NOT NULL UNIQUE,
- username TEXT UNIQUE,
salt BLOB NOT NULL UNIQUE,
pwhash BLOB NOT NULL,
- confirmed_at TEXT,
- confirmer_id INTEGER REFERENCES "%s"(id),
metadata TEXT
);
+ CREATE TABLE IF NOT EXISTS "%s_confirmation_attempts" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ -- uuid BLOB NOT NULL UNIQUE,
+ user_id INTEGER NOT NULL REFERENCES "%s_users"(id),
+ token TEXT NOT NULL UNIQUE
+ );
+ CREATE TABLE IF NOT EXISTS "%s_user_confirmations" (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ user_id INTEGER NOT NULL
+ REFERENCES "%s_users"(id) UNIQUE,
+ attempt_id INTEGER NOT NULL
+ REFERENCES "%s_confirmation_attempts"(id) UNIQUE
+ );
CREATE TABLE IF NOT EXISTS "%s_user_changes" (
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL DEFAULT (%s),
@@ -180,12 +187,12 @@ func createTablesSQL(prefix string) queryT {
value TEXT NOT NULL,
op BOOLEAN NOT NULL
);
- CREATE TABLE IF NOT EXISTS "%s_tokens" (
- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
- timestamp TEXT NOT NULL DEFAULT (%s),
- uuid BLOB NOT NULL UNIQUE,
- type TEXT NOT NULL
- );
+ -- CREATE TABLE IF NOT EXISTS "%s_tokens" (
+ -- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ -- timestamp TEXT NOT NULL DEFAULT (%s),
+ -- uuid BLOB NOT NULL UNIQUE,
+ -- type TEXT NOT NULL
+ -- );
CREATE TABLE IF NOT EXISTS "%s_roles" (
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES "%s_users"(id),
@@ -205,9 +212,10 @@ func createTablesSQL(prefix string) queryT {
timestamp TEXT NOT NULL DEFAULT (%s),
uuid BLOB NOT NULL UNIQUE,
user_id INTEGER NOT NULL REFERENCES "%s_users"(id),
- type TEXT NOT NULL,
- revoked_at TEXT,
- revoker_id INTEGER REFERENCES "%s_users"(id),
+ -- type TEXT NOT NULL,
+ -- revoked_at TEXT,
+ -- revoker_id INTEGER REFERENCES "%s_users"(id),
+ -- FIXME: add provenance: login, refresh, confirmation, etc.
metadata TEXT
);
CREATE TABLE IF NOT EXISTS "%s_attempts" (
@@ -233,6 +241,12 @@ func createTablesSQL(prefix string) queryT {
prefix,
g.SQLiteNow,
prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
prefix,
g.SQLiteNow,
prefix,
@@ -266,50 +280,68 @@ func createTables(db *sql.DB, prefix string) error {
})
}
-func byEmailSQL(prefix string) queryT {
+func registerSQL(prefix string) queryT {
+ const tmpl_write = `
+ INSERT INTO "%s_users" (uuid, email, salt, pwhash)
+ VALUES (?, ?, ?, ?) RETURNING id, timestamp;
+ `
const tmpl_read = `
- SELECT id, timestamp, uuid, email, username, pwhash, metadata
- FROM "%s_users" WHERE email = ?;
+ SELECT id, timestamp from "%s_users"
+ WHERE uuid = ?;
`
return queryT{
- read: fmt.Sprintf(tmpl_read, prefix),
+ write: fmt.Sprintf(tmpl_write, prefix),
+ read: fmt.Sprintf(tmpl_read, prefix),
}
}
-func byEmailStmt(
+func registerStmt(
db *sql.DB,
prefix string,
-) (func(string) (userT, error), func() error, error) {
- q := byEmailSQL(prefix)
+) (
+ func(guuid.UUID, string, []byte, []byte) (userT, error),
+ func() error,
+ error,
+) {
+ q := registerSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
+ if err != nil {
+ return nil, nil, err
+ }
readStmt, err := db.Prepare(q.read)
if err != nil {
return nil, nil, err
}
- fn := func(email string) (userT, error) {
+ fn := func(
+ userID guuid.UUID,
+ email string,
+ salt []byte,
+ pwhash []byte,
+ ) (userT, error) {
user := userT{
- email: email,
+ uuid: userID,
+ email: email,
+ salt: salt,
+ pwhash: pwhash,
+ confirmed: false,
}
- var (
- timestr string
- uuid_bytes []byte
- )
- err := readStmt.QueryRow(email).Scan(
- &user.id,
- &timestr,
- &uuid_bytes,
- &user.username,
- &user.pwhash,
- &user.metadata,
- )
+ var timestr string
+ user_id_bytes := userID[:]
+ err := writeStmt.QueryRow(
+ user_id_bytes,
+ email,
+ salt,
+ pwhash,
+ ).Scan(&user.id, &timestr)
if err != nil {
return userT{}, err
}
- user.uuid = guuid.UUID(uuid_bytes)
- user.timestamp, err = time.Parse(time.RFC3339Nano,timestr)
+ user.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
if err != nil {
return userT{}, err
}
@@ -317,165 +349,329 @@ func byEmailStmt(
return user, nil
}
- return fn, readStmt.Close, nil
+ closeFn := func() error {
+ return g.SomeFnError(writeStmt.Close, readStmt.Close)
+ }
+
+ return fn, closeFn, nil
}
-func byTokenSQL(prefix string) queryT{
- const tmpl_read = `
- SELECT id, timestamp, uuid, email, username, pwhash, metadata
- FROM "%s" WHERE email = ?;
+func sendTokenSQL(prefix string) queryT {
+ const tmpl_write = `
+ INSERT INTO "%s_confirmation_attempts" (user_id, token)
+ VALUES (
+ (SELECT id FROM "%s_users" WHERE uuid = ?),
+ ?
+ )
`
return queryT{
- read: fmt.Sprintf(tmpl_read, prefix),
+ write: fmt.Sprintf(tmpl_write, prefix, prefix),
}
}
-func byTokenStmt(
+func sendTokenStmt(
db *sql.DB,
prefix string,
-) (func(guuid.UUID) (userT, error), func() error, error) {
- q := byTokenSQL(prefix)
+) (func(guuid.UUID, string) error, func() error, error) {
+ q := sendTokenSQL(prefix)
- readStmt, err := db.Prepare(q.read)
+ writeStmt, err := db.Prepare(q.write)
if err != nil {
return nil, nil, err
}
- fn := func(token guuid.UUID) (userT, error) {
- var user userT
- // FIXME: build user
- err := readStmt.QueryRow(token).Scan(&user.id)
- return user, err
+ fn := func(userID guuid.UUID, token string) error {
+ user_id_bytes := userID[:]
+ _, err := writeStmt.Exec(user_id_bytes, token)
+ return err
}
- return fn, readStmt.Close, nil
+ return fn, writeStmt.Close, nil
}
-func registerSQL(prefix string) queryT {
+func confirmSQL(prefix string) queryT {
const tmpl_write = `
- INSERT INTO "%s" (uuid, email, username, salt, pwhash, metadata)
- VALUES (?, ?, ?, ?, ?, ?);
+ INSERT INTO "%s_user_confirmations" (user_id, attempt_id)
+ VALUES (?, ?);
+ `
+ const tmpl_read = `
+ SELECT
+ "%s_confirmation_attempts".id,
+ "%s_confirmation_attempts".user_id,
+ "%s_users".uuid
+ FROM "%s_confirmation_attempts"
+ JOIN "%s_users" ON
+ "%s_confirmation_attempts".user_id = "%s_users".id
+ WHERE token = ?;
+ `
+ const tmpl_session = `
+ INSERT INTO "%s_sessions" (uuid, user_id)
+ VALUES (?, ?) RETURNING id, timestamp;
`
return queryT{
- write: fmt.Sprintf(tmpl_write, prefix),
+ write: fmt.Sprintf(tmpl_write, prefix),
+ read: fmt.Sprintf(
+ tmpl_read,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ session: fmt.Sprintf(tmpl_session, prefix),
}
}
-func registerStmt(
+func confirmStmt(
db *sql.DB,
prefix string,
-) (func(string, []byte, []byte) (userT, error), func() error, error) {
- q := registerSQL(prefix)
+) (func(string, guuid.UUID) (sessionT, error), func() error, error) {
+ q := confirmSQL(prefix)
writeStmt, err := db.Prepare(q.write)
if err != nil {
return nil, nil, err
}
- fn := func(email string, salt []byte, pwhash []byte) (userT, error) {
- /*
- timestamp TEXT NOT NULL DEFAULT (%s),
- uuid BLOB NOT NULL UNIQUE,
- email TEXT NOT NULL UNIQUE,
- username TEXT UNIQUE,
- pwhash TEXT NOT NULL,
- metadata TEXT
- */
+ readStmt, err := db.Prepare(q.read)
+ if err != nil {
+ return nil, nil, g.WrapErrors(writeStmt.Close(), err)
+ }
- var user userT
- // err := stmt.QueryRow(
- ret, err := writeStmt.Exec(
- guuid.New(),
- email,
- "credentials.username",
- salt,
- pwhash,
- "credentials.metadata",
- // ).Scan(&user.email)
- // FIXME: finish
+ sessionStmt, err := db.Prepare(q.session)
+ if err != nil {
+ return nil, nil, g.WrapErrors(
+ writeStmt.Close(),
+ readStmt.Close(),
+ err,
)
- if false {
- fmt.Printf("ret: %#v\n", ret)
- fmt.Printf("user: %#v\n", user)
+ }
+
+ fn := func(token string, sessionID guuid.UUID) (sessionT, error) {
+ session := sessionT{
+ uuid: sessionID,
}
- return user, err
+
+ var (
+ user_id int64
+ attempt_id int64
+ user_id_bytes []byte
+ )
+ err := readStmt.QueryRow(token).Scan(
+ &attempt_id,
+ &user_id,
+ &user_id_bytes,
+ )
+ if err != nil {
+ return sessionT{}, err
+ }
+ session.userID = guuid.UUID(user_id_bytes)
+
+ _, err = writeStmt.Exec(user_id, attempt_id)
+ if err != nil {
+ return sessionT{}, err
+ }
+
+ session_id_bytes := sessionID[:]
+ var timestr string
+ err = sessionStmt.QueryRow(
+ session_id_bytes,
+ user_id,
+ ).Scan(&session.id, &timestr)
+ if err != nil {
+ return sessionT{}, err
+ }
+
+ session.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
+ if err != nil {
+ return sessionT{}, err
+ }
+
+ return session, nil
}
- return fn, writeStmt.Close, nil
+ closeFn := func() error {
+ return g.SomeFnError(
+ writeStmt.Close,
+ readStmt.Close,
+ sessionStmt.Close,
+ )
+ }
+
+ return fn, closeFn, nil
}
-func confirmSQL(prefix string) queryT {
- const tmpl_write = `
- -- INSERT SOMETHING %s
+func byEmailSQL(prefix string) queryT {
+ // FIXME: rewrite as LEFT JOIN?
+ const tmpl_read = `
+ SELECT id, timestamp, uuid, salt, pwhash, metadata, (
+ CASE WHEN EXISTS (
+ SELECT id FROM "%s_user_confirmations"
+ WHERE user_id = (
+ SELECT id FROM "%s_users"
+ WHERE email = ?
+ )
+ ) THEN 1
+ ELSE 0
+ END
+ ) as confirmed
+ FROM "%s_users" WHERE email = ?;
`
return queryT{
- write: fmt.Sprintf(tmpl_write, prefix),
+ read: fmt.Sprintf(tmpl_read, prefix, prefix, prefix),
}
}
-func confirmStmt(
+func byEmailStmt(
db *sql.DB,
prefix string,
-) (func(guuid.UUID) (sessionT, error), func() error, error) {
- q := confirmSQL(prefix)
+) (func(string) (userT, error), func() error, error) {
+ q := byEmailSQL(prefix)
- writeStmt, err := db.Prepare(q.write)
+ readStmt, err := db.Prepare(q.read)
if err != nil {
return nil, nil, err
}
- fn := func(token guuid.UUID) (sessionT, error) {
- var session sessionT
- err := writeStmt.QueryRow(token).Scan(&session)
- return session, err
+ fn := func(email string) (userT, error) {
+ user := userT{
+ email: email,
+ }
+
+ var (
+ timestr string
+ user_id_bytes []byte
+ metadatastr sql.NullString
+ )
+ err := readStmt.QueryRow(email, email).Scan(
+ &user.id,
+ &timestr,
+ &user_id_bytes,
+ &user.salt,
+ &user.pwhash,
+ &metadatastr,
+ &user.confirmed,
+ )
+ if err != nil {
+ return userT{}, err
+ }
+ user.uuid = guuid.UUID(user_id_bytes)
+
+ user.timestamp, err = time.Parse(time.RFC3339Nano,timestr)
+ if err != nil {
+ return userT{}, err
+ }
+
+ if metadatastr.Valid {
+ err := json.Unmarshal(
+ []byte(metadatastr.String),
+ &user.metadata,
+ )
+ if err != nil {
+ g.Warning(
+ "failed to parse metadata field",
+ "sqlite-json-unmarshal-error",
+ "userID", user.uuid.String(),
+ "error", err,
+ )
+ }
+ }
+
+ return user, nil
}
- return fn, writeStmt.Close, nil
+ return fn, readStmt.Close, nil
}
func loginSQL(prefix string) queryT {
- const tmpl_write = `
- -- INSERT INTO "%s" (t3, t4) VALUES (?, ?);
+ const tmpl_session = `
+ INSERT INTO "%s_sessions" (uuid, user_id)
+ VALUES (
+ ?,
+ (SELECT id FROM "%s_users" WHERE uuid = ?)
+ ) RETURNING id, timestamp;
`
return queryT{
- write: fmt.Sprintf(tmpl_write, prefix),
+ session: fmt.Sprintf(tmpl_session, prefix, prefix),
}
}
func loginStmt(
db *sql.DB,
prefix string,
-) (func(string, string) (sessionT, error), func() error, error) {
+) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) {
q := loginSQL(prefix)
- writeStmt, err := db.Prepare(q.write)
+ sessionStmt, err := db.Prepare(q.session)
if err != nil {
return nil, nil, err
}
- fn := func(email string, pwhash string) (sessionT, error) {
- var session sessionT
- err := writeStmt.QueryRow(email, pwhash).Scan(session)
- // FIXME: finish
+ fn := func(userID guuid.UUID, sessionID guuid.UUID) (sessionT, error) {
+ session := sessionT{
+ uuid: sessionID,
+ userID: userID,
+ }
+
+ user_id_bytes := userID[:]
+ session_id_bytes := sessionID[:]
+ var timestr string
+ err := sessionStmt.QueryRow(
+ session_id_bytes,
+ user_id_bytes,
+ ).Scan(
+ &session.id,
+ &timestr,
+ )
+ if err != nil {
+ return sessionT{}, err
+ }
+
+ session.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
+ if err != nil {
+ return sessionT{}, err
+ }
+
return session, err
}
- return fn, writeStmt.Close, nil
+ return fn, sessionStmt.Close, nil
}
func refreshSQL(prefix string) queryT {
const tmpl_write = `
- -- INSERT SOMETHING %s
+ INSERT INTO "%s_sessions" (uuid, user_id)
+ VALUES (
+ ?,
+ (SELECT user_id FROM "%s_sessions" WHERE uuid = ?)
+ ) RETURNING id, timestamp, (
+ SELECT "%s_users".uuid FROM "%s_users"
+ JOIN "%s_sessions" ON
+ "%s_users".id = "%s_sessions".user_id
+ WHERE "%s_sessions".uuid = ?
+ ) AS userID;
`
return queryT{
- write: fmt.Sprintf(tmpl_write, prefix),
+ write: fmt.Sprintf(
+ tmpl_write,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
}
}
func refreshStmt(
db *sql.DB,
prefix string,
-) (func(guuid.UUID) (sessionT, error), func() error, error) {
+) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) {
q := refreshSQL(prefix)
writeStmt, err := db.Prepare(q.write)
@@ -483,9 +679,35 @@ func refreshStmt(
return nil, nil, err
}
- fn := func(uuid guuid.UUID) (sessionT, error) {
- var session sessionT
- err := writeStmt.QueryRow(uuid).Scan(&session)
+ fn := func(
+ sessionID guuid.UUID,
+ newSessionID guuid.UUID,
+ ) (sessionT, error) {
+ session := sessionT{
+ uuid: newSessionID,
+ }
+
+ session_id_bytes := sessionID[:]
+ new_session_id_bytes := newSessionID[:]
+ var (
+ timestr string
+ user_id_bytes []byte
+ )
+ err := writeStmt.QueryRow(
+ new_session_id_bytes,
+ session_id_bytes,
+ session_id_bytes,
+ ).Scan(&session.id, &timestr, &user_id_bytes)
+ if err != nil {
+ return sessionT{}, err
+ }
+ session.userID = guuid.UUID(user_id_bytes)
+
+ session.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
+ if err != nil {
+ return sessionT{}, err
+ }
+
return session, err
}
@@ -570,9 +792,9 @@ func byUUIDStmt(
return nil, nil, err
}
- fn := func(uuid guuid.UUID) (sessionT, error) {
+ fn := func(sessionID guuid.UUID) (sessionT, error) {
var session sessionT
- err := readStmt.QueryRow(uuid).Scan(&session)
+ err := readStmt.QueryRow(sessionID).Scan(&session)
return session, err
}
@@ -599,8 +821,8 @@ func logoutStmt(
return nil, nil, err
}
- fn := func(uuid guuid.UUID) error {
- _, err := writeStmt.Exec(uuid)
+ fn := func(sessionID guuid.UUID) error {
+ _, err := writeStmt.Exec(sessionID)
return err
}
@@ -627,8 +849,8 @@ func outOthersStmt(
return nil, nil, err
}
- fn := func(uuid guuid.UUID) error {
- _, err := writeStmt.Exec(uuid)
+ fn := func(sessionID guuid.UUID) error {
+ _, err := writeStmt.Exec(sessionID)
return err
}
@@ -655,8 +877,8 @@ func outAllStmt(
return nil, nil, err
}
- fn := func(uuid guuid.UUID) error {
- _, err := writeStmt.Exec(uuid)
+ fn := func(sessionID guuid.UUID) error {
+ _, err := writeStmt.Exec(sessionID)
return err
}
@@ -668,10 +890,10 @@ func initDB(
prefix string,
) (queriesT, error) {
createTablesErr := createTables(db, prefix)
- byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix)
- byToken, byTokenClose, byTokenErr := byTokenStmt(db, prefix)
register, registerClose, registerErr := registerStmt(db, prefix)
+ sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix)
confirm, confirmClose, confirmErr := confirmStmt(db, prefix)
+ byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix)
login, loginClose, loginErr := loginStmt(db, prefix)
refresh, refreshClose, refreshErr := refreshStmt(db, prefix)
reset, resetClose, resetErr := resetStmt(db, prefix)
@@ -683,10 +905,10 @@ func initDB(
err := g.SomeError(
createTablesErr,
- byEmailErr,
- byTokenErr,
registerErr,
+ sendTokenErr,
confirmErr,
+ byEmailErr,
loginErr,
refreshErr,
resetErr,
@@ -702,10 +924,10 @@ func initDB(
close := func() error {
return g.SomeFnError(
- byEmailClose,
- byTokenClose,
registerClose,
+ sendTokenClose,
confirmClose,
+ byEmailClose,
loginClose,
refreshClose,
resetClose,
@@ -717,21 +939,78 @@ func initDB(
)
}
- // FIXME: lock
+ var connMutex sync.RWMutex
return queriesT{
- byEmail: byEmail,
- byToken: byToken,
- register: register,
- confirm: confirm,
- login: login,
- refresh: refresh,
- reset: reset,
- change: change,
- byUUID: byUUID,
- logout: logout,
- outOthers: outOthers,
- outAll: outAll,
- close: close,
+ register: func(
+ a guuid.UUID,
+ b string,
+ c []byte,
+ d []byte,
+ ) (userT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return register(a, b, c, d)
+ },
+ sendToken: func(a guuid.UUID, b string) error {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return sendToken(a, b)
+ },
+ confirm: func(a string, b guuid.UUID) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return confirm(a, b)
+ },
+ byEmail: func(a string) (userT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return byEmail(a)
+ },
+ login: func(a guuid.UUID, b guuid.UUID) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return login(a, b)
+ },
+ refresh: func(a guuid.UUID, b guuid.UUID) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return refresh(a, b)
+ },
+ reset: func(a int64, b []byte, c guuid.UUID) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return reset(a, b, c)
+ },
+ change: func(a int64, b []byte) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return change(a, b)
+ },
+ byUUID: func(a guuid.UUID) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return byUUID(a)
+ },
+ logout: func(a guuid.UUID) error {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return logout(a)
+ },
+ outOthers: func(a guuid.UUID) error {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return outOthers(a)
+ },
+ outAll: func(a guuid.UUID) error {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return outAll(a)
+ },
+ close: func() error {
+ connMutex.Lock()
+ defer connMutex.Unlock()
+ return close()
+ },
}, nil
}
@@ -767,13 +1046,24 @@ var consumers = []consumerT{
handlerFn: forgotPasswordRequestHandler,
},
}
-func registerConsumers(auth authT, consumers []consumerT) {
+func registerConsumers(auth authT, consumers []consumerT, prefix string) error {
for _, consumer := range consumers {
- auth.queue.Subscribe(
+ err := auth.queue.Subscribe(
consumer.topic,
- defaultPrefix + "-" + consumer.topic,
+ prefix + "-" + consumer.topic,
consumer.handlerFn(auth),
)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func unregisterConsumers(queue q.IQueue, consumers []consumerT, prefix string) {
+ for _, consumer := range consumers {
+ queue.Unsubscribe(consumer.topic, prefix + "-" + consumer.topic)
}
}
@@ -861,28 +1151,41 @@ func asResult[A any, B any](fn func(A) (B, error)) func(A) resultT[B] {
}
}
+func unwrapResult[A any, B any](fn func(A) resultT[B]) func(A) (B, error) {
+ return func(input A) (B, error) {
+ result := fn(input)
+ return result.value, result.err
+ }
+}
+
func NewWithPrefix(db *sql.DB, queue q.IQueue, prefix string) (IAuth, error) {
queries, err := initDB(db, prefix)
if err != nil {
return authT{}, err
}
- numCPU := (runtime.NumCPU() / 2) + 1
- hasher, closeHasher := makePoolRunner(numCPU, asResult(scrypt.Hash))
- checker, closeChecker := makePoolRunner(numCPU, asResult(scrypt.Check))
+ numCPU := runtime.NumCPU()
+ hasher, closeHasher := makePoolRunner(numCPU, asResult(scrypt.Hash))
- close := func() {
+ closeFn := func() {
+ unregisterConsumers(queue, consumers, prefix)
closeHasher()
- closeChecker()
}
- return authT{
+ auth := authT{
queries: queries,
queue: queue,
- hasher: hasher,
- checker: checker,
- close: close,
- }, nil
+ hasher: unwrapResult(hasher),
+ close: closeFn,
+ }
+
+ err = registerConsumers(auth, consumers, prefix)
+ if err != nil {
+ closeFn()
+ return authT{}, err
+ }
+
+ return auth, nil
}
func New(db *sql.DB, queue q.IQueue) (IAuth, error) {
@@ -926,7 +1229,10 @@ func (auth authT) Register(
Password: []byte(password),
Salt: salt,
}
- result := auth.hasher(input)
+ hash, err := auth.hasher(input)
+ if err != nil {
+ return userT{}, err
+ }
/*
We also try to register anyway, to prevent disk IO timing attacks.
@@ -937,7 +1243,7 @@ func (auth authT) Register(
waiter := auth.queue.WaitFor(NEW_USER, flowID, "register")
defer waiter.Close()
- payload, err := newUserPayload(email, salt, result.value)
+ payload, err := newUserPayload(email, salt, hash)
if err != nil {
return userT{}, err
}
@@ -987,13 +1293,12 @@ func (auth authT) ResendConfirmation(email string) error {
return err
}
-
_, err = auth.queue.Publish(unsent)
return err
}
-func (auth authT) ConfirmEmail(token guuid.UUID) (sessionT, error) {
- return auth.queries.confirm(token)
+func (auth authT) ConfirmEmail(token string) (sessionT, error) {
+ return auth.queries.confirm("token FIXME", guuid.New())
}
func (auth authT) LoginEmail(
@@ -1005,34 +1310,36 @@ func (auth authT) LoginEmail(
}
// special check for sql.ErrNoRows to combat enumeration attacks.
+ // FIXME: how so?
user, err := auth.queries.byEmail(email)
if err != nil && err != sql.ErrNoRows {
return sessionT{}, err
}
- input := scrypt.CheckInput{
+ input := scrypt.HashInput{
Password: []byte(password),
Salt: user.salt,
- Hash: user.pwhash,
}
- ok, err := scrypt.Check(input)
+ hash, err := auth.hasher(input)
if err != nil {
return sessionT{}, err
}
+
+ ok := slices.Equal(hash, user.pwhash)
if !ok {
return sessionT{}, ErrBadCombo
}
- if user.confirmed_at == nil {
+ if !user.confirmed {
return sessionT{}, ErrUnconfirmed
}
- dbSession, err := auth.queries.login(email, password)
+ session, err := auth.queries.login(user.uuid, guuid.New())
if err != nil {
return sessionT{}, err
}
- return dbSession, nil
+ return session, nil
}
func forgotPasswordMessage(
@@ -1099,7 +1406,7 @@ func (auth authT) Refresh(session sessionT) (sessionT, error) {
return sessionT{}, err
}
- return auth.queries.refresh(session.uuid)
+ return auth.queries.refresh(session.uuid, guuid.New())
}
func (auth authT) ResetPassword(
@@ -1115,25 +1422,22 @@ func (auth authT) ResetPassword(
return sessionT{}, ErrTooShort
}
- user, err := auth.queries.byToken(token)
- if err != nil {
- return sessionT{}, err
- }
+ user := userT{}
input := scrypt.HashInput{
Password: []byte(password),
Salt: user.salt,
}
- pwhash, err := scrypt.Hash(input)
+ pwhash, err := auth.hasher(input)
if err != nil {
return sessionT{}, err
}
- if user.confirmed_at != nil {
+ if user.confirmed {
return auth.queries.reset(user.id, pwhash, token)
} else {
// return auth.queries.confirm(user.id, pwhash, token)
- return auth.queries.confirm(token)
+ return auth.queries.confirm("token FIXME", guuid.New())
}
}
@@ -1151,17 +1455,16 @@ func (auth authT) ChangePassword(
return sessionT{}, ErrTooShort
}
- // FIXME
input := scrypt.HashInput{
Password: []byte(newPassword),
Salt: user.salt,
}
- pwhash, err := scrypt.Hash(input)
+ pwhash, err := auth.hasher(input)
if err != nil {
return sessionT{}, err
}
- if user.confirmed_at == nil {
+ if !user.confirmed {
return sessionT{}, ErrUnconfirmed
}
@@ -1212,32 +1515,5 @@ func (auth authT) Close() error {
func Main() {
- g.Init()
- db, err := sql.Open("acude", "file:gracha.db?mode=memory&cache=shared")
- if err != nil {
- panic(err)
- }
- defer db.Close()
-
- queue, err := q.New(db)
- if err != nil {
- panic(err)
- }
- defer queue.Close()
-
- auth, err := New(db, queue)
- if err != nil {
- fmt.Println(err)
- panic(err)
- }
-
- user, err := auth.Register("contact@example.com", "password", "password")
- if false {
- fmt.Printf("user: %#v\n", user)
- fmt.Printf("err: %#v\n", err)
- }
-
- return
- fmt.Printf("q: %#v\n", queue)
- fmt.Printf("auth: %#v\n", auth)
+ // FIXME
}