summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore10
-rw-r--r--Makefile83
-rw-r--r--deps.mk57
-rwxr-xr-xmkdeps.sh25
-rw-r--r--src/gracha.go1316
-rw-r--r--tests/benchmarks/register-login/gracha.go9
l---------tests/benchmarks/register-login/main.go1
-rw-r--r--tests/functional/register-twice/gracha.go9
l---------tests/functional/register-twice/main.go1
-rw-r--r--tests/fuzz/api/gracha.go35
l---------tests/fuzz/api/main.go1
-rw-r--r--tests/gracha.go56
-rw-r--r--tests/queries.sql0
13 files changed, 992 insertions, 611 deletions
diff --git a/.gitignore b/.gitignore
index c096254..1b9f827 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,7 +1,15 @@
/src/version.go
/*.bin
-/*.db
+/*.db*
/src/*.a
/src/*.bin
/tests/*.a
/tests/*.bin
+/tests/functional/*/*.a
+/tests/functional/*/*.bin
+/tests/fuzz/*/*.a
+/tests/fuzz/*/*.bin
+/tests/benchmarks/*/*.a
+/tests/benchmarks/*/*.bin
+/tests/benchmarks/*/*.txt
+/tests/fuzz/corpus/
diff --git a/Makefile b/Makefile
index 7721f5b..2ea93b9 100644
--- a/Makefile
+++ b/Makefile
@@ -17,7 +17,7 @@ MANDIR = $(SHAREDIR)/man
EXEC = ./
## Where to store the installation. Empty by default.
DESTDIR =
-LDLIBS = -lsqlite3
+LDLIBS = --static -lscrypt-kdf -lsqlite3 -lm
GOCFLAGS = -I $(GOLIBDIR)
GOLDFLAGS = -L $(GOLIBDIR)
@@ -26,17 +26,26 @@ GOLDFLAGS = -L $(GOLIBDIR)
.SUFFIXES:
.SUFFIXES: .go .a .bin .bin-check
+.go.a:
+ go tool compile $(GOCFLAGS) -I $(@D) -o $@ -p $(*F) \
+ `find $< $$(if [ $(*F) != main ]; then \
+ echo src/$(NAME).go src/version.go; fi) | uniq`
+
+.a.bin:
+ go tool link $(GOLDFLAGS) -L $(@D) -o $@ --extldflags '$(LDLIBS)' $<
+
all:
include deps.mk
-objects = \
- src/$(NAME).a \
- src/main.a \
- tests/$(NAME).a \
- tests/main.a \
+libs.a = $(libs.go:.go=.a)
+mains.a = $(mains.go:.go=.a)
+mains.bin = $(mains.go:.go=.bin)
+functional-tests/lib.a = $(functional-tests/lib.go:.go=.a)
+fuzz-targets/lib.a = $(fuzz-targets/lib.go:.go=.a)
+benchmarks/lib.a = $(benchmarks/lib.go:.go=.a)
sources = \
src/$(NAME).go \
@@ -46,13 +55,16 @@ sources = \
derived-assets = \
src/version.go \
- $(objects) \
- src/main.bin \
- tests/main.bin \
+ $(libs.a) \
+ $(mains.a) \
+ $(mains.bin) \
$(NAME).bin \
side-assets = \
$(NAME).db* \
+ tests/functional/*/*.go.db* \
+ tests/fuzz/corpus/ \
+ tests/benchmarks/*/main.txt \
@@ -61,40 +73,35 @@ side-assets = \
all: $(derived-assets)
-$(objects): Makefile
+$(libs.a): Makefile deps.mk
+$(libs.a): src/$(NAME).go src/version.go
-src/$(NAME).a: src/$(NAME).go src/version.go
- go tool compile $(GOCFLAGS) -o $@ -p $(*F) -I $(@D) $*.go src/version.go
-src/main.a: src/main.go src/$(NAME).a
-tests/main.a: tests/main.go tests/$(NAME).a
-src/main.a tests/main.a:
- go tool compile $(GOCFLAGS) -o $@ -p $(*F) -I $(@D) $*.go
+$(fuzz-targets/lib.a):
+ go tool compile $(GOCFLAGS) -o $@ -p $(NAME) -d=libfuzzer \
+ src/version.go $*.go src/$(NAME).go
-tests/$(NAME).a: tests/$(NAME).go src/$(NAME).go src/version.go
- go tool compile $(GOCFLAGS) -o $@ -p $(*F) $*.go src/$(*F).go src/version.go
-
-src/main.bin: src/main.a
-tests/main.bin: tests/main.a
-src/main.bin tests/main.bin:
- go tool link $(GOLDFLAGS) -o $@ -L $(@D) --extldflags '$(LDLIBS)' $*.a
+src/version.go: Makefile
+ echo 'package $(NAME); const Version = "$(VERSION)"' > $@
$(NAME).bin: src/main.bin
ln -fs $? $@
-src/version.go: Makefile
- echo 'package $(NAME); const Version = "$(VERSION)"' > $@
+.PRECIOUS: tests/queries.sql
+tests/queries.sql: tests/main.bin ALWAYS
+ env TESTING_DUMP_SQL_QUERIES=1 $(EXEC)tests/main.bin | diff -U10 $@ -
tests.bin-check = \
- tests/main.bin-check \
+ tests/main.bin-check \
+ $(functional-tests/main.go:.go=.bin-check) \
-tests/main.bin-check: tests/main.bin
$(tests.bin-check):
$(EXEC)$*.bin
check-unit: $(tests.bin-check)
+check-unit: tests/queries.sql
integration-tests = \
@@ -107,6 +114,7 @@ $(integration-tests): ALWAYS
sh $@
check-integration: $(integration-tests)
+check-integration: fuzz
## Run all tests. Each test suite is isolated, so that a parallel
@@ -116,6 +124,27 @@ check: check-unit check-integration
+FUZZSEC=1
+fuzz-targets/main.bin-check = $(fuzz-targets/main.go:.go=.bin-check)
+$(fuzz-targets/main.bin-check):
+ $(EXEC)$*.bin --test.fuzztime=$(FUZZSEC)s \
+ --test.fuzz='.*' --test.fuzzcachedir=tests/fuzz/corpus
+
+fuzz: $(fuzz-targets/main.bin-check)
+
+
+
+benchmarks/main.bin-check = $(benchmarks/main.go:.go=.bin-check)
+$(benchmarks/main.bin-check):
+ rm -f $*.txt
+ printf '%s\n' '$(EXEC)$*.bin' >> $*.txt
+ LANG=POSIX.UTF-8 time -p $(EXEC)$*.bin 2>> $*.txt
+ printf '%s\n' '$*.txt'
+
+bench: $(benchmarks/main.bin-check)
+
+
+
## Remove *all* derived artifacts produced during the build.
## A dedicated test asserts that this is always true.
clean:
diff --git a/deps.mk b/deps.mk
index e69de29..a9d0fa7 100644
--- a/deps.mk
+++ b/deps.mk
@@ -0,0 +1,57 @@
+libs.go = \
+ src/gracha.go \
+ tests/benchmarks/register-login/gracha.go \
+ tests/functional/register-twice/gracha.go \
+ tests/fuzz/api/gracha.go \
+ tests/gracha.go \
+
+mains.go = \
+ src/main.go \
+ tests/benchmarks/register-login/main.go \
+ tests/functional/register-twice/main.go \
+ tests/fuzz/api/main.go \
+ tests/main.go \
+
+functional-tests/libs.go = \
+ tests/functional/register-twice/gracha.go \
+
+functional-tests/main.go = \
+ tests/functional/register-twice/main.go \
+
+fuzz-targets/lib.go = \
+ tests/fuzz/api/gracha.go \
+
+fuzz-targets/main.go = \
+ tests/fuzz/api/main.go \
+
+benchmarks/lib.go = \
+ tests/benchmarks/register-login/gracha.go \
+
+benchmarks/main.go = \
+ tests/benchmarks/register-login/main.go \
+
+src/gracha.a: src/gracha.go
+src/main.a: src/main.go
+tests/benchmarks/register-login/gracha.a: tests/benchmarks/register-login/gracha.go
+tests/benchmarks/register-login/main.a: tests/benchmarks/register-login/main.go
+tests/functional/register-twice/gracha.a: tests/functional/register-twice/gracha.go
+tests/functional/register-twice/main.a: tests/functional/register-twice/main.go
+tests/fuzz/api/gracha.a: tests/fuzz/api/gracha.go
+tests/fuzz/api/main.a: tests/fuzz/api/main.go
+tests/gracha.a: tests/gracha.go
+tests/main.a: tests/main.go
+src/main.bin: src/main.a
+tests/benchmarks/register-login/main.bin: tests/benchmarks/register-login/main.a
+tests/functional/register-twice/main.bin: tests/functional/register-twice/main.a
+tests/fuzz/api/main.bin: tests/fuzz/api/main.a
+tests/main.bin: tests/main.a
+src/main.bin-check: src/main.bin
+tests/benchmarks/register-login/main.bin-check: tests/benchmarks/register-login/main.bin
+tests/functional/register-twice/main.bin-check: tests/functional/register-twice/main.bin
+tests/fuzz/api/main.bin-check: tests/fuzz/api/main.bin
+tests/main.bin-check: tests/main.bin
+src/main.a: src/$(NAME).a
+tests/benchmarks/register-login/main.a: tests/benchmarks/register-login/$(NAME).a
+tests/functional/register-twice/main.a: tests/functional/register-twice/$(NAME).a
+tests/fuzz/api/main.a: tests/fuzz/api/$(NAME).a
+tests/main.a: tests/$(NAME).a
diff --git a/mkdeps.sh b/mkdeps.sh
index e5606ff..b1c61b3 100755
--- a/mkdeps.sh
+++ b/mkdeps.sh
@@ -2,3 +2,28 @@
set -eu
export LANG=POSIX.UTF-8
+
+
+libs() {
+ find src tests -name '*.go' | grep -v '/main\.go$' |
+ grep -v '/version\.go$'
+}
+
+mains() {
+ find src tests -name '*.go' | grep '/main\.go$'
+}
+
+libs | varlist 'libs.go'
+mains | varlist 'mains.go'
+
+find tests/functional/*/*.go -not -name main.go | varlist 'functional-tests/libs.go'
+find tests/functional/*/main.go | varlist 'functional-tests/main.go'
+find tests/fuzz/*/*.go -not -name main.go | varlist 'fuzz-targets/lib.go'
+find tests/fuzz/*/main.go | varlist 'fuzz-targets/main.go'
+find tests/benchmarks/*/*.go -not -name main.go | varlist 'benchmarks/lib.go'
+find tests/benchmarks/*/main.go | varlist 'benchmarks/main.go'
+
+{ libs; mains; } | sort | sed 's/^\(.*\)\.go$/\1.a:\t\1.go/'
+mains | sort | sed 's/^\(.*\)\.go$/\1.bin:\t\1.a/'
+mains | sort | sed 's/^\(.*\)\.go$/\1.bin-check:\t\1.bin/'
+mains | sort | sed 's|^\(.*\)/main\.go$|\1/main.a:\t\1/$(NAME).a|'
diff --git a/src/gracha.go b/src/gracha.go
index dcff98d..e300f66 100644
--- a/src/gracha.go
+++ b/src/gracha.go
@@ -1,48 +1,71 @@
package gracha
import (
- "crypto/rand"
+ "context"
"database/sql"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
+ "runtime"
+ "sync"
"time"
"guuid"
- "liteq"
+ "q"
"scrypt"
g "gobang"
)
-type tablesT struct{
- users string
- userChanges string
- tokens string
- roles string
- roleChanges string
- sessions string
- attempts string
- audit string
+const (
+ defaultPrefix = "gracha"
+
+ rollbackErrorFmt = "rollback error: %w; while executing: %w"
+ NEW_USER = "new-user"
+ SEND_CONFIRMATION_REQUEST = "send-confirmation-request"
+ FORGOT_PASSWORD_REQUEST = "forgot-password-request"
+
+
+ day = 24 * time.Hour
+)
+
+var (
+ SessionDuration = 7 * day
+ RegisterTimeout = 15 * time.Second
+
+ ErrPassMismatch = errors.New("gracha: password confirmation mismatch")
+ ErrTooShort = errors.New("gracha: password too short")
+ ErrTimeout = errors.New("gracha: timeout when creating user")
+ ErrRegistered = errors.New("gracha: user already registered")
+ ErrBadCombo = errors.New("gracha: bad username/passphrase combo")
+ ErrUnconfirmed = errors.New("gracha: email is not confirmed")
+ ErrRevokedSession = errors.New("gracha: this session was revoked")
+ ErrSessionExpired = errors.New("gracha: session expired")
+)
+
+
+
+type queryT struct{
+ write string
+ read string
}
type queriesT struct{
- userByEmail func(string) (userT, error)
- userByToken func([]byte) (userT, error)
- register func(string, []byte, []byte) (userT, error)
- confirm func([]byte) (sessionT, error)
- login func(string, string) (sessionT, error)
- refresh func([]byte) (sessionT, error)
- resetPassword func(int64, []byte, []byte) (sessionT, error)
- resetAndConfirm func(int64, []byte, []byte) (sessionT, error)
- changePassword func(int64, []byte) (sessionT, error)
- sessionByUUID func([]byte) (sessionT, error)
- logout func([]byte) error
- logoutOthers func([]byte) error
- logoutAll func([]byte) error
- close func() error
+ byEmail func(string) (userT, error)
+ byToken func(guuid.UUID) (userT, error)
+ register func(string, []byte, []byte) (userT, error)
+ confirm func(guuid.UUID) (sessionT, error)
+ login func(string, string) (sessionT, error)
+ refresh func(guuid.UUID) (sessionT, error)
+ reset func(int64, []byte, guuid.UUID) (sessionT, error)
+ change func(int64, []byte) (sessionT, error)
+ byUUID func(guuid.UUID) (sessionT, error)
+ logout func(guuid.UUID) error
+ outOthers func(guuid.UUID) error
+ outAll func(guuid.UUID) error
+ close func() error
}
type confirmationT struct{
@@ -50,17 +73,14 @@ type confirmationT struct{
type userT struct{
id int64
- timestr string
timestamp time.Time
- uuid []byte
+ uuid guuid.UUID
email string
username *string
salt []byte
pwhash []byte
- // confirmation
confirmed_at *time.Time
- metadatastr *string
- metadata *map[string]interface{}
+ metadata map[string]interface{}
}
type sessionT struct{
@@ -70,240 +90,284 @@ type sessionT struct{
uuid guuid.UUID
user_id int64
type_ string
- revokedstr *string
revoked_at *time.Time
- metadatastr *string
- metadata *map[string]interface{}
-}
-
-type Auth struct{
- tables tablesT
- queries queriesT
- q liteq.Queue
+ metadata map[string]interface{}
}
type consumerT struct{
topic string
- handlerFn func(Auth) func([]byte) error
+ handlerFn func(authT) func(q.Message) error
}
+type authT struct{
+ queries queriesT
+ queue q.IQueue
+ hasher func(scrypt.HashInput) resultT[[]byte]
+ checker func(scrypt.CheckInput) resultT[bool]
+ close func()
+}
+type IAuth interface{
+ Register(string, string, string) (userT, error)
+ ResendConfirmation(string) error
+ ConfirmEmail(guuid.UUID) (sessionT, error)
+ LoginEmail(string, string) (sessionT, error)
+ ForgotPassword(string) error
+ Refresh(sessionT) (sessionT, error)
+ ResetPassword(guuid.UUID, string, string) (sessionT, error)
+ ChangePassword(userT, string, string, string) (sessionT, error)
+ Logout(sessionT) error
+ LogoutOthers(sessionT) error
+ LogoutAll(sessionT) error
+ Close() error
+}
-const (
- NEW_USER = "new-user"
- SEND_CONFIRMATION_REQUEST = "send-confirmation-request"
- FORGOT_PASSWORD_REQUEST = "forgot-password-request"
- defaultPrefix = "gracha"
- day = 24 * time.Hour
-)
+func tryRollback(db *sql.DB, ctx context.Context, err error) error {
+ _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;")
+ if rollbackErr != nil {
+ return fmt.Errorf(
+ rollbackErrorFmt,
+ rollbackErr,
+ err,
+ )
+ }
-var (
- SessionDuration = 7 * day
- RegisterTimeout = 15 * time.Second
+ return err
+}
- ErrPasswordMismatch = errors.New("gracha: password and its confirmation don't match")
- ErrPasswordTooShort = errors.New("gracha: bad username/passphrase combo")
- ErrRegisterTimeout = errors.New("gracha: timeout when creating user")
- ErrAlreadyRegistered = errors.New("gracha: user already registered")
- ErrEmailPasswordCombo = errors.New("gracha: bad username/passphrase combo")
- ErrUnconfirmedUser = errors.New("gracha: user email is not confirmed")
- ErrRevokedSession = errors.New("gracha: this session was revoked")
- ErrSessionExpired = errors.New("gracha: session expired")
-)
+func inTx(db *sql.DB, fn func(context.Context) error) error {
+ ctx := context.Background()
+
+ _, err := db.ExecContext(ctx, "BEGIN IMMEDIATE;")
+ if err != nil {
+ return err
+ }
+ err = fn(ctx)
+ if err != nil {
+ return tryRollback(db, ctx, err)
+ }
+ _, err = db.ExecContext(ctx, "COMMIT;")
+ if err != nil {
+ return tryRollback(db, ctx, err)
+ }
-func tablesFrom(prefix string) (tablesT, error) {
- if !g.ValidSQLTablePrefix(prefix) {
- return tablesT{}, g.ErrBadSQLTablePrefix
- }
-
- users := prefix + "-users"
- userChanges := prefix + "-user-changes"
- tokens := prefix + "-tokens"
- roles := prefix + "-roles"
- roleChanges := prefix + "-role-changes"
- sessions := prefix + "-sessions"
- attempts := prefix + "-attempts"
- audit := prefix + "-audit"
- return tablesT{
- users: users,
- userChanges: userChanges,
- tokens: tokens,
- roles: roles,
- roleChanges: roleChanges,
- sessions: sessions,
- attempts: attempts,
- audit: audit,
- }, nil
+ return nil
}
-func createTables(db *sql.DB, tables tablesT) error {
- const tmpl = `
- BEGIN TRANSACTION;
- CREATE TABLE IF NOT EXISTS "%s" (
- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
- timestamp TEXT NOT NULL DEFAULT (%s),
- uuid BLOB NOT NULL UNIQUE,
- email TEXT NOT NULL UNIQUE,
- username TEXT UNIQUE,
- salt BLOB NOT NULL UNIQUE,
- pwhash BLOB NOT NULL,
- confirmed_at TEXT,
- confirmer_id INTEGER REFERENCES "%s"(id),
- metadata TEXT
+func createTablesSQL(prefix string) queryT {
+ const tmpl_write = `
+ CREATE TABLE IF NOT EXISTS "%s_users" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ uuid BLOB NOT NULL UNIQUE,
+ email TEXT NOT NULL UNIQUE,
+ username TEXT UNIQUE,
+ salt BLOB NOT NULL UNIQUE,
+ pwhash BLOB NOT NULL,
+ confirmed_at TEXT,
+ confirmer_id INTEGER REFERENCES "%s"(id),
+ metadata TEXT
);
- CREATE TABLE IF NOT EXISTS "%s" (
- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
- timestamp TEXT NOT NULL DEFAULT (%s),
- user_id INTEGER NOT NULL REFERENCES "%s"(id),
- attribute TEXT NOT NULL,
- value TEXT NOT NULL,
- op BOOLEAN NOT NULL
+ CREATE TABLE IF NOT EXISTS "%s_user_changes" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ user_id INTEGER NOT NULL REFERENCES "%s_users"(id),
+ attribute TEXT NOT NULL,
+ value TEXT NOT NULL,
+ op BOOLEAN NOT NULL
);
- CREATE TABLE IF NOT EXISTS "%s" (
- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
- timestamp TEXT NOT NULL DEFAULT (%s),
- uuid BLOB NOT NULL UNIQUE,
- type TEXT NOT NULL
+ CREATE TABLE IF NOT EXISTS "%s_tokens" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ uuid BLOB NOT NULL UNIQUE,
+ type TEXT NOT NULL
);
- CREATE TABLE IF NOT EXISTS "%s" (
- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
- user_id INTEGER NOT NULL REFERENCES "%s"(id),
- role TEXT NOT NULL,
+ CREATE TABLE IF NOT EXISTS "%s_roles" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL REFERENCES "%s_users"(id),
+ role TEXT NOT NULL,
+ metadata TEXT,
UNIQUE (user_id, role)
);
- CREATE TABLE IF NOT EXISTS "%s" (
- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
- timestamp TEXT NOT NULL DEFAULT (%s),
- user_id INTEGER NOT NULL REFERENCES "%s"(id),
- role TEXT NOT NULL,
- op BOOLEAN NOT NULL
+ CREATE TABLE IF NOT EXISTS "%s_role_changes" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ user_id INTEGER NOT NULL REFERENCES "%s_roles"(id),
+ role TEXT NOT NULL,
+ op BOOLEAN NOT NULL
);
- CREATE TABLE IF NOT EXISTS "%s" (
- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
- timestamp TEXT NOT NULL DEFAULT (%s),
- uuid BLOB NOT NULL UNIQUE,
- user_id INTEGER NOT NULL REFERENCES "%s"(id),
- type TEXT NOT NULL,
- revoked_at TEXT,
- revoker_id INTEGER REFERENCES "%s"(id),
- metadata TEXT
+ CREATE TABLE IF NOT EXISTS "%s_sessions" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ uuid BLOB NOT NULL UNIQUE,
+ user_id INTEGER NOT NULL REFERENCES "%s_users"(id),
+ type TEXT NOT NULL,
+ revoked_at TEXT,
+ revoker_id INTEGER REFERENCES "%s_users"(id),
+ metadata TEXT
);
- CREATE TABLE IF NOT EXISTS "%s" (
- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
- timestamp TEXT NOT NULL DEFAULT (%s),
- user_id INTEGER REFERENCES "%s"(id),
- session_id INTEGER REFERENCES "%s"(id),
- metadata TEXT
+ CREATE TABLE IF NOT EXISTS "%s_attempts" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ user_id INTEGER REFERENCES "%s_users"(id),
+ session_id INTEGER REFERENCES "%s_sessions"(id),
+ metadata TEXT
);
- CREATE TABLE IF NOT EXISTS "%s" (
- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
- timestamp TEXT NOT NULL DEFAULT (%s),
- uuid BLOB NOT NULL UNIQUE,
- attribute TEXT NOT NULL,
- value TEXT NOT NULL,
- op BOOLEAN NOT NULL,
- metadata TEXT
+ CREATE TABLE IF NOT EXISTS "%s_audit" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ uuid BLOB NOT NULL UNIQUE,
+ attribute TEXT NOT NULL,
+ value TEXT NOT NULL,
+ op BOOLEAN NOT NULL,
+ metadata TEXT
);
- COMMIT TRANSACTION;
`
- sql := fmt.Sprintf(
- tmpl,
- tables.users,
- g.SQLiteNow,
- tables.tokens,
- tables.userChanges,
- g.SQLiteNow,
- tables.users,
- tables.tokens,
- g.SQLiteNow,
- tables.roles,
- tables.users,
- tables.roleChanges,
- g.SQLiteNow,
- tables.users,
- tables.sessions,
- g.SQLiteNow,
- tables.users,
- tables.sessions,
- tables.attempts,
- g.SQLiteNow,
- tables.users,
- tables.sessions,
- tables.audit,
- g.SQLiteNow,
- )
- /// fmt.Println(sql) ///
+ return queryT{
+ write: fmt.Sprintf(
+ tmpl_write,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ ),
+ }
+}
- _, err := db.Exec(sql)
- return err
+func createTables(db *sql.DB, prefix string) error {
+ q := createTablesSQL(prefix)
+
+ return inTx(db, func(ctx context.Context) error {
+ _, err := db.ExecContext(ctx, q.write)
+ return err
+ })
}
-func userByEmailQuery(
- db *sql.DB,
- tables tablesT,
-) (func(string) (userT, error), func() error, error) {
- const tmpl = `
+func byEmailSQL(prefix string) queryT {
+ const tmpl_read = `
SELECT id, timestamp, uuid, email, username, pwhash, metadata
- FROM "%s" WHERE email = ?;
+ FROM "%s_users" WHERE email = ?;
`
- sql := fmt.Sprintf(tmpl, tables.users)
- /// fmt.Println(sql) ///
+ return queryT{
+ read: fmt.Sprintf(tmpl_read, prefix),
+ }
+}
+
+func byEmailStmt(
+ db *sql.DB,
+ prefix string,
+) (func(string) (userT, error), func() error, error) {
+ q := byEmailSQL(prefix)
- stmt, err := db.Prepare(sql)
+ readStmt, err := db.Prepare(q.read)
if err != nil {
return nil, nil, err
}
fn := func(email string) (userT, error) {
- var user userT
- err := stmt.QueryRow(email).Scan(&user.id) // FIXME: build user
- return user, err
+ user := userT{
+ email: email,
+ }
+
+ var (
+ timestr string
+ uuid_bytes []byte
+ )
+ err := readStmt.QueryRow(email).Scan(
+ &user.id,
+ &timestr,
+ &uuid_bytes,
+ &user.username,
+ &user.pwhash,
+ &user.metadata,
+ )
+ if err != nil {
+ return userT{}, err
+ }
+ user.uuid = guuid.UUID(uuid_bytes)
+
+ user.timestamp, err = time.Parse(time.RFC3339Nano,timestr)
+ if err != nil {
+ return userT{}, err
+ }
+
+ return user, nil
}
- return fn, stmt.Close, nil
+ return fn, readStmt.Close, nil
}
-func userByTokenQuery(
- db *sql.DB,
- tables tablesT,
-) (func([]byte) (userT, error), func() error, error) {
- const tmpl = `
+func byTokenSQL(prefix string) queryT{
+ const tmpl_read = `
SELECT id, timestamp, uuid, email, username, pwhash, metadata
FROM "%s" WHERE email = ?;
`
- sql := fmt.Sprintf(tmpl, tables.users)
- /// fmt.Println(sql) ///
+ return queryT{
+ read: fmt.Sprintf(tmpl_read, prefix),
+ }
+}
- stmt, err := db.Prepare(sql)
+func byTokenStmt(
+ db *sql.DB,
+ prefix string,
+) (func(guuid.UUID) (userT, error), func() error, error) {
+ q := byTokenSQL(prefix)
+
+ readStmt, err := db.Prepare(q.read)
if err != nil {
return nil, nil, err
}
- fn := func(token []byte) (userT, error) {
+ fn := func(token guuid.UUID) (userT, error) {
var user userT
- err := stmt.QueryRow(token).Scan(&user.id) // FIXME: build user
+ // FIXME: build user
+ err := readStmt.QueryRow(token).Scan(&user.id)
return user, err
}
- return fn, stmt.Close, nil
+ return fn, readStmt.Close, nil
}
-func registerQuery(
- db *sql.DB,
- tables tablesT,
-) (func(string, []byte, []byte) (userT, error), func() error, error) {
- const tmpl = `
+func registerSQL(prefix string) queryT {
+ const tmpl_write = `
INSERT INTO "%s" (uuid, email, username, salt, pwhash, metadata)
VALUES (?, ?, ?, ?, ?, ?);
`
- sql := fmt.Sprintf(tmpl, tables.users)
- /// fmt.Println(sql) ///
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ }
+}
- stmt, err := db.Prepare(sql)
+func registerStmt(
+ db *sql.DB,
+ prefix string,
+) (func(string, []byte, []byte) (userT, error), func() error, error) {
+ q := registerSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
if err != nil {
return nil, nil, err
}
@@ -320,8 +384,8 @@ func registerQuery(
var user userT
// err := stmt.QueryRow(
- ret, err := stmt.Exec(
- guuid.NewBytes(),
+ ret, err := writeStmt.Exec(
+ guuid.New(),
email,
"credentials.username",
salt,
@@ -337,429 +401,630 @@ func registerQuery(
return user, err
}
- return fn, stmt.Close, nil
+ return fn, writeStmt.Close, nil
}
-func loginQuery(
- db *sql.DB,
- tables tablesT,
-) (func(string, string) (sessionT, error), func() error, error) {
- const tmpl = `
- -- INSERT INTO "%s" (t3, t4) VALUES (?, ?);
+func confirmSQL(prefix string) queryT {
+ const tmpl_write = `
+ -- INSERT SOMETHING %s
`
- sql := fmt.Sprintf(tmpl, tables.users)
- /// fmt.Println(sql) ///
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ }
+}
+
+func confirmStmt(
+ db *sql.DB,
+ prefix string,
+) (func(guuid.UUID) (sessionT, error), func() error, error) {
+ q := confirmSQL(prefix)
- stmt, err := db.Prepare(sql)
+ writeStmt, err := db.Prepare(q.write)
if err != nil {
return nil, nil, err
}
- fn := func(email string, pwhash string) (sessionT, error) {
+ fn := func(token guuid.UUID) (sessionT, error) {
var session sessionT
- err := stmt.QueryRow(email, pwhash).Scan(session)
- // FIXME: finish
+ err := writeStmt.QueryRow(token).Scan(&session)
return session, err
}
- return fn, stmt.Close, nil
+ return fn, writeStmt.Close, nil
}
-func refreshQuery(
- db *sql.DB,
- tables tablesT,
-) (func([]byte) (sessionT, error), func() error, error) {
- const tmpl = `
- -- INSERT SOMETHING %s
+func loginSQL(prefix string) queryT {
+ const tmpl_write = `
+ -- INSERT INTO "%s" (t3, t4) VALUES (?, ?);
`
- sql := fmt.Sprintf(tmpl, tables.users)
- /// fmt.Println(sql) ///
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ }
+}
+
+func loginStmt(
+ db *sql.DB,
+ prefix string,
+) (func(string, string) (sessionT, error), func() error, error) {
+ q := loginSQL(prefix)
- stmt, err := db.Prepare(sql)
+ writeStmt, err := db.Prepare(q.write)
if err != nil {
return nil, nil, err
}
- fn := func(uuid []byte) (sessionT, error) {
+ fn := func(email string, pwhash string) (sessionT, error) {
var session sessionT
- err := stmt.QueryRow(uuid).Scan(&session)
+ err := writeStmt.QueryRow(email, pwhash).Scan(session)
+ // FIXME: finish
return session, err
}
- return fn, stmt.Close, nil
+ return fn, writeStmt.Close, nil
}
-func resetPasswordQuery(
- db *sql.DB,
- tables tablesT,
-) (func(int64, []byte, []byte) (sessionT, error), func() error, error) {
- const tmpl = `
+func refreshSQL(prefix string) queryT {
+ const tmpl_write = `
-- INSERT SOMETHING %s
`
- sql := fmt.Sprintf(tmpl, tables.users)
- /// fmt.Println(sql) ///
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ }
+}
- stmt, err := db.Prepare(sql)
+func refreshStmt(
+ db *sql.DB,
+ prefix string,
+) (func(guuid.UUID) (sessionT, error), func() error, error) {
+ q := refreshSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
if err != nil {
return nil, nil, err
}
- fn := func(id int64, pwhash []byte, token []byte) (sessionT, error) {
+ fn := func(uuid guuid.UUID) (sessionT, error) {
var session sessionT
- err := stmt.QueryRow(id, pwhash, token).Scan(&session)
+ err := writeStmt.QueryRow(uuid).Scan(&session)
return session, err
}
- return fn, stmt.Close, nil
+ return fn, writeStmt.Close, nil
}
-func resetAndConfirmQuery(
- db *sql.DB,
- tables tablesT,
-) (func(int64, []byte, []byte) (sessionT, error), func() error, error) {
- const tmpl = `
+func resetSQL(prefix string) queryT {
+ const tmpl_write = `
-- INSERT SOMETHING %s
`
- sql := fmt.Sprintf(tmpl, tables.users)
- /// fmt.Println(sql) ///
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ }
+}
- stmt, err := db.Prepare(sql)
+func resetStmt(
+ db *sql.DB,
+ prefix string,
+) (func(int64, []byte, guuid.UUID) (sessionT, error), func() error, error) {
+ q := resetSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
if err != nil {
return nil, nil, err
}
- fn := func(id int64, pwhash []byte, token []byte) (sessionT, error) {
+ fn := func(id int64, pwhash []byte, token guuid.UUID) (sessionT, error) {
var session sessionT
- err := stmt.QueryRow(id, pwhash, token).Scan(&session)
+ err := writeStmt.QueryRow(id, pwhash, token).Scan(&session)
return session, err
}
- return fn, stmt.Close, nil
+ return fn, writeStmt.Close, nil
}
-func changePasswordQuery(
- db *sql.DB,
- tables tablesT,
-) (func(int64, []byte) (sessionT, error), func() error, error) {
- const tmpl = `
+func changeSQL(prefix string) queryT {
+ const tmpl_write = `
-- INSERT SOMETHING %s
`
- sql := fmt.Sprintf(tmpl, tables.users)
- /// fmt.Println(sql) ///
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ }
+}
- stmt, err := db.Prepare(sql)
+func changeStmt(
+ db *sql.DB,
+ prefix string,
+) (func(int64, []byte) (sessionT, error), func() error, error) {
+ q := changeSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
if err != nil {
return nil, nil, err
}
fn := func(id int64, pwhash []byte) (sessionT, error) {
var session sessionT
- err := stmt.QueryRow(id, pwhash).Scan(&session)
+ err := writeStmt.QueryRow(id, pwhash).Scan(&session)
return session, err
}
- return fn, stmt.Close, nil
+ return fn, writeStmt.Close, nil
}
-func sessionByUUIDQuery(
- db *sql.DB,
- tables tablesT,
-) (func([]byte) (sessionT, error), func() error, error) {
- const tmpl = `
+func byUUIDSQL(prefix string) queryT {
+ const tmpl_read = `
-- INSERT SOMETHING %s
`
- sql := fmt.Sprintf(tmpl, tables.users)
- /// fmt.Println(sql) ///
+ return queryT{
+ read: fmt.Sprintf(tmpl_read, prefix),
+ }
+}
+
+func byUUIDStmt(
+ db *sql.DB,
+ prefix string,
+) (func(guuid.UUID) (sessionT, error), func() error, error) {
+ q := byUUIDSQL(prefix)
- stmt, err := db.Prepare(sql)
+ readStmt, err := db.Prepare(q.read)
if err != nil {
return nil, nil, err
}
- fn := func(uuid []byte) (sessionT, error) {
+ fn := func(uuid guuid.UUID) (sessionT, error) {
var session sessionT
- err := stmt.QueryRow(uuid).Scan(&session)
+ err := readStmt.QueryRow(uuid).Scan(&session)
return session, err
}
- return fn, stmt.Close, nil
+ return fn, readStmt.Close, nil
}
-func logoutQuery(
- db *sql.DB,
- tables tablesT,
-) (func([]byte) error, func() error, error) {
- const tmpl = `
+func logoutSQL(prefix string) queryT {
+ const tmpl_write = `
-- INSERT SOMETHING %s
`
- sql := fmt.Sprintf(tmpl, tables.users)
- /// fmt.Println(sql) ///
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ }
+}
+
+func logoutStmt(
+ db *sql.DB,
+ prefix string,
+) (func(guuid.UUID) error, func() error, error) {
+ q := logoutSQL(prefix)
- stmt, err := db.Prepare(sql)
+ writeStmt, err := db.Prepare(q.write)
if err != nil {
return nil, nil, err
}
- fn := func(uuid []byte) error {
- _, err := stmt.Exec(uuid)
+ fn := func(uuid guuid.UUID) error {
+ _, err := writeStmt.Exec(uuid)
return err
}
- return fn, stmt.Close, nil
+ return fn, writeStmt.Close, nil
}
-func logoutOthersQuery(
- db *sql.DB,
- tables tablesT,
-) (func([]byte) error, func() error, error) {
- const tmpl = `
+func outOthersSQL(prefix string) queryT {
+ const tmpl_write = `
-- INSERT SOMETHING %s
`
- sql := fmt.Sprintf(tmpl, tables.users)
- /// fmt.Println(sql) ///
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ }
+}
- stmt, err := db.Prepare(sql)
+func outOthersStmt(
+ db *sql.DB,
+ prefix string,
+) (func(guuid.UUID) error, func() error, error) {
+ q := outOthersSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
if err != nil {
return nil, nil, err
}
- fn := func(uuid []byte) error {
- _, err := stmt.Exec(uuid)
+ fn := func(uuid guuid.UUID) error {
+ _, err := writeStmt.Exec(uuid)
return err
}
- return fn, stmt.Close, nil
+ return fn, writeStmt.Close, nil
}
-func logoutAllQuery(
- db *sql.DB,
- tables tablesT,
-) (func([]byte) error, func() error, error) {
- const tmpl = `
+func outAllSQL(prefix string) queryT {
+ const tmpl_write = `
-- INSERT SOMETHING %s
`
- sql := fmt.Sprintf(tmpl, tables.users)
- /// fmt.Println(sql) ///
+ return queryT{
+ write: fmt.Sprintf(tmpl_write, prefix),
+ }
+}
- stmt, err := db.Prepare(sql)
+func outAllStmt(
+ db *sql.DB,
+ prefix string,
+) (func(guuid.UUID) error, func() error, error) {
+ q := outAllSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
if err != nil {
return nil, nil, err
}
- fn := func(uuid []byte) error {
- _, err := stmt.Exec(uuid)
+ fn := func(uuid guuid.UUID) error {
+ _, err := writeStmt.Exec(uuid)
return err
}
- return fn, stmt.Close, nil
-}
-
-func initDB(db *sql.DB, tables tablesT) (queriesT, error) {
- createTablesErr := createTables(db, tables)
- userByEmail, userByEmailClose, userByEmailErr := userByEmailQuery(db, tables)
- userByToken, userByTokenClose, userByTokenErr := userByTokenQuery(db, tables)
- register, registerClose, registerErr := registerQuery(db, tables)
- login, loginClose, loginErr := loginQuery(db, tables)
- refresh, refreshClose, refreshErr := refreshQuery(db, tables)
- resetPassword, resetPasswordClose, resetPasswordErr := resetPasswordQuery(db, tables)
- resetAndConfirm, resetAndConfirmClose, resetAndConfirmErr := resetAndConfirmQuery(db, tables)
- changePassword, changePasswordClose, changePasswordErr := changePasswordQuery(db, tables)
- sessionByUUID, sessionByUUIDClose, sessionByUUIDErr := sessionByUUIDQuery(db, tables)
- logout, logoutClose, logoutErr := logoutQuery(db, tables)
- logoutOthers, logoutOthersClose, logoutOthersErr := logoutOthersQuery(db, tables)
- logoutAll, logoutAllClose, logoutAllErr := logoutAllQuery(db, tables)
-
- errs := []error {
- createTablesErr,
- userByEmailErr,
- userByTokenErr,
- registerErr,
- loginErr,
- refreshErr,
- resetPasswordErr,
- resetAndConfirmErr,
- changePasswordErr,
- sessionByUUIDErr,
- logoutErr,
- logoutOthersErr,
- logoutAllErr,
- }
- err := g.SomeError(errs)
+ return fn, writeStmt.Close, nil
+}
+
+func initDB(
+ db *sql.DB,
+ prefix string,
+) (queriesT, error) {
+ createTablesErr := createTables(db, prefix)
+ byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix)
+ byToken, byTokenClose, byTokenErr := byTokenStmt(db, prefix)
+ register, registerClose, registerErr := registerStmt(db, prefix)
+ confirm, confirmClose, confirmErr := confirmStmt(db, prefix)
+ login, loginClose, loginErr := loginStmt(db, prefix)
+ refresh, refreshClose, refreshErr := refreshStmt(db, prefix)
+ reset, resetClose, resetErr := resetStmt(db, prefix)
+ change, changeClose, changeErr := changeStmt(db, prefix)
+ byUUID, byUUIDClose, byUUIDErr := byUUIDStmt(db, prefix)
+ logout, logoutClose, logoutErr := logoutStmt(db, prefix)
+ outOthers, outOthersClose, outOthersErr := outOthersStmt(db, prefix)
+ outAll, outAllClose, outAllErr := outAllStmt(db, prefix)
+
+ err := g.SomeError(
+ createTablesErr,
+ byEmailErr,
+ byTokenErr,
+ registerErr,
+ confirmErr,
+ loginErr,
+ refreshErr,
+ resetErr,
+ changeErr,
+ byUUIDErr,
+ logoutErr,
+ outOthersErr,
+ outAllErr,
+ )
if err != nil {
return queriesT{}, err
}
close := func() error {
- fns := [](func() error){
- userByEmailClose,
- userByTokenClose,
- registerClose,
- loginClose,
- refreshClose,
- resetPasswordClose,
- resetAndConfirmClose,
- changePasswordClose,
- sessionByUUIDClose,
- logoutClose,
- logoutOthersClose,
- logoutAllClose,
- }
- return g.SomeFnError(fns)
+ return g.SomeFnError(
+ byEmailClose,
+ byTokenClose,
+ registerClose,
+ confirmClose,
+ loginClose,
+ refreshClose,
+ resetClose,
+ changeClose,
+ byUUIDClose,
+ logoutClose,
+ outOthersClose,
+ outAllClose,
+ )
}
+ // FIXME: lock
return queriesT{
- userByEmail: userByEmail,
- userByToken: userByToken,
- register: register,
- login: login,
- close: close,
- refresh: refresh,
- resetPassword: resetPassword,
- resetAndConfirm: resetAndConfirm,
- changePassword: changePassword,
- sessionByUUID: sessionByUUID,
- logout: logout,
- logoutOthers: logoutOthers,
- logoutAll: logoutAll,
+ byEmail: byEmail,
+ byToken: byToken,
+ register: register,
+ confirm: confirm,
+ login: login,
+ refresh: refresh,
+ reset: reset,
+ change: change,
+ byUUID: byUUID,
+ logout: logout,
+ outOthers: outOthers,
+ outAll: outAll,
+ close: close,
}, nil
}
-func newUserPayload(email string, salt []byte, pwhash []byte) ([]byte, error) {
- data := make(map[string]interface{})
- data["email"] = email
- data["salt"] = hex.EncodeToString(salt)
- data["pwhash"] = hex.EncodeToString(pwhash)
- return json.Marshal(data)
+func newUserHandler(auth authT) func(q.Message) error {
+ return func(message q.Message) error {
+ return nil
+ }
}
-func publishRegister(
- q liteq.Queue,
- email string,
- salt []byte,
- pwhash []byte,
- flowId []byte,
-) error {
- payload, err := newUserPayload(email, salt, pwhash)
- if err != nil {
- return err
+func sendConfirmationRequestHandler(auth authT) func(q.Message) error {
+ return func(message q.Message) error {
+ return nil
}
+}
- return q.Publish(NEW_USER, payload, flowId)
+func forgotPasswordRequestHandler(auth authT) func(q.Message) error {
+ return func(message q.Message) error {
+ return nil
+ }
}
-func register(
- auth Auth,
- email string,
- salt []byte,
- pwhash []byte,
-) (userT, error) {
- flowId := guuid.NewBytes()
- waiter := auth.q.WaitFor(NEW_USER, flowId)
- defer auth.q.Unwait(waiter)
+var consumers = []consumerT{
+ consumerT{
+ topic: NEW_USER,
+ handlerFn: newUserHandler,
+ },
+ consumerT{
+ topic: SEND_CONFIRMATION_REQUEST,
+ handlerFn: sendConfirmationRequestHandler,
+ },
+ consumerT{
+ topic: FORGOT_PASSWORD_REQUEST,
+ handlerFn: forgotPasswordRequestHandler,
+ },
+}
+func registerConsumers(auth authT, consumers []consumerT) {
+ for _, consumer := range consumers {
+ auth.queue.Subscribe(
+ consumer.topic,
+ defaultPrefix + "-" + consumer.topic,
+ consumer.handlerFn(auth),
+ )
+ }
+}
+
+type resultT[T any] struct{
+ value T
+ err error
+}
+
+type taggedT[T any] struct{
+ id int
+ value T
+}
+
+func startRunner[A any, B any](
+ in <-chan taggedT[A],
+ out chan<- taggedT[B],
+ fn func(A) B,
+ done func(),
+) {
+ for input := range in {
+ out <- taggedT[B]{
+ id: input.id,
+ value: fn(input.value),
+ }
+ }
+ done()
+}
+
+func makePoolRunner[A any, B any](count int, fn func(A) B) (func(A) B, func()) {
+ var wg sync.WaitGroup
+ wg.Add(count)
+
+ in := make(chan taggedT[A])
+ out := make(chan taggedT[B])
+
+ for _ = range count {
+ go startRunner(in, out, fn, wg.Done)
+ }
+
+ var mutex sync.Mutex
+ m := map[int]chan B{}
+ id := 0
+ go func() {
+ for output := range out {
+ mutex.Lock()
+ defer mutex.Unlock()
+ m[output.id] <- output.value
+ close(m[output.id])
+ delete(m, output.id)
+ }
+ }()
+
+ poolRunFn := func(input A) B {
+ c := make(chan B)
+ {
+ mutex.Lock()
+ defer mutex.Unlock()
+ m[id] = c
+ id++
+ }
+
+ in <- taggedT[A]{
+ id: id,
+ value: input,
+ }
+ return <- c
+ }
+
+ close := func() {
+ close(in)
+ wg.Wait()
+ close(out)
+ }
+
+ return poolRunFn, close
+}
+
+func asResult[A any, B any](fn func(A) (B, error)) func(A) resultT[B] {
+ return func(input A) resultT[B] {
+ output, err := fn(input)
+ return resultT[B]{
+ value: output,
+ err: err,
+ }
+ }
+}
- err := publishRegister(auth.q, email, salt, pwhash, flowId)
+func NewWithPrefix(db *sql.DB, queue q.IQueue, prefix string) (IAuth, error) {
+ queries, err := initDB(db, prefix)
if err != nil {
- return userT{}, err
+ return authT{}, err
}
- select {
- case <-time.After(RegisterTimeout):
- return userT{}, ErrRegisterTimeout
- case <-waiter:
- return auth.queries.userByEmail(email)
+ numCPU := (runtime.NumCPU() / 2) + 1
+ hasher, closeHasher := makePoolRunner(numCPU, asResult(scrypt.Hash))
+ checker, closeChecker := makePoolRunner(numCPU, asResult(scrypt.Check))
+
+ close := func() {
+ closeHasher()
+ closeChecker()
}
+
+ return authT{
+ queries: queries,
+ queue: queue,
+ hasher: hasher,
+ checker: checker,
+ close: close,
+ }, nil
+}
+
+func New(db *sql.DB, queue q.IQueue) (IAuth, error) {
+ return NewWithPrefix(db, queue, defaultPrefix)
}
-func (auth Auth) Register(
+func newUserPayload(email string, salt []byte, pwhash []byte) ([]byte, error) {
+ data := make(map[string]interface{})
+ data["email"] = email
+ data["salt"] = hex.EncodeToString(salt)
+ data["pwhash"] = hex.EncodeToString(pwhash)
+ return json.Marshal(data)
+}
+
+func (auth authT) Register(
email string,
password string,
confirmPassword string,
) (userT, error) {
if password != confirmPassword {
- return userT{}, ErrPasswordMismatch
+ return userT{}, ErrPassMismatch
}
if len(password) < scrypt.MinimumPasswordLength {
- return userT{}, ErrPasswordTooShort
+ return userT{}, ErrTooShort
}
// special check for sql.ErrNoRows to combat enumeration attacks.
- _, lookupErr := auth.queries.userByEmail(email)
+ // FIXME: how so?
+ _, lookupErr := auth.queries.byEmail(email)
if lookupErr != nil && lookupErr != sql.ErrNoRows {
return userT{}, lookupErr
}
- salt, err := scrypt.SaltFrom(rand.Reader)
+ salt, err := scrypt.Salt()
if err != nil {
return userT{}, err
}
- pwhash, err := scrypt.HashFrom([]byte(password), salt)
- if err != nil {
- return userT{}, err
+ input := scrypt.HashInput{
+ Password: []byte(password),
+ Salt: salt,
}
+ result := auth.hasher(input)
/*
We also try to register anyway, to prevent disk IO timing attacks.
+ / FIXME: how so?
*/
- user, err := register(auth, email, salt, pwhash)
+
+ flowID := guuid.New()
+ waiter := auth.queue.WaitFor(NEW_USER, flowID, "register")
+ defer waiter.Close()
+
+ payload, err := newUserPayload(email, salt, result.value)
+ if err != nil {
+ return userT{}, err
+ }
+
+ unsent := q.UnsentMessage{
+ Topic: NEW_USER,
+ FlowID: flowID,
+ Payload: payload,
+ }
+ _, err = auth.queue.Publish(unsent)
if err != nil {
- if lookupErr != nil {
- return userT{}, ErrAlreadyRegistered
- }
return userT{}, err
}
+ user := userT{}
+ select {
+ case <-time.After(RegisterTimeout):
+ err = ErrTimeout
+ case <-waiter.Channel:
+ user, err = auth.queries.byEmail(email)
+ }
+
return user, nil
}
-func sendConfirmationPayload(email string) ([]byte, error) {
+func sendConfirmationMessage(
+ email string,
+ flowID guuid.UUID,
+) (q.UnsentMessage, error) {
data := make(map[string]interface{})
data["email"] = email
- return json.Marshal(data)
+ payload, err := json.Marshal(data)
+ if err != nil {
+ return q.UnsentMessage{}, err
+ }
+
+ return q.UnsentMessage{
+ Topic: SEND_CONFIRMATION_REQUEST,
+ FlowID: flowID,
+ Payload: payload,
+ }, nil
}
-func publishSendConfirmation(q liteq.Queue, email string, flowId []byte) error {
- payload, err := sendConfirmationPayload(email)
+func (auth authT) ResendConfirmation(email string) error {
+ unsent, err := sendConfirmationMessage(email, guuid.New())
if err != nil {
return err
}
- return q.Publish(SEND_CONFIRMATION_REQUEST, payload, flowId)
-}
-func (auth Auth) ResendConfirmation(email string) error {
- return publishSendConfirmation(auth.q, email, guuid.NewBytes())
+ _, err = auth.queue.Publish(unsent)
+ return err
}
-func (auth Auth) ConfirmEmail(token guuid.UUID) (sessionT, error) {
- return auth.queries.confirm(token[:])
+func (auth authT) ConfirmEmail(token guuid.UUID) (sessionT, error) {
+ return auth.queries.confirm(token)
}
-func (auth Auth) LoginEmail(
+func (auth authT) LoginEmail(
email string,
password string,
) (sessionT, error) {
if len(password) < scrypt.MinimumPasswordLength {
- return sessionT{}, ErrPasswordTooShort
+ return sessionT{}, ErrTooShort
}
// special check for sql.ErrNoRows to combat enumeration attacks.
- user, err := auth.queries.userByEmail(email)
+ user, err := auth.queries.byEmail(email)
if err != nil && err != sql.ErrNoRows {
return sessionT{}, err
}
- ok, err := scrypt.CheckFrom([]byte(password), user.salt, user.pwhash)
+ input := scrypt.CheckInput{
+ Password: []byte(password),
+ Salt: user.salt,
+ Hash: user.pwhash,
+ }
+ ok, err := scrypt.Check(input)
if err != nil {
return sessionT{}, err
}
if !ok {
- return sessionT{}, ErrEmailPasswordCombo
+ return sessionT{}, ErrBadCombo
}
if user.confirmed_at == nil {
- return sessionT{}, ErrUnconfirmedUser
+ return sessionT{}, ErrUnconfirmed
}
dbSession, err := auth.queries.login(email, password)
@@ -770,29 +1035,38 @@ func (auth Auth) LoginEmail(
return dbSession, nil
}
-func forgotPasswordPayload(email string) ([]byte, error) {
+func forgotPasswordMessage(
+ email string,
+ flowID guuid.UUID,
+) (q.UnsentMessage, error) {
data := make(map[string]interface{})
data["email"] = email
- return json.Marshal(data)
-}
-
-func publishForgotPassword(q liteq.Queue, email string, flowId []byte) error {
- payload, err := forgotPasswordPayload(email)
+ payload, err := json.Marshal(data)
if err != nil {
- return err
+ return q.UnsentMessage{}, err
}
- return q.Publish(FORGOT_PASSWORD_REQUEST, payload, flowId)
+ return q.UnsentMessage{
+ Topic: FORGOT_PASSWORD_REQUEST,
+ FlowID: flowID,
+ Payload: payload,
+ }, nil
}
-func (auth Auth) ForgotPassword(email string) error {
+func (auth authT) ForgotPassword(email string) error {
// special check for sql.ErrNoRows to combat enumeration attacks.
- user, err := auth.queries.userByEmail(email)
+ user, err := auth.queries.byEmail(email)
if err != nil && err != sql.ErrNoRows {
return err
}
- return publishForgotPassword(auth.q, user.email, guuid.NewBytes())
+ unsent, err := forgotPasswordMessage(user.email, guuid.New())
+ if err != nil {
+ return err
+ }
+
+ _, err = auth.queue.Publish(unsent)
+ return err
}
func checkSession(session sessionT, now time.Time) error {
@@ -807,8 +1081,11 @@ func checkSession(session sessionT, now time.Time) error {
return nil
}
-func validateSession(lookupFn func([]byte) (sessionT, error), session sessionT) error {
- dbSession, err := lookupFn(session.uuid[:])
+func validateSession(
+ lookupFn func(guuid.UUID) (sessionT, error),
+ session sessionT,
+) error {
+ dbSession, err := lookupFn(session.uuid)
if err != nil {
return err
}
@@ -816,184 +1093,139 @@ func validateSession(lookupFn func([]byte) (sessionT, error), session sessionT)
return checkSession(dbSession, time.Now())
}
-func (auth Auth) Refresh(session sessionT) (sessionT, error) {
- err := validateSession(auth.queries.sessionByUUID, session)
+func (auth authT) Refresh(session sessionT) (sessionT, error) {
+ err := validateSession(auth.queries.byUUID, session)
if err != nil {
return sessionT{}, err
}
- return auth.queries.refresh(session.uuid[:])
+ return auth.queries.refresh(session.uuid)
}
-func (auth Auth) ResetPassword(
- token []byte,
+func (auth authT) ResetPassword(
+ token guuid.UUID,
password string,
confirmPassword string,
) (sessionT, error) {
if password != confirmPassword {
- return sessionT{}, ErrPasswordMismatch
+ return sessionT{}, ErrPassMismatch
}
if len(password) < scrypt.MinimumPasswordLength {
- return sessionT{}, ErrPasswordTooShort
+ return sessionT{}, ErrTooShort
}
- user, err := auth.queries.userByToken(token)
+ user, err := auth.queries.byToken(token)
if err != nil {
return sessionT{}, err
}
- pwhash, err := scrypt.HashFrom([]byte(password), user.salt)
+ input := scrypt.HashInput{
+ Password: []byte(password),
+ Salt: user.salt,
+ }
+ pwhash, err := scrypt.Hash(input)
if err != nil {
return sessionT{}, err
}
if user.confirmed_at != nil {
- return auth.queries.resetPassword(user.id, pwhash, token)
+ return auth.queries.reset(user.id, pwhash, token)
} else {
- return auth.queries.resetAndConfirm(user.id, pwhash, token)
+ // return auth.queries.confirm(user.id, pwhash, token)
+ return auth.queries.confirm(token)
}
}
-func (auth Auth) ChangePassword(
+func (auth authT) ChangePassword(
user userT,
currentPassword string,
newPassword string,
confirmNewPassword string,
) (sessionT, error) {
if newPassword != confirmNewPassword {
- return sessionT{}, ErrPasswordMismatch
+ return sessionT{}, ErrPassMismatch
}
if len(newPassword) < scrypt.MinimumPasswordLength {
- return sessionT{}, ErrPasswordTooShort
+ return sessionT{}, ErrTooShort
}
- pwhash, err := scrypt.HashFrom([]byte(newPassword), user.salt)
+ // FIXME
+ input := scrypt.HashInput{
+ Password: []byte(newPassword),
+ Salt: user.salt,
+ }
+ pwhash, err := scrypt.Hash(input)
if err != nil {
return sessionT{}, err
}
if user.confirmed_at == nil {
- return sessionT{}, ErrUnconfirmedUser
+ return sessionT{}, ErrUnconfirmed
}
- return auth.queries.changePassword(user.id, pwhash)
+ return auth.queries.change(user.id, pwhash)
}
func runLogout(
- lookupFn func([]byte) (sessionT, error),
+ lookupFn func(guuid.UUID) (sessionT, error),
session sessionT,
- queryFn func([]byte) error,
+ queryFn func(guuid.UUID) error,
) error {
err := validateSession(lookupFn, session)
if err != nil {
return err
}
- return queryFn(session.uuid[:])
+ return queryFn(session.uuid)
}
-func (auth Auth) Logout(session sessionT) error {
- return runLogout(auth.queries.sessionByUUID, session, auth.queries.logout)
+func (auth authT) Logout(session sessionT) error {
+ return runLogout(
+ auth.queries.byUUID,
+ session,
+ auth.queries.logout,
+ )
}
-func (auth Auth) LogoutOthers(session sessionT) error {
- return runLogout(auth.queries.sessionByUUID, session, auth.queries.logoutOthers)
+func (auth authT) LogoutOthers(session sessionT) error {
+ return runLogout(
+ auth.queries.byUUID,
+ session,
+ auth.queries.outOthers,
+ )
}
-func (auth Auth) LogoutAll(session sessionT) error {
- return runLogout(auth.queries.sessionByUUID, session, auth.queries.logoutAll)
+func (auth authT) LogoutAll(session sessionT) error {
+ return runLogout(
+ auth.queries.byUUID,
+ session,
+ auth.queries.outAll,
+ )
}
-func (auth Auth) Close() error {
+func (auth authT) Close() error {
return auth.queries.close()
}
-func newUserHandler(auth Auth) func([]byte) error {
- return func(payload []byte) error {
- return nil
- }
-}
-
-func sendConfirmationRequestHandler(auth Auth) func([]byte) error {
- return func(payload []byte) error {
- return nil
- }
-}
-
-func forgotPasswordRequestHandler(auth Auth) func([]byte) error {
- return func(payload []byte) error {
- return nil
- }
-}
-
-var consumers = []consumerT{
- consumerT{
- topic: NEW_USER,
- handlerFn: newUserHandler,
- },
- consumerT{
- topic: SEND_CONFIRMATION_REQUEST,
- handlerFn: sendConfirmationRequestHandler,
- },
- consumerT{
- topic: FORGOT_PASSWORD_REQUEST,
- handlerFn: forgotPasswordRequestHandler,
- },
-}
-func registerConsumers(auth Auth, consumers []consumerT) {
- for _, consumer := range consumers {
- auth.q.Subscribe(
- consumer.topic,
- defaultPrefix + "-" + consumer.topic,
- consumer.handlerFn(auth),
- )
- }
-}
-
-func NewWithPrefix(db *sql.DB, q liteq.Queue, prefix string) (Auth, error) {
- tables, err := tablesFrom(prefix)
- if err != nil {
- return Auth{}, err
- }
-
- queries, err := initDB(db, tables)
- if err != nil {
- return Auth{}, err
- }
-
- return Auth{
- tables: tables,
- queries: queries,
- q: q,
- }, nil
-}
-
-func New(db *sql.DB, q liteq.Queue) (Auth, error) {
- return NewWithPrefix(db, q, defaultPrefix)
-}
-
func Main() {
g.Init()
- q := new(liteq.Queue)
- sql.Register("sqlite-liteq", liteq.MakeDriver(q))
-
- db, err := sql.Open("sqlite-liteq", "file:gracha.db?mode=memory&cache=shared")
+ db, err := sql.Open("acude", "file:gracha.db?mode=memory&cache=shared")
if err != nil {
panic(err)
}
- // defer db.Close()
+ defer db.Close()
- *q, err = liteq.New(db)
+ queue, err := q.New(db)
if err != nil {
panic(err)
}
- // defer q.Close()
+ defer queue.Close()
- auth, err := New(db, *q)
+ auth, err := New(db, queue)
if err != nil {
fmt.Println(err)
panic(err)
@@ -1006,6 +1238,6 @@ func Main() {
}
return
- fmt.Printf("q: %#v\n", q)
+ fmt.Printf("q: %#v\n", queue)
fmt.Printf("auth: %#v\n", auth)
}
diff --git a/tests/benchmarks/register-login/gracha.go b/tests/benchmarks/register-login/gracha.go
new file mode 100644
index 0000000..f363b6b
--- /dev/null
+++ b/tests/benchmarks/register-login/gracha.go
@@ -0,0 +1,9 @@
+package gracha
+
+import (
+)
+
+
+
+func MainTest() {
+}
diff --git a/tests/benchmarks/register-login/main.go b/tests/benchmarks/register-login/main.go
new file mode 120000
index 0000000..f67563d
--- /dev/null
+++ b/tests/benchmarks/register-login/main.go
@@ -0,0 +1 @@
+../../main.go \ No newline at end of file
diff --git a/tests/functional/register-twice/gracha.go b/tests/functional/register-twice/gracha.go
new file mode 100644
index 0000000..f363b6b
--- /dev/null
+++ b/tests/functional/register-twice/gracha.go
@@ -0,0 +1,9 @@
+package gracha
+
+import (
+)
+
+
+
+func MainTest() {
+}
diff --git a/tests/functional/register-twice/main.go b/tests/functional/register-twice/main.go
new file mode 120000
index 0000000..f67563d
--- /dev/null
+++ b/tests/functional/register-twice/main.go
@@ -0,0 +1 @@
+../../main.go \ No newline at end of file
diff --git a/tests/fuzz/api/gracha.go b/tests/fuzz/api/gracha.go
new file mode 100644
index 0000000..6f2981a
--- /dev/null
+++ b/tests/fuzz/api/gracha.go
@@ -0,0 +1,35 @@
+package gracha
+
+import (
+ "os"
+ "testing"
+ "testing/internal/testdeps"
+)
+
+
+
+func api(f *testing.F) {
+ f.Fuzz(func(t *testing.T, n int) {
+ // FIXME
+ if n > 1 {
+ if n < 2 {
+ t.Errorf("Failed n: %v\n", n)
+ }
+ }
+ })
+}
+
+
+
+func MainTest() {
+ fuzzTargets := []testing.InternalFuzzTarget{
+ { "api", api },
+ }
+
+ deps := testdeps.TestDeps{}
+ tests := []testing.InternalTest {}
+ benchmarks := []testing.InternalBenchmark{}
+ examples := []testing.InternalExample {}
+ m := testing.MainStart(deps, tests, benchmarks, fuzzTargets, examples)
+ os.Exit(m.Run())
+}
diff --git a/tests/fuzz/api/main.go b/tests/fuzz/api/main.go
new file mode 120000
index 0000000..f67563d
--- /dev/null
+++ b/tests/fuzz/api/main.go
@@ -0,0 +1 @@
+../../main.go \ No newline at end of file
diff --git a/tests/gracha.go b/tests/gracha.go
index 18b54d6..0c11fac 100644
--- a/tests/gracha.go
+++ b/tests/gracha.go
@@ -1,15 +1,15 @@
package gracha
import (
- "database/sql"
+ // "database/sql"
- "liteq"
+ // "q"
g "gobang"
)
type testAuth struct{
- auth Auth
+ auth authT
// registerEmail func(credentials)
close func() error
}
@@ -18,55 +18,28 @@ func test_defaultPrefix() {
g.TestStart("defaultPrefix")
g.Testing("the defaultPrefix is valid", func() {
- g.TAssertEqual(g.ValidSQLTablePrefix(defaultPrefix), true)
- })
-}
-
-func test_tablesFrom() {
- g.TestStart("tablesFrom()")
-
- g.Testing("prefix needs to be valid", func() {
- _, err := tablesFrom("invalid-prefix")
- g.TAssertEqual(err, g.ErrBadSQLTablePrefix)
- })
-
- g.Testing("the struct adds suffixes", func() {
- t, err := tablesFrom(defaultPrefix)
- g.TAssertEqual(err, nil)
- g.TAssertEqual(t, tablesT{
- users: "gracha-users",
- userChanges: "gracha-user-changes",
- tokens: "gracha-tokens",
- roles: "gracha-roles",
- roleChanges: "gracha-role-changes",
- sessions: "gracha-sessions",
- attempts: "gracha-attempts",
- audit: "gracha-audit",
- })
+ g.TErrorIf(g.ValidateSQLTablePrefix(defaultPrefix))
})
}
+/*
func mkauth() testAuth {
- q := new(liteq.Queue)
- sql.Register("sqlite-liteq", liteq.MakeDriver(q))
-
- db, err := sql.Open("sqlite-liteq", "file:db?mode=memory&cache=shared")
+ db, err := sql.Open("acude", "file:db?mode=memory&cache=shared")
g.TAssertEqual(err, nil)
- *q, err = liteq.New(db)
+ queue, err := q.New(db)
g.TAssertEqual(err, nil)
- auth, err := New(db, *q)
+ auth, err := New(db, queue)
g.TAssertEqual(err, nil)
- fns := [](func() error){
- db.Close,
- q.Close,
- }
return testAuth{
auth: auth,
close: func() error {
- return g.SomeFnError(fns)
+ return g.SomeFnError(
+ db.Close,
+ queue.Close,
+ )
},
}
}
@@ -97,11 +70,12 @@ func test_Register() {
g.Testing("we can't register duplicate emails", func() {
})
}
+*/
func MainTest() {
g.Init()
test_defaultPrefix()
- test_tablesFrom()
- test_Register()
+ // test_tablesFrom()
+ // test_Register()
}
diff --git a/tests/queries.sql b/tests/queries.sql
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/queries.sql