diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index d61cdd334d..e503188e1d 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1886,6 +1886,64 @@ where .await } + pub async fn project_collaborators( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result> { + self.transact(|mut tx| async move { + let collaborators = sqlx::query_as::<_, ProjectCollaborator>( + " + SELECT * + FROM project_collaborators + WHERE project_id = $1 + ", + ) + .bind(project_id) + .fetch_all(&mut tx) + .await?; + + if collaborators + .iter() + .any(|collaborator| collaborator.connection_id == connection_id.0 as i32) + { + Ok(collaborators) + } else { + Err(anyhow!("no such project"))? + } + }) + .await + } + + pub async fn project_connection_ids( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result> { + self.transact(|mut tx| async move { + let connection_ids = sqlx::query_scalar::<_, i32>( + " + SELECT connection_id + FROM project_collaborators + WHERE project_id = $1 + ", + ) + .bind(project_id) + .fetch_all(&mut tx) + .await?; + + if connection_ids.contains(&(connection_id.0 as i32)) { + Ok(connection_ids + .into_iter() + .map(|connection_id| ConnectionId(connection_id as u32)) + .collect()) + } else { + Err(anyhow!("no such project"))? + } + }) + .await + } + pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> { todo!() // test_support!(self, { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 1943f18ceb..f0116f04f9 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1187,13 +1187,15 @@ impl Server { self: Arc, request: Message, ) -> Result<()> { - let receiver_ids = self.store().await.project_connection_ids( - ProjectId::from_proto(request.payload.project_id), - request.sender_connection_id, - )?; + let project_id = ProjectId::from_proto(request.payload.project_id); + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; broadcast( request.sender_connection_id, - receiver_ids, + project_connection_ids, |connection_id| { self.peer.forward_send( request.sender_connection_id, @@ -1214,25 +1216,25 @@ impl Server { T: EntityMessage + RequestMessage, { let project_id = ProjectId::from_proto(request.payload.remote_entity_id()); - let host_connection_id = self - .store() - .await - .read_project(project_id, request.sender_connection_id)? - .host_connection_id; + let collaborators = self + .app_state + .db + .project_collaborators(project_id, request.sender_connection_id) + .await?; + let host = collaborators + .iter() + .find(|collaborator| collaborator.is_host) + .ok_or_else(|| anyhow!("host not found"))?; + let payload = self .peer .forward_request( request.sender_connection_id, - host_connection_id, + ConnectionId(host.connection_id as u32), request.payload, ) .await?; - // Ensure project still exists by the time we get the response from the host. - self.store() - .await - .read_project(project_id, request.sender_connection_id)?; - response.send(payload)?; Ok(()) } @@ -1243,25 +1245,39 @@ impl Server { response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); - let host = self - .store() - .await - .read_project(project_id, request.sender_connection_id)? - .host_connection_id; + let collaborators = self + .app_state + .db + .project_collaborators(project_id, request.sender_connection_id) + .await?; + let host = collaborators + .into_iter() + .find(|collaborator| collaborator.is_host) + .ok_or_else(|| anyhow!("host not found"))?; + let host_connection_id = ConnectionId(host.connection_id as u32); let response_payload = self .peer - .forward_request(request.sender_connection_id, host, request.payload.clone()) + .forward_request( + request.sender_connection_id, + host_connection_id, + request.payload.clone(), + ) .await?; - let mut guests = self - .store() - .await - .read_project(project_id, request.sender_connection_id)? - .connection_ids(); - guests.retain(|guest_connection_id| *guest_connection_id != request.sender_connection_id); - broadcast(host, guests, |conn_id| { + let mut collaborators = self + .app_state + .db + .project_collaborators(project_id, request.sender_connection_id) + .await?; + collaborators.retain(|collaborator| { + collaborator.connection_id != request.sender_connection_id.0 as i32 + }); + let project_connection_ids = collaborators + .into_iter() + .map(|collaborator| ConnectionId(collaborator.connection_id as u32)); + broadcast(host_connection_id, project_connection_ids, |conn_id| { self.peer - .forward_send(host, conn_id, response_payload.clone()) + .forward_send(host_connection_id, conn_id, response_payload.clone()) }); response.send(response_payload)?; Ok(()) @@ -1285,14 +1301,15 @@ impl Server { response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); - let receiver_ids = { - let store = self.store().await; - store.project_connection_ids(project_id, request.sender_connection_id)? - }; + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; broadcast( request.sender_connection_id, - receiver_ids, + project_connection_ids, |connection_id| { self.peer.forward_send( request.sender_connection_id, @@ -1309,13 +1326,16 @@ impl Server { self: Arc, request: Message, ) -> Result<()> { - let receiver_ids = self.store().await.project_connection_ids( - ProjectId::from_proto(request.payload.project_id), - request.sender_connection_id, - )?; + let project_id = ProjectId::from_proto(request.payload.project_id); + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; + broadcast( request.sender_connection_id, - receiver_ids, + project_connection_ids, |connection_id| { self.peer.forward_send( request.sender_connection_id, @@ -1331,13 +1351,15 @@ impl Server { self: Arc, request: Message, ) -> Result<()> { - let receiver_ids = self.store().await.project_connection_ids( - ProjectId::from_proto(request.payload.project_id), - request.sender_connection_id, - )?; + let project_id = ProjectId::from_proto(request.payload.project_id); + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; broadcast( request.sender_connection_id, - receiver_ids, + project_connection_ids, |connection_id| { self.peer.forward_send( request.sender_connection_id, @@ -1350,13 +1372,15 @@ impl Server { } async fn buffer_saved(self: Arc, request: Message) -> Result<()> { - let receiver_ids = self.store().await.project_connection_ids( - ProjectId::from_proto(request.payload.project_id), - request.sender_connection_id, - )?; + let project_id = ProjectId::from_proto(request.payload.project_id); + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; broadcast( request.sender_connection_id, - receiver_ids, + project_connection_ids, |connection_id| { self.peer.forward_send( request.sender_connection_id, @@ -1376,14 +1400,14 @@ impl Server { let project_id = ProjectId::from_proto(request.payload.project_id); let leader_id = ConnectionId(request.payload.leader_id); let follower_id = request.sender_connection_id; - { - let store = self.store().await; - if !store - .project_connection_ids(project_id, follower_id)? - .contains(&leader_id) - { - Err(anyhow!("no such peer"))?; - } + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; + + if !project_connection_ids.contains(&leader_id) { + Err(anyhow!("no such peer"))?; } let mut response_payload = self @@ -1400,11 +1424,12 @@ impl Server { async fn unfollow(self: Arc, request: Message) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let leader_id = ConnectionId(request.payload.leader_id); - let store = self.store().await; - if !store - .project_connection_ids(project_id, request.sender_connection_id)? - .contains(&leader_id) - { + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; + if !project_connection_ids.contains(&leader_id) { Err(anyhow!("no such peer"))?; } self.peer @@ -1417,9 +1442,12 @@ impl Server { request: Message, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); - let store = self.store().await; - let connection_ids = - store.project_connection_ids(project_id, request.sender_connection_id)?; + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; + let leader_id = request .payload .variant @@ -1431,7 +1459,7 @@ impl Server { }); for follower_id in &request.payload.follower_ids { let follower_id = ConnectionId(*follower_id); - if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id { + if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id { self.peer.forward_send( request.sender_connection_id, follower_id, @@ -1629,13 +1657,15 @@ impl Server { self: Arc, request: Message, ) -> Result<()> { - let receiver_ids = self.store().await.project_connection_ids( - ProjectId::from_proto(request.payload.project_id), - request.sender_connection_id, - )?; + let project_id = ProjectId::from_proto(request.payload.project_id); + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; broadcast( request.sender_connection_id, - receiver_ids, + project_connection_ids, |connection_id| { self.peer.forward_send( request.sender_connection_id, diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index e3abc8dd3c..f694440a50 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -325,34 +325,6 @@ impl Store { }) } - pub fn project_connection_ids( - &self, - project_id: ProjectId, - acting_connection_id: ConnectionId, - ) -> Result> { - Ok(self - .read_project(project_id, acting_connection_id)? - .connection_ids()) - } - - pub fn read_project( - &self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result<&Project> { - let project = self - .projects - .get(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - if project.host_connection_id == connection_id - || project.guests.contains_key(&connection_id) - { - Ok(project) - } else { - Err(anyhow!("no such project"))? - } - } - #[cfg(test)] pub fn check_invariants(&self) { for (connection_id, connection) in &self.connections {