diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 175e3604c0..ba97b09acd 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -39,6 +39,7 @@ use rpc::{ use serde::{Serialize, Serializer}; use std::{ any::TypeId, + fmt, future::Future, marker::PhantomData, net::SocketAddr, @@ -67,20 +68,63 @@ lazy_static! { .unwrap(); } -type MessageHandler = Box< - dyn Send + Sync + Fn(Arc, Box, Session) -> BoxFuture<'static, ()>, ->; +type MessageHandler = + Box, Session) -> BoxFuture<'static, ()>>; struct Response { - server: Arc, + peer: Arc, receipt: Receipt, responded: Arc, } +impl Response { + fn send(self, payload: R::Response) -> Result<()> { + self.responded.store(true, SeqCst); + self.peer.respond(self.receipt, payload)?; + Ok(()) + } +} + +#[derive(Clone)] struct Session { user_id: UserId, connection_id: ConnectionId, db: Arc>, + peer: Arc, + connection_pool: Arc>, + live_kit_client: Option>, +} + +impl Session { + async fn db(&self) -> MutexGuard { + #[cfg(test)] + tokio::task::yield_now().await; + let guard = self.db.lock().await; + #[cfg(test)] + tokio::task::yield_now().await; + guard + } + + async fn connection_pool(&self) -> ConnectionPoolGuard<'_> { + #[cfg(test)] + tokio::task::yield_now().await; + let guard = self.connection_pool.lock().await; + #[cfg(test)] + tokio::task::yield_now().await; + ConnectionPoolGuard { + guard, + _not_send: PhantomData, + } + } +} + +impl fmt::Debug for Session { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Session") + .field("user_id", &self.user_id) + .field("connection_id", &self.connection_id) + .finish() + } } struct DbHandle(Arc); @@ -93,17 +137,9 @@ impl Deref for DbHandle { } } -impl Response { - fn send(self, payload: R::Response) -> Result<()> { - self.responded.store(true, SeqCst); - self.server.peer.respond(self.receipt, payload)?; - Ok(()) - } -} - pub struct Server { peer: Arc, - pub(crate) connection_pool: Mutex, + pub(crate) connection_pool: Arc>, app_state: Arc, handlers: HashMap, } @@ -148,76 +184,74 @@ impl Server { }; server - .add_request_handler(Server::ping) - .add_request_handler(Server::create_room) - .add_request_handler(Server::join_room) - .add_message_handler(Server::leave_room) - .add_request_handler(Server::call) - .add_request_handler(Server::cancel_call) - .add_message_handler(Server::decline_call) - .add_request_handler(Server::update_participant_location) - .add_request_handler(Server::share_project) - .add_message_handler(Server::unshare_project) - .add_request_handler(Server::join_project) - .add_message_handler(Server::leave_project) - .add_request_handler(Server::update_project) - .add_request_handler(Server::update_worktree) - .add_message_handler(Server::start_language_server) - .add_message_handler(Server::update_language_server) - .add_request_handler(Server::update_diagnostic_summary) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler( - Server::forward_project_request::, - ) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_message_handler(Server::create_buffer_for_peer) - .add_request_handler(Server::update_buffer) - .add_message_handler(Server::update_buffer_file) - .add_message_handler(Server::buffer_reloaded) - .add_message_handler(Server::buffer_saved) - .add_request_handler(Server::save_buffer) - .add_request_handler(Server::get_users) - .add_request_handler(Server::fuzzy_search_users) - .add_request_handler(Server::request_contact) - .add_request_handler(Server::remove_contact) - .add_request_handler(Server::respond_to_contact_request) - .add_request_handler(Server::follow) - .add_message_handler(Server::unfollow) - .add_message_handler(Server::update_followers) - .add_message_handler(Server::update_diff_base) - .add_request_handler(Server::get_private_user_info); + .add_request_handler(ping) + .add_request_handler(create_room) + .add_request_handler(join_room) + .add_message_handler(leave_room) + .add_request_handler(call) + .add_request_handler(cancel_call) + .add_message_handler(decline_call) + .add_request_handler(update_participant_location) + .add_request_handler(share_project) + .add_message_handler(unshare_project) + .add_request_handler(join_project) + .add_message_handler(leave_project) + .add_request_handler(update_project) + .add_request_handler(update_worktree) + .add_message_handler(start_language_server) + .add_message_handler(update_language_server) + .add_request_handler(update_diagnostic_summary) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_message_handler(create_buffer_for_peer) + .add_request_handler(update_buffer) + .add_message_handler(update_buffer_file) + .add_message_handler(buffer_reloaded) + .add_message_handler(buffer_saved) + .add_request_handler(save_buffer) + .add_request_handler(get_users) + .add_request_handler(fuzzy_search_users) + .add_request_handler(request_contact) + .add_request_handler(remove_contact) + .add_request_handler(respond_to_contact_request) + .add_request_handler(follow) + .add_message_handler(unfollow) + .add_message_handler(update_followers) + .add_message_handler(update_diff_base) + .add_request_handler(get_private_user_info); Arc::new(server) } fn add_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, TypedEnvelope, Session) -> Fut, + F: 'static + Send + Sync + Fn(TypedEnvelope, Session) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { let prev_handler = self.handlers.insert( TypeId::of::(), - Box::new(move |server, envelope, session| { + Box::new(move |envelope, session| { let envelope = envelope.into_any().downcast::>().unwrap(); let span = info_span!( "handle message", @@ -229,7 +263,7 @@ impl Server { "message received" ); }); - let future = (handler)(server, *envelope, session); + let future = (handler)(*envelope, session); async move { if let Err(error) = future.await { tracing::error!(%error, "error handling message"); @@ -247,34 +281,33 @@ impl Server { fn add_message_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, M, Session) -> Fut, + F: 'static + Send + Sync + Fn(M, Session) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { - self.add_handler(move |server, envelope, session| { - handler(server, envelope.payload, session) - }); + self.add_handler(move |envelope, session| handler(envelope.payload, session)); self } fn add_request_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, M, Response, Session) -> Fut, + F: 'static + Send + Sync + Fn(M, Response, Session) -> Fut, Fut: Send + Future>, M: RequestMessage, { let handler = Arc::new(handler); - self.add_handler(move |server, envelope, session| { + self.add_handler(move |envelope, session| { let receipt = envelope.receipt(); let handler = handler.clone(); async move { + let peer = session.peer.clone(); let responded = Arc::new(AtomicBool::default()); let response = Response { - server: server.clone(), + peer: peer.clone(), responded: responded.clone(), receipt, }; - match (handler)(server.clone(), envelope.payload, response, session).await { + match (handler)(envelope.payload, response, session).await { Ok(()) => { if responded.load(std::sync::atomic::Ordering::SeqCst) { Ok(()) @@ -283,7 +316,7 @@ impl Server { } } Err(error) => { - server.peer.respond_with_error( + peer.respond_with_error( receipt, proto::Error { message: error.to_string(), @@ -304,7 +337,7 @@ impl Server { mut send_connection_id: Option>, executor: E, ) -> impl Future> { - let mut this = self.clone(); + let this = self.clone(); let user_id = user.id; let login = user.github_login; let span = info_span!("handle connection", %user_id, %login, %address); @@ -340,7 +373,7 @@ impl Server { ).await?; { - let mut pool = this.connection_pool().await; + let mut pool = this.connection_pool.lock().await; pool.add_connection(connection_id, user_id, user.admin); this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?; @@ -356,13 +389,19 @@ impl Server { this.peer.send(connection_id, incoming_call)?; } - this.update_user_contacts(user_id).await?; + let session = Session { + user_id, + connection_id, + db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))), + peer: this.peer.clone(), + connection_pool: this.connection_pool.clone(), + live_kit_client: this.app_state.live_kit_client.clone() + }; + update_user_contacts(user_id, &session).await?; let handle_io = handle_io.fuse(); futures::pin_mut!(handle_io); - let db = Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))); - // Handlers for foreground messages are pushed into the following `FuturesUnordered`. // This prevents deadlocks when e.g., client A performs a request to client B and // client B performs a request to client A. If both clients stop processing further @@ -390,12 +429,7 @@ impl Server { let span_enter = span.enter(); if let Some(handler) = this.handlers.get(&message.payload_type_id()) { let is_background = message.is_background(); - let session = Session { - user_id, - connection_id, - db: db.clone(), - }; - let handle_message = (handler)(this.clone(), message, session); + let handle_message = (handler)(message, session.clone()); drop(span_enter); let handle_message = handle_message.instrument(span); @@ -417,7 +451,7 @@ impl Server { drop(foreground_message_handlers); tracing::info!(%user_id, %login, %connection_id, %address, "signing out"); - if let Err(error) = this.sign_out(connection_id, user_id).await { + if let Err(error) = sign_out(session).await { tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out"); } @@ -425,40 +459,6 @@ impl Server { }.instrument(span) } - #[instrument(skip(self), err)] - async fn sign_out( - self: &mut Arc, - connection_id: ConnectionId, - user_id: UserId, - ) -> Result<()> { - self.peer.disconnect(connection_id); - let decline_calls = { - let mut pool = self.connection_pool().await; - pool.remove_connection(connection_id)?; - let mut connections = pool.user_connection_ids(user_id); - connections.next().is_none() - }; - - self.leave_room_for_connection(connection_id, user_id) - .await - .trace_err(); - if decline_calls { - if let Some(room) = self - .app_state - .db - .decline_call(None, user_id) - .await - .trace_err() - { - self.room_updated(&room); - } - } - - self.update_user_contacts(user_id).await?; - - Ok(()) - } - pub async fn invite_code_redeemed( self: &Arc, inviter_id: UserId, @@ -466,7 +466,7 @@ impl Server { ) -> Result<()> { if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? { if let Some(code) = &user.invite_code { - let pool = self.connection_pool().await; + let pool = self.connection_pool.lock().await; let invitee_contact = contact_for_user(invitee_id, true, false, &pool); for connection_id in pool.user_connection_ids(inviter_id) { self.peer.send( @@ -492,7 +492,7 @@ impl Server { pub async fn invite_count_updated(self: &Arc, user_id: UserId) -> Result<()> { if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? { if let Some(invite_code) = &user.invite_code { - let pool = self.connection_pool().await; + let pool = self.connection_pool.lock().await; for connection_id in pool.user_connection_ids(user_id) { self.peer.send( connection_id, @@ -510,1215 +510,12 @@ impl Server { Ok(()) } - async fn ping( - self: Arc, - _: proto::Ping, - response: Response, - _session: Session, - ) -> Result<()> { - response.send(proto::Ack {})?; - Ok(()) - } - - async fn create_room( - self: Arc, - _request: proto::CreateRoom, - response: Response, - session: Session, - ) -> Result<()> { - let room = self - .app_state - .db - .create_room(session.user_id, session.connection_id) - .await?; - - let live_kit_connection_info = - if let Some(live_kit) = self.app_state.live_kit_client.as_ref() { - if let Some(_) = live_kit - .create_room(room.live_kit_room.clone()) - .await - .trace_err() - { - if let Some(token) = live_kit - .room_token(&room.live_kit_room, &session.connection_id.to_string()) - .trace_err() - { - Some(proto::LiveKitConnectionInfo { - server_url: live_kit.url().into(), - token, - }) - } else { - None - } - } else { - None - } - } else { - None - }; - - response.send(proto::CreateRoomResponse { - room: Some(room), - live_kit_connection_info, - })?; - self.update_user_contacts(session.user_id).await?; - Ok(()) - } - - async fn join_room( - self: Arc, - request: proto::JoinRoom, - response: Response, - session: Session, - ) -> Result<()> { - let room = self - .app_state - .db - .join_room( - RoomId::from_proto(request.id), - session.user_id, - session.connection_id, - ) - .await?; - for connection_id in self - .connection_pool() - .await - .user_connection_ids(session.user_id) - { - self.peer - .send(connection_id, proto::CallCanceled {}) - .trace_err(); - } - - let live_kit_connection_info = - if let Some(live_kit) = self.app_state.live_kit_client.as_ref() { - if let Some(token) = live_kit - .room_token(&room.live_kit_room, &session.connection_id.to_string()) - .trace_err() - { - Some(proto::LiveKitConnectionInfo { - server_url: live_kit.url().into(), - token, - }) - } else { - None - } - } else { - None - }; - - self.room_updated(&room); - response.send(proto::JoinRoomResponse { - room: Some(room), - live_kit_connection_info, - })?; - - self.update_user_contacts(session.user_id).await?; - Ok(()) - } - - async fn leave_room( - self: Arc, - _message: proto::LeaveRoom, - session: Session, - ) -> Result<()> { - self.leave_room_for_connection(session.connection_id, session.user_id) - .await - } - - async fn leave_room_for_connection( - self: &Arc, - leaving_connection_id: ConnectionId, - leaving_user_id: UserId, - ) -> Result<()> { - let mut contacts_to_update = HashSet::default(); - - let Some(left_room) = self.app_state.db.leave_room(leaving_connection_id).await? else { - return Err(anyhow!("no room to leave"))?; - }; - contacts_to_update.insert(leaving_user_id); - - for project in left_room.left_projects.into_values() { - for connection_id in project.connection_ids { - if project.host_user_id == leaving_user_id { - self.peer - .send( - connection_id, - proto::UnshareProject { - project_id: project.id.to_proto(), - }, - ) - .trace_err(); - } else { - self.peer - .send( - connection_id, - proto::RemoveProjectCollaborator { - project_id: project.id.to_proto(), - peer_id: leaving_connection_id.0, - }, - ) - .trace_err(); - } - } - - self.peer - .send( - leaving_connection_id, - proto::UnshareProject { - project_id: project.id.to_proto(), - }, - ) - .trace_err(); - } - - self.room_updated(&left_room.room); - { - let pool = self.connection_pool().await; - for canceled_user_id in left_room.canceled_calls_to_user_ids { - for connection_id in pool.user_connection_ids(canceled_user_id) { - self.peer - .send(connection_id, proto::CallCanceled {}) - .trace_err(); - } - contacts_to_update.insert(canceled_user_id); - } - } - - for contact_user_id in contacts_to_update { - self.update_user_contacts(contact_user_id).await?; - } - - if let Some(live_kit) = self.app_state.live_kit_client.as_ref() { - live_kit - .remove_participant( - left_room.room.live_kit_room.clone(), - leaving_connection_id.to_string(), - ) - .await - .trace_err(); - - if left_room.room.participants.is_empty() { - live_kit - .delete_room(left_room.room.live_kit_room) - .await - .trace_err(); - } - } - - Ok(()) - } - - async fn call( - self: Arc, - request: proto::Call, - response: Response, - session: Session, - ) -> Result<()> { - let room_id = RoomId::from_proto(request.room_id); - let calling_user_id = session.user_id; - let calling_connection_id = session.connection_id; - let called_user_id = UserId::from_proto(request.called_user_id); - let initial_project_id = request.initial_project_id.map(ProjectId::from_proto); - if !self - .app_state - .db - .has_contact(calling_user_id, called_user_id) - .await? - { - return Err(anyhow!("cannot call a user who isn't a contact"))?; - } - - let (room, incoming_call) = self - .app_state - .db - .call( - room_id, - calling_user_id, - calling_connection_id, - called_user_id, - initial_project_id, - ) - .await?; - self.room_updated(&room); - self.update_user_contacts(called_user_id).await?; - - let mut calls = self - .connection_pool() - .await - .user_connection_ids(called_user_id) - .map(|connection_id| self.peer.request(connection_id, incoming_call.clone())) - .collect::>(); - - while let Some(call_response) = calls.next().await { - match call_response.as_ref() { - Ok(_) => { - response.send(proto::Ack {})?; - return Ok(()); - } - Err(_) => { - call_response.trace_err(); - } - } - } - - let room = self - .app_state - .db - .call_failed(room_id, called_user_id) - .await?; - self.room_updated(&room); - self.update_user_contacts(called_user_id).await?; - - Err(anyhow!("failed to ring user"))? - } - - async fn cancel_call( - self: Arc, - request: proto::CancelCall, - response: Response, - session: Session, - ) -> Result<()> { - let called_user_id = UserId::from_proto(request.called_user_id); - let room_id = RoomId::from_proto(request.room_id); - let room = self - .app_state - .db - .cancel_call(Some(room_id), session.connection_id, called_user_id) - .await?; - for connection_id in self - .connection_pool() - .await - .user_connection_ids(called_user_id) - { - self.peer - .send(connection_id, proto::CallCanceled {}) - .trace_err(); - } - self.room_updated(&room); - response.send(proto::Ack {})?; - - self.update_user_contacts(called_user_id).await?; - Ok(()) - } - - async fn decline_call( - self: Arc, - message: proto::DeclineCall, - session: Session, - ) -> Result<()> { - let room_id = RoomId::from_proto(message.room_id); - let room = self - .app_state - .db - .decline_call(Some(room_id), session.user_id) - .await?; - for connection_id in self - .connection_pool() - .await - .user_connection_ids(session.user_id) - { - self.peer - .send(connection_id, proto::CallCanceled {}) - .trace_err(); - } - self.room_updated(&room); - self.update_user_contacts(session.user_id).await?; - Ok(()) - } - - async fn update_participant_location( - self: Arc, - request: proto::UpdateParticipantLocation, - response: Response, - session: Session, - ) -> Result<()> { - let room_id = RoomId::from_proto(request.room_id); - let location = request - .location - .ok_or_else(|| anyhow!("invalid location"))?; - let room = self - .app_state - .db - .update_room_participant_location(room_id, session.connection_id, location) - .await?; - self.room_updated(&room); - response.send(proto::Ack {})?; - Ok(()) - } - - fn room_updated(&self, room: &proto::Room) { - for participant in &room.participants { - self.peer - .send( - ConnectionId(participant.peer_id), - proto::RoomUpdated { - room: Some(room.clone()), - }, - ) - .trace_err(); - } - } - - async fn share_project( - self: Arc, - request: proto::ShareProject, - response: Response, - session: Session, - ) -> Result<()> { - let (project_id, room) = self - .app_state - .db - .share_project( - RoomId::from_proto(request.room_id), - session.connection_id, - &request.worktrees, - ) - .await?; - response.send(proto::ShareProjectResponse { - project_id: project_id.to_proto(), - })?; - self.room_updated(&room); - - Ok(()) - } - - async fn unshare_project( - self: Arc, - message: proto::UnshareProject, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(message.project_id); - - let (room, guest_connection_ids) = self - .app_state - .db - .unshare_project(project_id, session.connection_id) - .await?; - - broadcast(session.connection_id, guest_connection_ids, |conn_id| { - self.peer.send(conn_id, message.clone()) - }); - self.room_updated(&room); - - Ok(()) - } - - async fn update_user_contacts(self: &Arc, user_id: UserId) -> Result<()> { - let contacts = self.app_state.db.get_contacts(user_id).await?; - let busy = self.app_state.db.is_user_busy(user_id).await?; - let pool = self.connection_pool().await; - let updated_contact = contact_for_user(user_id, false, busy, &pool); - for contact in contacts { - if let db::Contact::Accepted { - user_id: contact_user_id, - .. - } = contact - { - for contact_conn_id in pool.user_connection_ids(contact_user_id) { - self.peer - .send( - contact_conn_id, - proto::UpdateContacts { - contacts: vec![updated_contact.clone()], - remove_contacts: Default::default(), - incoming_requests: Default::default(), - remove_incoming_requests: Default::default(), - outgoing_requests: Default::default(), - remove_outgoing_requests: Default::default(), - }, - ) - .trace_err(); - } - } - } - Ok(()) - } - - async fn join_project( - self: Arc, - request: proto::JoinProject, - response: Response, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let guest_user_id = session.user_id; - - tracing::info!(%project_id, "join project"); - - let (project, replica_id) = self - .app_state - .db - .join_project(project_id, session.connection_id) - .await?; - - let collaborators = project - .collaborators - .iter() - .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32) - .map(|collaborator| proto::Collaborator { - peer_id: collaborator.connection_id as u32, - replica_id: collaborator.replica_id.0 as u32, - user_id: collaborator.user_id.to_proto(), - }) - .collect::>(); - let worktrees = project - .worktrees - .iter() - .map(|(id, worktree)| proto::WorktreeMetadata { - id: id.to_proto(), - root_name: worktree.root_name.clone(), - visible: worktree.visible, - abs_path: worktree.abs_path.clone(), - }) - .collect::>(); - - for collaborator in &collaborators { - self.peer - .send( - ConnectionId(collaborator.peer_id), - proto::AddProjectCollaborator { - project_id: project_id.to_proto(), - collaborator: Some(proto::Collaborator { - peer_id: session.connection_id.0, - replica_id: replica_id.0 as u32, - user_id: guest_user_id.to_proto(), - }), - }, - ) - .trace_err(); - } - - // First, we send the metadata associated with each worktree. - response.send(proto::JoinProjectResponse { - worktrees: worktrees.clone(), - replica_id: replica_id.0 as u32, - collaborators: collaborators.clone(), - language_servers: project.language_servers.clone(), - })?; - - for (worktree_id, worktree) in project.worktrees { - #[cfg(any(test, feature = "test-support"))] - const MAX_CHUNK_SIZE: usize = 2; - #[cfg(not(any(test, feature = "test-support")))] - const MAX_CHUNK_SIZE: usize = 256; - - // Stream this worktree's entries. - let message = proto::UpdateWorktree { - project_id: project_id.to_proto(), - worktree_id: worktree_id.to_proto(), - abs_path: worktree.abs_path.clone(), - root_name: worktree.root_name, - updated_entries: worktree.entries, - removed_entries: Default::default(), - scan_id: worktree.scan_id, - is_last_update: worktree.is_complete, - }; - for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) { - self.peer.send(session.connection_id, update.clone())?; - } - - // Stream this worktree's diagnostics. - for summary in worktree.diagnostic_summaries { - self.peer.send( - session.connection_id, - proto::UpdateDiagnosticSummary { - project_id: project_id.to_proto(), - worktree_id: worktree.id.to_proto(), - summary: Some(summary), - }, - )?; - } - } - - for language_server in &project.language_servers { - self.peer.send( - session.connection_id, - proto::UpdateLanguageServer { - project_id: project_id.to_proto(), - language_server_id: language_server.id, - variant: Some( - proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated( - proto::LspDiskBasedDiagnosticsUpdated {}, - ), - ), - }, - )?; - } - - Ok(()) - } - - async fn leave_project( - self: Arc, - request: proto::LeaveProject, - session: Session, - ) -> Result<()> { - let sender_id = session.connection_id; - let project_id = ProjectId::from_proto(request.project_id); - let project; - { - project = self - .app_state - .db - .leave_project(project_id, sender_id) - .await?; - tracing::info!( - %project_id, - host_user_id = %project.host_user_id, - host_connection_id = %project.host_connection_id, - "leave project" - ); - - broadcast(sender_id, project.connection_ids, |conn_id| { - self.peer.send( - conn_id, - proto::RemoveProjectCollaborator { - project_id: project_id.to_proto(), - peer_id: sender_id.0, - }, - ) - }); - } - - Ok(()) - } - - async fn update_project( - self: Arc, - request: proto::UpdateProject, - response: Response, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let (room, guest_connection_ids) = self - .app_state - .db - .update_project(project_id, session.connection_id, &request.worktrees) - .await?; - broadcast( - session.connection_id, - guest_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - self.room_updated(&room); - response.send(proto::Ack {})?; - - Ok(()) - } - - async fn update_worktree( - self: Arc, - request: proto::UpdateWorktree, - response: Response, - session: Session, - ) -> Result<()> { - let guest_connection_ids = self - .app_state - .db - .update_worktree(&request, session.connection_id) - .await?; - - broadcast( - session.connection_id, - guest_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - response.send(proto::Ack {})?; - Ok(()) - } - - async fn update_diagnostic_summary( - self: Arc, - request: proto::UpdateDiagnosticSummary, - response: Response, - session: Session, - ) -> Result<()> { - let guest_connection_ids = self - .app_state - .db - .update_diagnostic_summary(&request, session.connection_id) - .await?; - - broadcast( - session.connection_id, - guest_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - - response.send(proto::Ack {})?; - Ok(()) - } - - async fn start_language_server( - self: Arc, - request: proto::StartLanguageServer, - session: Session, - ) -> Result<()> { - let guest_connection_ids = self - .app_state - .db - .start_language_server(&request, session.connection_id) - .await?; - - broadcast( - session.connection_id, - guest_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - Ok(()) - } - - async fn update_language_server( - self: Arc, - request: proto::UpdateLanguageServer, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = self - .app_state - .db - .project_connection_ids(project_id, session.connection_id) - .await?; - broadcast( - session.connection_id, - project_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - Ok(()) - } - - async fn forward_project_request( - self: Arc, - request: T, - response: Response, - session: Session, - ) -> Result<()> - where - T: EntityMessage + RequestMessage, - { - let project_id = ProjectId::from_proto(request.remote_entity_id()); - let collaborators = self - .app_state - .db - .project_collaborators(project_id, session.connection_id) - .await?; - let host = collaborators - .iter() - .find(|collaborator| collaborator.is_host) - .ok_or_else(|| anyhow!("host not found"))?; - - let payload = self - .peer - .forward_request( - session.connection_id, - ConnectionId(host.connection_id as u32), - request, - ) - .await?; - - response.send(payload)?; - Ok(()) - } - - async fn save_buffer( - self: Arc, - request: proto::SaveBuffer, - response: Response, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let collaborators = self - .app_state - .db - .project_collaborators(project_id, session.connection_id) - .await?; - let host = collaborators - .into_iter() - .find(|collaborator| collaborator.is_host) - .ok_or_else(|| anyhow!("host not found"))?; - let host_connection_id = ConnectionId(host.connection_id as u32); - let response_payload = self - .peer - .forward_request(session.connection_id, host_connection_id, request.clone()) - .await?; - - let mut collaborators = self - .app_state - .db - .project_collaborators(project_id, session.connection_id) - .await?; - collaborators - .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32); - let project_connection_ids = collaborators - .into_iter() - .map(|collaborator| ConnectionId(collaborator.connection_id as u32)); - broadcast(host_connection_id, project_connection_ids, |conn_id| { - self.peer - .forward_send(host_connection_id, conn_id, response_payload.clone()) - }); - response.send(response_payload)?; - Ok(()) - } - - async fn create_buffer_for_peer( - self: Arc, - request: proto::CreateBufferForPeer, - session: Session, - ) -> Result<()> { - self.peer.forward_send( - session.connection_id, - ConnectionId(request.peer_id), - request, - )?; - Ok(()) - } - - async fn update_buffer( - self: Arc, - request: proto::UpdateBuffer, - response: Response, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = self - .app_state - .db - .project_connection_ids(project_id, session.connection_id) - .await?; - - broadcast( - session.connection_id, - project_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - response.send(proto::Ack {})?; - Ok(()) - } - - async fn update_buffer_file( - self: Arc, - request: proto::UpdateBufferFile, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = self - .app_state - .db - .project_connection_ids(project_id, session.connection_id) - .await?; - - broadcast( - session.connection_id, - project_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - Ok(()) - } - - async fn buffer_reloaded( - self: Arc, - request: proto::BufferReloaded, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = self - .app_state - .db - .project_connection_ids(project_id, session.connection_id) - .await?; - broadcast( - session.connection_id, - project_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - Ok(()) - } - - async fn buffer_saved( - self: Arc, - request: proto::BufferSaved, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = self - .app_state - .db - .project_connection_ids(project_id, session.connection_id) - .await?; - broadcast( - session.connection_id, - project_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - Ok(()) - } - - async fn follow( - self: Arc, - request: proto::Follow, - response: Response, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let leader_id = ConnectionId(request.leader_id); - let follower_id = session.connection_id; - let project_connection_ids = self - .app_state - .db - .project_connection_ids(project_id, session.connection_id) - .await?; - - if !project_connection_ids.contains(&leader_id) { - Err(anyhow!("no such peer"))?; - } - - let mut response_payload = self - .peer - .forward_request(session.connection_id, leader_id, request) - .await?; - response_payload - .views - .retain(|view| view.leader_id != Some(follower_id.0)); - response.send(response_payload)?; - Ok(()) - } - - async fn unfollow(self: Arc, request: proto::Unfollow, session: Session) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let leader_id = ConnectionId(request.leader_id); - let project_connection_ids = self - .app_state - .db - .project_connection_ids(project_id, session.connection_id) - .await?; - if !project_connection_ids.contains(&leader_id) { - Err(anyhow!("no such peer"))?; - } - self.peer - .forward_send(session.connection_id, leader_id, request)?; - Ok(()) - } - - async fn update_followers( - self: Arc, - request: proto::UpdateFollowers, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = session - .db - .lock() - .await - .project_connection_ids(project_id, session.connection_id) - .await?; - - let leader_id = request.variant.as_ref().and_then(|variant| match variant { - proto::update_followers::Variant::CreateView(payload) => payload.leader_id, - proto::update_followers::Variant::UpdateView(payload) => payload.leader_id, - proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id, - }); - for follower_id in &request.follower_ids { - let follower_id = ConnectionId(*follower_id); - if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id { - self.peer - .forward_send(session.connection_id, follower_id, request.clone())?; - } - } - Ok(()) - } - - async fn get_users( - self: Arc, - request: proto::GetUsers, - response: Response, - _session: Session, - ) -> Result<()> { - let user_ids = request - .user_ids - .into_iter() - .map(UserId::from_proto) - .collect(); - let users = self - .app_state - .db - .get_users_by_ids(user_ids) - .await? - .into_iter() - .map(|user| proto::User { - id: user.id.to_proto(), - avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), - github_login: user.github_login, - }) - .collect(); - response.send(proto::UsersResponse { users })?; - Ok(()) - } - - async fn fuzzy_search_users( - self: Arc, - request: proto::FuzzySearchUsers, - response: Response, - session: Session, - ) -> Result<()> { - let query = request.query; - let db = &self.app_state.db; - let users = match query.len() { - 0 => vec![], - 1 | 2 => db - .get_user_by_github_account(&query, None) - .await? - .into_iter() - .collect(), - _ => db.fuzzy_search_users(&query, 10).await?, - }; - let users = users - .into_iter() - .filter(|user| user.id != session.user_id) - .map(|user| proto::User { - id: user.id.to_proto(), - avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), - github_login: user.github_login, - }) - .collect(); - response.send(proto::UsersResponse { users })?; - Ok(()) - } - - async fn request_contact( - self: Arc, - request: proto::RequestContact, - response: Response, - session: Session, - ) -> Result<()> { - let requester_id = session.user_id; - let responder_id = UserId::from_proto(request.responder_id); - if requester_id == responder_id { - return Err(anyhow!("cannot add yourself as a contact"))?; - } - - self.app_state - .db - .send_contact_request(requester_id, responder_id) - .await?; - - // Update outgoing contact requests of requester - let mut update = proto::UpdateContacts::default(); - update.outgoing_requests.push(responder_id.to_proto()); - for connection_id in self - .connection_pool() - .await - .user_connection_ids(requester_id) - { - self.peer.send(connection_id, update.clone())?; - } - - // Update incoming contact requests of responder - let mut update = proto::UpdateContacts::default(); - update - .incoming_requests - .push(proto::IncomingContactRequest { - requester_id: requester_id.to_proto(), - should_notify: true, - }); - for connection_id in self - .connection_pool() - .await - .user_connection_ids(responder_id) - { - self.peer.send(connection_id, update.clone())?; - } - - response.send(proto::Ack {})?; - Ok(()) - } - - async fn respond_to_contact_request( - self: Arc, - request: proto::RespondToContactRequest, - response: Response, - session: Session, - ) -> Result<()> { - let responder_id = session.user_id; - let requester_id = UserId::from_proto(request.requester_id); - if request.response == proto::ContactRequestResponse::Dismiss as i32 { - self.app_state - .db - .dismiss_contact_notification(responder_id, requester_id) - .await?; - } else { - let accept = request.response == proto::ContactRequestResponse::Accept as i32; - self.app_state - .db - .respond_to_contact_request(responder_id, requester_id, accept) - .await?; - let busy = self.app_state.db.is_user_busy(requester_id).await?; - - let pool = self.connection_pool().await; - // Update responder with new contact - let mut update = proto::UpdateContacts::default(); - if accept { - update - .contacts - .push(contact_for_user(requester_id, false, busy, &pool)); - } - update - .remove_incoming_requests - .push(requester_id.to_proto()); - for connection_id in pool.user_connection_ids(responder_id) { - self.peer.send(connection_id, update.clone())?; - } - - // Update requester with new contact - let mut update = proto::UpdateContacts::default(); - if accept { - update - .contacts - .push(contact_for_user(responder_id, true, busy, &pool)); - } - update - .remove_outgoing_requests - .push(responder_id.to_proto()); - for connection_id in pool.user_connection_ids(requester_id) { - self.peer.send(connection_id, update.clone())?; - } - } - - response.send(proto::Ack {})?; - Ok(()) - } - - async fn remove_contact( - self: Arc, - request: proto::RemoveContact, - response: Response, - session: Session, - ) -> Result<()> { - let requester_id = session.user_id; - let responder_id = UserId::from_proto(request.user_id); - self.app_state - .db - .remove_contact(requester_id, responder_id) - .await?; - - // Update outgoing contact requests of requester - let mut update = proto::UpdateContacts::default(); - update - .remove_outgoing_requests - .push(responder_id.to_proto()); - for connection_id in self - .connection_pool() - .await - .user_connection_ids(requester_id) - { - self.peer.send(connection_id, update.clone())?; - } - - // Update incoming contact requests of responder - let mut update = proto::UpdateContacts::default(); - update - .remove_incoming_requests - .push(requester_id.to_proto()); - for connection_id in self - .connection_pool() - .await - .user_connection_ids(responder_id) - { - self.peer.send(connection_id, update.clone())?; - } - - response.send(proto::Ack {})?; - Ok(()) - } - - async fn update_diff_base( - self: Arc, - request: proto::UpdateDiffBase, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = self - .app_state - .db - .project_connection_ids(project_id, session.connection_id) - .await?; - broadcast( - session.connection_id, - project_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - Ok(()) - } - - async fn get_private_user_info( - self: Arc, - _request: proto::GetPrivateUserInfo, - response: Response, - session: Session, - ) -> Result<()> { - let metrics_id = self - .app_state - .db - .get_user_metrics_id(session.user_id) - .await?; - let user = self - .app_state - .db - .get_user_by_id(session.user_id) - .await? - .ok_or_else(|| anyhow!("user not found"))?; - response.send(proto::GetPrivateUserInfoResponse { - metrics_id, - staff: user.admin, - })?; - Ok(()) - } - - pub(crate) async fn connection_pool(&self) -> ConnectionPoolGuard<'_> { - #[cfg(test)] - tokio::task::yield_now().await; - let guard = self.connection_pool.lock().await; - #[cfg(test)] - tokio::task::yield_now().await; - ConnectionPoolGuard { - guard, - _not_send: PhantomData, - } - } - pub async fn snapshot<'a>(self: &'a Arc) -> ServerSnapshot<'a> { ServerSnapshot { - connection_pool: self.connection_pool().await, + connection_pool: ConnectionPoolGuard { + guard: self.connection_pool.lock().await, + _not_send: PhantomData, + }, peer: &self.peer, } } @@ -1847,7 +644,8 @@ pub async fn handle_websocket_request( pub async fn handle_metrics(Extension(server): Extension>) -> Result { let connections = server - .connection_pool() + .connection_pool + .lock() .await .connections() .filter(|connection| !connection.admin) @@ -1866,6 +664,1042 @@ pub async fn handle_metrics(Extension(server): Extension>) -> Result Ok(encoded_metrics) } +#[instrument(err)] +async fn sign_out(session: Session) -> Result<()> { + session.peer.disconnect(session.connection_id); + let decline_calls = { + let mut pool = session.connection_pool().await; + pool.remove_connection(session.connection_id)?; + let mut connections = pool.user_connection_ids(session.user_id); + connections.next().is_none() + }; + + leave_room_for_session(&session).await.trace_err(); + if decline_calls { + if let Some(room) = session + .db() + .await + .decline_call(None, session.user_id) + .await + .trace_err() + { + room_updated(&room, &session); + } + } + + update_user_contacts(session.user_id, &session).await?; + + Ok(()) +} + +async fn ping(_: proto::Ping, response: Response, _session: Session) -> Result<()> { + response.send(proto::Ack {})?; + Ok(()) +} + +async fn create_room( + _request: proto::CreateRoom, + response: Response, + session: Session, +) -> Result<()> { + let room = session + .db() + .await + .create_room(session.user_id, session.connection_id) + .await?; + + let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() { + if let Some(_) = live_kit + .create_room(room.live_kit_room.clone()) + .await + .trace_err() + { + if let Some(token) = live_kit + .room_token(&room.live_kit_room, &session.connection_id.to_string()) + .trace_err() + { + Some(proto::LiveKitConnectionInfo { + server_url: live_kit.url().into(), + token, + }) + } else { + None + } + } else { + None + } + } else { + None + }; + + response.send(proto::CreateRoomResponse { + room: Some(room), + live_kit_connection_info, + })?; + update_user_contacts(session.user_id, &session).await?; + Ok(()) +} + +async fn join_room( + request: proto::JoinRoom, + response: Response, + session: Session, +) -> Result<()> { + let room = session + .db() + .await + .join_room( + RoomId::from_proto(request.id), + session.user_id, + session.connection_id, + ) + .await?; + for connection_id in session + .connection_pool() + .await + .user_connection_ids(session.user_id) + { + session + .peer + .send(connection_id, proto::CallCanceled {}) + .trace_err(); + } + + let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() { + if let Some(token) = live_kit + .room_token(&room.live_kit_room, &session.connection_id.to_string()) + .trace_err() + { + Some(proto::LiveKitConnectionInfo { + server_url: live_kit.url().into(), + token, + }) + } else { + None + } + } else { + None + }; + + room_updated(&room, &session); + response.send(proto::JoinRoomResponse { + room: Some(room), + live_kit_connection_info, + })?; + + update_user_contacts(session.user_id, &session).await?; + Ok(()) +} + +async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> { + leave_room_for_session(&session).await +} + +async fn call( + request: proto::Call, + response: Response, + session: Session, +) -> Result<()> { + let room_id = RoomId::from_proto(request.room_id); + let calling_user_id = session.user_id; + let calling_connection_id = session.connection_id; + let called_user_id = UserId::from_proto(request.called_user_id); + let initial_project_id = request.initial_project_id.map(ProjectId::from_proto); + if !session + .db() + .await + .has_contact(calling_user_id, called_user_id) + .await? + { + return Err(anyhow!("cannot call a user who isn't a contact"))?; + } + + let (room, incoming_call) = session + .db() + .await + .call( + room_id, + calling_user_id, + calling_connection_id, + called_user_id, + initial_project_id, + ) + .await?; + room_updated(&room, &session); + update_user_contacts(called_user_id, &session).await?; + + let mut calls = session + .connection_pool() + .await + .user_connection_ids(called_user_id) + .map(|connection_id| session.peer.request(connection_id, incoming_call.clone())) + .collect::>(); + + while let Some(call_response) = calls.next().await { + match call_response.as_ref() { + Ok(_) => { + response.send(proto::Ack {})?; + return Ok(()); + } + Err(_) => { + call_response.trace_err(); + } + } + } + + let room = session + .db() + .await + .call_failed(room_id, called_user_id) + .await?; + room_updated(&room, &session); + update_user_contacts(called_user_id, &session).await?; + + Err(anyhow!("failed to ring user"))? +} + +async fn cancel_call( + request: proto::CancelCall, + response: Response, + session: Session, +) -> Result<()> { + let called_user_id = UserId::from_proto(request.called_user_id); + let room_id = RoomId::from_proto(request.room_id); + let room = session + .db() + .await + .cancel_call(Some(room_id), session.connection_id, called_user_id) + .await?; + for connection_id in session + .connection_pool() + .await + .user_connection_ids(called_user_id) + { + session + .peer + .send(connection_id, proto::CallCanceled {}) + .trace_err(); + } + room_updated(&room, &session); + response.send(proto::Ack {})?; + + update_user_contacts(called_user_id, &session).await?; + Ok(()) +} + +async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> { + let room_id = RoomId::from_proto(message.room_id); + let room = session + .db() + .await + .decline_call(Some(room_id), session.user_id) + .await?; + for connection_id in session + .connection_pool() + .await + .user_connection_ids(session.user_id) + { + session + .peer + .send(connection_id, proto::CallCanceled {}) + .trace_err(); + } + room_updated(&room, &session); + update_user_contacts(session.user_id, &session).await?; + Ok(()) +} + +async fn update_participant_location( + request: proto::UpdateParticipantLocation, + response: Response, + session: Session, +) -> Result<()> { + let room_id = RoomId::from_proto(request.room_id); + let location = request + .location + .ok_or_else(|| anyhow!("invalid location"))?; + let room = session + .db() + .await + .update_room_participant_location(room_id, session.connection_id, location) + .await?; + room_updated(&room, &session); + response.send(proto::Ack {})?; + Ok(()) +} + +async fn share_project( + request: proto::ShareProject, + response: Response, + session: Session, +) -> Result<()> { + let (project_id, room) = session + .db() + .await + .share_project( + RoomId::from_proto(request.room_id), + session.connection_id, + &request.worktrees, + ) + .await?; + response.send(proto::ShareProjectResponse { + project_id: project_id.to_proto(), + })?; + room_updated(&room, &session); + + Ok(()) +} + +async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(message.project_id); + + let (room, guest_connection_ids) = session + .db() + .await + .unshare_project(project_id, session.connection_id) + .await?; + + broadcast(session.connection_id, guest_connection_ids, |conn_id| { + session.peer.send(conn_id, message.clone()) + }); + room_updated(&room, &session); + + Ok(()) +} + +async fn join_project( + request: proto::JoinProject, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let guest_user_id = session.user_id; + + tracing::info!(%project_id, "join project"); + + let (project, replica_id) = session + .db() + .await + .join_project(project_id, session.connection_id) + .await?; + + let collaborators = project + .collaborators + .iter() + .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32) + .map(|collaborator| proto::Collaborator { + peer_id: collaborator.connection_id as u32, + replica_id: collaborator.replica_id.0 as u32, + user_id: collaborator.user_id.to_proto(), + }) + .collect::>(); + let worktrees = project + .worktrees + .iter() + .map(|(id, worktree)| proto::WorktreeMetadata { + id: id.to_proto(), + root_name: worktree.root_name.clone(), + visible: worktree.visible, + abs_path: worktree.abs_path.clone(), + }) + .collect::>(); + + for collaborator in &collaborators { + session + .peer + .send( + ConnectionId(collaborator.peer_id), + proto::AddProjectCollaborator { + project_id: project_id.to_proto(), + collaborator: Some(proto::Collaborator { + peer_id: session.connection_id.0, + replica_id: replica_id.0 as u32, + user_id: guest_user_id.to_proto(), + }), + }, + ) + .trace_err(); + } + + // First, we send the metadata associated with each worktree. + response.send(proto::JoinProjectResponse { + worktrees: worktrees.clone(), + replica_id: replica_id.0 as u32, + collaborators: collaborators.clone(), + language_servers: project.language_servers.clone(), + })?; + + for (worktree_id, worktree) in project.worktrees { + #[cfg(any(test, feature = "test-support"))] + const MAX_CHUNK_SIZE: usize = 2; + #[cfg(not(any(test, feature = "test-support")))] + const MAX_CHUNK_SIZE: usize = 256; + + // Stream this worktree's entries. + let message = proto::UpdateWorktree { + project_id: project_id.to_proto(), + worktree_id: worktree_id.to_proto(), + abs_path: worktree.abs_path.clone(), + root_name: worktree.root_name, + updated_entries: worktree.entries, + removed_entries: Default::default(), + scan_id: worktree.scan_id, + is_last_update: worktree.is_complete, + }; + for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) { + session.peer.send(session.connection_id, update.clone())?; + } + + // Stream this worktree's diagnostics. + for summary in worktree.diagnostic_summaries { + session.peer.send( + session.connection_id, + proto::UpdateDiagnosticSummary { + project_id: project_id.to_proto(), + worktree_id: worktree.id.to_proto(), + summary: Some(summary), + }, + )?; + } + } + + for language_server in &project.language_servers { + session.peer.send( + session.connection_id, + proto::UpdateLanguageServer { + project_id: project_id.to_proto(), + language_server_id: language_server.id, + variant: Some( + proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated( + proto::LspDiskBasedDiagnosticsUpdated {}, + ), + ), + }, + )?; + } + + Ok(()) +} + +async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> { + let sender_id = session.connection_id; + let project_id = ProjectId::from_proto(request.project_id); + let project; + { + project = session + .db() + .await + .leave_project(project_id, sender_id) + .await?; + tracing::info!( + %project_id, + host_user_id = %project.host_user_id, + host_connection_id = %project.host_connection_id, + "leave project" + ); + + broadcast(sender_id, project.connection_ids, |conn_id| { + session.peer.send( + conn_id, + proto::RemoveProjectCollaborator { + project_id: project_id.to_proto(), + peer_id: sender_id.0, + }, + ) + }); + } + + Ok(()) +} + +async fn update_project( + request: proto::UpdateProject, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let (room, guest_connection_ids) = session + .db() + .await + .update_project(project_id, session.connection_id, &request.worktrees) + .await?; + broadcast( + session.connection_id, + guest_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + room_updated(&room, &session); + response.send(proto::Ack {})?; + + Ok(()) +} + +async fn update_worktree( + request: proto::UpdateWorktree, + response: Response, + session: Session, +) -> Result<()> { + let guest_connection_ids = session + .db() + .await + .update_worktree(&request, session.connection_id) + .await?; + + broadcast( + session.connection_id, + guest_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + response.send(proto::Ack {})?; + Ok(()) +} + +async fn update_diagnostic_summary( + request: proto::UpdateDiagnosticSummary, + response: Response, + session: Session, +) -> Result<()> { + let guest_connection_ids = session + .db() + .await + .update_diagnostic_summary(&request, session.connection_id) + .await?; + + broadcast( + session.connection_id, + guest_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + + response.send(proto::Ack {})?; + Ok(()) +} + +async fn start_language_server( + request: proto::StartLanguageServer, + session: Session, +) -> Result<()> { + let guest_connection_ids = session + .db() + .await + .start_language_server(&request, session.connection_id) + .await?; + + broadcast( + session.connection_id, + guest_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} + +async fn update_language_server( + request: proto::UpdateLanguageServer, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + broadcast( + session.connection_id, + project_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} + +async fn forward_project_request( + request: T, + response: Response, + session: Session, +) -> Result<()> +where + T: EntityMessage + RequestMessage, +{ + let project_id = ProjectId::from_proto(request.remote_entity_id()); + let collaborators = session + .db() + .await + .project_collaborators(project_id, session.connection_id) + .await?; + let host = collaborators + .iter() + .find(|collaborator| collaborator.is_host) + .ok_or_else(|| anyhow!("host not found"))?; + + let payload = session + .peer + .forward_request( + session.connection_id, + ConnectionId(host.connection_id as u32), + request, + ) + .await?; + + response.send(payload)?; + Ok(()) +} + +async fn save_buffer( + request: proto::SaveBuffer, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let collaborators = session + .db() + .await + .project_collaborators(project_id, session.connection_id) + .await?; + let host = collaborators + .into_iter() + .find(|collaborator| collaborator.is_host) + .ok_or_else(|| anyhow!("host not found"))?; + let host_connection_id = ConnectionId(host.connection_id as u32); + let response_payload = session + .peer + .forward_request(session.connection_id, host_connection_id, request.clone()) + .await?; + + let mut collaborators = session + .db() + .await + .project_collaborators(project_id, session.connection_id) + .await?; + collaborators + .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32); + let project_connection_ids = collaborators + .into_iter() + .map(|collaborator| ConnectionId(collaborator.connection_id as u32)); + broadcast(host_connection_id, project_connection_ids, |conn_id| { + session + .peer + .forward_send(host_connection_id, conn_id, response_payload.clone()) + }); + response.send(response_payload)?; + Ok(()) +} + +async fn create_buffer_for_peer( + request: proto::CreateBufferForPeer, + session: Session, +) -> Result<()> { + session.peer.forward_send( + session.connection_id, + ConnectionId(request.peer_id), + request, + )?; + Ok(()) +} + +async fn update_buffer( + request: proto::UpdateBuffer, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + + broadcast( + session.connection_id, + project_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + response.send(proto::Ack {})?; + Ok(()) +} + +async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + + broadcast( + session.connection_id, + project_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} + +async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + broadcast( + session.connection_id, + project_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} + +async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + broadcast( + session.connection_id, + project_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} + +async fn follow( + request: proto::Follow, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let leader_id = ConnectionId(request.leader_id); + let follower_id = session.connection_id; + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + + if !project_connection_ids.contains(&leader_id) { + Err(anyhow!("no such peer"))?; + } + + let mut response_payload = session + .peer + .forward_request(session.connection_id, leader_id, request) + .await?; + response_payload + .views + .retain(|view| view.leader_id != Some(follower_id.0)); + response.send(response_payload)?; + Ok(()) +} + +async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let leader_id = ConnectionId(request.leader_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + if !project_connection_ids.contains(&leader_id) { + Err(anyhow!("no such peer"))?; + } + session + .peer + .forward_send(session.connection_id, leader_id, request)?; + Ok(()) +} + +async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db + .lock() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + + let leader_id = request.variant.as_ref().and_then(|variant| match variant { + proto::update_followers::Variant::CreateView(payload) => payload.leader_id, + proto::update_followers::Variant::UpdateView(payload) => payload.leader_id, + proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id, + }); + for follower_id in &request.follower_ids { + let follower_id = ConnectionId(*follower_id); + if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id { + session + .peer + .forward_send(session.connection_id, follower_id, request.clone())?; + } + } + Ok(()) +} + +async fn get_users( + request: proto::GetUsers, + response: Response, + session: Session, +) -> Result<()> { + let user_ids = request + .user_ids + .into_iter() + .map(UserId::from_proto) + .collect(); + let users = session + .db() + .await + .get_users_by_ids(user_ids) + .await? + .into_iter() + .map(|user| proto::User { + id: user.id.to_proto(), + avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), + github_login: user.github_login, + }) + .collect(); + response.send(proto::UsersResponse { users })?; + Ok(()) +} + +async fn fuzzy_search_users( + request: proto::FuzzySearchUsers, + response: Response, + session: Session, +) -> Result<()> { + let query = request.query; + let users = match query.len() { + 0 => vec![], + 1 | 2 => session + .db() + .await + .get_user_by_github_account(&query, None) + .await? + .into_iter() + .collect(), + _ => session.db().await.fuzzy_search_users(&query, 10).await?, + }; + let users = users + .into_iter() + .filter(|user| user.id != session.user_id) + .map(|user| proto::User { + id: user.id.to_proto(), + avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), + github_login: user.github_login, + }) + .collect(); + response.send(proto::UsersResponse { users })?; + Ok(()) +} + +async fn request_contact( + request: proto::RequestContact, + response: Response, + session: Session, +) -> Result<()> { + let requester_id = session.user_id; + let responder_id = UserId::from_proto(request.responder_id); + if requester_id == responder_id { + return Err(anyhow!("cannot add yourself as a contact"))?; + } + + session + .db() + .await + .send_contact_request(requester_id, responder_id) + .await?; + + // Update outgoing contact requests of requester + let mut update = proto::UpdateContacts::default(); + update.outgoing_requests.push(responder_id.to_proto()); + for connection_id in session + .connection_pool() + .await + .user_connection_ids(requester_id) + { + session.peer.send(connection_id, update.clone())?; + } + + // Update incoming contact requests of responder + let mut update = proto::UpdateContacts::default(); + update + .incoming_requests + .push(proto::IncomingContactRequest { + requester_id: requester_id.to_proto(), + should_notify: true, + }); + for connection_id in session + .connection_pool() + .await + .user_connection_ids(responder_id) + { + session.peer.send(connection_id, update.clone())?; + } + + response.send(proto::Ack {})?; + Ok(()) +} + +async fn respond_to_contact_request( + request: proto::RespondToContactRequest, + response: Response, + session: Session, +) -> Result<()> { + let responder_id = session.user_id; + let requester_id = UserId::from_proto(request.requester_id); + let db = session.db().await; + if request.response == proto::ContactRequestResponse::Dismiss as i32 { + db.dismiss_contact_notification(responder_id, requester_id) + .await?; + } else { + let accept = request.response == proto::ContactRequestResponse::Accept as i32; + + db.respond_to_contact_request(responder_id, requester_id, accept) + .await?; + let busy = db.is_user_busy(requester_id).await?; + + let pool = session.connection_pool().await; + // Update responder with new contact + let mut update = proto::UpdateContacts::default(); + if accept { + update + .contacts + .push(contact_for_user(requester_id, false, busy, &pool)); + } + update + .remove_incoming_requests + .push(requester_id.to_proto()); + for connection_id in pool.user_connection_ids(responder_id) { + session.peer.send(connection_id, update.clone())?; + } + + // Update requester with new contact + let mut update = proto::UpdateContacts::default(); + if accept { + update + .contacts + .push(contact_for_user(responder_id, true, busy, &pool)); + } + update + .remove_outgoing_requests + .push(responder_id.to_proto()); + for connection_id in pool.user_connection_ids(requester_id) { + session.peer.send(connection_id, update.clone())?; + } + } + + response.send(proto::Ack {})?; + Ok(()) +} + +async fn remove_contact( + request: proto::RemoveContact, + response: Response, + session: Session, +) -> Result<()> { + let requester_id = session.user_id; + let responder_id = UserId::from_proto(request.user_id); + let db = session.db().await; + db.remove_contact(requester_id, responder_id).await?; + + let pool = session.connection_pool().await; + // Update outgoing contact requests of requester + let mut update = proto::UpdateContacts::default(); + update + .remove_outgoing_requests + .push(responder_id.to_proto()); + for connection_id in pool.user_connection_ids(requester_id) { + session.peer.send(connection_id, update.clone())?; + } + + // Update incoming contact requests of responder + let mut update = proto::UpdateContacts::default(); + update + .remove_incoming_requests + .push(requester_id.to_proto()); + for connection_id in pool.user_connection_ids(responder_id) { + session.peer.send(connection_id, update.clone())?; + } + + response.send(proto::Ack {})?; + Ok(()) +} + +async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + broadcast( + session.connection_id, + project_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} + +async fn get_private_user_info( + _request: proto::GetPrivateUserInfo, + response: Response, + session: Session, +) -> Result<()> { + let metrics_id = session + .db() + .await + .get_user_metrics_id(session.user_id) + .await?; + let user = session + .db() + .await + .get_user_by_id(session.user_id) + .await? + .ok_or_else(|| anyhow!("user not found"))?; + response.send(proto::GetPrivateUserInfoResponse { + metrics_id, + staff: user.admin, + })?; + Ok(()) +} + fn to_axum_message(message: TungsteniteMessage) -> AxumMessage { match message { TungsteniteMessage::Text(payload) => AxumMessage::Text(payload), @@ -1941,6 +1775,137 @@ fn contact_for_user( } } +fn room_updated(room: &proto::Room, session: &Session) { + for participant in &room.participants { + session + .peer + .send( + ConnectionId(participant.peer_id), + proto::RoomUpdated { + room: Some(room.clone()), + }, + ) + .trace_err(); + } +} + +async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> { + let db = session.db().await; + let contacts = db.get_contacts(user_id).await?; + let busy = db.is_user_busy(user_id).await?; + + let pool = session.connection_pool().await; + let updated_contact = contact_for_user(user_id, false, busy, &pool); + for contact in contacts { + if let db::Contact::Accepted { + user_id: contact_user_id, + .. + } = contact + { + for contact_conn_id in pool.user_connection_ids(contact_user_id) { + session + .peer + .send( + contact_conn_id, + proto::UpdateContacts { + contacts: vec![updated_contact.clone()], + remove_contacts: Default::default(), + incoming_requests: Default::default(), + remove_incoming_requests: Default::default(), + outgoing_requests: Default::default(), + remove_outgoing_requests: Default::default(), + }, + ) + .trace_err(); + } + } + } + Ok(()) +} + +async fn leave_room_for_session(session: &Session) -> Result<()> { + let mut contacts_to_update = HashSet::default(); + + let Some(left_room) = session.db().await.leave_room(session.connection_id).await? else { + return Err(anyhow!("no room to leave"))?; + }; + contacts_to_update.insert(session.user_id); + + for project in left_room.left_projects.into_values() { + for connection_id in project.connection_ids { + if project.host_user_id == session.user_id { + session + .peer + .send( + connection_id, + proto::UnshareProject { + project_id: project.id.to_proto(), + }, + ) + .trace_err(); + } else { + session + .peer + .send( + connection_id, + proto::RemoveProjectCollaborator { + project_id: project.id.to_proto(), + peer_id: session.connection_id.0, + }, + ) + .trace_err(); + } + } + + session + .peer + .send( + session.connection_id, + proto::UnshareProject { + project_id: project.id.to_proto(), + }, + ) + .trace_err(); + } + + room_updated(&left_room.room, &session); + { + let pool = session.connection_pool().await; + for canceled_user_id in left_room.canceled_calls_to_user_ids { + for connection_id in pool.user_connection_ids(canceled_user_id) { + session + .peer + .send(connection_id, proto::CallCanceled {}) + .trace_err(); + } + contacts_to_update.insert(canceled_user_id); + } + } + + for contact_user_id in contacts_to_update { + update_user_contacts(contact_user_id, &session).await?; + } + + if let Some(live_kit) = session.live_kit_client.as_ref() { + live_kit + .remove_participant( + left_room.room.live_kit_room.clone(), + session.connection_id.to_string(), + ) + .await + .trace_err(); + + if left_room.room.participants.is_empty() { + live_kit + .delete_room(left_room.room.live_kit_room) + .await + .trace_err(); + } + } + + Ok(()) +} + pub trait ResultExt { type Ok;