Fix termination of peer's incoming future

* Re-enable peer tests
* Enhance request/response unit test to exercise
  peers interacting with each other end-to-end
This commit is contained in:
Max Brunsfeld 2021-06-16 21:18:22 -07:00
parent fb736d5e78
commit c5cec247c4
3 changed files with 211 additions and 158 deletions

1
Cargo.lock generated
View file

@ -4364,6 +4364,7 @@ dependencies = [
"rsa",
"serde 1.0.125",
"smol",
"tempdir",
]
[[package]]

View file

@ -21,3 +21,4 @@ prost-build = { git="https://github.com/sfackler/prost", rev="082f3e65874fe91382
[dev-dependencies]
smol = "1.2.5"
tempdir = "0.3.7"

View file

@ -29,7 +29,6 @@ struct Connection {
writer: Mutex<MessageStream<BoxedWriter>>,
response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
next_message_id: AtomicU32,
_close_barrier: barrier::Sender,
}
type MessageHandler = Box<
@ -53,7 +52,7 @@ impl<T> TypedEnvelope<T> {
}
pub struct Peer {
connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
connections: RwLock<HashMap<ConnectionId, (Arc<Connection>, barrier::Sender)>>,
message_handlers: RwLock<Vec<MessageHandler>>,
handler_types: Mutex<HashSet<TypeId>>,
next_connection_id: AtomicU32,
@ -106,7 +105,7 @@ impl Peer {
pub async fn add_connection<Conn>(
self: &Arc<Self>,
conn: Conn,
) -> (ConnectionId, impl Future<Output = ()>)
) -> (ConnectionId, impl Future<Output = Result<()>>)
where
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
@ -119,13 +118,12 @@ impl Peer {
writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
response_channels: Default::default(),
next_message_id: Default::default(),
_close_barrier: close_tx,
});
self.connections
.write()
.await
.insert(connection_id, connection.clone());
.insert(connection_id, (connection.clone(), close_tx));
let this = self.clone();
let handler_future = async move {
@ -178,8 +176,9 @@ impl Peer {
}
Either::Left((Err(error), _)) => {
log::warn!("received invalid RPC message: {}", error);
Err(error)?;
}
Either::Right(_) => break,
Either::Right(_) => return Ok(()),
}
}
};
@ -199,13 +198,7 @@ impl Peer {
let this = self.clone();
let (tx, mut rx) = oneshot::channel();
async move {
let connection = this
.connections
.read()
.await
.get(&connection_id)
.ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
.clone();
let connection = this.connection(connection_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
@ -236,13 +229,7 @@ impl Peer {
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
let connection = this
.connections
.read()
.await
.get(&connection_id)
.ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
.clone();
let connection = this.connection(connection_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
@ -263,13 +250,7 @@ impl Peer {
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
let connection = this
.connections
.read()
.await
.get(&request.connection_id)
.ok_or_else(|| anyhow!("unknown connection: {}", request.connection_id.0))?
.clone();
let connection = this.connection(request.connection_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
@ -282,146 +263,216 @@ impl Peer {
Ok(())
}
}
async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
Ok(self
.connections
.read()
.await
.get(&id)
.ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
.0
.clone())
}
}
// #[cfg(test)]
// mod tests {
// use super::*;
// use smol::{
// future::poll_once,
// io::AsyncWriteExt,
// net::unix::{UnixListener, UnixStream},
// };
// use std::{future::Future, io};
// use tempdir::TempDir;
#[cfg(test)]
mod tests {
use super::*;
use smol::{
io::AsyncWriteExt,
net::unix::{UnixListener, UnixStream},
};
use std::io;
use tempdir::TempDir;
// #[gpui::test]
// async fn test_request_response(cx: gpui::TestAppContext) {
// let executor = cx.read(|app| app.background_executor().clone());
// let socket_dir_path = TempDir::new("request-response").unwrap();
// let socket_path = socket_dir_path.path().join(".sock");
// let listener = UnixListener::bind(&socket_path).unwrap();
// let client_conn = UnixStream::connect(&socket_path).await.unwrap();
// let (server_conn, _) = listener.accept().await.unwrap();
#[test]
fn test_request_response() {
smol::block_on(async move {
// create socket
let socket_dir_path = TempDir::new("test-request-response").unwrap();
let socket_path = socket_dir_path.path().join("test.sock");
let listener = UnixListener::bind(&socket_path).unwrap();
// let mut server_stream = MessageStream::new(server_conn);
// let client = Peer::new();
// let (connection_id, handler) = client.add_connection(client_conn).await;
// executor.spawn(handler).detach();
// create 2 clients connected to 1 server
let server = Peer::new();
let client1 = Peer::new();
let client2 = Peer::new();
let (client1_conn_id, f1) = client1
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
.await;
let (client2_conn_id, f2) = client2
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
.await;
let (_, f3) = server
.add_connection(listener.accept().await.unwrap().0)
.await;
let (_, f4) = server
.add_connection(listener.accept().await.unwrap().0)
.await;
smol::spawn(f1).detach();
smol::spawn(f2).detach();
smol::spawn(f3).detach();
smol::spawn(f4).detach();
// let client_req = client.request(
// connection_id,
// proto::Auth {
// user_id: 42,
// access_token: "token".to_string(),
// },
// );
// smol::pin!(client_req);
// let server_req = send_recv(&mut client_req, server_stream.read_message())
// .await
// .unwrap();
// assert_eq!(
// server_req.payload,
// Some(proto::envelope::Payload::Auth(proto::Auth {
// user_id: 42,
// access_token: "token".to_string()
// }))
// );
// define the expected requests and responses
let request1 = proto::OpenWorktree {
worktree_id: 101,
access_token: "first-worktree-access-token".to_string(),
};
let response1 = proto::OpenWorktreeResponse {
worktree: Some(proto::Worktree {
paths: vec![b"path/one".to_vec()],
}),
};
let request2 = proto::OpenWorktree {
worktree_id: 102,
access_token: "second-worktree-access-token".to_string(),
};
let response2 = proto::OpenWorktreeResponse {
worktree: Some(proto::Worktree {
paths: vec![b"path/two".to_vec(), b"path/three".to_vec()],
}),
};
let request3 = proto::OpenBuffer {
worktree_id: 102,
path: b"path/two".to_vec(),
};
let response3 = proto::OpenBufferResponse {
buffer: Some(proto::Buffer {
id: 1001,
path: b"path/two".to_vec(),
content: b"path/two content".to_vec(),
history: vec![],
}),
};
let request4 = proto::OpenBuffer {
worktree_id: 101,
path: b"path/one".to_vec(),
};
let response4 = proto::OpenBufferResponse {
buffer: Some(proto::Buffer {
id: 1002,
path: b"path/one".to_vec(),
content: b"path/one content".to_vec(),
history: vec![],
}),
};
// // Respond to another request to ensure requests are properly matched up.
// server_stream
// .write_message(
// &proto::AuthResponse {
// credentials_valid: false,
// }
// .into_envelope(1000, Some(999)),
// )
// .await
// .unwrap();
// server_stream
// .write_message(
// &proto::AuthResponse {
// credentials_valid: true,
// }
// .into_envelope(1001, Some(server_req.id)),
// )
// .await
// .unwrap();
// assert_eq!(
// client_req.await.unwrap(),
// proto::AuthResponse {
// credentials_valid: true
// }
// );
// }
// on the server, respond to two requests for each client
let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
let mut open_worktree_rx = server.add_message_handler::<proto::OpenWorktree>().await;
let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
smol::spawn({
let request1 = request1.clone();
let request2 = request2.clone();
let request3 = request3.clone();
let request4 = request4.clone();
let response1 = response1.clone();
let response2 = response2.clone();
let response3 = response3.clone();
let response4 = response4.clone();
async move {
let msg = open_worktree_rx.recv().await.unwrap();
assert_eq!(msg.payload, request1);
server.respond(msg, response1.clone()).await.unwrap();
// #[gpui::test]
// async fn test_disconnect(cx: gpui::TestAppContext) {
// let executor = cx.read(|app| app.background_executor().clone());
// let socket_dir_path = TempDir::new("drop-client").unwrap();
// let socket_path = socket_dir_path.path().join(".sock");
// let listener = UnixListener::bind(&socket_path).unwrap();
// let client_conn = UnixStream::connect(&socket_path).await.unwrap();
// let (mut server_conn, _) = listener.accept().await.unwrap();
let msg = open_worktree_rx.recv().await.unwrap();
assert_eq!(msg.payload, request2.clone());
server.respond(msg, response2.clone()).await.unwrap();
// let client = Peer::new();
// let (connection_id, handler) = client.add_connection(client_conn).await;
// executor.spawn(handler).detach();
// client.disconnect(connection_id).await;
let msg = open_buffer_rx.recv().await.unwrap();
assert_eq!(msg.payload, request3.clone());
server.respond(msg, response3.clone()).await.unwrap();
// // Try sending an empty payload over and over, until the client is dropped and hangs up.
// loop {
// match server_conn.write(&[]).await {
// Ok(_) => {}
// Err(err) => {
// if err.kind() == io::ErrorKind::BrokenPipe {
// break;
// }
// }
// }
// }
// }
let msg = open_buffer_rx.recv().await.unwrap();
assert_eq!(msg.payload, request4.clone());
server.respond(msg, response4.clone()).await.unwrap();
// #[gpui::test]
// async fn test_io_error(cx: gpui::TestAppContext) {
// let executor = cx.read(|app| app.background_executor().clone());
// let socket_dir_path = TempDir::new("io-error").unwrap();
// let socket_path = socket_dir_path.path().join(".sock");
// let _listener = UnixListener::bind(&socket_path).unwrap();
// let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
// client_conn.close().await.unwrap();
server_done_tx.send(()).await.unwrap();
}
})
.detach();
// let client = Peer::new();
// let (connection_id, handler) = client.add_connection(client_conn).await;
// executor.spawn(handler).detach();
// let err = client
// .request(
// connection_id,
// proto::Auth {
// user_id: 42,
// access_token: "token".to_string(),
// },
// )
// .await
// .unwrap_err();
// assert_eq!(
// err.downcast_ref::<io::Error>().unwrap().kind(),
// io::ErrorKind::BrokenPipe
// );
// }
assert_eq!(
client1.request(client1_conn_id, request1).await.unwrap(),
response1
);
assert_eq!(
client2.request(client2_conn_id, request2).await.unwrap(),
response2
);
assert_eq!(
client2.request(client2_conn_id, request3).await.unwrap(),
response3
);
assert_eq!(
client1.request(client1_conn_id, request4).await.unwrap(),
response4
);
// async fn send_recv<S, R, O>(mut sender: S, receiver: R) -> O
// where
// S: Unpin + Future,
// R: Future<Output = O>,
// {
// smol::pin!(receiver);
// loop {
// poll_once(&mut sender).await;
// match poll_once(&mut receiver).await {
// Some(message) => break message,
// None => continue,
// }
// }
// }
// }
client1.disconnect(client1_conn_id).await;
client2.disconnect(client1_conn_id).await;
server_done_rx.recv().await.unwrap();
});
}
#[test]
fn test_disconnect() {
smol::block_on(async move {
let socket_dir_path = TempDir::new("drop-client").unwrap();
let socket_path = socket_dir_path.path().join(".sock");
let listener = UnixListener::bind(&socket_path).unwrap();
let client_conn = UnixStream::connect(&socket_path).await.unwrap();
let (mut server_conn, _) = listener.accept().await.unwrap();
let client = Peer::new();
let (connection_id, handler) = client.add_connection(client_conn).await;
smol::spawn(handler).detach();
client.disconnect(connection_id).await;
// Try sending an empty payload over and over, until the client is dropped and hangs up.
loop {
match server_conn.write(&[]).await {
Ok(_) => {}
Err(err) => {
if err.kind() == io::ErrorKind::BrokenPipe {
break;
}
}
}
}
});
}
#[test]
fn test_io_error() {
smol::block_on(async move {
let socket_dir_path = TempDir::new("io-error").unwrap();
let socket_path = socket_dir_path.path().join(".sock");
let _listener = UnixListener::bind(&socket_path).unwrap();
let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
client_conn.close().await.unwrap();
let client = Peer::new();
let (connection_id, handler) = client.add_connection(client_conn).await;
smol::spawn(handler).detach();
let err = client
.request(
connection_id,
proto::Auth {
user_id: 42,
access_token: "token".to_string(),
},
)
.await
.unwrap_err();
assert_eq!(
err.downcast_ref::<io::Error>().unwrap().kind(),
io::ErrorKind::BrokenPipe
);
});
}
}