diff --git a/zed/src/lib.rs b/zed/src/lib.rs index cb4f3cb8e0..56b696fb4b 100644 --- a/zed/src/lib.rs +++ b/zed/src/lib.rs @@ -60,7 +60,7 @@ fn share_worktree(_: &(), cx: &mut MutableAppContext) { // a TLS stream using `native-tls`. let stream = smol::net::TcpStream::connect(rpc_address).await?; - let mut rpc_client = RpcClient::new(stream, executor); + let rpc_client = RpcClient::new(stream, executor); let auth_response = rpc_client .request(proto::from_client::Auth { diff --git a/zed/src/rpc_client.rs b/zed/src/rpc_client.rs index 44e7c28407..6953b8e760 100644 --- a/zed/src/rpc_client.rs +++ b/zed/src/rpc_client.rs @@ -1,12 +1,15 @@ use anyhow::{anyhow, Result}; -use futures::FutureExt; +use futures::future::Either; use gpui::executor::Background; use parking_lot::Mutex; use postage::{ - mpsc, + barrier, mpsc, oneshot, prelude::{Sink, Stream}, }; -use smol::prelude::{AsyncRead, AsyncWrite}; +use smol::{ + io::{ReadHalf, WriteHalf}, + prelude::{AsyncRead, AsyncWrite}, +}; use std::{ collections::HashMap, io, @@ -21,8 +24,9 @@ use zed_rpc::proto::{ pub struct RpcClient { response_channels: Arc, bool)>>>, - outgoing_tx: mpsc::Sender, + outgoing_tx: mpsc::Sender<(proto::FromClient, oneshot::Sender>)>, next_message_id: AtomicI32, + _drop_tx: barrier::Sender, } impl RpcClient { @@ -31,82 +35,109 @@ impl RpcClient { Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let response_channels = Arc::new(Mutex::new(HashMap::new())); - let (outgoing_tx, mut outgoing_rx) = mpsc::channel(32); + let (conn_rx, conn_tx) = smol::io::split(conn); + let (outgoing_tx, outgoing_rx) = mpsc::channel(32); + let (_drop_tx, drop_rx) = barrier::channel(); - { - let response_channels = response_channels.clone(); - executor - .spawn(async move { - let (conn_rx, conn_tx) = smol::io::split(conn); - let mut stream_tx = MessageStream::new(conn_tx); - let mut stream_rx = MessageStream::new(conn_rx); - loop { - futures::select! { - incoming = stream_rx.read_message::().fuse() => { - Self::handle_incoming(incoming, &response_channels).await; - } - outgoing = outgoing_rx.recv().fuse() => { - if let Some(outgoing) = outgoing { - stream_tx.write_message(&outgoing).await; - } else { - break; - } - } - } - } - }) - .detach(); - } + executor + .spawn(Self::handle_incoming( + conn_rx, + drop_rx, + response_channels.clone(), + )) + .detach(); + + executor + .spawn(Self::handle_outgoing(conn_tx, outgoing_rx)) + .detach(); Self { response_channels, outgoing_tx, + _drop_tx, next_message_id: AtomicI32::new(0), } } - async fn handle_incoming( - incoming: io::Result, - response_channels: &Mutex, bool)>>, - ) { - match incoming { - Ok(incoming) => { - if let Some(variant) = incoming.variant { - if let Some(request_id) = incoming.request_id { - let channel = response_channels.lock().remove(&request_id); - if let Some((mut tx, oneshot)) = channel { - if tx.send(variant).await.is_ok() { - if !oneshot { - response_channels.lock().insert(request_id, (tx, false)); + async fn handle_incoming( + conn: ReadHalf, + mut drop_rx: barrier::Receiver, + response_channels: Arc< + Mutex, bool)>>, + >, + ) where + Conn: AsyncRead + Unpin, + { + let mut stream = MessageStream::new(conn); + loop { + let read_message = stream.read_message::(); + let dropped = drop_rx.recv(); + smol::pin!(read_message); + smol::pin!(dropped); + let result = futures::future::select(&mut read_message, &mut dropped).await; + match result { + Either::Left((Ok(incoming), _)) => { + if let Some(variant) = incoming.variant { + if let Some(request_id) = incoming.request_id { + let channel = response_channels.lock().remove(&request_id); + if let Some((mut tx, oneshot)) = channel { + if tx.send(variant).await.is_ok() { + if !oneshot { + response_channels.lock().insert(request_id, (tx, false)); + } } + } else { + log::warn!( + "received RPC response to unknown request id {}", + request_id + ); } - } else { - log::warn!( - "received RPC response to unknown request id {}", - request_id - ); } + } else { + log::warn!("received RPC message with no content"); } - } else { - log::warn!("received RPC message with no content"); + } + Either::Left((Err(error), _)) => { + log::warn!("invalid incoming RPC message {:?}", error) + } + Either::Right(_) => { + eprintln!("done with incoming loop"); + break; } } - Err(error) => log::warn!("invalid incoming RPC message {:?}", error), + } + } + + async fn handle_outgoing( + conn: WriteHalf, + mut outgoing_rx: mpsc::Receiver<(proto::FromClient, oneshot::Sender>)>, + ) where + Conn: AsyncWrite + Unpin, + { + let mut stream = MessageStream::new(conn); + while let Some((message, mut result_tx)) = outgoing_rx.recv().await { + let result = stream.write_message(&message).await; + result_tx.send(result).await.unwrap(); } } pub async fn request(&self, req: T) -> Result { let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst); + let (result_tx, mut result_rx) = oneshot::channel(); let (tx, mut rx) = mpsc::channel(1); self.response_channels.lock().insert(message_id, (tx, true)); self.outgoing_tx .clone() - .send(proto::FromClient { - id: message_id, - variant: Some(req.to_variant()), - }) + .send(( + proto::FromClient { + id: message_id, + variant: Some(req.to_variant()), + }, + result_tx, + )) .await - .unwrap(); + .ok(); + result_rx.recv().await.unwrap()?; let response = rx .recv() .await @@ -117,14 +148,19 @@ impl RpcClient { pub async fn send(&self, message: T) -> Result<()> { let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst); + let (result_tx, mut result_rx) = oneshot::channel(); self.outgoing_tx .clone() - .send(proto::FromClient { - id: message_id, - variant: Some(message.to_variant()), - }) + .send(( + proto::FromClient { + id: message_id, + variant: Some(message.to_variant()), + }, + result_tx, + )) .await - .unwrap(); + .ok(); + result_rx.recv().await.unwrap()?; Ok(()) } @@ -133,18 +169,23 @@ impl RpcClient { subscription: T, ) -> Result>> { let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst); + let (result_tx, mut result_rx) = oneshot::channel(); let (tx, rx) = mpsc::channel(256); self.response_channels .lock() .insert(message_id, (tx, false)); self.outgoing_tx .clone() - .send(proto::FromClient { - id: message_id, - variant: Some(subscription.to_variant()), - }) + .send(( + proto::FromClient { + id: message_id, + variant: Some(subscription.to_variant()), + }, + result_tx, + )) .await - .unwrap(); + .ok(); + result_rx.recv().await.unwrap()?; Ok(rx.map(|event| { T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}")) })) @@ -251,6 +292,29 @@ mod tests { } } + #[gpui::test] + async fn test_io_error(cx: gpui::TestAppContext) { + let executor = cx.read(|app| app.background_executor().clone()); + let socket_dir_path = TempDir::new("request-response-socket").unwrap(); + let socket_path = socket_dir_path.path().join(".sock"); + let _listener = UnixListener::bind(&socket_path).unwrap(); + let mut client_conn = UnixStream::connect(&socket_path).await.unwrap(); + client_conn.close().await.unwrap(); + + let client = RpcClient::new(client_conn, executor.clone()); + let err = client + .request(proto::from_client::Auth { + user_id: 42, + access_token: "token".to_string(), + }) + .await + .unwrap_err(); + assert_eq!( + err.downcast_ref::().unwrap().kind(), + io::ErrorKind::BrokenPipe + ); + } + async fn send_recv(mut sender: S, receiver: R) -> O where S: Unpin + Future,