diff --git a/server/src/db.rs b/server/src/db.rs index 9b61018701..300c8de6d5 100644 --- a/server/src/db.rs +++ b/server/src/db.rs @@ -21,6 +21,12 @@ pub struct Signup { pub about: String, } +#[derive(Debug, FromRow, Serialize)] +pub struct Channel { + id: i32, + pub name: String, +} + #[derive(Debug, FromRow)] pub struct ChannelMessage { id: i32, @@ -158,6 +164,7 @@ impl Db { // orgs + #[cfg(test)] pub async fn create_org(&self, name: &str, slug: &str) -> Result { let query = " INSERT INTO orgs (name, slug) @@ -172,6 +179,7 @@ impl Db { .map(OrgId) } + #[cfg(test)] pub async fn add_org_member(&self, org_id: OrgId, user_id: UserId) -> Result<()> { let query = " INSERT INTO org_memberships (org_id, user_id) @@ -187,6 +195,7 @@ impl Db { // channels + #[cfg(test)] pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result { let query = " INSERT INTO channels (owner_id, owner_is_user, name) @@ -201,6 +210,42 @@ impl Db { .map(ChannelId) } + pub async fn get_channels_for_user(&self, user_id: UserId) -> Result> { + let query = " + SELECT + channels.id, channels.name + FROM + channel_memberships, channels + WHERE + channel_memberships.user_id = $1 AND + channel_memberships.channel_id = channels.id + "; + sqlx::query_as(query) + .bind(user_id.0) + .fetch_all(&self.0) + .await + } + + pub async fn can_user_access_channel( + &self, + user_id: UserId, + channel_id: ChannelId, + ) -> Result { + let query = " + SELECT id + FROM channel_memberships + WHERE user_id = $1 AND channel_id = $2 + LIMIT 1 + "; + sqlx::query_scalar::<_, i32>(query) + .bind(user_id.0) + .bind(channel_id.0) + .fetch_optional(&self.0) + .await + .map(|e| e.is_some()) + } + + #[cfg(test)] pub async fn add_channel_member( &self, channel_id: ChannelId, @@ -269,6 +314,12 @@ impl std::ops::Deref for Db { } } +impl Channel { + pub fn id(&self) -> ChannelId { + ChannelId(self.id) + } +} + impl User { pub fn id(&self) -> UserId { UserId(self.id) diff --git a/server/src/rpc.rs b/server/src/rpc.rs index cd229a3c67..f1ebf605a2 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1,6 +1,6 @@ use super::{ auth::{self, PeerExt as _}, - db::UserId, + db::{ChannelId, UserId}, AppState, }; use anyhow::anyhow; @@ -39,7 +39,7 @@ pub struct State { } struct ConnectionState { - _user_id: UserId, + user_id: UserId, worktrees: HashSet, } @@ -70,11 +70,11 @@ impl WorktreeState { impl State { // Add a new connection associated with a given user. - pub fn add_connection(&mut self, connection_id: ConnectionId, _user_id: UserId) { + pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) { self.connections.insert( connection_id, ConnectionState { - _user_id, + user_id, worktrees: Default::default(), }, ); @@ -130,6 +130,14 @@ impl State { } } + fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result { + Ok(self + .connections + .get(&connection_id) + .ok_or_else(|| anyhow!("unknown connection"))? + .user_id) + } + fn read_worktree( &self, worktree_id: u64, @@ -254,6 +262,8 @@ pub fn add_rpc_routes(router: &mut Router, state: &Arc, rpc: &Arc>, rpc: &Arc) { @@ -600,6 +610,54 @@ async fn buffer_saved( broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await } +async fn get_channels( + request: TypedEnvelope, + rpc: &Arc, + state: &Arc, +) -> tide::Result<()> { + let user_id = state + .rpc + .read() + .await + .user_id_for_connection(request.sender_id)?; + let channels = state.db.get_channels_for_user(user_id).await?; + rpc.respond( + request.receipt(), + proto::GetChannelsResponse { + channels: channels + .into_iter() + .map(|chan| proto::Channel { + id: chan.id().0 as u64, + name: chan.name, + }) + .collect(), + }, + ) + .await?; + Ok(()) +} + +async fn join_channel( + request: TypedEnvelope, + rpc: &Arc, + state: &Arc, +) -> tide::Result<()> { + let user_id = state + .rpc + .read() + .await + .user_id_for_connection(request.sender_id)?; + if !state + .db + .can_user_access_channel(user_id, ChannelId(request.payload.channel_id as i32)) + .await? + { + Err(anyhow!("access denied"))?; + } + + Ok(()) +} + async fn broadcast_in_worktree( worktree_id: u64, request: TypedEnvelope, diff --git a/server/src/tests.rs b/server/src/tests.rs index d0257e9f41..653e2ae59a 100644 --- a/server/src/tests.rs +++ b/server/src/tests.rs @@ -1,5 +1,7 @@ use crate::{ - auth, db, github, + auth, + db::{self, UserId}, + github, rpc::{self, add_rpc_routes}, AppState, Config, }; @@ -31,8 +33,8 @@ async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) // Connect to a server as 2 clients. let mut server = TestServer::start().await; - let client_a = server.create_client(&mut cx_a, "user_a").await; - let client_b = server.create_client(&mut cx_b, "user_b").await; + let (_, client_a) = server.create_client(&mut cx_a, "user_a").await; + let (_, client_b) = server.create_client(&mut cx_b, "user_b").await; cx_a.foreground().forbid_parking(); @@ -138,9 +140,9 @@ async fn test_propagate_saves_and_fs_changes_in_shared_worktree( // Connect to a server as 3 clients. let mut server = TestServer::start().await; - let client_a = server.create_client(&mut cx_a, "user_a").await; - let client_b = server.create_client(&mut cx_b, "user_b").await; - let client_c = server.create_client(&mut cx_c, "user_c").await; + let (_, client_a) = server.create_client(&mut cx_a, "user_a").await; + let (_, client_b) = server.create_client(&mut cx_b, "user_b").await; + let (_, client_c) = server.create_client(&mut cx_c, "user_c").await; cx_a.foreground().forbid_parking(); @@ -280,8 +282,8 @@ async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: Tes // Connect to a server as 2 clients. let mut server = TestServer::start().await; - let client_a = server.create_client(&mut cx_a, "user_a").await; - let client_b = server.create_client(&mut cx_b, "user_b").await; + let (_, client_a) = server.create_client(&mut cx_a, "user_a").await; + let (_, client_b) = server.create_client(&mut cx_b, "user_b").await; cx_a.foreground().forbid_parking(); @@ -359,8 +361,8 @@ async fn test_editing_while_guest_opens_buffer(mut cx_a: TestAppContext, mut cx_ // Connect to a server as 2 clients. let mut server = TestServer::start().await; - let client_a = server.create_client(&mut cx_a, "user_a").await; - let client_b = server.create_client(&mut cx_b, "user_b").await; + let (_, client_a) = server.create_client(&mut cx_a, "user_a").await; + let (_, client_b) = server.create_client(&mut cx_b, "user_b").await; cx_a.foreground().forbid_parking(); @@ -420,8 +422,8 @@ async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) // Connect to a server as 2 clients. let mut server = TestServer::start().await; - let client_a = server.create_client(&mut cx_a, "user_a").await; - let client_b = server.create_client(&mut cx_a, "user_b").await; + let (_, client_a) = server.create_client(&mut cx_a, "user_a").await; + let (_, client_b) = server.create_client(&mut cx_a, "user_b").await; cx_a.foreground().forbid_parking(); @@ -474,6 +476,36 @@ async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) .await; } +#[gpui::test] +async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) { + let lang_registry = Arc::new(LanguageRegistry::new()); + + // Connect to a server as 2 clients. + let mut server = TestServer::start().await; + let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await; + let (user_id_b, client_b) = server.create_client(&mut cx_a, "user_b").await; + + // Create a channel that includes these 2 users and 1 other user. + let db = &server.app_state.db; + let user_id_c = db.create_user("user_c", false).await.unwrap(); + let org_id = db.create_org("Test Org", "test-org").await.unwrap(); + let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); + db.add_channel_member(channel_id, user_id_a, false) + .await + .unwrap(); + db.add_channel_member(channel_id, user_id_b, false) + .await + .unwrap(); + db.add_channel_member(channel_id, user_id_c, false) + .await + .unwrap(); + db.create_channel_message(channel_id, user_id_c, "first message!") + .await + .unwrap(); + + // let chatroom_a = ChatRoom:: +} + struct TestServer { peer: Arc, app_state: Arc, @@ -497,7 +529,7 @@ impl TestServer { } } - async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> Client { + async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> (UserId, Client) { let user_id = self.app_state.db.create_user(name, false).await.unwrap(); let lang_registry = Arc::new(LanguageRegistry::new()); let client = Client::new(lang_registry.clone()); @@ -520,7 +552,7 @@ impl TestServer { .await .unwrap(); - client + (user_id, client) } async fn build_app_state(db_name: &str) -> Arc { diff --git a/zrpc/src/proto.rs b/zrpc/src/proto.rs index bb082c1783..77390cbb17 100644 --- a/zrpc/src/proto.rs +++ b/zrpc/src/proto.rs @@ -79,6 +79,11 @@ message!(UpdateBuffer); request_message!(SaveBuffer, BufferSaved); message!(AddPeer); message!(RemovePeer); +request_message!(GetChannels, GetChannelsResponse); +request_message!(JoinChannel, JoinChannelResponse); +request_message!(GetUsers, GetUsersResponse); +message!(SendChannelMessage); +message!(ChannelMessageSent); /// A stream of protobuf messages. pub struct MessageStream {