diff options
-rw-r--r-- | .gitignore | 10 | ||||
-rw-r--r-- | Makefile | 83 | ||||
-rw-r--r-- | deps.mk | 57 | ||||
-rwxr-xr-x | mkdeps.sh | 25 | ||||
-rw-r--r-- | src/gracha.go | 1316 | ||||
-rw-r--r-- | tests/benchmarks/register-login/gracha.go | 9 | ||||
l--------- | tests/benchmarks/register-login/main.go | 1 | ||||
-rw-r--r-- | tests/functional/register-twice/gracha.go | 9 | ||||
l--------- | tests/functional/register-twice/main.go | 1 | ||||
-rw-r--r-- | tests/fuzz/api/gracha.go | 35 | ||||
l--------- | tests/fuzz/api/main.go | 1 | ||||
-rw-r--r-- | tests/gracha.go | 56 | ||||
-rw-r--r-- | tests/queries.sql | 0 |
13 files changed, 992 insertions, 611 deletions
@@ -1,7 +1,15 @@ /src/version.go /*.bin -/*.db +/*.db* /src/*.a /src/*.bin /tests/*.a /tests/*.bin +/tests/functional/*/*.a +/tests/functional/*/*.bin +/tests/fuzz/*/*.a +/tests/fuzz/*/*.bin +/tests/benchmarks/*/*.a +/tests/benchmarks/*/*.bin +/tests/benchmarks/*/*.txt +/tests/fuzz/corpus/ @@ -17,7 +17,7 @@ MANDIR = $(SHAREDIR)/man EXEC = ./ ## Where to store the installation. Empty by default. DESTDIR = -LDLIBS = -lsqlite3 +LDLIBS = --static -lscrypt-kdf -lsqlite3 -lm GOCFLAGS = -I $(GOLIBDIR) GOLDFLAGS = -L $(GOLIBDIR) @@ -26,17 +26,26 @@ GOLDFLAGS = -L $(GOLIBDIR) .SUFFIXES: .SUFFIXES: .go .a .bin .bin-check +.go.a: + go tool compile $(GOCFLAGS) -I $(@D) -o $@ -p $(*F) \ + `find $< $$(if [ $(*F) != main ]; then \ + echo src/$(NAME).go src/version.go; fi) | uniq` + +.a.bin: + go tool link $(GOLDFLAGS) -L $(@D) -o $@ --extldflags '$(LDLIBS)' $< + all: include deps.mk -objects = \ - src/$(NAME).a \ - src/main.a \ - tests/$(NAME).a \ - tests/main.a \ +libs.a = $(libs.go:.go=.a) +mains.a = $(mains.go:.go=.a) +mains.bin = $(mains.go:.go=.bin) +functional-tests/lib.a = $(functional-tests/lib.go:.go=.a) +fuzz-targets/lib.a = $(fuzz-targets/lib.go:.go=.a) +benchmarks/lib.a = $(benchmarks/lib.go:.go=.a) sources = \ src/$(NAME).go \ @@ -46,13 +55,16 @@ sources = \ derived-assets = \ src/version.go \ - $(objects) \ - src/main.bin \ - tests/main.bin \ + $(libs.a) \ + $(mains.a) \ + $(mains.bin) \ $(NAME).bin \ side-assets = \ $(NAME).db* \ + tests/functional/*/*.go.db* \ + tests/fuzz/corpus/ \ + tests/benchmarks/*/main.txt \ @@ -61,40 +73,35 @@ side-assets = \ all: $(derived-assets) -$(objects): Makefile +$(libs.a): Makefile deps.mk +$(libs.a): src/$(NAME).go src/version.go -src/$(NAME).a: src/$(NAME).go src/version.go - go tool compile $(GOCFLAGS) -o $@ -p $(*F) -I $(@D) $*.go src/version.go -src/main.a: src/main.go src/$(NAME).a -tests/main.a: tests/main.go tests/$(NAME).a -src/main.a tests/main.a: - go tool compile $(GOCFLAGS) -o $@ -p $(*F) -I $(@D) $*.go +$(fuzz-targets/lib.a): + go tool compile $(GOCFLAGS) -o $@ -p $(NAME) -d=libfuzzer \ + src/version.go $*.go src/$(NAME).go -tests/$(NAME).a: tests/$(NAME).go src/$(NAME).go src/version.go - go tool compile $(GOCFLAGS) -o $@ -p $(*F) $*.go src/$(*F).go src/version.go - -src/main.bin: src/main.a -tests/main.bin: tests/main.a -src/main.bin tests/main.bin: - go tool link $(GOLDFLAGS) -o $@ -L $(@D) --extldflags '$(LDLIBS)' $*.a +src/version.go: Makefile + echo 'package $(NAME); const Version = "$(VERSION)"' > $@ $(NAME).bin: src/main.bin ln -fs $? $@ -src/version.go: Makefile - echo 'package $(NAME); const Version = "$(VERSION)"' > $@ +.PRECIOUS: tests/queries.sql +tests/queries.sql: tests/main.bin ALWAYS + env TESTING_DUMP_SQL_QUERIES=1 $(EXEC)tests/main.bin | diff -U10 $@ - tests.bin-check = \ - tests/main.bin-check \ + tests/main.bin-check \ + $(functional-tests/main.go:.go=.bin-check) \ -tests/main.bin-check: tests/main.bin $(tests.bin-check): $(EXEC)$*.bin check-unit: $(tests.bin-check) +check-unit: tests/queries.sql integration-tests = \ @@ -107,6 +114,7 @@ $(integration-tests): ALWAYS sh $@ check-integration: $(integration-tests) +check-integration: fuzz ## Run all tests. Each test suite is isolated, so that a parallel @@ -116,6 +124,27 @@ check: check-unit check-integration +FUZZSEC=1 +fuzz-targets/main.bin-check = $(fuzz-targets/main.go:.go=.bin-check) +$(fuzz-targets/main.bin-check): + $(EXEC)$*.bin --test.fuzztime=$(FUZZSEC)s \ + --test.fuzz='.*' --test.fuzzcachedir=tests/fuzz/corpus + +fuzz: $(fuzz-targets/main.bin-check) + + + +benchmarks/main.bin-check = $(benchmarks/main.go:.go=.bin-check) +$(benchmarks/main.bin-check): + rm -f $*.txt + printf '%s\n' '$(EXEC)$*.bin' >> $*.txt + LANG=POSIX.UTF-8 time -p $(EXEC)$*.bin 2>> $*.txt + printf '%s\n' '$*.txt' + +bench: $(benchmarks/main.bin-check) + + + ## Remove *all* derived artifacts produced during the build. ## A dedicated test asserts that this is always true. clean: @@ -0,0 +1,57 @@ +libs.go = \ + src/gracha.go \ + tests/benchmarks/register-login/gracha.go \ + tests/functional/register-twice/gracha.go \ + tests/fuzz/api/gracha.go \ + tests/gracha.go \ + +mains.go = \ + src/main.go \ + tests/benchmarks/register-login/main.go \ + tests/functional/register-twice/main.go \ + tests/fuzz/api/main.go \ + tests/main.go \ + +functional-tests/libs.go = \ + tests/functional/register-twice/gracha.go \ + +functional-tests/main.go = \ + tests/functional/register-twice/main.go \ + +fuzz-targets/lib.go = \ + tests/fuzz/api/gracha.go \ + +fuzz-targets/main.go = \ + tests/fuzz/api/main.go \ + +benchmarks/lib.go = \ + tests/benchmarks/register-login/gracha.go \ + +benchmarks/main.go = \ + tests/benchmarks/register-login/main.go \ + +src/gracha.a: src/gracha.go +src/main.a: src/main.go +tests/benchmarks/register-login/gracha.a: tests/benchmarks/register-login/gracha.go +tests/benchmarks/register-login/main.a: tests/benchmarks/register-login/main.go +tests/functional/register-twice/gracha.a: tests/functional/register-twice/gracha.go +tests/functional/register-twice/main.a: tests/functional/register-twice/main.go +tests/fuzz/api/gracha.a: tests/fuzz/api/gracha.go +tests/fuzz/api/main.a: tests/fuzz/api/main.go +tests/gracha.a: tests/gracha.go +tests/main.a: tests/main.go +src/main.bin: src/main.a +tests/benchmarks/register-login/main.bin: tests/benchmarks/register-login/main.a +tests/functional/register-twice/main.bin: tests/functional/register-twice/main.a +tests/fuzz/api/main.bin: tests/fuzz/api/main.a +tests/main.bin: tests/main.a +src/main.bin-check: src/main.bin +tests/benchmarks/register-login/main.bin-check: tests/benchmarks/register-login/main.bin +tests/functional/register-twice/main.bin-check: tests/functional/register-twice/main.bin +tests/fuzz/api/main.bin-check: tests/fuzz/api/main.bin +tests/main.bin-check: tests/main.bin +src/main.a: src/$(NAME).a +tests/benchmarks/register-login/main.a: tests/benchmarks/register-login/$(NAME).a +tests/functional/register-twice/main.a: tests/functional/register-twice/$(NAME).a +tests/fuzz/api/main.a: tests/fuzz/api/$(NAME).a +tests/main.a: tests/$(NAME).a @@ -2,3 +2,28 @@ set -eu export LANG=POSIX.UTF-8 + + +libs() { + find src tests -name '*.go' | grep -v '/main\.go$' | + grep -v '/version\.go$' +} + +mains() { + find src tests -name '*.go' | grep '/main\.go$' +} + +libs | varlist 'libs.go' +mains | varlist 'mains.go' + +find tests/functional/*/*.go -not -name main.go | varlist 'functional-tests/libs.go' +find tests/functional/*/main.go | varlist 'functional-tests/main.go' +find tests/fuzz/*/*.go -not -name main.go | varlist 'fuzz-targets/lib.go' +find tests/fuzz/*/main.go | varlist 'fuzz-targets/main.go' +find tests/benchmarks/*/*.go -not -name main.go | varlist 'benchmarks/lib.go' +find tests/benchmarks/*/main.go | varlist 'benchmarks/main.go' + +{ libs; mains; } | sort | sed 's/^\(.*\)\.go$/\1.a:\t\1.go/' +mains | sort | sed 's/^\(.*\)\.go$/\1.bin:\t\1.a/' +mains | sort | sed 's/^\(.*\)\.go$/\1.bin-check:\t\1.bin/' +mains | sort | sed 's|^\(.*\)/main\.go$|\1/main.a:\t\1/$(NAME).a|' diff --git a/src/gracha.go b/src/gracha.go index dcff98d..e300f66 100644 --- a/src/gracha.go +++ b/src/gracha.go @@ -1,48 +1,71 @@ package gracha import ( - "crypto/rand" + "context" "database/sql" "encoding/hex" "encoding/json" "errors" "fmt" + "runtime" + "sync" "time" "guuid" - "liteq" + "q" "scrypt" g "gobang" ) -type tablesT struct{ - users string - userChanges string - tokens string - roles string - roleChanges string - sessions string - attempts string - audit string +const ( + defaultPrefix = "gracha" + + rollbackErrorFmt = "rollback error: %w; while executing: %w" + NEW_USER = "new-user" + SEND_CONFIRMATION_REQUEST = "send-confirmation-request" + FORGOT_PASSWORD_REQUEST = "forgot-password-request" + + + day = 24 * time.Hour +) + +var ( + SessionDuration = 7 * day + RegisterTimeout = 15 * time.Second + + ErrPassMismatch = errors.New("gracha: password confirmation mismatch") + ErrTooShort = errors.New("gracha: password too short") + ErrTimeout = errors.New("gracha: timeout when creating user") + ErrRegistered = errors.New("gracha: user already registered") + ErrBadCombo = errors.New("gracha: bad username/passphrase combo") + ErrUnconfirmed = errors.New("gracha: email is not confirmed") + ErrRevokedSession = errors.New("gracha: this session was revoked") + ErrSessionExpired = errors.New("gracha: session expired") +) + + + +type queryT struct{ + write string + read string } type queriesT struct{ - userByEmail func(string) (userT, error) - userByToken func([]byte) (userT, error) - register func(string, []byte, []byte) (userT, error) - confirm func([]byte) (sessionT, error) - login func(string, string) (sessionT, error) - refresh func([]byte) (sessionT, error) - resetPassword func(int64, []byte, []byte) (sessionT, error) - resetAndConfirm func(int64, []byte, []byte) (sessionT, error) - changePassword func(int64, []byte) (sessionT, error) - sessionByUUID func([]byte) (sessionT, error) - logout func([]byte) error - logoutOthers func([]byte) error - logoutAll func([]byte) error - close func() error + byEmail func(string) (userT, error) + byToken func(guuid.UUID) (userT, error) + register func(string, []byte, []byte) (userT, error) + confirm func(guuid.UUID) (sessionT, error) + login func(string, string) (sessionT, error) + refresh func(guuid.UUID) (sessionT, error) + reset func(int64, []byte, guuid.UUID) (sessionT, error) + change func(int64, []byte) (sessionT, error) + byUUID func(guuid.UUID) (sessionT, error) + logout func(guuid.UUID) error + outOthers func(guuid.UUID) error + outAll func(guuid.UUID) error + close func() error } type confirmationT struct{ @@ -50,17 +73,14 @@ type confirmationT struct{ type userT struct{ id int64 - timestr string timestamp time.Time - uuid []byte + uuid guuid.UUID email string username *string salt []byte pwhash []byte - // confirmation confirmed_at *time.Time - metadatastr *string - metadata *map[string]interface{} + metadata map[string]interface{} } type sessionT struct{ @@ -70,240 +90,284 @@ type sessionT struct{ uuid guuid.UUID user_id int64 type_ string - revokedstr *string revoked_at *time.Time - metadatastr *string - metadata *map[string]interface{} -} - -type Auth struct{ - tables tablesT - queries queriesT - q liteq.Queue + metadata map[string]interface{} } type consumerT struct{ topic string - handlerFn func(Auth) func([]byte) error + handlerFn func(authT) func(q.Message) error } +type authT struct{ + queries queriesT + queue q.IQueue + hasher func(scrypt.HashInput) resultT[[]byte] + checker func(scrypt.CheckInput) resultT[bool] + close func() +} +type IAuth interface{ + Register(string, string, string) (userT, error) + ResendConfirmation(string) error + ConfirmEmail(guuid.UUID) (sessionT, error) + LoginEmail(string, string) (sessionT, error) + ForgotPassword(string) error + Refresh(sessionT) (sessionT, error) + ResetPassword(guuid.UUID, string, string) (sessionT, error) + ChangePassword(userT, string, string, string) (sessionT, error) + Logout(sessionT) error + LogoutOthers(sessionT) error + LogoutAll(sessionT) error + Close() error +} -const ( - NEW_USER = "new-user" - SEND_CONFIRMATION_REQUEST = "send-confirmation-request" - FORGOT_PASSWORD_REQUEST = "forgot-password-request" - defaultPrefix = "gracha" - day = 24 * time.Hour -) +func tryRollback(db *sql.DB, ctx context.Context, err error) error { + _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;") + if rollbackErr != nil { + return fmt.Errorf( + rollbackErrorFmt, + rollbackErr, + err, + ) + } -var ( - SessionDuration = 7 * day - RegisterTimeout = 15 * time.Second + return err +} - ErrPasswordMismatch = errors.New("gracha: password and its confirmation don't match") - ErrPasswordTooShort = errors.New("gracha: bad username/passphrase combo") - ErrRegisterTimeout = errors.New("gracha: timeout when creating user") - ErrAlreadyRegistered = errors.New("gracha: user already registered") - ErrEmailPasswordCombo = errors.New("gracha: bad username/passphrase combo") - ErrUnconfirmedUser = errors.New("gracha: user email is not confirmed") - ErrRevokedSession = errors.New("gracha: this session was revoked") - ErrSessionExpired = errors.New("gracha: session expired") -) +func inTx(db *sql.DB, fn func(context.Context) error) error { + ctx := context.Background() + + _, err := db.ExecContext(ctx, "BEGIN IMMEDIATE;") + if err != nil { + return err + } + err = fn(ctx) + if err != nil { + return tryRollback(db, ctx, err) + } + _, err = db.ExecContext(ctx, "COMMIT;") + if err != nil { + return tryRollback(db, ctx, err) + } -func tablesFrom(prefix string) (tablesT, error) { - if !g.ValidSQLTablePrefix(prefix) { - return tablesT{}, g.ErrBadSQLTablePrefix - } - - users := prefix + "-users" - userChanges := prefix + "-user-changes" - tokens := prefix + "-tokens" - roles := prefix + "-roles" - roleChanges := prefix + "-role-changes" - sessions := prefix + "-sessions" - attempts := prefix + "-attempts" - audit := prefix + "-audit" - return tablesT{ - users: users, - userChanges: userChanges, - tokens: tokens, - roles: roles, - roleChanges: roleChanges, - sessions: sessions, - attempts: attempts, - audit: audit, - }, nil + return nil } -func createTables(db *sql.DB, tables tablesT) error { - const tmpl = ` - BEGIN TRANSACTION; - CREATE TABLE IF NOT EXISTS "%s" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - uuid BLOB NOT NULL UNIQUE, - email TEXT NOT NULL UNIQUE, - username TEXT UNIQUE, - salt BLOB NOT NULL UNIQUE, - pwhash BLOB NOT NULL, - confirmed_at TEXT, - confirmer_id INTEGER REFERENCES "%s"(id), - metadata TEXT +func createTablesSQL(prefix string) queryT { + const tmpl_write = ` + CREATE TABLE IF NOT EXISTS "%s_users" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + uuid BLOB NOT NULL UNIQUE, + email TEXT NOT NULL UNIQUE, + username TEXT UNIQUE, + salt BLOB NOT NULL UNIQUE, + pwhash BLOB NOT NULL, + confirmed_at TEXT, + confirmer_id INTEGER REFERENCES "%s"(id), + metadata TEXT ); - CREATE TABLE IF NOT EXISTS "%s" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - user_id INTEGER NOT NULL REFERENCES "%s"(id), - attribute TEXT NOT NULL, - value TEXT NOT NULL, - op BOOLEAN NOT NULL + CREATE TABLE IF NOT EXISTS "%s_user_changes" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + user_id INTEGER NOT NULL REFERENCES "%s_users"(id), + attribute TEXT NOT NULL, + value TEXT NOT NULL, + op BOOLEAN NOT NULL ); - CREATE TABLE IF NOT EXISTS "%s" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - uuid BLOB NOT NULL UNIQUE, - type TEXT NOT NULL + CREATE TABLE IF NOT EXISTS "%s_tokens" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + uuid BLOB NOT NULL UNIQUE, + type TEXT NOT NULL ); - CREATE TABLE IF NOT EXISTS "%s" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - user_id INTEGER NOT NULL REFERENCES "%s"(id), - role TEXT NOT NULL, + CREATE TABLE IF NOT EXISTS "%s_roles" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL REFERENCES "%s_users"(id), + role TEXT NOT NULL, + metadata TEXT, UNIQUE (user_id, role) ); - CREATE TABLE IF NOT EXISTS "%s" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - user_id INTEGER NOT NULL REFERENCES "%s"(id), - role TEXT NOT NULL, - op BOOLEAN NOT NULL + CREATE TABLE IF NOT EXISTS "%s_role_changes" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + user_id INTEGER NOT NULL REFERENCES "%s_roles"(id), + role TEXT NOT NULL, + op BOOLEAN NOT NULL ); - CREATE TABLE IF NOT EXISTS "%s" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - uuid BLOB NOT NULL UNIQUE, - user_id INTEGER NOT NULL REFERENCES "%s"(id), - type TEXT NOT NULL, - revoked_at TEXT, - revoker_id INTEGER REFERENCES "%s"(id), - metadata TEXT + CREATE TABLE IF NOT EXISTS "%s_sessions" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + uuid BLOB NOT NULL UNIQUE, + user_id INTEGER NOT NULL REFERENCES "%s_users"(id), + type TEXT NOT NULL, + revoked_at TEXT, + revoker_id INTEGER REFERENCES "%s_users"(id), + metadata TEXT ); - CREATE TABLE IF NOT EXISTS "%s" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - user_id INTEGER REFERENCES "%s"(id), - session_id INTEGER REFERENCES "%s"(id), - metadata TEXT + CREATE TABLE IF NOT EXISTS "%s_attempts" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + user_id INTEGER REFERENCES "%s_users"(id), + session_id INTEGER REFERENCES "%s_sessions"(id), + metadata TEXT ); - CREATE TABLE IF NOT EXISTS "%s" ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - uuid BLOB NOT NULL UNIQUE, - attribute TEXT NOT NULL, - value TEXT NOT NULL, - op BOOLEAN NOT NULL, - metadata TEXT + CREATE TABLE IF NOT EXISTS "%s_audit" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT (%s), + uuid BLOB NOT NULL UNIQUE, + attribute TEXT NOT NULL, + value TEXT NOT NULL, + op BOOLEAN NOT NULL, + metadata TEXT ); - COMMIT TRANSACTION; ` - sql := fmt.Sprintf( - tmpl, - tables.users, - g.SQLiteNow, - tables.tokens, - tables.userChanges, - g.SQLiteNow, - tables.users, - tables.tokens, - g.SQLiteNow, - tables.roles, - tables.users, - tables.roleChanges, - g.SQLiteNow, - tables.users, - tables.sessions, - g.SQLiteNow, - tables.users, - tables.sessions, - tables.attempts, - g.SQLiteNow, - tables.users, - tables.sessions, - tables.audit, - g.SQLiteNow, - ) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf( + tmpl_write, + prefix, + g.SQLiteNow, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + prefix, + g.SQLiteNow, + ), + } +} - _, err := db.Exec(sql) - return err +func createTables(db *sql.DB, prefix string) error { + q := createTablesSQL(prefix) + + return inTx(db, func(ctx context.Context) error { + _, err := db.ExecContext(ctx, q.write) + return err + }) } -func userByEmailQuery( - db *sql.DB, - tables tablesT, -) (func(string) (userT, error), func() error, error) { - const tmpl = ` +func byEmailSQL(prefix string) queryT { + const tmpl_read = ` SELECT id, timestamp, uuid, email, username, pwhash, metadata - FROM "%s" WHERE email = ?; + FROM "%s_users" WHERE email = ?; ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + read: fmt.Sprintf(tmpl_read, prefix), + } +} + +func byEmailStmt( + db *sql.DB, + prefix string, +) (func(string) (userT, error), func() error, error) { + q := byEmailSQL(prefix) - stmt, err := db.Prepare(sql) + readStmt, err := db.Prepare(q.read) if err != nil { return nil, nil, err } fn := func(email string) (userT, error) { - var user userT - err := stmt.QueryRow(email).Scan(&user.id) // FIXME: build user - return user, err + user := userT{ + email: email, + } + + var ( + timestr string + uuid_bytes []byte + ) + err := readStmt.QueryRow(email).Scan( + &user.id, + ×tr, + &uuid_bytes, + &user.username, + &user.pwhash, + &user.metadata, + ) + if err != nil { + return userT{}, err + } + user.uuid = guuid.UUID(uuid_bytes) + + user.timestamp, err = time.Parse(time.RFC3339Nano,timestr) + if err != nil { + return userT{}, err + } + + return user, nil } - return fn, stmt.Close, nil + return fn, readStmt.Close, nil } -func userByTokenQuery( - db *sql.DB, - tables tablesT, -) (func([]byte) (userT, error), func() error, error) { - const tmpl = ` +func byTokenSQL(prefix string) queryT{ + const tmpl_read = ` SELECT id, timestamp, uuid, email, username, pwhash, metadata FROM "%s" WHERE email = ?; ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + read: fmt.Sprintf(tmpl_read, prefix), + } +} - stmt, err := db.Prepare(sql) +func byTokenStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) (userT, error), func() error, error) { + q := byTokenSQL(prefix) + + readStmt, err := db.Prepare(q.read) if err != nil { return nil, nil, err } - fn := func(token []byte) (userT, error) { + fn := func(token guuid.UUID) (userT, error) { var user userT - err := stmt.QueryRow(token).Scan(&user.id) // FIXME: build user + // FIXME: build user + err := readStmt.QueryRow(token).Scan(&user.id) return user, err } - return fn, stmt.Close, nil + return fn, readStmt.Close, nil } -func registerQuery( - db *sql.DB, - tables tablesT, -) (func(string, []byte, []byte) (userT, error), func() error, error) { - const tmpl = ` +func registerSQL(prefix string) queryT { + const tmpl_write = ` INSERT INTO "%s" (uuid, email, username, salt, pwhash, metadata) VALUES (?, ?, ?, ?, ?, ?); ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} - stmt, err := db.Prepare(sql) +func registerStmt( + db *sql.DB, + prefix string, +) (func(string, []byte, []byte) (userT, error), func() error, error) { + q := registerSQL(prefix) + + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } @@ -320,8 +384,8 @@ func registerQuery( var user userT // err := stmt.QueryRow( - ret, err := stmt.Exec( - guuid.NewBytes(), + ret, err := writeStmt.Exec( + guuid.New(), email, "credentials.username", salt, @@ -337,429 +401,630 @@ func registerQuery( return user, err } - return fn, stmt.Close, nil + return fn, writeStmt.Close, nil } -func loginQuery( - db *sql.DB, - tables tablesT, -) (func(string, string) (sessionT, error), func() error, error) { - const tmpl = ` - -- INSERT INTO "%s" (t3, t4) VALUES (?, ?); +func confirmSQL(prefix string) queryT { + const tmpl_write = ` + -- INSERT SOMETHING %s ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} + +func confirmStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) (sessionT, error), func() error, error) { + q := confirmSQL(prefix) - stmt, err := db.Prepare(sql) + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } - fn := func(email string, pwhash string) (sessionT, error) { + fn := func(token guuid.UUID) (sessionT, error) { var session sessionT - err := stmt.QueryRow(email, pwhash).Scan(session) - // FIXME: finish + err := writeStmt.QueryRow(token).Scan(&session) return session, err } - return fn, stmt.Close, nil + return fn, writeStmt.Close, nil } -func refreshQuery( - db *sql.DB, - tables tablesT, -) (func([]byte) (sessionT, error), func() error, error) { - const tmpl = ` - -- INSERT SOMETHING %s +func loginSQL(prefix string) queryT { + const tmpl_write = ` + -- INSERT INTO "%s" (t3, t4) VALUES (?, ?); ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} + +func loginStmt( + db *sql.DB, + prefix string, +) (func(string, string) (sessionT, error), func() error, error) { + q := loginSQL(prefix) - stmt, err := db.Prepare(sql) + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } - fn := func(uuid []byte) (sessionT, error) { + fn := func(email string, pwhash string) (sessionT, error) { var session sessionT - err := stmt.QueryRow(uuid).Scan(&session) + err := writeStmt.QueryRow(email, pwhash).Scan(session) + // FIXME: finish return session, err } - return fn, stmt.Close, nil + return fn, writeStmt.Close, nil } -func resetPasswordQuery( - db *sql.DB, - tables tablesT, -) (func(int64, []byte, []byte) (sessionT, error), func() error, error) { - const tmpl = ` +func refreshSQL(prefix string) queryT { + const tmpl_write = ` -- INSERT SOMETHING %s ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} - stmt, err := db.Prepare(sql) +func refreshStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) (sessionT, error), func() error, error) { + q := refreshSQL(prefix) + + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } - fn := func(id int64, pwhash []byte, token []byte) (sessionT, error) { + fn := func(uuid guuid.UUID) (sessionT, error) { var session sessionT - err := stmt.QueryRow(id, pwhash, token).Scan(&session) + err := writeStmt.QueryRow(uuid).Scan(&session) return session, err } - return fn, stmt.Close, nil + return fn, writeStmt.Close, nil } -func resetAndConfirmQuery( - db *sql.DB, - tables tablesT, -) (func(int64, []byte, []byte) (sessionT, error), func() error, error) { - const tmpl = ` +func resetSQL(prefix string) queryT { + const tmpl_write = ` -- INSERT SOMETHING %s ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} - stmt, err := db.Prepare(sql) +func resetStmt( + db *sql.DB, + prefix string, +) (func(int64, []byte, guuid.UUID) (sessionT, error), func() error, error) { + q := resetSQL(prefix) + + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } - fn := func(id int64, pwhash []byte, token []byte) (sessionT, error) { + fn := func(id int64, pwhash []byte, token guuid.UUID) (sessionT, error) { var session sessionT - err := stmt.QueryRow(id, pwhash, token).Scan(&session) + err := writeStmt.QueryRow(id, pwhash, token).Scan(&session) return session, err } - return fn, stmt.Close, nil + return fn, writeStmt.Close, nil } -func changePasswordQuery( - db *sql.DB, - tables tablesT, -) (func(int64, []byte) (sessionT, error), func() error, error) { - const tmpl = ` +func changeSQL(prefix string) queryT { + const tmpl_write = ` -- INSERT SOMETHING %s ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} - stmt, err := db.Prepare(sql) +func changeStmt( + db *sql.DB, + prefix string, +) (func(int64, []byte) (sessionT, error), func() error, error) { + q := changeSQL(prefix) + + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } fn := func(id int64, pwhash []byte) (sessionT, error) { var session sessionT - err := stmt.QueryRow(id, pwhash).Scan(&session) + err := writeStmt.QueryRow(id, pwhash).Scan(&session) return session, err } - return fn, stmt.Close, nil + return fn, writeStmt.Close, nil } -func sessionByUUIDQuery( - db *sql.DB, - tables tablesT, -) (func([]byte) (sessionT, error), func() error, error) { - const tmpl = ` +func byUUIDSQL(prefix string) queryT { + const tmpl_read = ` -- INSERT SOMETHING %s ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + read: fmt.Sprintf(tmpl_read, prefix), + } +} + +func byUUIDStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) (sessionT, error), func() error, error) { + q := byUUIDSQL(prefix) - stmt, err := db.Prepare(sql) + readStmt, err := db.Prepare(q.read) if err != nil { return nil, nil, err } - fn := func(uuid []byte) (sessionT, error) { + fn := func(uuid guuid.UUID) (sessionT, error) { var session sessionT - err := stmt.QueryRow(uuid).Scan(&session) + err := readStmt.QueryRow(uuid).Scan(&session) return session, err } - return fn, stmt.Close, nil + return fn, readStmt.Close, nil } -func logoutQuery( - db *sql.DB, - tables tablesT, -) (func([]byte) error, func() error, error) { - const tmpl = ` +func logoutSQL(prefix string) queryT { + const tmpl_write = ` -- INSERT SOMETHING %s ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} + +func logoutStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) error, func() error, error) { + q := logoutSQL(prefix) - stmt, err := db.Prepare(sql) + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } - fn := func(uuid []byte) error { - _, err := stmt.Exec(uuid) + fn := func(uuid guuid.UUID) error { + _, err := writeStmt.Exec(uuid) return err } - return fn, stmt.Close, nil + return fn, writeStmt.Close, nil } -func logoutOthersQuery( - db *sql.DB, - tables tablesT, -) (func([]byte) error, func() error, error) { - const tmpl = ` +func outOthersSQL(prefix string) queryT { + const tmpl_write = ` -- INSERT SOMETHING %s ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} - stmt, err := db.Prepare(sql) +func outOthersStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) error, func() error, error) { + q := outOthersSQL(prefix) + + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } - fn := func(uuid []byte) error { - _, err := stmt.Exec(uuid) + fn := func(uuid guuid.UUID) error { + _, err := writeStmt.Exec(uuid) return err } - return fn, stmt.Close, nil + return fn, writeStmt.Close, nil } -func logoutAllQuery( - db *sql.DB, - tables tablesT, -) (func([]byte) error, func() error, error) { - const tmpl = ` +func outAllSQL(prefix string) queryT { + const tmpl_write = ` -- INSERT SOMETHING %s ` - sql := fmt.Sprintf(tmpl, tables.users) - /// fmt.Println(sql) /// + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} - stmt, err := db.Prepare(sql) +func outAllStmt( + db *sql.DB, + prefix string, +) (func(guuid.UUID) error, func() error, error) { + q := outAllSQL(prefix) + + writeStmt, err := db.Prepare(q.write) if err != nil { return nil, nil, err } - fn := func(uuid []byte) error { - _, err := stmt.Exec(uuid) + fn := func(uuid guuid.UUID) error { + _, err := writeStmt.Exec(uuid) return err } - return fn, stmt.Close, nil -} - -func initDB(db *sql.DB, tables tablesT) (queriesT, error) { - createTablesErr := createTables(db, tables) - userByEmail, userByEmailClose, userByEmailErr := userByEmailQuery(db, tables) - userByToken, userByTokenClose, userByTokenErr := userByTokenQuery(db, tables) - register, registerClose, registerErr := registerQuery(db, tables) - login, loginClose, loginErr := loginQuery(db, tables) - refresh, refreshClose, refreshErr := refreshQuery(db, tables) - resetPassword, resetPasswordClose, resetPasswordErr := resetPasswordQuery(db, tables) - resetAndConfirm, resetAndConfirmClose, resetAndConfirmErr := resetAndConfirmQuery(db, tables) - changePassword, changePasswordClose, changePasswordErr := changePasswordQuery(db, tables) - sessionByUUID, sessionByUUIDClose, sessionByUUIDErr := sessionByUUIDQuery(db, tables) - logout, logoutClose, logoutErr := logoutQuery(db, tables) - logoutOthers, logoutOthersClose, logoutOthersErr := logoutOthersQuery(db, tables) - logoutAll, logoutAllClose, logoutAllErr := logoutAllQuery(db, tables) - - errs := []error { - createTablesErr, - userByEmailErr, - userByTokenErr, - registerErr, - loginErr, - refreshErr, - resetPasswordErr, - resetAndConfirmErr, - changePasswordErr, - sessionByUUIDErr, - logoutErr, - logoutOthersErr, - logoutAllErr, - } - err := g.SomeError(errs) + return fn, writeStmt.Close, nil +} + +func initDB( + db *sql.DB, + prefix string, +) (queriesT, error) { + createTablesErr := createTables(db, prefix) + byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix) + byToken, byTokenClose, byTokenErr := byTokenStmt(db, prefix) + register, registerClose, registerErr := registerStmt(db, prefix) + confirm, confirmClose, confirmErr := confirmStmt(db, prefix) + login, loginClose, loginErr := loginStmt(db, prefix) + refresh, refreshClose, refreshErr := refreshStmt(db, prefix) + reset, resetClose, resetErr := resetStmt(db, prefix) + change, changeClose, changeErr := changeStmt(db, prefix) + byUUID, byUUIDClose, byUUIDErr := byUUIDStmt(db, prefix) + logout, logoutClose, logoutErr := logoutStmt(db, prefix) + outOthers, outOthersClose, outOthersErr := outOthersStmt(db, prefix) + outAll, outAllClose, outAllErr := outAllStmt(db, prefix) + + err := g.SomeError( + createTablesErr, + byEmailErr, + byTokenErr, + registerErr, + confirmErr, + loginErr, + refreshErr, + resetErr, + changeErr, + byUUIDErr, + logoutErr, + outOthersErr, + outAllErr, + ) if err != nil { return queriesT{}, err } close := func() error { - fns := [](func() error){ - userByEmailClose, - userByTokenClose, - registerClose, - loginClose, - refreshClose, - resetPasswordClose, - resetAndConfirmClose, - changePasswordClose, - sessionByUUIDClose, - logoutClose, - logoutOthersClose, - logoutAllClose, - } - return g.SomeFnError(fns) + return g.SomeFnError( + byEmailClose, + byTokenClose, + registerClose, + confirmClose, + loginClose, + refreshClose, + resetClose, + changeClose, + byUUIDClose, + logoutClose, + outOthersClose, + outAllClose, + ) } + // FIXME: lock return queriesT{ - userByEmail: userByEmail, - userByToken: userByToken, - register: register, - login: login, - close: close, - refresh: refresh, - resetPassword: resetPassword, - resetAndConfirm: resetAndConfirm, - changePassword: changePassword, - sessionByUUID: sessionByUUID, - logout: logout, - logoutOthers: logoutOthers, - logoutAll: logoutAll, + byEmail: byEmail, + byToken: byToken, + register: register, + confirm: confirm, + login: login, + refresh: refresh, + reset: reset, + change: change, + byUUID: byUUID, + logout: logout, + outOthers: outOthers, + outAll: outAll, + close: close, }, nil } -func newUserPayload(email string, salt []byte, pwhash []byte) ([]byte, error) { - data := make(map[string]interface{}) - data["email"] = email - data["salt"] = hex.EncodeToString(salt) - data["pwhash"] = hex.EncodeToString(pwhash) - return json.Marshal(data) +func newUserHandler(auth authT) func(q.Message) error { + return func(message q.Message) error { + return nil + } } -func publishRegister( - q liteq.Queue, - email string, - salt []byte, - pwhash []byte, - flowId []byte, -) error { - payload, err := newUserPayload(email, salt, pwhash) - if err != nil { - return err +func sendConfirmationRequestHandler(auth authT) func(q.Message) error { + return func(message q.Message) error { + return nil } +} - return q.Publish(NEW_USER, payload, flowId) +func forgotPasswordRequestHandler(auth authT) func(q.Message) error { + return func(message q.Message) error { + return nil + } } -func register( - auth Auth, - email string, - salt []byte, - pwhash []byte, -) (userT, error) { - flowId := guuid.NewBytes() - waiter := auth.q.WaitFor(NEW_USER, flowId) - defer auth.q.Unwait(waiter) +var consumers = []consumerT{ + consumerT{ + topic: NEW_USER, + handlerFn: newUserHandler, + }, + consumerT{ + topic: SEND_CONFIRMATION_REQUEST, + handlerFn: sendConfirmationRequestHandler, + }, + consumerT{ + topic: FORGOT_PASSWORD_REQUEST, + handlerFn: forgotPasswordRequestHandler, + }, +} +func registerConsumers(auth authT, consumers []consumerT) { + for _, consumer := range consumers { + auth.queue.Subscribe( + consumer.topic, + defaultPrefix + "-" + consumer.topic, + consumer.handlerFn(auth), + ) + } +} + +type resultT[T any] struct{ + value T + err error +} + +type taggedT[T any] struct{ + id int + value T +} + +func startRunner[A any, B any]( + in <-chan taggedT[A], + out chan<- taggedT[B], + fn func(A) B, + done func(), +) { + for input := range in { + out <- taggedT[B]{ + id: input.id, + value: fn(input.value), + } + } + done() +} + +func makePoolRunner[A any, B any](count int, fn func(A) B) (func(A) B, func()) { + var wg sync.WaitGroup + wg.Add(count) + + in := make(chan taggedT[A]) + out := make(chan taggedT[B]) + + for _ = range count { + go startRunner(in, out, fn, wg.Done) + } + + var mutex sync.Mutex + m := map[int]chan B{} + id := 0 + go func() { + for output := range out { + mutex.Lock() + defer mutex.Unlock() + m[output.id] <- output.value + close(m[output.id]) + delete(m, output.id) + } + }() + + poolRunFn := func(input A) B { + c := make(chan B) + { + mutex.Lock() + defer mutex.Unlock() + m[id] = c + id++ + } + + in <- taggedT[A]{ + id: id, + value: input, + } + return <- c + } + + close := func() { + close(in) + wg.Wait() + close(out) + } + + return poolRunFn, close +} + +func asResult[A any, B any](fn func(A) (B, error)) func(A) resultT[B] { + return func(input A) resultT[B] { + output, err := fn(input) + return resultT[B]{ + value: output, + err: err, + } + } +} - err := publishRegister(auth.q, email, salt, pwhash, flowId) +func NewWithPrefix(db *sql.DB, queue q.IQueue, prefix string) (IAuth, error) { + queries, err := initDB(db, prefix) if err != nil { - return userT{}, err + return authT{}, err } - select { - case <-time.After(RegisterTimeout): - return userT{}, ErrRegisterTimeout - case <-waiter: - return auth.queries.userByEmail(email) + numCPU := (runtime.NumCPU() / 2) + 1 + hasher, closeHasher := makePoolRunner(numCPU, asResult(scrypt.Hash)) + checker, closeChecker := makePoolRunner(numCPU, asResult(scrypt.Check)) + + close := func() { + closeHasher() + closeChecker() } + + return authT{ + queries: queries, + queue: queue, + hasher: hasher, + checker: checker, + close: close, + }, nil +} + +func New(db *sql.DB, queue q.IQueue) (IAuth, error) { + return NewWithPrefix(db, queue, defaultPrefix) } -func (auth Auth) Register( +func newUserPayload(email string, salt []byte, pwhash []byte) ([]byte, error) { + data := make(map[string]interface{}) + data["email"] = email + data["salt"] = hex.EncodeToString(salt) + data["pwhash"] = hex.EncodeToString(pwhash) + return json.Marshal(data) +} + +func (auth authT) Register( email string, password string, confirmPassword string, ) (userT, error) { if password != confirmPassword { - return userT{}, ErrPasswordMismatch + return userT{}, ErrPassMismatch } if len(password) < scrypt.MinimumPasswordLength { - return userT{}, ErrPasswordTooShort + return userT{}, ErrTooShort } // special check for sql.ErrNoRows to combat enumeration attacks. - _, lookupErr := auth.queries.userByEmail(email) + // FIXME: how so? + _, lookupErr := auth.queries.byEmail(email) if lookupErr != nil && lookupErr != sql.ErrNoRows { return userT{}, lookupErr } - salt, err := scrypt.SaltFrom(rand.Reader) + salt, err := scrypt.Salt() if err != nil { return userT{}, err } - pwhash, err := scrypt.HashFrom([]byte(password), salt) - if err != nil { - return userT{}, err + input := scrypt.HashInput{ + Password: []byte(password), + Salt: salt, } + result := auth.hasher(input) /* We also try to register anyway, to prevent disk IO timing attacks. + / FIXME: how so? */ - user, err := register(auth, email, salt, pwhash) + + flowID := guuid.New() + waiter := auth.queue.WaitFor(NEW_USER, flowID, "register") + defer waiter.Close() + + payload, err := newUserPayload(email, salt, result.value) + if err != nil { + return userT{}, err + } + + unsent := q.UnsentMessage{ + Topic: NEW_USER, + FlowID: flowID, + Payload: payload, + } + _, err = auth.queue.Publish(unsent) if err != nil { - if lookupErr != nil { - return userT{}, ErrAlreadyRegistered - } return userT{}, err } + user := userT{} + select { + case <-time.After(RegisterTimeout): + err = ErrTimeout + case <-waiter.Channel: + user, err = auth.queries.byEmail(email) + } + return user, nil } -func sendConfirmationPayload(email string) ([]byte, error) { +func sendConfirmationMessage( + email string, + flowID guuid.UUID, +) (q.UnsentMessage, error) { data := make(map[string]interface{}) data["email"] = email - return json.Marshal(data) + payload, err := json.Marshal(data) + if err != nil { + return q.UnsentMessage{}, err + } + + return q.UnsentMessage{ + Topic: SEND_CONFIRMATION_REQUEST, + FlowID: flowID, + Payload: payload, + }, nil } -func publishSendConfirmation(q liteq.Queue, email string, flowId []byte) error { - payload, err := sendConfirmationPayload(email) +func (auth authT) ResendConfirmation(email string) error { + unsent, err := sendConfirmationMessage(email, guuid.New()) if err != nil { return err } - return q.Publish(SEND_CONFIRMATION_REQUEST, payload, flowId) -} -func (auth Auth) ResendConfirmation(email string) error { - return publishSendConfirmation(auth.q, email, guuid.NewBytes()) + _, err = auth.queue.Publish(unsent) + return err } -func (auth Auth) ConfirmEmail(token guuid.UUID) (sessionT, error) { - return auth.queries.confirm(token[:]) +func (auth authT) ConfirmEmail(token guuid.UUID) (sessionT, error) { + return auth.queries.confirm(token) } -func (auth Auth) LoginEmail( +func (auth authT) LoginEmail( email string, password string, ) (sessionT, error) { if len(password) < scrypt.MinimumPasswordLength { - return sessionT{}, ErrPasswordTooShort + return sessionT{}, ErrTooShort } // special check for sql.ErrNoRows to combat enumeration attacks. - user, err := auth.queries.userByEmail(email) + user, err := auth.queries.byEmail(email) if err != nil && err != sql.ErrNoRows { return sessionT{}, err } - ok, err := scrypt.CheckFrom([]byte(password), user.salt, user.pwhash) + input := scrypt.CheckInput{ + Password: []byte(password), + Salt: user.salt, + Hash: user.pwhash, + } + ok, err := scrypt.Check(input) if err != nil { return sessionT{}, err } if !ok { - return sessionT{}, ErrEmailPasswordCombo + return sessionT{}, ErrBadCombo } if user.confirmed_at == nil { - return sessionT{}, ErrUnconfirmedUser + return sessionT{}, ErrUnconfirmed } dbSession, err := auth.queries.login(email, password) @@ -770,29 +1035,38 @@ func (auth Auth) LoginEmail( return dbSession, nil } -func forgotPasswordPayload(email string) ([]byte, error) { +func forgotPasswordMessage( + email string, + flowID guuid.UUID, +) (q.UnsentMessage, error) { data := make(map[string]interface{}) data["email"] = email - return json.Marshal(data) -} - -func publishForgotPassword(q liteq.Queue, email string, flowId []byte) error { - payload, err := forgotPasswordPayload(email) + payload, err := json.Marshal(data) if err != nil { - return err + return q.UnsentMessage{}, err } - return q.Publish(FORGOT_PASSWORD_REQUEST, payload, flowId) + return q.UnsentMessage{ + Topic: FORGOT_PASSWORD_REQUEST, + FlowID: flowID, + Payload: payload, + }, nil } -func (auth Auth) ForgotPassword(email string) error { +func (auth authT) ForgotPassword(email string) error { // special check for sql.ErrNoRows to combat enumeration attacks. - user, err := auth.queries.userByEmail(email) + user, err := auth.queries.byEmail(email) if err != nil && err != sql.ErrNoRows { return err } - return publishForgotPassword(auth.q, user.email, guuid.NewBytes()) + unsent, err := forgotPasswordMessage(user.email, guuid.New()) + if err != nil { + return err + } + + _, err = auth.queue.Publish(unsent) + return err } func checkSession(session sessionT, now time.Time) error { @@ -807,8 +1081,11 @@ func checkSession(session sessionT, now time.Time) error { return nil } -func validateSession(lookupFn func([]byte) (sessionT, error), session sessionT) error { - dbSession, err := lookupFn(session.uuid[:]) +func validateSession( + lookupFn func(guuid.UUID) (sessionT, error), + session sessionT, +) error { + dbSession, err := lookupFn(session.uuid) if err != nil { return err } @@ -816,184 +1093,139 @@ func validateSession(lookupFn func([]byte) (sessionT, error), session sessionT) return checkSession(dbSession, time.Now()) } -func (auth Auth) Refresh(session sessionT) (sessionT, error) { - err := validateSession(auth.queries.sessionByUUID, session) +func (auth authT) Refresh(session sessionT) (sessionT, error) { + err := validateSession(auth.queries.byUUID, session) if err != nil { return sessionT{}, err } - return auth.queries.refresh(session.uuid[:]) + return auth.queries.refresh(session.uuid) } -func (auth Auth) ResetPassword( - token []byte, +func (auth authT) ResetPassword( + token guuid.UUID, password string, confirmPassword string, ) (sessionT, error) { if password != confirmPassword { - return sessionT{}, ErrPasswordMismatch + return sessionT{}, ErrPassMismatch } if len(password) < scrypt.MinimumPasswordLength { - return sessionT{}, ErrPasswordTooShort + return sessionT{}, ErrTooShort } - user, err := auth.queries.userByToken(token) + user, err := auth.queries.byToken(token) if err != nil { return sessionT{}, err } - pwhash, err := scrypt.HashFrom([]byte(password), user.salt) + input := scrypt.HashInput{ + Password: []byte(password), + Salt: user.salt, + } + pwhash, err := scrypt.Hash(input) if err != nil { return sessionT{}, err } if user.confirmed_at != nil { - return auth.queries.resetPassword(user.id, pwhash, token) + return auth.queries.reset(user.id, pwhash, token) } else { - return auth.queries.resetAndConfirm(user.id, pwhash, token) + // return auth.queries.confirm(user.id, pwhash, token) + return auth.queries.confirm(token) } } -func (auth Auth) ChangePassword( +func (auth authT) ChangePassword( user userT, currentPassword string, newPassword string, confirmNewPassword string, ) (sessionT, error) { if newPassword != confirmNewPassword { - return sessionT{}, ErrPasswordMismatch + return sessionT{}, ErrPassMismatch } if len(newPassword) < scrypt.MinimumPasswordLength { - return sessionT{}, ErrPasswordTooShort + return sessionT{}, ErrTooShort } - pwhash, err := scrypt.HashFrom([]byte(newPassword), user.salt) + // FIXME + input := scrypt.HashInput{ + Password: []byte(newPassword), + Salt: user.salt, + } + pwhash, err := scrypt.Hash(input) if err != nil { return sessionT{}, err } if user.confirmed_at == nil { - return sessionT{}, ErrUnconfirmedUser + return sessionT{}, ErrUnconfirmed } - return auth.queries.changePassword(user.id, pwhash) + return auth.queries.change(user.id, pwhash) } func runLogout( - lookupFn func([]byte) (sessionT, error), + lookupFn func(guuid.UUID) (sessionT, error), session sessionT, - queryFn func([]byte) error, + queryFn func(guuid.UUID) error, ) error { err := validateSession(lookupFn, session) if err != nil { return err } - return queryFn(session.uuid[:]) + return queryFn(session.uuid) } -func (auth Auth) Logout(session sessionT) error { - return runLogout(auth.queries.sessionByUUID, session, auth.queries.logout) +func (auth authT) Logout(session sessionT) error { + return runLogout( + auth.queries.byUUID, + session, + auth.queries.logout, + ) } -func (auth Auth) LogoutOthers(session sessionT) error { - return runLogout(auth.queries.sessionByUUID, session, auth.queries.logoutOthers) +func (auth authT) LogoutOthers(session sessionT) error { + return runLogout( + auth.queries.byUUID, + session, + auth.queries.outOthers, + ) } -func (auth Auth) LogoutAll(session sessionT) error { - return runLogout(auth.queries.sessionByUUID, session, auth.queries.logoutAll) +func (auth authT) LogoutAll(session sessionT) error { + return runLogout( + auth.queries.byUUID, + session, + auth.queries.outAll, + ) } -func (auth Auth) Close() error { +func (auth authT) Close() error { return auth.queries.close() } -func newUserHandler(auth Auth) func([]byte) error { - return func(payload []byte) error { - return nil - } -} - -func sendConfirmationRequestHandler(auth Auth) func([]byte) error { - return func(payload []byte) error { - return nil - } -} - -func forgotPasswordRequestHandler(auth Auth) func([]byte) error { - return func(payload []byte) error { - return nil - } -} - -var consumers = []consumerT{ - consumerT{ - topic: NEW_USER, - handlerFn: newUserHandler, - }, - consumerT{ - topic: SEND_CONFIRMATION_REQUEST, - handlerFn: sendConfirmationRequestHandler, - }, - consumerT{ - topic: FORGOT_PASSWORD_REQUEST, - handlerFn: forgotPasswordRequestHandler, - }, -} -func registerConsumers(auth Auth, consumers []consumerT) { - for _, consumer := range consumers { - auth.q.Subscribe( - consumer.topic, - defaultPrefix + "-" + consumer.topic, - consumer.handlerFn(auth), - ) - } -} - -func NewWithPrefix(db *sql.DB, q liteq.Queue, prefix string) (Auth, error) { - tables, err := tablesFrom(prefix) - if err != nil { - return Auth{}, err - } - - queries, err := initDB(db, tables) - if err != nil { - return Auth{}, err - } - - return Auth{ - tables: tables, - queries: queries, - q: q, - }, nil -} - -func New(db *sql.DB, q liteq.Queue) (Auth, error) { - return NewWithPrefix(db, q, defaultPrefix) -} - func Main() { g.Init() - q := new(liteq.Queue) - sql.Register("sqlite-liteq", liteq.MakeDriver(q)) - - db, err := sql.Open("sqlite-liteq", "file:gracha.db?mode=memory&cache=shared") + db, err := sql.Open("acude", "file:gracha.db?mode=memory&cache=shared") if err != nil { panic(err) } - // defer db.Close() + defer db.Close() - *q, err = liteq.New(db) + queue, err := q.New(db) if err != nil { panic(err) } - // defer q.Close() + defer queue.Close() - auth, err := New(db, *q) + auth, err := New(db, queue) if err != nil { fmt.Println(err) panic(err) @@ -1006,6 +1238,6 @@ func Main() { } return - fmt.Printf("q: %#v\n", q) + fmt.Printf("q: %#v\n", queue) fmt.Printf("auth: %#v\n", auth) } diff --git a/tests/benchmarks/register-login/gracha.go b/tests/benchmarks/register-login/gracha.go new file mode 100644 index 0000000..f363b6b --- /dev/null +++ b/tests/benchmarks/register-login/gracha.go @@ -0,0 +1,9 @@ +package gracha + +import ( +) + + + +func MainTest() { +} diff --git a/tests/benchmarks/register-login/main.go b/tests/benchmarks/register-login/main.go new file mode 120000 index 0000000..f67563d --- /dev/null +++ b/tests/benchmarks/register-login/main.go @@ -0,0 +1 @@ +../../main.go
\ No newline at end of file diff --git a/tests/functional/register-twice/gracha.go b/tests/functional/register-twice/gracha.go new file mode 100644 index 0000000..f363b6b --- /dev/null +++ b/tests/functional/register-twice/gracha.go @@ -0,0 +1,9 @@ +package gracha + +import ( +) + + + +func MainTest() { +} diff --git a/tests/functional/register-twice/main.go b/tests/functional/register-twice/main.go new file mode 120000 index 0000000..f67563d --- /dev/null +++ b/tests/functional/register-twice/main.go @@ -0,0 +1 @@ +../../main.go
\ No newline at end of file diff --git a/tests/fuzz/api/gracha.go b/tests/fuzz/api/gracha.go new file mode 100644 index 0000000..6f2981a --- /dev/null +++ b/tests/fuzz/api/gracha.go @@ -0,0 +1,35 @@ +package gracha + +import ( + "os" + "testing" + "testing/internal/testdeps" +) + + + +func api(f *testing.F) { + f.Fuzz(func(t *testing.T, n int) { + // FIXME + if n > 1 { + if n < 2 { + t.Errorf("Failed n: %v\n", n) + } + } + }) +} + + + +func MainTest() { + fuzzTargets := []testing.InternalFuzzTarget{ + { "api", api }, + } + + deps := testdeps.TestDeps{} + tests := []testing.InternalTest {} + benchmarks := []testing.InternalBenchmark{} + examples := []testing.InternalExample {} + m := testing.MainStart(deps, tests, benchmarks, fuzzTargets, examples) + os.Exit(m.Run()) +} diff --git a/tests/fuzz/api/main.go b/tests/fuzz/api/main.go new file mode 120000 index 0000000..f67563d --- /dev/null +++ b/tests/fuzz/api/main.go @@ -0,0 +1 @@ +../../main.go
\ No newline at end of file diff --git a/tests/gracha.go b/tests/gracha.go index 18b54d6..0c11fac 100644 --- a/tests/gracha.go +++ b/tests/gracha.go @@ -1,15 +1,15 @@ package gracha import ( - "database/sql" + // "database/sql" - "liteq" + // "q" g "gobang" ) type testAuth struct{ - auth Auth + auth authT // registerEmail func(credentials) close func() error } @@ -18,55 +18,28 @@ func test_defaultPrefix() { g.TestStart("defaultPrefix") g.Testing("the defaultPrefix is valid", func() { - g.TAssertEqual(g.ValidSQLTablePrefix(defaultPrefix), true) - }) -} - -func test_tablesFrom() { - g.TestStart("tablesFrom()") - - g.Testing("prefix needs to be valid", func() { - _, err := tablesFrom("invalid-prefix") - g.TAssertEqual(err, g.ErrBadSQLTablePrefix) - }) - - g.Testing("the struct adds suffixes", func() { - t, err := tablesFrom(defaultPrefix) - g.TAssertEqual(err, nil) - g.TAssertEqual(t, tablesT{ - users: "gracha-users", - userChanges: "gracha-user-changes", - tokens: "gracha-tokens", - roles: "gracha-roles", - roleChanges: "gracha-role-changes", - sessions: "gracha-sessions", - attempts: "gracha-attempts", - audit: "gracha-audit", - }) + g.TErrorIf(g.ValidateSQLTablePrefix(defaultPrefix)) }) } +/* func mkauth() testAuth { - q := new(liteq.Queue) - sql.Register("sqlite-liteq", liteq.MakeDriver(q)) - - db, err := sql.Open("sqlite-liteq", "file:db?mode=memory&cache=shared") + db, err := sql.Open("acude", "file:db?mode=memory&cache=shared") g.TAssertEqual(err, nil) - *q, err = liteq.New(db) + queue, err := q.New(db) g.TAssertEqual(err, nil) - auth, err := New(db, *q) + auth, err := New(db, queue) g.TAssertEqual(err, nil) - fns := [](func() error){ - db.Close, - q.Close, - } return testAuth{ auth: auth, close: func() error { - return g.SomeFnError(fns) + return g.SomeFnError( + db.Close, + queue.Close, + ) }, } } @@ -97,11 +70,12 @@ func test_Register() { g.Testing("we can't register duplicate emails", func() { }) } +*/ func MainTest() { g.Init() test_defaultPrefix() - test_tablesFrom() - test_Register() + // test_tablesFrom() + // test_Register() } diff --git a/tests/queries.sql b/tests/queries.sql new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/queries.sql |