From dff812b38ee97490c0eeea4178f964f99d69cf58 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 10 Jan 2022 16:04:49 +0100 Subject: [PATCH] Don't panic when dropping a subscription in a subscription handler --- crates/client/src/client.rs | 55 ++++++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index a8356bcea0..93cce9294a 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -125,7 +125,7 @@ struct ClientState { entity_id_extractors: HashMap u64>>, model_handlers: HashMap< (TypeId, u64), - Box, &mut AsyncAppContext)>, + Option, &mut AsyncAppContext)>>, >, _maintain_connection: Option>, heartbeat_interval: Duration, @@ -285,7 +285,7 @@ impl Client { state.model_handlers.insert( subscription_id, - Box::new(move |envelope, cx| { + Some(Box::new(move |envelope, cx| { if let Some(model) = model.upgrade(cx) { let envelope = envelope.into_any().downcast::>().unwrap(); model.update(cx, |model, cx| { @@ -294,7 +294,7 @@ impl Client { } }); } - }), + })), ); Subscription { @@ -335,7 +335,7 @@ impl Client { }); let prev_handler = state.model_handlers.insert( subscription_id, - Box::new(move |envelope, cx| { + Some(Box::new(move |envelope, cx| { if let Some(model) = model.upgrade(cx) { let envelope = envelope.into_any().downcast::>().unwrap(); model.update(cx, |model, cx| { @@ -344,7 +344,7 @@ impl Client { } }); } - }), + })), ); if prev_handler.is_some() { panic!("registered a handler for the same entity twice") @@ -450,7 +450,8 @@ impl Client { let payload_type_id = message.payload_type_id(); let entity_id = (extract_entity_id)(message.as_ref()); let handler_key = (payload_type_id, entity_id); - if let Some(mut handler) = state.model_handlers.remove(&handler_key) { + if let Some(handler) = state.model_handlers.get_mut(&handler_key) { + let mut handler = handler.take().unwrap(); drop(state); // Avoid deadlocks if the handler interacts with rpc::Client let start_time = Instant::now(); log::info!("RPC client message {}", message.payload_type_name()); @@ -459,10 +460,11 @@ impl Client { "RPC message handled. duration:{:?}", start_time.elapsed() ); - this.state - .write() - .model_handlers - .insert(handler_key, handler); + + let mut state = this.state.write(); + if state.model_handlers.contains_key(&handler_key) { + state.model_handlers.insert(handler_key, Some(handler)); + } } else { log::info!("unhandled message {}", message.payload_type_name()); } @@ -813,4 +815,37 @@ mod tests { ); assert_eq!(decode_worktree_url("not://the-right-format"), None); } + + #[gpui::test] + async fn test_dropping_subscription_in_handler(mut cx: TestAppContext) { + cx.foreground().forbid_parking(); + + let user_id = 5; + let mut client = Client::new(FakeHttpClient::with_404_response()); + let server = FakeServer::for_client(user_id, &mut client, &cx).await; + + let model = cx.add_model(|_| Model { subscription: None }); + let (done_tx, mut done_rx) = postage::oneshot::channel(); + let mut done_tx = Some(done_tx); + model.update(&mut cx, |model, cx| { + model.subscription = Some(client.subscribe( + cx, + move |model, _: TypedEnvelope, _, _| { + model.subscription.take(); + postage::sink::Sink::try_send(&mut done_tx.take().unwrap(), ()).unwrap(); + Ok(()) + }, + )); + }); + server.send(proto::Ping {}).await; + done_rx.recv().await.unwrap(); + } + + struct Model { + subscription: Option, + } + + impl Entity for Model { + type Event = (); + } }