summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEuAndreh <eu@euandre.org>2024-11-07 11:09:55 -0300
committerEuAndreh <eu@euandre.org>2024-11-07 11:09:55 -0300
commit58734647fef344a8105d7b76306eb010a1d0ca4e (patch)
treed50646085b6a6e0a8715979d209c90caae3764cf
parentAdjust to dependency renaming "q" -> "fiinha" (diff)
downloadcracha-58734647fef344a8105d7b76306eb010a1d0ca4e.tar.gz
cracha-58734647fef344a8105d7b76306eb010a1d0ca4e.tar.xz
src/cracha.go: Add userByUUID() and refactor *Stmt() arguments
-rw-r--r--src/cracha.go360
-rw-r--r--tests/cracha.go237
2 files changed, 441 insertions, 156 deletions
diff --git a/src/cracha.go b/src/cracha.go
index dc4193c..f669198 100644
--- a/src/cracha.go
+++ b/src/cracha.go
@@ -1,7 +1,6 @@
package cracha
import (
- "context"
"database/sql"
"encoding/hex"
"encoding/json"
@@ -48,6 +47,12 @@ var (
+type dbconfigT struct{
+ shared *sql.DB
+ dbpath string
+ prefix string
+}
+
type queryT struct{
write string
read string
@@ -55,19 +60,20 @@ type queryT struct{
}
type queriesT struct{
- register func(guuid.UUID, string, []byte, []byte) (userT, error)
- sendToken func(guuid.UUID, string) error
- confirm func(string, guuid.UUID) (sessionT, error)
- byEmail func(string) (userT, error)
- login func(guuid.UUID, guuid.UUID) (sessionT, error)
- refresh func(guuid.UUID, guuid.UUID) (sessionT, error)
- reset func(int64, []byte, guuid.UUID) (sessionT, error)
- change func(guuid.UUID, []byte) (sessionT, error)
- byUUID func(guuid.UUID) (sessionT, error)
- logout func(guuid.UUID) error
- outOthers func(guuid.UUID) error
- outAll func(guuid.UUID) error
- close func() error
+ register func(guuid.UUID, string, []byte, []byte) (userT, error)
+ sendToken func(guuid.UUID, string) error
+ confirm func(string, guuid.UUID) (sessionT, error)
+ byEmail func(string) (userT, error)
+ userByUUID func(guuid.UUID) (userT, error)
+ login func(guuid.UUID, guuid.UUID) (sessionT, error)
+ refresh func(guuid.UUID, guuid.UUID) (sessionT, error)
+ reset func(int64, []byte, guuid.UUID) (sessionT, error)
+ change func(guuid.UUID, []byte) (sessionT, error)
+ byUUID func(guuid.UUID) (sessionT, error)
+ logout func(guuid.UUID) error
+ outOthers func(guuid.UUID) error
+ outAll func(guuid.UUID) error
+ close func() error
}
type userT struct{
@@ -82,7 +88,6 @@ type userT struct{
type User struct{
UUID guuid.UUID
- Salt []byte
Confirmed bool
}
@@ -91,7 +96,6 @@ type sessionT struct{
timestamp time.Time
uuid guuid.UUID
userID guuid.UUID
- // type_ string
revoked_at *time.Time
}
@@ -120,7 +124,7 @@ type IAuth interface{
ForgotPassword(string) error
Refresh(Session) (Session, error)
ResetPassword(guuid.UUID, string, string) (Session, error)
- ChangePassword(User, string, string, string) (Session, error)
+ ChangePassword(guuid.UUID, string, string, string) (Session, error)
Logout(Session) error
LogoutOthers(Session) error
LogoutAll(Session) error
@@ -134,8 +138,8 @@ func validateSession(session sessionT) (sessionT, error) {
return session, nil
}
-func tryRollback(db *sql.DB, ctx context.Context, err error) error {
- _, rollbackErr := db.ExecContext(ctx, "ROLLBACK;")
+func tryRollback(tx *sql.Tx, err error) error {
+ rollbackErr := tx.Rollback()
if rollbackErr != nil {
return fmt.Errorf(
rollbackErrorFmt,
@@ -147,27 +151,72 @@ func tryRollback(db *sql.DB, ctx context.Context, err error) error {
return err
}
-func inTx(db *sql.DB, fn func(context.Context) error) error {
- ctx := context.Background()
-
- _, err := db.ExecContext(ctx, "BEGIN IMMEDIATE;")
+func inTx(db *sql.DB, fn func(*sql.Tx) error) error {
+ tx, err := db.Begin()
if err != nil {
return err
}
- err = fn(ctx)
+ err = fn(tx)
if err != nil {
- return tryRollback(db, ctx, err)
+ return tryRollback(tx, err)
}
- _, err = db.ExecContext(ctx, "COMMIT;")
+ err = tx.Commit()
if err != nil {
- return tryRollback(db, ctx, err)
+ return tryRollback(tx, err)
}
return nil
}
+func serialized[A any, B any](callback func(...A) B) (func(...A) B, func()) {
+ in := make(chan []A)
+ out := make(chan B)
+
+ closed := false
+ var (
+ closeWg sync.WaitGroup
+ closeMutex sync.Mutex
+ )
+ closeWg.Add(1)
+
+ go func() {
+ for input := range in {
+ out <- callback(input...)
+ }
+ close(out)
+ closeWg.Done()
+ }()
+
+ fn := func(input ...A) B {
+ in <- input
+ return (<- out)
+ }
+
+ closeFn := func() {
+ closeMutex.Lock()
+ defer closeMutex.Unlock()
+ if closed {
+ return
+ }
+ close(in)
+ closed = true
+ closeWg.Wait()
+ }
+
+ return fn, closeFn
+}
+
+func execSerialized(query string, db *sql.DB) (func(...any) error, func()) {
+ return serialized(func(args ...any) error {
+ return inTx(db, func(tx *sql.Tx) error {
+ _, err := tx.Exec(query, args...)
+ return err
+ })
+ })
+}
+
func createTablesSQL(prefix string) queryT {
const tmpl_write = `
CREATE TABLE IF NOT EXISTS "%s_users" (
@@ -293,8 +342,8 @@ func createTablesSQL(prefix string) queryT {
func createTables(db *sql.DB, prefix string) error {
q := createTablesSQL(prefix)
- return inTx(db, func(ctx context.Context) error {
- _, err := db.ExecContext(ctx, q.write)
+ return inTx(db, func(tx *sql.Tx) error {
+ _, err := tx.Exec(q.write)
return err
})
}
@@ -315,21 +364,20 @@ func registerSQL(prefix string) queryT {
}
func registerStmt(
- db *sql.DB,
- prefix string,
+ cfg dbconfigT,
) (
func(guuid.UUID, string, []byte, []byte) (userT, error),
func() error,
error,
) {
- q := registerSQL(prefix)
+ q := registerSQL(cfg.prefix)
- writeStmt, err := db.Prepare(q.write)
+ writeStmt, err := cfg.shared.Prepare(q.write)
if err != nil {
return nil, nil, err
}
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
@@ -389,12 +437,11 @@ func sendTokenSQL(prefix string) queryT {
}
func sendTokenStmt(
- db *sql.DB,
- prefix string,
+ cfg dbconfigT,
) (func(guuid.UUID, string) error, func() error, error) {
- q := sendTokenSQL(prefix)
+ q := sendTokenSQL(cfg.prefix)
- writeStmt, err := db.Prepare(q.write)
+ writeStmt, err := cfg.shared.Prepare(q.write)
if err != nil {
return nil, nil, err
}
@@ -444,22 +491,21 @@ func confirmSQL(prefix string) queryT {
}
func confirmStmt(
- db *sql.DB,
- prefix string,
+ cfg dbconfigT,
) (func(string, guuid.UUID) (sessionT, error), func() error, error) {
- q := confirmSQL(prefix)
+ q := confirmSQL(cfg.prefix)
- writeStmt, err := db.Prepare(q.write)
+ writeStmt, err := cfg.shared.Prepare(q.write)
if err != nil {
return nil, nil, err
}
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, g.WrapErrors(writeStmt.Close(), err)
}
- sessionStmt, err := db.Prepare(q.session)
+ sessionStmt, err := cfg.shared.Prepare(q.session)
if err != nil {
return nil, nil, g.WrapErrors(
writeStmt.Close(),
@@ -523,7 +569,6 @@ func confirmStmt(
}
func byEmailSQL(prefix string) queryT {
- // FIXME: rewrite as LEFT JOIN?
const tmpl_read = `
SELECT id, timestamp, uuid, salt, pwhash, (
CASE WHEN EXISTS (
@@ -544,12 +589,11 @@ func byEmailSQL(prefix string) queryT {
}
func byEmailStmt(
- db *sql.DB,
- prefix string,
+ cfg dbconfigT,
) (func(string) (userT, error), func() error, error) {
- q := byEmailSQL(prefix)
+ q := byEmailSQL(cfg.prefix)
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
@@ -587,6 +631,65 @@ func byEmailStmt(
return fn, readStmt.Close, nil
}
+func userByUUIDSQL(prefix string) queryT {
+ const tmpl_read = `
+ SELECT id, timestamp, email, salt, pwhash, (
+ CASE WHEN EXISTS (
+ SELECT id FROM "%s_user_confirmations"
+ WHERE user_id = (
+ SELECT id FROM "%s_users"
+ WHERE uuid = ?
+ )
+ ) THEN 1
+ ELSE 0
+ END
+ ) as confirmed
+ FROM "%s_users" WHERE uuid = ?;
+ `
+ return queryT{
+ read: fmt.Sprintf(tmpl_read, prefix, prefix, prefix),
+ }
+}
+
+func userByUUIDStmt(
+ cfg dbconfigT,
+) (func(guuid.UUID) (userT, error), func() error, error) {
+ q := userByUUIDSQL(cfg.prefix)
+
+ readStmt, err := cfg.shared.Prepare(q.read)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fn := func(userID guuid.UUID) (userT, error) {
+ user := userT{
+ uuid: userID,
+ }
+
+ var timestr string
+ err := readStmt.QueryRow(userID[:], userID[:]).Scan(
+ &user.id,
+ &timestr,
+ &user.email,
+ &user.salt,
+ &user.pwhash,
+ &user.confirmed,
+ )
+ if err != nil {
+ return userT{}, err
+ }
+
+ user.timestamp, err = time.Parse(time.RFC3339Nano, timestr)
+ if err != nil {
+ return userT{}, err
+ }
+
+ return user, nil
+ }
+
+ return fn, readStmt.Close, nil
+}
+
func loginSQL(prefix string) queryT {
const tmpl_session = `
INSERT INTO "%s_sessions" (uuid, user_id)
@@ -601,12 +704,11 @@ func loginSQL(prefix string) queryT {
}
func loginStmt(
- db *sql.DB,
- prefix string,
+ cfg dbconfigT,
) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) {
- q := loginSQL(prefix)
+ q := loginSQL(cfg.prefix)
- sessionStmt, err := db.Prepare(q.session)
+ sessionStmt, err := cfg.shared.Prepare(q.session)
if err != nil {
return nil, nil, err
}
@@ -671,12 +773,11 @@ func refreshSQL(prefix string) queryT {
}
func refreshStmt(
- db *sql.DB,
- prefix string,
+ cfg dbconfigT,
) (func(guuid.UUID, guuid.UUID) (sessionT, error), func() error, error) {
- q := refreshSQL(prefix)
+ q := refreshSQL(cfg.prefix)
- writeStmt, err := db.Prepare(q.write)
+ writeStmt, err := cfg.shared.Prepare(q.write)
if err != nil {
return nil, nil, err
}
@@ -726,12 +827,11 @@ func resetSQL(prefix string) queryT {
}
func resetStmt(
- db *sql.DB,
- prefix string,
+ cfg dbconfigT,
) (func(int64, []byte, guuid.UUID) (sessionT, error), func() error, error) {
- q := resetSQL(prefix)
+ q := resetSQL(cfg.prefix)
- writeStmt, err := db.Prepare(q.write)
+ writeStmt, err := cfg.shared.Prepare(q.write)
if err != nil {
return nil, nil, err
}
@@ -758,12 +858,11 @@ func changeSQL(prefix string) queryT {
}
func changeStmt(
- db *sql.DB,
- prefix string,
+ cfg dbconfigT,
) (func(guuid.UUID, []byte) (sessionT, error), func() error, error) {
- q := changeSQL(prefix)
+ q := changeSQL(cfg.prefix)
- writeStmt, err := db.Prepare(q.write)
+ writeStmt, err := cfg.shared.Prepare(q.write)
if err != nil {
return nil, nil, err
}
@@ -790,12 +889,11 @@ func byUUIDSQL(prefix string) queryT {
}
func byUUIDStmt(
- db *sql.DB,
- prefix string,
+ cfg dbconfigT,
) (func(guuid.UUID) (sessionT, error), func() error, error) {
- q := byUUIDSQL(prefix)
+ q := byUUIDSQL(cfg.prefix)
- readStmt, err := db.Prepare(q.read)
+ readStmt, err := cfg.shared.Prepare(q.read)
if err != nil {
return nil, nil, err
}
@@ -822,12 +920,11 @@ func logoutSQL(prefix string) queryT {
}
func logoutStmt(
- db *sql.DB,
- prefix string,
+ cfg dbconfigT,
) (func(guuid.UUID) error, func() error, error) {
- q := logoutSQL(prefix)
+ q := logoutSQL(cfg.prefix)
- writeStmt, err := db.Prepare(q.write)
+ writeStmt, err := cfg.shared.Prepare(q.write)
if err != nil {
return nil, nil, err
}
@@ -850,12 +947,11 @@ func outOthersSQL(prefix string) queryT {
}
func outOthersStmt(
- db *sql.DB,
- prefix string,
+ cfg dbconfigT,
) (func(guuid.UUID) error, func() error, error) {
- q := outOthersSQL(prefix)
+ q := outOthersSQL(cfg.prefix)
- writeStmt, err := db.Prepare(q.write)
+ writeStmt, err := cfg.shared.Prepare(q.write)
if err != nil {
return nil, nil, err
}
@@ -878,12 +974,11 @@ func outAllSQL(prefix string) queryT {
}
func outAllStmt(
- db *sql.DB,
- prefix string,
+ cfg dbconfigT,
) (func(guuid.UUID) error, func() error, error) {
- q := outAllSQL(prefix)
+ q := outAllSQL(cfg.prefix)
- writeStmt, err := db.Prepare(q.write)
+ writeStmt, err := cfg.shared.Prepare(q.write)
if err != nil {
return nil, nil, err
}
@@ -897,29 +992,47 @@ func outAllStmt(
}
func initDB(
- db *sql.DB,
+ dbpath string,
prefix string,
) (queriesT, error) {
- createTablesErr := createTables(db, prefix)
- register, registerClose, registerErr := registerStmt(db, prefix)
- sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix)
- confirm, confirmClose, confirmErr := confirmStmt(db, prefix)
- byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix)
- login, loginClose, loginErr := loginStmt(db, prefix)
- refresh, refreshClose, refreshErr := refreshStmt(db, prefix)
- reset, resetClose, resetErr := resetStmt(db, prefix)
- change, changeClose, changeErr := changeStmt(db, prefix)
- byUUID, byUUIDClose, byUUIDErr := byUUIDStmt(db, prefix)
- logout, logoutClose, logoutErr := logoutStmt(db, prefix)
- outOthers, outOthersClose, outOthersErr := outOthersStmt(db, prefix)
- outAll, outAllClose, outAllErr := outAllStmt(db, prefix)
-
- err := g.SomeError(
+ err := g.ValidateSQLTablePrefix(prefix)
+ if err != nil {
+ return queriesT{}, err
+ }
+
+ shared, err := sql.Open(golite.DriverName, dbpath)
+ if err != nil {
+ return queriesT{}, err
+ }
+
+ cfg := dbconfigT{
+ shared: shared,
+ dbpath: dbpath,
+ prefix: prefix,
+ }
+
+ createTablesErr := createTables(shared, prefix)
+ register, registerClose, registerErr := registerStmt(cfg)
+ sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(cfg)
+ confirm, confirmClose, confirmErr := confirmStmt(cfg)
+ byEmail, byEmailClose, byEmailErr := byEmailStmt(cfg)
+ userByUUID, userByUUIDClose, userByUUIDErr := userByUUIDStmt(cfg)
+ login, loginClose, loginErr := loginStmt(cfg)
+ refresh, refreshClose, refreshErr := refreshStmt(cfg)
+ reset, resetClose, resetErr := resetStmt(cfg)
+ change, changeClose, changeErr := changeStmt(cfg)
+ byUUID, byUUIDClose, byUUIDErr := byUUIDStmt(cfg)
+ logout, logoutClose, logoutErr := logoutStmt(cfg)
+ outOthers, outOthersClose, outOthersErr := outOthersStmt(cfg)
+ outAll, outAllClose, outAllErr := outAllStmt(cfg)
+
+ err = g.SomeError(
createTablesErr,
registerErr,
sendTokenErr,
confirmErr,
byEmailErr,
+ userByUUIDErr,
loginErr,
refreshErr,
resetErr,
@@ -935,18 +1048,19 @@ func initDB(
close := func() error {
return g.SomeFnError(
- registerClose,
- sendTokenClose,
- confirmClose,
- byEmailClose,
- loginClose,
- refreshClose,
- resetClose,
- changeClose,
- byUUIDClose,
- logoutClose,
- outOthersClose,
- outAllClose,
+ registerClose,
+ sendTokenClose,
+ confirmClose,
+ byEmailClose,
+ userByUUIDClose,
+ loginClose,
+ refreshClose,
+ resetClose,
+ changeClose,
+ byUUIDClose,
+ logoutClose,
+ outOthersClose,
+ outAllClose,
)
}
@@ -977,6 +1091,11 @@ func initDB(
defer connMutex.RUnlock()
return byEmail(a)
},
+ userByUUID: func(a guuid.UUID) (userT, error) {
+ connMutex.RLock()
+ defer connMutex.RUnlock()
+ return userByUUID(a)
+ },
login: func(a guuid.UUID, b guuid.UUID) (sessionT, error) {
connMutex.RLock()
defer connMutex.RUnlock()
@@ -1179,13 +1298,9 @@ func NewWithPrefix(databasePath string, prefix string) (IAuth, error) {
return authT{}, err
}
- db, err := sql.Open(golite.DriverName, databasePath)
- if err != nil {
- return authT{}, err
- }
-
- queries, err := initDB(db, prefix)
+ queries, err := initDB(databasePath, prefix)
if err != nil {
+ queue.Close()
return authT{}, err
}
@@ -1193,13 +1308,13 @@ func NewWithPrefix(databasePath string, prefix string) (IAuth, error) {
hasher, closeHasher := makePoolRunner(numCPU, asResult(scrypt.Hash))
closeFn := func() {
+ queue.Close()
unregisterConsumers(queue, consumers, prefix)
closeHasher()
}
auth := authT{
queue: queue,
- db: db,
queries: queries,
hasher: unwrapResult(hasher),
close: closeFn,
@@ -1516,7 +1631,7 @@ func (auth authT) ResetPassword(
}
func (auth authT) ChangePassword(
- user User,
+ userID guuid.UUID,
currentPassword string,
newPassword string,
confirmNewPassword string,
@@ -1529,20 +1644,25 @@ func (auth authT) ChangePassword(
return Session{}, ErrTooShort
}
+ user, err := auth.queries.userByUUID(userID)
+ if err != nil {
+ return Session{}, err
+ }
+
input := scrypt.HashInput{
Password: []byte(newPassword),
- Salt: user.Salt,
+ Salt: user.salt,
}
pwhash, err := auth.hasher(input)
if err != nil {
return Session{}, err
}
- if !user.Confirmed {
+ if !user.confirmed {
return Session{}, ErrUnconfirmed
}
- session, err := auth.queries.change(user.UUID, pwhash)
+ session, err := auth.queries.change(user.uuid, pwhash)
if err != nil {
return Session{}, nil
}
diff --git a/tests/cracha.go b/tests/cracha.go
index 905733c..eda3fea 100644
--- a/tests/cracha.go
+++ b/tests/cracha.go
@@ -2,8 +2,10 @@ package cracha
import (
"database/sql"
+ "errors"
"fmt"
"os"
+ "reflect"
"time"
// "q"
@@ -51,17 +53,77 @@ func test_defaultPrefix() {
}
func test_tryRollback() {
- // FIXME
+ g.TestStart("tryRollback()")
+
+ myErr := errors.New("bottom error")
+
+ db, err := sql.Open(golite.DriverName, golite.InMemory)
+ g.TErrorIf(err)
+ defer db.Close()
+
+
+ g.Testing("the error is propagated if rollback doesn't fail", func() {
+ tx, err := db.Begin()
+ g.TErrorIf(err)
+
+ err = tryRollback(tx, myErr)
+ g.TAssertEqual(err, myErr)
+ })
+
+ g.Testing("a wrapped error when rollback fails", func() {
+ tx, err := db.Begin()
+ g.TErrorIf(err)
+
+ err = tx.Commit()
+ g.TErrorIf(err)
+
+ err = tryRollback(tx, myErr)
+ g.TAssertEqual(reflect.DeepEqual(err, myErr), false)
+ g.TAssertEqual(errors.Is(err, myErr), true)
+ })
}
func test_inTx() {
+ g.TestStart("inTx()")
+
+ db, err := sql.Open(golite.DriverName, golite.InMemory)
+ g.TErrorIf(err)
+ defer db.Close()
+
+
+ g.Testing("when fn() errors, we propagate it", func() {
+ myErr := errors.New("to be propagated")
+ err := inTx(db, func(tx *sql.Tx) error {
+ return myErr
+ })
+ g.TAssertEqual(err, myErr)
+ })
+
+ g.Testing("no nil error we get nil", func() {
+ err := inTx(db, func(tx *sql.Tx) error {
+ return nil
+ })
+ g.TErrorIf(err)
+ })
+}
+
+func test_serialized() {
+ // FIXME
+}
+
+func test_execSerialized() {
// FIXME
}
func test_createTables() {
g.TestStart("createTables()")
- db, err := sql.Open(golite.DriverName, ":memory:")
+ const (
+ prefix = defaultPrefix
+ dbpath = golite.InMemory
+ )
+
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
defer db.Close()
@@ -70,12 +132,12 @@ func test_createTables() {
const tmpl_read = `
SELECT id FROM "%s_users" LIMIT 1;
`
- qRead := fmt.Sprintf(tmpl_read, defaultPrefix)
+ qRead := fmt.Sprintf(tmpl_read, prefix)
_, err := db.Exec(qRead)
g.TErrorNil(err)
- err = createTables(db, defaultPrefix)
+ err = createTables(db, prefix)
g.TErrorIf(err)
_, err = db.Exec(qRead)
@@ -84,9 +146,9 @@ func test_createTables() {
g.Testing("we can do it multiple times", func() {
g.TErrorIf(g.SomeError(
- createTables(db, defaultPrefix),
- createTables(db, defaultPrefix),
- createTables(db, defaultPrefix),
+ createTables(db, prefix),
+ createTables(db, prefix),
+ createTables(db, prefix),
))
})
}
@@ -96,13 +158,19 @@ func test_registerStmt() {
const (
prefix = defaultPrefix
+ dbpath = golite.InMemory
)
- db, err := sql.Open(golite.DriverName, ":memory:")
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- register, registerClose, registerErr := registerStmt(db, prefix)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ }
+ register, registerClose, registerErr := registerStmt(cfg)
g.TErrorIf(registerErr)
defer g.SomeFnError(
registerClose,
@@ -182,14 +250,20 @@ func test_sendTokenStmt() {
const (
prefix = defaultPrefix
+ dbpath = golite.InMemory
)
- db, err := sql.Open(golite.DriverName, ":memory:")
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- register, registerClose, registerErr := registerStmt(db, prefix)
- sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ }
+ register, registerClose, registerErr := registerStmt(cfg)
+ sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(cfg)
g.TErrorIf(g.SomeError(
registerErr,
sendTokenErr,
@@ -288,15 +362,21 @@ func test_confirmStmt() {
const (
prefix = defaultPrefix
+ dbpath = golite.InMemory
)
- db, err := sql.Open(golite.DriverName, ":memory:")
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- register, registerClose, registerErr := registerStmt(db, prefix)
- sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix)
- confirm, confirmClose, confirmErr := confirmStmt(db, prefix)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ }
+ register, registerClose, registerErr := registerStmt(cfg)
+ sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(cfg)
+ confirm, confirmClose, confirmErr := confirmStmt(cfg)
g.TErrorIf(g.SomeError(
registerErr,
sendTokenErr,
@@ -377,16 +457,22 @@ func test_byEmailStmt() {
const (
prefix = defaultPrefix
+ dbpath = golite.InMemory
)
- db, err := sql.Open(golite.DriverName, ":memory:")
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- register, registerClose, registerErr := registerStmt(db, prefix)
- sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix)
- confirm, confirmClose, confirmErr := confirmStmt(db, prefix)
- byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ }
+ register, registerClose, registerErr := registerStmt(cfg)
+ sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(cfg)
+ confirm, confirmClose, confirmErr := confirmStmt(cfg)
+ byEmail, byEmailClose, byEmailErr := byEmailStmt(cfg)
g.TErrorIf(g.SomeError(
registerErr,
sendTokenErr,
@@ -403,8 +489,7 @@ func test_byEmailStmt() {
g.Testing("error when not found", func() {
- email := string(mksalt())
- _, err := byEmail(email)
+ _, err := byEmail(string(mksalt()))
g.TAssertEqual(err, sql.ErrNoRows)
})
@@ -434,22 +519,95 @@ func test_byEmailStmt() {
})
}
+func test_userByUUIDStmt() {
+ g.TestStart("userByUUIDStmt()")
+
+ const (
+ prefix = defaultPrefix
+ dbpath = golite.InMemory
+ )
+
+ db, err := sql.Open(golite.DriverName, dbpath)
+ g.TErrorIf(err)
+ g.TErrorIf(createTables(db, prefix))
+
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ }
+ register, registerClose, registerErr := registerStmt(cfg)
+ sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(cfg)
+ confirm, confirmClose, confirmErr := confirmStmt(cfg)
+ userByUUID, userByUUIDClose, userByUUIDErr := userByUUIDStmt(cfg)
+ g.TErrorIf(g.SomeError(
+ registerErr,
+ sendTokenErr,
+ confirmErr,
+ userByUUIDErr,
+ ))
+ defer g.SomeFnError(
+ registerClose,
+ sendTokenClose,
+ confirmClose,
+ userByUUIDClose,
+ db.Close,
+ )
+
+
+ g.Testing("error when not found", func() {
+ _, err := userByUUID(guuid.New())
+ g.TAssertEqual(err, sql.ErrNoRows)
+ })
+
+ g.Testing("full user otherwise, confirmed or not", func() {
+ u := newUser()
+
+ _, err := register(u.userID, u.email, u.salt, u.pwhash)
+ g.TErrorIf(err)
+
+ user1, err := userByUUID(u.userID)
+ g.TErrorIf(err)
+
+ g.TErrorIf(sendToken(u.userID, u.token))
+
+ user2, err := userByUUID(u.userID)
+ g.TErrorIf(err)
+
+ _, err = confirm(u.token, guuid.New())
+ g.TErrorIf(err)
+
+ user3, err := userByUUID(u.userID)
+ g.TErrorIf(err)
+
+ g.TAssertEqual(user1, user2)
+ user2.confirmed = true
+ g.TAssertEqual(user2, user3)
+ })
+}
+
func test_loginStmt() {
g.TestStart("loginStmt()")
const (
prefix = defaultPrefix
+ dbpath = golite.InMemory
)
- db, err := sql.Open(golite.DriverName, ":memory:")
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- register, registerClose, registerErr := registerStmt(db, prefix)
- sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix)
- confirm, confirmClose, confirmErr := confirmStmt(db, prefix)
- byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix)
- login, loginClose, loginErr := loginStmt(db, prefix)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ }
+ register, registerClose, registerErr := registerStmt(cfg)
+ sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(cfg)
+ confirm, confirmClose, confirmErr := confirmStmt(cfg)
+ byEmail, byEmailClose, byEmailErr := byEmailStmt(cfg)
+ login, loginClose, loginErr := loginStmt(cfg)
g.TErrorIf(g.SomeError(
registerErr,
sendTokenErr,
@@ -577,18 +735,24 @@ func test_refreshStmt() {
const (
prefix = defaultPrefix
+ dbpath = golite.InMemory
)
- db, err := sql.Open(golite.DriverName, ":memory:")
+ db, err := sql.Open(golite.DriverName, dbpath)
g.TErrorIf(err)
g.TErrorIf(createTables(db, prefix))
- register, registerClose, registerErr := registerStmt(db, prefix)
- sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(db, prefix)
- confirm, confirmClose, confirmErr := confirmStmt(db, prefix)
- byEmail, byEmailClose, byEmailErr := byEmailStmt(db, prefix)
- login, loginClose, loginErr := loginStmt(db, prefix)
- refresh, refreshClose, refreshErr := refreshStmt(db, prefix)
+ cfg := dbconfigT{
+ shared: db,
+ dbpath: dbpath,
+ prefix: prefix,
+ }
+ register, registerClose, registerErr := registerStmt(cfg)
+ sendToken, sendTokenClose, sendTokenErr := sendTokenStmt(cfg)
+ confirm, confirmClose, confirmErr := confirmStmt(cfg)
+ byEmail, byEmailClose, byEmailErr := byEmailStmt(cfg)
+ login, loginClose, loginErr := loginStmt(cfg)
+ refresh, refreshClose, refreshErr := refreshStmt(cfg)
g.TErrorIf(g.SomeError(
registerErr,
sendTokenErr,
@@ -890,6 +1054,7 @@ func MainTest() {
test_sendTokenStmt()
test_confirmStmt()
test_byEmailStmt()
+ test_userByUUIDStmt()
test_loginStmt()
test_refreshStmt()
test_resetStmt()