package gracha import ( "crypto/rand" "database/sql" "encoding/hex" "encoding/json" "errors" "fmt" "time" "guuid" "liteq" "scrypt" g "gobang" ) type tablesT struct{ users string userChanges string tokens string roles string roleChanges string sessions string attempts string audit string } type queriesT struct{ userByEmail func(string) (userT, error) userByToken func([]byte) (userT, error) register func(string, []byte, []byte) (userT, error) confirm func([]byte) (sessionT, error) login func(string, string) (sessionT, error) refresh func([]byte) (sessionT, error) resetPassword func(int64, []byte, []byte) (sessionT, error) resetAndConfirm func(int64, []byte, []byte) (sessionT, error) changePassword func(int64, []byte) (sessionT, error) sessionByUUID func([]byte) (sessionT, error) logout func([]byte) error logoutOthers func([]byte) error logoutAll func([]byte) error close func() error } type confirmationT struct{ } type userT struct{ id int64 timestr string timestamp time.Time uuid []byte email string username *string salt []byte pwhash []byte // confirmation confirmed_at *time.Time metadatastr *string metadata *map[string]interface{} } type sessionT struct{ id int64 timestr string timestamp time.Time uuid guuid.UUID user_id int64 type_ string revokedstr *string revoked_at *time.Time metadatastr *string metadata *map[string]interface{} } type Auth struct{ tables tablesT queries queriesT q liteq.Queue } type consumerT struct{ topic string handlerFn func(Auth) func([]byte) error } const ( NEW_USER = "new-user" SEND_CONFIRMATION_REQUEST = "send-confirmation-request" FORGOT_PASSWORD_REQUEST = "forgot-password-request" defaultPrefix = "gracha" day = 24 * time.Hour ) var ( SessionDuration = 7 * day RegisterTimeout = 15 * time.Second ErrPasswordMismatch = errors.New("gracha: password and its confirmation don't match") ErrPasswordTooShort = errors.New("gracha: bad username/passphrase combo") ErrRegisterTimeout = errors.New("gracha: timeout when creating user") ErrAlreadyRegistered = errors.New("gracha: user already registered") ErrEmailPasswordCombo = errors.New("gracha: bad username/passphrase combo") ErrUnconfirmedUser = errors.New("gracha: user email is not confirmed") ErrRevokedSession = errors.New("gracha: this session was revoked") ErrSessionExpired = errors.New("gracha: session expired") ) func tablesFrom(prefix string) (tablesT, error) { if !g.ValidSQLTablePrefix(prefix) { return tablesT{}, g.ErrBadSQLTablePrefix } users := prefix + "-users" userChanges := prefix + "-user-changes" tokens := prefix + "-tokens" roles := prefix + "-roles" roleChanges := prefix + "-role-changes" sessions := prefix + "-sessions" attempts := prefix + "-attempts" audit := prefix + "-audit" return tablesT{ users: users, userChanges: userChanges, tokens: tokens, roles: roles, roleChanges: roleChanges, sessions: sessions, attempts: attempts, audit: audit, }, nil } func createTables(db *sql.DB, tables tablesT) error { const tmpl = ` BEGIN TRANSACTION; CREATE TABLE IF NOT EXISTS "%s" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, email TEXT NOT NULL UNIQUE, username TEXT UNIQUE, salt BLOB NOT NULL UNIQUE, pwhash BLOB NOT NULL, confirmed_at TEXT, confirmer_id INTEGER REFERENCES "%s"(id), metadata TEXT ); CREATE TABLE IF NOT EXISTS "%s" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), user_id INTEGER NOT NULL REFERENCES "%s"(id), attribute TEXT NOT NULL, value TEXT NOT NULL, op BOOLEAN NOT NULL ); CREATE TABLE IF NOT EXISTS "%s" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, type TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS "%s" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL REFERENCES "%s"(id), role TEXT NOT NULL, UNIQUE (user_id, role) ); CREATE TABLE IF NOT EXISTS "%s" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), user_id INTEGER NOT NULL REFERENCES "%s"(id), role TEXT NOT NULL, op BOOLEAN NOT NULL ); CREATE TABLE IF NOT EXISTS "%s" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, user_id INTEGER NOT NULL REFERENCES "%s"(id), type TEXT NOT NULL, revoked_at TEXT, revoker_id INTEGER REFERENCES "%s"(id), metadata TEXT ); CREATE TABLE IF NOT EXISTS "%s" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), user_id INTEGER REFERENCES "%s"(id), session_id INTEGER REFERENCES "%s"(id), metadata TEXT ); CREATE TABLE IF NOT EXISTS "%s" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, attribute TEXT NOT NULL, value TEXT NOT NULL, op BOOLEAN NOT NULL, metadata TEXT ); COMMIT TRANSACTION; ` sql := fmt.Sprintf( tmpl, tables.users, g.SQLiteNow, tables.tokens, tables.userChanges, g.SQLiteNow, tables.users, tables.tokens, g.SQLiteNow, tables.roles, tables.users, tables.roleChanges, g.SQLiteNow, tables.users, tables.sessions, g.SQLiteNow, tables.users, tables.sessions, tables.attempts, g.SQLiteNow, tables.users, tables.sessions, tables.audit, g.SQLiteNow, ) /// fmt.Println(sql) /// _, err := db.Exec(sql) return err } func userByEmailQuery( db *sql.DB, tables tablesT, ) (func(string) (userT, error), func() error, error) { const tmpl = ` SELECT id, timestamp, uuid, email, username, pwhash, metadata FROM "%s" WHERE email = ?; ` sql := fmt.Sprintf(tmpl, tables.users) /// fmt.Println(sql) /// stmt, err := db.Prepare(sql) if err != nil { return nil, nil, err } fn := func(email string) (userT, error) { var user userT err := stmt.QueryRow(email).Scan(&user.id) // FIXME: build user return user, err } return fn, stmt.Close, nil } func userByTokenQuery( db *sql.DB, tables tablesT, ) (func([]byte) (userT, error), func() error, error) { const tmpl = ` SELECT id, timestamp, uuid, email, username, pwhash, metadata FROM "%s" WHERE email = ?; ` sql := fmt.Sprintf(tmpl, tables.users) /// fmt.Println(sql) /// stmt, err := db.Prepare(sql) if err != nil { return nil, nil, err } fn := func(token []byte) (userT, error) { var user userT err := stmt.QueryRow(token).Scan(&user.id) // FIXME: build user return user, err } return fn, stmt.Close, nil } func registerQuery( db *sql.DB, tables tablesT, ) (func(string, []byte, []byte) (userT, error), func() error, error) { const tmpl = ` INSERT INTO "%s" (uuid, email, username, salt, pwhash, metadata) VALUES (?, ?, ?, ?, ?, ?); ` sql := fmt.Sprintf(tmpl, tables.users) /// fmt.Println(sql) /// stmt, err := db.Prepare(sql) if err != nil { return nil, nil, err } fn := func(email string, salt []byte, pwhash []byte) (userT, error) { /* timestamp TEXT NOT NULL DEFAULT (%s), uuid BLOB NOT NULL UNIQUE, email TEXT NOT NULL UNIQUE, username TEXT UNIQUE, pwhash TEXT NOT NULL, metadata TEXT */ var user userT // err := stmt.QueryRow( ret, err := stmt.Exec( guuid.NewBytes(), email, "credentials.username", salt, pwhash, "credentials.metadata", // ).Scan(&user.email) // FIXME: finish ) if false { fmt.Printf("ret: %#v\n", ret) fmt.Printf("user: %#v\n", user) } return user, err } return fn, stmt.Close, nil } func loginQuery( db *sql.DB, tables tablesT, ) (func(string, string) (sessionT, error), func() error, error) { const tmpl = ` -- INSERT INTO "%s" (t3, t4) VALUES (?, ?); ` sql := fmt.Sprintf(tmpl, tables.users) /// fmt.Println(sql) /// stmt, err := db.Prepare(sql) if err != nil { return nil, nil, err } fn := func(email string, pwhash string) (sessionT, error) { var session sessionT err := stmt.QueryRow(email, pwhash).Scan(session) // FIXME: finish return session, err } return fn, stmt.Close, nil } func refreshQuery( db *sql.DB, tables tablesT, ) (func([]byte) (sessionT, error), func() error, error) { const tmpl = ` -- INSERT SOMETHING %s ` sql := fmt.Sprintf(tmpl, tables.users) /// fmt.Println(sql) /// stmt, err := db.Prepare(sql) if err != nil { return nil, nil, err } fn := func(uuid []byte) (sessionT, error) { var session sessionT err := stmt.QueryRow(uuid).Scan(&session) return session, err } return fn, stmt.Close, nil } func resetPasswordQuery( db *sql.DB, tables tablesT, ) (func(int64, []byte, []byte) (sessionT, error), func() error, error) { const tmpl = ` -- INSERT SOMETHING %s ` sql := fmt.Sprintf(tmpl, tables.users) /// fmt.Println(sql) /// stmt, err := db.Prepare(sql) if err != nil { return nil, nil, err } fn := func(id int64, pwhash []byte, token []byte) (sessionT, error) { var session sessionT err := stmt.QueryRow(id, pwhash, token).Scan(&session) return session, err } return fn, stmt.Close, nil } func resetAndConfirmQuery( db *sql.DB, tables tablesT, ) (func(int64, []byte, []byte) (sessionT, error), func() error, error) { const tmpl = ` -- INSERT SOMETHING %s ` sql := fmt.Sprintf(tmpl, tables.users) /// fmt.Println(sql) /// stmt, err := db.Prepare(sql) if err != nil { return nil, nil, err } fn := func(id int64, pwhash []byte, token []byte) (sessionT, error) { var session sessionT err := stmt.QueryRow(id, pwhash, token).Scan(&session) return session, err } return fn, stmt.Close, nil } func changePasswordQuery( db *sql.DB, tables tablesT, ) (func(int64, []byte) (sessionT, error), func() error, error) { const tmpl = ` -- INSERT SOMETHING %s ` sql := fmt.Sprintf(tmpl, tables.users) /// fmt.Println(sql) /// stmt, err := db.Prepare(sql) if err != nil { return nil, nil, err } fn := func(id int64, pwhash []byte) (sessionT, error) { var session sessionT err := stmt.QueryRow(id, pwhash).Scan(&session) return session, err } return fn, stmt.Close, nil } func sessionByUUIDQuery( db *sql.DB, tables tablesT, ) (func([]byte) (sessionT, error), func() error, error) { const tmpl = ` -- INSERT SOMETHING %s ` sql := fmt.Sprintf(tmpl, tables.users) /// fmt.Println(sql) /// stmt, err := db.Prepare(sql) if err != nil { return nil, nil, err } fn := func(uuid []byte) (sessionT, error) { var session sessionT err := stmt.QueryRow(uuid).Scan(&session) return session, err } return fn, stmt.Close, nil } func logoutQuery( db *sql.DB, tables tablesT, ) (func([]byte) error, func() error, error) { const tmpl = ` -- INSERT SOMETHING %s ` sql := fmt.Sprintf(tmpl, tables.users) /// fmt.Println(sql) /// stmt, err := db.Prepare(sql) if err != nil { return nil, nil, err } fn := func(uuid []byte) error { _, err := stmt.Exec(uuid) return err } return fn, stmt.Close, nil } func logoutOthersQuery( db *sql.DB, tables tablesT, ) (func([]byte) error, func() error, error) { const tmpl = ` -- INSERT SOMETHING %s ` sql := fmt.Sprintf(tmpl, tables.users) /// fmt.Println(sql) /// stmt, err := db.Prepare(sql) if err != nil { return nil, nil, err } fn := func(uuid []byte) error { _, err := stmt.Exec(uuid) return err } return fn, stmt.Close, nil } func logoutAllQuery( db *sql.DB, tables tablesT, ) (func([]byte) error, func() error, error) { const tmpl = ` -- INSERT SOMETHING %s ` sql := fmt.Sprintf(tmpl, tables.users) /// fmt.Println(sql) /// stmt, err := db.Prepare(sql) if err != nil { return nil, nil, err } fn := func(uuid []byte) error { _, err := stmt.Exec(uuid) return err } return fn, stmt.Close, nil } func initDB(db *sql.DB, tables tablesT) (queriesT, error) { createTablesErr := createTables(db, tables) userByEmail, userByEmailClose, userByEmailErr := userByEmailQuery(db, tables) userByToken, userByTokenClose, userByTokenErr := userByTokenQuery(db, tables) register, registerClose, registerErr := registerQuery(db, tables) login, loginClose, loginErr := loginQuery(db, tables) refresh, refreshClose, refreshErr := refreshQuery(db, tables) resetPassword, resetPasswordClose, resetPasswordErr := resetPasswordQuery(db, tables) resetAndConfirm, resetAndConfirmClose, resetAndConfirmErr := resetAndConfirmQuery(db, tables) changePassword, changePasswordClose, changePasswordErr := changePasswordQuery(db, tables) sessionByUUID, sessionByUUIDClose, sessionByUUIDErr := sessionByUUIDQuery(db, tables) logout, logoutClose, logoutErr := logoutQuery(db, tables) logoutOthers, logoutOthersClose, logoutOthersErr := logoutOthersQuery(db, tables) logoutAll, logoutAllClose, logoutAllErr := logoutAllQuery(db, tables) errs := []error { createTablesErr, userByEmailErr, userByTokenErr, registerErr, loginErr, refreshErr, resetPasswordErr, resetAndConfirmErr, changePasswordErr, sessionByUUIDErr, logoutErr, logoutOthersErr, logoutAllErr, } err := g.SomeError(errs) if err != nil { return queriesT{}, err } close := func() error { fns := [](func() error){ userByEmailClose, userByTokenClose, registerClose, loginClose, refreshClose, resetPasswordClose, resetAndConfirmClose, changePasswordClose, sessionByUUIDClose, logoutClose, logoutOthersClose, logoutAllClose, } return g.SomeFnError(fns) } return queriesT{ userByEmail: userByEmail, userByToken: userByToken, register: register, login: login, close: close, refresh: refresh, resetPassword: resetPassword, resetAndConfirm: resetAndConfirm, changePassword: changePassword, sessionByUUID: sessionByUUID, logout: logout, logoutOthers: logoutOthers, logoutAll: logoutAll, }, nil } func newUserPayload(email string, salt []byte, pwhash []byte) ([]byte, error) { data := make(map[string]interface{}) data["email"] = email data["salt"] = hex.EncodeToString(salt) data["pwhash"] = hex.EncodeToString(pwhash) return json.Marshal(data) } func publishRegister( q liteq.Queue, email string, salt []byte, pwhash []byte, flowId []byte, ) error { payload, err := newUserPayload(email, salt, pwhash) if err != nil { return err } return q.Publish(NEW_USER, payload, flowId) } func register( auth Auth, email string, salt []byte, pwhash []byte, ) (userT, error) { flowId := guuid.NewBytes() waiter := auth.q.WaitFor(NEW_USER, flowId) defer auth.q.Unwait(waiter) err := publishRegister(auth.q, email, salt, pwhash, flowId) if err != nil { return userT{}, err } select { case <-time.After(RegisterTimeout): return userT{}, ErrRegisterTimeout case <-waiter: return auth.queries.userByEmail(email) } } func (auth Auth) Register( email string, password string, confirmPassword string, ) (userT, error) { if password != confirmPassword { return userT{}, ErrPasswordMismatch } if len(password) < scrypt.MinimumPasswordLength { return userT{}, ErrPasswordTooShort } // special check for sql.ErrNoRows to combat enumeration attacks. _, lookupErr := auth.queries.userByEmail(email) if lookupErr != nil && lookupErr != sql.ErrNoRows { return userT{}, lookupErr } salt, err := scrypt.SaltFrom(rand.Reader) if err != nil { return userT{}, err } pwhash, err := scrypt.HashFrom([]byte(password), salt) if err != nil { return userT{}, err } /* We also try to register anyway, to prevent disk IO timing attacks. */ user, err := register(auth, email, salt, pwhash) if err != nil { if lookupErr != nil { return userT{}, ErrAlreadyRegistered } return userT{}, err } return user, nil } func sendConfirmationPayload(email string) ([]byte, error) { data := make(map[string]interface{}) data["email"] = email return json.Marshal(data) } func publishSendConfirmation(q liteq.Queue, email string, flowId []byte) error { payload, err := sendConfirmationPayload(email) if err != nil { return err } return q.Publish(SEND_CONFIRMATION_REQUEST, payload, flowId) } func (auth Auth) ResendConfirmation(email string) error { return publishSendConfirmation(auth.q, email, guuid.NewBytes()) } func (auth Auth) ConfirmEmail(token guuid.UUID) (sessionT, error) { return auth.queries.confirm(token[:]) } func (auth Auth) LoginEmail( email string, password string, ) (sessionT, error) { if len(password) < scrypt.MinimumPasswordLength { return sessionT{}, ErrPasswordTooShort } // special check for sql.ErrNoRows to combat enumeration attacks. user, err := auth.queries.userByEmail(email) if err != nil && err != sql.ErrNoRows { return sessionT{}, err } ok, err := scrypt.CheckFrom([]byte(password), user.salt, user.pwhash) if err != nil { return sessionT{}, err } if !ok { return sessionT{}, ErrEmailPasswordCombo } if user.confirmed_at == nil { return sessionT{}, ErrUnconfirmedUser } dbSession, err := auth.queries.login(email, password) if err != nil { return sessionT{}, err } return dbSession, nil } func forgotPasswordPayload(email string) ([]byte, error) { data := make(map[string]interface{}) data["email"] = email return json.Marshal(data) } func publishForgotPassword(q liteq.Queue, email string, flowId []byte) error { payload, err := forgotPasswordPayload(email) if err != nil { return err } return q.Publish(FORGOT_PASSWORD_REQUEST, payload, flowId) } func (auth Auth) ForgotPassword(email string) error { // special check for sql.ErrNoRows to combat enumeration attacks. user, err := auth.queries.userByEmail(email) if err != nil && err != sql.ErrNoRows { return err } return publishForgotPassword(auth.q, user.email, guuid.NewBytes()) } func checkSession(session sessionT, now time.Time) error { if session.revoked_at != nil { return ErrRevokedSession } if session.timestamp.Add(SessionDuration).After(now) { return ErrSessionExpired } return nil } func validateSession(lookupFn func([]byte) (sessionT, error), session sessionT) error { dbSession, err := lookupFn(session.uuid[:]) if err != nil { return err } return checkSession(dbSession, time.Now()) } func (auth Auth) Refresh(session sessionT) (sessionT, error) { err := validateSession(auth.queries.sessionByUUID, session) if err != nil { return sessionT{}, err } return auth.queries.refresh(session.uuid[:]) } func (auth Auth) ResetPassword( token []byte, password string, confirmPassword string, ) (sessionT, error) { if password != confirmPassword { return sessionT{}, ErrPasswordMismatch } if len(password) < scrypt.MinimumPasswordLength { return sessionT{}, ErrPasswordTooShort } user, err := auth.queries.userByToken(token) if err != nil { return sessionT{}, err } pwhash, err := scrypt.HashFrom([]byte(password), user.salt) if err != nil { return sessionT{}, err } if user.confirmed_at != nil { return auth.queries.resetPassword(user.id, pwhash, token) } else { return auth.queries.resetAndConfirm(user.id, pwhash, token) } } func (auth Auth) ChangePassword( user userT, currentPassword string, newPassword string, confirmNewPassword string, ) (sessionT, error) { if newPassword != confirmNewPassword { return sessionT{}, ErrPasswordMismatch } if len(newPassword) < scrypt.MinimumPasswordLength { return sessionT{}, ErrPasswordTooShort } pwhash, err := scrypt.HashFrom([]byte(newPassword), user.salt) if err != nil { return sessionT{}, err } if user.confirmed_at == nil { return sessionT{}, ErrUnconfirmedUser } return auth.queries.changePassword(user.id, pwhash) } func runLogout( lookupFn func([]byte) (sessionT, error), session sessionT, queryFn func([]byte) error, ) error { err := validateSession(lookupFn, session) if err != nil { return err } return queryFn(session.uuid[:]) } func (auth Auth) Logout(session sessionT) error { return runLogout(auth.queries.sessionByUUID, session, auth.queries.logout) } func (auth Auth) LogoutOthers(session sessionT) error { return runLogout(auth.queries.sessionByUUID, session, auth.queries.logoutOthers) } func (auth Auth) LogoutAll(session sessionT) error { return runLogout(auth.queries.sessionByUUID, session, auth.queries.logoutAll) } func (auth Auth) Close() error { return auth.queries.close() } func newUserHandler(auth Auth) func([]byte) error { return func(payload []byte) error { return nil } } func sendConfirmationRequestHandler(auth Auth) func([]byte) error { return func(payload []byte) error { return nil } } func forgotPasswordRequestHandler(auth Auth) func([]byte) error { return func(payload []byte) error { return nil } } var consumers = []consumerT{ consumerT{ topic: NEW_USER, handlerFn: newUserHandler, }, consumerT{ topic: SEND_CONFIRMATION_REQUEST, handlerFn: sendConfirmationRequestHandler, }, consumerT{ topic: FORGOT_PASSWORD_REQUEST, handlerFn: forgotPasswordRequestHandler, }, } func registerConsumers(auth Auth, consumers []consumerT) { for _, consumer := range consumers { auth.q.Subscribe( consumer.topic, defaultPrefix + "-" + consumer.topic, consumer.handlerFn(auth), ) } } func NewWithPrefix(db *sql.DB, q liteq.Queue, prefix string) (Auth, error) { tables, err := tablesFrom(prefix) if err != nil { return Auth{}, err } queries, err := initDB(db, tables) if err != nil { return Auth{}, err } return Auth{ tables: tables, queries: queries, q: q, }, nil } func New(db *sql.DB, q liteq.Queue) (Auth, error) { return NewWithPrefix(db, q, defaultPrefix) } func Main() { g.Init() q := new(liteq.Queue) sql.Register("sqlite-liteq", liteq.MakeDriver(q)) db, err := sql.Open("sqlite-liteq", "file:gracha.db?mode=memory&cache=shared") if err != nil { panic(err) } // defer db.Close() *q, err = liteq.New(db) if err != nil { panic(err) } // defer q.Close() auth, err := New(db, *q) if err != nil { fmt.Println(err) panic(err) } user, err := auth.Register("contact@example.com", "password", "password") if false { fmt.Printf("user: %#v\n", user) fmt.Printf("err: %#v\n", err) } return fmt.Printf("q: %#v\n", q) fmt.Printf("auth: %#v\n", auth) }