diff --git a/zed/src/rpc_client.rs b/zed/src/rpc_client.rs index 2737f524f8..2358c791a0 100644 --- a/zed/src/rpc_client.rs +++ b/zed/src/rpc_client.rs @@ -2,7 +2,7 @@ use anyhow::{anyhow, Result}; use gpui::executor::Background; use parking_lot::Mutex; use postage::{ - oneshot, + mpsc, oneshot, prelude::{Sink, Stream}, }; use smol::{ @@ -11,11 +11,13 @@ use smol::{ prelude::{AsyncRead, AsyncWrite}, }; use std::{collections::HashMap, sync::Arc}; -use zed_rpc::proto::{self, MessageStream, RequestMessage, SendMessage, ServerMessage}; +use zed_rpc::proto::{ + self, MessageStream, RequestMessage, SendMessage, ServerMessage, SubscribeMessage, +}; pub struct RpcClient { stream: MessageStream>, - response_channels: Arc>>>, + response_channels: Arc, bool)>>>, next_message_id: i32, _drop_tx: oneshot::Sender<()>, } @@ -59,9 +61,15 @@ where Message::Message(message) => { if let Some(variant) = message.variant { if let Some(request_id) = message.request_id { - let tx = response_channels.lock().remove(&request_id); - if let Some(mut tx) = tx { - tx.send(variant).await?; + let channel = response_channels.lock().remove(&request_id); + if let Some((mut tx, oneshot)) = channel { + if tx.send(variant).await.is_ok() { + if !oneshot { + response_channels + .lock() + .insert(request_id, (tx, false)); + } + } } else { log::warn!( "received RPC response to unknown request id {}", @@ -85,10 +93,8 @@ where pub async fn request(&mut self, req: T) -> Result { let message_id = self.next_message_id; self.next_message_id += 1; - - let (tx, mut rx) = oneshot::channel(); - self.response_channels.lock().insert(message_id, tx); - + let (tx, mut rx) = mpsc::channel(1); + self.response_channels.lock().insert(message_id, (tx, true)); self.stream .write_message(&proto::FromClient { id: message_id, @@ -114,6 +120,28 @@ where .await?; Ok(()) } + + pub async fn subscribe( + &mut self, + subscription: T, + ) -> Result>> { + let message_id = self.next_message_id; + self.next_message_id += 1; + let (tx, rx) = mpsc::channel(256); + self.response_channels + .lock() + .insert(message_id, (tx, false)); + self.stream + .write_message(&proto::FromClient { + id: message_id, + variant: Some(subscription.to_variant()), + }) + .await?; + + Ok(rx.map(|event| { + T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}")) + })) + } } #[cfg(test)]