summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile1
-rw-r--r--src/gracha.go720
-rw-r--r--tests/gracha.go958
-rw-r--r--tests/queries.sql205
4 files changed, 1623 insertions, 261 deletions
diff --git a/Makefile b/Makefile
index 2ea93b9..9d00d8d 100644
--- a/Makefile
+++ b/Makefile
@@ -91,6 +91,7 @@ $(NAME).bin: src/main.bin
.PRECIOUS: tests/queries.sql
tests/queries.sql: tests/main.bin ALWAYS
+ env TESTING_DUMP_SQL_QUERIES=1 $(EXEC)tests/main.bin | ifnew $@
env TESTING_DUMP_SQL_QUERIES=1 $(EXEC)tests/main.bin | diff -U10 $@ -
tests.bin-check = \
diff --git a/src/gracha.go b/src/gracha.go
index e300f66..c240dd3 100644
--- a/src/gracha.go
+++ b/src/gracha.go
@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"runtime"
+ "slices"
"sync"
"time"
@@ -27,7 +28,6 @@ const (
SEND_CONFIRMATION_REQUEST = "send-confirmation-request"
FORGOT_PASSWORD_REQUEST = "forgot-password-request"
-
day = 24 * time.Hour
)
@@ -48,17 +48,18 @@ var (
type queryT struct{
- write string
- read string
+ write string
+ read string
+ session string
}
type queriesT struct{
+ register func(guuid.UUID, string, []byte, []byte) (userT, error)
+ sendToken func(guuid.UUID, string) error
+ confirm func(string, guuid.UUID) (sessionT, error)
byEmail func(string) (userT, error)
- byToken func(guuid.UUID) (userT, error)
- register func(string, []byte, []byte) (userT, error)
- confirm func(guuid.UUID) (sessionT, error)
- login func(string, string) (sessionT, error)
- refresh func(guuid.UUID) (sessionT, error)
+ login func(guuid.UUID, guuid.UUID) (sessionT, error)
+ refresh func(guuid.UUID, guuid.UUID) (sessionT, error)
reset func(int64, []byte, guuid.UUID) (sessionT, error)
change func(int64, []byte) (sessionT, error)
byUUID func(guuid.UUID) (sessionT, error)
@@ -68,28 +69,23 @@ type queriesT struct{
close func() error
}
-type confirmationT struct{
-}
-
type userT struct{
id int64
timestamp time.Time
uuid guuid.UUID
email string
- username *string
salt []byte
pwhash []byte
- confirmed_at *time.Time
metadata map[string]interface{}
+ confirmed bool
}
type sessionT struct{
id int64
- timestr string
timestamp time.Time
uuid guuid.UUID
- user_id int64
- type_ string
+ userID guuid.UUID
+ // type_ string
revoked_at *time.Time
metadata map[string]interface{}
}
@@ -102,15 +98,14 @@ type consumerT struct{
type authT struct{
queries queriesT
queue q.IQueue
- hasher func(scrypt.HashInput) resultT[[]byte]
- checker func(scrypt.CheckInput) resultT[bool]
+ hasher func(scrypt.HashInput) ([]byte, error)
close func()
}
type IAuth interface{
Register(string, string, string) (userT, error)
ResendConfirmation(string) error
- ConfirmEmail(guuid.UUID) (sessionT, error)
+ ConfirmEmail(string) (sessionT, error)
LoginEmail(string, string) (sessionT, error)
ForgotPassword(string) error
Refresh(sessionT) (sessionT, error)
@@ -165,13 +160,25 @@ func createTablesSQL(prefix string) queryT {
timestamp TEXT NOT NULL DEFAULT (%s),
uuid BLOB NOT NULL UNIQUE,
email TEXT NOT NULL UNIQUE,
- username TEXT UNIQUE,
salt BLOB NOT NULL UNIQUE,
pwhash BLOB NOT NULL,
- confirmed_at TEXT,
- confirmer_id INTEGER REFERENCES "%s"(id),
metadata TEXT
);
+ CREATE TABLE IF NOT EXISTS "%s_confirmation_attempts" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ -- uuid BLOB NOT NULL UNIQUE,
+ user_id INTEGER NOT NULL REFERENCES "%s_users"(id),
+ token TEXT NOT NULL UNIQUE
+ );
+ CREATE TABLE IF NOT EXISTS "%s_user_confirmations" (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (%s),
+ user_id INTEGER NOT NULL
+ REFERENCES "%s_users"(id) UNIQUE,
+ attempt_id INTEGER NOT NULL
+ REFERENCES "%s_confirmation_attempts"(id) UNIQUE
+ );
CREATE TABLE IF NOT EXISTS "%s_user_changes" (
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL DEFAULT (%s),
@@ -180,12 +187,12 @@ func createTablesSQL(prefix string) queryT {
value TEXT NOT NULL,
op BOOLEAN NOT NULL
);
- CREATE TABLE IF NOT EXISTS "%s_tokens" (
- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
- timestamp TEXT NOT NULL DEFAULT (%s),
- uuid BLOB NOT NULL UNIQUE,
- type TEXT NOT NULL
- );
+ -- CREATE TABLE IF NOT EXISTS "%s_tokens" (
+ -- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ -- timestamp TEXT NOT NULL DEFAULT (%s),
+ -- uuid BLOB NOT NULL UNIQUE,
+ -- type TEXT NOT NULL
+ -- );
CREATE TABLE IF NOT EXISTS "%s_roles" (
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES "%s_users"(id),
@@ -205,9 +212,10 @@ func createTablesSQL(prefix string) queryT {
timestamp TEXT NOT NULL DEFAULT (%s),
uuid BLOB NOT NULL UNIQUE,
user_id INTEGER NOT NULL REFERENCES "%s_users"(id),
- type TEXT NOT NULL,
- revoked_at TEXT,
- revoker_id INTEGER REFERENCES "%s_users"(id),
+ -- type TEXT NOT NULL,
+ -- revoked_at TEXT,
+ -- revoker_id INTEGER REFERENCES "%s_users"(id),
+ -- FIXME: add provenance: login, refresh, confirmation, etc.
metadata TEXT
);
CREATE TABLE IF NOT EXISTS "%s_attempts" (
@@ -233,6 +241,12 @@ func createTablesSQL(prefix string) queryT {
prefix,
g.SQLiteNow,
prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
+ g.SQLiteNow,
+ prefix,
+ prefix,
prefix,
g.SQLiteNow,
prefix,
@@ -266,50 +280,68 @@ func createTables(db *sql.DB, prefix string) error {
})
}
-func byEmailSQL(prefix string) queryT {
+func registerSQL(prefix string) queryT {
+ const tmpl_write = `
+ INSERT INTO "%s_users" (uuid, email, salt, pwhash)
+ VALUES (?, ?, ?, ?) RETURNING id, timestamp;
+ `
const tmpl_read = `
- SELECT id, timestamp, uuid, email, username, pwhash, metadata
- FROM "%s_users" WHERE email = ?;
+ SELECT id, timestamp from "%s_users"
+ WHERE uuid = ?;
`
return queryT{
- read: fmt.Sprintf(tmpl_read, prefix),
+ write: fmt.Sprintf(tmpl_write, prefix),
+ read: fmt.Sprintf(tmpl_read, prefix),
}
}
-func byEmailStmt(
+func registerStmt(
db *sql.DB,
prefix string,
-) (func(string) (userT, error), func() error, error) {
- q := byEmailSQL(prefix)
+) (
+ func(guuid.UUID, string, []byte, []byte) (userT, error),
+ func() error,
+ error,
+) {
+ q := registerSQL(prefix)
+
+ writeStmt, err := db.Prepare(q.write)
+ if err != nil {
+ return nil, nil, err
+ }
readStmt, err := db.Prepare(q.read)
if err != nil {
return nil, nil, err
}
- fn := func(email string) (userT, error) {
+ fn := func(
+ userID guuid.UUID,
+ email string,
+ salt []byte,
+ pwhash []byte,
+ ) (userT, error) {
user := userT{
- email: email,
+ uuid: userID,
+ email: email,
+ salt: salt,
+ pwhash: pwhash,
+ confirmed: false,
}
- var (
- timestr string
- uuid_bytes []byte
- )
- err := readStmt.QueryRow(email).Scan(
- &user.id,
- &timestr,
- &uuid_bytes,
- &user.username,
- &user.pwhash,
- &user.metadata,
- )
+ var timestr string
+ user_id_bytes := userID[:]
+ err := writeStmt.QueryRow(
+ user_id_bytes,
+ email,
+ salt,
+ pwhash,
+ ).Scan(&user.id, &timestr)
if err != nil {
return userT{}, err
}
- user.uuid = guuid.UUID(uuid_bytes)
- user.timestamp, err = time.Parse(time.RFC3339Nano,timestr)
+ user.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
if err != nil {
return userT{}, err
}
@@ -317,165 +349,329 @@ func byEmailStmt(
return user, nil
}
- return fn, readStmt.Close, nil
+ closeFn := func() error {
+ return g.SomeFnError(writeStmt.Close, readStmt.Close)
+ }
+
+ return fn, closeFn, nil
}
-func byTokenSQL(prefix string) queryT{
- const tmpl_read = `
- SELECT id, timestamp, uuid, email, username, pwhash, metadata
- FROM "%s" WHERE email = ?;
+func sendTokenSQL(prefix string) queryT {
+ const tmpl_write = `
+ INSERT INTO "%s_confirmation_attempts" (user_id, token)
+ VALUES (
+ (SELECT id FROM "%s_users" WHERE uuid = ?),
+ ?
+ )
`
return queryT{
- read: fmt.Sprintf(tmpl_read, prefix),
+ write: fmt.Sprintf(tmpl_write, prefix, prefix),
}
}
-func byTokenStmt(
+func sendTokenStmt(
db *sql.DB,
prefix string,
-) (func(guuid.UUID) (userT, error), func() error, error) {
- q := byTokenSQL(prefix)
+) (func(guuid.UUID, string) error, func() error, error) {
+ q := sendTokenSQL(prefix)
- readStmt, err := db.Prepare(q.read)
+ writeStmt, err := db.Prepare(q.write)
if err != nil {
return nil, nil, err
}
- fn := func(token guuid.UUID) (userT, error) {
- var user userT
- // FIXME: build user
- err := readStmt.QueryRow(token).Scan(&user.id)
- return user, err
+ fn := func(userID guuid.UUID, token string) error {
+ user_id_bytes := userID[:]
+ _, err := writeStmt.Exec(user_id_bytes, token)
+ return err
}
- return fn, readStmt.Close, nil
+ return fn, writeStmt.Close, nil
}
-func registerSQL(prefix string) queryT {
+func confirmSQL(prefix string) queryT {
const tmpl_write = `
- INSERT INTO "%s" (uuid, email, username, salt, pwhash, metadata)
- VALUES (?, ?, ?, ?, ?, ?);
+ INSERT INTO "%s_user_confirmations" (user_id, attempt_id)
+ VALUES (?, ?);
+ `
+ const tmpl_read = `
+ SELECT
+ "%s_confirmation_attempts".id,
+ "%s_confirmation_attempts".user_id,
+ "%s_users".uuid
+ FROM "%s_confirmation_attempts"
+ JOIN "%s_users" ON
+ "%s_confirmation_attempts".user_id = "%s_users".id
+ WHERE token = ?;
+ `
+ const tmpl_session = `
+ INSERT INTO "%s_sessions" (uuid, user_id)
+ VALUES (?, ?) RETURNING id, timestamp;
`
return queryT{
- write: fmt.Sprintf(tmpl_write, prefix),
+ write: fmt.Sprintf(tmpl_write, prefix),
+ read: fmt.Sprintf(
+ tmpl_read,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
+ session: fmt.Sprintf(tmpl_session, prefix),
}
}
-func registerStmt(
+func confirmStmt(
db *sql.DB,
prefix string,
-) (func(string, []byte, []byte) (userT, error), func() error, error) {
- q := registerSQL(prefix)
+) (func(string, guuid.UUID) (sessionT, error), func() error, error) {
+ q := confirmSQL(prefix)
writeStmt, err := db.Prepare(q.write)
if err != nil {
return nil, nil, err
}
- fn := func(email string, salt []byte, pwhash []byte) (userT, error) {
- /*
- timestamp TEXT NOT NULL DEFAULT (%s),
- uuid BLOB NOT NULL UNIQUE,
- email TEXT NOT NULL UNIQUE,
- username TEXT UNIQUE,
- pwhash TEXT NOT NULL,
- metadata TEXT
- */
+ readStmt, err := db.Prepare(q.read)
+ if err != nil {
+ return nil, nil, g.WrapErrors(writeStmt.Close(), err)
+ }
- var user userT
- // err := stmt.QueryRow(
- ret, err := writeStmt.Exec(
- guuid.New(),
- email,
- "credentials.username",
- salt,
- pwhash,
- "credentials.metadata",
- // ).Scan(&user.email)
- // FIXME: finish
+ sessionStmt, err := db.Prepare(q.session)
+ if err != nil {
+ return nil, nil, g.WrapErrors(
+ writeStmt.Close(),
+ readStmt.Close(),
+ err,
)
- if false {
- fmt.Printf("ret: %#v\n", ret)
- fmt.Printf("user: %#v\n", user)
+ }
+
+ fn := func(token string, sessionID guuid.UUID) (sessionT, error) {
+ session := sessionT{
+ uuid: sessionID,
}
- return user, err
+
+ var (
+ user_id int64
+ attempt_id int64
+ user_id_bytes []byte
+ )
+ err := readStmt.QueryRow(token).Scan(
+ &attempt_id,
+ &user_id,
+ &user_id_bytes,
+ )
+ if err != nil {
+ return sessionT{}, err
+ }
+ session.userID = guuid.UUID(user_id_bytes)
+
+ _, err = writeStmt.Exec(user_id, attempt_id)
+ if err != nil {
+ return sessionT{}, err
+ }
+
+ session_id_bytes := sessionID[:]
+ var timestr string
+ err = sessionStmt.QueryRow(
+ session_id_bytes,
+ user_id,
+ ).Scan(&session.id, &timestr)
+ if err != nil {
+ return sessionT{}, err
+ }
+
+ session.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
+ if err != nil {
+ return sessionT{}, err
+ }
+
+ return session, nil
}
- return fn, writeStmt.Close, nil
+ closeFn := func() error {
+ return g.SomeFnError(
+ writeStmt.Close,
+ readStmt.Close,
+ sessionStmt.Close,
+ )
+ }
+
+ return fn, closeFn, nil
}
-func confirmSQL(prefix string) queryT {
- const tmpl_write = `
- -- INSERT SOMETHING %s
+func byEmailSQL(prefix string) queryT {
+ // FIXME: rewrite as LEFT JOIN?
+ const tmpl_read = `
+ SELECT id, timestamp, uuid, salt, pwhash, metadata, (
+ CASE WHEN EXISTS (
+ SELECT id FROM "%s_user_confirmations"
+ WHERE user_id = (
+ SELECT id FROM "%s_users"
+ WHERE email = ?
+ )
+ ) THEN 1
+ ELSE 0
+ END
+ ) as confirmed
+ FROM "%s_users" WHERE email = ?;
`
return queryT{
- write: fmt.Sprintf(tmpl_write, prefix),
+ read: fmt.Sprintf(tmpl_read, prefix, prefix, prefix),
}
}
-func confirmStmt(
+func byEmailStmt(
db *sql.DB,
prefix string,
-) (func(guuid.UUID) (sessionT, error), func() error, error) {
- q := confirmSQL(prefix)
+) (func(string) (userT, error), func() error, error) {
+ q := byEmailSQL(prefix)
- writeStmt, err := db.Prepare(q.write)
+ readStmt, err := db.Prepare(q.read)
if err != nil {
return nil, nil, err
}
- fn := func(token guuid.UUID) (sessionT, error) {
- var session sessionT
- err := writeStmt.QueryRow(token).Scan(&session)
- return session, err
+ fn := func(email string) (userT, error) {
+ user := userT{
+ email: email,
+ }
+
+ var (
+ timestr string
+ user_id_bytes []byte
+ metadatastr sql.NullString
+ )
+ err := readStmt.QueryRow(email, email).Scan(
+ &user.id,
+ &timestr,
+ &user_id_bytes,
+ &user.salt,
+ &user.pwhash,
+ &metadatastr,
+ &user.confirmed,
+ )
+ if err != nil {
+ return userT{}, err
+ }
+ user.uuid = guuid.UUID(user_id_bytes)
+
+ user.timestamp, err = time.Parse(time.RFC3339Nano,timestr)
+ if err != nil {
+ return userT{}, err
+ }
+
+ if metadatastr.Valid {
+ err := json.Unmarshal(
+ []byte(metadatastr.String),
+ &user.metadata,
+ )
+ if err != nil {
+ g.Warning(
+ "failed to parse metadata field",
+ "sqlite-json-unmarshal-error",
+ "userID", user.uuid.String(),
+ "error", err,
+ )
+ }
+ }
+
+ return user, nil
}
- return fn, writeStmt.Close, nil
+ return fn, readStmt.Close, nil
}
func loginSQL(prefix string) queryT {
- const tmpl_write = `
- -- INSERT INTO "%s" (t3, t4) VALUES (?, ?);
+ const tmpl_session = `
+ INSERT INTO "%s_sessions" (uuid, user_id)
+ VALUES (
+ ?,
+ (SELECT id FROM "%s_users" WHERE uuid = ?)
+ ) RETURNING id, timestamp;
`
return queryT{
- write: fmt.Sprintf(tmpl_write, prefix),
+ session: fmt.Sprintf(tmpl_session, prefix, prefix),
}
}
func loginStmt(
db *sql.DB,
prefix string,
-) (func(string, string) (sessionT, error), func() error, error) {
+) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) {
q := loginSQL(prefix)
- writeStmt, err := db.Prepare(q.write)
+ sessionStmt, err := db.Prepare(q.session)
if err != nil {
return nil, nil, err
}
- fn := func(email string, pwhash string) (sessionT, error) {
- var session sessionT
- err := writeStmt.QueryRow(email, pwhash).Scan(session)
- // FIXME: finish
+ fn := func(userID guuid.UUID, sessionID guuid.UUID) (sessionT, error) {
+ session := sessionT{
+ uuid: sessionID,
+ userID: userID,
+ }
+
+ user_id_bytes := userID[:]
+ session_id_bytes := sessionID[:]
+ var timestr string
+ err := sessionStmt.QueryRow(
+ session_id_bytes,
+ user_id_bytes,
+ ).Scan(
+ &session.id,
+ &timestr,
+ )
+ if err != nil {
+ return sessionT{}, err
+ }
+
+ session.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
+ if err != nil {
+ return sessionT{}, err
+ }
+
return session, err
}
- return fn, writeStmt.Close, nil
+ return fn, sessionStmt.Close, nil
}
func refreshSQL(prefix string) queryT {
const tmpl_write = `
- -- INSERT SOMETHING %s
+ INSERT INTO "%s_sessions" (uuid, user_id)
+ VALUES (
+ ?,
+ (SELECT user_id FROM "%s_sessions" WHERE uuid = ?)
+ ) RETURNING id, timestamp, (
+ SELECT "%s_users".uuid FROM "%s_users"
+ JOIN "%s_sessions" ON
+ "%s_users".id = "%s_sessions".user_id
+ WHERE "%s_sessions".uuid = ?
+ ) AS userID;
`
return queryT{
- write: fmt.Sprintf(tmpl_write, prefix),
+ write: fmt.Sprintf(
+ tmpl_write,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ prefix,
+ ),
}
}
func refreshStmt(
db *sql.DB,
prefix string,
-) (func(guuid.UUID) (sessionT, error), func() error, error) {
+) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) {
q := refreshSQL(prefix)
writeStmt, err := db.Prepare(q.write)
@@ -483,9 +679,35 @@ func refreshStmt(
return nil, nil, err
}
- fn := func(uuid guuid.UUID) (sessionT, error) {
- var session sessionT
- err := writeStmt.QueryRow(uuid).Scan(&session)
+ fn := func(
+ sessionID guuid.UUID,
+ newSessionID guuid.UUID,
+ ) (sessionT, error) {
+ session := sessionT{
+ uuid: newSessionID,
+ }
+
+ session_id_bytes := sessionID[:]
+ new_session_id_bytes := newSessionID[:]
+ var (
+ timestr string
+ user_id_bytes []byte
+ )
+ err := writeStmt.QueryRow(
+ new_session_id_bytes,
+ session_id_bytes,
+ session_id_bytes,
+ ).Scan(&session.id, &timestr, &user_id_bytes)
+ if err != nil {
+ return sessionT{}, err
+ }
+ session.userID = guuid.UUID(user_id_bytes)
+
+ session.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
+ if err != nil {
+ return sessionT{}, err
+ }
+
return session, err
}
@@ -570,9 +792,9 @@ func byUUIDStmt(
return nil, nil, err
}
- fn := func(uuid guuid.UUID) (sessionT, error) {
+ fn := func(sessionID guuid.UUID) (sessionT, error) {
var session sessionT
- err := readStmt.QueryRow(uuid).Scan(&session)
+ err := readStmt.QueryRow(sessionID).Scan(&session)
return session, err
}
@@ -599,8 +821,8 @@ func logoutStmt(
return nil, nil, err
}
- fn := func(uuid guuid.UUID) error {
- _, err := writeStmt.Exec(uuid)
+ fn := func(sessionID guuid.UUID) error {
+ _, err := writeStmt.Exec(sessionID)
return err
}
@@ -627,8 +849,8 @@ func outOthersStmt(
return nil, nil, err
}
- fn := func(uuid guuid.UUID) error {
- _, err := writeStmt.Exec(uuid)
+ fn := func(sessionID guuid.UUID) error {
+ _, err := writeStmt.Exec(sessionID)
return err
}
@@ -655,8 +877,8 @@ func outAllStmt(
return nil, nil, err
}
- fn := func(uuid guuid.UUID) error {
- _, err := writeStmt.Exec(uuid)
+ fn := func(sessionID guuid.UUID) error {
+ _, err := writeStmt.Exec(sessionID)
return err
}
@@ -668,10 +890,10 @@ func initDB(
prefix string,
) (queriesT, error) {
createTablesErr := createTables(db, prefix)
- byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix)
- byToken, byTokenClose, byTokenErr := byTokenStmt(db, prefix)
register, registerClose, registerErr := registerStmt(db, prefix)
+ sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix)
confirm, confirmClose, confirmErr := confirmStmt(db, prefix)
+ byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix)
login, loginClose, loginErr := loginStmt(db, prefix)
refresh, refreshClose, refreshErr := refreshStmt(db, prefix)
reset, resetClose, resetErr := resetStmt(db, prefix)
@@ -683,10 +905,10 @@ func initDB(
err := g.SomeError(
createTablesErr,
- byEmailErr,
- byTokenErr,
registerErr,
+ sendTokenErr,
confirmErr,
+ byEmailErr,
loginErr,
refreshErr,
resetErr,
@@ -702,10 +924,10 @@ func initDB(
close := func() error {
return g.SomeFnError(
- byEmailClose,
- byTokenClose,
registerClose,
+ sendTokenClose,
confirmClose,
+ byEmailClose,
loginClose,
refreshClose,
resetClose,
@@ -717,21 +939,78 @@ func initDB(
)
}
- // FIXME: lock
+ var connMutex sync.RWMutex
return queriesT{
- byEmail: byEmail,
- byToken: byToken,
- register: register,
- confirm: confirm,
- login: login,
- refresh: refresh,
- reset: reset,
- change: change,
- byUUID: byUUID,
- logout: logout,
- outOthers: outOthers,
- outAll: outAll,
- close: close,
+ register: func(
+ a guuid.UUID,
+ b string,
+ c []byte,
+ d []byte,
+ ) (userT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return register(a, b, c, d)
+ },
+ sendToken: func(a guuid.UUID, b string) error {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return sendToken(a, b)
+ },
+ confirm: func(a string, b guuid.UUID) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return confirm(a, b)
+ },
+ byEmail: func(a string) (userT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return byEmail(a)
+ },
+ login: func(a guuid.UUID, b guuid.UUID) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return login(a, b)
+ },
+ refresh: func(a guuid.UUID, b guuid.UUID) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return refresh(a, b)
+ },
+ reset: func(a int64, b []byte, c guuid.UUID) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return reset(a, b, c)
+ },
+ change: func(a int64, b []byte) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return change(a, b)
+ },
+ byUUID: func(a guuid.UUID) (sessionT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return byUUID(a)
+ },
+ logout: func(a guuid.UUID) error {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return logout(a)
+ },
+ outOthers: func(a guuid.UUID) error {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return outOthers(a)
+ },
+ outAll: func(a guuid.UUID) error {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return outAll(a)
+ },
+ close: func() error {
+ connMutex.Lock()
+ defer connMutex.Unlock()
+ return close()
+ },
}, nil
}
@@ -767,13 +1046,24 @@ var consumers = []consumerT{
handlerFn: forgotPasswordRequestHandler,
},
}
-func registerConsumers(auth authT, consumers []consumerT) {
+func registerConsumers(auth authT, consumers []consumerT, prefix string) error {
for _, consumer := range consumers {
- auth.queue.Subscribe(
+ err := auth.queue.Subscribe(
consumer.topic,
- defaultPrefix + "-" + consumer.topic,
+ prefix + "-" + consumer.topic,
consumer.handlerFn(auth),
)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func unregisterConsumers(queue q.IQueue, consumers []consumerT, prefix string) {
+ for _, consumer := range consumers {
+ queue.Unsubscribe(consumer.topic, prefix + "-" + consumer.topic)
}
}
@@ -861,28 +1151,41 @@ func asResult[A any, B any](fn func(A) (B, error)) func(A) resultT[B] {
}
}
+func unwrapResult[A any, B any](fn func(A) resultT[B]) func(A) (B, error) {
+ return func(input A) (B, error) {
+ result := fn(input)
+ return result.value, result.err
+ }
+}
+
func NewWithPrefix(db *sql.DB, queue q.IQueue, prefix string) (IAuth, error) {
queries, err := initDB(db, prefix)
if err != nil {
return authT{}, err
}
- numCPU := (runtime.NumCPU() / 2) + 1
- hasher, closeHasher := makePoolRunner(numCPU, asResult(scrypt.Hash))
- checker, closeChecker := makePoolRunner(numCPU, asResult(scrypt.Check))
+ numCPU := runtime.NumCPU()
+ hasher, closeHasher := makePoolRunner(numCPU, asResult(scrypt.Hash))
- close := func() {
+ closeFn := func() {
+ unregisterConsumers(queue, consumers, prefix)
closeHasher()
- closeChecker()
}
- return authT{
+ auth := authT{
queries: queries,
queue: queue,
- hasher: hasher,
- checker: checker,
- close: close,
- }, nil
+ hasher: unwrapResult(hasher),
+ close: closeFn,
+ }
+
+ err = registerConsumers(auth, consumers, prefix)
+ if err != nil {
+ closeFn()
+ return authT{}, err
+ }
+
+ return auth, nil
}
func New(db *sql.DB, queue q.IQueue) (IAuth, error) {
@@ -926,7 +1229,10 @@ func (auth authT) Register(
Password: []byte(password),
Salt: salt,
}
- result := auth.hasher(input)
+ hash, err := auth.hasher(input)
+ if err != nil {
+ return userT{}, err
+ }
/*
We also try to register anyway, to prevent disk IO timing attacks.
@@ -937,7 +1243,7 @@ func (auth authT) Register(
waiter := auth.queue.WaitFor(NEW_USER, flowID, "register")
defer waiter.Close()
- payload, err := newUserPayload(email, salt, result.value)
+ payload, err := newUserPayload(email, salt, hash)
if err != nil {
return userT{}, err
}
@@ -987,13 +1293,12 @@ func (auth authT) ResendConfirmation(email string) error {
return err
}
-
_, err = auth.queue.Publish(unsent)
return err
}
-func (auth authT) ConfirmEmail(token guuid.UUID) (sessionT, error) {
- return auth.queries.confirm(token)
+func (auth authT) ConfirmEmail(token string) (sessionT, error) {
+ return auth.queries.confirm("token FIXME", guuid.New())
}
func (auth authT) LoginEmail(
@@ -1005,34 +1310,36 @@ func (auth authT) LoginEmail(
}
// special check for sql.ErrNoRows to combat enumeration attacks.
+ // FIXME: how so?
user, err := auth.queries.byEmail(email)
if err != nil && err != sql.ErrNoRows {
return sessionT{}, err
}
- input := scrypt.CheckInput{
+ input := scrypt.HashInput{
Password: []byte(password),
Salt: user.salt,
- Hash: user.pwhash,
}
- ok, err := scrypt.Check(input)
+ hash, err := auth.hasher(input)
if err != nil {
return sessionT{}, err
}
+
+ ok := slices.Equal(hash, user.pwhash)
if !ok {
return sessionT{}, ErrBadCombo
}
- if user.confirmed_at == nil {
+ if !user.confirmed {
return sessionT{}, ErrUnconfirmed
}
- dbSession, err := auth.queries.login(email, password)
+ session, err := auth.queries.login(user.uuid, guuid.New())
if err != nil {
return sessionT{}, err
}
- return dbSession, nil
+ return session, nil
}
func forgotPasswordMessage(
@@ -1099,7 +1406,7 @@ func (auth authT) Refresh(session sessionT) (sessionT, error) {
return sessionT{}, err
}
- return auth.queries.refresh(session.uuid)
+ return auth.queries.refresh(session.uuid, guuid.New())
}
func (auth authT) ResetPassword(
@@ -1115,25 +1422,22 @@ func (auth authT) ResetPassword(
return sessionT{}, ErrTooShort
}
- user, err := auth.queries.byToken(token)
- if err != nil {
- return sessionT{}, err
- }
+ user := userT{}
input := scrypt.HashInput{
Password: []byte(password),
Salt: user.salt,
}
- pwhash, err := scrypt.Hash(input)
+ pwhash, err := auth.hasher(input)
if err != nil {
return sessionT{}, err
}
- if user.confirmed_at != nil {
+ if user.confirmed {
return auth.queries.reset(user.id, pwhash, token)
} else {
// return auth.queries.confirm(user.id, pwhash, token)
- return auth.queries.confirm(token)
+ return auth.queries.confirm("token FIXME", guuid.New())
}
}
@@ -1151,17 +1455,16 @@ func (auth authT) ChangePassword(
return sessionT{}, ErrTooShort
}
- // FIXME
input := scrypt.HashInput{
Password: []byte(newPassword),
Salt: user.salt,
}
- pwhash, err := scrypt.Hash(input)
+ pwhash, err := auth.hasher(input)
if err != nil {
return sessionT{}, err
}
- if user.confirmed_at == nil {
+ if !user.confirmed {
return sessionT{}, ErrUnconfirmed
}
@@ -1212,32 +1515,5 @@ func (auth authT) Close() error {
func Main() {
- g.Init()
- db, err := sql.Open("acude", "file:gracha.db?mode=memory&cache=shared")
- if err != nil {
- panic(err)
- }
- defer db.Close()
-
- queue, err := q.New(db)
- if err != nil {
- panic(err)
- }
- defer queue.Close()
-
- auth, err := New(db, queue)
- if err != nil {
- fmt.Println(err)
- panic(err)
- }
-
- user, err := auth.Register("contact@example.com", "password", "password")
- if false {
- fmt.Printf("user: %#v\n", user)
- fmt.Printf("err: %#v\n", err)
- }
-
- return
- fmt.Printf("q: %#v\n", queue)
- fmt.Printf("auth: %#v\n", auth)
+ // FIXME
}
diff --git a/tests/gracha.go b/tests/gracha.go
index 0c11fac..3c834ef 100644
--- a/tests/gracha.go
+++ b/tests/gracha.go
@@ -1,19 +1,47 @@
package gracha
import (
- // "database/sql"
+ "database/sql"
+ "fmt"
+ "os"
+ "time"
// "q"
+ "golite"
+ "guuid"
+ "scrypt"
g "gobang"
)
-type testAuth struct{
- auth authT
- // registerEmail func(credentials)
- close func() error
+
+type userDataT struct{
+ userID guuid.UUID
+ email string
+ salt []byte
+ pwhash []byte
+ token string
+}
+
+
+
+func mksalt() []byte {
+ salt, err := scrypt.Salt()
+ g.TErrorIf(err)
+ return salt
}
+func newUser() userDataT {
+ return userDataT{
+ userID: guuid.New(),
+ email: string(mksalt()),
+ salt: mksalt(),
+ pwhash: mksalt(),
+ token: string(mksalt()),
+ }
+}
+
+
func test_defaultPrefix() {
g.TestStart("defaultPrefix")
@@ -22,60 +50,912 @@ func test_defaultPrefix() {
})
}
-/*
-func mkauth() testAuth {
- db, err := sql.Open("acude", "file:db?mode=memory&cache=shared")
- g.TAssertEqual(err, nil)
+func test_tryRollback() {
+ // FIXME
+}
+
+func test_inTx() {
+ // FIXME
+}
+
+func test_createTables() {
+ g.TestStart("createTables()")
+
+ db, err := sql.Open(golite.DriverName, ":memory:")
+ g.TErrorIf(err)
+ defer db.Close()
+
+
+ g.Testing("tables exist afterwards", func() {
+ const tmpl_read = `
+ SELECT id FROM "%s_users" LIMIT 1;
+ `
+ qRead := fmt.Sprintf(tmpl_read, defaultPrefix)
+
+ _, err := db.Exec(qRead)
+ g.TErrorNil(err)
+
+ err = createTables(db, defaultPrefix)
+ g.TErrorIf(err)
+
+ _, err = db.Exec(qRead)
+ g.TErrorIf(err)
+ })
+
+ g.Testing("we can do it multiple times", func() {
+ g.TErrorIf(g.SomeError(
+ createTables(db, defaultPrefix),
+ createTables(db, defaultPrefix),
+ createTables(db, defaultPrefix),
+ ))
+ })
+}
+
+func test_registerStmt() {
+ g.TestStart("registerStmt()")
+
+ const (
+ prefix = defaultPrefix
+ )
+
+ db, err := sql.Open(golite.DriverName, ":memory:")
+ g.TErrorIf(err)
+ g.TErrorIf(createTables(db, prefix))
+
+ register, registerClose, registerErr := registerStmt(db, prefix)
+ g.TErrorIf(registerErr)
+ defer g.SomeFnError(
+ registerClose,
+ db.Close,
+ )
+
+
+ g.Testing("we can register a user", func() {
+ u := newUser()
+ user, err := register(u.userID, u.email, u.salt, u.pwhash)
+ g.TErrorIf(err)
+
+ g.TAssertEqual(user.timestamp == time.Time{}, false)
+ g.TAssertEqual(user.id, int64(1))
+ g.TAssertEqual(user.uuid, u.userID)
+ g.TAssertEqual(user.email, u.email)
+ g.TAssertEqual(user.salt, u.salt)
+ g.TAssertEqual(user.pwhash, u.pwhash)
+ })
+
+ g.Testing("users can't have the same uuid", func() {
+ u1 := newUser()
+ u2 := newUser()
+ userID := guuid.New()
+
+ _, err1 := register(userID, u1.email, u1.salt, u1.pwhash)
+ _, err2 := register(userID, u2.email, u2.salt, u2.pwhash)
+
+ g.TErrorIf(err1)
+ g.TAssertEqual(
+ err2.(golite.Error).ExtendedCode,
+ golite.ErrConstraintUnique,
+ )
+ })
+
+ g.Testing("users can't have the same email", func() {
+ u1 := newUser()
+ u2 := newUser()
+ email := string(mksalt())
+
+ _, err1 := register(u1.userID, email, u1.salt, u1.pwhash)
+ _, err2 := register(u2.userID, email, u2.salt, u2.pwhash)
+
+ g.TErrorIf(err1)
+ g.TAssertEqual(
+ err2.(golite.Error).ExtendedCode,
+ golite.ErrConstraintUnique,
+ )
+ })
+
+ g.Testing("users can't have the same salt", func() {
+ u1 := newUser()
+ u2 := newUser()
+ salt := mksalt()
+
+ _, err1 := register(u1.userID, u1.email, salt, u1.pwhash)
+ _, err2 := register(u2.userID, u2.email, salt, u2.pwhash)
+
+ g.TErrorIf(err1)
+ g.TAssertEqual(
+ err2.(golite.Error).ExtendedCode,
+ golite.ErrConstraintUnique,
+ )
+ })
+
+ g.Testing("no error when close()ing more than once", func() {
+ g.TErrorIf(g.SomeError(
+ registerClose(),
+ registerClose(),
+ registerClose(),
+ ))
+ })
+}
+
+func test_sendTokenStmt() {
+ g.TestStart("sendToken()")
+
+ const (
+ prefix = defaultPrefix
+ )
+
+ db, err := sql.Open(golite.DriverName, ":memory:")
+ g.TErrorIf(err)
+ g.TErrorIf(createTables(db, prefix))
+
+ register, registerClose, registerErr := registerStmt(db, prefix)
+ sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix)
+ g.TErrorIf(g.SomeError(
+ registerErr,
+ sendTokenErr,
+ ))
+ defer g.SomeFnError(
+ registerClose,
+ sendTokenClose,
+ db.Close,
+ )
+
+
+ g.Testing("can't send a token to a non-existent user", func() {
+ err := sendToken(guuid.New(), "some token")
+ g.TAssertEqual(
+ err.(golite.Error).ExtendedCode,
+ golite.ErrConstraintNotNull,
+ )
+ })
+
+ g.Testing("otherwise creates the confirmation attempt", func() {
+ u := newUser()
+
+ _, err := register(u.userID, u.email, u.salt, u.pwhash)
+ g.TErrorIf(err)
+
+ err = sendToken(u.userID, u.token)
+ g.TErrorIf(err)
+ })
+
+ g.Testing("token has to be unique globally", func() {
+ u1 := newUser()
+ u2 := newUser()
+ token := string(mksalt())
+
+ _, err := register(u1.userID, u1.email, u1.salt, u1.pwhash)
+ g.TErrorIf(err)
+ _, err = register(u2.userID, u2.email, u2.salt, u2.pwhash)
+ g.TErrorIf(err)
+
+ err1 := sendToken(u1.userID, token)
+ err2 := sendToken(u2.userID, token)
+ g.TErrorIf(err1)
+ g.TAssertEqual(
+ err2.(golite.Error).ExtendedCode,
+ golite.ErrConstraintUnique,
+ )
+ })
+
+ g.Testing("a user can have multiple", func() {
+ u := newUser()
+ token1 := string(mksalt())
+ token2 := string(mksalt())
+ token3 := string(mksalt())
- queue, err := q.New(db)
- g.TAssertEqual(err, nil)
+ _, err := register(u.userID, u.email, u.salt, u.pwhash)
+ g.TErrorIf(err)
- auth, err := New(db, queue)
- g.TAssertEqual(err, nil)
+ g.TErrorIf(g.SomeError(
+ sendToken(u.userID, token1),
+ sendToken(u.userID, token2),
+ sendToken(u.userID, token3),
+ ))
- return testAuth{
- auth: auth,
- close: func() error {
- return g.SomeFnError(
- db.Close,
- queue.Close,
+ tokens := []string{}
+ const tmpl_read = `
+ SELECT token from "%s_confirmation_attempts"
+ WHERE user_id = (
+ SELECT id FROM "%s_users"
+ WHERE uuid = ?
)
- },
- }
+ ORDER BY id ASC;
+ `
+ qRead := fmt.Sprintf(tmpl_read, prefix, prefix)
+ rows, err := db.Query(qRead, u.userID[:])
+ g.TErrorIf(err)
+
+ for rows.Next() {
+ var token string
+ err := rows.Scan(&token)
+ g.TErrorIf(err)
+ tokens = append(tokens, token)
+ }
+ g.TErrorIf(g.SomeFnError(rows.Err, rows.Close))
+
+ expected := []string{
+ token1,
+ token2,
+ token3,
+ }
+ g.TAssertEqual(tokens, expected)
+ })
+}
+
+func test_confirmStmt() {
+ g.TestStart("confirmStmt()")
+
+ const (
+ prefix = defaultPrefix
+ )
+
+ db, err := sql.Open(golite.DriverName, ":memory:")
+ g.TErrorIf(err)
+ g.TErrorIf(createTables(db, prefix))
+
+ register, registerClose, registerErr := registerStmt(db, prefix)
+ sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix)
+ confirm, confirmClose, confirmErr := confirmStmt(db, prefix)
+ g.TErrorIf(g.SomeError(
+ registerErr,
+ sendTokenErr,
+ confirmErr,
+ ))
+ defer g.SomeFnError(
+ registerClose,
+ sendTokenClose,
+ confirmClose,
+ db.Close,
+ )
+
+
+ g.Testing("can't confirm a token that doesn't exist", func() {
+ _, err := confirm(string(mksalt()), guuid.New())
+ g.TAssertEqual(err, sql.ErrNoRows)
+ })
+
+ g.Testing("otherwise it creates a confirmation and a session", func() {
+ u := newUser()
+ sessionID := guuid.New()
+
+ _, err := register(u.userID, u.email, u.salt, u.pwhash)
+ g.TErrorIf(err)
+
+ g.TErrorIf(sendToken(u.userID, u.token))
+
+ session, err := confirm(u.token, sessionID)
+ g.TErrorIf(err)
+
+ g.TAssertEqual(session.timestamp == time.Time{}, false)
+ g.TAssertEqual(session.uuid, sessionID)
+ g.TAssertEqual(session.userID, u.userID)
+ })
+
+ g.Testing("can't confirm the same token twice", func() {
+ u := newUser()
+
+ _, err := register(u.userID, u.email, u.salt, u.pwhash)
+ g.TErrorIf(err)
+
+ g.TErrorIf(sendToken(u.userID, u.token))
+
+ _, err1 := confirm(u.token, guuid.New())
+ _, err2 := confirm(u.token, guuid.New())
+ g.TErrorIf(err1)
+ g.TAssertEqual(
+ err2.(golite.Error).ExtendedCode,
+ golite.ErrConstraintUnique,
+ )
+ })
+
+ g.Testing("a user can't have 2 confirmations", func() {
+ u := newUser()
+ token1 := string(mksalt())
+ token2 := string(mksalt())
+
+ _, err := register(u.userID, u.email, u.salt, u.pwhash)
+ g.TErrorIf(err)
+
+ g.TErrorIf(g.SomeError(
+ sendToken(u.userID, token1),
+ sendToken(u.userID, token2),
+ ))
+
+ _, err1 := confirm(token1, guuid.New())
+ _, err2 := confirm(token2, guuid.New())
+ g.TErrorIf(err1)
+ g.TAssertEqual(
+ err2.(golite.Error).ExtendedCode,
+ golite.ErrConstraintUnique,
+ )
+ })
+}
+
+func test_byEmailStmt() {
+ g.TestStart("byEmailStmt()")
+
+ const (
+ prefix = defaultPrefix
+ )
+
+ db, err := sql.Open(golite.DriverName, ":memory:")
+ g.TErrorIf(err)
+ g.TErrorIf(createTables(db, prefix))
+
+ register, registerClose, registerErr := registerStmt(db, prefix)
+ sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix)
+ confirm, confirmClose, confirmErr := confirmStmt(db, prefix)
+ byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix)
+ g.TErrorIf(g.SomeError(
+ registerErr,
+ sendTokenErr,
+ confirmErr,
+ byEmailErr,
+ ))
+ defer g.SomeFnError(
+ registerClose,
+ sendTokenClose,
+ confirmClose,
+ byEmailClose,
+ db.Close,
+ )
+
+
+ g.Testing("error when not found", func() {
+ email := string(mksalt())
+ _, err := byEmail(email)
+ g.TAssertEqual(err, sql.ErrNoRows)
+ })
+
+ g.Testing("full user otherwise, confirmed or not", func() {
+ u := newUser()
+
+ _, err := register(u.userID, u.email, u.salt, u.pwhash)
+ g.TErrorIf(err)
+
+ user1, err := byEmail(u.email)
+ g.TErrorIf(err)
+
+ g.TErrorIf(sendToken(u.userID, u.token))
+
+ user2, err := byEmail(u.email)
+ g.TErrorIf(err)
+
+ _, err = confirm(u.token, guuid.New())
+ g.TErrorIf(err)
+
+ user3, err := byEmail(u.email)
+ g.TErrorIf(err)
+
+ g.TAssertEqual(user1, user2)
+ user2.confirmed = true
+ g.TAssertEqual(user2, user3)
+ })
+
+ g.Testing("if there is metadata content it is included", func() {
+ u := newUser()
+
+ _, err := register(u.userID, u.email, u.salt, u.pwhash)
+ g.TErrorIf(err)
+
+ user1, err := byEmail(u.email)
+ g.TErrorIf(err)
+ g.TAssertEqual(user1.metadata == nil, true)
+
+ const tmpl_write = `
+ UPDATE "%s_users"
+ SET METADATA = '{ "key": "value" }'
+ WHERE uuid = ?;
+ `
+ qWrite := fmt.Sprintf(tmpl_write, prefix)
+ _, err = db.Exec(qWrite, u.userID[:])
+ g.TErrorIf(err)
+
+ expected := map[string]interface{}{"key": "value"}
+
+ user2, err := byEmail(u.email)
+ g.TErrorIf(err)
+ g.TAssertEqual(user2.metadata, expected)
+ })
}
-func test_Register() {
- g.TestStart("Register()")
+func test_loginStmt() {
+ g.TestStart("loginStmt()")
const (
- email = "email@example.com"
- password = "password"
- confirmPassword = "password"
+ prefix = defaultPrefix
+ )
+
+ db, err := sql.Open(golite.DriverName, ":memory:")
+ g.TErrorIf(err)
+ g.TErrorIf(createTables(db, prefix))
+
+ register, registerClose, registerErr := registerStmt(db, prefix)
+ sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix)
+ confirm, confirmClose, confirmErr := confirmStmt(db, prefix)
+ byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix)
+ login, loginClose, loginErr := loginStmt(db, prefix)
+ g.TErrorIf(g.SomeError(
+ registerErr,
+ sendTokenErr,
+ confirmErr,
+ byEmailErr,
+ loginErr,
+ ))
+ defer g.SomeFnError(
+ registerClose,
+ sendTokenClose,
+ confirmClose,
+ byEmailClose,
+ loginClose,
+ db.Close,
)
- g.Testing("we can register a new email", func() {
- t := mkauth()
- defer t.close()
- user, err := t.auth.Register(
- email,
- password,
- confirmPassword,
+ g.Testing("a user must exist to login", func() {
+ _, err := login(guuid.New(), guuid.New())
+ g.TAssertEqual(
+ err.(golite.Error).ExtendedCode,
+ golite.ErrConstraintNotNull,
)
+ })
+
+ g.Testing("sessionID must be unique globally", func() {
+ u1 := newUser()
+ u2 := newUser()
+ sessionID := guuid.New()
+
+ _, err := register(u1.userID, u1.email, u1.salt, u1.pwhash)
+ g.TErrorIf(err)
+ _, err = register(u2.userID, u2.email, u2.salt, u2.pwhash)
+ g.TErrorIf(err)
+
+ _, err1 := login(u1.userID, sessionID)
+ _, err2 := login(u1.userID, sessionID)
+ _, err3 := login(u2.userID, sessionID)
+
+ g.TErrorIf(err1)
+ g.TAssertEqual(
+ err2.(golite.Error).ExtendedCode,
+ golite.ErrConstraintUnique,
+ )
+ g.TAssertEqual(
+ err3.(golite.Error).ExtendedCode,
+ golite.ErrConstraintUnique,
+ )
+ })
+
+ g.Testing("a user can have multiple active sessions", func() {
+ u := newUser()
+ sessionID1 := guuid.New()
+ sessionID2 := guuid.New()
+
+ _, err := register(u.userID, u.email, u.salt, u.pwhash)
+ g.TErrorIf(err)
+
+ session1, err := login(u.userID, sessionID1)
+ g.TErrorIf(err)
+
+ session2, err := login(u.userID, sessionID2)
+ g.TErrorIf(err)
+
+ g.TAssertEqual(session1.uuid, sessionID1)
+ g.TAssertEqual(session2.uuid, sessionID2)
+ g.TAssertEqual(session1.userID, u.userID)
+ g.TAssertEqual(session2.userID, u.userID)
+ })
+
+ g.Testing("multiple users can be logged in", func() {
+ u1 := newUser()
+ u2 := newUser()
+ sessionID1 := guuid.New()
+ sessionID2 := guuid.New()
+
+ _, err := register(u1.userID, u1.email, u1.salt, u1.pwhash)
+ g.TErrorIf(err)
+ _, err = register(u2.userID, u2.email, u2.salt, u2.pwhash)
+ g.TErrorIf(err)
+
+ session1, err := login(u1.userID, sessionID1)
+ g.TErrorIf(err)
+ session2, err := login(u2.userID, sessionID2)
+ g.TErrorIf(err)
+
+ g.TAssertEqual(session1.uuid, sessionID1)
+ g.TAssertEqual(session2.uuid, sessionID2)
+ g.TAssertEqual(session1.userID, u1.userID)
+ g.TAssertEqual(session2.userID, u2.userID)
+ })
+
+ g.Testing("an unconfirmed user is allowed to login", func() {
+ u := newUser()
+
+ _, err := register(u.userID, u.email, u.salt, u.pwhash)
+ g.TErrorIf(err)
+
+ _, err = login(u.userID, guuid.New())
+ g.TErrorIf(err)
+ })
+
+ g.Testing("a confirmed user is allowed to login, too", func() {
+ u := newUser()
+
+ _, err := register(u.userID, u.email, u.salt, u.pwhash)
+ g.TErrorIf(err)
+
+ g.TErrorIf(sendToken(u.userID, u.token))
+
+ _, err = confirm(u.token, guuid.New())
+ g.TErrorIf(err)
+
+ user, err := byEmail(u.email)
+ g.TErrorIf(err)
+ g.TAssertEqual(user.confirmed, true)
+
+ _, err = login(u.userID, guuid.New())
+ g.TErrorIf(err)
+ })
+}
+
+func test_refreshStmt() {
+ g.TestStart("refreshStmt()")
+
+ const (
+ prefix = defaultPrefix
+ )
+
+ db, err := sql.Open(golite.DriverName, ":memory:")
+ g.TErrorIf(err)
+ g.TErrorIf(createTables(db, prefix))
+
+ register, registerClose, registerErr := registerStmt(db, prefix)
+ sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix)
+ confirm, confirmClose, confirmErr := confirmStmt(db, prefix)
+ byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix)
+ login, loginClose, loginErr := loginStmt(db, prefix)
+ refresh, refreshClose, refreshErr := refreshStmt(db, prefix)
+ g.TErrorIf(g.SomeError(
+ registerErr,
+ sendTokenErr,
+ confirmErr,
+ byEmailErr,
+ loginErr,
+ refreshErr,
+ ))
+ defer g.SomeFnError(
+ registerClose,
+ sendTokenClose,
+ confirmClose,
+ byEmailClose,
+ loginClose,
+ refreshClose,
+ db.Close,
+ )
+
+ reg := func(u userDataT) userT {
+ user, err := register(u.userID, u.email, u.salt, u.pwhash)
+ g.TErrorIf(err)
+ return user
+ }
+
+ conf := func(u userDataT) sessionT {
+ err := sendToken(u.userID, u.token)
+ g.TErrorIf(err)
+
+ session, err := confirm(u.token, guuid.New())
+ g.TErrorIf(err)
+ return session
+ }
+
+
+ g.Testing("a session needs to exist be be refreshed", func() {
+ _, err := refresh(guuid.New(), guuid.New())
+ g.TAssertEqual(
+ err.(golite.Error).ExtendedCode,
+ golite.ErrConstraintNotNull,
+ )
+ })
+
+ g.Testing("we can refresh the session of an unconfirmed user", func() {
+ u := newUser()
+ sessionID1 := guuid.New()
+ sessionID2 := guuid.New()
+
+ reg(u)
+
+ session1, err := login(u.userID, sessionID1)
+ g.TErrorIf(err)
+
+ user, err := byEmail(u.email)
+ g.TErrorIf(err)
+ g.TAssertEqual(user.confirmed, false)
+
+ session2, err := refresh(sessionID1, sessionID2)
+ g.TErrorIf(err)
+
+ g.TAssertEqual(session1.userID, u.userID)
+ g.TAssertEqual(session1.userID, session2.userID)
+ g.TAssertEqual(session1.uuid, sessionID1)
+ g.TAssertEqual(session2.uuid, sessionID2)
+ g.TAssertEqual(session1.timestamp == time.Time{}, false)
+ g.TAssertEqual(session2.timestamp == time.Time{}, false)
+ })
+
+ g.Testing("we can refresh the session of a confirmed user", func() {
+ u := newUser()
+ sessionID1 := guuid.New()
+ sessionID2 := guuid.New()
+
+ reg(u)
+ session1 := conf(u)
+
+ user, err := byEmail(u.email)
+ g.TErrorIf(err)
+ g.TAssertEqual(user.confirmed, true)
+
+ // FIXME
return
- g.TAssertEqual(err, nil)
- g.TAssertEqual(user, nil)
+
+ session2, err := refresh(sessionID1, sessionID2)
+ g.TErrorIf(err)
+
+ g.TAssertEqual(session1.userID, u.userID)
+ g.TAssertEqual(session1.userID, session2.userID)
+ g.TAssertEqual(session1.uuid, sessionID1)
+ g.TAssertEqual(session2.uuid, sessionID2)
+ })
+
+ g.Testing("we can't refresh an expired session", func() {
})
- g.Testing("we can't register duplicate emails", func() {
+ g.Testing("the sessionID can't be reused, even across users", func() {
})
+ // FIXME
}
-*/
+
+func test_resetStmt() {
+ // FIXME
+}
+
+func test_changeStmt() {
+ // FIXME
+}
+
+func test_byUUIDStmt() {
+ // FIXME
+}
+
+func test_logoutStmt() {
+ // FIXME
+}
+
+func test_outOthersStmt() {
+ // FIXME
+}
+
+func test_outAllStmt() {
+ // FIXME
+}
+
+func test_initDB() {
+ // FIXME
+}
+
+func test_queriesTclose() {
+ // FIXME
+}
+
+func test_newUserHandler() {
+ // FIXME
+}
+
+func test_sendConfirmationRequestHandler() {
+ // FIXME
+}
+
+func test_forgotPasswordRequestHandler() {
+ // FIXME
+}
+
+func test_registerConsumers() {
+ // FIXME
+}
+
+func test_unregisterConsumers() {
+ // FIXME
+}
+
+func test_startRunner() {
+ // FIXME
+}
+
+func test_makePoolRunner() {
+ // FIXME
+}
+
+func test_asResult() {
+ // FIXME
+}
+
+func test_unwrapResult() {
+ // FIXME
+}
+
+func test_NewWithPrefix() {
+ // FIXME
+}
+
+func test_New() {
+ // FIXME
+}
+
+func test_newUserPayload() {
+ // FIXME
+}
+
+func test_authT_Register() {
+ // FIXME
+}
+
+func test_sendConfirmationMessage() {
+ // FIXME
+}
+
+func test_authT_ResentConfirmation() {
+ // FIXME
+}
+
+func test_authT_ConfirmEmail() {
+ // FIXME
+}
+
+func test_authT_LoginEmail() {
+ // FIXME
+}
+
+func test_forgotPasswordMessage() {
+ // FIXME
+}
+
+func test_authT_ForgotPassword() {
+ // FIXME
+}
+
+func test_checkSession() {
+ // FIXME
+}
+
+func test_validateSession() {
+ // FIXME
+}
+
+func test_authT_Refresh() {
+ // FIXME
+}
+
+func test_authT_ResetPassword() {
+ // FIXME
+}
+
+func test_authT_ChangePassword() {
+ // FIXME
+}
+
+func test_runLogout() {
+ // FIXME
+}
+
+func test_authT_Logout() {
+ // FIXME
+}
+
+func test_authT_LogoutOthers() {
+ // FIXME
+}
+
+func test_authT_LogoutAll() {
+ // FIXME
+}
+
+func test_authT_Close() {
+ // FIXME
+}
+
+func test_usage() {
+ // FIXME
+}
+
+func test_getopt() {
+ // FIXME
+}
+
+func test_runCommand() {
+ // FIXME
+}
+
+
+func dumpQueries() {
+ queries := []struct{name string; fn func(string) queryT}{
+ { "createTables", createTablesSQL },
+ { "byEmail", byEmailSQL },
+ { "register", registerSQL },
+ { "sendToken", sendTokenSQL },
+ { "confirm", confirmSQL },
+ { "login", loginSQL },
+ { "refresh", refreshSQL },
+ { "reset", resetSQL },
+ { "change", changeSQL },
+ { "byUUID", byUUIDSQL },
+ { "logout", logoutSQL },
+ { "outOthers", outOthersSQL },
+ { "outAll", outAllSQL },
+ }
+ for _, query := range queries {
+ q := query.fn(defaultPrefix)
+ fmt.Printf("\n-- %s.sql:", query.name)
+ fmt.Printf("\n-- write:%s\n", q.write)
+ fmt.Printf("\n-- read:%s\n", q.read)
+ }
+}
+
func MainTest() {
+ if os.Getenv("TESTING_DUMP_SQL_QUERIES") != "" {
+ dumpQueries()
+ return
+ }
+
g.Init()
test_defaultPrefix()
- // test_tablesFrom()
- // test_Register()
+ test_tryRollback()
+ test_inTx()
+ test_createTables()
+ test_registerStmt()
+ test_sendTokenStmt()
+ test_confirmStmt()
+ test_byEmailStmt()
+ test_loginStmt()
+ test_refreshStmt()
+ test_resetStmt()
+ test_changeStmt()
+ test_byUUIDStmt()
+ test_logoutStmt()
+ test_outOthersStmt()
+ test_outAllStmt()
+ test_initDB()
+ test_queriesTclose()
+ test_newUserHandler()
+ test_sendConfirmationRequestHandler()
+ test_forgotPasswordRequestHandler()
+ test_registerConsumers()
+ test_unregisterConsumers()
+ test_startRunner()
+ test_makePoolRunner()
+ test_asResult()
+ test_unwrapResult()
+ test_NewWithPrefix()
+ test_New()
+ test_newUserPayload()
+ test_authT_Register()
+ test_sendConfirmationMessage()
+ test_authT_ResentConfirmation()
+ test_authT_ConfirmEmail()
+ test_authT_LoginEmail()
+ test_forgotPasswordMessage()
+ test_authT_ForgotPassword()
+ test_checkSession()
+ test_validateSession()
+ test_authT_Refresh()
+ test_authT_ResetPassword()
+ test_authT_ChangePassword()
+ test_runLogout()
+ test_authT_Logout()
+ test_authT_LogoutOthers()
+ test_authT_LogoutAll()
+ test_authT_Close()
+ test_usage()
+ test_getopt()
+ test_runCommand()
}
diff --git a/tests/queries.sql b/tests/queries.sql
index e69de29..c44a6c5 100644
--- a/tests/queries.sql
+++ b/tests/queries.sql
@@ -0,0 +1,205 @@
+
+-- createTables.sql:
+-- write:
+ CREATE TABLE IF NOT EXISTS "gracha_users" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000000Z', 'now')),
+ uuid BLOB NOT NULL UNIQUE,
+ email TEXT NOT NULL UNIQUE,
+ salt BLOB NOT NULL UNIQUE,
+ pwhash BLOB NOT NULL,
+ metadata TEXT
+ );
+ CREATE TABLE IF NOT EXISTS "gracha_confirmation_attempts" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000000Z', 'now')),
+ -- uuid BLOB NOT NULL UNIQUE,
+ user_id INTEGER NOT NULL REFERENCES "gracha_users"(id),
+ token TEXT NOT NULL UNIQUE
+ );
+ CREATE TABLE IF NOT EXISTS "gracha_user_confirmations" (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000000Z', 'now')),
+ user_id INTEGER NOT NULL
+ REFERENCES "gracha_users"(id) UNIQUE,
+ attempt_id INTEGER NOT NULL
+ REFERENCES "gracha_confirmation_attempts"(id) UNIQUE
+ );
+ CREATE TABLE IF NOT EXISTS "gracha_user_changes" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000000Z', 'now')),
+ user_id INTEGER NOT NULL REFERENCES "gracha_users"(id),
+ attribute TEXT NOT NULL,
+ value TEXT NOT NULL,
+ op BOOLEAN NOT NULL
+ );
+ -- CREATE TABLE IF NOT EXISTS "gracha_tokens" (
+ -- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ -- timestamp TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000000Z', 'now')),
+ -- uuid BLOB NOT NULL UNIQUE,
+ -- type TEXT NOT NULL
+ -- );
+ CREATE TABLE IF NOT EXISTS "gracha_roles" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL REFERENCES "gracha_users"(id),
+ role TEXT NOT NULL,
+ metadata TEXT,
+ UNIQUE (user_id, role)
+ );
+ CREATE TABLE IF NOT EXISTS "gracha_role_changes" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000000Z', 'now')),
+ user_id INTEGER NOT NULL REFERENCES "gracha_roles"(id),
+ role TEXT NOT NULL,
+ op BOOLEAN NOT NULL
+ );
+ CREATE TABLE IF NOT EXISTS "gracha_sessions" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000000Z', 'now')),
+ uuid BLOB NOT NULL UNIQUE,
+ user_id INTEGER NOT NULL REFERENCES "gracha_users"(id),
+ -- type TEXT NOT NULL,
+ -- revoked_at TEXT,
+ -- revoker_id INTEGER REFERENCES "gracha_users"(id),
+ -- FIXME: add provenance: login, refresh, confirmation, etc.
+ metadata TEXT
+ );
+ CREATE TABLE IF NOT EXISTS "gracha_attempts" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000000Z', 'now')),
+ user_id INTEGER REFERENCES "gracha_users"(id),
+ session_id INTEGER REFERENCES "gracha_sessions"(id),
+ metadata TEXT
+ );
+ CREATE TABLE IF NOT EXISTS "gracha_audit" (
+ id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000000Z', 'now')),
+ uuid BLOB NOT NULL UNIQUE,
+ attribute TEXT NOT NULL,
+ value TEXT NOT NULL,
+ op BOOLEAN NOT NULL,
+ metadata TEXT
+ );
+
+
+-- read:
+
+-- byEmail.sql:
+-- write:
+
+-- read:
+ SELECT id, timestamp, uuid, salt, pwhash, metadata, (
+ CASE WHEN EXISTS (
+ SELECT id FROM "gracha_user_confirmations"
+ WHERE user_id = (
+ SELECT id FROM "gracha_users"
+ WHERE email = ?
+ )
+ ) THEN 1
+ ELSE 0
+ END
+ ) as confirmed
+ FROM "gracha_users" WHERE email = ?;
+
+
+-- register.sql:
+-- write:
+ INSERT INTO "gracha_users" (uuid, email, salt, pwhash)
+ VALUES (?, ?, ?, ?) RETURNING id, timestamp;
+
+
+-- read:
+ SELECT id, timestamp from "gracha_users"
+ WHERE uuid = ?;
+
+
+-- sendToken.sql:
+-- write:
+ INSERT INTO "gracha_confirmation_attempts" (user_id, token)
+ VALUES (
+ (SELECT id FROM "gracha_users" WHERE uuid = ?),
+ ?
+ )
+
+
+-- read:
+
+-- confirm.sql:
+-- write:
+ INSERT INTO "gracha_user_confirmations" (user_id, attempt_id)
+ VALUES (?, ?);
+
+
+-- read:
+ SELECT
+ "gracha_confirmation_attempts".id,
+ "gracha_confirmation_attempts".user_id,
+ "gracha_users".uuid
+ FROM "gracha_confirmation_attempts"
+ JOIN "gracha_users" ON
+ "gracha_confirmation_attempts".user_id = "gracha_users".id
+ WHERE token = ?;
+
+
+-- login.sql:
+-- write:
+
+-- read:
+
+-- refresh.sql:
+-- write:
+ INSERT INTO "gracha_sessions" (uuid, user_id)
+ VALUES (
+ ?,
+ (SELECT user_id FROM "gracha_sessions" WHERE uuid = ?)
+ ) RETURNING id, timestamp, (
+ SELECT "gracha_users".uuid FROM "gracha_users"
+ JOIN "gracha_sessions" ON
+ "gracha_users".id = "gracha_sessions".user_id
+ WHERE "gracha_sessions".uuid = ?
+ ) AS userID;
+
+
+-- read:
+
+-- reset.sql:
+-- write:
+ -- INSERT SOMETHING gracha
+
+
+-- read:
+
+-- change.sql:
+-- write:
+ -- INSERT SOMETHING gracha
+
+
+-- read:
+
+-- byUUID.sql:
+-- write:
+
+-- read:
+ -- INSERT SOMETHING gracha
+
+
+-- logout.sql:
+-- write:
+ -- INSERT SOMETHING gracha
+
+
+-- read:
+
+-- outOthers.sql:
+-- write:
+ -- INSERT SOMETHING gracha
+
+
+-- read:
+
+-- outAll.sql:
+-- write:
+ -- INSERT SOMETHING gracha
+
+
+-- read: