diff options
Diffstat (limited to 'src/lib.go')
-rw-r--r-- | src/lib.go | 216 |
1 files changed, 192 insertions, 24 deletions
@@ -7,11 +7,13 @@ import ( "errors" "flag" "fmt" + "io/ioutil" "log/slog" "net" "os" "regexp" "runtime/debug" + "sort" "strings" "sync" "time" @@ -42,7 +44,7 @@ type Channel struct { } type Context struct { - dbConn *sql.DB + db *sql.DB tx chan int } @@ -66,7 +68,9 @@ type Message struct { } var ( - CmdUser = Message { Command: "USER" } + CmdUSER = Message { Command: "USER" } + CmdPRIVMSG = Message { Command: "PRIVMSG" } + CmdJOIN = Message { Command: "JOIN" } ) func SplitOnCRLF(data []byte, _atEOF bool) (int, []byte, error) { @@ -114,17 +118,16 @@ func ParseMessageParams(params string) MessageParams { } var MessageRegex = regexp.MustCompilePOSIX( - // <prefix> <command> <params> - //1 2 3 4 - `^(:([^ ]+) +)?([a-zA-Z]+|[0-9]{3}) *( .*)$`, - // ^^^^ FIXME: test these spaces + // <prefix> <command> <params> + //1 2 3 4 + `^(:([^ ]+) +)?([a-zA-Z]+) *( .*)$`, ) func ParseMessage(rawMessage string) (Message, error) { var msg Message components := MessageRegex.FindStringSubmatch(rawMessage) if components == nil { - return msg, nil + return msg, errors.New("Can't parse message") } msg = Message { @@ -136,8 +139,58 @@ func ParseMessage(rawMessage string) (Message, error) { return msg, nil } -func HandleMessage(msg Message) { - fmt.Printf("msg: %#v\n", msg) +func HandleUnknown(ctx *Context, msg Message) { + g.Warning( + "Unsupported command", "unsupported-command", + "command", msg.Command, + ) + var r Reply = ReplyUnknown + r.Prefix = "dunno" + // return []Action { r } +} + +func HandleUSER(ctx *Context, msg Message) { + fmt.Printf("USER: %#v\n", msg) +} + +func HandlePRIVMSG(ctx *Context, msg Message) { + // . assert no missing params + // . write to DB: (after auth) + // . channel timeline: message from $USER + // . reply to $USER + // . broadcast new timeline event to members of the channel + + stmt, err := ctx.db.Prepare(` + INSERT INTO messages + (id, sender_id, body, timestamp) + VALUES + (?, ?, ?, ? ); + `) + if err != nil { + // FIXME: reply error + fmt.Println("can't prepare: ", err) + return + } + defer stmt.Close() + + ret, err := stmt.Exec(g.NewUUID().ToString(), "FIXME", "FIXME", time.Now()) + if err != nil { + // FIXME: reply error + fmt.Println("xablau can't prepare: ", err) + return + } + + fmt.Println("ret: ", ret) +} + +func HandleJOIN(ctx *Context, msg Message) { + fmt.Printf("JOIN: %#v\n", msg) + + // . write to DB: (after auth) + // . $USER now in channel + // . channel timeline: $USER joined + // . reply to $USER + // . broadcast new timeline event to members of the channel } func ReplyAnonymous() { @@ -146,21 +199,68 @@ func ReplyAnonymous() { func PersistMessage(msg Message) { } -func ActionsFor(msg Message) []int { - return []int { } +type ActionType int +const ( + ActionReply = iota +) + +type Action interface { + Type() ActionType +} + +type Reply struct { + Prefix string + Command int + Params MessageParams +} + +func (reply Reply) Type() ActionType { + return ActionReply +} + +var ( + ReplyUnknown = Reply { + Command: 421, + Params: MessageParams { + Middle: []string { }, + Trailing: "Unknown command", + }, + } +) + +var Commands = map[string]func(*Context, Message) { + CmdUSER.Command: HandleUSER, + CmdPRIVMSG.Command: HandlePRIVMSG, + CmdJOIN.Command: HandleJOIN, } -func RunAction(action int) { +func ActionFnFor(command string) func(*Context, Message) { + fn := Commands[command] + if fn != nil { + return fn + } + + return HandleUnknown } func ProcessMessage(ctx *Context, connection *Connection, rawMessage string) { msg, err := ParseMessage(rawMessage) if err != nil { + g.Info( + "Error processing message", + "process-message", + "err", err, + ) return } - if msg.Command == CmdUser.Command { - connection.id = msg.Params.Middle[0] + if msg.Command == CmdUSER.Command { + args := msg.Params.Middle + if len(args) == 0 { + go ReplyAnonymous() + return + } + connection.id = args[0] connection.isAuthenticated = true } @@ -169,9 +269,7 @@ func ProcessMessage(ctx *Context, connection *Connection, rawMessage string) { return } - for _, action := range ActionsFor(msg) { - RunAction(action) - } + ActionFnFor(msg.Command)(ctx, msg) } func ReadLoop(ctx *Context, connection *Connection) { @@ -183,11 +281,11 @@ func ReadLoop(ctx *Context, connection *Connection) { } func WriteLoop(ctx *Context, connection *Connection) { - fmt.Println("WriteLoop") + // fmt.Println("WriteLoop") } func PingLoop(ctx *Context, connection *Connection) { - fmt.Println("PingLoop") + // fmt.Println("PingLoop") } func HandleConnection(ctx *Context, conn net.Conn) { @@ -270,10 +368,80 @@ func SetEnvironmentVariables() { } } +func InitMigrations(db *sql.DB) { + _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS migrations ( + filename TEXT PRIMARY KEY + ); + `) + g.FatalIf(err) +} + +const MIGRATIONS_DIR = "src/sql/migrations/" +func PendingMigrations(db *sql.DB) []string { + files, err := ioutil.ReadDir(MIGRATIONS_DIR) + g.FatalIf(err) + + set := make(map[string]bool) + for _, file := range files { + set[file.Name()] = true + } + + rows, err := db.Query(`SELECT filename FROM migrations;`) + g.FatalIf(err) + defer rows.Close() + + for rows.Next() { + var filename string + err := rows.Scan(&filename) + g.FatalIf(err) + delete(set, filename) + } + g.FatalIf(rows.Err()) + + difference := make([]string, 0) + for filename := range set { + difference = append(difference, filename) + } + + sort.Sort(sort.StringSlice(difference)) + return difference +} + +func RunMigrations(db *sql.DB) { + InitMigrations(db) + + stmt, err := db.Prepare(`INSERT INTO migrations (filename) VALUES (?);`) + g.FatalIf(err) + defer stmt.Close() + + for _, filename := range PendingMigrations(db) { + g.Info("Running migration file", "exec-migration-file", + "filename", filename, + ) + + tx, err := db.Begin() + g.FatalIf(err) + + sql, err := os.ReadFile(MIGRATIONS_DIR + filename) + g.FatalIf(err) + + _, err = tx.Exec(string(sql)) + g.FatalIf(err) + + _, err = tx.Stmt(stmt).Exec(filename) + g.FatalIf(err) + + err = tx.Commit() + g.FatalIf(err) + } +} + func InitDB(databasePath string) *sql.DB { - DB, err := sql.Open("sqlite3", databasePath) + db, err := sql.Open("sqlite3", databasePath) g.FatalIf(err) - return DB + RunMigrations(db) + return db } func Init() { @@ -313,11 +481,11 @@ func Start(ctx *Context, publicSocketPath string, commandSocketPath string) { } func BuildContext(databasePath string) *Context { - dbConn := InitDB(databasePath) + db := InitDB(databasePath) tx := make(chan int, 100) return &Context { - dbConn, - tx, + db: db, + tx: tx, } } |