mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-12 13:24:19 +00:00
Close connection when RpcClient
is dropped and add unit tests
This commit is contained in:
parent
b2b1ce5e81
commit
8e3f40bfdd
2 changed files with 149 additions and 38 deletions
|
@ -1,7 +1,7 @@
|
|||
use anyhow::{anyhow, Context, Result};
|
||||
use gpui::{AsyncAppContext, MutableAppContext, Task};
|
||||
use rpc_client::RpcClient;
|
||||
use std::{convert::TryFrom, net::Shutdown, time::Duration};
|
||||
use std::{convert::TryFrom, time::Duration};
|
||||
use tiny_http::{Header, Response, Server};
|
||||
use url::Url;
|
||||
use util::SurfResultExt;
|
||||
|
@ -60,9 +60,7 @@ fn share_worktree(_: &(), cx: &mut MutableAppContext) {
|
|||
// a TLS stream using `native-tls`.
|
||||
let stream = smol::net::TcpStream::connect(rpc_address).await?;
|
||||
|
||||
let mut rpc_client = RpcClient::new(stream, executor, |stream| {
|
||||
stream.shutdown(Shutdown::Read).ok();
|
||||
});
|
||||
let mut rpc_client = RpcClient::new(stream, executor);
|
||||
|
||||
let auth_response = rpc_client
|
||||
.request(proto::from_client::Auth {
|
||||
|
|
|
@ -5,60 +5,81 @@ use postage::{
|
|||
oneshot,
|
||||
prelude::{Sink, Stream},
|
||||
};
|
||||
use smol::prelude::{AsyncRead, AsyncWrite};
|
||||
use smol::{
|
||||
future::FutureExt,
|
||||
io::WriteHalf,
|
||||
prelude::{AsyncRead, AsyncWrite},
|
||||
};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use zed_rpc::proto::{self, MessageStream, RequestMessage, SendMessage, ServerMessage};
|
||||
|
||||
pub struct RpcClient<Conn, ShutdownFn>
|
||||
where
|
||||
ShutdownFn: FnMut(&mut Conn),
|
||||
{
|
||||
stream: MessageStream<Conn>,
|
||||
pub struct RpcClient<Conn> {
|
||||
stream: MessageStream<WriteHalf<Conn>>,
|
||||
response_channels: Arc<Mutex<HashMap<i32, oneshot::Sender<proto::from_server::Variant>>>>,
|
||||
next_message_id: i32,
|
||||
shutdown_fn: ShutdownFn,
|
||||
_drop_tx: oneshot::Sender<()>,
|
||||
}
|
||||
|
||||
impl<Conn, ShutdownFn> RpcClient<Conn, ShutdownFn>
|
||||
impl<Conn> RpcClient<Conn>
|
||||
where
|
||||
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
ShutdownFn: FnMut(&mut Conn),
|
||||
{
|
||||
pub fn new(conn: Conn, executor: Arc<Background>, shutdown_fn: ShutdownFn) -> Self {
|
||||
pub fn new(conn: Conn, executor: Arc<Background>) -> Self {
|
||||
let (conn_rx, conn_tx) = smol::io::split(conn);
|
||||
let (drop_tx, mut drop_rx) = oneshot::channel();
|
||||
let response_channels = Arc::new(Mutex::new(HashMap::new()));
|
||||
|
||||
let result = Self {
|
||||
let client = Self {
|
||||
next_message_id: 0,
|
||||
stream: MessageStream::new(conn.clone()),
|
||||
stream: MessageStream::new(conn_tx),
|
||||
response_channels: response_channels.clone(),
|
||||
shutdown_fn,
|
||||
_drop_tx: drop_tx,
|
||||
};
|
||||
|
||||
executor
|
||||
.spawn::<Result<()>, _>(async move {
|
||||
let mut stream = MessageStream::new(conn);
|
||||
enum Message {
|
||||
Message(proto::FromServer),
|
||||
ClientDropped,
|
||||
}
|
||||
|
||||
let mut stream = MessageStream::new(conn_rx);
|
||||
let client_dropped = async move {
|
||||
assert!(drop_rx.recv().await.is_none());
|
||||
Ok(Message::ClientDropped) as Result<_>
|
||||
};
|
||||
smol::pin!(client_dropped);
|
||||
loop {
|
||||
let message = stream.read_message::<proto::FromServer>().await?;
|
||||
if let Some(variant) = message.variant {
|
||||
if let Some(request_id) = message.request_id {
|
||||
let tx = response_channels.lock().remove(&request_id);
|
||||
if let Some(mut tx) = tx {
|
||||
tx.send(variant).await?;
|
||||
let message = async {
|
||||
Ok(Message::Message(
|
||||
stream.read_message::<proto::FromServer>().await?,
|
||||
))
|
||||
};
|
||||
|
||||
match message.race(&mut client_dropped).await? {
|
||||
Message::Message(message) => {
|
||||
if let Some(variant) = message.variant {
|
||||
if let Some(request_id) = message.request_id {
|
||||
let tx = response_channels.lock().remove(&request_id);
|
||||
if let Some(mut tx) = tx {
|
||||
tx.send(variant).await?;
|
||||
} else {
|
||||
log::warn!(
|
||||
"received RPC response to unknown request id {}",
|
||||
request_id
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::warn!(
|
||||
"received RPC response to unknown request id {}",
|
||||
request_id
|
||||
);
|
||||
log::warn!("received RPC message with no content");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::warn!("received RPC message with no content");
|
||||
Message::ClientDropped => break Ok(()),
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
result
|
||||
client
|
||||
}
|
||||
|
||||
pub async fn request<T: RequestMessage>(&mut self, req: T) -> Result<T::Response> {
|
||||
|
@ -87,11 +108,103 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<Conn, ShutdownFn> Drop for RpcClient<Conn, ShutdownFn>
|
||||
where
|
||||
ShutdownFn: FnMut(&mut Conn),
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
(self.shutdown_fn)(self.stream.inner_mut())
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use smol::{
|
||||
future::poll_once,
|
||||
io::AsyncWriteExt,
|
||||
net::unix::{UnixListener, UnixStream},
|
||||
};
|
||||
use std::{future::Future, io};
|
||||
use tempdir::TempDir;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_request_response(cx: gpui::TestAppContext) {
|
||||
let executor = cx.read(|app| app.background_executor().clone());
|
||||
let socket_dir_path = TempDir::new("request-response-socket").unwrap();
|
||||
let socket_path = socket_dir_path.path().join(".sock");
|
||||
let listener = UnixListener::bind(&socket_path).unwrap();
|
||||
let client_conn = UnixStream::connect(&socket_path).await.unwrap();
|
||||
let (server_conn, _) = listener.accept().await.unwrap();
|
||||
|
||||
let mut server_stream = MessageStream::new(server_conn);
|
||||
let mut client = RpcClient::new(client_conn, executor.clone());
|
||||
|
||||
let client_req = client.request(proto::from_client::Auth {
|
||||
user_id: 42,
|
||||
access_token: "token".to_string(),
|
||||
});
|
||||
smol::pin!(client_req);
|
||||
let server_req = send_recv(
|
||||
&mut client_req,
|
||||
server_stream.read_message::<proto::FromClient>(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
server_req.variant,
|
||||
Some(proto::from_client::Variant::Auth(
|
||||
proto::from_client::Auth {
|
||||
user_id: 42,
|
||||
access_token: "token".to_string()
|
||||
}
|
||||
))
|
||||
);
|
||||
|
||||
server_stream
|
||||
.write_message(&proto::FromServer {
|
||||
request_id: Some(server_req.id),
|
||||
variant: Some(proto::from_server::Variant::AuthResponse(
|
||||
proto::from_server::AuthResponse {
|
||||
credentials_valid: true,
|
||||
},
|
||||
)),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
client_req.await.unwrap(),
|
||||
proto::from_server::AuthResponse {
|
||||
credentials_valid: true
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_drop_client(cx: gpui::TestAppContext) {
|
||||
let executor = cx.read(|app| app.background_executor().clone());
|
||||
let socket_dir_path = TempDir::new("request-response-socket").unwrap();
|
||||
let socket_path = socket_dir_path.path().join(".sock");
|
||||
let listener = UnixListener::bind(&socket_path).unwrap();
|
||||
let client_conn = UnixStream::connect(&socket_path).await.unwrap();
|
||||
let (mut server_conn, _) = listener.accept().await.unwrap();
|
||||
|
||||
let client = RpcClient::new(client_conn, executor.clone());
|
||||
drop(client);
|
||||
|
||||
// Try sending an empty payload over and over, until the client is dropped and hangs up.
|
||||
let error = loop {
|
||||
match server_conn.write(&[0]).await {
|
||||
Ok(_) => continue,
|
||||
Err(err) => break err,
|
||||
}
|
||||
};
|
||||
assert_eq!(error.kind(), io::ErrorKind::BrokenPipe);
|
||||
}
|
||||
|
||||
async fn send_recv<S, R, O>(mut sender: S, receiver: R) -> O
|
||||
where
|
||||
S: Unpin + Future,
|
||||
R: Future<Output = O>,
|
||||
{
|
||||
smol::pin!(receiver);
|
||||
loop {
|
||||
poll_once(&mut sender).await;
|
||||
match poll_once(&mut receiver).await {
|
||||
Some(message) => break message,
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue