diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/gracha.go | 1316 |
1 files changed, 774 insertions, 542 deletions
diff --git a/src/gracha.go b/src/gracha.go index dcff98d..e300f66 100644 --- a/src/gracha.go +++ b/src/gracha.go @@ -1,48 +1,71 @@ package gracha import ( - "crypto/rand" + "context" "database/sql" "encoding/hex" "encoding/json" "errors" "fmt" + "runtime" + "sync" "time" "guuid" - "liteq" + "q" "scrypt" g "gobang" ) -type tablesT struct{ - users string - userChanges string - tokens string - roles string - roleChanges string - sessions string - attempts string - audit string +const ( + defaultPrefix = "gracha" + + rollbackErrorFmt = "rollback error: %w; while executing: %w" + NEW_USER = "new-user" + SEND_CONFIRMATION_REQUEST = "send-confirmation-request" + FORGOT_PASSWORD_REQUEST = "forgot-password-request" + + + day = 24 * time.Hour +) + +var ( + SessionDuration = 7 * day + RegisterTimeout = 15 * time.Second + + ErrPassMismatch = errors.New("gracha: password confirmation mismatch") + ErrTooShort = errors.New("gracha: password too short") + ErrTimeout = errors.New("gracha: timeout when creating user") + ErrRegistered = errors.New("gracha: user already registered") + ErrBadCombo = errors.New("gracha: bad username/passphrase combo") + ErrUnconfirmed = errors.New("gracha: email is not confirmed") + ErrRevokedSession = errors.New("gracha: this session was revoked") + ErrSessionExpired = errors.New("gracha: session expired") +) + + + +type queryT struct{ + write string + read string } type queriesT struct{ - userByEmail func(string) (userT, error) - userByToken func([]byte) (userT, error) - register func(string, []byte, []byte) (userT, error) - confirm func([]byte) (sessionT, error) - login func(string, string) (sessionT, error) - refresh func([]byte) (sessionT, error) - resetPassword func(int64, []byte, []byte) (sessionT, error) - resetAndConfirm func(int64, []byte, []byte) (sessionT, error) - changePassword func(int64, []byte) (sessionT, error) - sessionByUUID func([]byte) (sessionT, error) - logout func([]byte) error - logoutOthers func([]byte) error - logoutAll func([]byte) error - close func() error + byEmail func(string) (userT, error) + byToken func(guuid.UUID) (userT, error) + register func(string, []byte, []byte) (userT, error) + confirm func(guuid.UUID) (sessionT, error) + login func(string, string) (sessionT, error) + refresh func(guuid.UUID) (sessionT, error) + reset func(int64, []byte, guuid.UUID) (sessionT, error) + change func(int64, []byte) (sessionT, error) + byUUID func(guuid.UUID) (sessionT, error) + logout func(guuid.UUID) error + outOthers func(guuid.UUID) error + outAll func(guuid.UUID) error + close func() error } type confirmationT struct{ @@ -50,17 +73,14 @@ type confirmationT struct{ type userT struct{ id int64 - timestr string timestamp time.Time - uuid []byte + uuid guuid.UUID email string username *string salt []byte pwhash []byte - // confirmation confirmed_at *time.Time - metadatastr *string - metadata *map[string]interface{} + metadata map[string]interface{} } type sessionT struct{ @@ -70,240 +90,284 @@ type sessionT struct{ uuid guuid.UUID user_id int64 type_ string - revokedstr *string revoked_at *time.Time - metadatastr *string - metadata *map[string]interface{} -} - -type Auth struct{ - tables tablesT - queries queriesT - q liteq.Queue + metadata map[string]interface{} } type consumerT struct{ topic string - handlerFn func(Auth) func([]byte) error + handlerFn func(authT) func(q.Message) error } +type authT struct{ + queries queriesT + queue q.IQueue + hasher func(scrypt.HashInput) resultT[[]byte] + checker func(scrypt.CheckInput) resultT[bool] + close func() +} +type IAuth interface{ + Register(string, string, string) (userT, error) + ResendConfirmation(string) error + ConfirmEmail(guuid.UUID) (sessionT, error) + LoginEmail(string, string) (sessionT, error) + ForgotPassword(string) error + Refresh(sessionT) (sessionT, error) + ResetPassword(guuid.UUID, string, string) (sessionT, error) + ChangePassword(userT, string, string, string) (sessionT, error) + Logout(sessionT) error + LogoutOthers(sessionT) error + LogoutAll(sessionT) error + Close() error +} -const ( - NEW_USER = "new-user" - SEND_CONFIRMATION_REQUEST = "send-confirmation-request" - FORGOT_PASSWORD_REQUEST = "forgot-password-request" - defaultPrefix = "gracha" - day = 24 * time.Hour -) +func tryRollback(db *sql.DB, ctx context.Context, err error) error { + _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;") + if rollbackErr != nil { + return fmt.Errorf( + rollbackErrorFmt, + rollbackErr, + err, + ) + } -var ( - SessionDuration = 7 * day - RegisterTimeout = 15 * time.Second + return err +} - ErrPasswordMismatch = errors.New("gracha: password and its confirmation don't match") - ErrPasswordTooShort = errors.New("gracha: bad username/passphrase combo") - ErrRegisterTimeout = errors.New("gracha: timeout when creating user") - ErrAlreadyRegistered = errors.New("gracha: user already registered") - ErrEmailPasswordCombo = errors.New("gracha: bad username/passphrase combo") - ErrUnconfirmedUser = errors.New("gracha: user email is not confirmed") - ErrRevokedSession = errors.New("gracha: this session was revoked") - ErrSessionExpired = errors.New("gracha: session expired") -) +func inTx(db *sql.DB, fn func(context.Context) error) error { + ctx := context.Background() + + _, err := db.ExecContext(ctx, "BEGIN IMMEDIATE;") + if err != nil { + return err + } + err = fn(ctx) + if err != nil { + return tryRollback(db, ctx, err) + } + _, err = db.ExecContext(ctx, "COMMIT;") + if err != nil { + return tryRollback(db, ctx, err) + } -func tablesFrom(prefix string) (tablesT, error) { - if !g.ValidSQLTablePrefix(prefix) { - return tablesT{}, g.ErrBadSQLTablePrefix - } - - users := prefix + "-users" - userChanges := prefix + "-user-changes" - tokens := prefix + "-tokens" - roles := prefix + "-roles" - roleChanges := prefix + "-role-changes" - sessions := prefix + "-sessions" - attempts := prefix + "-attempts" - audit := prefix + "-audit" - return tablesT{ - users: users, - userChanges: userChanges, - tokens: tokens, - roles: roles, - roleChanges: roleChanges, - sessions: sessions, - attempts: attempts, - audit: audit, - }, nil + return nil } -func createTables(db *sql.DB, tables tablesT) error { - const tmpl = ` - BEGIN TRANSACTION; - CREATE TABLE IF NOT EXISTS "%s" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - uuid BLOB NOT NULL UNIQUE, - email TEXT NOT NULL UNIQUE, - username TEXT UNIQUE, - salt BLOB NOT NULL UNIQUE, - pwhash BLOB NOT NULL, - confirmed_at TEXT, - confirmer_id INTEGER REFERENCES "%s"(id), - metadata TEXT +func createTablesSQL(prefix string) queryT { + const tmpl_write = ` + CREATE TABLE IF NOT EXISTS "%s_users" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + uuid BLOB NOT NULL UNIQUE, + email TEXT NOT NULL UNIQUE, + username TEXT UNIQUE, + salt BLOB NOT NULL UNIQUE, + pwhash BLOB NOT NULL, + confirmed_at TEXT, + confirmer_id INTEGER REFERENCES "%s"(id), + metadata TEXT ); - CREATE TABLE IF NOT EXISTS "%s" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - user_id INTEGER NOT NULL REFERENCES "%s"(id), - attribute TEXT NOT NULL, - value TEXT NOT NULL, - op BOOLEAN NOT NULL + CREATE TABLE IF NOT EXISTS "%s_user_changes" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + user_id INTEGER NOT NULL REFERENCES "%s_users"(id), + attribute TEXT NOT NULL, + value TEXT NOT NULL, + op BOOLEAN NOT NULL ); - CREATE TABLE IF NOT EXISTS "%s" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - uuid BLOB NOT NULL UNIQUE, - type TEXT NOT NULL + CREATE TABLE IF NOT EXISTS "%s_tokens" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + uuid BLOB NOT NULL UNIQUE, + type TEXT NOT NULL ); - CREATE TABLE IF NOT EXISTS "%s" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - user_id INTEGER NOT NULL REFERENCES "%s"(id), - role TEXT NOT NULL, + CREATE TABLE IF NOT EXISTS "%s_roles" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL REFERENCES "%s_users"(id), + role TEXT NOT NULL, + metadata TEXT, UNIQUE (user_id, role) ); - CREATE TABLE IF NOT EXISTS "%s" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - user_id INTEGER NOT NULL REFERENCES "%s"(id), - role TEXT NOT NULL, - op BOOLEAN NOT NULL + CREATE TABLE IF NOT EXISTS "%s_role_changes" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + user_id INTEGER NOT NULL REFERENCES "%s_roles"(id), + role TEXT NOT NULL, + op BOOLEAN NOT NULL ); - CREATE TABLE IF NOT EXISTS "%s" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - uuid BLOB NOT NULL UNIQUE, - user_id INTEGER NOT NULL REFERENCES "%s"(id), - type TEXT NOT NULL, - revoked_at TEXT, - revoker_id INTEGER REFERENCES "%s"(id), - metadata TEXT + CREATE TABLE IF NOT EXISTS "%s_sessions" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + uuid BLOB NOT NULL UNIQUE, + user_id INTEGER NOT NULL REFERENCES "%s_users"(id), + type TEXT NOT NULL, + revoked_at TEXT, + revoker_id INTEGER REFERENCES "%s_users"(id), + metadata TEXT ); - CREATE TABLE IF NOT EXISTS "%s" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - user_id INTEGER REFERENCES "%s"(id), - session_id INTEGER REFERENCES "%s"(id), - metadata TEXT + CREATE TABLE IF NOT EXISTS "%s_attempts" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + user_id INTEGER REFERENCES "%s_users"(id), + session_id INTEGER REFERENCES "%s_sessions"(id), + metadata TEXT ); - CREATE TABLE IF NOT EXISTS "%s" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - uuid BLOB NOT NULL UNIQUE, - attribute TEXT NOT NULL, - value TEXT NOT NULL, - op BOOLEAN NOT NULL, - metadata TEXT + CREATE TABLE IF NOT EXISTS "%s_audit" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + uuid BLOB NOT NULL UNIQUE, + attribute TEXT NOT NULL, + value TEXT NOT NULL, + op BOOLEAN NOT NULL, + metadata TEXT ); - COMMIT TRANSACTION; ` - sql := fmt.Sprintf( - tmpl, - tables.users, - g.SQLiteNow, - tables.tokens, - tables.userChanges, - g.SQLiteNow, - tables.users, - tables.tokens, - g.SQLiteNow, - tables.roles, - tables.users, - tables.roleChanges, - g.SQLiteNow, - tables.users, - tables.sessions, - g.SQLiteNow, - tables.users, - tables.sessions, - tables.attempts, - g.SQLiteNow, - tables.users, - tables.sessions, - tables.audit, - g.SQLiteNow, - ) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf( + tmpl_write, + prefix, + g.SQLiteNow, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + prefix, + g.SQLiteNow, + ), + } +} - _, err := db.Exec(sql) - return err +func createTables(db *sql.DB, prefix string) error { + q := createTablesSQL(prefix) + + return inTx(db, func(ctx context.Context) error { + _, err := db.ExecContext(ctx, q.write) + return err + }) } -func userByEmailQuery( - db *sql.DB, - tables tablesT, -) (func(string) (userT, error), func() error, error) { - const tmpl = ` +func byEmailSQL(prefix string) queryT { + const tmpl_read = ` SELECT id, timestamp, uuid, email, username, pwhash, metadata - FROM "%s" WHERE email = ?; + FROM "%s_users" WHERE email = ?; ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + read: fmt.Sprintf(tmpl_read, prefix), + } +} + +func byEmailStmt( + db *sql.DB, + prefix string, +) (func(string) (userT, error), func() error, error) { + q := byEmailSQL(prefix) - stmt, err := db.Prepare(sql) + readStmt, err := db.Prepare(q.read) if err != nil { return nil, nil, err } fn := func(email string) (userT, error) { - var user userT - err := stmt.QueryRow(email).Scan(&user.id) // FIXME: build user - return user, err + user := userT{ + email: email, + } + + var ( + timestr string + uuid_bytes []byte + ) + err := readStmt.QueryRow(email).Scan( + &user.id, + ×tr, + &uuid_bytes, + &user.username, + &user.pwhash, + &user.metadata, + ) + if err != nil { + return userT{}, err + } + user.uuid = guuid.UUID(uuid_bytes) + + user.timestamp, err = time.Parse(time.RFC3339Nano,timestr) + if err != nil { + return userT{}, err + } + + return user, nil } - return fn, stmt.Close, nil + return fn, readStmt.Close, nil } -func userByTokenQuery( - db *sql.DB, - tables tablesT, -) (func([]byte) (userT, error), func() error, error) { - const tmpl = ` +func byTokenSQL(prefix string) queryT{ + const tmpl_read = ` SELECT id, timestamp, uuid, email, username, pwhash, metadata FROM "%s" WHERE email = ?; ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + read: fmt.Sprintf(tmpl_read, prefix), + } +} - stmt, err := db.Prepare(sql) +func byTokenStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) (userT, error), func() error, error) { + q := byTokenSQL(prefix) + + readStmt, err := db.Prepare(q.read) if err != nil { return nil, nil, err } - fn := func(token []byte) (userT, error) { + fn := func(token guuid.UUID) (userT, error) { var user userT - err := stmt.QueryRow(token).Scan(&user.id) // FIXME: build user + // FIXME: build user + err := readStmt.QueryRow(token).Scan(&user.id) return user, err } - return fn, stmt.Close, nil + return fn, readStmt.Close, nil } -func registerQuery( - db *sql.DB, - tables tablesT, -) (func(string, []byte, []byte) (userT, error), func() error, error) { - const tmpl = ` +func registerSQL(prefix string) queryT { + const tmpl_write = ` INSERT INTO "%s" (uuid, email, username, salt, pwhash, metadata) VALUES (?, ?, ?, ?, ?, ?); ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} - stmt, err := db.Prepare(sql) +func registerStmt( + db *sql.DB, + prefix string, +) (func(string, []byte, []byte) (userT, error), func() error, error) { + q := registerSQL(prefix) + + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } @@ -320,8 +384,8 @@ func registerQuery( var user userT // err := stmt.QueryRow( - ret, err := stmt.Exec( - guuid.NewBytes(), + ret, err := writeStmt.Exec( + guuid.New(), email, "credentials.username", salt, @@ -337,429 +401,630 @@ func registerQuery( return user, err } - return fn, stmt.Close, nil + return fn, writeStmt.Close, nil } -func loginQuery( - db *sql.DB, - tables tablesT, -) (func(string, string) (sessionT, error), func() error, error) { - const tmpl = ` - -- INSERT INTO "%s" (t3, t4) VALUES (?, ?); +func confirmSQL(prefix string) queryT { + const tmpl_write = ` + -- INSERT SOMETHING %s ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} + +func confirmStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) (sessionT, error), func() error, error) { + q := confirmSQL(prefix) - stmt, err := db.Prepare(sql) + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } - fn := func(email string, pwhash string) (sessionT, error) { + fn := func(token guuid.UUID) (sessionT, error) { var session sessionT - err := stmt.QueryRow(email, pwhash).Scan(session) - // FIXME: finish + err := writeStmt.QueryRow(token).Scan(&session) return session, err } - return fn, stmt.Close, nil + return fn, writeStmt.Close, nil } -func refreshQuery( - db *sql.DB, - tables tablesT, -) (func([]byte) (sessionT, error), func() error, error) { - const tmpl = ` - -- INSERT SOMETHING %s +func loginSQL(prefix string) queryT { + const tmpl_write = ` + -- INSERT INTO "%s" (t3, t4) VALUES (?, ?); ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} + +func loginStmt( + db *sql.DB, + prefix string, +) (func(string, string) (sessionT, error), func() error, error) { + q := loginSQL(prefix) - stmt, err := db.Prepare(sql) + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } - fn := func(uuid []byte) (sessionT, error) { + fn := func(email string, pwhash string) (sessionT, error) { var session sessionT - err := stmt.QueryRow(uuid).Scan(&session) + err := writeStmt.QueryRow(email, pwhash).Scan(session) + // FIXME: finish return session, err } - return fn, stmt.Close, nil + return fn, writeStmt.Close, nil } -func resetPasswordQuery( - db *sql.DB, - tables tablesT, -) (func(int64, []byte, []byte) (sessionT, error), func() error, error) { - const tmpl = ` +func refreshSQL(prefix string) queryT { + const tmpl_write = ` -- INSERT SOMETHING %s ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} - stmt, err := db.Prepare(sql) +func refreshStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) (sessionT, error), func() error, error) { + q := refreshSQL(prefix) + + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } - fn := func(id int64, pwhash []byte, token []byte) (sessionT, error) { + fn := func(uuid guuid.UUID) (sessionT, error) { var session sessionT - err := stmt.QueryRow(id, pwhash, token).Scan(&session) + err := writeStmt.QueryRow(uuid).Scan(&session) return session, err } - return fn, stmt.Close, nil + return fn, writeStmt.Close, nil } -func resetAndConfirmQuery( - db *sql.DB, - tables tablesT, -) (func(int64, []byte, []byte) (sessionT, error), func() error, error) { - const tmpl = ` +func resetSQL(prefix string) queryT { + const tmpl_write = ` -- INSERT SOMETHING %s ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} - stmt, err := db.Prepare(sql) +func resetStmt( + db *sql.DB, + prefix string, +) (func(int64, []byte, guuid.UUID) (sessionT, error), func() error, error) { + q := resetSQL(prefix) + + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } - fn := func(id int64, pwhash []byte, token []byte) (sessionT, error) { + fn := func(id int64, pwhash []byte, token guuid.UUID) (sessionT, error) { var session sessionT - err := stmt.QueryRow(id, pwhash, token).Scan(&session) + err := writeStmt.QueryRow(id, pwhash, token).Scan(&session) return session, err } - return fn, stmt.Close, nil + return fn, writeStmt.Close, nil } -func changePasswordQuery( - db *sql.DB, - tables tablesT, -) (func(int64, []byte) (sessionT, error), func() error, error) { - const tmpl = ` +func changeSQL(prefix string) queryT { + const tmpl_write = ` -- INSERT SOMETHING %s ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} - stmt, err := db.Prepare(sql) +func changeStmt( + db *sql.DB, + prefix string, +) (func(int64, []byte) (sessionT, error), func() error, error) { + q := changeSQL(prefix) + + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(id int64, pwhash []byte) (sessionT, error) { var session sessionT - err := stmt.QueryRow(id, pwhash).Scan(&session) + err := writeStmt.QueryRow(id, pwhash).Scan(&session) return session, err } - return fn, stmt.Close, nil + return fn, writeStmt.Close, nil } -func sessionByUUIDQuery( - db *sql.DB, - tables tablesT, -) (func([]byte) (sessionT, error), func() error, error) { - const tmpl = ` +func byUUIDSQL(prefix string) queryT { + const tmpl_read = ` -- INSERT SOMETHING %s ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + read: fmt.Sprintf(tmpl_read, prefix), + } +} + +func byUUIDStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) (sessionT, error), func() error, error) { + q := byUUIDSQL(prefix) - stmt, err := db.Prepare(sql) + readStmt, err := db.Prepare(q.read) if err != nil { return nil, nil, err } - fn := func(uuid []byte) (sessionT, error) { + fn := func(uuid guuid.UUID) (sessionT, error) { var session sessionT - err := stmt.QueryRow(uuid).Scan(&session) + err := readStmt.QueryRow(uuid).Scan(&session) return session, err } - return fn, stmt.Close, nil + return fn, readStmt.Close, nil } -func logoutQuery( - db *sql.DB, - tables tablesT, -) (func([]byte) error, func() error, error) { - const tmpl = ` +func logoutSQL(prefix string) queryT { + const tmpl_write = ` -- INSERT SOMETHING %s ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} + +func logoutStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) error, func() error, error) { + q := logoutSQL(prefix) - stmt, err := db.Prepare(sql) + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } - fn := func(uuid []byte) error { - _, err := stmt.Exec(uuid) + fn := func(uuid guuid.UUID) error { + _, err := writeStmt.Exec(uuid) return err } - return fn, stmt.Close, nil + return fn, writeStmt.Close, nil } -func logoutOthersQuery( - db *sql.DB, - tables tablesT, -) (func([]byte) error, func() error, error) { - const tmpl = ` +func outOthersSQL(prefix string) queryT { + const tmpl_write = ` -- INSERT SOMETHING %s ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} - stmt, err := db.Prepare(sql) +func outOthersStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) error, func() error, error) { + q := outOthersSQL(prefix) + + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } - fn := func(uuid []byte) error { - _, err := stmt.Exec(uuid) + fn := func(uuid guuid.UUID) error { + _, err := writeStmt.Exec(uuid) return err } - return fn, stmt.Close, nil + return fn, writeStmt.Close, nil } -func logoutAllQuery( - db *sql.DB, - tables tablesT, -) (func([]byte) error, func() error, error) { - const tmpl = ` +func outAllSQL(prefix string) queryT { + const tmpl_write = ` -- INSERT SOMETHING %s ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} - stmt, err := db.Prepare(sql) +func outAllStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) error, func() error, error) { + q := outAllSQL(prefix) + + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } - fn := func(uuid []byte) error { - _, err := stmt.Exec(uuid) + fn := func(uuid guuid.UUID) error { + _, err := writeStmt.Exec(uuid) return err } - return fn, stmt.Close, nil -} - -func initDB(db *sql.DB, tables tablesT) (queriesT, error) { - createTablesErr := createTables(db, tables) - userByEmail, userByEmailClose, userByEmailErr := userByEmailQuery(db, tables) - userByToken, userByTokenClose, userByTokenErr := userByTokenQuery(db, tables) - register, registerClose, registerErr := registerQuery(db, tables) - login, loginClose, loginErr := loginQuery(db, tables) - refresh, refreshClose, refreshErr := refreshQuery(db, tables) - resetPassword, resetPasswordClose, resetPasswordErr := resetPasswordQuery(db, tables) - resetAndConfirm, resetAndConfirmClose, resetAndConfirmErr := resetAndConfirmQuery(db, tables) - changePassword, changePasswordClose, changePasswordErr := changePasswordQuery(db, tables) - sessionByUUID, sessionByUUIDClose, sessionByUUIDErr := sessionByUUIDQuery(db, tables) - logout, logoutClose, logoutErr := logoutQuery(db, tables) - logoutOthers, logoutOthersClose, logoutOthersErr := logoutOthersQuery(db, tables) - logoutAll, logoutAllClose, logoutAllErr := logoutAllQuery(db, tables) - - errs := []error { - createTablesErr, - userByEmailErr, - userByTokenErr, - registerErr, - loginErr, - refreshErr, - resetPasswordErr, - resetAndConfirmErr, - changePasswordErr, - sessionByUUIDErr, - logoutErr, - logoutOthersErr, - logoutAllErr, - } - err := g.SomeError(errs) + return fn, writeStmt.Close, nil +} + +func initDB( + db *sql.DB, + prefix string, +) (queriesT, error) { + createTablesErr := createTables(db, prefix) + byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix) + byToken, byTokenClose, byTokenErr := byTokenStmt(db, prefix) + register, registerClose, registerErr := registerStmt(db, prefix) + confirm, confirmClose, confirmErr := confirmStmt(db, prefix) + login, loginClose, loginErr := loginStmt(db, prefix) + refresh, refreshClose, refreshErr := refreshStmt(db, prefix) + reset, resetClose, resetErr := resetStmt(db, prefix) + change, changeClose, changeErr := changeStmt(db, prefix) + byUUID, byUUIDClose, byUUIDErr := byUUIDStmt(db, prefix) + logout, logoutClose, logoutErr := logoutStmt(db, prefix) + outOthers, outOthersClose, outOthersErr := outOthersStmt(db, prefix) + outAll, outAllClose, outAllErr := outAllStmt(db, prefix) + + err := g.SomeError( + createTablesErr, + byEmailErr, + byTokenErr, + registerErr, + confirmErr, + loginErr, + refreshErr, + resetErr, + changeErr, + byUUIDErr, + logoutErr, + outOthersErr, + outAllErr, + ) if err != nil { return queriesT{}, err } close := func() error { - fns := [](func() error){ - userByEmailClose, - userByTokenClose, - registerClose, - loginClose, - refreshClose, - resetPasswordClose, - resetAndConfirmClose, - changePasswordClose, - sessionByUUIDClose, - logoutClose, - logoutOthersClose, - logoutAllClose, - } - return g.SomeFnError(fns) + return g.SomeFnError( + byEmailClose, + byTokenClose, + registerClose, + confirmClose, + loginClose, + refreshClose, + resetClose, + changeClose, + byUUIDClose, + logoutClose, + outOthersClose, + outAllClose, + ) } + // FIXME: lock return queriesT{ - userByEmail: userByEmail, - userByToken: userByToken, - register: register, - login: login, - close: close, - refresh: refresh, - resetPassword: resetPassword, - resetAndConfirm: resetAndConfirm, - changePassword: changePassword, - sessionByUUID: sessionByUUID, - logout: logout, - logoutOthers: logoutOthers, - logoutAll: logoutAll, + byEmail: byEmail, + byToken: byToken, + register: register, + confirm: confirm, + login: login, + refresh: refresh, + reset: reset, + change: change, + byUUID: byUUID, + logout: logout, + outOthers: outOthers, + outAll: outAll, + close: close, }, nil } -func newUserPayload(email string, salt []byte, pwhash []byte) ([]byte, error) { - data := make(map[string]interface{}) - data["email"] = email - data["salt"] = hex.EncodeToString(salt) - data["pwhash"] = hex.EncodeToString(pwhash) - return json.Marshal(data) +func newUserHandler(auth authT) func(q.Message) error { + return func(message q.Message) error { + return nil + } } -func publishRegister( - q liteq.Queue, - email string, - salt []byte, - pwhash []byte, - flowId []byte, -) error { - payload, err := newUserPayload(email, salt, pwhash) - if err != nil { - return err +func sendConfirmationRequestHandler(auth authT) func(q.Message) error { + return func(message q.Message) error { + return nil } +} - return q.Publish(NEW_USER, payload, flowId) +func forgotPasswordRequestHandler(auth authT) func(q.Message) error { + return func(message q.Message) error { + return nil + } } -func register( - auth Auth, - email string, - salt []byte, - pwhash []byte, -) (userT, error) { - flowId := guuid.NewBytes() - waiter := auth.q.WaitFor(NEW_USER, flowId) - defer auth.q.Unwait(waiter) +var consumers = []consumerT{ + consumerT{ + topic: NEW_USER, + handlerFn: newUserHandler, + }, + consumerT{ + topic: SEND_CONFIRMATION_REQUEST, + handlerFn: sendConfirmationRequestHandler, + }, + consumerT{ + topic: FORGOT_PASSWORD_REQUEST, + handlerFn: forgotPasswordRequestHandler, + }, +} +func registerConsumers(auth authT, consumers []consumerT) { + for _, consumer := range consumers { + auth.queue.Subscribe( + consumer.topic, + defaultPrefix + "-" + consumer.topic, + consumer.handlerFn(auth), + ) + } +} + +type resultT[T any] struct{ + value T + err error +} + +type taggedT[T any] struct{ + id int + value T +} + +func startRunner[A any, B any]( + in <-chan taggedT[A], + out chan<- taggedT[B], + fn func(A) B, + done func(), +) { + for input := range in { + out <- taggedT[B]{ + id: input.id, + value: fn(input.value), + } + } + done() +} + +func makePoolRunner[A any, B any](count int, fn func(A) B) (func(A) B, func()) { + var wg sync.WaitGroup + wg.Add(count) + + in := make(chan taggedT[A]) + out := make(chan taggedT[B]) + + for _ = range count { + go startRunner(in, out, fn, wg.Done) + } + + var mutex sync.Mutex + m := map[int]chan B{} + id := 0 + go func() { + for output := range out { + mutex.Lock() + defer mutex.Unlock() + m[output.id] <- output.value + close(m[output.id]) + delete(m, output.id) + } + }() + + poolRunFn := func(input A) B { + c := make(chan B) + { + mutex.Lock() + defer mutex.Unlock() + m[id] = c + id++ + } + + in <- taggedT[A]{ + id: id, + value: input, + } + return <- c + } + + close := func() { + close(in) + wg.Wait() + close(out) + } + + return poolRunFn, close +} + +func asResult[A any, B any](fn func(A) (B, error)) func(A) resultT[B] { + return func(input A) resultT[B] { + output, err := fn(input) + return resultT[B]{ + value: output, + err: err, + } + } +} - err := publishRegister(auth.q, email, salt, pwhash, flowId) +func NewWithPrefix(db *sql.DB, queue q.IQueue, prefix string) (IAuth, error) { + queries, err := initDB(db, prefix) if err != nil { - return userT{}, err + return authT{}, err } - select { - case <-time.After(RegisterTimeout): - return userT{}, ErrRegisterTimeout - case <-waiter: - return auth.queries.userByEmail(email) + numCPU := (runtime.NumCPU() / 2) + 1 + hasher, closeHasher := makePoolRunner(numCPU, asResult(scrypt.Hash)) + checker, closeChecker := makePoolRunner(numCPU, asResult(scrypt.Check)) + + close := func() { + closeHasher() + closeChecker() } + + return authT{ + queries: queries, + queue: queue, + hasher: hasher, + checker: checker, + close: close, + }, nil +} + +func New(db *sql.DB, queue q.IQueue) (IAuth, error) { + return NewWithPrefix(db, queue, defaultPrefix) } -func (auth Auth) Register( +func newUserPayload(email string, salt []byte, pwhash []byte) ([]byte, error) { + data := make(map[string]interface{}) + data["email"] = email + data["salt"] = hex.EncodeToString(salt) + data["pwhash"] = hex.EncodeToString(pwhash) + return json.Marshal(data) +} + +func (auth authT) Register( email string, password string, confirmPassword string, ) (userT, error) { if password != confirmPassword { - return userT{}, ErrPasswordMismatch + return userT{}, ErrPassMismatch } if len(password) < scrypt.MinimumPasswordLength { - return userT{}, ErrPasswordTooShort + return userT{}, ErrTooShort } // special check for sql.ErrNoRows to combat enumeration attacks. - _, lookupErr := auth.queries.userByEmail(email) + // FIXME: how so? + _, lookupErr := auth.queries.byEmail(email) if lookupErr != nil && lookupErr != sql.ErrNoRows { return userT{}, lookupErr } - salt, err := scrypt.SaltFrom(rand.Reader) + salt, err := scrypt.Salt() if err != nil { return userT{}, err } - pwhash, err := scrypt.HashFrom([]byte(password), salt) - if err != nil { - return userT{}, err + input := scrypt.HashInput{ + Password: []byte(password), + Salt: salt, } + result := auth.hasher(input) /* We also try to register anyway, to prevent disk IO timing attacks. + / FIXME: how so? */ - user, err := register(auth, email, salt, pwhash) + + flowID := guuid.New() + waiter := auth.queue.WaitFor(NEW_USER, flowID, "register") + defer waiter.Close() + + payload, err := newUserPayload(email, salt, result.value) + if err != nil { + return userT{}, err + } + + unsent := q.UnsentMessage{ + Topic: NEW_USER, + FlowID: flowID, + Payload: payload, + } + _, err = auth.queue.Publish(unsent) if err != nil { - if lookupErr != nil { - return userT{}, ErrAlreadyRegistered - } return userT{}, err } + user := userT{} + select { + case <-time.After(RegisterTimeout): + err = ErrTimeout + case <-waiter.Channel: + user, err = auth.queries.byEmail(email) + } + return user, nil } -func sendConfirmationPayload(email string) ([]byte, error) { +func sendConfirmationMessage( + email string, + flowID guuid.UUID, +) (q.UnsentMessage, error) { data := make(map[string]interface{}) data["email"] = email - return json.Marshal(data) + payload, err := json.Marshal(data) + if err != nil { + return q.UnsentMessage{}, err + } + + return q.UnsentMessage{ + Topic: SEND_CONFIRMATION_REQUEST, + FlowID: flowID, + Payload: payload, + }, nil } -func publishSendConfirmation(q liteq.Queue, email string, flowId []byte) error { - payload, err := sendConfirmationPayload(email) +func (auth authT) ResendConfirmation(email string) error { + unsent, err := sendConfirmationMessage(email, guuid.New()) if err != nil { return err } - return q.Publish(SEND_CONFIRMATION_REQUEST, payload, flowId) -} -func (auth Auth) ResendConfirmation(email string) error { - return publishSendConfirmation(auth.q, email, guuid.NewBytes()) + _, err = auth.queue.Publish(unsent) + return err } -func (auth Auth) ConfirmEmail(token guuid.UUID) (sessionT, error) { - return auth.queries.confirm(token[:]) +func (auth authT) ConfirmEmail(token guuid.UUID) (sessionT, error) { + return auth.queries.confirm(token) } -func (auth Auth) LoginEmail( +func (auth authT) LoginEmail( email string, password string, ) (sessionT, error) { if len(password) < scrypt.MinimumPasswordLength { - return sessionT{}, ErrPasswordTooShort + return sessionT{}, ErrTooShort } // special check for sql.ErrNoRows to combat enumeration attacks. - user, err := auth.queries.userByEmail(email) + user, err := auth.queries.byEmail(email) if err != nil && err != sql.ErrNoRows { return sessionT{}, err } - ok, err := scrypt.CheckFrom([]byte(password), user.salt, user.pwhash) + input := scrypt.CheckInput{ + Password: []byte(password), + Salt: user.salt, + Hash: user.pwhash, + } + ok, err := scrypt.Check(input) if err != nil { return sessionT{}, err } if !ok { - return sessionT{}, ErrEmailPasswordCombo + return sessionT{}, ErrBadCombo } if user.confirmed_at == nil { - return sessionT{}, ErrUnconfirmedUser + return sessionT{}, ErrUnconfirmed } dbSession, err := auth.queries.login(email, password) @@ -770,29 +1035,38 @@ func (auth Auth) LoginEmail( return dbSession, nil } -func forgotPasswordPayload(email string) ([]byte, error) { +func forgotPasswordMessage( + email string, + flowID guuid.UUID, +) (q.UnsentMessage, error) { data := make(map[string]interface{}) data["email"] = email - return json.Marshal(data) -} - -func publishForgotPassword(q liteq.Queue, email string, flowId []byte) error { - payload, err := forgotPasswordPayload(email) + payload, err := json.Marshal(data) if err != nil { - return err + return q.UnsentMessage{}, err } - return q.Publish(FORGOT_PASSWORD_REQUEST, payload, flowId) + return q.UnsentMessage{ + Topic: FORGOT_PASSWORD_REQUEST, + FlowID: flowID, + Payload: payload, + }, nil } -func (auth Auth) ForgotPassword(email string) error { +func (auth authT) ForgotPassword(email string) error { // special check for sql.ErrNoRows to combat enumeration attacks. - user, err := auth.queries.userByEmail(email) + user, err := auth.queries.byEmail(email) if err != nil && err != sql.ErrNoRows { return err } - return publishForgotPassword(auth.q, user.email, guuid.NewBytes()) + unsent, err := forgotPasswordMessage(user.email, guuid.New()) + if err != nil { + return err + } + + _, err = auth.queue.Publish(unsent) + return err } func checkSession(session sessionT, now time.Time) error { @@ -807,8 +1081,11 @@ func checkSession(session sessionT, now time.Time) error { return nil } -func validateSession(lookupFn func([]byte) (sessionT, error), session sessionT) error { - dbSession, err := lookupFn(session.uuid[:]) +func validateSession( + lookupFn func(guuid.UUID) (sessionT, error), + session sessionT, +) error { + dbSession, err := lookupFn(session.uuid) if err != nil { return err } @@ -816,184 +1093,139 @@ func validateSession(lookupFn func([]byte) (sessionT, error), session sessionT) return checkSession(dbSession, time.Now()) } -func (auth Auth) Refresh(session sessionT) (sessionT, error) { - err := validateSession(auth.queries.sessionByUUID, session) +func (auth authT) Refresh(session sessionT) (sessionT, error) { + err := validateSession(auth.queries.byUUID, session) if err != nil { return sessionT{}, err } - return auth.queries.refresh(session.uuid[:]) + return auth.queries.refresh(session.uuid) } -func (auth Auth) ResetPassword( - token []byte, +func (auth authT) ResetPassword( + token guuid.UUID, password string, confirmPassword string, ) (sessionT, error) { if password != confirmPassword { - return sessionT{}, ErrPasswordMismatch + return sessionT{}, ErrPassMismatch } if len(password) < scrypt.MinimumPasswordLength { - return sessionT{}, ErrPasswordTooShort + return sessionT{}, ErrTooShort } - user, err := auth.queries.userByToken(token) + user, err := auth.queries.byToken(token) if err != nil { return sessionT{}, err } - pwhash, err := scrypt.HashFrom([]byte(password), user.salt) + input := scrypt.HashInput{ + Password: []byte(password), + Salt: user.salt, + } + pwhash, err := scrypt.Hash(input) if err != nil { return sessionT{}, err } if user.confirmed_at != nil { - return auth.queries.resetPassword(user.id, pwhash, token) + return auth.queries.reset(user.id, pwhash, token) } else { - return auth.queries.resetAndConfirm(user.id, pwhash, token) + // return auth.queries.confirm(user.id, pwhash, token) + return auth.queries.confirm(token) } } -func (auth Auth) ChangePassword( +func (auth authT) ChangePassword( user userT, currentPassword string, newPassword string, confirmNewPassword string, ) (sessionT, error) { if newPassword != confirmNewPassword { - return sessionT{}, ErrPasswordMismatch + return sessionT{}, ErrPassMismatch } if len(newPassword) < scrypt.MinimumPasswordLength { - return sessionT{}, ErrPasswordTooShort + return sessionT{}, ErrTooShort } - pwhash, err := scrypt.HashFrom([]byte(newPassword), user.salt) + // FIXME + input := scrypt.HashInput{ + Password: []byte(newPassword), + Salt: user.salt, + } + pwhash, err := scrypt.Hash(input) if err != nil { return sessionT{}, err } if user.confirmed_at == nil { - return sessionT{}, ErrUnconfirmedUser + return sessionT{}, ErrUnconfirmed } - return auth.queries.changePassword(user.id, pwhash) + return auth.queries.change(user.id, pwhash) } func runLogout( - lookupFn func([]byte) (sessionT, error), + lookupFn func(guuid.UUID) (sessionT, error), session sessionT, - queryFn func([]byte) error, + queryFn func(guuid.UUID) error, ) error { err := validateSession(lookupFn, session) if err != nil { return err } - return queryFn(session.uuid[:]) + return queryFn(session.uuid) } -func (auth Auth) Logout(session sessionT) error { - return runLogout(auth.queries.sessionByUUID, session, auth.queries.logout) +func (auth authT) Logout(session sessionT) error { + return runLogout( + auth.queries.byUUID, + session, + auth.queries.logout, + ) } -func (auth Auth) LogoutOthers(session sessionT) error { - return runLogout(auth.queries.sessionByUUID, session, auth.queries.logoutOthers) +func (auth authT) LogoutOthers(session sessionT) error { + return runLogout( + auth.queries.byUUID, + session, + auth.queries.outOthers, + ) } -func (auth Auth) LogoutAll(session sessionT) error { - return runLogout(auth.queries.sessionByUUID, session, auth.queries.logoutAll) +func (auth authT) LogoutAll(session sessionT) error { + return runLogout( + auth.queries.byUUID, + session, + auth.queries.outAll, + ) } -func (auth Auth) Close() error { +func (auth authT) Close() error { return auth.queries.close() } -func newUserHandler(auth Auth) func([]byte) error { - return func(payload []byte) error { - return nil - } -} - -func sendConfirmationRequestHandler(auth Auth) func([]byte) error { - return func(payload []byte) error { - return nil - } -} - -func forgotPasswordRequestHandler(auth Auth) func([]byte) error { - return func(payload []byte) error { - return nil - } -} - -var consumers = []consumerT{ - consumerT{ - topic: NEW_USER, - handlerFn: newUserHandler, - }, - consumerT{ - topic: SEND_CONFIRMATION_REQUEST, - handlerFn: sendConfirmationRequestHandler, - }, - consumerT{ - topic: FORGOT_PASSWORD_REQUEST, - handlerFn: forgotPasswordRequestHandler, - }, -} -func registerConsumers(auth Auth, consumers []consumerT) { - for _, consumer := range consumers { - auth.q.Subscribe( - consumer.topic, - defaultPrefix + "-" + consumer.topic, - consumer.handlerFn(auth), - ) - } -} - -func NewWithPrefix(db *sql.DB, q liteq.Queue, prefix string) (Auth, error) { - tables, err := tablesFrom(prefix) - if err != nil { - return Auth{}, err - } - - queries, err := initDB(db, tables) - if err != nil { - return Auth{}, err - } - - return Auth{ - tables: tables, - queries: queries, - q: q, - }, nil -} - -func New(db *sql.DB, q liteq.Queue) (Auth, error) { - return NewWithPrefix(db, q, defaultPrefix) -} - func Main() { g.Init() - q := new(liteq.Queue) - sql.Register("sqlite-liteq", liteq.MakeDriver(q)) - - db, err := sql.Open("sqlite-liteq", "file:gracha.db?mode=memory&cache=shared") + db, err := sql.Open("acude", "file:gracha.db?mode=memory&cache=shared") if err != nil { panic(err) } - // defer db.Close() + defer db.Close() - *q, err = liteq.New(db) + queue, err := q.New(db) if err != nil { panic(err) } - // defer q.Close() + defer queue.Close() - auth, err := New(db, *q) + auth, err := New(db, queue) if err != nil { fmt.Println(err) panic(err) @@ -1006,6 +1238,6 @@ func Main() { } return - fmt.Printf("q: %#v\n", q) + fmt.Printf("q: %#v\n", queue) fmt.Printf("auth: %#v\n", auth) } |