From 18b31f1552db6f7a25ae8011fe5ffd467095cd79 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 8 Jan 2024 12:04:59 -0800 Subject: [PATCH] Check user is host for host-broadcasted project messages --- crates/collab/src/db/queries/projects.rs | 28 ++++++++ crates/collab/src/rpc.rs | 92 +++++------------------- crates/rpc/src/macros.rs | 4 +- crates/rpc/src/proto.rs | 5 +- 4 files changed, 51 insertions(+), 78 deletions(-) diff --git a/crates/collab/src/db/queries/projects.rs b/crates/collab/src/db/queries/projects.rs index ca59c851e7..04c77c8077 100644 --- a/crates/collab/src/db/queries/projects.rs +++ b/crates/collab/src/db/queries/projects.rs @@ -777,6 +777,34 @@ impl Database { .await } + pub async fn check_user_is_project_host( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result<()> { + let room_id = self.room_id_for_project(project_id).await?; + self.room_transaction(room_id, |tx| async move { + project_collaborator::Entity::find() + .filter( + Condition::all() + .add(project_collaborator::Column::ProjectId.eq(project_id)) + .add(project_collaborator::Column::IsHost.eq(true)) + .add(project_collaborator::Column::ConnectionId.eq(connection_id.id)) + .add( + project_collaborator::Column::ConnectionServerId + .eq(connection_id.owner_id), + ), + ) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("failed to read project host"))?; + + Ok(()) + }) + .await + .map(|guard| guard.into_inner()) + } + pub async fn host_for_mutating_project_request( &self, project_id: ProjectId, diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 68774c22e6..572670d78f 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -42,7 +42,7 @@ use prometheus::{register_int_gauge, IntGauge}; use rpc::{ proto::{ self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo, - RequestMessage, UpdateChannelBufferCollaborators, + RequestMessage, ShareProject, UpdateChannelBufferCollaborators, }, Connection, ConnectionId, Peer, Receipt, TypedEnvelope, }; @@ -216,7 +216,6 @@ impl Server { .add_message_handler(update_language_server) .add_message_handler(update_diagnostic_summary) .add_message_handler(update_worktree_settings) - .add_message_handler(refresh_inlay_hints) .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_read_only_project_request::) @@ -251,9 +250,11 @@ impl Server { .add_request_handler(forward_mutating_project_request::) .add_message_handler(create_buffer_for_peer) .add_request_handler(update_buffer) - .add_message_handler(update_buffer_file) - .add_message_handler(buffer_reloaded) - .add_message_handler(buffer_saved) + .add_message_handler(broadcast_project_message_from_host::) + .add_message_handler(broadcast_project_message_from_host::) + .add_message_handler(broadcast_project_message_from_host::) + .add_message_handler(broadcast_project_message_from_host::) + .add_message_handler(broadcast_project_message_from_host::) .add_request_handler(get_users) .add_request_handler(fuzzy_search_users) .add_request_handler(request_contact) @@ -285,7 +286,6 @@ impl Server { .add_request_handler(follow) .add_message_handler(unfollow) .add_message_handler(update_followers) - .add_message_handler(update_diff_base) .add_request_handler(get_private_user_info) .add_message_handler(acknowledge_channel_message) .add_message_handler(acknowledge_buffer_version); @@ -1697,10 +1697,6 @@ async fn update_worktree_settings( Ok(()) } -async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> { - broadcast_project_message(request.project_id, request, session).await -} - async fn start_language_server( request: proto::StartLanguageServer, session: Session, @@ -1804,6 +1800,14 @@ async fn create_buffer_for_peer( request: proto::CreateBufferForPeer, session: Session, ) -> Result<()> { + session + .db() + .await + .check_user_is_project_host( + ProjectId::from_proto(request.project_id), + session.connection_id, + ) + .await?; let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?; session .peer @@ -1856,60 +1860,17 @@ async fn update_buffer( Ok(()) } -async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = session - .db() - .await - .project_connection_ids(project_id, session.connection_id) - .await?; - - broadcast( - Some(session.connection_id), - project_connection_ids.iter().copied(), - |connection_id| { - session - .peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - Ok(()) -} - -async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = session - .db() - .await - .project_connection_ids(project_id, session.connection_id) - .await?; - broadcast( - Some(session.connection_id), - project_connection_ids.iter().copied(), - |connection_id| { - session - .peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - Ok(()) -} - -async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> { - broadcast_project_message(request.project_id, request, session).await -} - -async fn broadcast_project_message( - project_id: u64, +async fn broadcast_project_message_from_host>( request: T, session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(project_id); + let project_id = ProjectId::from_proto(request.remote_entity_id()); let project_connection_ids = session .db() .await .project_connection_ids(project_id, session.connection_id) .await?; + broadcast( Some(session.connection_id), project_connection_ids.iter().copied(), @@ -3138,25 +3099,6 @@ async fn mark_notification_as_read( Ok(()) } -async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = session - .db() - .await - .project_connection_ids(project_id, session.connection_id) - .await?; - broadcast( - Some(session.connection_id), - project_connection_ids.iter().copied(), - |connection_id| { - session - .peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - Ok(()) -} - async fn get_private_user_info( _request: proto::GetPrivateUserInfo, response: Response, diff --git a/crates/rpc/src/macros.rs b/crates/rpc/src/macros.rs index 89e605540d..85e2b0cf87 100644 --- a/crates/rpc/src/macros.rs +++ b/crates/rpc/src/macros.rs @@ -60,8 +60,10 @@ macro_rules! request_messages { #[macro_export] macro_rules! entity_messages { - ($id_field:ident, $($name:ident),* $(,)?) => { + ({$id_field:ident, $entity_type:ty}, $($name:ident),* $(,)?) => { $(impl EntityMessage for $name { + type Entity = $entity_type; + fn remote_entity_id(&self) -> u64 { self.$id_field } diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 336c252630..25b8b00dae 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -31,6 +31,7 @@ pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 's } pub trait EntityMessage: EnvelopedMessage { + type Entity; fn remote_entity_id(&self) -> u64; } @@ -369,7 +370,7 @@ request_messages!( ); entity_messages!( - project_id, + {project_id, ShareProject}, AddProjectCollaborator, ApplyCodeAction, ApplyCompletionAdditionalEdits, @@ -422,7 +423,7 @@ entity_messages!( ); entity_messages!( - channel_id, + {channel_id, Channel}, ChannelMessageSent, RemoveChannelMessage, UpdateChannelBuffer,