From 9d51fe88e987511f7d52f978a82cfb9ae0e503ec Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 1 Jul 2021 18:12:46 -0700 Subject: [PATCH] Serialize RPC sends and responses using a channel --- zed-rpc/src/peer.rs | 344 +++++++++++++++++++++---------------------- zed-rpc/src/proto.rs | 25 +++- zed/src/rpc.rs | 6 +- 3 files changed, 190 insertions(+), 185 deletions(-) diff --git a/zed-rpc/src/peer.rs b/zed-rpc/src/peer.rs index 54baafff3e..bdab225c69 100644 --- a/zed-rpc/src/peer.rs +++ b/zed-rpc/src/peer.rs @@ -1,12 +1,9 @@ use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage}; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use async_lock::{Mutex, RwLock}; -use futures::{ - future::{BoxFuture, Either}, - AsyncRead, AsyncWrite, FutureExt, -}; +use futures::{future::BoxFuture, AsyncRead, AsyncWrite, FutureExt}; use postage::{ - barrier, mpsc, oneshot, + mpsc, prelude::{Sink, Stream}, }; use std::{ @@ -15,29 +12,18 @@ use std::{ fmt, future::Future, marker::PhantomData, - pin::Pin, sync::{ atomic::{self, AtomicU32}, Arc, }, }; -type BoxedWriter = Pin>; -type BoxedReader = Pin>; - #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub struct ConnectionId(pub u32); #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub struct PeerId(pub u32); -struct Connection { - writer: Mutex>, - reader: Mutex>, - response_channels: Mutex>>, - next_message_id: AtomicU32, -} - type MessageHandler = Box< dyn Send + Sync + Fn(&mut Option, ConnectionId) -> Option>, >; @@ -74,18 +60,34 @@ impl TypedEnvelope { } pub struct Peer { - connections: RwLock>>, - connection_close_barriers: RwLock>, + connections: RwLock>, message_handlers: RwLock>, handler_types: Mutex>, next_connection_id: AtomicU32, } +#[derive(Clone)] +struct Connection { + outgoing_tx: mpsc::Sender, + next_message_id: Arc, + response_channels: ResponseChannels, +} + +pub struct ConnectionHandler { + peer: Arc, + connection_id: ConnectionId, + response_channels: ResponseChannels, + outgoing_rx: mpsc::Receiver, + reader: MessageStream, + writer: MessageStream, +} + +type ResponseChannels = Arc>>>; + impl Peer { pub fn new() -> Arc { Arc::new(Self { connections: Default::default(), - connection_close_barriers: Default::default(), message_handlers: Default::default(), handler_types: Default::default(), next_connection_id: Default::default(), @@ -127,7 +129,10 @@ impl Peer { rx } - pub async fn add_connection(self: &Arc, conn: Conn) -> ConnectionId + pub async fn add_connection( + self: &Arc, + conn: Conn, + ) -> (ConnectionId, ConnectionHandler) where Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -135,120 +140,37 @@ impl Peer { self.next_connection_id .fetch_add(1, atomic::Ordering::SeqCst), ); - self.connections.write().await.insert( + let (outgoing_tx, outgoing_rx) = mpsc::channel(64); + let connection = Connection { + outgoing_tx, + next_message_id: Default::default(), + response_channels: Default::default(), + }; + let handler = ConnectionHandler { + peer: self.clone(), connection_id, - Arc::new(Connection { - reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))), - writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))), - response_channels: Default::default(), - next_message_id: Default::default(), - }), - ); - connection_id + response_channels: connection.response_channels.clone(), + outgoing_rx, + reader: MessageStream::new(conn.clone()), + writer: MessageStream::new(conn), + }; + self.connections + .write() + .await + .insert(connection_id, connection); + (connection_id, handler) } pub async fn disconnect(&self, connection_id: ConnectionId) { self.connections.write().await.remove(&connection_id); - self.connection_close_barriers - .write() - .await - .remove(&connection_id); } pub async fn reset(&self) { self.connections.write().await.clear(); - self.connection_close_barriers.write().await.clear(); self.handler_types.lock().await.clear(); self.message_handlers.write().await.clear(); } - pub fn handle_messages( - self: &Arc, - connection_id: ConnectionId, - ) -> impl Future> + 'static { - let (close_tx, mut close_rx) = barrier::channel(); - let this = self.clone(); - async move { - this.connection_close_barriers - .write() - .await - .insert(connection_id, close_tx); - let connection = this.connection(connection_id).await?; - let closed = close_rx.recv(); - futures::pin_mut!(closed); - - loop { - let mut reader = connection.reader.lock().await; - let read_message = reader.read_message(); - futures::pin_mut!(read_message); - - match futures::future::select(read_message, &mut closed).await { - Either::Left((Ok(incoming), _)) => { - if let Some(responding_to) = incoming.responding_to { - let channel = connection - .response_channels - .lock() - .await - .remove(&responding_to); - if let Some(mut tx) = channel { - tx.send(incoming).await.ok(); - } else { - log::warn!( - "received RPC response to unknown request {}", - responding_to - ); - } - } else { - let mut envelope = Some(incoming); - let mut handler_index = None; - let mut handler_was_dropped = false; - for (i, handler) in - this.message_handlers.read().await.iter().enumerate() - { - if let Some(future) = handler(&mut envelope, connection_id) { - handler_was_dropped = future.await; - handler_index = Some(i); - break; - } - } - - if let Some(handler_index) = handler_index { - if handler_was_dropped { - drop(this.message_handlers.write().await.remove(handler_index)); - } - } else { - log::warn!("unhandled message: {:?}", envelope.unwrap().payload); - } - } - } - Either::Left((Err(error), _)) => { - log::warn!("received invalid RPC message: {}", error); - Err(error)?; - } - Either::Right(_) => return Ok(()), - } - } - } - } - - pub async fn receive( - self: &Arc, - connection_id: ConnectionId, - ) -> Result> { - let connection = self.connection(connection_id).await?; - let envelope = connection.reader.lock().await.read_message().await?; - let original_sender_id = envelope.original_sender_id; - let message_id = envelope.id; - let payload = - M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?; - Ok(TypedEnvelope { - sender_id: connection_id, - original_sender_id: original_sender_id.map(PeerId), - message_id, - payload, - }) - } - pub fn request( self: &Arc, receiver_id: ConnectionId, @@ -273,9 +195,9 @@ impl Peer { request: T, ) -> impl Future> { let this = self.clone(); - let (tx, mut rx) = oneshot::channel(); + let (tx, mut rx) = mpsc::channel(1); async move { - let connection = this.connection(receiver_id).await?; + let mut connection = this.connection(receiver_id).await?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); @@ -285,19 +207,13 @@ impl Peer { .await .insert(message_id, tx); connection - .writer - .lock() - .await - .write_message(&request.into_envelope( - message_id, - None, - original_sender_id.map(|id| id.0), - )) + .outgoing_tx + .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0))) .await?; let response = rx .recv() .await - .expect("response channel was unexpectedly dropped"); + .ok_or_else(|| anyhow!("connection was closed"))?; T::Response::from_envelope(response) .ok_or_else(|| anyhow!("received response of the wrong type")) } @@ -305,20 +221,18 @@ impl Peer { pub fn send( self: &Arc, - connection_id: ConnectionId, + receiver_id: ConnectionId, message: T, ) -> impl Future> { let this = self.clone(); async move { - let connection = this.connection(connection_id).await?; + let mut connection = this.connection(receiver_id).await?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); connection - .writer - .lock() - .await - .write_message(&message.into_envelope(message_id, None, None)) + .outgoing_tx + .send(message.into_envelope(message_id, None, None)) .await?; Ok(()) } @@ -332,15 +246,13 @@ impl Peer { ) -> impl Future> { let this = self.clone(); async move { - let connection = this.connection(receiver_id).await?; + let mut connection = this.connection(receiver_id).await?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); connection - .writer - .lock() - .await - .write_message(&message.into_envelope(message_id, None, Some(sender_id.0))) + .outgoing_tx + .send(message.into_envelope(message_id, None, Some(sender_id.0))) .await?; Ok(()) } @@ -353,28 +265,114 @@ impl Peer { ) -> impl Future> { let this = self.clone(); async move { - let connection = this.connection(receipt.sender_id).await?; + let mut connection = this.connection(receipt.sender_id).await?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); connection - .writer - .lock() - .await - .write_message(&response.into_envelope(message_id, Some(receipt.message_id), None)) + .outgoing_tx + .send(response.into_envelope(message_id, Some(receipt.message_id), None)) .await?; Ok(()) } } - async fn connection(&self, id: ConnectionId) -> Result> { - Ok(self - .connections - .read() - .await - .get(&id) - .ok_or_else(|| anyhow!("unknown connection: {}", id.0))? - .clone()) + fn connection( + self: &Arc, + connection_id: ConnectionId, + ) -> impl Future> { + let this = self.clone(); + async move { + let connections = this.connections.read().await; + let connection = connections + .get(&connection_id) + .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?; + Ok(connection.clone()) + } + } +} + +impl ConnectionHandler +where + Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + pub async fn run(mut self) -> Result<()> { + loop { + let read_message = self.reader.read_message().fuse(); + futures::pin_mut!(read_message); + loop { + futures::select! { + incoming = read_message => match incoming { + Ok(incoming) => { + Self::handle_incoming_message(incoming, &self.peer, self.connection_id, &self.response_channels).await; + break; + } + Err(error) => { + self.response_channels.lock().await.clear(); + Err(error).context("received invalid RPC message")?; + } + }, + outgoing = self.outgoing_rx.recv().fuse() => match outgoing { + Some(outgoing) => { + if let Err(result) = self.writer.write_message(&outgoing).await { + self.response_channels.lock().await.clear(); + Err(result).context("failed to write RPC message")?; + } + } + None => return Ok(()), + } + } + } + } + } + + pub async fn receive(&mut self) -> Result> { + let envelope = self.reader.read_message().await?; + let original_sender_id = envelope.original_sender_id; + let message_id = envelope.id; + let payload = + M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?; + Ok(TypedEnvelope { + sender_id: self.connection_id, + original_sender_id: original_sender_id.map(PeerId), + message_id, + payload, + }) + } + + async fn handle_incoming_message( + message: proto::Envelope, + peer: &Arc, + connection_id: ConnectionId, + response_channels: &ResponseChannels, + ) { + if let Some(responding_to) = message.responding_to { + let channel = response_channels.lock().await.remove(&responding_to); + if let Some(mut tx) = channel { + tx.send(message).await.ok(); + } else { + log::warn!("received RPC response to unknown request {}", responding_to); + } + } else { + let mut envelope = Some(message); + let mut handler_index = None; + let mut handler_was_dropped = false; + for (i, handler) in peer.message_handlers.read().await.iter().enumerate() { + if let Some(future) = handler(&mut envelope, connection_id) { + handler_was_dropped = future.await; + handler_index = Some(i); + break; + } + } + + if let Some(handler_index) = handler_index { + if handler_was_dropped { + drop(peer.message_handlers.write().await.remove(handler_index)); + } + } else { + log::warn!("unhandled message: {:?}", envelope.unwrap().payload); + } + } } } @@ -412,22 +410,22 @@ mod tests { let server = Peer::new(); let client1 = Peer::new(); let client2 = Peer::new(); - let client1_conn_id = client1 + let (client1_conn_id, task1) = client1 .add_connection(UnixStream::connect(&socket_path).await.unwrap()) .await; - let client2_conn_id = client2 + let (client2_conn_id, task2) = client2 .add_connection(UnixStream::connect(&socket_path).await.unwrap()) .await; - let server_conn_id1 = server + let (_, task3) = server .add_connection(listener.accept().await.unwrap().0) .await; - let server_conn_id2 = server + let (_, task4) = server .add_connection(listener.accept().await.unwrap().0) .await; - smol::spawn(client1.handle_messages(client1_conn_id)).detach(); - smol::spawn(client2.handle_messages(client2_conn_id)).detach(); - smol::spawn(server.handle_messages(server_conn_id1)).detach(); - smol::spawn(server.handle_messages(server_conn_id2)).detach(); + smol::spawn(task1.run()).detach(); + smol::spawn(task2.run()).detach(); + smol::spawn(task3.run()).detach(); + smol::spawn(task4.run()).detach(); // define the expected requests and responses let request1 = proto::Auth { @@ -548,12 +546,11 @@ mod tests { let (mut server_conn, _) = listener.accept().await.unwrap(); let client = Peer::new(); - let connection_id = client.add_connection(client_conn).await; + let (connection_id, handler) = client.add_connection(client_conn).await; let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) = - barrier::channel(); - let handle_messages = client.handle_messages(connection_id); + postage::barrier::channel(); smol::spawn(async move { - handle_messages.await.ok(); + handler.run().await.ok(); incoming_messages_ended_tx.send(()).await.unwrap(); }) .detach(); @@ -576,8 +573,8 @@ mod tests { client_conn.close().await.unwrap(); let client = Peer::new(); - let connection_id = client.add_connection(client_conn).await; - smol::spawn(client.handle_messages(connection_id)).detach(); + let (connection_id, handler) = client.add_connection(client_conn).await; + smol::spawn(handler.run()).detach(); let err = client .request( @@ -589,10 +586,7 @@ mod tests { ) .await .unwrap_err(); - assert_eq!( - err.downcast_ref::().unwrap().kind(), - io::ErrorKind::BrokenPipe - ); + assert_eq!(err.to_string(), "connection was closed"); }); } } diff --git a/zed-rpc/src/proto.rs b/zed-rpc/src/proto.rs index bbd18d2353..6a3ee845b4 100644 --- a/zed-rpc/src/proto.rs +++ b/zed-rpc/src/proto.rs @@ -82,6 +82,7 @@ message!(RemoveGuest); pub struct MessageStream { byte_stream: T, buffer: Vec, + upcoming_message_len: Option, } impl MessageStream { @@ -89,6 +90,7 @@ impl MessageStream { Self { byte_stream, buffer: Default::default(), + upcoming_message_len: None, } } @@ -120,12 +122,23 @@ where { /// Read a protobuf message of the given type from the stream. pub async fn read_message(&mut self) -> io::Result { - let mut delimiter_buf = [0; 4]; - self.byte_stream.read_exact(&mut delimiter_buf).await?; - let message_len = u32::from_be_bytes(delimiter_buf) as usize; - self.buffer.resize(message_len, 0); - self.byte_stream.read_exact(&mut self.buffer).await?; - Ok(Envelope::decode(self.buffer.as_slice())?) + loop { + if let Some(upcoming_message_len) = self.upcoming_message_len { + self.buffer.resize(upcoming_message_len, 0); + self.byte_stream.read_exact(&mut self.buffer).await?; + self.upcoming_message_len = None; + return Ok(Envelope::decode(self.buffer.as_slice())?); + } else { + self.buffer.resize(4, 0); + self.byte_stream.read_exact(&mut self.buffer).await?; + self.upcoming_message_len = Some(u32::from_be_bytes([ + self.buffer[0], + self.buffer[1], + self.buffer[2], + self.buffer[3], + ]) as usize); + } + } } } diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 480e458956..3d567b516d 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -121,10 +121,8 @@ impl Client { let stream = smol::net::TcpStream::connect(&address).await?; log::info!("connected to rpc address {}", address); - let connection_id = self.peer.add_connection(stream).await; - executor - .spawn(self.peer.handle_messages(connection_id)) - .detach(); + let (connection_id, handler) = self.peer.add_connection(stream).await; + executor.spawn(handler.run()).detach(); let auth_response = self .peer