diff options
Diffstat (limited to 'src/cracha.go')
-rw-r--r-- | src/cracha.go | 360 |
1 files changed, 240 insertions, 120 deletions
diff --git a/src/cracha.go b/src/cracha.go index dc4193c..f669198 100644 --- a/src/cracha.go +++ b/src/cracha.go @@ -1,7 +1,6 @@ package cracha import ( - "context" "database/sql" "encoding/hex" "encoding/json" @@ -48,6 +47,12 @@ var ( +type dbconfigT struct{ + shared *sql.DB + dbpath string + prefix string +} + type queryT struct{ write string read string @@ -55,19 +60,20 @@ type queryT struct{ } type queriesT struct{ - register func(guuid.UUID, string, []byte, []byte) (userT, error) - sendToken func(guuid.UUID, string) error - confirm func(string, guuid.UUID) (sessionT, error) - byEmail func(string) (userT, error) - login func(guuid.UUID, guuid.UUID) (sessionT, error) - refresh func(guuid.UUID, guuid.UUID) (sessionT, error) - reset func(int64, []byte, guuid.UUID) (sessionT, error) - change func(guuid.UUID, []byte) (sessionT, error) - byUUID func(guuid.UUID) (sessionT, error) - logout func(guuid.UUID) error - outOthers func(guuid.UUID) error - outAll func(guuid.UUID) error - close func() error + register func(guuid.UUID, string, []byte, []byte) (userT, error) + sendToken func(guuid.UUID, string) error + confirm func(string, guuid.UUID) (sessionT, error) + byEmail func(string) (userT, error) + userByUUID func(guuid.UUID) (userT, error) + login func(guuid.UUID, guuid.UUID) (sessionT, error) + refresh func(guuid.UUID, guuid.UUID) (sessionT, error) + reset func(int64, []byte, guuid.UUID) (sessionT, error) + change func(guuid.UUID, []byte) (sessionT, error) + byUUID func(guuid.UUID) (sessionT, error) + logout func(guuid.UUID) error + outOthers func(guuid.UUID) error + outAll func(guuid.UUID) error + close func() error } type userT struct{ @@ -82,7 +88,6 @@ type userT struct{ type User struct{ UUID guuid.UUID - Salt []byte Confirmed bool } @@ -91,7 +96,6 @@ type sessionT struct{ timestamp time.Time uuid guuid.UUID userID guuid.UUID - // type_ string revoked_at *time.Time } @@ -120,7 +124,7 @@ type IAuth interface{ ForgotPassword(string) error Refresh(Session) (Session, error) ResetPassword(guuid.UUID, string, string) (Session, error) - ChangePassword(User, string, string, string) (Session, error) + ChangePassword(guuid.UUID, string, string, string) (Session, error) Logout(Session) error LogoutOthers(Session) error LogoutAll(Session) error @@ -134,8 +138,8 @@ func validateSession(session sessionT) (sessionT, error) { return session, nil } -func tryRollback(db *sql.DB, ctx context.Context, err error) error { - _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;") +func tryRollback(tx *sql.Tx, err error) error { + rollbackErr := tx.Rollback() if rollbackErr != nil { return fmt.Errorf( rollbackErrorFmt, @@ -147,27 +151,72 @@ func tryRollback(db *sql.DB, ctx context.Context, err error) error { return err } -func inTx(db *sql.DB, fn func(context.Context) error) error { - ctx := context.Background() - - _, err := db.ExecContext(ctx, "BEGIN IMMEDIATE;") +func inTx(db *sql.DB, fn func(*sql.Tx) error) error { + tx, err := db.Begin() if err != nil { return err } - err = fn(ctx) + err = fn(tx) if err != nil { - return tryRollback(db, ctx, err) + return tryRollback(tx, err) } - _, err = db.ExecContext(ctx, "COMMIT;") + err = tx.Commit() if err != nil { - return tryRollback(db, ctx, err) + return tryRollback(tx, err) } return nil } +func serialized[A any, B any](callback func(...A) B) (func(...A) B, func()) { + in := make(chan []A) + out := make(chan B) + + closed := false + var ( + closeWg sync.WaitGroup + closeMutex sync.Mutex + ) + closeWg.Add(1) + + go func() { + for input := range in { + out <- callback(input...) + } + close(out) + closeWg.Done() + }() + + fn := func(input ...A) B { + in <- input + return (<- out) + } + + closeFn := func() { + closeMutex.Lock() + defer closeMutex.Unlock() + if closed { + return + } + close(in) + closed = true + closeWg.Wait() + } + + return fn, closeFn +} + +func execSerialized(query string, db *sql.DB) (func(...any) error, func()) { + return serialized(func(args ...any) error { + return inTx(db, func(tx *sql.Tx) error { + _, err := tx.Exec(query, args...) + return err + }) + }) +} + func createTablesSQL(prefix string) queryT { const tmpl_write = ` CREATE TABLE IF NOT EXISTS "%s_users" ( @@ -293,8 +342,8 @@ func createTablesSQL(prefix string) queryT { func createTables(db *sql.DB, prefix string) error { q := createTablesSQL(prefix) - return inTx(db, func(ctx context.Context) error { - _, err := db.ExecContext(ctx, q.write) + return inTx(db, func(tx *sql.Tx) error { + _, err := tx.Exec(q.write) return err }) } @@ -315,21 +364,20 @@ func registerSQL(prefix string) queryT { } func registerStmt( - db *sql.DB, - prefix string, + cfg dbconfigT, ) ( func(guuid.UUID, string, []byte, []byte) (userT, error), func() error, error, ) { - q := registerSQL(prefix) + q := registerSQL(cfg.prefix) - writeStmt, err := db.Prepare(q.write) + writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -389,12 +437,11 @@ func sendTokenSQL(prefix string) queryT { } func sendTokenStmt( - db *sql.DB, - prefix string, + cfg dbconfigT, ) (func(guuid.UUID, string) error, func() error, error) { - q := sendTokenSQL(prefix) + q := sendTokenSQL(cfg.prefix) - writeStmt, err := db.Prepare(q.write) + writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } @@ -444,22 +491,21 @@ func confirmSQL(prefix string) queryT { } func confirmStmt( - db *sql.DB, - prefix string, + cfg dbconfigT, ) (func(string, guuid.UUID) (sessionT, error), func() error, error) { - q := confirmSQL(prefix) + q := confirmSQL(cfg.prefix) - writeStmt, err := db.Prepare(q.write) + writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, g.WrapErrors(writeStmt.Close(), err) } - sessionStmt, err := db.Prepare(q.session) + sessionStmt, err := cfg.shared.Prepare(q.session) if err != nil { return nil, nil, g.WrapErrors( writeStmt.Close(), @@ -523,7 +569,6 @@ func confirmStmt( } func byEmailSQL(prefix string) queryT { - // FIXME: rewrite as LEFT JOIN? const tmpl_read = ` SELECT id, timestamp, uuid, salt, pwhash, ( CASE WHEN EXISTS ( @@ -544,12 +589,11 @@ func byEmailSQL(prefix string) queryT { } func byEmailStmt( - db *sql.DB, - prefix string, + cfg dbconfigT, ) (func(string) (userT, error), func() error, error) { - q := byEmailSQL(prefix) + q := byEmailSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -587,6 +631,65 @@ func byEmailStmt( return fn, readStmt.Close, nil } +func userByUUIDSQL(prefix string) queryT { + const tmpl_read = ` + SELECT id, timestamp, email, salt, pwhash, ( + CASE WHEN EXISTS ( + SELECT id FROM "%s_user_confirmations" + WHERE user_id = ( + SELECT id FROM "%s_users" + WHERE uuid = ? + ) + ) THEN 1 + ELSE 0 + END + ) as confirmed + FROM "%s_users" WHERE uuid = ?; + ` + return queryT{ + read: fmt.Sprintf(tmpl_read, prefix, prefix, prefix), + } +} + +func userByUUIDStmt( + cfg dbconfigT, +) (func(guuid.UUID) (userT, error), func() error, error) { + q := userByUUIDSQL(cfg.prefix) + + readStmt, err := cfg.shared.Prepare(q.read) + if err != nil { + return nil, nil, err + } + + fn := func(userID guuid.UUID) (userT, error) { + user := userT{ + uuid: userID, + } + + var timestr string + err := readStmt.QueryRow(userID[:], userID[:]).Scan( + &user.id, + ×tr, + &user.email, + &user.salt, + &user.pwhash, + &user.confirmed, + ) + if err != nil { + return userT{}, err + } + + user.timestamp, err = time.Parse(time.RFC3339Nano, timestr) + if err != nil { + return userT{}, err + } + + return user, nil + } + + return fn, readStmt.Close, nil +} + func loginSQL(prefix string) queryT { const tmpl_session = ` INSERT INTO "%s_sessions" (uuid, user_id) @@ -601,12 +704,11 @@ func loginSQL(prefix string) queryT { } func loginStmt( - db *sql.DB, - prefix string, + cfg dbconfigT, ) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) { - q := loginSQL(prefix) + q := loginSQL(cfg.prefix) - sessionStmt, err := db.Prepare(q.session) + sessionStmt, err := cfg.shared.Prepare(q.session) if err != nil { return nil, nil, err } @@ -671,12 +773,11 @@ func refreshSQL(prefix string) queryT { } func refreshStmt( - db *sql.DB, - prefix string, + cfg dbconfigT, ) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) { - q := refreshSQL(prefix) + q := refreshSQL(cfg.prefix) - writeStmt, err := db.Prepare(q.write) + writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } @@ -726,12 +827,11 @@ func resetSQL(prefix string) queryT { } func resetStmt( - db *sql.DB, - prefix string, + cfg dbconfigT, ) (func(int64, []byte, guuid.UUID) (sessionT, error), func() error, error) { - q := resetSQL(prefix) + q := resetSQL(cfg.prefix) - writeStmt, err := db.Prepare(q.write) + writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } @@ -758,12 +858,11 @@ func changeSQL(prefix string) queryT { } func changeStmt( - db *sql.DB, - prefix string, + cfg dbconfigT, ) (func(guuid.UUID, []byte) (sessionT, error), func() error, error) { - q := changeSQL(prefix) + q := changeSQL(cfg.prefix) - writeStmt, err := db.Prepare(q.write) + writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } @@ -790,12 +889,11 @@ func byUUIDSQL(prefix string) queryT { } func byUUIDStmt( - db *sql.DB, - prefix string, + cfg dbconfigT, ) (func(guuid.UUID) (sessionT, error), func() error, error) { - q := byUUIDSQL(prefix) + q := byUUIDSQL(cfg.prefix) - readStmt, err := db.Prepare(q.read) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } @@ -822,12 +920,11 @@ func logoutSQL(prefix string) queryT { } func logoutStmt( - db *sql.DB, - prefix string, + cfg dbconfigT, ) (func(guuid.UUID) error, func() error, error) { - q := logoutSQL(prefix) + q := logoutSQL(cfg.prefix) - writeStmt, err := db.Prepare(q.write) + writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } @@ -850,12 +947,11 @@ func outOthersSQL(prefix string) queryT { } func outOthersStmt( - db *sql.DB, - prefix string, + cfg dbconfigT, ) (func(guuid.UUID) error, func() error, error) { - q := outOthersSQL(prefix) + q := outOthersSQL(cfg.prefix) - writeStmt, err := db.Prepare(q.write) + writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } @@ -878,12 +974,11 @@ func outAllSQL(prefix string) queryT { } func outAllStmt( - db *sql.DB, - prefix string, + cfg dbconfigT, ) (func(guuid.UUID) error, func() error, error) { - q := outAllSQL(prefix) + q := outAllSQL(cfg.prefix) - writeStmt, err := db.Prepare(q.write) + writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { return nil, nil, err } @@ -897,29 +992,47 @@ func outAllStmt( } func initDB( - db *sql.DB, + dbpath string, prefix string, ) (queriesT, error) { - createTablesErr := createTables(db, prefix) - register, registerClose, registerErr := registerStmt(db, prefix) - sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix) - confirm, confirmClose, confirmErr := confirmStmt(db, prefix) - byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix) - login, loginClose, loginErr := loginStmt(db, prefix) - refresh, refreshClose, refreshErr := refreshStmt(db, prefix) - reset, resetClose, resetErr := resetStmt(db, prefix) - change, changeClose, changeErr := changeStmt(db, prefix) - byUUID, byUUIDClose, byUUIDErr := byUUIDStmt(db, prefix) - logout, logoutClose, logoutErr := logoutStmt(db, prefix) - outOthers, outOthersClose, outOthersErr := outOthersStmt(db, prefix) - outAll, outAllClose, outAllErr := outAllStmt(db, prefix) - - err := g.SomeError( + err := g.ValidateSQLTablePrefix(prefix) + if err != nil { + return queriesT{}, err + } + + shared, err := sql.Open(golite.DriverName, dbpath) + if err != nil { + return queriesT{}, err + } + + cfg := dbconfigT{ + shared: shared, + dbpath: dbpath, + prefix: prefix, + } + + createTablesErr := createTables(shared, prefix) + register, registerClose, registerErr := registerStmt(cfg) + sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(cfg) + confirm, confirmClose, confirmErr := confirmStmt(cfg) + byEmail, byEmailClose, byEmailErr := byEmailStmt(cfg) + userByUUID, userByUUIDClose, userByUUIDErr := userByUUIDStmt(cfg) + login, loginClose, loginErr := loginStmt(cfg) + refresh, refreshClose, refreshErr := refreshStmt(cfg) + reset, resetClose, resetErr := resetStmt(cfg) + change, changeClose, changeErr := changeStmt(cfg) + byUUID, byUUIDClose, byUUIDErr := byUUIDStmt(cfg) + logout, logoutClose, logoutErr := logoutStmt(cfg) + outOthers, outOthersClose, outOthersErr := outOthersStmt(cfg) + outAll, outAllClose, outAllErr := outAllStmt(cfg) + + err = g.SomeError( createTablesErr, registerErr, sendTokenErr, confirmErr, byEmailErr, + userByUUIDErr, loginErr, refreshErr, resetErr, @@ -935,18 +1048,19 @@ func initDB( close := func() error { return g.SomeFnError( - registerClose, - sendTokenClose, - confirmClose, - byEmailClose, - loginClose, - refreshClose, - resetClose, - changeClose, - byUUIDClose, - logoutClose, - outOthersClose, - outAllClose, + registerClose, + sendTokenClose, + confirmClose, + byEmailClose, + userByUUIDClose, + loginClose, + refreshClose, + resetClose, + changeClose, + byUUIDClose, + logoutClose, + outOthersClose, + outAllClose, ) } @@ -977,6 +1091,11 @@ func initDB( defer connMutex.RUnlock() return byEmail(a) }, + userByUUID: func(a guuid.UUID) (userT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return userByUUID(a) + }, login: func(a guuid.UUID, b guuid.UUID) (sessionT, error) { connMutex.RLock() defer connMutex.RUnlock() @@ -1179,13 +1298,9 @@ func NewWithPrefix(databasePath string, prefix string) (IAuth, error) { return authT{}, err } - db, err := sql.Open(golite.DriverName, databasePath) - if err != nil { - return authT{}, err - } - - queries, err := initDB(db, prefix) + queries, err := initDB(databasePath, prefix) if err != nil { + queue.Close() return authT{}, err } @@ -1193,13 +1308,13 @@ func NewWithPrefix(databasePath string, prefix string) (IAuth, error) { hasher, closeHasher := makePoolRunner(numCPU, asResult(scrypt.Hash)) closeFn := func() { + queue.Close() unregisterConsumers(queue, consumers, prefix) closeHasher() } auth := authT{ queue: queue, - db: db, queries: queries, hasher: unwrapResult(hasher), close: closeFn, @@ -1516,7 +1631,7 @@ func (auth authT) ResetPassword( } func (auth authT) ChangePassword( - user User, + userID guuid.UUID, currentPassword string, newPassword string, confirmNewPassword string, @@ -1529,20 +1644,25 @@ func (auth authT) ChangePassword( return Session{}, ErrTooShort } + user, err := auth.queries.userByUUID(userID) + if err != nil { + return Session{}, err + } + input := scrypt.HashInput{ Password: []byte(newPassword), - Salt: user.Salt, + Salt: user.salt, } pwhash, err := auth.hasher(input) if err != nil { return Session{}, err } - if !user.Confirmed { + if !user.confirmed { return Session{}, ErrUnconfirmed } - session, err := auth.queries.change(user.UUID, pwhash) + session, err := auth.queries.change(user.uuid, pwhash) if err != nil { return Session{}, nil } |