diff options
author | EuAndreh <eu@euandre.org> | 2024-11-21 11:04:08 -0300 |
---|---|---|
committer | EuAndreh <eu@euandre.org> | 2025-01-17 09:51:33 -0300 |
commit | 65de65ce1e34efeb421974bcb5ddd85fb53253bb (patch) | |
tree | 62c37d832885d9a1113369db4e27563c4dbbcdb8 /src | |
parent | src/papod.go: Integrate db layer with network, create command handlers, simpl... (diff) | |
download | papod-65de65ce1e34efeb421974bcb5ddd85fb53253bb.tar.gz papod-65de65ce1e34efeb421974bcb5ddd85fb53253bb.tar.xz |
Implement most of db layer
Many missing implementations or tests are marked with FIXME so I don't
loose track of holes in the code.
Diffstat (limited to 'src')
-rw-r--r-- | src/papod.go | 3066 |
1 files changed, 2312 insertions, 754 deletions
diff --git a/src/papod.go b/src/papod.go index 4bc5b4d..710c87c 100644 --- a/src/papod.go +++ b/src/papod.go @@ -52,25 +52,29 @@ type queriesT struct{ userByUUID func(guuid.UUID) (userT, error) updateUser func(userT) error deleteUser func(guuid.UUID) error - addNetwork func(userT, newNetworkT) (networkT, error) + addNetwork func(userT, newNetworkT, guuid.UUID) (networkT, error) getNetwork func(userT, guuid.UUID) (networkT, error) networks func(userT, func(networkT) error) error - setNetwork func(userT, networkT) error - nipNetwork func(userT, guuid.UUID) error - addMember func(userT, networkT, newMemberT) (memberT, error) - showMember func(userT, guuid.UUID) (memberT, error) - members func(userT, guuid.UUID, func(memberT) error) error - editMember func(userT, memberT) error - dropMember func(userT, guuid.UUID) error - addChannel func(guuid.UUID, newChannelT) (channelT, error) - channels func(guuid.UUID, func(channelT) error) error - topic func(channelT) error - endChannel func(guuid.UUID) error - join func(guuid.UUID, guuid.UUID) error - part func(guuid.UUID, guuid.UUID) error - names func(guuid.UUID, func(memberT) error) error + setNetwork func(memberT, networkT) error + nipNetwork func(memberT) error + membership func(userT, networkT) (memberT, error) + addMember func(memberT, newMemberT) (memberT, error) + addRole func(memberT, string, memberT) error + dropRole func(memberT, string, memberT) error + showMember func(memberT, guuid.UUID) (memberT, error) + members func(memberT, func(memberT) error) error + editMember func(memberT, memberT) error + dropMember func(memberT, guuid.UUID) error + addChannel func(memberT, newChannelT) (channelT, error) + chanByName func(memberT, string) (channelT, error) + channels func(memberT, func(channelT) error) error + setChannel func(memberT, channelT) error + endChannel func(memberT, guuid.UUID) error + join func(memberT, guuid.UUID) error + part func(memberT, channelT) error + names func(memberT, guuid.UUID, func(memberT) error) error addEvent func(newEventT) (eventT, error) - allAfter func(guuid.UUID, func(eventT) error) error + allAfter func(memberT, guuid.UUID, func(eventT) error) error logMessage func(userT, messageT) error close func() error } @@ -97,6 +101,13 @@ const ( NetworkType_Unlisted NetworkType = "unlisted" ) +type MemberStatus string +const ( + MemberStatus_Active MemberStatus = "active" + MemberStatus_Inactive MemberStatus = "inactive" + MemberStatus_Removed MemberStatus = "removed" +) + type newNetworkT struct{ uuid guuid.UUID name string @@ -108,28 +119,33 @@ type networkT struct{ id int64 timestamp time.Time uuid guuid.UUID - createdBy guuid.UUID name string description string type_ NetworkType } type newMemberT struct{ - userID guuid.UUID + userID guuid.UUID + memberID guuid.UUID + username string } type memberT struct{ - id int64 - timestamp time.Time - uuid guuid.UUID + id int64 + timestamp time.Time + uuid guuid.UUID + username string + displayName string + pictureID *guuid.UUID + status MemberStatus + roles []string } type newChannelT struct{ id int64 timestamp time.Time uuid guuid.UUID - // networkID guuid.UUID FIXME - publicName string + publicName *string label string description string virtual bool @@ -139,48 +155,65 @@ type channelT struct{ id int64 timestamp time.Time uuid guuid.UUID - // networkID guuid.UUID FIXME - publicName string + publicName *string label string description string virtual bool } +type SourceType string +const ( + SourceType_Logon SourceType = "logon" +) + +type sourceT struct{ + uuid guuid.UUID + type_ SourceType + metadata *map[string]interface{} +} + +type EventType string +const ( + EventType_UserJoin EventType = "user-join" + EventType_UserMessage EventType = "user-message" +) + type newEventT struct{ - eventID guuid.UUID - channelID guuid.UUID - connectionID guuid.UUID - type_ string - payload string + eventID guuid.UUID + channelID guuid.UUID + source sourceT + type_ EventType + payload string + metadata *map[string]interface{} } type eventT struct{ - id int64 - timestamp time.Time - uuid guuid.UUID - channelID guuid.UUID - connectionID guuid.UUID - type_ string - payload string - previous *eventT - isFist bool + id int64 + timestamp time.Time + uuid guuid.UUID + channelID guuid.UUID + source sourceT + type_ EventType + payload string + metadata *map[string]interface{} } -type messageParamsT struct{ - middle []string - trailing string +type eventEntryT struct{ + event eventT + previous *eventT + isFirst bool } type messageT struct{ prefix string command string - params messageParamsT + params []string raw string } type replyT struct{ command string - params messageParamsT + params []string } type listenersT struct{ @@ -196,22 +229,16 @@ type consumerT struct{ handlerFn func(papodT) func(fiinha.Message) error } +type netConnI interface{ + Write(p []byte) (n int, err error) + Close() error +} + type connectionT struct{ - conn net.Conn uuid guuid.UUID user *userT -} - -type receiverT struct{ - send func(messageT) - close func() -} - -type receiversT struct{ - add func(receiverT) - remove func(receiverT) - get func(guuid.UUID) []receiverT - close func() + conn netConnI + send func(messageT) } type metricsT struct{ @@ -228,7 +255,7 @@ type papodT struct{ queries queriesT listeners listenersT consumers []consumerT - receivers receiversT + state stateT metrics metricsT // logger g.Logger } @@ -329,134 +356,271 @@ func inTx(db *sql.DB, fn func(*sql.Tx) error) error { /// treated only as opaque IDs. func createTablesSQL(prefix string) queryT { const tmpl_write = ` - -- FIXME: unconfirmed premise: statements within a trigger are - -- part of the transaction that caused it, and so are - -- atomic. + -- TODO: unconfirmed premise: statements within a trigger are + -- part of the transaction that caused it, and so are + -- atomic. -- See also: -- https://stackoverflow.com/questions/77441888/ -- https://stackoverflow.com/questions/30511116/ CREATE TABLE IF NOT EXISTS "%s_users" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), + timestamp TEXT NOT NULL DEFAULT ( + %s + ), -- provided by cracha - uuid BLOB NOT NULL UNIQUE, + user_uuid BLOB NOT NULL UNIQUE, username TEXT NOT NULL, display_name TEXT NOT NULL, picture_uuid BLOB UNIQUE, deleted INT NOT NULL CHECK(deleted IN (0, 1)) ) STRICT; --- CREATE TABLE IF NOT EXISTS "%s_user_changes" ( --- id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, --- timestamp TEXT NOT NULL DEFAULT (%s), --- user_id INTEGER NOT NULL REFERENCES "%s_users"(id), --- attribute TEXT NOT NULL CHECK( --- attribute IN ( --- 'username', --- 'display_name', --- 'picture_uuid', --- 'deleted' --- ) --- ), --- value TEXT NOT NULL, --- op INT NOT NULL CHECK(op IN (0, 1)) --- ) STRICT; --- CREATE TRIGGER IF NOT EXISTS "%s_user_creation" --- AFTER INSERT ON "%s_users" --- BEGIN --- INSERT INTO "%s_user_changes" ( --- user_id, attribute, value, op --- ) VALUES --- (NEW.id, 'username', NEW.username, true), --- (NEW.id, 'display_name', NEW.display_name, true), --- (NEW.id, 'deleted', NEW.deleted, true) --- ; --- END; --- CREATE TRIGGER IF NOT EXISTS "%s_user_creation_picture_uuid" --- AFTER INSERT ON "%s_users" --- WHEN NEW.picture_uuid != NULL --- BEGIN --- INSERT INTO "%s_user_changes" ( --- user_id, attribute, value, op --- ) VALUES --- (NEW.id, 'picture_uuid', NEW.picture_uuid, true) --- ; --- END; --- CREATE TRIGGER IF NOT EXISTS "%s_user_update_username" --- AFTER UPDATE ON "%s_users" --- WHEN OLD.username != NEW.username --- BEGIN --- INSERT INTO "%s_user_changes" ( --- user_id, attribute, value, op --- ) VALUES --- (NEW.id, 'username', OLD.username, false), --- (NEW.id, 'username', NEW.username, true) --- ; --- END; --- CREATE TRIGGER IF NOT EXISTS "%s_user_update_display_name" --- AFTER UPDATE ON "%s_users" --- WHEN OLD.display_name != NEW.display_name --- BEGIN --- INSERT INTO "%s_user_changes" ( --- user_id, attribute, value, op --- ) VALUES --- (NEW.id, 'display_name', OLD.display_name, false), --- (NEW.id, 'display_name', NEW.display_name, true) --- ; --- END; --- CREATE TRIGGER IF NOT EXISTS "%s_user_update_picture_uuid" --- AFTER UPDATE ON "%s_users" --- WHEN OLD.picture_uuid != NEW.picture_uuid --- BEGIN --- INSERT INTO "%s_user_changes" ( --- user_id, attribute, value, op --- ) VALUES --- (NEW.id, 'picture_uuid', OLD.picture_uuid, false), --- (NEW.id, 'picture_uuid', NEW.picture_uuid, true) --- ; --- END; --- CREATE TRIGGER IF NOT EXISTS "%s_user_update_deleted" --- AFTER UPDATE ON "%s_users" --- WHEN OLD.deleted != NEW.deleted --- BEGIN --- INSERT INTO "%s_user_changes" ( --- user_id, attribute, value, op --- ) VALUES --- (NEW.id, 'deleted', OLD.deleted, false), --- (NEW.id, 'deleted', NEW.deleted, true) --- ; --- END; + CREATE TABLE IF NOT EXISTS "%s_user_changes" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT ( + %s + ), + user_id INTEGER NOT NULL, + attribute TEXT NOT NULL CHECK( + attribute IN ( + 'username', + 'display_name', + 'picture_uuid', + 'deleted' + ) + ), + value_text TEXT, + value_blob BLOB, + value_bool INT CHECK(value_bool IN (0, 1)), + op INT NOT NULL CHECK(op IN (0, 1)) + ) STRICT; + CREATE TRIGGER IF NOT EXISTS "%s_user_new" + AFTER INSERT ON "%s_users" + BEGIN + INSERT INTO "%s_user_changes" ( + user_id, attribute, value_text, op + ) VALUES + (NEW.id, 'username', NEW.username, true), + (NEW.id, 'display_name', NEW.display_name, true) + ; + INSERT INTO "%s_user_changes" ( + user_id, attribute, value_bool, op + ) VALUES + (NEW.id, 'deleted', NEW.deleted, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_user_new_picture_uuid" + AFTER INSERT ON "%s_users" + WHEN NEW.picture_uuid IS NOT NULL + BEGIN + INSERT INTO "%s_user_changes" ( + user_id, attribute, value_blob, op + ) VALUES + (NEW.id, 'picture_uuid', NEW.picture_uuid, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_user_update_username" + AFTER UPDATE ON "%s_users" + WHEN OLD.username != NEW.username + BEGIN + INSERT INTO "%s_user_changes" ( + user_id, attribute, value_text, op + ) VALUES + (NEW.id, 'username', OLD.username, false), + (NEW.id, 'username', NEW.username, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_user_update_display_name" + AFTER UPDATE ON "%s_users" + WHEN OLD.display_name != NEW.display_name + BEGIN + INSERT INTO "%s_user_changes" ( + user_id, attribute, value_text, op + ) VALUES + (NEW.id, 'display_name', OLD.display_name, false), + (NEW.id, 'display_name', NEW.display_name, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_user_add_picture_uuid" + AFTER UPDATE ON "%s_users" + WHEN ( + OLD.picture_uuid IS NULL AND + NEW.picture_uuid IS NOT NULL + ) + BEGIN + INSERT INTO "%s_user_changes" ( + user_id, attribute, value_blob, op + ) VALUES + (NEW.id, 'picture_uuid', NEW.picture_uuid, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_user_remove_picture_uuid" + AFTER UPDATE ON "%s_users" + WHEN ( + OLD.picture_uuid IS NOT NULL AND + NEW.picture_uuid IS NULL + ) + BEGIN + INSERT INTO "%s_user_changes" ( + user_id, attribute, value_blob, op + ) VALUES + (NEW.id, 'picture_uuid', OLD.picture_uuid, false) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_user_update_picture_uuid" + AFTER UPDATE ON "%s_users" + WHEN ( + OLD.picture_uuid IS NOT NULL AND + NEW.picture_uuid IS NOT NULL AND + OLD.picture_uuid != NEW.picture_uuid + ) + BEGIN + INSERT INTO "%s_user_changes" ( + user_id, attribute, value_blob, op + ) VALUES + (NEW.id, 'picture_uuid', OLD.picture_uuid, false), + (NEW.id, 'picture_uuid', NEW.picture_uuid, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_user_update_deleted" + AFTER UPDATE ON "%s_users" + WHEN OLD.deleted != NEW.deleted + BEGIN + INSERT INTO "%s_user_changes" ( + user_id, attribute, value_bool, op + ) VALUES + (NEW.id, 'deleted', OLD.deleted, false), + (NEW.id, 'deleted', NEW.deleted, true) + ; + END; + + CREATE TABLE IF NOT EXISTS "%s_sessions" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT ( + %s + ), + -- provided by cracha + session_uuid BLOB NOT NULL UNIQUE, + user_id INTEGER NOT NULL + REFERENCES "%s_users"(id), + finished_at TEXT + ); + CREATE TABLE IF NOT EXISTS "%s_connections" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT ( + %s + ), + uuid BLOB NOT NULL UNIQUE, + finished_at TEXT + ); + CREATE TABLE IF NOT EXISTS "%s_logons" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT ( + %s + ), + session_id INTEGER NOT NULL + REFERENCES "%s_sessions"(id), + connection_id INTEGER NOT NULL + REFERENCES "%s_connections"(id), + UNIQUE (session_id, connection_id) + ) STRICT; CREATE TABLE IF NOT EXISTS "%s_networks" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), + timestamp TEXT NOT NULL DEFAULT ( + %s + ), uuid BLOB NOT NULL UNIQUE, - creator_id INTEGER NOT NULL REFERENCES "%s_users"(id), name TEXT NOT NULL, description TEXT NOT NULL, type TEXT NOT NULL CHECK( type IN ('public', 'private', 'unlisted') - ) + ), + deleted INT NOT NULL CHECK(deleted IN (0, 1)) ) STRICT; + CREATE INDEX IF NOT EXISTS "%s_networks_type" + ON "%s_networks"(type); CREATE TABLE IF NOT EXISTS "%s_network_changes" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - network_id INTEGER NOT NULL - REFERENCES "%s_networks"(id), + timestamp TEXT NOT NULL DEFAULT ( + %s + ), + network_id INTEGER NOT NULL, attribute TEXT NOT NULL CHECK( attribute IN ( 'name', 'description', - 'type' + 'type', + 'deleted', + 'logon_id' -- FIXME ) ), value TEXT NOT NULL, op INT NOT NULL CHECK(op IN (0, 1)) ) STRICT; + CREATE TRIGGER IF NOT EXISTS "%s_network_new" + AFTER INSERT ON "%s_networks" + BEGIN + INSERT INTO "%s_network_changes" ( + network_id, attribute, value, op + ) VALUES + (NEW.id, 'name', NEW.name, true), + (NEW.id, 'description', NEW.description, true), + (NEW.id, 'type', NEW.type, true), + (NEW.id, 'deleted', NEW.deleted, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_network_update_name" + AFTER UPDATE ON "%s_networks" + WHEN OLD.name != NEW.name + BEGIN + INSERT INTO "%s_network_changes" ( + network_id, attribute, value, op + ) VALUES + (NEW.id, 'name', OLD.name, false), + (NEW.id, 'name', NEW.name, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_network_update_description" + AFTER UPDATE ON "%s_networks" + WHEN OLD.description != NEW.description + BEGIN + INSERT INTO "%s_network_changes" ( + network_id, attribute, value, op + ) VALUES + (NEW.id, 'description', OLD.description, false), + (NEW.id, 'description', NEW.description, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_network_update_type" + AFTER UPDATE ON "%s_networks" + WHEN OLD.description != NEW.description + BEGIN + INSERT INTO "%s_network_changes" ( + network_id, attribute, value, op + ) VALUES + (NEW.id, 'type', OLD.type, false), + (NEW.id, 'type', NEW.type, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_network_update_deleted" + AFTER UPDATE ON "%s_networks" + WHEN OLD.deleted != NEW.deleted + BEGIN + INSERT INTO "%s_network_changes" ( + network_id, attribute, value, op + ) VALUES + (NEW.id, 'deleted', OLD.deleted, false), + (NEW.id, 'deleted', NEW.deleted, true) + ; + END; CREATE TABLE IF NOT EXISTS "%s_members" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), + timestamp TEXT NOT NULL DEFAULT ( + %s + ), + uuid BLOB NOT NULL UNIQUE, network_id INTEGER NOT NULL REFERENCES "%s_networks"(id), user_id INTEGER NOT NULL, @@ -472,6 +636,120 @@ func createTablesSQL(prefix string) queryT { UNIQUE (network_id, username, active_uniq), UNIQUE (network_id, user_id) ) STRICT; + CREATE TABLE IF NOT EXISTS "%s_member_changes" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT ( + %s + ), + member_id INTEGER NOT NULL, + attribute TEXT NOT NULL CHECK( + attribute IN ( + 'username', + 'display_name', + 'picture_uuid', + 'status', + 'logon_id' -- FIXME + ) + ), + value_text TEXT, + value_blob BLOB, + op INT NOT NULL CHECK(op IN (0, 1)) + ) STRICT; + CREATE TRIGGER IF NOT EXISTS "%s_member_new" + AFTER INSERT ON "%s_members" + BEGIN + INSERT INTO "%s_member_changes" ( + member_id, attribute, value_text, op + ) VALUES + (NEW.id, 'username', NEW.username, true), + (NEW.id, 'display_name', NEW.display_name, true), + (NEW.id, 'status', NEW.status, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_member_new_picture_uuid" + AFTER INSERT ON "%s_members" + WHEN NEW.picture_uuid IS NOT NULL + BEGIN + INSERT INTO "%s_member_changes" ( + member_id, attribute, value_blob, op + ) VALUES + (NEW.id, 'picture_uuid', NEW.picture_uuid, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_member_update_username" + AFTER UPDATE ON "%s_members" + WHEN OLD.username != NEW.username + BEGIN + INSERT INTO "%s_member_changes" ( + member_id, attribute, value_text, op + ) VALUES + (NEW.id, 'username', OLD.username, false), + (NEW.id, 'username', NEW.username, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_member_update_display_name" + AFTER UPDATE ON "%s_members" + WHEN OLD.display_name != NEW.display_name + BEGIN + INSERT INTO "%s_member_changes" ( + member_id, attribute, value_text, op + ) VALUES + (NEW.id, 'display_name', OLD.display_name, false), + (NEW.id, 'display_name', NEW.display_name, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_member_update_status" + AFTER UPDATE ON "%s_members" + WHEN OLD.status != NEW.status + BEGIN + INSERT INTO "%s_member_changes" ( + member_id, attribute, value_text, op + ) VALUES + (NEW.id, 'status', OLD.status, false), + (NEW.id, 'status', NEW.status, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_member_add_picture_uuid" + AFTER UPDATE ON "%s_members" + WHEN ( + OLD.picture_uuid IS NULL AND + NEW.picture_uuid IS NOT NULL + ) + BEGIN + INSERT INTO "%s_member_changes" ( + member_id, attribute, value_blob, op + ) VALUES + (NEW.id, 'picture_uuid', NEW.picture_uuid, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_member_remove_picture_uuid" + AFTER UPDATE ON "%s_members" + WHEN ( + OLD.picture_uuid IS NOT NULL AND + NEW.picture_uuid IS NULL + ) + BEGIN + INSERT INTO "%s_member_changes" ( + member_id, attribute, value_blob, op + ) VALUES + (NEW.id, 'picture_uuid', OLD.picture_uuid, false) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_member_update_picture_uuid" + AFTER UPDATE ON "%s_members" + WHEN ( + OLD.picture_uuid IS NOT NULL AND + NEW.picture_uuid IS NOT NULL AND + OLD.picture_uuid != NEW.picture_uuid + ) + BEGIN + INSERT INTO "%s_member_changes" ( + member_id, attribute, value_blob, op + ) VALUES + (NEW.id, 'picture_uuid', OLD.picture_uuid, false), + (NEW.id, 'picture_uuid', NEW.picture_uuid, true) + ; + END; CREATE TABLE IF NOT EXISTS "%s_member_roles" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, @@ -480,40 +758,157 @@ func createTablesSQL(prefix string) queryT { role TEXT NOT NULL, UNIQUE (member_id, role) ) STRICT; - - -- FIXME: use a trigger - CREATE TABLE IF NOT EXISTS "%s_member_changes" ( + CREATE TABLE IF NOT EXISTS "%s_member_role_changes" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - member_id INTEGER NOT NULL - REFERENCES "%s_members"(id), - attribute TEXT NOT NULL, + timestamp TEXT NOT NULL DEFAULT ( + %s + ), + role_id INTEGER NOT NULL, + attribute TEXT NOT NULL CHECK( + attribute IN ( + 'role', + 'logon_id' -- FIXME + ) + ), value TEXT NOT NULL, op INT NOT NULL CHECK(op IN (0, 1)) ) STRICT; + CREATE TRIGGER IF NOT EXISTS "%s_member_role_add" + AFTER INSERT ON "%s_member_roles" + BEGIN + INSERT INTO "%s_member_role_changes" ( + role_id, attribute, value, op + ) VALUES + (NEW.id, 'role', NEW.role, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_member_role_remove" + AFTER DELETE ON "%s_member_roles" + BEGIN + INSERT INTO "%s_member_role_changes" ( + role_id, attribute, value, op + ) VALUES + (OLD.id, 'role', OLD.role, false) + ; + END; CREATE TABLE IF NOT EXISTS "%s_channels" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), + timestamp TEXT NOT NULL DEFAULT ( + %s + ), uuid BLOB NOT NULL UNIQUE, - network_id INTEGER -- FIXME NOT NULL + network_id INTEGER NOT NULL REFERENCES "%s_networks"(id), - public_name TEXT UNIQUE, + public_name TEXT, label TEXT NOT NULL, description TEXT NOT NULL, - virtual INT NOT NULL CHECK(virtual IN (0, 1)) + virtual INT NOT NULL CHECK(virtual IN (0, 1)), + UNIQUE (network_id, public_name) ) STRICT; - - -- FIXME: use a trigger CREATE TABLE IF NOT EXISTS "%s_channel_changes" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), - channel_id INTEGER NOT NULL - REFERENCES "%s_channels"(id), - attribute TEXT NOT NULL, - value TEXT NOT NULL, + timestamp TEXT NOT NULL DEFAULT ( + %s + ), + channel_id INTEGER NOT NULL, + attribute TEXT NOT NULL CHECK( + attribute IN ( + 'public_name', + 'label', + 'description', + 'virtual', + 'logon_id' -- FIXME + ) + ), + value_text TEXT, + value_bool INT CHECK(value_bool IN (0, 1)), op INT NOT NULL CHECK(op IN (0, 1)) ) STRICT; + CREATE TRIGGER IF NOT EXISTS "%s_channel_new" + AFTER INSERT ON "%s_channels" + BEGIN + INSERT INTO "%s_channel_changes" ( + channel_id, attribute, value_text, op + ) VALUES + (NEW.id, 'label', NEW.label, true), + (NEW.id, 'description', NEW.description, true) + ; + INSERT INTO "%s_channel_changes" ( + channel_id, attribute, value_bool, op + ) VALUES + (NEW.id, 'virtual', NEW.virtual, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_channel_new_public_name" + AFTER INSERT ON "%s_channels" + WHEN NEW.public_name IS NOT NULL + BEGIN + INSERT INTO "%s_channel_changes" ( + channel_id, attribute, value_text, op + ) VALUES + (NEW.id, 'public_name', NEW.public_name, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_channel_update_label" + AFTER UPDATE ON "%s_channels" + WHEN OLD.label != NEW.label + BEGIN + INSERT INTO "%s_channel_changes" ( + channel_id, attribute, value_text, op + ) VALUES + (NEW.id, 'label', OLD.label, false), + (NEW.id, 'label', NEW.label, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_channel_update_description" + AFTER UPDATE ON "%s_channels" + WHEN OLD.description != NEW.description + BEGIN + INSERT INTO "%s_channel_changes" ( + channel_id, attribute, value_text, op + ) VALUES + (NEW.id, 'description', OLD.description, false), + (NEW.id, 'description', NEW.description, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_channel_update_virtual" + AFTER UPDATE ON "%s_channels" + WHEN OLD.virtual != NEW.virtual + BEGIN + INSERT INTO "%s_channel_changes" ( + channel_id, attribute, value_bool, op + ) VALUES + (NEW.id, 'virtual', OLD.virtual, false), + (NEW.id, 'virtual', NEW.virtual, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_channel_add_public_name" + AFTER UPDATE ON "%s_channels" + WHEN ( + OLD.public_name IS NULL AND + NEW.public_name IS NOT NULL + ) + BEGIN + INSERT INTO "%s_channel_changes" ( + channel_id, attribute, value_text, op + ) VALUES + (NEW.id, 'public_name', NEW.public_name, true) + ; + END; + CREATE TRIGGER IF NOT EXISTS "%s_channel_remove_public_name" + AFTER UPDATE ON "%s_channels" + WHEN ( + OLD.public_name IS NOT NULL AND + NEW.public_name IS NULL + ) + BEGIN + INSERT INTO "%s_channel_changes" ( + channel_id, attribute, value_text, op + ) VALUES + (OLD.id, 'public_name', OLD.public_name, false) + ; + END; CREATE TABLE IF NOT EXISTS "%s_participants" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, @@ -523,29 +918,46 @@ func createTablesSQL(prefix string) queryT { REFERENCES "%s_members"(id), UNIQUE (channel_id, member_id) ) STRICT; + CREATE TABLE IF NOT EXISTS "%s_participant_changes" ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL DEFAULT ( + %s + ), + participant_id INTEGER NOT NULL, + attribute TEXT NOT NULL CHECK( + attribute IN ( + 'connection_id' + ) + ), + value TEXT NOT NULL, + op INT NOT NULL CHECK(op IN (0, 1)) + ) STRICT; - -- FIXME: create database table for connections? - -- A user can have multiple sessions (different browsers, - -- mobile, etc.), and each session has multiple connections, as - -- the user connects and disconnections using the same session - -- id, all while it is valid. - -- FIXME: can a connection have multiple sessions? A long-lived - -- connection that spans multiple sessions would fit into this. CREATE TABLE IF NOT EXISTS "%s_channel_events" ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL DEFAULT (%s), + timestamp TEXT NOT NULL DEFAULT ( + %s + ), uuid BLOB NOT NULL UNIQUE, channel_id INTEGER NOT NULL REFERENCES "%s_channels"(id), - connection_uuid BLOB NOT NULL, -- FIXME: join + source_uuid BLOB NOT NULL, + source_type TEXT NOT NULL CHECK( + source_type IN ( + 'logon' + ) + ), + source_metadata TEXT, type TEXT NOT NULL CHECK( type IN ( 'user-join', 'user-message' ) ), - payload TEXT NOT NULL + payload TEXT NOT NULL, + metadata TEXT ) STRICT; + ` return queryT{ write: fmt.Sprintf( @@ -574,29 +986,113 @@ func createTablesSQL(prefix string) queryT { prefix, prefix, prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + g.SQLiteNow, + prefix, g.SQLiteNow, prefix, prefix, + prefix, g.SQLiteNow, prefix, prefix, + prefix, g.SQLiteNow, prefix, prefix, prefix, prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, g.SQLiteNow, prefix, prefix, g.SQLiteNow, prefix, prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, g.SQLiteNow, prefix, prefix, prefix, prefix, prefix, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + g.SQLiteNow, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + g.SQLiteNow, + prefix, g.SQLiteNow, prefix, ), @@ -612,10 +1108,48 @@ func createTables(db *sql.DB, prefix string) error { }) } +func memberRolesSQL(prefix string) queryT { + const tmpl_read = ` + SELECT role FROM "%s_member_roles" + JOIN "%s_members" ON + "%s_member_roles".member_id = "%s_members".id + WHERE "%s_members".uuid = ? + ORDER BY "%s_member_roles".id; + ` + return queryT{ + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + } +} + +func collectRoles(rows *sql.Rows) ([]string, error) { + roles := []string{} + + for rows.Next() { + var role string + err := rows.Scan(&role) + if err != nil { + rows.Close() + return nil, err + } + + roles = append(roles, role) + } + + return roles, g.WrapErrors(rows.Err(), rows.Close()) +} + func createUserSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_users" ( - uuid, username, display_name, picture_uuid, deleted + user_uuid, username, display_name, picture_uuid, deleted ) VALUES ( ?, ?, ?, NULL, false ) RETURNING id, timestamp; @@ -673,8 +1207,8 @@ func userByUUIDSQL(prefix string) queryT { picture_uuid FROM "%s_users" WHERE - uuid = ? AND - deleted = false; + user_uuid = ? AND + deleted = false; ` return queryT{ read: fmt.Sprintf(tmpl_read, prefix), @@ -775,8 +1309,8 @@ func deleteUserSQL(prefix string) queryT { UPDATE "%s_users" SET deleted = true WHERE - uuid = ? AND - deleted = false + user_uuid = ? AND + deleted = false RETURNING id; ` return queryT{ @@ -805,46 +1339,86 @@ func deleteUserStmt( func addNetworkSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_networks" ( - uuid, name, description, type, creator_id + uuid, name, description, type, deleted ) VALUES ( ?, ?, ?, ?, - ( - SELECT id FROM "%s_users" - WHERE id = ? AND deleted = false - ) - ) RETURNING id, timestamp; - + false + ) RETURNING id; + + WITH creator AS ( + SELECT username, display_name, picture_uuid + FROM "%s_users" + WHERE id = ? AND deleted = false + ), new_network AS ( + SELECT id FROM "%s_networks" WHERE uuid = ? + ) INSERT INTO "%s_members" ( - network_id, user_id, username, display_name, + uuid, network_id, user_id, username, display_name, picture_uuid, status, active_uniq ) VALUES ( - last_insert_rowid(), ?, - ( - SELECT username, display_name, picture_uuid - FROM "%s_users" - WHERE id = ? AND deleted = false - ), + (SELECT id FROM new_network), + ?, + (SELECT username FROM creator), + (SELECT display_name FROM creator), + (SELECT picture_uuid FROM creator), 'active', 'active' - ) RETURNING id, timestamp; + ) RETURNING id; + + WITH new_member AS ( + SELECT id FROM "%s_members" WHERE uuid = ? + ) + INSERT INTO "%s_member_roles" (member_id, role) + VALUES ( + (SELECT id FROM new_member), + 'admin' + ), + ( + (SELECT id FROM new_member), + 'creator' + ) + RETURNING id; + ` + const tmpl_read = ` + SELECT id, timestamp FROM "%s_networks" + WHERE uuid = ? AND deleted = false; ` return queryT{ - write: fmt.Sprintf(tmpl_write, prefix, prefix, prefix, prefix), + write: fmt.Sprintf( + tmpl_write, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + read: fmt.Sprintf(tmpl_read, prefix), } } func addNetworkStmt( cfg dbconfigT, -) (func(userT, newNetworkT) (networkT, error), func() error, error) { +) ( + func(userT, newNetworkT, guuid.UUID) (networkT, error), + func() error, + error, +) { q := addNetworkSQL(cfg.prefix) + readStmt, err := cfg.shared.Prepare(q.read) + if err != nil { + return nil, nil, err + } + privateDB, err := sql.Open(golite.DriverName, cfg.dbpath) if err != nil { + readStmt.Close() return nil, nil, err } @@ -853,10 +1427,10 @@ func addNetworkStmt( fn := func( user userT, newNetwork newNetworkT, + memberID guuid.UUID, ) (networkT, error) { network := networkT{ uuid: newNetwork.uuid, - createdBy: user.uuid, name: newNetwork.name, description: newNetwork.description, type_: newNetwork.type_, @@ -868,87 +1442,35 @@ func addNetworkStmt( newNetwork.description, newNetwork.type_, user.id, + newNetwork.uuid[:], + memberID[:], + user.id, + memberID[:], ) if err != nil { return networkT{}, err } - return network, nil - - /* - member := memberT{ - } - var timestr string - { - FIXME - rows, err := writeStmt.Query( - newNetwork.uuid[:], - newNetwork.name, - newNetwork.description, - newNetwork.type_, - user.id, - ) - if err != nil { - return networkT{}, err - } - defer rows.Close() - - { - if !rows.Next() { - return networkT{}, sql.ErrNoRows - } - - err := rows.Scan(&network.id, ×tr) - if err != nil { - return networkT{}, err - } - - network.timestamp, err = time.Parse( - time.RFC3339Nano, - timestr, - ) - if err != nil { - return networkT{}, err - } - } - - { - if !rows.Next() { - return networkT{}, sql.ErrNoRows - } - - err := rows.Scan(&member.id, ×tr) - if err != nil { - return networkT{}, err - } - - member.timestamp, err = time.Parse( - time.RFC3339Nano, - timestr, - ) - if err != nil { - return networkT{}, err - } - } - - { - if rows.Next() { - return networkT{}, errors.New("FIXME") - } - err := rows.Err() + err = readStmt.QueryRow(network.uuid[:]).Scan( + &network.id, + ×tr, + ) + if err != nil { + return networkT{}, err + } - if err != nil { - return networkT{}, err - } - } + network.timestamp, err = time.Parse(time.RFC3339Nano, timestr) + if err != nil { + return networkT{}, err } - */ + + return network, nil } closeFn := func() error { writeFnClose() - return privateDB.Close() + return g.SomeError(privateDB.Close(), readStmt.Close()) } return fn, closeFn, nil @@ -956,29 +1478,32 @@ func addNetworkStmt( func getNetworkSQL(prefix string) queryT { const tmpl_read = ` + WITH probing_user AS ( + SELECT id FROM "%s_users" + WHERE id = ? AND deleted = false + ), target_network AS ( + SELECT id FROM "%s_networks" + WHERE uuid = ? AND deleted = false + ) SELECT - "%s_networks".id, - "%s_networks".timestamp, - "%s_users".uuid, - "%s_networks".name, - "%s_networks".description, - "%s_networks".type + id, + timestamp, + name, + description, + type FROM "%s_networks" - JOIN "%s_users" ON - "%s_users".id = "%s_networks".creator_id WHERE - "%s_networks".uuid = $networkUUID AND - $userID IN ( - SELECT id FROM "%s_users" - WHERE id = $userID AND deleted = false - ) AND + uuid = ? AND + deleted = false AND + ? IN probing_user AND ( - "%s_networks".type IN ('public', 'unlisted') OR - $userID IN ( + type IN ('public', 'unlisted') OR + ? IN ( SELECT user_id FROM "%s_members" WHERE - user_id = $userID AND - network_id = "%s_networks".id + user_id = ? AND + network_id IN target_network AND + status != 'removed' ) ); ` @@ -1000,6 +1525,13 @@ func getNetworkSQL(prefix string) queryT { prefix, prefix, prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, ), } } @@ -1019,14 +1551,17 @@ func getNetworkStmt( uuid: networkID, } - var ( - timestr string - creator_id_bytes []byte - ) - err := readStmt.QueryRow(networkID[:], user.id).Scan( + var timestr string + err := readStmt.QueryRow( + user.id, + networkID[:], + networkID[:], + user.id, + user.id, + user.id, + ).Scan( &network.id, ×tr, - &creator_id_bytes, &network.name, &network.description, &network.type_, @@ -1034,7 +1569,6 @@ func getNetworkStmt( if err != nil { return networkT{}, err } - network.createdBy = guuid.UUID(creator_id_bytes) network.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { @@ -1057,17 +1591,26 @@ func networkEach(rows *sql.Rows, callback func(networkT) error) error { network networkT timestr string network_id_bytes []byte + deleted bool ) err := rows.Scan( &network.id, ×tr, &network_id_bytes, + &network.name, + &network.description, + &network.type_, + &deleted, ) if err != nil { return g.WrapErrors(rows.Close(), err) } network.uuid = guuid.UUID(network_id_bytes) + if deleted { + return sql.ErrNoRows + } + network.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return g.WrapErrors(rows.Close(), err) @@ -1084,10 +1627,49 @@ func networkEach(rows *sql.Rows, callback func(networkT) error) error { func networksSQL(prefix string) queryT { const tmpl_read = ` - -- FIXME %s + WITH current_user AS ( + SELECT id, deleted FROM "%s_users" WHERE id = ? + ) + SELECT + "%s_networks".id, + "%s_networks".timestamp, + "%s_networks".uuid, + "%s_networks".name, + "%s_networks".description, + "%s_networks".type, + (SELECT deleted FROM current_user) + FROM "%s_networks" + JOIN "%s_members" ON + "%s_networks".id = "%s_members".network_id + WHERE ( + "%s_networks".type = 'public' OR + "%s_networks".id IN ( + SELECT network_id FROM "%s_members" + WHERE user_id IN (SELECT id FROM current_user) + ) + ) AND "%s_networks".deleted = false + ORDER BY "%s_networks".id; ` return queryT{ - read: fmt.Sprintf(tmpl_read, prefix), + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), } } @@ -1102,7 +1684,7 @@ func networksStmt( } fn := func(user userT) (*sql.Rows, error) { - return readStmt.Query(user.uuid[:]) + return readStmt.Query(user.id) } return fn, readStmt.Close, nil @@ -1110,16 +1692,42 @@ func networksStmt( func setNetworkSQL(prefix string) queryT { const tmpl_write = ` - -- FIXME %s + UPDATE "%s_networks" + SET + name = ?, + description = ?, + type = ? + WHERE id = ? AND deleted = false + RETURNING ( + SELECT CASE WHEN EXISTS ( + SELECT role from "%s_member_roles" + WHERE + member_id = ? AND + role IN ( + 'admin', + 'network-settings-update' + ) AND ? IN ( + SELECT network_id + FROM "%s_members" + WHERE + id = ? AND + status = 'active' + ) + ) THEN true ELSE RAISE( + ABORT, + 'member not allowed to update network data' + ) END + ); + ` return queryT{ - write: fmt.Sprintf(tmpl_write, prefix), + write: fmt.Sprintf(tmpl_write, prefix, prefix, prefix), } } func setNetworkStmt( cfg dbconfigT, -) (func(userT, networkT) error, func() error, error) { +) (func(memberT, networkT) error, func() error, error) { q := setNetworkSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) @@ -1127,9 +1735,17 @@ func setNetworkStmt( return nil, nil, err } - fn := func(user userT, network networkT) error { - _, err := writeStmt.Exec(network) - return err + fn := func(actor memberT, network networkT) error { + var _allowed bool + return writeStmt.QueryRow( + network.name, + network.description, + network.type_, + network.id, + actor.id, + network.id, + actor.id, + ).Scan(&_allowed) } return fn, writeStmt.Close, nil @@ -1137,16 +1753,38 @@ func setNetworkStmt( func nipNetworkSQL(prefix string) queryT { const tmpl_write = ` - -- FIXME %s + WITH target_network AS ( + SELECT network_id AS id + FROM "%s_members" + WHERE + id = ? AND + status = 'active' + ) + UPDATE "%s_networks" + SET deleted = true + WHERE id IN target_network AND deleted = false + RETURNING ( + SELECT CASE WHEN EXISTS ( + SELECT role FROM "%s_member_roles" + WHERE + member_id = ? AND + role IN ( + 'admin' + ) + ) THEN true ELSE RAISE( + ABORT, + 'member not allowed to delete network' + ) END + ); ` return queryT{ - write: fmt.Sprintf(tmpl_write, prefix), + write: fmt.Sprintf(tmpl_write, prefix, prefix, prefix), } } func nipNetworkStmt( cfg dbconfigT, -) (func(userT, guuid.UUID) error, func() error, error) { +) (func(memberT) error, func() error, error) { q := nipNetworkSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) @@ -1154,27 +1792,192 @@ func nipNetworkStmt( return nil, nil, err } - fn := func(user userT, networkID guuid.UUID) error { - _, err := writeStmt.Exec(networkID[:]) - return err + fn := func(actor memberT) error { + var _allowed bool + return writeStmt.QueryRow(actor.id, actor.id).Scan(&_allowed) } return fn, writeStmt.Close, nil } +func membershipSQL(prefix string) queryT { + const tmpl_read = ` + SELECT + "%s_members".id, + "%s_members".timestamp, + "%s_members".uuid, + "%s_members".username, + "%s_members".display_name, + "%s_members".picture_uuid, + "%s_members".status + FROM "%s_members" + JOIN "%s_users" ON + "%s_users".id = "%s_members".user_id + JOIN "%s_networks" ON + "%s_networks".id = "%s_members".network_id + WHERE + "%s_members".user_id = ? AND + "%s_members".network_id = ? AND + "%s_members".status = 'active' AND + "%s_users".deleted = false AND + "%s_networks".deleted = false; + ` + return queryT{ + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), + } +} + +func membershipStmt( + cfg dbconfigT, +) (func(userT, networkT) (memberT, error), func() error, error) { + q := membershipSQL(cfg.prefix) + + readStmt, err := cfg.shared.Prepare(q.read) + if err != nil { + return nil, nil, err + } + + rolesStmt, err := cfg.shared.Prepare(memberRolesSQL(cfg.prefix).read) + if err != nil { + readStmt.Close() + return nil, nil, err + } + + fn := func(actor userT, network networkT) (memberT, error) { + member := memberT{} + + var ( + timestr string + member_id_bytes []byte + picture_id_bytes []byte + ) + err := readStmt.QueryRow(actor.id, network.id).Scan( + &member.id, + ×tr, + &member_id_bytes, + &member.username, + &member.displayName, + &picture_id_bytes, + &member.status, + ) + if err != nil { + return memberT{}, err + } + member.uuid = guuid.UUID(member_id_bytes) + + member.timestamp, err = time.Parse(time.RFC3339Nano, timestr) + if err != nil { + return memberT{}, err + } + + rows, err := rolesStmt.Query(member_id_bytes) + if err != nil { + return memberT{}, err + } + + member.roles, err = collectRoles(rows) + if err != nil { + return memberT{}, err + } + + return member, nil + } + + closeFn := func() error { + return g.SomeError( + readStmt.Close(), + rolesStmt.Close(), + ) + } + + return fn, closeFn, nil +} func addMemberSQL(prefix string) queryT { const tmpl_write = ` - -- FIXME %s + WITH target_user AS ( + SELECT id, username, display_name, picture_uuid + FROM "%s_users" + WHERE user_uuid = ? AND deleted = false + ), target_network AS ( + SELECT "%s_members".network_id AS id + FROM "%s_members" + JOIN "%s_networks" ON + "%s_members".network_id = "%s_networks".id + WHERE + "%s_members".id = ? AND + "%s_members".status = 'active' AND + "%s_networks".deleted = false + ) + INSERT INTO "%s_members" ( + uuid, network_id, user_id, username, display_name, + picture_uuid, status, active_uniq + ) VALUES ( + ?, + (SELECT id FROM target_network), + (SELECT id FROM target_user), + ?, + (SELECT display_name FROM target_user), + (SELECT picture_uuid FROM target_user), + 'active', + 'active' + ) RETURNING id, timestamp, display_name, picture_uuid, status, ( + SELECT CASE WHEN EXISTS ( + SELECT role from "%s_member_roles" + WHERE + member_id = ? AND + role IN ( + 'admin', + 'add-member' + ) + ) THEN true ELSE RAISE( + ABORT, + 'member not allowed to add another member' + ) END + ); ` return queryT{ - write: fmt.Sprintf(tmpl_write, prefix), + write: fmt.Sprintf( + tmpl_write, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), } } func addMemberStmt( cfg dbconfigT, -) (func(userT, networkT, newMemberT) (memberT, error), func() error, error) { +) (func(memberT, newMemberT) (memberT, error), func() error, error) { q := addMemberSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) @@ -1182,46 +1985,164 @@ func addMemberStmt( return nil, nil, err } - fn := func( - user userT, - network networkT, - newMember newMemberT, - ) (memberT, error) { + rolesStmt, err := cfg.shared.Prepare(memberRolesSQL(cfg.prefix).read) + if err != nil { + writeStmt.Close() + return nil, nil, err + } + + fn := func(actor memberT, newMember newMemberT) (memberT, error) { member := memberT{ + uuid: newMember.memberID, + username: newMember.username, } - var timestr string - err := writeStmt.QueryRow(network.uuid[:], newMember).Scan( + var ( + timestr string + picture_id_bytes []byte + _allowed bool + ) + err := writeStmt.QueryRow( + newMember.userID[:], + actor.id, + newMember.memberID[:], + newMember.username, + actor.id, + ).Scan( &member.id, ×tr, + &member.displayName, + &picture_id_bytes, + &member.status, + &_allowed, ) if err != nil { return memberT{}, err } + if picture_id_bytes != nil { + pictureID := guuid.UUID(picture_id_bytes) + member.pictureID = &pictureID + } member.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return memberT{}, err } + rows, err := rolesStmt.Query(member.uuid[:]) + if err != nil { + return memberT{}, err + } + + member.roles, err = collectRoles(rows) + if err != nil { + return memberT{}, err + } + return member, nil } + closeFn := func() error { + return g.SomeError( + writeStmt.Close(), + rolesStmt.Close(), + ) + } + + return fn, closeFn, nil +} + +func addRoleSQL(prefix string) queryT { + const tmpl_write = ` + INSERT INTO "%s_member_roles" (member_id, role) + VALUES (?, ?); + ` + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} + +func addRoleStmt( + cfg dbconfigT, +) (func(memberT, string, memberT) error, func() error, error) { + q := addRoleSQL(cfg.prefix) + + writeStmt, err := cfg.shared.Prepare(q.write) + if err != nil { + return nil, nil, err + } + + fn := func(actor memberT, role string, member memberT) error { + // FIXME: do authorization + _, err := writeStmt.Exec(member.id, role) + return err + } + return fn, writeStmt.Close, nil } +func dropRoleSQL(prefix string) queryT { + const tmpl_write = ` + DELETE FROM "%s_member_roles" + WHERE + member_id = ? AND + role = ? + RETURNING 1; + ` + return queryT{ + write: fmt.Sprintf(tmpl_write, prefix), + } +} + +func dropRoleStmt( + cfg dbconfigT, +) (func(memberT, string, memberT) error, func() error, error) { + q := dropRoleSQL(cfg.prefix) + + writeStmt, err := cfg.shared.Prepare(q.write) + if err != nil { + return nil, nil, err + } + + fn := func(actor memberT, role string, member memberT) error { + // FIXME: do authorization + // _, err := writeStmt.Exec(member.id, role) + // return err + var _id int64 + return writeStmt.QueryRow(member.id, role).Scan(&_id) + } + + return fn, writeStmt.Close, nil +} + + func showMemberSQL(prefix string) queryT { const tmpl_read = ` - -- FIXME %s + WITH current_network AS ( + SELECT network_id + FROM "%s_members" + WHERE id = ? + ) + SELECT + id, + timestamp, + username, + display_name, + picture_uuid, + status + FROM "%s_members" + WHERE + uuid = ? AND + network_id IN current_network; ` return queryT{ - read: fmt.Sprintf(tmpl_read, prefix), + read: fmt.Sprintf(tmpl_read, prefix, prefix), } } func showMemberStmt( cfg dbconfigT, -) (func(userT, guuid.UUID) (memberT, error), func() error, error) { +) (func(memberT, guuid.UUID) (memberT, error), func() error, error) { q := showMemberSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) @@ -1229,29 +2150,64 @@ func showMemberStmt( return nil, nil, err } - fn := func(user userT, memberID guuid.UUID) (memberT, error) { + rolesStmt, err := cfg.shared.Prepare(memberRolesSQL(cfg.prefix).read) + if err != nil { + readStmt.Close() + return nil, nil, err + } + + fn := func(actor memberT, memberID guuid.UUID) (memberT, error) { member := memberT{ uuid: memberID, } - var timestr string - err := readStmt.QueryRow(memberID[:]).Scan( + var ( + timestr string + picture_id_bytes []byte + ) + err := readStmt.QueryRow(actor.id, memberID[:]).Scan( &member.id, ×tr, + &member.username, + &member.displayName, + &picture_id_bytes, + &member.status, ) if err != nil { return memberT{}, err } + if picture_id_bytes != nil { + pictureID := guuid.UUID(picture_id_bytes) + // FIXME: test this + member.pictureID = &pictureID + } member.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { return memberT{}, err } - return member, err + rows, err := rolesStmt.Query(memberID[:]) + if err != nil { + return memberT{}, err + } + + member.roles, err = collectRoles(rows) + if err != nil { + return memberT{}, err + } + + return member, nil } - return fn, readStmt.Close, nil + closeFn := func() error { + return g.SomeError( + readStmt.Close(), + rolesStmt.Close(), + ) + } + + return fn, closeFn, nil } func memberEach(rows *sql.Rows, callback func(memberT) error) error { @@ -1261,19 +2217,28 @@ func memberEach(rows *sql.Rows, callback func(memberT) error) error { for rows.Next() { var ( - member memberT - timestr string - member_id_bytes []byte + member memberT + timestr string + member_id_bytes []byte + picture_id_bytes []byte ) err := rows.Scan( &member.id, ×tr, &member_id_bytes, + &member.username, + &member.displayName, + &picture_id_bytes, + &member.status, ) if err != nil { return g.WrapErrors(rows.Close(), err) } - // member.uuid = guuid.UUID(member_id_bytes) FIXME + member.uuid = guuid.UUID(member_id_bytes) + if picture_id_bytes != nil { + pictureID := guuid.UUID(picture_id_bytes) + member.pictureID = &pictureID + } member.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { @@ -1291,16 +2256,46 @@ func memberEach(rows *sql.Rows, callback func(memberT) error) error { func membersSQL(prefix string) queryT { const tmpl_read = ` - -- FIXME %s + WITH target_network AS ( + SELECT "%s_members".network_id + FROM "%s_members" + JOIN "%s_networks" ON + "%s_members".network_id = "%s_networks".id + WHERE + "%s_members".id = ? AND + "%s_networks".deleted = false + ) + SELECT + id, + timestamp, + uuid, + username, + display_name, + picture_uuid, + status + FROM "%s_members" + WHERE + network_id IN target_network AND + status = 'active'; ` return queryT{ - read: fmt.Sprintf(tmpl_read, prefix), + read: fmt.Sprintf( + tmpl_read, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + prefix, + ), } } func membersStmt( cfg dbconfigT, -) (func(userT, guuid.UUID) (*sql.Rows, error), func() error, error) { +) (func(memberT) (*sql.Rows, error), func() error, error) { q := membersSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) @@ -1308,8 +2303,8 @@ func membersStmt( return nil, nil, err } - fn := func(user userT, networkID guuid.UUID) (*sql.Rows, error) { - return readStmt.Query(networkID[:]) + fn := func(actor memberT) (*sql.Rows, error) { + return readStmt.Query(actor.id) } return fn, readStmt.Close, nil @@ -1317,7 +2312,11 @@ func membersStmt( func editMemberSQL(prefix string) queryT { const tmpl_write = ` - -- FIXME %s + UPDATE "%s_members" + SET + status = ? + WHERE id = ? + RETURNING id; ` return queryT{ write: fmt.Sprintf(tmpl_write, prefix), @@ -1326,7 +2325,7 @@ func editMemberSQL(prefix string) queryT { func editMemberStmt( cfg dbconfigT, -) (func(userT, memberT) error, func() error, error) { +) (func(memberT, memberT) error, func() error, error) { q := editMemberSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) @@ -1334,9 +2333,12 @@ func editMemberStmt( return nil, nil, err } - fn := func(user userT, member memberT) error { - _, err := writeStmt.Exec(member) - return err + fn := func(actor memberT, member memberT) error { + var _id int64 + return writeStmt.QueryRow( + member.status, + member.id, + ).Scan(&_id) } return fn, writeStmt.Close, nil @@ -1344,59 +2346,115 @@ func editMemberStmt( func dropMemberSQL(prefix string) queryT { const tmpl_write = ` - -- FIXME + UPDATE "%s_members" SET status = 'removed' + WHERE uuid = ? RETURNING id; + + DELETE FROM "%s_member_roles" + WHERE + role != 'creator' AND + member_id IN ( + SELECT id FROM "%s_members" + WHERE uuid = ? + ) ` return queryT{ - write: fmt.Sprintf(tmpl_write), + write: fmt.Sprintf(tmpl_write, prefix, prefix, prefix), } } func dropMemberStmt( cfg dbconfigT, -) (func(userT, guuid.UUID) error, func() error, error) { +) (func(memberT, guuid.UUID) error, func() error, error) { q := dropMemberSQL(cfg.prefix) - writeStmt, err := cfg.shared.Prepare(q.write) + privateDB, err := sql.Open(golite.DriverName, cfg.dbpath) if err != nil { return nil, nil, err } - fn := func(user userT, memberID guuid.UUID) error { - _, err := writeStmt.Exec(memberID[:]) - return err + writeFn, writeFnClose := execSerialized(q.write, privateDB) + + fn := func(actor memberT, memberID guuid.UUID) error { + err := writeFn(memberID[:], memberID[:]) + if err != nil { + return err + } + + // if res == 0 { // FIXME } + return nil } - return fn, writeStmt.Close, nil + closeFn := func() error { + writeFnClose() + return privateDB.Close() + } + + return fn, closeFn, nil } func addChannelSQL(prefix string) queryT { const tmpl_write = ` + WITH target_network AS ( + SELECT network_id AS id + FROM "%s_members" + WHERE id = ? + ) INSERT INTO "%s_channels" ( - uuid, public_name, label, description, virtual - ) VALUES (?, ?, ?, ?, ?) RETURNING id, timestamp; + uuid, + network_id, + public_name, + label, + description, + virtual + ) VALUES ( + ?, + (SELECT id FROM target_network), + ?, + ?, + ?, + ? + ) RETURNING id, timestamp; + + WITH new_channel AS ( + SELECT id FROM "%s_channels" WHERE uuid = ? + ) + INSERT INTO "%s_participants" (channel_id, member_id) + VALUES ( + (SELECT id FROM new_channel), + ? + ); + ` + const tmpl_read = ` + SELECT id, timestamp FROM "%s_channels" + WHERE uuid = ?; ` return queryT{ - write: fmt.Sprintf(tmpl_write, prefix), + write: fmt.Sprintf(tmpl_write, prefix, prefix, prefix, prefix), + read: fmt.Sprintf(tmpl_read, prefix), } } func addChannelStmt( cfg dbconfigT, -) (func (guuid.UUID, newChannelT) (channelT, error), func() error, error) { +) (func (memberT, newChannelT) (channelT, error), func() error, error) { q := addChannelSQL(cfg.prefix) - writeStmt, err := cfg.shared.Prepare(q.write) + readStmt, err := cfg.shared.Prepare(q.read) if err != nil { return nil, nil, err } - fn := func( - networkID guuid.UUID, - newChannel newChannelT, - ) (channelT, error) { + privateDB, err := sql.Open(golite.DriverName, cfg.dbpath) + if err != nil { + readStmt.Close() + return nil, nil, err + } + + writeFn, writeFnClose := execSerialized(q.write, privateDB) + + fn := func(actor memberT, newChannel newChannelT) (channelT, error) { channel := channelT{ uuid: newChannel.uuid, - // networkID[:], publicName: newChannel.publicName, label: newChannel.label, description: newChannel.description, @@ -1404,13 +2462,24 @@ func addChannelStmt( } var timestr string - err := writeStmt.QueryRow( + err := writeFn( + actor.id, newChannel.uuid[:], newChannel.publicName, newChannel.label, newChannel.description, newChannel.virtual, - ).Scan(&channel.id, ×tr) + newChannel.uuid[:], + actor.id, + ) + if err != nil { + return channelT{}, err + } + + err = readStmt.QueryRow(newChannel.uuid[:]).Scan( + &channel.id, + ×tr, + ) if err != nil { return channelT{}, err } @@ -1423,9 +2492,35 @@ func addChannelStmt( return channel, nil } - return fn, writeStmt.Close, nil + closeFn := func() error { + writeFnClose() + return readStmt.Close() + } + + return fn, closeFn, nil } +/* +func chanByName(prefix string) queryT { + const tmpl_read = ` + SELECT + id, + timestamp, + uuid, + public, + + ` + return queryT{ + read: fmt.Sprintf(tmpl_read, prefix), + } +} + +func chanByStmt( + cfg dbconfigT, +) ( +FIXME +*/ + func channelEach(rows *sql.Rows, callback func(channelT) error) error { if rows == nil { return nil @@ -1436,16 +2531,24 @@ func channelEach(rows *sql.Rows, callback func(channelT) error) error { channel channelT timestr string channel_id_bytes []byte + publicName sql.NullString ) err := rows.Scan( &channel.id, ×tr, &channel_id_bytes, + &publicName, + &channel.label, + &channel.description, + &channel.virtual, ) if err != nil { return g.WrapErrors(rows.Close(), err) } channel.uuid = guuid.UUID(channel_id_bytes) + if publicName.Valid { + channel.publicName = &publicName.String + } channel.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { @@ -1463,16 +2566,40 @@ func channelEach(rows *sql.Rows, callback func(channelT) error) error { func channelsSQL(prefix string) queryT { const tmpl_read = ` - -- FIXME %s + WITH current_network AS ( + SELECT network_id AS id + FROM "%s_members" + WHERE id = ? + ), member_private_channels AS ( + SELECT channel_id AS id + FROM "%s_participants" + WHERE member_id = ? + ) + SELECT + id, + timestamp, + uuid, + public_name, + label, + description, + virtual + FROM "%s_channels" + WHERE + network_id IN current_network AND + ( + public_name IS NOT NULL OR + id IN member_private_channels + ) + ORDER BY id; ` return queryT{ - read: fmt.Sprintf(tmpl_read, prefix), + read: fmt.Sprintf(tmpl_read, prefix, prefix, prefix), } } func channelsStmt( cfg dbconfigT, -) (func(guuid.UUID) (*sql.Rows, error), func() error, error) { +) (func(memberT) (*sql.Rows, error), func() error, error) { q := channelsSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) @@ -1480,38 +2607,96 @@ func channelsStmt( return nil, nil, err } - fn := func(networkID guuid.UUID) (*sql.Rows, error) { - return readStmt.Query(networkID[:]) + fn := func(actor memberT) (*sql.Rows, error) { + return readStmt.Query(actor.id, actor.id) } return fn, readStmt.Close, nil } -func topicSQL(prefix string) queryT { +func setChannelSQL(prefix string) queryT { const tmpl_write = ` - -- FIXME %s + WITH participant_channel AS ( + SELECT channel_id AS id + FROM "%s_participants" + WHERE + member_id = ? AND + channel_id = ? + ) + UPDATE "%s_channels" + SET + description = ?, + public_name = ? + WHERE id IN participant_channel + RETURNING id; + ` + const tmpl_read = ` + SELECT ( + SELECT network_id AS id + FROM "%s_channels" + WHERE id = ? + ) AS channel_network_id, ( + SELECT network_id AS id + FROM "%s_members" + WHERE id = ? + ) AS member_network_id; ` return queryT{ - write: fmt.Sprintf(tmpl_write, prefix), + write: fmt.Sprintf(tmpl_write, prefix, prefix), + read: fmt.Sprintf(tmpl_read, prefix, prefix), } } -func topicStmt( +func setChannelStmt( cfg dbconfigT, -) (func(channelT) error, func() error, error) { - q := topicSQL(cfg.prefix) +) (func(memberT, channelT) error, func() error, error) { + q := setChannelSQL(cfg.prefix) + + readStmt, err := cfg.shared.Prepare(q.read) + if err != nil { + return nil, nil, err + } writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { + readStmt.Close() return nil, nil, err } - fn := func(channel channelT) error { - _, err := writeStmt.Exec(channel) - return err + fn := func(actor memberT, channel channelT) error { + var ( + netid1 sql.NullInt64 + netid2 sql.NullInt64 + ) + err := readStmt.QueryRow(channel.id, actor.id).Scan( + &netid1, + &netid2, + ) + if err != nil { + return err + } + if !netid1.Valid || !netid2.Valid || + netid1.Int64 != netid2.Int64 { + return sql.ErrNoRows + } + + var _id int64 + return writeStmt.QueryRow( + actor.id, + channel.id, + channel.description, + channel.publicName, + ).Scan(&_id) } - return fn, writeStmt.Close, nil + closeFn := func() error { + return g.SomeError( + readStmt.Close(), + writeStmt.Close(), + ) + } + + return fn, closeFn, nil } func endChannelSQL(prefix string) queryT { @@ -1525,7 +2710,7 @@ func endChannelSQL(prefix string) queryT { func endChannelStmt( cfg dbconfigT, -) (func(guuid.UUID) error, func() error, error) { +) (func(memberT, guuid.UUID) error, func() error, error) { q := endChannelSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) @@ -1533,7 +2718,7 @@ func endChannelStmt( return nil, nil, err } - fn := func(channelID guuid.UUID) error { + fn := func(actor memberT, channelID guuid.UUID) error { _, err := writeStmt.Exec(channelID[:]) return err } @@ -1543,43 +2728,108 @@ func endChannelStmt( func joinSQL(prefix string) queryT { const tmpl_write = ` - -- FIXME %s + WITH target_channel AS ( + SELECT id + FROM "%s_channels" + WHERE + uuid = ? AND + public_name IS NOT NULL + ) + INSERT INTO "%s_participants" (channel_id, member_id) + VALUES ( + (SELECT id FROM target_channel), + ? + ) RETURNING id; + ` + const tmpl_read = ` + SELECT ( + SELECT network_id AS id + FROM "%s_channels" + WHERE + uuid = ? AND + public_name IS NOT NULL + ) AS channel_network_id, ( + SELECT network_id AS id + FROM "%s_members" WHERE id = ? + ) AS member_network_id; ` return queryT{ - write: fmt.Sprintf(tmpl_write, prefix), + write: fmt.Sprintf(tmpl_write, prefix, prefix), + read: fmt.Sprintf(tmpl_read, prefix, prefix), } } func joinStmt( cfg dbconfigT, -) (func(guuid.UUID, guuid.UUID) error, func() error, error) { +) (func(memberT, guuid.UUID) error, func() error, error) { q := joinSQL(cfg.prefix) + readStmt, err := cfg.shared.Prepare(q.read) + if err != nil { + return nil, nil, err + } + writeStmt, err := cfg.shared.Prepare(q.write) if err != nil { + readStmt.Close() return nil, nil, err } - fn := func(memberID guuid.UUID, channelID guuid.UUID) error { - _, err := writeStmt.Exec(memberID[:], channelID[:]) - return err + fn := func(actor memberT, channelID guuid.UUID) error { + var ( + netid1 sql.NullInt64 + netid2 sql.NullInt64 + ) + err := readStmt.QueryRow(channelID[:], actor.id).Scan( + &netid1, + &netid2, + ) + if err != nil { + return err + } + + if !netid1.Valid || !netid2.Valid || + netid1.Int64 != netid2.Int64 { + return sql.ErrNoRows + } + + var _id int64 + return writeStmt.QueryRow(channelID[:], actor.id).Scan(&_id) } - return fn, writeStmt.Close, nil + closeFn := func() error { + return g.SomeError( + readStmt.Close(), + writeStmt.Close(), + ) + } + + return fn, closeFn, nil } func partSQL(prefix string) queryT { const tmpl_write = ` - -- FIXME %s + WITH target_channel AS ( + SELECT id + FROM "%s_channels" + WHERE + id = ? AND + virtual = false + ) + DELETE FROM "%s_participants" + WHERE + member_id = ? AND + channel_id IN target_channel + RETURNING 1; ` return queryT{ - write: fmt.Sprintf(tmpl_write, prefix), + write: fmt.Sprintf(tmpl_write, prefix, prefix), } } func partStmt( cfg dbconfigT, -) (func(guuid.UUID, guuid.UUID) error, func() error, error) { +) (func(memberT, channelT) error, func() error, error) { q := partSQL(cfg.prefix) writeStmt, err := cfg.shared.Prepare(q.write) @@ -1587,9 +2837,9 @@ func partStmt( return nil, nil, err } - fn := func(memberID guuid.UUID, channelID guuid.UUID) error { - _, err := writeStmt.Exec(memberID[:], channelID[:]) - return err + fn := func(actor memberT, channel channelT) error { + var _id int64 + return writeStmt.QueryRow(channel.id, actor.id).Scan(&_id) } return fn, writeStmt.Close, nil @@ -1641,7 +2891,7 @@ func namesSQL(prefix string) queryT { func namesStmt( cfg dbconfigT, -) (func(guuid.UUID) (*sql.Rows, error), func() error, error) { +) (func(memberT, guuid.UUID) (*sql.Rows, error), func() error, error) { q := namesSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) @@ -1649,7 +2899,7 @@ func namesStmt( return nil, nil, err } - fn := func(channelID guuid.UUID) (*sql.Rows, error) { + fn := func(actor memberT, channelID guuid.UUID) (*sql.Rows, error) { return readStmt.Query(channelID[:]) } @@ -1659,12 +2909,16 @@ func namesStmt( func addEventSQL(prefix string) queryT { const tmpl_write = ` INSERT INTO "%s_channel_events" ( - uuid, channel_id, connection_uuid, type, payload + uuid, channel_id, source_uuid, source_type, + source_metadata, type, payload, metadata ) VALUES ( ?, (SELECT id FROM "%s_channels" WHERE uuid = ?), ?, ?, + ?, + ?, + ?, ? ) RETURNING id, timestamp; ` @@ -1685,20 +2939,24 @@ func addEventStmt( fn := func(newEvent newEventT) (eventT, error) { event := eventT{ - uuid: newEvent.eventID, - channelID: newEvent.channelID, - connectionID: newEvent.connectionID, - type_: newEvent.type_, - payload: newEvent.payload, + uuid: newEvent.eventID, + channelID: newEvent.channelID, + source: newEvent.source, + type_: newEvent.type_, + payload: newEvent.payload, + metadata: newEvent.metadata, } var timestr string err := writeStmt.QueryRow( newEvent.eventID[:], newEvent.channelID[:], - newEvent.connectionID[:], + newEvent.source.uuid[:], + newEvent.source.type_, + newEvent.source.metadata, newEvent.type_, newEvent.payload, + newEvent.metadata, ).Scan(&event.id, ×tr) if err != nil { return eventT{}, err @@ -1726,14 +2984,12 @@ func eventEach(rows *sql.Rows, callback func(eventT) error) error { timestr string event_id_bytes []byte channel_id_bytes []byte - connection_id_bytes []byte ) err := rows.Scan( &event.id, ×tr, &event_id_bytes, &channel_id_bytes, - &connection_id_bytes, &event.type_, &event.payload, ) @@ -1743,7 +2999,6 @@ func eventEach(rows *sql.Rows, callback func(eventT) error) error { } event.uuid = guuid.UUID(event_id_bytes) event.channelID = guuid.UUID(channel_id_bytes) - event.connectionID = guuid.UUID(connection_id_bytes) event.timestamp, err = time.Parse(time.RFC3339Nano, timestr) if err != nil { @@ -1773,7 +3028,7 @@ func allAfterSQL(prefix string) queryT { "%s_channel_events".timestamp, "%s_channel_events".uuid, "%s_channels".uuid, - "%s_channel_events".connection_uuid, + -- "%s_channel_events".connection_uuid, "%s_channel_events".type, "%s_channel_events".payload FROM "%s_channel_events" @@ -1808,7 +3063,7 @@ func allAfterSQL(prefix string) queryT { func allAfterStmt( cfg dbconfigT, -) (func (guuid.UUID) (*sql.Rows, error), func() error, error) { +) (func (memberT, guuid.UUID) (*sql.Rows, error), func() error, error) { q := allAfterSQL(cfg.prefix) readStmt, err := cfg.shared.Prepare(q.read) @@ -1816,7 +3071,7 @@ func allAfterStmt( return nil, nil, err } - fn := func(eventID guuid.UUID) (*sql.Rows, error) { + fn := func(actor memberT, eventID guuid.UUID) (*sql.Rows, error) { return readStmt.Query(eventID[:]) } @@ -1842,6 +3097,7 @@ func logMessageStmt( return nil, nil, err } + // FIXME: actor? fn := func(user userT, message messageT) error { return nil // FIXME _, err := writeStmt.Exec(user, message) @@ -1881,14 +3137,17 @@ func initDB( networks, networksClose, networksErr := networksStmt(cfg) setNetwork, setNetworkClose, setNetworkErr := setNetworkStmt(cfg) nipNetwork, nipNetworkClose, nipNetworkErr := nipNetworkStmt(cfg) + membership, membershipClose, membershipErr := membershipStmt(cfg) addMember, addMemberClose, addMemberErr := addMemberStmt(cfg) + addRole, addRoleClose, addRoleErr := addRoleStmt(cfg) + dropRole, dropRoleClose, dropRoleErr := dropRoleStmt(cfg) showMember, showMemberClose, showMemberErr := showMemberStmt(cfg) members, membersClose, membersErr := membersStmt(cfg) editMember, editMemberClose, editMemberErr := editMemberStmt(cfg) dropMember, dropMemberClose, dropMemberErr := dropMemberStmt(cfg) addChannel, addChannelClose, addChannelErr := addChannelStmt(cfg) channels, channelsClose, channelsErr := channelsStmt(cfg) - topic, topicClose, topicErr := topicStmt(cfg) + setChannel, setChannelClose, setChannelErr := setChannelStmt(cfg) endChannel, endChannelClose, endChannelErr := endChannelStmt(cfg) join, joinClose, joinErr := joinStmt(cfg) part, partClose, partErr := partStmt(cfg) @@ -1908,14 +3167,17 @@ func initDB( networksClose, setNetworkClose, nipNetworkClose, + membershipClose, addMemberClose, + addRoleClose, + dropRoleClose, showMemberClose, membersClose, editMemberClose, dropMemberClose, addChannelClose, channelsClose, - topicClose, + setChannelClose, endChannelClose, joinClose, partClose, @@ -1937,14 +3199,17 @@ func initDB( networksErr, setNetworkErr, nipNetworkErr, + membershipErr, addMemberErr, + addRoleErr, + dropRoleErr, showMemberErr, membersErr, editMemberErr, dropMemberErr, addChannelErr, channelsErr, - topicErr, + setChannelErr, endChannelErr, joinErr, partErr, @@ -1954,12 +3219,6 @@ func initDB( logMessageErr, ) if err != nil { - ferr := g.SomeError( - - createUserErr, - - ) - fmt.Printf("ferr: %#v\n", ferr) closeFn() return queriesT{}, err } @@ -1986,10 +3245,14 @@ func initDB( defer connMutex.RUnlock() return deleteUser(a) }, - addNetwork: func(a userT, b newNetworkT) (networkT, error) { + addNetwork: func( + a userT, + b newNetworkT, + c guuid.UUID, + ) (networkT, error) { connMutex.RLock() defer connMutex.RUnlock() - return addNetwork(a, b) + return addNetwork(a, b, c) }, getNetwork: func(a userT, b guuid.UUID) (networkT, error) { connMutex.RLock() @@ -2015,35 +3278,42 @@ func initDB( return networkEach(rows, callback) }, - setNetwork: func(a userT, b networkT) error { + setNetwork: func(a memberT, b networkT) error { connMutex.RLock() defer connMutex.RUnlock() return setNetwork(a, b) }, - nipNetwork: func(a userT, b guuid.UUID) error { + nipNetwork: func(a memberT) error { connMutex.RLock() defer connMutex.RUnlock() - return nipNetwork(a, b) + return nipNetwork(a) }, - addMember: func( - a userT, - b networkT, - c newMemberT, - ) (memberT, error) { + membership: func(a userT, b networkT) (memberT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return membership(a, b) + }, + addMember: func(a memberT, b newMemberT) (memberT, error) { + connMutex.RLock() + defer connMutex.RUnlock() + return addMember(a, b) + }, + addRole: func(a memberT, b string, c memberT) error { connMutex.RLock() defer connMutex.RUnlock() - return addMember(a, b, c) + return addRole(a, b, c) }, - showMember: func(a userT, b guuid.UUID) (memberT, error) { + dropRole: func(a memberT, b string, c memberT) error { + connMutex.RLock() + defer connMutex.RUnlock() + return dropRole(a, b, c) + }, + showMember: func(a memberT, b guuid.UUID) (memberT, error) { connMutex.RLock() defer connMutex.RUnlock() return showMember(a, b) }, - members: func( - a userT, - b guuid.UUID, - callback func(memberT) error, - ) error { + members: func(a memberT, callback func(memberT) error) error { var ( err error rows *sql.Rows @@ -2051,7 +3321,7 @@ func initDB( { connMutex.RLock() defer connMutex.RUnlock() - rows, err = members(a, b) + rows, err = members(a) } if err != nil { return err @@ -2059,25 +3329,25 @@ func initDB( return memberEach(rows, callback) }, - editMember: func(a userT, b memberT) error { + editMember: func(a memberT, b memberT) error { connMutex.RLock() defer connMutex.RUnlock() return editMember(a, b) }, - dropMember: func(a userT, b guuid.UUID) error { + dropMember: func(a memberT, b guuid.UUID) error { connMutex.RLock() defer connMutex.RUnlock() return dropMember(a, b) }, addChannel: func( - a guuid.UUID, b newChannelT, + a memberT, b newChannelT, ) (channelT, error) { connMutex.RLock() defer connMutex.RUnlock() return addChannel(a, b) }, channels: func( - a guuid.UUID, + a memberT, callback func(channelT) error, ) error { var ( @@ -2095,27 +3365,31 @@ func initDB( return channelEach(rows, callback) }, - topic: func(a channelT) error { + setChannel: func(a memberT, b channelT) error { connMutex.RLock() defer connMutex.RUnlock() - return topic(a) + return setChannel(a, b) }, - endChannel: func(a guuid.UUID) error { + endChannel: func(a memberT, b guuid.UUID) error { connMutex.RLock() defer connMutex.RUnlock() - return endChannel(a) + return endChannel(a, b) }, - join: func(a guuid.UUID, b guuid.UUID) error { + join: func(a memberT, b guuid.UUID) error { connMutex.RLock() defer connMutex.RUnlock() return join(a, b) }, - part: func(a guuid.UUID, b guuid.UUID) error { + part: func(a memberT, b channelT) error { connMutex.RLock() defer connMutex.RUnlock() return part(a, b) }, - names: func(a guuid.UUID, callback func(memberT) error) error { + names: func( + a memberT, + b guuid.UUID, + callback func(memberT) error, + ) error { var ( err error rows *sql.Rows @@ -2123,7 +3397,7 @@ func initDB( { connMutex.RLock() defer connMutex.RUnlock() - rows, err = names(a) + rows, err = names(a, b) } if err != nil { return err @@ -2137,7 +3411,8 @@ func initDB( return addEvent(a) }, allAfter: func( - a guuid.UUID, + a memberT, + b guuid.UUID, callback func(eventT) error, ) error { var ( @@ -2147,7 +3422,7 @@ func initDB( { connMutex.RLock() defer connMutex.RUnlock() - rows, err = allAfter(a) + rows, err = allAfter(a, b) } if err != nil { return err @@ -2236,19 +3511,110 @@ func initListeners( }, nil } -func makeReceivers() receiversT { - var rwmutex sync.Mutex - return receiversT{ - add: func(receiver receiverT) { +type stateT struct{ + connected func(*connectionT) + disconnect func(*connectionT) + authenticated func(*connectionT) + subscribe func(string, []string) + members func(string) []string + connections func(string) []guuid.UUID + connection func(guuid.UUID) *connectionT +} + +// TODO: key for members should be the channelID, not its name +type stateDataT struct{ + connections map[guuid.UUID]*connectionT + users map[string][]guuid.UUID + members map[string]map[string][]guuid.UUID +} + +// TODO: lock is global, should be by network +func newState() stateT { + var rwmutex sync.RWMutex + state := stateDataT{ + connections: map[guuid.UUID]*connectionT{}, + users: map[string][]guuid.UUID{}, + members: map[string]map[string][]guuid.UUID{}, + } + return stateT{ + connected: func(connection *connectionT) { + rwmutex.Lock() + defer rwmutex.Unlock() + state.connections[connection.uuid] = connection }, - remove: func(receiver receiverT) { + disconnect: func(connection *connectionT) { + { + rwmutex.Lock() + defer rwmutex.Unlock() + delete(state.connections, connection.uuid) + delete(state.users, connection.user.username) + delete(state.members, connection.user.username) + } + err := connection.conn.Close() + if err != nil { + g.Warning( + "Failed to close the connection", + "close-error", + "from", "daemon", + "err", err, + ) + } }, - get: func(guuid.UUID) []receiverT{ - return nil + authenticated: func(connection *connectionT) { + username := connection.user.username + rwmutex.Lock() + defer rwmutex.Unlock() + if state.users[username] == nil { + state.users[username] = []guuid.UUID{} + } + state.users[username] = append( + state.users[username], + connection.uuid, + ) }, - close: func() { + subscribe: func( + username string, + channelNames []string, + ) { rwmutex.Lock() defer rwmutex.Unlock() + for _, channelName := range channelNames { + if state.members[channelName] == nil { + state.members[channelName] = + map[string][]guuid.UUID{} + } + state.members[channelName][username] = + state.users[username] + } + }, + members: func(channelName string) []string { + rwmutex.RLock() + defer rwmutex.RUnlock() + usernames := make( + []string, + len(state.members[channelName]), + ) + i := 0 + for username, _ := range state.members[channelName] { + usernames[i] = username + i++ + } + return usernames + }, + connections: func(username string) []guuid.UUID { + rwmutex.RLock() + defer rwmutex.RUnlock() + connections := make( + []guuid.UUID, + len(state.users[username]), + ) + copy(connections, state.users[username]) + return connections + }, + connection: func(connectionID guuid.UUID) *connectionT { + rwmutex.RLock() + defer rwmutex.RUnlock() + return state.connections[connectionID] }, } } @@ -2311,7 +3677,8 @@ func NewWithPrefix( } consumers := buildConsumers(prefix) - receivers := makeReceivers() + state := newState() + // receivers := makeReceivers() metrics := buildMetrics(prefix) // logger := g.NewLogger("prefix", prefix, "program", "papod") @@ -2321,7 +3688,8 @@ func NewWithPrefix( queries: queries, listeners: listeners, consumers: consumers, - receivers: receivers, + state: state, + // receivers: receivers, metrics: metrics, // logger: logger, }, nil @@ -2355,29 +3723,36 @@ func splitOnRawMessage(data []byte, atEOF bool) (int, []byte, error) { return advance, token, error } +func splitCommas(r rune) bool { + return r == ',' +} + func splitSpaces(r rune) bool { return r == ' ' } -func parseMessageParams(params string) messageParamsT { +func parseMessageParams(params string) []string { const sep = " :" - var middle string - var trailing string - idx := strings.Index(params, sep) if idx == -1 { - middle = params - trailing = "" + return strings.FieldsFunc(params, splitSpaces) } else { - middle = params[:idx] - trailing = params[idx + len(sep):] + middle := params[:idx] + trailing := params[idx + len(sep):] + return append( + strings.FieldsFunc(middle, splitSpaces), + trailing, + ) } +} - return messageParamsT{ - middle: strings.FieldsFunc(middle, splitSpaces), - trailing: trailing, +func stripBlankParams(params []string) []string { + if len(params) == 1 && len(params[0]) == 0 { + return []string{} } + + return params } var messageRegex = regexp.MustCompilePOSIX( @@ -2396,7 +3771,7 @@ func parseMessage(rawMessage string) (messageT, error) { msg = messageT{ prefix: components[2], command: components[3], - params: parseMessageParams(components[4]), + params: stripBlankParams(parseMessageParams(components[4])), raw: rawMessage, } return msg, nil @@ -2418,332 +3793,470 @@ func asReply(event eventT) replyT { return replyT{} } -func broadcastEvent(event eventT, receiversFn func(guuid.UUID) []receiverT) { - message := asMessage(event) - for _, receiver := range receiversFn(event.channelID) { - // FIXME: - // is this death by a thousand goroutines? Is the runtime - // able to handle the creation and destruction of hundreds of - // thousands of goroutines per second? - go receiver.send(message) + +/// Is this death by a thousand goroutines? Is the runtime able to handle the +/// creation and destruction of hundreds of thousands of goroutines per second? +/// For now, we'll assume that Go's (gc) runtime, scheduler and garbage +/// collector are capable of working together to make sure this isn't a +/// catastrophe. +func broadcastMessage( + message messageT, + channelName string, + usersFn func(string) []string, + connectionIDsFn func(string) []guuid.UUID, + connectionFn func(guuid.UUID) *connectionT, +) { + for _, username := range usersFn(channelName) { + for _, connectionID := range connectionIDsFn(username) { + connection := connectionFn(connectionID) + if connection == nil { + continue + } + + go connection.send(message) + } } } -var ( - replyErrUnknown = replyT{ - command: "421", - params: messageParamsT{ - middle: []string{}, - trailing: "Unknown command", + +/* +Intentionally not implemented: + +- RPL_BOUNCE + +*/ + +// FIXME: add check for minRPL... +const minRPL_WELCOME = 0 +func _RPL_WELCOME(connection *connectionT, msg messageT) replyT { + return replyT{ + command: "001", + params: []string{ + connection.user.username, + "Welcome to the Internet Relay Network " + + connection.user.username, }, } - replyErrNotRegistered = replyT{ - command: "451", - params: messageParamsT{ - middle: []string{}, - trailing: "You have not registered", +} + +const minRPL_YOURHOST = 0 +func _RPL_YOURHOST(connection *connectionT, msg messageT) replyT { + return replyT{ + command: "002", + params: []string{ + connection.user.username, + "Your host is FIXME, running version " + + Version, + }, + } +} + +const minRPL_CREATED = 0 +func _RPL_CREATED(connection *connectionT, msg messageT) replyT { + return replyT{ + command: "003", + params: []string{ + connection.user.username, + "This server was create FIXME", + }, + } +} + +const minRPL_MYINFO = 0 +func _RPL_MYINFO(connection *connectionT, msg messageT) replyT { + return replyT{ + command: "004", + params: []string{ + connection.user.username, + "FIXME " + Version + " i x", }, } - replyErrFileError = replyT{ +} + +const minRPL_UNAWAY = 0 +func _RPL_UNAWAY(connection *connectionT, msg messageT) replyT { + return replyT{ + command: "305", + params: []string{ + connection.user.username, + "You are no longer marked as away", + }, + } +} + +const minRPL_NOWAWAY = 0 +func _RPL_NOWAWAY(connection *connectionT, msg messageT) replyT { + return replyT{ + command: "306", + params: []string{ + connection.user.username, + "You have been marked as away", + }, + } +} + +const minRPL_WHOISUSER = 1 +func _RPL_WHOISUSER(connection *connectionT, msg messageT) replyT { + user := msg.params[0] + return replyT{ + command: "311", + params: []string{ + connection.user.username, + user, + user, + "samehost", + "*", + "my real name is: " + user, + }, + } +} + +const minRPL_WHOISSERVER = 1 +func _RPL_WHOISSERVER(connection *connectionT, msg messageT) replyT { + user := msg.params[0] + return replyT{ + command: "312", + params: []string{ + connection.user.username, + user, + "stillsamehost", + "some server info", + }, + } +} + +const minRPL_ENDOFWHOIS = 1 +func _RPL_ENDOFWHOIS(connection *connectionT, msg messageT) replyT { + user := msg.params[0] + return replyT{ + command: "318", + params: []string{ + connection.user.username, + user, + "End of WHOIS list", + }, + } +} + +const minRPL_WHOISCHANNELS = 1 +func _RPL_WHOISCHANNELS(connection *connectionT, msg messageT) replyT { + user := msg.params[0] + return replyT{ + command: "319", + params: []string{ + connection.user.username, + user, + "#default", + }, + } +} + +const minRPL_CHANNELMODEIS = 1 +func _RPL_CHANNELMODEIS(connection *connectionT, msg messageT) replyT { + channel := msg.params[0] + return replyT{ + command: "324", + params: []string{ + connection.user.username, + channel, + "+Cnst", + }, + } +} + +const minRPL_NOTOPIC = 1 +func _RPL_NOTOPIC(connection *connectionT, msg messageT) replyT { + channel := msg.params[0] + return replyT{ + command: "331", + params: []string{ + connection.user.username, + channel, + "No topic is set", + }, + } +} + +const minRPL_NAMREPLY = 1 +func _RPL_NAMREPLY(connection *connectionT, msg messageT) replyT { + channel := msg.params[0] + return replyT{ + command: "353", + params: []string{ + connection.user.username, + "=", + channel, + connection.user.username + " virtualuser", + }, + } +} + +const minRPL_ENDOFNAMES = 1 +func _RPL_ENDOFNAMES(connection *connectionT, msg messageT) replyT { + channel := msg.params[0] + return replyT{ + command: "366", + params: []string{ + connection.user.username, + channel, + "End of NAMES list", + }, + } +} + +const minERR_UNKNOWNCOMMAND = 0 +func _ERR_UNKNOWNCOMMAND(connection *connectionT, msg messageT) replyT { + return replyT{ + command: "421", + params: []string{ + connection.user.username, + "Unknown command", + }, + } +} + +const minERR_FILEERROR = 0 +func _ERR_FILEERROR(connection *connectionT, msg messageT) replyT { + return replyT{ command: "424", - params: messageParamsT{ - middle: []string{}, - trailing: "File error doing query on database", + params: []string{ + "File error doing query on database", }, } - RPL_WELCOME = replyT{ - command: "001", - params: messageParamsT{ - middle: []string{}, - trailing: "", +} + +const minERR_NOTREGISTERED = 0 +func _ERR_NOTREGISTERED(connection *connectionT, msg messageT) replyT { + return replyT{ + command: "451", + params: []string{ + "You have not registered", }, } -) +} -func handleUnknown( - papod papodT, - connection *connectionT, - msg messageT, -) ([]replyT, error) { - // FIXME: user doesn't exist when unauthenticated - err := papod.queries.logMessage(userT{ }, msg) - if err != nil { - g.Warning( - "Failed to log message", fmt.Sprintf("%#v", msg), - "group-as", "db-write", - "handler-action", "log-and-ignore", - "connection", connection.uuid.String(), - "err", err, - ) +const minERR_NEEDMOREPARAMS = 0 +func _ERR_NEEDMOREPARAMS(connection *connectionT, msg messageT) replyT { + return replyT{ + command: "461", + params: []string{ + msg.command, + "Not enough parameters", + }, } +} + - return []replyT{ replyErrUnknown }, nil +func _CAP(connection *connectionT, msg messageT) replyT { + return replyT{ + command: "CAP", + params: []string { + "*", + "LS", + }, + } +} + +const minPONG = 0 +func _PONG(connection *connectionT, msg messageT) replyT { + return replyT{ + command: "PONG", + params: msg.params, + } } + +const minUSER = 4 func handleUSER( papod papodT, connection *connectionT, msg messageT, -) ([]replyT, error) { - u := connection.user.username - m := []string{ u } - return []replyT{ replyT{ - command: "001", - params: messageParamsT{ - middle: m, - trailing: "Welcome to the Internet Relay Network " + u, - }, - }, replyT{ - command: "002", - params: messageParamsT{ - middle: m, - trailing: "Your host is FIXME, running version " + - Version, - }, - }, replyT{ - command: "003", - params: messageParamsT{ - middle: m, - trailing: "This server was create FIXME", - }, - }, replyT{ - command: "004", - params: messageParamsT{ - middle: m, - trailing: "FIXME " + Version + " i x", - }, - }, }, nil +) ([]replyT, bool, error) { + return []replyT{ + _RPL_WELCOME (connection, msg), + _RPL_YOURHOST(connection, msg), + _RPL_CREATED (connection, msg), + _RPL_MYINFO (connection, msg), + }, false, nil } +const minNICK = 1 func handleNICK( papod papodT, connection *connectionT, msg messageT, -) ([]replyT, error) { - connection.user.username = msg.params.middle[0] - return []replyT{}, nil +) ([]replyT, bool, error) { + connection.user.username = msg.params[0] + return []replyT{}, false, nil } +const minPRIVMSG = 2 func handlePRIVMSG( papod papodT, connection *connectionT, msg messageT, -) ([]replyT, error) { - // FIXME: check missing params +) ([]replyT, bool, error) { // FIXME: check if user is member of channel, and is authorized to post // FIXME: adapt to handle multiple targets - return []replyT{}, nil - - event, err := papod.queries.addEvent(asNewEvent(msg)) - if err != nil { - // FIXME: not allowed reply per RFC 1459, check other specs - return []replyT{ replyErrFileError }, nil - } - - go broadcastEvent(event, papod.receivers.get) - - reply := asReply(event) - return []replyT{ reply }, nil + go broadcastMessage( + msg, + msg.params[0], + papod.state.members, + papod.state.connections, + papod.state.connection, + ) + return []replyT{}, false, nil } +const minTOPIC = 2 func handleTOPIC( papod papodT, connection *connectionT, msg messageT, -) ([]replyT, error) { - return []replyT{ replyT{ - command: "JOIN", - params: messageParamsT{ - middle: []string{ msg.params.middle[0] }, - trailing: "", - }, - } }, nil +) ([]replyT, bool, error) { + return []replyT{ + _RPL_NOTOPIC(connection, msg), + }, false, nil } +const minJOIN = 1 func handleJOIN( papod papodT, connection *connectionT, msg messageT, -) ([]replyT, error) { - u := connection.user.username - channel := msg.params.middle[0] - return []replyT{ replyT{ - command: "JOIN", - params: messageParamsT{ - middle: []string{ channel }, - trailing: "", - }, - }, replyT{ - command: "331", - params: messageParamsT{ - middle: []string{ u, channel }, - trailing: "No topic is set", - }, - }, replyT{ - command: "353", - params: messageParamsT{ - middle: []string{ u, "=", channel }, - trailing: u + " virtualuser", - }, - }, replyT{ - command: "366", - params: messageParamsT{ - middle: []string{ u, channel }, - trailing: "End of NAMES list", - }, - } }, nil - - - member, err := papod.queries.addMember( - *connection.user, - networkT{}, - newMemberT{}, - ) - if err != nil { - // FIXME: not allowed per RFC 1459 - return []replyT{ replyErrFileError }, nil - } - event := joinEvent(member) - - papod.metrics.nicksInChannel.Inc() - - go broadcastEvent(event, papod.receivers.get) +) ([]replyT, bool, error) { + // FIXME: add to database + channels := strings.FieldsFunc(msg.params[0], splitCommas) + papod.state.subscribe(connection.user.username, channels) - reply := asReply(event) - return []replyT{ reply }, nil + return []replyT{ + _RPL_NOTOPIC (connection, msg), + _RPL_NAMREPLY (connection, msg), + _RPL_ENDOFNAMES(connection, msg), + }, false, nil } +const minMODE = 1 func handleMODE( papod papodT, connection *connectionT, msg messageT, -) ([]replyT, error) { - u := connection.user.username - channel := msg.params.middle[0] - return []replyT{ replyT{ - command: "324", - params: messageParamsT{ - middle: []string{ u, channel, "+Cnst" }, - trailing: "", - }, - } }, nil +) ([]replyT, bool, error) { + return []replyT{ + _RPL_CHANNELMODEIS(connection, msg), + }, false, nil } +const minWHOIS = 1 func handleWHOIS( papod papodT, connection *connectionT, msg messageT, -) ([]replyT, error) { - u := connection.user.username - user := msg.params.middle[0] - return []replyT{ replyT{ - command: "311", - params: messageParamsT{ - middle: []string{ u, user, user, "samehost", "*" }, - trailing: "my real name is: " + user, - }, - }, replyT{ - command: "312", - params: messageParamsT{ - middle: []string{ u, user, "stillsamehost" }, - trailing: "some server info", - }, - }, replyT{ - command: "319", - params: messageParamsT{ - middle: []string{ u, user }, - trailing: "#default", - }, - }, replyT{ - command: "318", - params: messageParamsT{ - middle: []string{ u, user }, - trailing: "End of WHOIS list", - }, - } }, nil +) ([]replyT, bool, error) { + return []replyT{ + _RPL_WHOISUSER (connection, msg), + _RPL_WHOISSERVER (connection, msg), + _RPL_WHOISCHANNELS(connection, msg), + _RPL_ENDOFWHOIS (connection, msg), + }, false, nil } +const minAWAY = 0 func handleAWAY( papod papodT, connection *connectionT, msg messageT, -) ([]replyT, error) { - u := connection.user.username - - if msg.params.trailing == "" { - return []replyT{ replyT{ - command: "305", - params: messageParamsT{ - middle: []string{ u }, - trailing: "You are no longer marked as away", - }, - } }, nil - } else { - return []replyT{ replyT{ - command: "306", - params: messageParamsT{ - middle: []string{ u }, - trailing: "You have been marked as away", - }, - } }, nil +) ([]replyT, bool, error) { + replyFn := _RPL_NOWAWAY + + if len(msg.params) == 0 { + replyFn = _RPL_UNAWAY } + + return []replyT{ + replyFn(connection, msg), + }, false, nil } +const minPING = 0 func handlePING( papod papodT, connection *connectionT, msg messageT, -) ([]replyT, error) { - return []replyT{ { - command: "PONG", - params: messageParamsT{ - middle: []string{}, - trailing: msg.params.middle[0], - }, - } }, nil +) ([]replyT, bool, error) { + return []replyT{ + _PONG(connection, msg), + }, false, nil } +const minQUIT = 0 func handleQUIT( papod papodT, connection *connectionT, msg messageT, -) ([]replyT, error) { - connection.conn.Close() - return []replyT{}, nil +) ([]replyT, bool, error) { + return []replyT{}, true, nil } +const minCAP = 1 func handleCAP( papod papodT, connection *connectionT, msg messageT, -) ([]replyT, error) { - if msg.params.middle[0] == "END" { - return nil, nil +) ([]replyT, bool, error) { + if msg.params[0] == "END" { + return nil, false, nil } - return []replyT{ replyT{ - command: "CAP", - params: messageParamsT{ - middle: []string { "*", "LS" }, - trailing: "", - }, - } }, nil + return []replyT{ + _CAP(connection, msg), + }, false, nil } + func authRequired( - fn func(papodT, *connectionT, messageT) ([]replyT, error), -) func(papodT, *connectionT, messageT) ([]replyT, error) { + fn func( + papodT, + *connectionT, + messageT, + ) ([]replyT, bool, error), +) func(papodT, *connectionT, messageT) ([]replyT, bool, error) { return func( papod papodT, connection *connectionT, - message messageT, - ) ([]replyT, error) { + msg messageT, + ) ([]replyT, bool, error) { if connection.user == nil { - return []replyT{ replyErrNotRegistered }, nil + return []replyT{ + _ERR_NOTREGISTERED(connection, msg), + }, false, nil } - - return fn(papod, connection, message) + + return fn(papod, connection, msg) + } +} + +func minArgs( + count int, + fn func( + papodT, + *connectionT, + messageT, + ) ([]replyT, bool, error), +) func(papodT, *connectionT, messageT) ([]replyT, bool, error) { + return func( + papod papodT, + connection *connectionT, + msg messageT, + ) ([]replyT, bool, error) { + if len(msg.params) < count { + return []replyT{ + _ERR_NEEDMOREPARAMS(connection, msg), + }, false, nil + } + + return fn(papod, connection, msg) } } @@ -2751,23 +4264,45 @@ var commands = map[string]func( papodT, *connectionT, messageT, -) ([]replyT, error) { - "USER": handleUSER, - "NICK": handleNICK, - "QUIT": handleQUIT, - "CAP": handleCAP, - "AWAY": authRequired(handleAWAY), - "PRIVMSG": authRequired(handlePRIVMSG), - "PING": authRequired(handlePING), - "JOIN": authRequired(handleJOIN), - "MODE": authRequired(handleMODE), - "TOPIC": authRequired(handleTOPIC), - "WHOIS": authRequired(handleWHOIS), +) ([]replyT, bool, error) { + "USER": minArgs(minUSER, handleUSER), + "NICK": minArgs(minNICK, handleNICK), + "QUIT": minArgs(minQUIT, handleQUIT), + "CAP": minArgs(minCAP, handleCAP), + "AWAY": authRequired(minArgs(minAWAY, handleAWAY)), + "PRIVMSG": authRequired(minArgs(minPRIVMSG, handlePRIVMSG)), + "PING": authRequired(minArgs(minPING, handlePING)), + "JOIN": authRequired(minArgs(minJOIN, handleJOIN)), + "MODE": authRequired(minArgs(minMODE, handleMODE)), + "TOPIC": authRequired(minArgs(minTOPIC, handleTOPIC)), + "WHOIS": authRequired(minArgs(minWHOIS, handleWHOIS)), +} + +func handleUnknown( + papod papodT, + connection *connectionT, + msg messageT, +) ([]replyT, bool, error) { + // FIXME: user doesn't exist when unauthenticated + err := papod.queries.logMessage(userT{ }, msg) + if err != nil { + g.Warning( + "Failed to log message", fmt.Sprintf("%#v", msg), + "group-as", "db-write", + "handler-action", "log-and-ignore", + "connection", connection.uuid.String(), + "err", err, + ) + } + + return []replyT{ + _ERR_UNKNOWNCOMMAND(connection, msg), + }, false, nil } func actionFnFor( command string, -) func(papodT, *connectionT, messageT) ([]replyT, error) { +) func(papodT, *connectionT, messageT) ([]replyT, bool, error) { fn := commands[command] if fn != nil { return fn @@ -2776,52 +4311,77 @@ func actionFnFor( return handleUnknown } -func replyString(reply replyT) string { - if reply.params.trailing == "" { - return fmt.Sprintf( - "%s %s\r\n", - reply.command, - strings.Join(reply.params.middle, " "), - ) - } else { - return fmt.Sprintf( - "%s %s :%s\r\n", - reply.command, - strings.Join(reply.params.middle, " "), - reply.params.trailing, - ) +func addTrailingSeparator(strs []string) { + if len(strs) == 0 { + return } + + last := strs[len(strs) - 1] + if strings.Contains(last, " ") && last[0] != ':' { + strs[len(strs) - 1] = ":" + last + } +} + +func (r replyT) String() string { + addTrailingSeparator(r.params) + return fmt.Sprintf( + "%s %s\r\n", + r.command, + strings.Join(r.params, " "), + ) +} + +func (m messageT) logAttributes() slog.Attr { + return slog.Group( + "message", + "prefix", m.prefix, + "raw", m.raw, + "params", m.params, + ) } -func processMessage(papod papodT, connection *connectionT, rawMessage string) { +func (r replyT) logAttributes() slog.Attr { + return slog.Group( + "reply", + "command", r.command, + "params", r.params, + ) +} + +func processMessage( + papod papodT, + connection *connectionT, + rawMessage string, +) { msg, err := parseMessage(rawMessage) if err != nil { g.Info( - "Error processing message", - "process-message", - "text", rawMessage, + "Error parsing message", "parse-message-error", + slog.Group( + "message", + "text", rawMessage, + ), "err", err, ) return } - - papod.metrics.receivedMessage( - "message", fmt.Sprintf("%#v", msg), - "text", rawMessage, - ) + papod.metrics.receivedMessage(msg.logAttributes()) var replyErrors []error - replies, actionErr := actionFnFor(msg.command)(papod, connection, msg) + replies, shouldClose, actionErr := actionFnFor(msg.command)( + papod, + connection, + msg, + ) for _, reply := range replies { - text := replyString(reply) - _, err = io.WriteString(connection.conn, text) + _, err = io.WriteString(connection.conn, reply.String()) if err != nil { replyErrors = append(replyErrors, err) } + papod.metrics.sentReply( - "message", rawMessage, - "reply", fmt.Sprintf("%#v", reply), - "text", text, + msg.logAttributes(), + reply.logAttributes(), ) } @@ -2842,16 +4402,20 @@ func processMessage(papod papodT, connection *connectionT, rawMessage string) { ) } - // FIXME: Close the connection + papod.state.disconnect(connection) + return + } + + if shouldClose { + papod.state.disconnect(connection) } } func handleConnection(papod papodT, conn net.Conn) { connection := connectionT{ - conn: conn, uuid: guuid.New(), - // user: nil, // FIXME: SASL shenanigan probably goes here - user: &userT{}, + user: &userT{}, // TODO: SASL shenanigan probably goes here + conn: conn, } scanner := bufio.NewScanner(conn) scanner.Split(splitOnRawMessage) @@ -2860,10 +4424,6 @@ func handleConnection(papod papodT, conn net.Conn) { } } -func handleCommand(papod papodT, conn net.Conn) { - // FIXME -} - func daemonLoop(papod papodT) { for { conn, err := papod.listeners.daemon.Accept() @@ -2880,7 +4440,6 @@ func daemonLoop(papod papodT) { ) continue } - // FIXME: where does it get closed go handleConnection(papod, conn) } } @@ -2901,7 +4460,7 @@ func commanderLoop(papod papodT) { ) continue } - go handleCommand(papod, conn) + go handleConnection(papod, conn) } } @@ -2939,7 +4498,6 @@ func (papod papodT) Start() error { "golite", golite.Version, "guuid", guuid.Version, "papod", Version, - "this", Version, ), ) |