From 8e3f40bfdd99e9b238aad2c00330cd7496338360 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 14 Jun 2021 14:32:49 +0200 Subject: [PATCH] Close connection when `RpcClient` is dropped and add unit tests --- zed/src/lib.rs | 6 +- zed/src/rpc_client.rs | 181 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 149 insertions(+), 38 deletions(-) diff --git a/zed/src/lib.rs b/zed/src/lib.rs index 2afc728ab0..cb4f3cb8e0 100644 --- a/zed/src/lib.rs +++ b/zed/src/lib.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, Context, Result}; use gpui::{AsyncAppContext, MutableAppContext, Task}; use rpc_client::RpcClient; -use std::{convert::TryFrom, net::Shutdown, time::Duration}; +use std::{convert::TryFrom, time::Duration}; use tiny_http::{Header, Response, Server}; use url::Url; use util::SurfResultExt; @@ -60,9 +60,7 @@ fn share_worktree(_: &(), cx: &mut MutableAppContext) { // a TLS stream using `native-tls`. let stream = smol::net::TcpStream::connect(rpc_address).await?; - let mut rpc_client = RpcClient::new(stream, executor, |stream| { - stream.shutdown(Shutdown::Read).ok(); - }); + let mut rpc_client = RpcClient::new(stream, executor); let auth_response = rpc_client .request(proto::from_client::Auth { diff --git a/zed/src/rpc_client.rs b/zed/src/rpc_client.rs index c9ca534a4f..003bf3f3a4 100644 --- a/zed/src/rpc_client.rs +++ b/zed/src/rpc_client.rs @@ -5,60 +5,81 @@ use postage::{ oneshot, prelude::{Sink, Stream}, }; -use smol::prelude::{AsyncRead, AsyncWrite}; +use smol::{ + future::FutureExt, + io::WriteHalf, + prelude::{AsyncRead, AsyncWrite}, +}; use std::{collections::HashMap, sync::Arc}; use zed_rpc::proto::{self, MessageStream, RequestMessage, SendMessage, ServerMessage}; -pub struct RpcClient -where - ShutdownFn: FnMut(&mut Conn), -{ - stream: MessageStream, +pub struct RpcClient { + stream: MessageStream>, response_channels: Arc>>>, next_message_id: i32, - shutdown_fn: ShutdownFn, + _drop_tx: oneshot::Sender<()>, } -impl RpcClient +impl RpcClient where Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static, - ShutdownFn: FnMut(&mut Conn), { - pub fn new(conn: Conn, executor: Arc, shutdown_fn: ShutdownFn) -> Self { + pub fn new(conn: Conn, executor: Arc) -> Self { + let (conn_rx, conn_tx) = smol::io::split(conn); + let (drop_tx, mut drop_rx) = oneshot::channel(); let response_channels = Arc::new(Mutex::new(HashMap::new())); - - let result = Self { + let client = Self { next_message_id: 0, - stream: MessageStream::new(conn.clone()), + stream: MessageStream::new(conn_tx), response_channels: response_channels.clone(), - shutdown_fn, + _drop_tx: drop_tx, }; executor .spawn::, _>(async move { - let mut stream = MessageStream::new(conn); + enum Message { + Message(proto::FromServer), + ClientDropped, + } + + let mut stream = MessageStream::new(conn_rx); + let client_dropped = async move { + assert!(drop_rx.recv().await.is_none()); + Ok(Message::ClientDropped) as Result<_> + }; + smol::pin!(client_dropped); loop { - let message = stream.read_message::().await?; - if let Some(variant) = message.variant { - if let Some(request_id) = message.request_id { - let tx = response_channels.lock().remove(&request_id); - if let Some(mut tx) = tx { - tx.send(variant).await?; + let message = async { + Ok(Message::Message( + stream.read_message::().await?, + )) + }; + + match message.race(&mut client_dropped).await? { + Message::Message(message) => { + if let Some(variant) = message.variant { + if let Some(request_id) = message.request_id { + let tx = response_channels.lock().remove(&request_id); + if let Some(mut tx) = tx { + tx.send(variant).await?; + } else { + log::warn!( + "received RPC response to unknown request id {}", + request_id + ); + } + } } else { - log::warn!( - "received RPC response to unknown request id {}", - request_id - ); + log::warn!("received RPC message with no content"); } } - } else { - log::warn!("received RPC message with no content"); + Message::ClientDropped => break Ok(()), } } }) .detach(); - result + client } pub async fn request(&mut self, req: T) -> Result { @@ -87,11 +108,103 @@ where } } -impl Drop for RpcClient -where - ShutdownFn: FnMut(&mut Conn), -{ - fn drop(&mut self) { - (self.shutdown_fn)(self.stream.inner_mut()) +#[cfg(test)] +mod tests { + use super::*; + use smol::{ + future::poll_once, + io::AsyncWriteExt, + net::unix::{UnixListener, UnixStream}, + }; + use std::{future::Future, io}; + use tempdir::TempDir; + + #[gpui::test] + async fn test_request_response(cx: gpui::TestAppContext) { + let executor = cx.read(|app| app.background_executor().clone()); + let socket_dir_path = TempDir::new("request-response-socket").unwrap(); + let socket_path = socket_dir_path.path().join(".sock"); + let listener = UnixListener::bind(&socket_path).unwrap(); + let client_conn = UnixStream::connect(&socket_path).await.unwrap(); + let (server_conn, _) = listener.accept().await.unwrap(); + + let mut server_stream = MessageStream::new(server_conn); + let mut client = RpcClient::new(client_conn, executor.clone()); + + let client_req = client.request(proto::from_client::Auth { + user_id: 42, + access_token: "token".to_string(), + }); + smol::pin!(client_req); + let server_req = send_recv( + &mut client_req, + server_stream.read_message::(), + ) + .await + .unwrap(); + assert_eq!( + server_req.variant, + Some(proto::from_client::Variant::Auth( + proto::from_client::Auth { + user_id: 42, + access_token: "token".to_string() + } + )) + ); + + server_stream + .write_message(&proto::FromServer { + request_id: Some(server_req.id), + variant: Some(proto::from_server::Variant::AuthResponse( + proto::from_server::AuthResponse { + credentials_valid: true, + }, + )), + }) + .await + .unwrap(); + assert_eq!( + client_req.await.unwrap(), + proto::from_server::AuthResponse { + credentials_valid: true + } + ); + } + + #[gpui::test] + async fn test_drop_client(cx: gpui::TestAppContext) { + let executor = cx.read(|app| app.background_executor().clone()); + let socket_dir_path = TempDir::new("request-response-socket").unwrap(); + let socket_path = socket_dir_path.path().join(".sock"); + let listener = UnixListener::bind(&socket_path).unwrap(); + let client_conn = UnixStream::connect(&socket_path).await.unwrap(); + let (mut server_conn, _) = listener.accept().await.unwrap(); + + let client = RpcClient::new(client_conn, executor.clone()); + drop(client); + + // Try sending an empty payload over and over, until the client is dropped and hangs up. + let error = loop { + match server_conn.write(&[0]).await { + Ok(_) => continue, + Err(err) => break err, + } + }; + assert_eq!(error.kind(), io::ErrorKind::BrokenPipe); + } + + async fn send_recv(mut sender: S, receiver: R) -> O + where + S: Unpin + Future, + R: Future, + { + smol::pin!(receiver); + loop { + poll_once(&mut sender).await; + match poll_once(&mut receiver).await { + Some(message) => break message, + None => continue, + } + } } }