From 9336c02867335a0dfc16650708da2b5dfbf87a4e Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Wed, 18 Aug 2021 21:59:13 -0600 Subject: [PATCH] Start on a peer2 module with an alternative implementation --- zrpc/proto/zed.proto | 58 +++--- zrpc/src/lib.rs | 1 + zrpc/src/peer.rs | 6 +- zrpc/src/peer2.rs | 470 +++++++++++++++++++++++++++++++++++++++++++ zrpc/src/proto.rs | 20 +- 5 files changed, 520 insertions(+), 35 deletions(-) create mode 100644 zrpc/src/peer2.rs diff --git a/zrpc/proto/zed.proto b/zrpc/proto/zed.proto index 7d8e3ff742..3a0b7aabb6 100644 --- a/zrpc/proto/zed.proto +++ b/zrpc/proto/zed.proto @@ -6,35 +6,45 @@ message Envelope { optional uint32 responding_to = 2; optional uint32 original_sender_id = 3; oneof payload { - Auth auth = 4; - AuthResponse auth_response = 5; - ShareWorktree share_worktree = 6; - ShareWorktreeResponse share_worktree_response = 7; - OpenWorktree open_worktree = 8; - OpenWorktreeResponse open_worktree_response = 9; - UpdateWorktree update_worktree = 10; - CloseWorktree close_worktree = 11; - OpenBuffer open_buffer = 12; - OpenBufferResponse open_buffer_response = 13; - CloseBuffer close_buffer = 14; - UpdateBuffer update_buffer = 15; - SaveBuffer save_buffer = 16; - BufferSaved buffer_saved = 17; - AddPeer add_peer = 18; - RemovePeer remove_peer = 19; - GetChannels get_channels = 20; - GetChannelsResponse get_channels_response = 21; - GetUsers get_users = 22; - GetUsersResponse get_users_response = 23; - JoinChannel join_channel = 24; - JoinChannelResponse join_channel_response = 25; - SendChannelMessage send_channel_message = 26; - ChannelMessageSent channel_message_sent = 27; + Ping ping = 4; + Pong pong = 5; + Auth auth = 6; + AuthResponse auth_response = 7; + ShareWorktree share_worktree = 8; + ShareWorktreeResponse share_worktree_response = 9; + OpenWorktree open_worktree = 10; + OpenWorktreeResponse open_worktree_response = 11; + UpdateWorktree update_worktree = 12; + CloseWorktree close_worktree = 13; + OpenBuffer open_buffer = 14; + OpenBufferResponse open_buffer_response = 15; + CloseBuffer close_buffer = 16; + UpdateBuffer update_buffer = 17; + SaveBuffer save_buffer = 18; + BufferSaved buffer_saved = 19; + AddPeer add_peer = 20; + RemovePeer remove_peer = 21; + GetChannels get_channels = 22; + GetChannelsResponse get_channels_response = 23; + GetUsers get_users = 24; + GetUsersResponse get_users_response = 25; + JoinChannel join_channel = 26; + JoinChannelResponse join_channel_response = 27; + SendChannelMessage send_channel_message = 28; + ChannelMessageSent channel_message_sent = 29; } } // Messages +message Ping { + int32 id = 1; +} + +message Pong { + int32 id = 2; +} + message Auth { int32 user_id = 1; string access_token = 2; diff --git a/zrpc/src/lib.rs b/zrpc/src/lib.rs index 8cafad9f1f..be3625e51f 100644 --- a/zrpc/src/lib.rs +++ b/zrpc/src/lib.rs @@ -1,5 +1,6 @@ pub mod auth; mod peer; +mod peer2; pub mod proto; #[cfg(any(test, feature = "test-support"))] pub mod test; diff --git a/zrpc/src/peer.rs b/zrpc/src/peer.rs index c377ad1309..d0dcf836a5 100644 --- a/zrpc/src/peer.rs +++ b/zrpc/src/peer.rs @@ -38,8 +38,8 @@ type ForegroundMessageHandler = Box, ConnectionId) -> Option>>; pub struct Receipt { - sender_id: ConnectionId, - message_id: u32, + pub sender_id: ConnectionId, + pub message_id: u32, payload_type: PhantomData, } @@ -172,7 +172,7 @@ impl Peer { } else { router.handle(connection_id, envelope.clone()).await; if let Some(envelope) = proto::build_typed_envelope(connection_id, envelope) { - broadcast_incoming_messages.send(envelope).await.ok(); + broadcast_incoming_messages.send(Arc::from(envelope)).await.ok(); } else { log::error!("unable to construct a typed envelope"); } diff --git a/zrpc/src/peer2.rs b/zrpc/src/peer2.rs new file mode 100644 index 0000000000..7ead744bdd --- /dev/null +++ b/zrpc/src/peer2.rs @@ -0,0 +1,470 @@ +use crate::{ + proto::{self, EnvelopedMessage, MessageStream, RequestMessage}, + ConnectionId, PeerId, Receipt, +}; +use anyhow::{anyhow, Context, Result}; +use async_lock::{Mutex, RwLock}; +use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; +use futures::{FutureExt, StreamExt}; +use postage::{ + mpsc, + prelude::{Sink as _, Stream as _}, +}; +use std::{ + any::Any, + collections::HashMap, + future::Future, + sync::{ + atomic::{self, AtomicU32}, + Arc, + }, +}; + +pub struct Peer { + connections: RwLock>, + next_connection_id: AtomicU32, +} + +#[derive(Clone)] +struct Connection { + outgoing_tx: mpsc::Sender, + next_message_id: Arc, + response_channels: Arc>>>, +} + +impl Peer { + pub fn new() -> Arc { + Arc::new(Self { + connections: Default::default(), + next_connection_id: Default::default(), + }) + } + + pub async fn add_connection( + self: &Arc, + conn: Conn, + ) -> ( + ConnectionId, + impl Future> + Send, + mpsc::Receiver>, + ) + where + Conn: futures::Sink + + futures::Stream> + + Send + + Unpin, + { + let (tx, rx) = conn.split(); + let connection_id = ConnectionId( + self.next_connection_id + .fetch_add(1, atomic::Ordering::SeqCst), + ); + let (mut incoming_tx, incoming_rx) = mpsc::channel(64); + let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64); + let connection = Connection { + outgoing_tx, + next_message_id: Default::default(), + response_channels: Default::default(), + }; + let mut writer = MessageStream::new(tx); + let mut reader = MessageStream::new(rx); + + let response_channels = connection.response_channels.clone(); + let handle_io = async move { + loop { + let read_message = reader.read_message().fuse(); + futures::pin_mut!(read_message); + loop { + futures::select_biased! { + incoming = read_message => match incoming { + Ok(incoming) => { + if let Some(responding_to) = incoming.responding_to { + let channel = response_channels.lock().await.remove(&responding_to); + if let Some(mut tx) = channel { + tx.send(incoming).await.ok(); + } else { + log::warn!("received RPC response to unknown request {}", responding_to); + } + } else { + if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) { + if incoming_tx.send(envelope).await.is_err() { + response_channels.lock().await.clear(); + return Ok(()) + } + } else { + log::error!("unable to construct a typed envelope"); + } + } + + break; + } + Err(error) => { + response_channels.lock().await.clear(); + Err(error).context("received invalid RPC message")?; + } + }, + outgoing = outgoing_rx.recv().fuse() => match outgoing { + Some(outgoing) => { + if let Err(result) = writer.write_message(&outgoing).await { + response_channels.lock().await.clear(); + Err(result).context("failed to write RPC message")?; + } + } + None => { + response_channels.lock().await.clear(); + return Ok(()) + } + } + } + } + } + }; + + self.connections + .write() + .await + .insert(connection_id, connection); + + (connection_id, handle_io, incoming_rx) + } + + pub async fn disconnect(&self, connection_id: ConnectionId) { + self.connections.write().await.remove(&connection_id); + } + + pub async fn reset(&self) { + self.connections.write().await.clear(); + } + + pub fn request( + self: &Arc, + receiver_id: ConnectionId, + request: T, + ) -> impl Future> { + self.request_internal(None, receiver_id, request) + } + + pub fn forward_request( + self: &Arc, + sender_id: ConnectionId, + receiver_id: ConnectionId, + request: T, + ) -> impl Future> { + self.request_internal(Some(sender_id), receiver_id, request) + } + + pub fn request_internal( + self: &Arc, + original_sender_id: Option, + receiver_id: ConnectionId, + request: T, + ) -> impl Future> { + let this = self.clone(); + let (tx, mut rx) = mpsc::channel(1); + async move { + let mut connection = this.connection(receiver_id).await?; + let message_id = connection + .next_message_id + .fetch_add(1, atomic::Ordering::SeqCst); + connection + .response_channels + .lock() + .await + .insert(message_id, tx); + connection + .outgoing_tx + .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0))) + .await + .map_err(|_| anyhow!("connection was closed"))?; + let response = rx + .recv() + .await + .ok_or_else(|| anyhow!("connection was closed"))?; + T::Response::from_envelope(response) + .ok_or_else(|| anyhow!("received response of the wrong type")) + } + } + + pub fn send( + self: &Arc, + receiver_id: ConnectionId, + message: T, + ) -> impl Future> { + let this = self.clone(); + async move { + let mut connection = this.connection(receiver_id).await?; + let message_id = connection + .next_message_id + .fetch_add(1, atomic::Ordering::SeqCst); + connection + .outgoing_tx + .send(message.into_envelope(message_id, None, None)) + .await?; + Ok(()) + } + } + + pub fn forward_send( + self: &Arc, + sender_id: ConnectionId, + receiver_id: ConnectionId, + message: T, + ) -> impl Future> { + let this = self.clone(); + async move { + let mut connection = this.connection(receiver_id).await?; + let message_id = connection + .next_message_id + .fetch_add(1, atomic::Ordering::SeqCst); + connection + .outgoing_tx + .send(message.into_envelope(message_id, None, Some(sender_id.0))) + .await?; + Ok(()) + } + } + + pub fn respond( + self: &Arc, + receipt: Receipt, + response: T::Response, + ) -> impl Future> { + let this = self.clone(); + async move { + let mut connection = this.connection(receipt.sender_id).await?; + let message_id = connection + .next_message_id + .fetch_add(1, atomic::Ordering::SeqCst); + connection + .outgoing_tx + .send(response.into_envelope(message_id, Some(receipt.message_id), None)) + .await?; + Ok(()) + } + } + + fn connection( + self: &Arc, + connection_id: ConnectionId, + ) -> impl Future> { + let this = self.clone(); + async move { + let connections = this.connections.read().await; + let connection = connections + .get(&connection_id) + .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?; + Ok(connection.clone()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{test, TypedEnvelope}; + + #[test] + fn test_request_response() { + smol::block_on(async move { + // create 2 clients connected to 1 server + let server = Peer::new(); + let client1 = Peer::new(); + let client2 = Peer::new(); + + let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional(); + let (client1_conn_id, io_task1, _) = + client1.add_connection(client1_to_server_conn).await; + let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await; + + let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional(); + let (client2_conn_id, io_task3, _) = + client2.add_connection(client2_to_server_conn).await; + let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await; + + smol::spawn(io_task1).detach(); + smol::spawn(io_task2).detach(); + smol::spawn(io_task3).detach(); + smol::spawn(io_task4).detach(); + smol::spawn(handle_messages(incoming1, server.clone())).detach(); + smol::spawn(handle_messages(incoming2, server.clone())).detach(); + + assert_eq!( + client1 + .request(client1_conn_id, proto::Ping { id: 1 },) + .await + .unwrap(), + proto::Pong { id: 1 } + ); + + assert_eq!( + client2 + .request(client2_conn_id, proto::Ping { id: 2 },) + .await + .unwrap(), + proto::Pong { id: 2 } + ); + + assert_eq!( + client1 + .request( + client1_conn_id, + proto::OpenBuffer { + worktree_id: 1, + path: "path/one".to_string(), + }, + ) + .await + .unwrap(), + proto::OpenBufferResponse { + buffer: Some(proto::Buffer { + id: 101, + content: "path/one content".to_string(), + history: vec![], + selections: vec![], + }), + } + ); + + assert_eq!( + client2 + .request( + client2_conn_id, + proto::OpenBuffer { + worktree_id: 2, + path: "path/two".to_string(), + }, + ) + .await + .unwrap(), + proto::OpenBufferResponse { + buffer: Some(proto::Buffer { + id: 102, + content: "path/two content".to_string(), + history: vec![], + selections: vec![], + }), + } + ); + + client1.disconnect(client1_conn_id).await; + client2.disconnect(client1_conn_id).await; + + async fn handle_messages( + mut messages: mpsc::Receiver>, + peer: Arc, + ) -> Result<()> { + while let Some(envelope) = messages.next().await { + if let Some(envelope) = envelope.downcast_ref::>() { + let receipt = envelope.receipt(); + peer.respond( + receipt, + proto::Pong { + id: envelope.payload.id, + }, + ) + .await? + } else if let Some(envelope) = + envelope.downcast_ref::>() + { + let message = &envelope.payload; + let receipt = envelope.receipt(); + let response = match message.path.as_str() { + "path/one" => { + assert_eq!(message.worktree_id, 1); + proto::OpenBufferResponse { + buffer: Some(proto::Buffer { + id: 101, + content: "path/one content".to_string(), + history: vec![], + selections: vec![], + }), + } + } + "path/two" => { + assert_eq!(message.worktree_id, 2); + proto::OpenBufferResponse { + buffer: Some(proto::Buffer { + id: 102, + content: "path/two content".to_string(), + history: vec![], + selections: vec![], + }), + } + } + _ => { + panic!("unexpected path {}", message.path); + } + }; + + peer.respond(receipt, response).await? + } else { + panic!("unknown message type"); + } + } + + Ok(()) + } + }); + } + + #[test] + fn test_disconnect() { + smol::block_on(async move { + let (client_conn, mut server_conn) = test::Channel::bidirectional(); + + let client = Peer::new(); + let (connection_id, io_handler, mut incoming) = + client.add_connection(client_conn).await; + + let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel(); + smol::spawn(async move { + io_handler.await.ok(); + io_ended_tx.send(()).await.unwrap(); + }) + .detach(); + + let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel(); + smol::spawn(async move { + incoming.next().await; + messages_ended_tx.send(()).await.unwrap(); + }) + .detach(); + + client.disconnect(connection_id).await; + + io_ended_rx.recv().await; + messages_ended_rx.recv().await; + assert!( + futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![])) + .await + .is_err() + ); + }); + } + + #[test] + fn test_io_error() { + smol::block_on(async move { + let (client_conn, server_conn) = test::Channel::bidirectional(); + drop(server_conn); + + let client = Peer::new(); + let (connection_id, io_handler, mut incoming) = + client.add_connection(client_conn).await; + smol::spawn(io_handler).detach(); + smol::spawn(async move { incoming.next().await }).detach(); + + let err = client + .request( + connection_id, + proto::Auth { + user_id: 42, + access_token: "token".to_string(), + }, + ) + .await + .unwrap_err(); + assert_eq!(err.to_string(), "connection was closed"); + }); + } +} diff --git a/zrpc/src/proto.rs b/zrpc/src/proto.rs index 1c799ebe21..d8c794fd63 100644 --- a/zrpc/src/proto.rs +++ b/zrpc/src/proto.rs @@ -4,7 +4,6 @@ use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSock use futures::{SinkExt as _, StreamExt as _}; use prost::Message; use std::any::Any; -use std::sync::Arc; use std::{ io, time::{Duration, SystemTime, UNIX_EPOCH}, @@ -34,14 +33,16 @@ pub trait RequestMessage: EnvelopedMessage { macro_rules! messages { ($($name:ident),* $(,)?) => { - pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option> { + pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option> { match envelope.payload { - $(Some(envelope::Payload::$name(payload)) => Some(Arc::new(TypedEnvelope { - sender_id, - original_sender_id: envelope.original_sender_id.map(PeerId), - message_id: envelope.id, - payload, - })), )* + $(Some(envelope::Payload::$name(payload)) => { + Some(Box::new(TypedEnvelope { + sender_id, + original_sender_id: envelope.original_sender_id.map(PeerId), + message_id: envelope.id, + payload, + })) + }, )* _ => None } } @@ -116,6 +117,8 @@ messages!( OpenBufferResponse, OpenWorktree, OpenWorktreeResponse, + Ping, + Pong, RemovePeer, SaveBuffer, SendChannelMessage, @@ -132,6 +135,7 @@ request_messages!( (JoinChannel, JoinChannelResponse), (OpenBuffer, OpenBufferResponse), (OpenWorktree, OpenWorktreeResponse), + (Ping, Pong), (SaveBuffer, BufferSaved), (ShareWorktree, ShareWorktreeResponse), );