Serialize RPC sends and responses using a channel

This commit is contained in:
Max Brunsfeld 2021-07-01 18:12:46 -07:00
parent 42f7867f6e
commit 9d51fe88e9
3 changed files with 190 additions and 185 deletions

View file

@ -1,12 +1,9 @@
use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context, Result};
use async_lock::{Mutex, RwLock};
use futures::{
future::{BoxFuture, Either},
AsyncRead, AsyncWrite, FutureExt,
};
use futures::{future::BoxFuture, AsyncRead, AsyncWrite, FutureExt};
use postage::{
barrier, mpsc, oneshot,
mpsc,
prelude::{Sink, Stream},
};
use std::{
@ -15,29 +12,18 @@ use std::{
fmt,
future::Future,
marker::PhantomData,
pin::Pin,
sync::{
atomic::{self, AtomicU32},
Arc,
},
};
type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct ConnectionId(pub u32);
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct PeerId(pub u32);
struct Connection {
writer: Mutex<MessageStream<BoxedWriter>>,
reader: Mutex<MessageStream<BoxedReader>>,
response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
next_message_id: AtomicU32,
}
type MessageHandler = Box<
dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
>;
@ -74,18 +60,34 @@ impl<T: RequestMessage> TypedEnvelope<T> {
}
pub struct Peer {
connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
connections: RwLock<HashMap<ConnectionId, Connection>>,
message_handlers: RwLock<Vec<MessageHandler>>,
handler_types: Mutex<HashSet<TypeId>>,
next_connection_id: AtomicU32,
}
#[derive(Clone)]
struct Connection {
outgoing_tx: mpsc::Sender<proto::Envelope>,
next_message_id: Arc<AtomicU32>,
response_channels: ResponseChannels,
}
pub struct ConnectionHandler<Conn> {
peer: Arc<Peer>,
connection_id: ConnectionId,
response_channels: ResponseChannels,
outgoing_rx: mpsc::Receiver<proto::Envelope>,
reader: MessageStream<Conn>,
writer: MessageStream<Conn>,
}
type ResponseChannels = Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>;
impl Peer {
pub fn new() -> Arc<Self> {
Arc::new(Self {
connections: Default::default(),
connection_close_barriers: Default::default(),
message_handlers: Default::default(),
handler_types: Default::default(),
next_connection_id: Default::default(),
@ -127,7 +129,10 @@ impl Peer {
rx
}
pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
pub async fn add_connection<Conn>(
self: &Arc<Self>,
conn: Conn,
) -> (ConnectionId, ConnectionHandler<Conn>)
where
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
@ -135,120 +140,37 @@ impl Peer {
self.next_connection_id
.fetch_add(1, atomic::Ordering::SeqCst),
);
self.connections.write().await.insert(
let (outgoing_tx, outgoing_rx) = mpsc::channel(64);
let connection = Connection {
outgoing_tx,
next_message_id: Default::default(),
response_channels: Default::default(),
};
let handler = ConnectionHandler {
peer: self.clone(),
connection_id,
Arc::new(Connection {
reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
response_channels: Default::default(),
next_message_id: Default::default(),
}),
);
connection_id
response_channels: connection.response_channels.clone(),
outgoing_rx,
reader: MessageStream::new(conn.clone()),
writer: MessageStream::new(conn),
};
self.connections
.write()
.await
.insert(connection_id, connection);
(connection_id, handler)
}
pub async fn disconnect(&self, connection_id: ConnectionId) {
self.connections.write().await.remove(&connection_id);
self.connection_close_barriers
.write()
.await
.remove(&connection_id);
}
pub async fn reset(&self) {
self.connections.write().await.clear();
self.connection_close_barriers.write().await.clear();
self.handler_types.lock().await.clear();
self.message_handlers.write().await.clear();
}
pub fn handle_messages(
self: &Arc<Self>,
connection_id: ConnectionId,
) -> impl Future<Output = Result<()>> + 'static {
let (close_tx, mut close_rx) = barrier::channel();
let this = self.clone();
async move {
this.connection_close_barriers
.write()
.await
.insert(connection_id, close_tx);
let connection = this.connection(connection_id).await?;
let closed = close_rx.recv();
futures::pin_mut!(closed);
loop {
let mut reader = connection.reader.lock().await;
let read_message = reader.read_message();
futures::pin_mut!(read_message);
match futures::future::select(read_message, &mut closed).await {
Either::Left((Ok(incoming), _)) => {
if let Some(responding_to) = incoming.responding_to {
let channel = connection
.response_channels
.lock()
.await
.remove(&responding_to);
if let Some(mut tx) = channel {
tx.send(incoming).await.ok();
} else {
log::warn!(
"received RPC response to unknown request {}",
responding_to
);
}
} else {
let mut envelope = Some(incoming);
let mut handler_index = None;
let mut handler_was_dropped = false;
for (i, handler) in
this.message_handlers.read().await.iter().enumerate()
{
if let Some(future) = handler(&mut envelope, connection_id) {
handler_was_dropped = future.await;
handler_index = Some(i);
break;
}
}
if let Some(handler_index) = handler_index {
if handler_was_dropped {
drop(this.message_handlers.write().await.remove(handler_index));
}
} else {
log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
}
}
}
Either::Left((Err(error), _)) => {
log::warn!("received invalid RPC message: {}", error);
Err(error)?;
}
Either::Right(_) => return Ok(()),
}
}
}
}
pub async fn receive<M: EnvelopedMessage>(
self: &Arc<Self>,
connection_id: ConnectionId,
) -> Result<TypedEnvelope<M>> {
let connection = self.connection(connection_id).await?;
let envelope = connection.reader.lock().await.read_message().await?;
let original_sender_id = envelope.original_sender_id;
let message_id = envelope.id;
let payload =
M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
Ok(TypedEnvelope {
sender_id: connection_id,
original_sender_id: original_sender_id.map(PeerId),
message_id,
payload,
})
}
pub fn request<T: RequestMessage>(
self: &Arc<Self>,
receiver_id: ConnectionId,
@ -273,9 +195,9 @@ impl Peer {
request: T,
) -> impl Future<Output = Result<T::Response>> {
let this = self.clone();
let (tx, mut rx) = oneshot::channel();
let (tx, mut rx) = mpsc::channel(1);
async move {
let connection = this.connection(receiver_id).await?;
let mut connection = this.connection(receiver_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
@ -285,19 +207,13 @@ impl Peer {
.await
.insert(message_id, tx);
connection
.writer
.lock()
.await
.write_message(&request.into_envelope(
message_id,
None,
original_sender_id.map(|id| id.0),
))
.outgoing_tx
.send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
.await?;
let response = rx
.recv()
.await
.expect("response channel was unexpectedly dropped");
.ok_or_else(|| anyhow!("connection was closed"))?;
T::Response::from_envelope(response)
.ok_or_else(|| anyhow!("received response of the wrong type"))
}
@ -305,20 +221,18 @@ impl Peer {
pub fn send<T: EnvelopedMessage>(
self: &Arc<Self>,
connection_id: ConnectionId,
receiver_id: ConnectionId,
message: T,
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
let connection = this.connection(connection_id).await?;
let mut connection = this.connection(receiver_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.writer
.lock()
.await
.write_message(&message.into_envelope(message_id, None, None))
.outgoing_tx
.send(message.into_envelope(message_id, None, None))
.await?;
Ok(())
}
@ -332,15 +246,13 @@ impl Peer {
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
let connection = this.connection(receiver_id).await?;
let mut connection = this.connection(receiver_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.writer
.lock()
.await
.write_message(&message.into_envelope(message_id, None, Some(sender_id.0)))
.outgoing_tx
.send(message.into_envelope(message_id, None, Some(sender_id.0)))
.await?;
Ok(())
}
@ -353,28 +265,114 @@ impl Peer {
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
let connection = this.connection(receipt.sender_id).await?;
let mut connection = this.connection(receipt.sender_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.writer
.lock()
.await
.write_message(&response.into_envelope(message_id, Some(receipt.message_id), None))
.outgoing_tx
.send(response.into_envelope(message_id, Some(receipt.message_id), None))
.await?;
Ok(())
}
}
async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
Ok(self
.connections
.read()
.await
.get(&id)
.ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
.clone())
fn connection(
self: &Arc<Self>,
connection_id: ConnectionId,
) -> impl Future<Output = Result<Connection>> {
let this = self.clone();
async move {
let connections = this.connections.read().await;
let connection = connections
.get(&connection_id)
.ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
Ok(connection.clone())
}
}
}
impl<Conn> ConnectionHandler<Conn>
where
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
pub async fn run(mut self) -> Result<()> {
loop {
let read_message = self.reader.read_message().fuse();
futures::pin_mut!(read_message);
loop {
futures::select! {
incoming = read_message => match incoming {
Ok(incoming) => {
Self::handle_incoming_message(incoming, &self.peer, self.connection_id, &self.response_channels).await;
break;
}
Err(error) => {
self.response_channels.lock().await.clear();
Err(error).context("received invalid RPC message")?;
}
},
outgoing = self.outgoing_rx.recv().fuse() => match outgoing {
Some(outgoing) => {
if let Err(result) = self.writer.write_message(&outgoing).await {
self.response_channels.lock().await.clear();
Err(result).context("failed to write RPC message")?;
}
}
None => return Ok(()),
}
}
}
}
}
pub async fn receive<M: EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
let envelope = self.reader.read_message().await?;
let original_sender_id = envelope.original_sender_id;
let message_id = envelope.id;
let payload =
M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
Ok(TypedEnvelope {
sender_id: self.connection_id,
original_sender_id: original_sender_id.map(PeerId),
message_id,
payload,
})
}
async fn handle_incoming_message(
message: proto::Envelope,
peer: &Arc<Peer>,
connection_id: ConnectionId,
response_channels: &ResponseChannels,
) {
if let Some(responding_to) = message.responding_to {
let channel = response_channels.lock().await.remove(&responding_to);
if let Some(mut tx) = channel {
tx.send(message).await.ok();
} else {
log::warn!("received RPC response to unknown request {}", responding_to);
}
} else {
let mut envelope = Some(message);
let mut handler_index = None;
let mut handler_was_dropped = false;
for (i, handler) in peer.message_handlers.read().await.iter().enumerate() {
if let Some(future) = handler(&mut envelope, connection_id) {
handler_was_dropped = future.await;
handler_index = Some(i);
break;
}
}
if let Some(handler_index) = handler_index {
if handler_was_dropped {
drop(peer.message_handlers.write().await.remove(handler_index));
}
} else {
log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
}
}
}
}
@ -412,22 +410,22 @@ mod tests {
let server = Peer::new();
let client1 = Peer::new();
let client2 = Peer::new();
let client1_conn_id = client1
let (client1_conn_id, task1) = client1
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
.await;
let client2_conn_id = client2
let (client2_conn_id, task2) = client2
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
.await;
let server_conn_id1 = server
let (_, task3) = server
.add_connection(listener.accept().await.unwrap().0)
.await;
let server_conn_id2 = server
let (_, task4) = server
.add_connection(listener.accept().await.unwrap().0)
.await;
smol::spawn(client1.handle_messages(client1_conn_id)).detach();
smol::spawn(client2.handle_messages(client2_conn_id)).detach();
smol::spawn(server.handle_messages(server_conn_id1)).detach();
smol::spawn(server.handle_messages(server_conn_id2)).detach();
smol::spawn(task1.run()).detach();
smol::spawn(task2.run()).detach();
smol::spawn(task3.run()).detach();
smol::spawn(task4.run()).detach();
// define the expected requests and responses
let request1 = proto::Auth {
@ -548,12 +546,11 @@ mod tests {
let (mut server_conn, _) = listener.accept().await.unwrap();
let client = Peer::new();
let connection_id = client.add_connection(client_conn).await;
let (connection_id, handler) = client.add_connection(client_conn).await;
let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
barrier::channel();
let handle_messages = client.handle_messages(connection_id);
postage::barrier::channel();
smol::spawn(async move {
handle_messages.await.ok();
handler.run().await.ok();
incoming_messages_ended_tx.send(()).await.unwrap();
})
.detach();
@ -576,8 +573,8 @@ mod tests {
client_conn.close().await.unwrap();
let client = Peer::new();
let connection_id = client.add_connection(client_conn).await;
smol::spawn(client.handle_messages(connection_id)).detach();
let (connection_id, handler) = client.add_connection(client_conn).await;
smol::spawn(handler.run()).detach();
let err = client
.request(
@ -589,10 +586,7 @@ mod tests {
)
.await
.unwrap_err();
assert_eq!(
err.downcast_ref::<io::Error>().unwrap().kind(),
io::ErrorKind::BrokenPipe
);
assert_eq!(err.to_string(), "connection was closed");
});
}
}

View file

@ -82,6 +82,7 @@ message!(RemoveGuest);
pub struct MessageStream<T> {
byte_stream: T,
buffer: Vec<u8>,
upcoming_message_len: Option<usize>,
}
impl<T> MessageStream<T> {
@ -89,6 +90,7 @@ impl<T> MessageStream<T> {
Self {
byte_stream,
buffer: Default::default(),
upcoming_message_len: None,
}
}
@ -120,12 +122,23 @@ where
{
/// Read a protobuf message of the given type from the stream.
pub async fn read_message(&mut self) -> io::Result<Envelope> {
let mut delimiter_buf = [0; 4];
self.byte_stream.read_exact(&mut delimiter_buf).await?;
let message_len = u32::from_be_bytes(delimiter_buf) as usize;
self.buffer.resize(message_len, 0);
self.byte_stream.read_exact(&mut self.buffer).await?;
Ok(Envelope::decode(self.buffer.as_slice())?)
loop {
if let Some(upcoming_message_len) = self.upcoming_message_len {
self.buffer.resize(upcoming_message_len, 0);
self.byte_stream.read_exact(&mut self.buffer).await?;
self.upcoming_message_len = None;
return Ok(Envelope::decode(self.buffer.as_slice())?);
} else {
self.buffer.resize(4, 0);
self.byte_stream.read_exact(&mut self.buffer).await?;
self.upcoming_message_len = Some(u32::from_be_bytes([
self.buffer[0],
self.buffer[1],
self.buffer[2],
self.buffer[3],
]) as usize);
}
}
}
}

View file

@ -121,10 +121,8 @@ impl Client {
let stream = smol::net::TcpStream::connect(&address).await?;
log::info!("connected to rpc address {}", address);
let connection_id = self.peer.add_connection(stream).await;
executor
.spawn(self.peer.handle_messages(connection_id))
.detach();
let (connection_id, handler) = self.peer.add_connection(stream).await;
executor.spawn(handler.run()).detach();
let auth_response = self
.peer