From c5cec247c46a39e3dbe3d2fe10ac3c812b682bcb Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 16 Jun 2021 21:18:22 -0700 Subject: [PATCH] Fix termination of peer's incoming future * Re-enable peer tests * Enhance request/response unit test to exercise peers interacting with each other end-to-end --- Cargo.lock | 1 + zed-rpc/Cargo.toml | 1 + zed-rpc/src/peer.rs | 367 +++++++++++++++++++++++++------------------- 3 files changed, 211 insertions(+), 158 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f548d0e887..20216c7be1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4364,6 +4364,7 @@ dependencies = [ "rsa", "serde 1.0.125", "smol", + "tempdir", ] [[package]] diff --git a/zed-rpc/Cargo.toml b/zed-rpc/Cargo.toml index 7cc83a2af5..76c194c4b9 100644 --- a/zed-rpc/Cargo.toml +++ b/zed-rpc/Cargo.toml @@ -21,3 +21,4 @@ prost-build = { git="https://github.com/sfackler/prost", rev="082f3e65874fe91382 [dev-dependencies] smol = "1.2.5" +tempdir = "0.3.7" diff --git a/zed-rpc/src/peer.rs b/zed-rpc/src/peer.rs index 2638a2827f..333e4b1672 100644 --- a/zed-rpc/src/peer.rs +++ b/zed-rpc/src/peer.rs @@ -29,7 +29,6 @@ struct Connection { writer: Mutex>, response_channels: Mutex>>, next_message_id: AtomicU32, - _close_barrier: barrier::Sender, } type MessageHandler = Box< @@ -53,7 +52,7 @@ impl TypedEnvelope { } pub struct Peer { - connections: RwLock>>, + connections: RwLock, barrier::Sender)>>, message_handlers: RwLock>, handler_types: Mutex>, next_connection_id: AtomicU32, @@ -106,7 +105,7 @@ impl Peer { pub async fn add_connection( self: &Arc, conn: Conn, - ) -> (ConnectionId, impl Future) + ) -> (ConnectionId, impl Future>) where Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -119,13 +118,12 @@ impl Peer { writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))), response_channels: Default::default(), next_message_id: Default::default(), - _close_barrier: close_tx, }); self.connections .write() .await - .insert(connection_id, connection.clone()); + .insert(connection_id, (connection.clone(), close_tx)); let this = self.clone(); let handler_future = async move { @@ -178,8 +176,9 @@ impl Peer { } Either::Left((Err(error), _)) => { log::warn!("received invalid RPC message: {}", error); + Err(error)?; } - Either::Right(_) => break, + Either::Right(_) => return Ok(()), } } }; @@ -199,13 +198,7 @@ impl Peer { let this = self.clone(); let (tx, mut rx) = oneshot::channel(); async move { - let connection = this - .connections - .read() - .await - .get(&connection_id) - .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))? - .clone(); + let connection = this.connection(connection_id).await?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); @@ -236,13 +229,7 @@ impl Peer { ) -> impl Future> { let this = self.clone(); async move { - let connection = this - .connections - .read() - .await - .get(&connection_id) - .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))? - .clone(); + let connection = this.connection(connection_id).await?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); @@ -263,13 +250,7 @@ impl Peer { ) -> impl Future> { let this = self.clone(); async move { - let connection = this - .connections - .read() - .await - .get(&request.connection_id) - .ok_or_else(|| anyhow!("unknown connection: {}", request.connection_id.0))? - .clone(); + let connection = this.connection(request.connection_id).await?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); @@ -282,146 +263,216 @@ impl Peer { Ok(()) } } + + async fn connection(&self, id: ConnectionId) -> Result> { + Ok(self + .connections + .read() + .await + .get(&id) + .ok_or_else(|| anyhow!("unknown connection: {}", id.0))? + .0 + .clone()) + } } -// #[cfg(test)] -// mod tests { -// use super::*; -// use smol::{ -// future::poll_once, -// io::AsyncWriteExt, -// net::unix::{UnixListener, UnixStream}, -// }; -// use std::{future::Future, io}; -// use tempdir::TempDir; +#[cfg(test)] +mod tests { + use super::*; + use smol::{ + io::AsyncWriteExt, + net::unix::{UnixListener, UnixStream}, + }; + use std::io; + use tempdir::TempDir; -// #[gpui::test] -// async fn test_request_response(cx: gpui::TestAppContext) { -// let executor = cx.read(|app| app.background_executor().clone()); -// let socket_dir_path = TempDir::new("request-response").unwrap(); -// let socket_path = socket_dir_path.path().join(".sock"); -// let listener = UnixListener::bind(&socket_path).unwrap(); -// let client_conn = UnixStream::connect(&socket_path).await.unwrap(); -// let (server_conn, _) = listener.accept().await.unwrap(); + #[test] + fn test_request_response() { + smol::block_on(async move { + // create socket + let socket_dir_path = TempDir::new("test-request-response").unwrap(); + let socket_path = socket_dir_path.path().join("test.sock"); + let listener = UnixListener::bind(&socket_path).unwrap(); -// let mut server_stream = MessageStream::new(server_conn); -// let client = Peer::new(); -// let (connection_id, handler) = client.add_connection(client_conn).await; -// executor.spawn(handler).detach(); + // create 2 clients connected to 1 server + let server = Peer::new(); + let client1 = Peer::new(); + let client2 = Peer::new(); + let (client1_conn_id, f1) = client1 + .add_connection(UnixStream::connect(&socket_path).await.unwrap()) + .await; + let (client2_conn_id, f2) = client2 + .add_connection(UnixStream::connect(&socket_path).await.unwrap()) + .await; + let (_, f3) = server + .add_connection(listener.accept().await.unwrap().0) + .await; + let (_, f4) = server + .add_connection(listener.accept().await.unwrap().0) + .await; + smol::spawn(f1).detach(); + smol::spawn(f2).detach(); + smol::spawn(f3).detach(); + smol::spawn(f4).detach(); -// let client_req = client.request( -// connection_id, -// proto::Auth { -// user_id: 42, -// access_token: "token".to_string(), -// }, -// ); -// smol::pin!(client_req); -// let server_req = send_recv(&mut client_req, server_stream.read_message()) -// .await -// .unwrap(); -// assert_eq!( -// server_req.payload, -// Some(proto::envelope::Payload::Auth(proto::Auth { -// user_id: 42, -// access_token: "token".to_string() -// })) -// ); + // define the expected requests and responses + let request1 = proto::OpenWorktree { + worktree_id: 101, + access_token: "first-worktree-access-token".to_string(), + }; + let response1 = proto::OpenWorktreeResponse { + worktree: Some(proto::Worktree { + paths: vec![b"path/one".to_vec()], + }), + }; + let request2 = proto::OpenWorktree { + worktree_id: 102, + access_token: "second-worktree-access-token".to_string(), + }; + let response2 = proto::OpenWorktreeResponse { + worktree: Some(proto::Worktree { + paths: vec![b"path/two".to_vec(), b"path/three".to_vec()], + }), + }; + let request3 = proto::OpenBuffer { + worktree_id: 102, + path: b"path/two".to_vec(), + }; + let response3 = proto::OpenBufferResponse { + buffer: Some(proto::Buffer { + id: 1001, + path: b"path/two".to_vec(), + content: b"path/two content".to_vec(), + history: vec![], + }), + }; + let request4 = proto::OpenBuffer { + worktree_id: 101, + path: b"path/one".to_vec(), + }; + let response4 = proto::OpenBufferResponse { + buffer: Some(proto::Buffer { + id: 1002, + path: b"path/one".to_vec(), + content: b"path/one content".to_vec(), + history: vec![], + }), + }; -// // Respond to another request to ensure requests are properly matched up. -// server_stream -// .write_message( -// &proto::AuthResponse { -// credentials_valid: false, -// } -// .into_envelope(1000, Some(999)), -// ) -// .await -// .unwrap(); -// server_stream -// .write_message( -// &proto::AuthResponse { -// credentials_valid: true, -// } -// .into_envelope(1001, Some(server_req.id)), -// ) -// .await -// .unwrap(); -// assert_eq!( -// client_req.await.unwrap(), -// proto::AuthResponse { -// credentials_valid: true -// } -// ); -// } + // on the server, respond to two requests for each client + let mut open_buffer_rx = server.add_message_handler::().await; + let mut open_worktree_rx = server.add_message_handler::().await; + let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>(); + smol::spawn({ + let request1 = request1.clone(); + let request2 = request2.clone(); + let request3 = request3.clone(); + let request4 = request4.clone(); + let response1 = response1.clone(); + let response2 = response2.clone(); + let response3 = response3.clone(); + let response4 = response4.clone(); + async move { + let msg = open_worktree_rx.recv().await.unwrap(); + assert_eq!(msg.payload, request1); + server.respond(msg, response1.clone()).await.unwrap(); -// #[gpui::test] -// async fn test_disconnect(cx: gpui::TestAppContext) { -// let executor = cx.read(|app| app.background_executor().clone()); -// let socket_dir_path = TempDir::new("drop-client").unwrap(); -// let socket_path = socket_dir_path.path().join(".sock"); -// let listener = UnixListener::bind(&socket_path).unwrap(); -// let client_conn = UnixStream::connect(&socket_path).await.unwrap(); -// let (mut server_conn, _) = listener.accept().await.unwrap(); + let msg = open_worktree_rx.recv().await.unwrap(); + assert_eq!(msg.payload, request2.clone()); + server.respond(msg, response2.clone()).await.unwrap(); -// let client = Peer::new(); -// let (connection_id, handler) = client.add_connection(client_conn).await; -// executor.spawn(handler).detach(); -// client.disconnect(connection_id).await; + let msg = open_buffer_rx.recv().await.unwrap(); + assert_eq!(msg.payload, request3.clone()); + server.respond(msg, response3.clone()).await.unwrap(); -// // Try sending an empty payload over and over, until the client is dropped and hangs up. -// loop { -// match server_conn.write(&[]).await { -// Ok(_) => {} -// Err(err) => { -// if err.kind() == io::ErrorKind::BrokenPipe { -// break; -// } -// } -// } -// } -// } + let msg = open_buffer_rx.recv().await.unwrap(); + assert_eq!(msg.payload, request4.clone()); + server.respond(msg, response4.clone()).await.unwrap(); -// #[gpui::test] -// async fn test_io_error(cx: gpui::TestAppContext) { -// let executor = cx.read(|app| app.background_executor().clone()); -// let socket_dir_path = TempDir::new("io-error").unwrap(); -// let socket_path = socket_dir_path.path().join(".sock"); -// let _listener = UnixListener::bind(&socket_path).unwrap(); -// let mut client_conn = UnixStream::connect(&socket_path).await.unwrap(); -// client_conn.close().await.unwrap(); + server_done_tx.send(()).await.unwrap(); + } + }) + .detach(); -// let client = Peer::new(); -// let (connection_id, handler) = client.add_connection(client_conn).await; -// executor.spawn(handler).detach(); -// let err = client -// .request( -// connection_id, -// proto::Auth { -// user_id: 42, -// access_token: "token".to_string(), -// }, -// ) -// .await -// .unwrap_err(); -// assert_eq!( -// err.downcast_ref::().unwrap().kind(), -// io::ErrorKind::BrokenPipe -// ); -// } + assert_eq!( + client1.request(client1_conn_id, request1).await.unwrap(), + response1 + ); + assert_eq!( + client2.request(client2_conn_id, request2).await.unwrap(), + response2 + ); + assert_eq!( + client2.request(client2_conn_id, request3).await.unwrap(), + response3 + ); + assert_eq!( + client1.request(client1_conn_id, request4).await.unwrap(), + response4 + ); -// async fn send_recv(mut sender: S, receiver: R) -> O -// where -// S: Unpin + Future, -// R: Future, -// { -// smol::pin!(receiver); -// loop { -// poll_once(&mut sender).await; -// match poll_once(&mut receiver).await { -// Some(message) => break message, -// None => continue, -// } -// } -// } -// } + client1.disconnect(client1_conn_id).await; + client2.disconnect(client1_conn_id).await; + + server_done_rx.recv().await.unwrap(); + }); + } + + #[test] + fn test_disconnect() { + smol::block_on(async move { + let socket_dir_path = TempDir::new("drop-client").unwrap(); + let socket_path = socket_dir_path.path().join(".sock"); + let listener = UnixListener::bind(&socket_path).unwrap(); + let client_conn = UnixStream::connect(&socket_path).await.unwrap(); + let (mut server_conn, _) = listener.accept().await.unwrap(); + + let client = Peer::new(); + let (connection_id, handler) = client.add_connection(client_conn).await; + smol::spawn(handler).detach(); + client.disconnect(connection_id).await; + + // Try sending an empty payload over and over, until the client is dropped and hangs up. + loop { + match server_conn.write(&[]).await { + Ok(_) => {} + Err(err) => { + if err.kind() == io::ErrorKind::BrokenPipe { + break; + } + } + } + } + }); + } + + #[test] + fn test_io_error() { + smol::block_on(async move { + let socket_dir_path = TempDir::new("io-error").unwrap(); + let socket_path = socket_dir_path.path().join(".sock"); + let _listener = UnixListener::bind(&socket_path).unwrap(); + let mut client_conn = UnixStream::connect(&socket_path).await.unwrap(); + client_conn.close().await.unwrap(); + + let client = Peer::new(); + let (connection_id, handler) = client.add_connection(client_conn).await; + smol::spawn(handler).detach(); + + let err = client + .request( + connection_id, + proto::Auth { + user_id: 42, + access_token: "token".to_string(), + }, + ) + .await + .unwrap_err(); + assert_eq!( + err.downcast_ref::().unwrap().kind(), + io::ErrorKind::BrokenPipe + ); + }); + } +}