mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-12 21:32:40 +00:00
Serialize RPC sends and responses using a channel
This commit is contained in:
parent
42f7867f6e
commit
9d51fe88e9
3 changed files with 190 additions and 185 deletions
|
@ -1,12 +1,9 @@
|
|||
use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
|
||||
use anyhow::{anyhow, Result};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_lock::{Mutex, RwLock};
|
||||
use futures::{
|
||||
future::{BoxFuture, Either},
|
||||
AsyncRead, AsyncWrite, FutureExt,
|
||||
};
|
||||
use futures::{future::BoxFuture, AsyncRead, AsyncWrite, FutureExt};
|
||||
use postage::{
|
||||
barrier, mpsc, oneshot,
|
||||
mpsc,
|
||||
prelude::{Sink, Stream},
|
||||
};
|
||||
use std::{
|
||||
|
@ -15,29 +12,18 @@ use std::{
|
|||
fmt,
|
||||
future::Future,
|
||||
marker::PhantomData,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{self, AtomicU32},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
|
||||
type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
pub struct ConnectionId(pub u32);
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
pub struct PeerId(pub u32);
|
||||
|
||||
struct Connection {
|
||||
writer: Mutex<MessageStream<BoxedWriter>>,
|
||||
reader: Mutex<MessageStream<BoxedReader>>,
|
||||
response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
|
||||
next_message_id: AtomicU32,
|
||||
}
|
||||
|
||||
type MessageHandler = Box<
|
||||
dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
|
||||
>;
|
||||
|
@ -74,18 +60,34 @@ impl<T: RequestMessage> TypedEnvelope<T> {
|
|||
}
|
||||
|
||||
pub struct Peer {
|
||||
connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
|
||||
connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
|
||||
connections: RwLock<HashMap<ConnectionId, Connection>>,
|
||||
message_handlers: RwLock<Vec<MessageHandler>>,
|
||||
handler_types: Mutex<HashSet<TypeId>>,
|
||||
next_connection_id: AtomicU32,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Connection {
|
||||
outgoing_tx: mpsc::Sender<proto::Envelope>,
|
||||
next_message_id: Arc<AtomicU32>,
|
||||
response_channels: ResponseChannels,
|
||||
}
|
||||
|
||||
pub struct ConnectionHandler<Conn> {
|
||||
peer: Arc<Peer>,
|
||||
connection_id: ConnectionId,
|
||||
response_channels: ResponseChannels,
|
||||
outgoing_rx: mpsc::Receiver<proto::Envelope>,
|
||||
reader: MessageStream<Conn>,
|
||||
writer: MessageStream<Conn>,
|
||||
}
|
||||
|
||||
type ResponseChannels = Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>;
|
||||
|
||||
impl Peer {
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
connections: Default::default(),
|
||||
connection_close_barriers: Default::default(),
|
||||
message_handlers: Default::default(),
|
||||
handler_types: Default::default(),
|
||||
next_connection_id: Default::default(),
|
||||
|
@ -127,7 +129,10 @@ impl Peer {
|
|||
rx
|
||||
}
|
||||
|
||||
pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
|
||||
pub async fn add_connection<Conn>(
|
||||
self: &Arc<Self>,
|
||||
conn: Conn,
|
||||
) -> (ConnectionId, ConnectionHandler<Conn>)
|
||||
where
|
||||
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
|
@ -135,120 +140,37 @@ impl Peer {
|
|||
self.next_connection_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst),
|
||||
);
|
||||
self.connections.write().await.insert(
|
||||
let (outgoing_tx, outgoing_rx) = mpsc::channel(64);
|
||||
let connection = Connection {
|
||||
outgoing_tx,
|
||||
next_message_id: Default::default(),
|
||||
response_channels: Default::default(),
|
||||
};
|
||||
let handler = ConnectionHandler {
|
||||
peer: self.clone(),
|
||||
connection_id,
|
||||
Arc::new(Connection {
|
||||
reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
|
||||
writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
|
||||
response_channels: Default::default(),
|
||||
next_message_id: Default::default(),
|
||||
}),
|
||||
);
|
||||
connection_id
|
||||
response_channels: connection.response_channels.clone(),
|
||||
outgoing_rx,
|
||||
reader: MessageStream::new(conn.clone()),
|
||||
writer: MessageStream::new(conn),
|
||||
};
|
||||
self.connections
|
||||
.write()
|
||||
.await
|
||||
.insert(connection_id, connection);
|
||||
(connection_id, handler)
|
||||
}
|
||||
|
||||
pub async fn disconnect(&self, connection_id: ConnectionId) {
|
||||
self.connections.write().await.remove(&connection_id);
|
||||
self.connection_close_barriers
|
||||
.write()
|
||||
.await
|
||||
.remove(&connection_id);
|
||||
}
|
||||
|
||||
pub async fn reset(&self) {
|
||||
self.connections.write().await.clear();
|
||||
self.connection_close_barriers.write().await.clear();
|
||||
self.handler_types.lock().await.clear();
|
||||
self.message_handlers.write().await.clear();
|
||||
}
|
||||
|
||||
pub fn handle_messages(
|
||||
self: &Arc<Self>,
|
||||
connection_id: ConnectionId,
|
||||
) -> impl Future<Output = Result<()>> + 'static {
|
||||
let (close_tx, mut close_rx) = barrier::channel();
|
||||
let this = self.clone();
|
||||
async move {
|
||||
this.connection_close_barriers
|
||||
.write()
|
||||
.await
|
||||
.insert(connection_id, close_tx);
|
||||
let connection = this.connection(connection_id).await?;
|
||||
let closed = close_rx.recv();
|
||||
futures::pin_mut!(closed);
|
||||
|
||||
loop {
|
||||
let mut reader = connection.reader.lock().await;
|
||||
let read_message = reader.read_message();
|
||||
futures::pin_mut!(read_message);
|
||||
|
||||
match futures::future::select(read_message, &mut closed).await {
|
||||
Either::Left((Ok(incoming), _)) => {
|
||||
if let Some(responding_to) = incoming.responding_to {
|
||||
let channel = connection
|
||||
.response_channels
|
||||
.lock()
|
||||
.await
|
||||
.remove(&responding_to);
|
||||
if let Some(mut tx) = channel {
|
||||
tx.send(incoming).await.ok();
|
||||
} else {
|
||||
log::warn!(
|
||||
"received RPC response to unknown request {}",
|
||||
responding_to
|
||||
);
|
||||
}
|
||||
} else {
|
||||
let mut envelope = Some(incoming);
|
||||
let mut handler_index = None;
|
||||
let mut handler_was_dropped = false;
|
||||
for (i, handler) in
|
||||
this.message_handlers.read().await.iter().enumerate()
|
||||
{
|
||||
if let Some(future) = handler(&mut envelope, connection_id) {
|
||||
handler_was_dropped = future.await;
|
||||
handler_index = Some(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(handler_index) = handler_index {
|
||||
if handler_was_dropped {
|
||||
drop(this.message_handlers.write().await.remove(handler_index));
|
||||
}
|
||||
} else {
|
||||
log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
|
||||
}
|
||||
}
|
||||
}
|
||||
Either::Left((Err(error), _)) => {
|
||||
log::warn!("received invalid RPC message: {}", error);
|
||||
Err(error)?;
|
||||
}
|
||||
Either::Right(_) => return Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn receive<M: EnvelopedMessage>(
|
||||
self: &Arc<Self>,
|
||||
connection_id: ConnectionId,
|
||||
) -> Result<TypedEnvelope<M>> {
|
||||
let connection = self.connection(connection_id).await?;
|
||||
let envelope = connection.reader.lock().await.read_message().await?;
|
||||
let original_sender_id = envelope.original_sender_id;
|
||||
let message_id = envelope.id;
|
||||
let payload =
|
||||
M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
|
||||
Ok(TypedEnvelope {
|
||||
sender_id: connection_id,
|
||||
original_sender_id: original_sender_id.map(PeerId),
|
||||
message_id,
|
||||
payload,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn request<T: RequestMessage>(
|
||||
self: &Arc<Self>,
|
||||
receiver_id: ConnectionId,
|
||||
|
@ -273,9 +195,9 @@ impl Peer {
|
|||
request: T,
|
||||
) -> impl Future<Output = Result<T::Response>> {
|
||||
let this = self.clone();
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
async move {
|
||||
let connection = this.connection(receiver_id).await?;
|
||||
let mut connection = this.connection(receiver_id).await?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
|
@ -285,19 +207,13 @@ impl Peer {
|
|||
.await
|
||||
.insert(message_id, tx);
|
||||
connection
|
||||
.writer
|
||||
.lock()
|
||||
.await
|
||||
.write_message(&request.into_envelope(
|
||||
message_id,
|
||||
None,
|
||||
original_sender_id.map(|id| id.0),
|
||||
))
|
||||
.outgoing_tx
|
||||
.send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
|
||||
.await?;
|
||||
let response = rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("response channel was unexpectedly dropped");
|
||||
.ok_or_else(|| anyhow!("connection was closed"))?;
|
||||
T::Response::from_envelope(response)
|
||||
.ok_or_else(|| anyhow!("received response of the wrong type"))
|
||||
}
|
||||
|
@ -305,20 +221,18 @@ impl Peer {
|
|||
|
||||
pub fn send<T: EnvelopedMessage>(
|
||||
self: &Arc<Self>,
|
||||
connection_id: ConnectionId,
|
||||
receiver_id: ConnectionId,
|
||||
message: T,
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let connection = this.connection(connection_id).await?;
|
||||
let mut connection = this.connection(receiver_id).await?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
connection
|
||||
.writer
|
||||
.lock()
|
||||
.await
|
||||
.write_message(&message.into_envelope(message_id, None, None))
|
||||
.outgoing_tx
|
||||
.send(message.into_envelope(message_id, None, None))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -332,15 +246,13 @@ impl Peer {
|
|||
) -> impl Future<Output = Result<()>> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let connection = this.connection(receiver_id).await?;
|
||||
let mut connection = this.connection(receiver_id).await?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
connection
|
||||
.writer
|
||||
.lock()
|
||||
.await
|
||||
.write_message(&message.into_envelope(message_id, None, Some(sender_id.0)))
|
||||
.outgoing_tx
|
||||
.send(message.into_envelope(message_id, None, Some(sender_id.0)))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -353,28 +265,114 @@ impl Peer {
|
|||
) -> impl Future<Output = Result<()>> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let connection = this.connection(receipt.sender_id).await?;
|
||||
let mut connection = this.connection(receipt.sender_id).await?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
connection
|
||||
.writer
|
||||
.lock()
|
||||
.await
|
||||
.write_message(&response.into_envelope(message_id, Some(receipt.message_id), None))
|
||||
.outgoing_tx
|
||||
.send(response.into_envelope(message_id, Some(receipt.message_id), None))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
|
||||
Ok(self
|
||||
.connections
|
||||
.read()
|
||||
.await
|
||||
.get(&id)
|
||||
.ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
|
||||
.clone())
|
||||
fn connection(
|
||||
self: &Arc<Self>,
|
||||
connection_id: ConnectionId,
|
||||
) -> impl Future<Output = Result<Connection>> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let connections = this.connections.read().await;
|
||||
let connection = connections
|
||||
.get(&connection_id)
|
||||
.ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
|
||||
Ok(connection.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Conn> ConnectionHandler<Conn>
|
||||
where
|
||||
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
pub async fn run(mut self) -> Result<()> {
|
||||
loop {
|
||||
let read_message = self.reader.read_message().fuse();
|
||||
futures::pin_mut!(read_message);
|
||||
loop {
|
||||
futures::select! {
|
||||
incoming = read_message => match incoming {
|
||||
Ok(incoming) => {
|
||||
Self::handle_incoming_message(incoming, &self.peer, self.connection_id, &self.response_channels).await;
|
||||
break;
|
||||
}
|
||||
Err(error) => {
|
||||
self.response_channels.lock().await.clear();
|
||||
Err(error).context("received invalid RPC message")?;
|
||||
}
|
||||
},
|
||||
outgoing = self.outgoing_rx.recv().fuse() => match outgoing {
|
||||
Some(outgoing) => {
|
||||
if let Err(result) = self.writer.write_message(&outgoing).await {
|
||||
self.response_channels.lock().await.clear();
|
||||
Err(result).context("failed to write RPC message")?;
|
||||
}
|
||||
}
|
||||
None => return Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn receive<M: EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
|
||||
let envelope = self.reader.read_message().await?;
|
||||
let original_sender_id = envelope.original_sender_id;
|
||||
let message_id = envelope.id;
|
||||
let payload =
|
||||
M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
|
||||
Ok(TypedEnvelope {
|
||||
sender_id: self.connection_id,
|
||||
original_sender_id: original_sender_id.map(PeerId),
|
||||
message_id,
|
||||
payload,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_incoming_message(
|
||||
message: proto::Envelope,
|
||||
peer: &Arc<Peer>,
|
||||
connection_id: ConnectionId,
|
||||
response_channels: &ResponseChannels,
|
||||
) {
|
||||
if let Some(responding_to) = message.responding_to {
|
||||
let channel = response_channels.lock().await.remove(&responding_to);
|
||||
if let Some(mut tx) = channel {
|
||||
tx.send(message).await.ok();
|
||||
} else {
|
||||
log::warn!("received RPC response to unknown request {}", responding_to);
|
||||
}
|
||||
} else {
|
||||
let mut envelope = Some(message);
|
||||
let mut handler_index = None;
|
||||
let mut handler_was_dropped = false;
|
||||
for (i, handler) in peer.message_handlers.read().await.iter().enumerate() {
|
||||
if let Some(future) = handler(&mut envelope, connection_id) {
|
||||
handler_was_dropped = future.await;
|
||||
handler_index = Some(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(handler_index) = handler_index {
|
||||
if handler_was_dropped {
|
||||
drop(peer.message_handlers.write().await.remove(handler_index));
|
||||
}
|
||||
} else {
|
||||
log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -412,22 +410,22 @@ mod tests {
|
|||
let server = Peer::new();
|
||||
let client1 = Peer::new();
|
||||
let client2 = Peer::new();
|
||||
let client1_conn_id = client1
|
||||
let (client1_conn_id, task1) = client1
|
||||
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
|
||||
.await;
|
||||
let client2_conn_id = client2
|
||||
let (client2_conn_id, task2) = client2
|
||||
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
|
||||
.await;
|
||||
let server_conn_id1 = server
|
||||
let (_, task3) = server
|
||||
.add_connection(listener.accept().await.unwrap().0)
|
||||
.await;
|
||||
let server_conn_id2 = server
|
||||
let (_, task4) = server
|
||||
.add_connection(listener.accept().await.unwrap().0)
|
||||
.await;
|
||||
smol::spawn(client1.handle_messages(client1_conn_id)).detach();
|
||||
smol::spawn(client2.handle_messages(client2_conn_id)).detach();
|
||||
smol::spawn(server.handle_messages(server_conn_id1)).detach();
|
||||
smol::spawn(server.handle_messages(server_conn_id2)).detach();
|
||||
smol::spawn(task1.run()).detach();
|
||||
smol::spawn(task2.run()).detach();
|
||||
smol::spawn(task3.run()).detach();
|
||||
smol::spawn(task4.run()).detach();
|
||||
|
||||
// define the expected requests and responses
|
||||
let request1 = proto::Auth {
|
||||
|
@ -548,12 +546,11 @@ mod tests {
|
|||
let (mut server_conn, _) = listener.accept().await.unwrap();
|
||||
|
||||
let client = Peer::new();
|
||||
let connection_id = client.add_connection(client_conn).await;
|
||||
let (connection_id, handler) = client.add_connection(client_conn).await;
|
||||
let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
|
||||
barrier::channel();
|
||||
let handle_messages = client.handle_messages(connection_id);
|
||||
postage::barrier::channel();
|
||||
smol::spawn(async move {
|
||||
handle_messages.await.ok();
|
||||
handler.run().await.ok();
|
||||
incoming_messages_ended_tx.send(()).await.unwrap();
|
||||
})
|
||||
.detach();
|
||||
|
@ -576,8 +573,8 @@ mod tests {
|
|||
client_conn.close().await.unwrap();
|
||||
|
||||
let client = Peer::new();
|
||||
let connection_id = client.add_connection(client_conn).await;
|
||||
smol::spawn(client.handle_messages(connection_id)).detach();
|
||||
let (connection_id, handler) = client.add_connection(client_conn).await;
|
||||
smol::spawn(handler.run()).detach();
|
||||
|
||||
let err = client
|
||||
.request(
|
||||
|
@ -589,10 +586,7 @@ mod tests {
|
|||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(
|
||||
err.downcast_ref::<io::Error>().unwrap().kind(),
|
||||
io::ErrorKind::BrokenPipe
|
||||
);
|
||||
assert_eq!(err.to_string(), "connection was closed");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -82,6 +82,7 @@ message!(RemoveGuest);
|
|||
pub struct MessageStream<T> {
|
||||
byte_stream: T,
|
||||
buffer: Vec<u8>,
|
||||
upcoming_message_len: Option<usize>,
|
||||
}
|
||||
|
||||
impl<T> MessageStream<T> {
|
||||
|
@ -89,6 +90,7 @@ impl<T> MessageStream<T> {
|
|||
Self {
|
||||
byte_stream,
|
||||
buffer: Default::default(),
|
||||
upcoming_message_len: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -120,12 +122,23 @@ where
|
|||
{
|
||||
/// Read a protobuf message of the given type from the stream.
|
||||
pub async fn read_message(&mut self) -> io::Result<Envelope> {
|
||||
let mut delimiter_buf = [0; 4];
|
||||
self.byte_stream.read_exact(&mut delimiter_buf).await?;
|
||||
let message_len = u32::from_be_bytes(delimiter_buf) as usize;
|
||||
self.buffer.resize(message_len, 0);
|
||||
self.byte_stream.read_exact(&mut self.buffer).await?;
|
||||
Ok(Envelope::decode(self.buffer.as_slice())?)
|
||||
loop {
|
||||
if let Some(upcoming_message_len) = self.upcoming_message_len {
|
||||
self.buffer.resize(upcoming_message_len, 0);
|
||||
self.byte_stream.read_exact(&mut self.buffer).await?;
|
||||
self.upcoming_message_len = None;
|
||||
return Ok(Envelope::decode(self.buffer.as_slice())?);
|
||||
} else {
|
||||
self.buffer.resize(4, 0);
|
||||
self.byte_stream.read_exact(&mut self.buffer).await?;
|
||||
self.upcoming_message_len = Some(u32::from_be_bytes([
|
||||
self.buffer[0],
|
||||
self.buffer[1],
|
||||
self.buffer[2],
|
||||
self.buffer[3],
|
||||
]) as usize);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -121,10 +121,8 @@ impl Client {
|
|||
let stream = smol::net::TcpStream::connect(&address).await?;
|
||||
log::info!("connected to rpc address {}", address);
|
||||
|
||||
let connection_id = self.peer.add_connection(stream).await;
|
||||
executor
|
||||
.spawn(self.peer.handle_messages(connection_id))
|
||||
.detach();
|
||||
let (connection_id, handler) = self.peer.add_connection(stream).await;
|
||||
executor.spawn(handler.run()).detach();
|
||||
|
||||
let auth_response = self
|
||||
.peer
|
||||
|
|
Loading…
Reference in a new issue