From 28aa1567ce8d814a9a3ffbcd1b566a1b343907d4 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 7 Nov 2022 15:40:02 +0100 Subject: [PATCH] Include `sender_user_id` when handling a server message/request --- crates/collab/src/rpc.rs | 465 +++++++++++++++++++++++---------------- 1 file changed, 276 insertions(+), 189 deletions(-) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 7bc2b43b9b..757c765838 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -68,8 +68,15 @@ lazy_static! { .unwrap(); } -type MessageHandler = - Box, Box) -> BoxFuture<'static, ()>>; +type MessageHandler = Box< + dyn Send + Sync + Fn(Arc, UserId, Box) -> BoxFuture<'static, ()>, +>; + +struct Message { + sender_user_id: UserId, + sender_connection_id: ConnectionId, + payload: T, +} struct Response { server: Arc, @@ -193,15 +200,15 @@ impl Server { Arc::new(server) } - fn add_message_handler(&mut self, handler: F) -> &mut Self + fn add_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, TypedEnvelope) -> Fut, + F: 'static + Send + Sync + Fn(Arc, UserId, TypedEnvelope) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { let prev_handler = self.handlers.insert( TypeId::of::(), - Box::new(move |server, envelope| { + Box::new(move |server, sender_user_id, envelope| { let envelope = envelope.into_any().downcast::>().unwrap(); let span = info_span!( "handle message", @@ -213,7 +220,7 @@ impl Server { "message received" ); }); - let future = (handler)(server, *envelope); + let future = (handler)(server, sender_user_id, *envelope); async move { if let Err(error) = future.await { tracing::error!(%error, "error handling message"); @@ -229,26 +236,50 @@ impl Server { self } + fn add_message_handler(&mut self, handler: F) -> &mut Self + where + F: 'static + Send + Sync + Fn(Arc, Message) -> Fut, + Fut: 'static + Send + Future>, + M: EnvelopedMessage, + { + self.add_handler(move |server, sender_user_id, envelope| { + handler( + server, + Message { + sender_user_id, + sender_connection_id: envelope.sender_id, + payload: envelope.payload, + }, + ) + }); + self + } + /// Handle a request while holding a lock to the store. This is useful when we're registering /// a connection but we want to respond on the connection before anybody else can send on it. fn add_request_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, TypedEnvelope, Response) -> Fut, + F: 'static + Send + Sync + Fn(Arc, Message, Response) -> Fut, Fut: Send + Future>, M: RequestMessage, { let handler = Arc::new(handler); - self.add_message_handler(move |server, envelope| { + self.add_handler(move |server, sender_user_id, envelope| { let receipt = envelope.receipt(); let handler = handler.clone(); async move { + let request = Message { + sender_user_id, + sender_connection_id: envelope.sender_id, + payload: envelope.payload, + }; let responded = Arc::new(AtomicBool::default()); let response = Response { server: server.clone(), responded: responded.clone(), - receipt: envelope.receipt(), + receipt, }; - match (handler)(server.clone(), envelope, response).await { + match (handler)(server.clone(), request, response).await { Ok(()) => { if responded.load(std::sync::atomic::Ordering::SeqCst) { Ok(()) @@ -361,7 +392,7 @@ impl Server { let span_enter = span.enter(); if let Some(handler) = this.handlers.get(&message.payload_type_id()) { let is_background = message.is_background(); - let handle_message = (handler)(this.clone(), message); + let handle_message = (handler)(this.clone(), user_id, message); drop(span_enter); let handle_message = handle_message.instrument(span); @@ -516,7 +547,7 @@ impl Server { async fn ping( self: Arc, - _: TypedEnvelope, + _: Message, response: Response, ) -> Result<()> { response.send(proto::Ack {})?; @@ -525,15 +556,13 @@ impl Server { async fn create_room( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let user_id; let room; { let mut store = self.store().await; - user_id = store.user_id_for_connection(request.sender_id)?; - room = store.create_room(request.sender_id)?.clone(); + room = store.create_room(request.sender_connection_id)?.clone(); } let live_kit_connection_info = @@ -544,7 +573,10 @@ impl Server { .trace_err() { if let Some(token) = live_kit - .room_token(&room.live_kit_room, &request.sender_id.to_string()) + .room_token( + &room.live_kit_room, + &request.sender_connection_id.to_string(), + ) .trace_err() { Some(proto::LiveKitConnectionInfo { @@ -565,21 +597,19 @@ impl Server { room: Some(room), live_kit_connection_info, })?; - self.update_user_contacts(user_id).await?; + self.update_user_contacts(request.sender_user_id).await?; Ok(()) } async fn join_room( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let user_id; { let mut store = self.store().await; - user_id = store.user_id_for_connection(request.sender_id)?; let (room, recipient_connection_ids) = - store.join_room(request.payload.id, request.sender_id)?; + store.join_room(request.payload.id, request.sender_connection_id)?; for recipient_id in recipient_connection_ids { self.peer .send(recipient_id, proto::CallCanceled {}) @@ -589,7 +619,10 @@ impl Server { let live_kit_connection_info = if let Some(live_kit) = self.app_state.live_kit_client.as_ref() { if let Some(token) = live_kit - .room_token(&room.live_kit_room, &request.sender_id.to_string()) + .room_token( + &room.live_kit_room, + &request.sender_connection_id.to_string(), + ) .trace_err() { Some(proto::LiveKitConnectionInfo { @@ -609,18 +642,17 @@ impl Server { })?; self.room_updated(room); } - self.update_user_contacts(user_id).await?; + self.update_user_contacts(request.sender_user_id).await?; Ok(()) } - async fn leave_room(self: Arc, message: TypedEnvelope) -> Result<()> { + async fn leave_room(self: Arc, message: Message) -> Result<()> { let mut contacts_to_update = HashSet::default(); let room_left; { let mut store = self.store().await; - let user_id = store.user_id_for_connection(message.sender_id)?; - let left_room = store.leave_room(message.payload.id, message.sender_id)?; - contacts_to_update.insert(user_id); + let left_room = store.leave_room(message.payload.id, message.sender_connection_id)?; + contacts_to_update.insert(message.sender_user_id); for project in left_room.unshared_projects { for connection_id in project.connection_ids() { @@ -640,13 +672,13 @@ impl Server { connection_id, proto::RemoveProjectCollaborator { project_id: project.id.to_proto(), - peer_id: message.sender_id.0, + peer_id: message.sender_connection_id.0, }, )?; } self.peer.send( - message.sender_id, + message.sender_connection_id, proto::UnshareProject { project_id: project.id.to_proto(), }, @@ -655,7 +687,7 @@ impl Server { } self.room_updated(&left_room.room); - room_left = self.room_left(&left_room.room, message.sender_id); + room_left = self.room_left(&left_room.room, message.sender_connection_id); for connection_id in left_room.canceled_call_connection_ids { self.peer @@ -675,13 +707,10 @@ impl Server { async fn call( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let caller_user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; + let caller_user_id = request.sender_user_id; let recipient_user_id = UserId::from_proto(request.payload.recipient_user_id); let initial_project_id = request .payload @@ -703,7 +732,7 @@ impl Server { room_id, recipient_user_id, initial_project_id, - request.sender_id, + request.sender_connection_id, )?; self.room_updated(room); recipient_connection_ids @@ -740,7 +769,7 @@ impl Server { async fn cancel_call( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { let recipient_user_id = UserId::from_proto(request.payload.recipient_user_id); @@ -749,7 +778,7 @@ impl Server { let (room, recipient_connection_ids) = store.cancel_call( request.payload.room_id, recipient_user_id, - request.sender_id, + request.sender_connection_id, )?; for recipient_id in recipient_connection_ids { self.peer @@ -763,16 +792,12 @@ impl Server { Ok(()) } - async fn decline_call( - self: Arc, - message: TypedEnvelope, - ) -> Result<()> { - let recipient_user_id; + async fn decline_call(self: Arc, message: Message) -> Result<()> { + let recipient_user_id = message.sender_user_id; { let mut store = self.store().await; - recipient_user_id = store.user_id_for_connection(message.sender_id)?; let (room, recipient_connection_ids) = - store.decline_call(message.payload.room_id, message.sender_id)?; + store.decline_call(message.payload.room_id, message.sender_connection_id)?; for recipient_id in recipient_connection_ids { self.peer .send(recipient_id, proto::CallCanceled {}) @@ -786,7 +811,7 @@ impl Server { async fn update_participant_location( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { let room_id = request.payload.room_id; @@ -795,7 +820,8 @@ impl Server { .location .ok_or_else(|| anyhow!("invalid location"))?; let mut store = self.store().await; - let room = store.update_participant_location(room_id, location, request.sender_id)?; + let room = + store.update_participant_location(room_id, location, request.sender_connection_id)?; self.room_updated(room); response.send(proto::Ack {})?; Ok(()) @@ -839,20 +865,20 @@ impl Server { async fn share_project( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; - let project_id = self.app_state.db.register_project(user_id).await?; + let project_id = self + .app_state + .db + .register_project(request.sender_user_id) + .await?; let mut store = self.store().await; let room = store.share_project( request.payload.room_id, project_id, request.payload.worktrees, - request.sender_id, + request.sender_connection_id, )?; response.send(proto::ShareProjectResponse { project_id: project_id.to_proto(), @@ -864,13 +890,13 @@ impl Server { async fn unshare_project( self: Arc, - message: TypedEnvelope, + message: Message, ) -> Result<()> { let project_id = ProjectId::from_proto(message.payload.project_id); let mut store = self.store().await; - let (room, project) = store.unshare_project(project_id, message.sender_id)?; + let (room, project) = store.unshare_project(project_id, message.sender_connection_id)?; broadcast( - message.sender_id, + message.sender_connection_id, project.guest_connection_ids(), |conn_id| self.peer.send(conn_id, message.payload.clone()), ); @@ -911,26 +937,24 @@ impl Server { async fn join_project( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); - + let guest_user_id = request.sender_user_id; let host_user_id; - let guest_user_id; let host_connection_id; { let state = self.store().await; let project = state.project(project_id)?; host_user_id = project.host.user_id; host_connection_id = project.host_connection_id; - guest_user_id = state.user_id_for_connection(request.sender_id)?; }; tracing::info!(%project_id, %host_user_id, %host_connection_id, "join project"); let mut store = self.store().await; - let (project, replica_id) = store.join_project(request.sender_id, project_id)?; + let (project, replica_id) = store.join_project(request.sender_connection_id, project_id)?; let peer_count = project.guests.len(); let mut collaborators = Vec::with_capacity(peer_count); collaborators.push(proto::Collaborator { @@ -951,7 +975,7 @@ impl Server { // Add all guests other than the requesting user's own connections as collaborators for (guest_conn_id, guest) in &project.guests { - if request.sender_id != *guest_conn_id { + if request.sender_connection_id != *guest_conn_id { collaborators.push(proto::Collaborator { peer_id: guest_conn_id.0, replica_id: guest.replica_id as u32, @@ -961,14 +985,14 @@ impl Server { } for conn_id in project.connection_ids() { - if conn_id != request.sender_id { + if conn_id != request.sender_connection_id { self.peer .send( conn_id, proto::AddProjectCollaborator { project_id: project_id.to_proto(), collaborator: Some(proto::Collaborator { - peer_id: request.sender_id.0, + peer_id: request.sender_connection_id.0, replica_id: replica_id as u32, user_id: guest_user_id.to_proto(), }), @@ -1004,13 +1028,14 @@ impl Server { is_last_update: worktree.is_complete, }; for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) { - self.peer.send(request.sender_id, update.clone())?; + self.peer + .send(request.sender_connection_id, update.clone())?; } // Stream this worktree's diagnostics. for summary in worktree.diagnostic_summaries.values() { self.peer.send( - request.sender_id, + request.sender_connection_id, proto::UpdateDiagnosticSummary { project_id: project_id.to_proto(), worktree_id: *worktree_id, @@ -1022,7 +1047,7 @@ impl Server { for language_server in &project.language_servers { self.peer.send( - request.sender_id, + request.sender_connection_id, proto::UpdateLanguageServer { project_id: project_id.to_proto(), language_server_id: language_server.id, @@ -1038,11 +1063,8 @@ impl Server { Ok(()) } - async fn leave_project( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - let sender_id = request.sender_id; + async fn leave_project(self: Arc, request: Message) -> Result<()> { + let sender_id = request.sender_connection_id; let project_id = ProjectId::from_proto(request.payload.project_id); let project; { @@ -1073,20 +1095,30 @@ impl Server { async fn update_project( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); { let mut state = self.store().await; let guest_connection_ids = state - .read_project(project_id, request.sender_id)? + .read_project(project_id, request.sender_connection_id)? .guest_connection_ids(); - let room = - state.update_project(project_id, &request.payload.worktrees, request.sender_id)?; - broadcast(request.sender_id, guest_connection_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + let room = state.update_project( + project_id, + &request.payload.worktrees, + request.sender_connection_id, + )?; + broadcast( + request.sender_connection_id, + guest_connection_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); self.room_updated(room); }; @@ -1095,13 +1127,13 @@ impl Server { async fn update_worktree( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let worktree_id = request.payload.worktree_id; let connection_ids = self.store().await.update_worktree( - request.sender_id, + request.sender_connection_id, project_id, worktree_id, &request.payload.root_name, @@ -1111,17 +1143,24 @@ impl Server { request.payload.is_last_update, )?; - broadcast(request.sender_id, connection_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + connection_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); response.send(proto::Ack {})?; Ok(()) } async fn update_diagnostic_summary( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { let summary = request .payload @@ -1131,55 +1170,76 @@ impl Server { let receiver_ids = self.store().await.update_diagnostic_summary( ProjectId::from_proto(request.payload.project_id), request.payload.worktree_id, - request.sender_id, + request.sender_connection_id, summary, )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + receiver_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); Ok(()) } async fn start_language_server( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { let receiver_ids = self.store().await.start_language_server( ProjectId::from_proto(request.payload.project_id), - request.sender_id, + request.sender_connection_id, request .payload .server .clone() .ok_or_else(|| anyhow!("invalid language server"))?, )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + receiver_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); Ok(()) } async fn update_language_server( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { let receiver_ids = self.store().await.project_connection_ids( ProjectId::from_proto(request.payload.project_id), - request.sender_id, + request.sender_connection_id, )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + receiver_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); Ok(()) } async fn forward_project_request( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> where @@ -1189,17 +1249,21 @@ impl Server { let host_connection_id = self .store() .await - .read_project(project_id, request.sender_id)? + .read_project(project_id, request.sender_connection_id)? .host_connection_id; let payload = self .peer - .forward_request(request.sender_id, host_connection_id, request.payload) + .forward_request( + request.sender_connection_id, + host_connection_id, + request.payload, + ) .await?; // Ensure project still exists by the time we get the response from the host. self.store() .await - .read_project(project_id, request.sender_id)?; + .read_project(project_id, request.sender_connection_id)?; response.send(payload)?; Ok(()) @@ -1207,26 +1271,26 @@ impl Server { async fn save_buffer( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let host = self .store() .await - .read_project(project_id, request.sender_id)? + .read_project(project_id, request.sender_connection_id)? .host_connection_id; let response_payload = self .peer - .forward_request(request.sender_id, host, request.payload.clone()) + .forward_request(request.sender_connection_id, host, request.payload.clone()) .await?; let mut guests = self .store() .await - .read_project(project_id, request.sender_id)? + .read_project(project_id, request.sender_connection_id)? .connection_ids(); - guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id); + guests.retain(|guest_connection_id| *guest_connection_id != request.sender_connection_id); broadcast(host, guests, |conn_id| { self.peer .forward_send(host, conn_id, response_payload.clone()) @@ -1237,10 +1301,10 @@ impl Server { async fn create_buffer_for_peer( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { self.peer.forward_send( - request.sender_id, + request.sender_connection_id, ConnectionId(request.payload.peer_id), request.payload, )?; @@ -1249,76 +1313,101 @@ impl Server { async fn update_buffer( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let receiver_ids = { let store = self.store().await; - store.project_connection_ids(project_id, request.sender_id)? + store.project_connection_ids(project_id, request.sender_connection_id)? }; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + receiver_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); response.send(proto::Ack {})?; Ok(()) } async fn update_buffer_file( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { let receiver_ids = self.store().await.project_connection_ids( ProjectId::from_proto(request.payload.project_id), - request.sender_id, + request.sender_connection_id, )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + receiver_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); Ok(()) } async fn buffer_reloaded( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { let receiver_ids = self.store().await.project_connection_ids( ProjectId::from_proto(request.payload.project_id), - request.sender_id, + request.sender_connection_id, )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + receiver_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); Ok(()) } - async fn buffer_saved( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { + async fn buffer_saved(self: Arc, request: Message) -> Result<()> { let receiver_ids = self.store().await.project_connection_ids( ProjectId::from_proto(request.payload.project_id), - request.sender_id, + request.sender_connection_id, )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + receiver_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); Ok(()) } async fn follow( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let leader_id = ConnectionId(request.payload.leader_id); - let follower_id = request.sender_id; + let follower_id = request.sender_connection_id; { let store = self.store().await; if !store @@ -1331,7 +1420,7 @@ impl Server { let mut response_payload = self .peer - .forward_request(request.sender_id, leader_id, request.payload) + .forward_request(request.sender_connection_id, leader_id, request.payload) .await?; response_payload .views @@ -1340,28 +1429,29 @@ impl Server { Ok(()) } - async fn unfollow(self: Arc, request: TypedEnvelope) -> Result<()> { + async fn unfollow(self: Arc, request: Message) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let leader_id = ConnectionId(request.payload.leader_id); let store = self.store().await; if !store - .project_connection_ids(project_id, request.sender_id)? + .project_connection_ids(project_id, request.sender_connection_id)? .contains(&leader_id) { Err(anyhow!("no such peer"))?; } self.peer - .forward_send(request.sender_id, leader_id, request.payload)?; + .forward_send(request.sender_connection_id, leader_id, request.payload)?; Ok(()) } async fn update_followers( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let store = self.store().await; - let connection_ids = store.project_connection_ids(project_id, request.sender_id)?; + let connection_ids = + store.project_connection_ids(project_id, request.sender_connection_id)?; let leader_id = request .payload .variant @@ -1374,8 +1464,11 @@ impl Server { for follower_id in &request.payload.follower_ids { let follower_id = ConnectionId(*follower_id); if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id { - self.peer - .forward_send(request.sender_id, follower_id, request.payload.clone())?; + self.peer.forward_send( + request.sender_connection_id, + follower_id, + request.payload.clone(), + )?; } } Ok(()) @@ -1383,7 +1476,7 @@ impl Server { async fn get_users( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { let user_ids = request @@ -1410,13 +1503,9 @@ impl Server { async fn fuzzy_search_users( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; let query = request.payload.query; let db = &self.app_state.db; let users = match query.len() { @@ -1430,7 +1519,7 @@ impl Server { }; let users = users .into_iter() - .filter(|user| user.id != user_id) + .filter(|user| user.id != request.sender_user_id) .map(|user| proto::User { id: user.id.to_proto(), avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), @@ -1443,13 +1532,10 @@ impl Server { async fn request_contact( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let requester_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; + let requester_id = request.sender_user_id; let responder_id = UserId::from_proto(request.payload.responder_id); if requester_id == responder_id { return Err(anyhow!("cannot add yourself as a contact"))?; @@ -1485,13 +1571,10 @@ impl Server { async fn respond_to_contact_request( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let responder_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; + let responder_id = request.sender_user_id; let requester_id = UserId::from_proto(request.payload.requester_id); if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 { self.app_state @@ -1541,13 +1624,10 @@ impl Server { async fn remove_contact( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let requester_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; + let requester_id = request.sender_user_id; let responder_id = UserId::from_proto(request.payload.user_id); self.app_state .db @@ -1578,33 +1658,40 @@ impl Server { async fn update_diff_base( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { let receiver_ids = self.store().await.project_connection_ids( ProjectId::from_proto(request.payload.project_id), - request.sender_id, + request.sender_connection_id, )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + receiver_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); Ok(()) } async fn get_private_user_info( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; - let metrics_id = self.app_state.db.get_user_metrics_id(user_id).await?; + let metrics_id = self + .app_state + .db + .get_user_metrics_id(request.sender_user_id) + .await?; let user = self .app_state .db - .get_user_by_id(user_id) + .get_user_by_id(request.sender_user_id) .await? .ok_or_else(|| anyhow!("user not found"))?; response.send(proto::GetPrivateUserInfoResponse {