diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 47f9aeb8e2..760037c8eb 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -124,7 +124,7 @@ struct ClientState { status: (watch::Sender, watch::Receiver), entity_id_extractors: HashMap u64>>, model_handlers: HashMap< - (TypeId, u64), + (TypeId, Option), Option, &mut AsyncAppContext)>>, >, _maintain_connection: Option>, @@ -152,14 +152,13 @@ impl Default for ClientState { pub struct Subscription { client: Weak, - id: (TypeId, u64), + id: (TypeId, Option), } impl Drop for Subscription { fn drop(&mut self) { if let Some(client) = self.client.upgrade() { let mut state = client.state.write(); - let _ = state.entity_id_extractors.remove(&self.id.0).unwrap(); let _ = state.model_handlers.remove(&self.id).unwrap(); } } @@ -267,18 +266,11 @@ impl Client { + Sync + FnMut(&mut M, TypedEnvelope, Arc, &mut ModelContext) -> Result<()>, { - let subscription_id = (TypeId::of::(), Default::default()); + let subscription_id = (TypeId::of::(), None); let client = self.clone(); let mut state = self.state.write(); let model = cx.weak_handle(); - let prev_extractor = state - .entity_id_extractors - .insert(subscription_id.0, Box::new(|_| Default::default())); - if prev_extractor.is_some() { - panic!("registered a handler for the same entity twice") - } - - state.model_handlers.insert( + let prev_handler = state.model_handlers.insert( subscription_id, Some(Box::new(move |envelope, cx| { if let Some(model) = model.upgrade(cx) { @@ -291,6 +283,9 @@ impl Client { } })), ); + if prev_handler.is_some() { + panic!("registered handler for the same message twice"); + } Subscription { client: Arc::downgrade(self), @@ -312,7 +307,7 @@ impl Client { + Sync + FnMut(&mut M, TypedEnvelope, Arc, &mut ModelContext) -> Result<()>, { - let subscription_id = (TypeId::of::(), remote_id); + let subscription_id = (TypeId::of::(), Some(remote_id)); let client = self.clone(); let mut state = self.state.write(); let model = cx.weak_handle(); @@ -439,29 +434,27 @@ impl Client { async move { while let Some(message) = incoming.recv().await { let mut state = this.state.write(); - if let Some(extract_entity_id) = + let payload_type_id = message.payload_type_id(); + let entity_id = if let Some(extract_entity_id) = state.entity_id_extractors.get(&message.payload_type_id()) { - let payload_type_id = message.payload_type_id(); - let entity_id = (extract_entity_id)(message.as_ref()); - let handler_key = (payload_type_id, entity_id); - if let Some(handler) = state.model_handlers.get_mut(&handler_key) { - let mut handler = handler.take().unwrap(); - drop(state); // Avoid deadlocks if the handler interacts with rpc::Client - let start_time = Instant::now(); - log::info!("RPC client message {}", message.payload_type_name()); - (handler)(message, &mut cx); - log::info!( - "RPC message handled. duration:{:?}", - start_time.elapsed() - ); + Some((extract_entity_id)(message.as_ref())) + } else { + None + }; - let mut state = this.state.write(); - if state.model_handlers.contains_key(&handler_key) { - state.model_handlers.insert(handler_key, Some(handler)); - } - } else { - log::info!("unhandled message {}", message.payload_type_name()); + let handler_key = (payload_type_id, entity_id); + if let Some(handler) = state.model_handlers.get_mut(&handler_key) { + let mut handler = handler.take().unwrap(); + drop(state); // Avoid deadlocks if the handler interacts with rpc::Client + let start_time = Instant::now(); + log::info!("RPC client message {}", message.payload_type_name()); + (handler)(message, &mut cx); + log::info!("RPC message handled. duration:{:?}", start_time.elapsed()); + + let mut state = this.state.write(); + if state.model_handlers.contains_key(&handler_key) { + state.model_handlers.insert(handler_key, Some(handler)); } } else { log::info!("unhandled message {}", message.payload_type_name()); @@ -811,6 +804,55 @@ mod tests { assert_eq!(decode_worktree_url("not://the-right-format"), None); } + #[gpui::test] + async fn test_subscribing_to_entity(mut cx: TestAppContext) { + cx.foreground().forbid_parking(); + + let user_id = 5; + let mut client = Client::new(FakeHttpClient::with_404_response()); + let server = FakeServer::for_client(user_id, &mut client, &cx).await; + + let model = cx.add_model(|_| Model { subscription: None }); + let (mut done_tx1, mut done_rx1) = postage::oneshot::channel(); + let (mut done_tx2, mut done_rx2) = postage::oneshot::channel(); + let _subscription1 = model.update(&mut cx, |_, cx| { + client.subscribe_to_entity( + 1, + cx, + move |_, _: TypedEnvelope, _, _| { + postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap(); + Ok(()) + }, + ) + }); + let _subscription2 = model.update(&mut cx, |_, cx| { + client.subscribe_to_entity( + 2, + cx, + move |_, _: TypedEnvelope, _, _| { + postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap(); + Ok(()) + }, + ) + }); + + // Ensure dropping a subscription for the same entity type still allows receiving of + // messages for other entity IDs of the same type. + let subscription3 = model.update(&mut cx, |_, cx| { + client.subscribe_to_entity( + 3, + cx, + move |_, _: TypedEnvelope, _, _| Ok(()), + ) + }); + drop(subscription3); + + server.send(proto::UnshareProject { project_id: 1 }).await; + server.send(proto::UnshareProject { project_id: 2 }).await; + done_rx1.recv().await.unwrap(); + done_rx2.recv().await.unwrap(); + } + #[gpui::test] async fn test_subscribing_after_dropping_subscription(mut cx: TestAppContext) { cx.foreground().forbid_parking();