Close connection when RpcClient is dropped and add unit tests

This commit is contained in:
Antonio Scandurra 2021-06-14 14:32:49 +02:00
parent b2b1ce5e81
commit 8e3f40bfdd
2 changed files with 149 additions and 38 deletions

View file

@ -1,7 +1,7 @@
use anyhow::{anyhow, Context, Result};
use gpui::{AsyncAppContext, MutableAppContext, Task};
use rpc_client::RpcClient;
use std::{convert::TryFrom, net::Shutdown, time::Duration};
use std::{convert::TryFrom, time::Duration};
use tiny_http::{Header, Response, Server};
use url::Url;
use util::SurfResultExt;
@ -60,9 +60,7 @@ fn share_worktree(_: &(), cx: &mut MutableAppContext) {
// a TLS stream using `native-tls`.
let stream = smol::net::TcpStream::connect(rpc_address).await?;
let mut rpc_client = RpcClient::new(stream, executor, |stream| {
stream.shutdown(Shutdown::Read).ok();
});
let mut rpc_client = RpcClient::new(stream, executor);
let auth_response = rpc_client
.request(proto::from_client::Auth {

View file

@ -5,60 +5,81 @@ use postage::{
oneshot,
prelude::{Sink, Stream},
};
use smol::prelude::{AsyncRead, AsyncWrite};
use smol::{
future::FutureExt,
io::WriteHalf,
prelude::{AsyncRead, AsyncWrite},
};
use std::{collections::HashMap, sync::Arc};
use zed_rpc::proto::{self, MessageStream, RequestMessage, SendMessage, ServerMessage};
pub struct RpcClient<Conn, ShutdownFn>
where
ShutdownFn: FnMut(&mut Conn),
{
stream: MessageStream<Conn>,
pub struct RpcClient<Conn> {
stream: MessageStream<WriteHalf<Conn>>,
response_channels: Arc<Mutex<HashMap<i32, oneshot::Sender<proto::from_server::Variant>>>>,
next_message_id: i32,
shutdown_fn: ShutdownFn,
_drop_tx: oneshot::Sender<()>,
}
impl<Conn, ShutdownFn> RpcClient<Conn, ShutdownFn>
impl<Conn> RpcClient<Conn>
where
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
ShutdownFn: FnMut(&mut Conn),
{
pub fn new(conn: Conn, executor: Arc<Background>, shutdown_fn: ShutdownFn) -> Self {
pub fn new(conn: Conn, executor: Arc<Background>) -> Self {
let (conn_rx, conn_tx) = smol::io::split(conn);
let (drop_tx, mut drop_rx) = oneshot::channel();
let response_channels = Arc::new(Mutex::new(HashMap::new()));
let result = Self {
let client = Self {
next_message_id: 0,
stream: MessageStream::new(conn.clone()),
stream: MessageStream::new(conn_tx),
response_channels: response_channels.clone(),
shutdown_fn,
_drop_tx: drop_tx,
};
executor
.spawn::<Result<()>, _>(async move {
let mut stream = MessageStream::new(conn);
enum Message {
Message(proto::FromServer),
ClientDropped,
}
let mut stream = MessageStream::new(conn_rx);
let client_dropped = async move {
assert!(drop_rx.recv().await.is_none());
Ok(Message::ClientDropped) as Result<_>
};
smol::pin!(client_dropped);
loop {
let message = stream.read_message::<proto::FromServer>().await?;
if let Some(variant) = message.variant {
if let Some(request_id) = message.request_id {
let tx = response_channels.lock().remove(&request_id);
if let Some(mut tx) = tx {
tx.send(variant).await?;
let message = async {
Ok(Message::Message(
stream.read_message::<proto::FromServer>().await?,
))
};
match message.race(&mut client_dropped).await? {
Message::Message(message) => {
if let Some(variant) = message.variant {
if let Some(request_id) = message.request_id {
let tx = response_channels.lock().remove(&request_id);
if let Some(mut tx) = tx {
tx.send(variant).await?;
} else {
log::warn!(
"received RPC response to unknown request id {}",
request_id
);
}
}
} else {
log::warn!(
"received RPC response to unknown request id {}",
request_id
);
log::warn!("received RPC message with no content");
}
}
} else {
log::warn!("received RPC message with no content");
Message::ClientDropped => break Ok(()),
}
}
})
.detach();
result
client
}
pub async fn request<T: RequestMessage>(&mut self, req: T) -> Result<T::Response> {
@ -87,11 +108,103 @@ where
}
}
impl<Conn, ShutdownFn> Drop for RpcClient<Conn, ShutdownFn>
where
ShutdownFn: FnMut(&mut Conn),
{
fn drop(&mut self) {
(self.shutdown_fn)(self.stream.inner_mut())
#[cfg(test)]
mod tests {
use super::*;
use smol::{
future::poll_once,
io::AsyncWriteExt,
net::unix::{UnixListener, UnixStream},
};
use std::{future::Future, io};
use tempdir::TempDir;
#[gpui::test]
async fn test_request_response(cx: gpui::TestAppContext) {
let executor = cx.read(|app| app.background_executor().clone());
let socket_dir_path = TempDir::new("request-response-socket").unwrap();
let socket_path = socket_dir_path.path().join(".sock");
let listener = UnixListener::bind(&socket_path).unwrap();
let client_conn = UnixStream::connect(&socket_path).await.unwrap();
let (server_conn, _) = listener.accept().await.unwrap();
let mut server_stream = MessageStream::new(server_conn);
let mut client = RpcClient::new(client_conn, executor.clone());
let client_req = client.request(proto::from_client::Auth {
user_id: 42,
access_token: "token".to_string(),
});
smol::pin!(client_req);
let server_req = send_recv(
&mut client_req,
server_stream.read_message::<proto::FromClient>(),
)
.await
.unwrap();
assert_eq!(
server_req.variant,
Some(proto::from_client::Variant::Auth(
proto::from_client::Auth {
user_id: 42,
access_token: "token".to_string()
}
))
);
server_stream
.write_message(&proto::FromServer {
request_id: Some(server_req.id),
variant: Some(proto::from_server::Variant::AuthResponse(
proto::from_server::AuthResponse {
credentials_valid: true,
},
)),
})
.await
.unwrap();
assert_eq!(
client_req.await.unwrap(),
proto::from_server::AuthResponse {
credentials_valid: true
}
);
}
#[gpui::test]
async fn test_drop_client(cx: gpui::TestAppContext) {
let executor = cx.read(|app| app.background_executor().clone());
let socket_dir_path = TempDir::new("request-response-socket").unwrap();
let socket_path = socket_dir_path.path().join(".sock");
let listener = UnixListener::bind(&socket_path).unwrap();
let client_conn = UnixStream::connect(&socket_path).await.unwrap();
let (mut server_conn, _) = listener.accept().await.unwrap();
let client = RpcClient::new(client_conn, executor.clone());
drop(client);
// Try sending an empty payload over and over, until the client is dropped and hangs up.
let error = loop {
match server_conn.write(&[0]).await {
Ok(_) => continue,
Err(err) => break err,
}
};
assert_eq!(error.kind(), io::ErrorKind::BrokenPipe);
}
async fn send_recv<S, R, O>(mut sender: S, receiver: R) -> O
where
S: Unpin + Future,
R: Future<Output = O>,
{
smol::pin!(receiver);
loop {
poll_once(&mut sender).await;
match poll_once(&mut receiver).await {
Some(message) => break message,
None => continue,
}
}
}
}