mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-12 21:32:40 +00:00
Allow peers to receive individual messages before starting message loop
Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
parent
4d28d03e3f
commit
05a662b35e
2 changed files with 75 additions and 47 deletions
|
@ -21,12 +21,14 @@ use std::{
|
|||
};
|
||||
|
||||
type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
|
||||
type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
pub struct ConnectionId(u32);
|
||||
|
||||
struct Connection {
|
||||
writer: Mutex<MessageStream<BoxedWriter>>,
|
||||
reader: Mutex<MessageStream<BoxedReader>>,
|
||||
response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
|
||||
next_message_id: AtomicU32,
|
||||
}
|
||||
|
@ -52,7 +54,8 @@ impl<T> TypedEnvelope<T> {
|
|||
}
|
||||
|
||||
pub struct Peer {
|
||||
connections: RwLock<HashMap<ConnectionId, (Arc<Connection>, barrier::Sender)>>,
|
||||
connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
|
||||
connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
|
||||
message_handlers: RwLock<Vec<MessageHandler>>,
|
||||
handler_types: Mutex<HashSet<TypeId>>,
|
||||
next_connection_id: AtomicU32,
|
||||
|
@ -62,6 +65,7 @@ impl Peer {
|
|||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
connections: Default::default(),
|
||||
connection_close_barriers: Default::default(),
|
||||
message_handlers: Default::default(),
|
||||
handler_types: Default::default(),
|
||||
next_connection_id: Default::default(),
|
||||
|
@ -102,10 +106,7 @@ impl Peer {
|
|||
rx
|
||||
}
|
||||
|
||||
pub async fn add_connection<Conn>(
|
||||
self: &Arc<Self>,
|
||||
conn: Conn,
|
||||
) -> (ConnectionId, impl Future<Output = Result<()>>)
|
||||
pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
|
||||
where
|
||||
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
|
@ -113,26 +114,44 @@ impl Peer {
|
|||
self.next_connection_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst),
|
||||
);
|
||||
let (close_tx, mut close_rx) = barrier::channel();
|
||||
let connection = Arc::new(Connection {
|
||||
writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
|
||||
response_channels: Default::default(),
|
||||
next_message_id: Default::default(),
|
||||
});
|
||||
self.connections.write().await.insert(
|
||||
connection_id,
|
||||
Arc::new(Connection {
|
||||
reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
|
||||
writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
|
||||
response_channels: Default::default(),
|
||||
next_message_id: Default::default(),
|
||||
}),
|
||||
);
|
||||
connection_id
|
||||
}
|
||||
|
||||
self.connections
|
||||
pub async fn disconnect(&self, connection_id: ConnectionId) {
|
||||
self.connections.write().await.remove(&connection_id);
|
||||
self.connection_close_barriers
|
||||
.write()
|
||||
.await
|
||||
.insert(connection_id, (connection.clone(), close_tx));
|
||||
.remove(&connection_id);
|
||||
}
|
||||
|
||||
pub fn handle_messages(
|
||||
self: &Arc<Self>,
|
||||
connection_id: ConnectionId,
|
||||
) -> impl Future<Output = Result<()>> + 'static {
|
||||
let (close_tx, mut close_rx) = barrier::channel();
|
||||
let this = self.clone();
|
||||
let handler_future = async move {
|
||||
async move {
|
||||
this.connection_close_barriers
|
||||
.write()
|
||||
.await
|
||||
.insert(connection_id, close_tx);
|
||||
let connection = this.connection(connection_id).await?;
|
||||
let closed = close_rx.recv();
|
||||
futures::pin_mut!(closed);
|
||||
|
||||
let mut stream = MessageStream::new(conn);
|
||||
loop {
|
||||
let read_message = stream.read_message();
|
||||
let mut reader = connection.reader.lock().await;
|
||||
let read_message = reader.read_message();
|
||||
futures::pin_mut!(read_message);
|
||||
|
||||
match futures::future::select(read_message, &mut closed).await {
|
||||
|
@ -181,13 +200,23 @@ impl Peer {
|
|||
Either::Right(_) => return Ok(()),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(connection_id, handler_future)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn disconnect(&self, connection_id: ConnectionId) {
|
||||
self.connections.write().await.remove(&connection_id);
|
||||
pub async fn receive<M: EnvelopedMessage>(
|
||||
self: &Arc<Self>,
|
||||
connection_id: ConnectionId,
|
||||
) -> Result<TypedEnvelope<M>> {
|
||||
let connection = self.connection(connection_id).await?;
|
||||
let envelope = connection.reader.lock().await.read_message().await?;
|
||||
let id = envelope.id;
|
||||
let payload =
|
||||
M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
|
||||
Ok(TypedEnvelope {
|
||||
id,
|
||||
connection_id,
|
||||
payload,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn request<T: RequestMessage>(
|
||||
|
@ -271,7 +300,6 @@ impl Peer {
|
|||
.await
|
||||
.get(&id)
|
||||
.ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
|
||||
.0
|
||||
.clone())
|
||||
}
|
||||
}
|
||||
|
@ -298,22 +326,22 @@ mod tests {
|
|||
let server = Peer::new();
|
||||
let client1 = Peer::new();
|
||||
let client2 = Peer::new();
|
||||
let (client1_conn_id, f1) = client1
|
||||
let client1_conn_id = client1
|
||||
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
|
||||
.await;
|
||||
let (client2_conn_id, f2) = client2
|
||||
let client2_conn_id = client2
|
||||
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
|
||||
.await;
|
||||
let (_, f3) = server
|
||||
let server_conn_id1 = server
|
||||
.add_connection(listener.accept().await.unwrap().0)
|
||||
.await;
|
||||
let (_, f4) = server
|
||||
let server_conn_id2 = server
|
||||
.add_connection(listener.accept().await.unwrap().0)
|
||||
.await;
|
||||
smol::spawn(f1).detach();
|
||||
smol::spawn(f2).detach();
|
||||
smol::spawn(f3).detach();
|
||||
smol::spawn(f4).detach();
|
||||
smol::spawn(client1.handle_messages(client1_conn_id)).detach();
|
||||
smol::spawn(client2.handle_messages(client2_conn_id)).detach();
|
||||
smol::spawn(server.handle_messages(server_conn_id1)).detach();
|
||||
smol::spawn(server.handle_messages(server_conn_id2)).detach();
|
||||
|
||||
// define the expected requests and responses
|
||||
let request1 = proto::OpenWorktree {
|
||||
|
@ -428,21 +456,21 @@ mod tests {
|
|||
let (mut server_conn, _) = listener.accept().await.unwrap();
|
||||
|
||||
let client = Peer::new();
|
||||
let (connection_id, handler) = client.add_connection(client_conn).await;
|
||||
smol::spawn(handler).detach();
|
||||
let connection_id = client.add_connection(client_conn).await;
|
||||
let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
|
||||
barrier::channel();
|
||||
let handle_messages = client.handle_messages(connection_id);
|
||||
smol::spawn(async move {
|
||||
handle_messages.await.unwrap();
|
||||
incoming_messages_ended_tx.send(()).await.unwrap();
|
||||
})
|
||||
.detach();
|
||||
client.disconnect(connection_id).await;
|
||||
|
||||
// Try sending an empty payload over and over, until the client is dropped and hangs up.
|
||||
loop {
|
||||
match server_conn.write(&[]).await {
|
||||
Ok(_) => {}
|
||||
Err(err) => {
|
||||
if err.kind() == io::ErrorKind::BrokenPipe {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
incoming_messages_ended_rx.recv().await;
|
||||
|
||||
let err = server_conn.write(&[]).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -456,8 +484,8 @@ mod tests {
|
|||
client_conn.close().await.unwrap();
|
||||
|
||||
let client = Peer::new();
|
||||
let (connection_id, handler) = client.add_connection(client_conn).await;
|
||||
smol::spawn(handler).detach();
|
||||
let connection_id = client.add_connection(client_conn).await;
|
||||
smol::spawn(client.handle_messages(connection_id)).detach();
|
||||
|
||||
let err = client
|
||||
.request(
|
||||
|
|
|
@ -691,8 +691,8 @@ impl Workspace {
|
|||
// a TLS stream using `native-tls`.
|
||||
let stream = smol::net::TcpStream::connect(rpc_address).await?;
|
||||
|
||||
let (connection_id, handler) = rpc.add_connection(stream).await;
|
||||
executor.spawn(handler).detach();
|
||||
let connection_id = rpc.add_connection(stream).await;
|
||||
executor.spawn(rpc.handle_messages(connection_id)).detach();
|
||||
|
||||
let auth_response = rpc
|
||||
.request(
|
||||
|
|
Loading…
Reference in a new issue