diff --git a/server/src/auth.rs b/server/src/auth.rs index ac326b15de..d61428fa37 100644 --- a/server/src/auth.rs +++ b/server/src/auth.rs @@ -137,34 +137,6 @@ impl PeerExt for Peer { } } -#[async_trait] -impl PeerExt for zrpc::peer2::Peer { - async fn sign_out( - self: &Arc, - connection_id: zrpc::ConnectionId, - state: &AppState, - ) -> tide::Result<()> { - self.disconnect(connection_id).await; - let worktree_ids = state.rpc.write().await.remove_connection(connection_id); - for worktree_id in worktree_ids { - let state = state.rpc.read().await; - if let Some(worktree) = state.worktrees.get(&worktree_id) { - rpc::broadcast(connection_id, worktree.connection_ids(), |conn_id| { - self.send( - conn_id, - proto::RemovePeer { - worktree_id, - peer_id: connection_id.0, - }, - ) - }) - .await?; - } - } - Ok(()) - } -} - pub fn build_client(client_id: &str, client_secret: &str) -> Client { Client::new( ClientId::new(client_id.to_string()), diff --git a/server/src/rpc.rs b/server/src/rpc.rs index b7be90d348..e628ef2c81 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -30,13 +30,15 @@ use time::OffsetDateTime; use zrpc::{ auth::random_token, proto::{self, EnvelopedMessage}, - ConnectionId, Peer, Router, TypedEnvelope, + ConnectionId, Peer, TypedEnvelope, }; type ReplicaId = u16; type Handler = Box< - dyn Fn(&mut Option>, Arc) -> Option>, + dyn Send + + Sync + + Fn(&mut Option>, Arc) -> Option>, >; #[derive(Default)] @@ -48,7 +50,7 @@ struct ServerBuilder { impl ServerBuilder { pub fn on_message(&mut self, handler: F) -> &mut Self where - F: 'static + Fn(Box>, Arc) -> Fut, + F: 'static + Send + Sync + Fn(Box>, Arc) -> Fut, Fut: 'static + Send + Future, M: EnvelopedMessage, { @@ -73,23 +75,23 @@ impl ServerBuilder { self } - pub fn build(self, rpc: Arc, state: Arc) -> Arc { + pub fn build(self, rpc: &Arc, state: &Arc) -> Arc { Arc::new(Server { - rpc, - state, + rpc: rpc.clone(), + state: state.clone(), handlers: self.handlers, }) } } -struct Server { - rpc: Arc, +pub struct Server { + rpc: Arc, state: Arc, handlers: Vec, } impl Server { - pub async fn add_connection( + pub async fn handle_connection( self: &Arc, connection: Conn, addr: String, @@ -332,99 +334,31 @@ impl State { } } -trait MessageHandler<'a, M: proto::EnvelopedMessage> { - type Output: 'a + Send + Future>; - - fn handle( - &self, - message: TypedEnvelope, - rpc: &'a Arc, - app_state: &'a Arc, - ) -> Self::Output; -} - -impl<'a, M, F, Fut> MessageHandler<'a, M> for F -where - M: proto::EnvelopedMessage, - F: Fn(TypedEnvelope, &'a Arc, &'a Arc) -> Fut, - Fut: 'a + Send + Future>, -{ - type Output = Fut; - - fn handle( - &self, - message: TypedEnvelope, - rpc: &'a Arc, - app_state: &'a Arc, - ) -> Self::Output { - (self)(message, rpc, app_state) - } -} - -fn on_message(router: &mut Router, rpc: &Arc, app_state: &Arc, handler: H) -where - M: EnvelopedMessage, - H: 'static + Clone + Send + Sync + for<'a> MessageHandler<'a, M>, -{ - let rpc = rpc.clone(); - let handler = handler.clone(); - let app_state = app_state.clone(); - router.add_message_handler(move |message| { - let rpc = rpc.clone(); - let handler = handler.clone(); - let app_state = app_state.clone(); - async move { - let sender_id = message.sender_id; - let message_id = message.message_id; - let start_time = Instant::now(); - log::info!( - "RPC message received. id: {}.{}, type:{}", - sender_id, - message_id, - M::NAME - ); - if let Err(err) = handler.handle(message, &rpc, &app_state).await { - log::error!("error handling message: {:?}", err); - } else { - log::info!( - "RPC message handled. id:{}.{}, duration:{:?}", - sender_id, - message_id, - start_time.elapsed() - ); - } - - Ok(()) - } - }); -} - -pub fn add_rpc_routes(router: &mut Router, state: &Arc, rpc: &Arc) { - on_message(router, rpc, state, share_worktree); - on_message(router, rpc, state, join_worktree); - on_message(router, rpc, state, update_worktree); - on_message(router, rpc, state, close_worktree); - on_message(router, rpc, state, open_buffer); - on_message(router, rpc, state, close_buffer); - on_message(router, rpc, state, update_buffer); - on_message(router, rpc, state, buffer_saved); - on_message(router, rpc, state, save_buffer); - on_message(router, rpc, state, get_channels); - on_message(router, rpc, state, get_users); - on_message(router, rpc, state, join_channel); - on_message(router, rpc, state, send_channel_message); +pub fn build_server(state: &Arc, rpc: &Arc) -> Arc { + ServerBuilder::default() + // .on_message(share_worktree) + // .on_message(join_worktree) + // .on_message(update_worktree) + // .on_message(close_worktree) + // .on_message(open_buffer) + // .on_message(close_buffer) + // .on_message(update_buffer) + // .on_message(buffer_saved) + // .on_message(save_buffer) + // .on_message(get_channels) + // .on_message(get_users) + // .on_message(join_channel) + // .on_message(send_channel_message) + .build(rpc, state) } pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { - let mut router = Router::new(); - add_rpc_routes(&mut router, app.state(), rpc); - let router = Arc::new(router); + let server = build_server(app.state(), rpc); let rpc = rpc.clone(); app.at("/rpc").with(auth::VerifyToken).get(move |request: Request>| { let user_id = request.ext::().copied(); - let rpc = rpc.clone(); - let router = router.clone(); + let server = server.clone(); async move { const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; @@ -451,12 +385,11 @@ pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { let http_res: &mut tide::http::Response = response.as_mut(); let upgrade_receiver = http_res.recv_upgrade().await; let addr = request.remote().unwrap_or("unknown").to_string(); - let state = request.state().clone(); let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?; task::spawn(async move { if let Some(stream) = upgrade_receiver.await { let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await; - handle_connection(rpc, router, state, addr, stream, user_id).await; + server.handle_connection(stream, addr, user_id).await; } }); @@ -465,43 +398,6 @@ pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { }); } -pub async fn handle_connection( - rpc: Arc, - router: Arc, - state: Arc, - addr: String, - stream: Conn, - user_id: UserId, -) where - Conn: 'static - + futures::Sink - + futures::Stream> - + Send - + Unpin, -{ - log::info!("accepted rpc connection: {:?}", addr); - let (connection_id, handle_io, handle_messages) = rpc.add_connection(stream, router).await; - state - .rpc - .write() - .await - .add_connection(connection_id, user_id); - - let handle_messages = async move { - handle_messages.await; - Ok(()) - }; - - if let Err(e) = futures::try_join!(handle_messages, handle_io) { - log::error!("error handling rpc connection {:?} - {:?}", addr, e); - } - - log::info!("closing connection to {:?}", addr); - if let Err(e) = rpc.sign_out(connection_id, &state).await { - log::error!("error signing out connection {:?} - {:?}", addr, e); - } -} - async fn share_worktree( mut request: TypedEnvelope, rpc: &Arc, diff --git a/server/src/tests.rs b/server/src/tests.rs index cce311c5f3..18a7375f86 100644 --- a/server/src/tests.rs +++ b/server/src/tests.rs @@ -2,7 +2,7 @@ use crate::{ auth, db::{self, UserId}, github, - rpc::{self, add_rpc_routes}, + rpc::{self, build_server}, AppState, Config, }; use async_std::task; @@ -24,7 +24,7 @@ use zed::{ test::Channel, worktree::Worktree, }; -use zrpc::{ForegroundRouter, Peer, Router}; +use zrpc::Peer; #[gpui::test] async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { @@ -541,7 +541,7 @@ impl TestServer { let app_state = Self::build_app_state(&db_name).await; let peer = Peer::new(); let mut router = Router::new(); - add_rpc_routes(&mut router, &app_state, &peer); + build_server(&mut router, &app_state, &peer); Self { peer, router: Arc::new(router), diff --git a/zed/src/lib.rs b/zed/src/lib.rs index f283584c0e..90e68c698d 100644 --- a/zed/src/lib.rs +++ b/zed/src/lib.rs @@ -24,14 +24,12 @@ pub use settings::Settings; use parking_lot::Mutex; use postage::watch; use std::sync::Arc; -use zrpc::ForegroundRouter; pub struct AppState { pub settings_tx: Arc>>, pub settings: watch::Receiver, pub languages: Arc, pub themes: Arc, - pub rpc_router: Arc, pub rpc: rpc::Client, pub fs: Arc, } diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 00f5e2dd74..2bd9fd7f81 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -13,7 +13,7 @@ use zrpc::proto::EntityMessage; pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope}; use zrpc::{ proto::{EnvelopedMessage, RequestMessage}, - ForegroundRouter, Peer, Receipt, + Peer, Receipt, }; lazy_static! { @@ -43,25 +43,6 @@ impl Client { } } - pub fn on_message( - &self, - router: &mut ForegroundRouter, - handler: H, - cx: &mut gpui::MutableAppContext, - ) where - H: 'static + Clone + for<'a> MessageHandler<'a, M>, - M: proto::EnvelopedMessage, - { - let this = self.clone(); - let cx = cx.to_async(); - router.add_message_handler(move |message| { - let this = this.clone(); - let mut cx = cx.clone(); - let handler = handler.clone(); - async move { handler.handle(message, &this, &mut cx).await } - }); - } - pub fn subscribe_from_model( &self, remote_id: u64, @@ -90,11 +71,7 @@ impl Client { }) } - pub async fn log_in_and_connect( - &self, - router: Arc, - cx: AsyncAppContext, - ) -> surf::Result<()> { + pub async fn log_in_and_connect(&self, cx: AsyncAppContext) -> surf::Result<()> { if self.state.read().await.connection_id.is_some() { return Ok(()); } @@ -111,13 +88,13 @@ impl Client { .await .context("websocket handshake")?; log::info!("connected to rpc address {}", *ZED_SERVER_URL); - self.add_connection(stream, router, cx).await?; + self.add_connection(stream, cx).await?; } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") { let stream = smol::net::TcpStream::connect(host).await?; let request = request.uri(format!("ws://{}/rpc", host)).body(())?; let (stream, _) = async_tungstenite::client_async(request, stream).await?; log::info!("connected to rpc address {}", *ZED_SERVER_URL); - self.add_connection(stream, router, cx).await?; + self.add_connection(stream, cx).await?; } else { return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?; }; @@ -125,12 +102,7 @@ impl Client { Ok(()) } - pub async fn add_connection( - &self, - conn: Conn, - router: Arc, - cx: AsyncAppContext, - ) -> surf::Result<()> + pub async fn add_connection(&self, conn: Conn, cx: AsyncAppContext) -> surf::Result<()> where Conn: 'static + futures::Sink @@ -138,8 +110,7 @@ impl Client { + Unpin + Send, { - let (connection_id, handle_io, handle_messages) = - self.peer.add_connection(conn, router).await; + let (connection_id, handle_io, handle_messages) = self.peer.add_connection(conn).await; cx.foreground().spawn(handle_messages).detach(); cx.background() .spawn(async move { diff --git a/zed/src/test.rs b/zed/src/test.rs index e2ddff054e..ce22369126 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -15,7 +15,6 @@ use std::{ sync::Arc, }; use tempdir::TempDir; -use zrpc::ForegroundRouter; #[cfg(feature = "test-support")] pub use zrpc::test::Channel; @@ -163,7 +162,6 @@ pub fn build_app_state(cx: &AppContext) -> Arc { settings, themes, languages: languages.clone(), - rpc_router: Arc::new(ForegroundRouter::new()), rpc: rpc::Client::new(languages), fs: Arc::new(RealFs), }) diff --git a/zed/src/workspace.rs b/zed/src/workspace.rs index cb543522d8..e074d57a64 100644 --- a/zed/src/workspace.rs +++ b/zed/src/workspace.rs @@ -728,10 +728,9 @@ impl Workspace { fn share_worktree(&mut self, app_state: &Arc, cx: &mut ViewContext) { let rpc = self.rpc.clone(); let platform = cx.platform(); - let router = app_state.rpc_router.clone(); let task = cx.spawn(|this, mut cx| async move { - rpc.log_in_and_connect(router, cx.clone()).await?; + rpc.log_in_and_connect(cx.clone()).await?; let share_task = this.update(&mut cx, |this, cx| { let worktree = this.worktrees.iter().next()?; @@ -761,10 +760,9 @@ impl Workspace { fn join_worktree(&mut self, app_state: &Arc, cx: &mut ViewContext) { let rpc = self.rpc.clone(); let languages = self.languages.clone(); - let router = app_state.rpc_router.clone(); let task = cx.spawn(|this, mut cx| async move { - rpc.log_in_and_connect(router, cx.clone()).await?; + rpc.log_in_and_connect(cx.clone()).await?; let worktree_url = cx .platform() diff --git a/zrpc/src/lib.rs b/zrpc/src/lib.rs index 67132cf299..8cafad9f1f 100644 --- a/zrpc/src/lib.rs +++ b/zrpc/src/lib.rs @@ -1,6 +1,5 @@ pub mod auth; mod peer; -pub mod peer2; pub mod proto; #[cfg(any(test, feature = "test-support"))] pub mod test; diff --git a/zrpc/src/peer.rs b/zrpc/src/peer.rs index d0dcf836a5..315f56b316 100644 --- a/zrpc/src/peer.rs +++ b/zrpc/src/peer.rs @@ -2,17 +2,14 @@ use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage}; use anyhow::{anyhow, Context, Result}; use async_lock::{Mutex, RwLock}; use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; -use futures::{ - future::{self, BoxFuture, LocalBoxFuture}, - FutureExt, Stream, StreamExt, -}; +use futures::{FutureExt, StreamExt}; use postage::{ - broadcast, mpsc, + mpsc, prelude::{Sink as _, Stream as _}, }; use std::{ - any::{Any, TypeId}, - collections::{HashMap, HashSet}, + any::Any, + collections::HashMap, fmt, future::Future, marker::PhantomData, @@ -25,17 +22,20 @@ use std::{ #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub struct ConnectionId(pub u32); +impl fmt::Display for ConnectionId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub struct PeerId(pub u32); -type MessageHandler = Box< - dyn Send - + Sync - + Fn(&mut Option, ConnectionId) -> Option>, ->; - -type ForegroundMessageHandler = - Box, ConnectionId) -> Option>>; +impl fmt::Display for PeerId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} pub struct Receipt { pub sender_id: ConnectionId, @@ -43,6 +43,18 @@ pub struct Receipt { payload_type: PhantomData, } +impl Clone for Receipt { + fn clone(&self) -> Self { + Self { + sender_id: self.sender_id, + message_id: self.message_id, + payload_type: PhantomData, + } + } +} + +impl Copy for Receipt {} + pub struct TypedEnvelope { pub sender_id: ConnectionId, pub original_sender_id: Option, @@ -67,17 +79,9 @@ impl TypedEnvelope { } } -pub type Router = RouterInternal; -pub type ForegroundRouter = RouterInternal; -pub struct RouterInternal { - message_handlers: Vec, - handler_types: HashSet, -} - pub struct Peer { connections: RwLock>, next_connection_id: AtomicU32, - incoming_messages: broadcast::Sender>, } #[derive(Clone)] @@ -92,22 +96,18 @@ impl Peer { Arc::new(Self { connections: Default::default(), next_connection_id: Default::default(), - incoming_messages: broadcast::channel(256).0, }) } - pub async fn add_connection( + pub async fn add_connection( self: &Arc, conn: Conn, - router: Arc>, ) -> ( ConnectionId, impl Future> + Send, - impl Future, + mpsc::Receiver>, ) where - H: Fn(&mut Option, ConnectionId) -> Option, - Fut: Future, Conn: futures::Sink + futures::Stream> + Send @@ -118,7 +118,7 @@ impl Peer { self.next_connection_id .fetch_add(1, atomic::Ordering::SeqCst), ); - let (mut incoming_tx, mut incoming_rx) = mpsc::channel(64); + let (mut incoming_tx, incoming_rx) = mpsc::channel(64); let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64); let connection = Connection { outgoing_tx, @@ -128,6 +128,7 @@ impl Peer { let mut writer = MessageStream::new(tx); let mut reader = MessageStream::new(rx); + let response_channels = connection.response_channels.clone(); let handle_io = async move { loop { let read_message = reader.read_message().fuse(); @@ -136,57 +137,54 @@ impl Peer { futures::select_biased! { incoming = read_message => match incoming { Ok(incoming) => { - if incoming_tx.send(incoming).await.is_err() { - return Ok(()); + if let Some(responding_to) = incoming.responding_to { + let channel = response_channels.lock().await.remove(&responding_to); + if let Some(mut tx) = channel { + tx.send(incoming).await.ok(); + } else { + log::warn!("received RPC response to unknown request {}", responding_to); + } + } else { + if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) { + if incoming_tx.send(envelope).await.is_err() { + response_channels.lock().await.clear(); + return Ok(()) + } + } else { + log::error!("unable to construct a typed envelope"); + } } + break; } Err(error) => { + response_channels.lock().await.clear(); Err(error).context("received invalid RPC message")?; } }, outgoing = outgoing_rx.recv().fuse() => match outgoing { Some(outgoing) => { if let Err(result) = writer.write_message(&outgoing).await { + response_channels.lock().await.clear(); Err(result).context("failed to write RPC message")?; } } - None => return Ok(()), + None => { + response_channels.lock().await.clear(); + return Ok(()) + } } } } } }; - let mut broadcast_incoming_messages = self.incoming_messages.clone(); - let response_channels = connection.response_channels.clone(); - let handle_messages = async move { - while let Some(envelope) = incoming_rx.recv().await { - if let Some(responding_to) = envelope.responding_to { - let channel = response_channels.lock().await.remove(&responding_to); - if let Some(mut tx) = channel { - tx.send(envelope).await.ok(); - } else { - log::warn!("received RPC response to unknown request {}", responding_to); - } - } else { - router.handle(connection_id, envelope.clone()).await; - if let Some(envelope) = proto::build_typed_envelope(connection_id, envelope) { - broadcast_incoming_messages.send(Arc::from(envelope)).await.ok(); - } else { - log::error!("unable to construct a typed envelope"); - } - } - } - response_channels.lock().await.clear(); - }; - self.connections .write() .await .insert(connection_id, connection); - (connection_id, handle_io, handle_messages) + (connection_id, handle_io, incoming_rx) } pub async fn disconnect(&self, connection_id: ConnectionId) { @@ -197,12 +195,6 @@ impl Peer { self.connections.write().await.clear(); } - pub fn subscribe(&self) -> impl Stream>> { - self.incoming_messages - .subscribe() - .filter_map(|envelope| future::ready(Arc::downcast(envelope).ok())) - } - pub fn request( self: &Arc, receiver_id: ConnectionId, @@ -325,142 +317,10 @@ impl Peer { } } -impl RouterInternal -where - H: Fn(&mut Option, ConnectionId) -> Option, - Fut: Future, -{ - pub fn new() -> Self { - Self { - message_handlers: Default::default(), - handler_types: Default::default(), - } - } - - async fn handle(&self, connection_id: ConnectionId, message: proto::Envelope) { - let mut envelope = Some(message); - for handler in self.message_handlers.iter() { - if let Some(future) = handler(&mut envelope, connection_id) { - future.await; - return; - } - } - log::warn!("unhandled message: {:?}", envelope.unwrap().payload); - } -} - -impl Router { - pub fn add_message_handler(&mut self, handler: F) - where - T: EnvelopedMessage, - Fut: 'static + Send + Future>, - F: 'static + Send + Sync + Fn(TypedEnvelope) -> Fut, - { - if !self.handler_types.insert(TypeId::of::()) { - panic!("duplicate handler type"); - } - - self.message_handlers - .push(Box::new(move |envelope, connection_id| { - if envelope.as_ref().map_or(false, T::matches_envelope) { - let envelope = Option::take(envelope).unwrap(); - let message_id = envelope.id; - let future = handler(TypedEnvelope { - sender_id: connection_id, - original_sender_id: envelope.original_sender_id.map(PeerId), - message_id, - payload: T::from_envelope(envelope).unwrap(), - }); - Some( - async move { - if let Err(error) = future.await { - log::error!( - "error handling message {} {}: {:?}", - T::NAME, - message_id, - error - ); - } - } - .boxed(), - ) - } else { - None - } - })); - } -} - -impl ForegroundRouter { - pub fn add_message_handler(&mut self, handler: F) - where - T: EnvelopedMessage, - Fut: 'static + Future>, - F: 'static + Fn(TypedEnvelope) -> Fut, - { - if !self.handler_types.insert(TypeId::of::()) { - panic!("duplicate handler type"); - } - - self.message_handlers - .push(Box::new(move |envelope, connection_id| { - if envelope.as_ref().map_or(false, T::matches_envelope) { - let envelope = Option::take(envelope).unwrap(); - let message_id = envelope.id; - let future = handler(TypedEnvelope { - sender_id: connection_id, - original_sender_id: envelope.original_sender_id.map(PeerId), - message_id, - payload: T::from_envelope(envelope).unwrap(), - }); - Some( - async move { - if let Err(error) = future.await { - log::error!( - "error handling message {} {}: {:?}", - T::NAME, - message_id, - error - ); - } - } - .boxed_local(), - ) - } else { - None - } - })); - } -} - -impl Clone for Receipt { - fn clone(&self) -> Self { - Self { - sender_id: self.sender_id, - message_id: self.message_id, - payload_type: PhantomData, - } - } -} - -impl Copy for Receipt {} - -impl fmt::Display for ConnectionId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - -impl fmt::Display for PeerId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - #[cfg(test)] mod tests { use super::*; - use crate::test; + use crate::{test, TypedEnvelope}; #[test] fn test_request_response() { @@ -470,139 +330,37 @@ mod tests { let client1 = Peer::new(); let client2 = Peer::new(); - let mut router = Router::new(); - router.add_message_handler({ - let server = server.clone(); - move |envelope: TypedEnvelope| { - let server = server.clone(); - async move { - let receipt = envelope.receipt(); - let message = envelope.payload; - server - .respond( - receipt, - match message.user_id { - 1 => { - assert_eq!(message.access_token, "access-token-1"); - proto::AuthResponse { - credentials_valid: true, - } - } - 2 => { - assert_eq!(message.access_token, "access-token-2"); - proto::AuthResponse { - credentials_valid: false, - } - } - _ => { - panic!("unexpected user id {}", message.user_id); - } - }, - ) - .await - } - } - }); - - router.add_message_handler({ - let server = server.clone(); - move |envelope: TypedEnvelope| { - let server = server.clone(); - async move { - let receipt = envelope.receipt(); - let message = envelope.payload; - server - .respond( - receipt, - match message.path.as_str() { - "path/one" => { - assert_eq!(message.worktree_id, 1); - proto::OpenBufferResponse { - buffer: Some(proto::Buffer { - id: 101, - content: "path/one content".to_string(), - history: vec![], - selections: vec![], - }), - } - } - "path/two" => { - assert_eq!(message.worktree_id, 2); - proto::OpenBufferResponse { - buffer: Some(proto::Buffer { - id: 102, - content: "path/two content".to_string(), - history: vec![], - selections: vec![], - }), - } - } - _ => { - panic!("unexpected path {}", message.path); - } - }, - ) - .await - } - } - }); - let router = Arc::new(router); - let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional(); - let (client1_conn_id, io_task1, msg_task1) = client1 - .add_connection(client1_to_server_conn, router.clone()) - .await; - let (_, io_task2, msg_task2) = server - .add_connection(server_to_client_1_conn, router.clone()) - .await; + let (client1_conn_id, io_task1, _) = + client1.add_connection(client1_to_server_conn).await; + let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await; let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional(); - let (client2_conn_id, io_task3, msg_task3) = client2 - .add_connection(client2_to_server_conn, router.clone()) - .await; - let (_, io_task4, msg_task4) = server - .add_connection(server_to_client_2_conn, router.clone()) - .await; + let (client2_conn_id, io_task3, _) = + client2.add_connection(client2_to_server_conn).await; + let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await; smol::spawn(io_task1).detach(); smol::spawn(io_task2).detach(); smol::spawn(io_task3).detach(); smol::spawn(io_task4).detach(); - smol::spawn(msg_task1).detach(); - smol::spawn(msg_task2).detach(); - smol::spawn(msg_task3).detach(); - smol::spawn(msg_task4).detach(); + smol::spawn(handle_messages(incoming1, server.clone())).detach(); + smol::spawn(handle_messages(incoming2, server.clone())).detach(); assert_eq!( client1 - .request( - client1_conn_id, - proto::Auth { - user_id: 1, - access_token: "access-token-1".to_string(), - }, - ) + .request(client1_conn_id, proto::Ping { id: 1 },) .await .unwrap(), - proto::AuthResponse { - credentials_valid: true, - } + proto::Pong { id: 1 } ); assert_eq!( client2 - .request( - client2_conn_id, - proto::Auth { - user_id: 2, - access_token: "access-token-2".to_string(), - }, - ) + .request(client2_conn_id, proto::Ping { id: 2 },) .await .unwrap(), - proto::AuthResponse { - credentials_valid: false, - } + proto::Pong { id: 2 } ); assert_eq!( @@ -649,6 +407,62 @@ mod tests { client1.disconnect(client1_conn_id).await; client2.disconnect(client1_conn_id).await; + + async fn handle_messages( + mut messages: mpsc::Receiver>, + peer: Arc, + ) -> Result<()> { + while let Some(envelope) = messages.next().await { + if let Some(envelope) = envelope.downcast_ref::>() { + let receipt = envelope.receipt(); + peer.respond( + receipt, + proto::Pong { + id: envelope.payload.id, + }, + ) + .await? + } else if let Some(envelope) = + envelope.downcast_ref::>() + { + let message = &envelope.payload; + let receipt = envelope.receipt(); + let response = match message.path.as_str() { + "path/one" => { + assert_eq!(message.worktree_id, 1); + proto::OpenBufferResponse { + buffer: Some(proto::Buffer { + id: 101, + content: "path/one content".to_string(), + history: vec![], + selections: vec![], + }), + } + } + "path/two" => { + assert_eq!(message.worktree_id, 2); + proto::OpenBufferResponse { + buffer: Some(proto::Buffer { + id: 102, + content: "path/two content".to_string(), + history: vec![], + selections: vec![], + }), + } + } + _ => { + panic!("unexpected path {}", message.path); + } + }; + + peer.respond(receipt, response).await? + } else { + panic!("unknown message type"); + } + } + + Ok(()) + } }); } @@ -658,9 +472,8 @@ mod tests { let (client_conn, mut server_conn) = test::Channel::bidirectional(); let client = Peer::new(); - let router = Arc::new(Router::new()); - let (connection_id, io_handler, message_handler) = - client.add_connection(client_conn, router).await; + let (connection_id, io_handler, mut incoming) = + client.add_connection(client_conn).await; let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel(); smol::spawn(async move { @@ -671,7 +484,7 @@ mod tests { let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel(); smol::spawn(async move { - message_handler.await; + incoming.next().await; messages_ended_tx.send(()).await.unwrap(); }) .detach(); @@ -695,11 +508,10 @@ mod tests { drop(server_conn); let client = Peer::new(); - let router = Arc::new(Router::new()); - let (connection_id, io_handler, message_handler) = - client.add_connection(client_conn, router).await; + let (connection_id, io_handler, mut incoming) = + client.add_connection(client_conn).await; smol::spawn(io_handler).detach(); - smol::spawn(message_handler).detach(); + smol::spawn(async move { incoming.next().await }).detach(); let err = client .request( diff --git a/zrpc/src/peer2.rs b/zrpc/src/peer2.rs deleted file mode 100644 index 7ead744bdd..0000000000 --- a/zrpc/src/peer2.rs +++ /dev/null @@ -1,470 +0,0 @@ -use crate::{ - proto::{self, EnvelopedMessage, MessageStream, RequestMessage}, - ConnectionId, PeerId, Receipt, -}; -use anyhow::{anyhow, Context, Result}; -use async_lock::{Mutex, RwLock}; -use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; -use futures::{FutureExt, StreamExt}; -use postage::{ - mpsc, - prelude::{Sink as _, Stream as _}, -}; -use std::{ - any::Any, - collections::HashMap, - future::Future, - sync::{ - atomic::{self, AtomicU32}, - Arc, - }, -}; - -pub struct Peer { - connections: RwLock>, - next_connection_id: AtomicU32, -} - -#[derive(Clone)] -struct Connection { - outgoing_tx: mpsc::Sender, - next_message_id: Arc, - response_channels: Arc>>>, -} - -impl Peer { - pub fn new() -> Arc { - Arc::new(Self { - connections: Default::default(), - next_connection_id: Default::default(), - }) - } - - pub async fn add_connection( - self: &Arc, - conn: Conn, - ) -> ( - ConnectionId, - impl Future> + Send, - mpsc::Receiver>, - ) - where - Conn: futures::Sink - + futures::Stream> - + Send - + Unpin, - { - let (tx, rx) = conn.split(); - let connection_id = ConnectionId( - self.next_connection_id - .fetch_add(1, atomic::Ordering::SeqCst), - ); - let (mut incoming_tx, incoming_rx) = mpsc::channel(64); - let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64); - let connection = Connection { - outgoing_tx, - next_message_id: Default::default(), - response_channels: Default::default(), - }; - let mut writer = MessageStream::new(tx); - let mut reader = MessageStream::new(rx); - - let response_channels = connection.response_channels.clone(); - let handle_io = async move { - loop { - let read_message = reader.read_message().fuse(); - futures::pin_mut!(read_message); - loop { - futures::select_biased! { - incoming = read_message => match incoming { - Ok(incoming) => { - if let Some(responding_to) = incoming.responding_to { - let channel = response_channels.lock().await.remove(&responding_to); - if let Some(mut tx) = channel { - tx.send(incoming).await.ok(); - } else { - log::warn!("received RPC response to unknown request {}", responding_to); - } - } else { - if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) { - if incoming_tx.send(envelope).await.is_err() { - response_channels.lock().await.clear(); - return Ok(()) - } - } else { - log::error!("unable to construct a typed envelope"); - } - } - - break; - } - Err(error) => { - response_channels.lock().await.clear(); - Err(error).context("received invalid RPC message")?; - } - }, - outgoing = outgoing_rx.recv().fuse() => match outgoing { - Some(outgoing) => { - if let Err(result) = writer.write_message(&outgoing).await { - response_channels.lock().await.clear(); - Err(result).context("failed to write RPC message")?; - } - } - None => { - response_channels.lock().await.clear(); - return Ok(()) - } - } - } - } - } - }; - - self.connections - .write() - .await - .insert(connection_id, connection); - - (connection_id, handle_io, incoming_rx) - } - - pub async fn disconnect(&self, connection_id: ConnectionId) { - self.connections.write().await.remove(&connection_id); - } - - pub async fn reset(&self) { - self.connections.write().await.clear(); - } - - pub fn request( - self: &Arc, - receiver_id: ConnectionId, - request: T, - ) -> impl Future> { - self.request_internal(None, receiver_id, request) - } - - pub fn forward_request( - self: &Arc, - sender_id: ConnectionId, - receiver_id: ConnectionId, - request: T, - ) -> impl Future> { - self.request_internal(Some(sender_id), receiver_id, request) - } - - pub fn request_internal( - self: &Arc, - original_sender_id: Option, - receiver_id: ConnectionId, - request: T, - ) -> impl Future> { - let this = self.clone(); - let (tx, mut rx) = mpsc::channel(1); - async move { - let mut connection = this.connection(receiver_id).await?; - let message_id = connection - .next_message_id - .fetch_add(1, atomic::Ordering::SeqCst); - connection - .response_channels - .lock() - .await - .insert(message_id, tx); - connection - .outgoing_tx - .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0))) - .await - .map_err(|_| anyhow!("connection was closed"))?; - let response = rx - .recv() - .await - .ok_or_else(|| anyhow!("connection was closed"))?; - T::Response::from_envelope(response) - .ok_or_else(|| anyhow!("received response of the wrong type")) - } - } - - pub fn send( - self: &Arc, - receiver_id: ConnectionId, - message: T, - ) -> impl Future> { - let this = self.clone(); - async move { - let mut connection = this.connection(receiver_id).await?; - let message_id = connection - .next_message_id - .fetch_add(1, atomic::Ordering::SeqCst); - connection - .outgoing_tx - .send(message.into_envelope(message_id, None, None)) - .await?; - Ok(()) - } - } - - pub fn forward_send( - self: &Arc, - sender_id: ConnectionId, - receiver_id: ConnectionId, - message: T, - ) -> impl Future> { - let this = self.clone(); - async move { - let mut connection = this.connection(receiver_id).await?; - let message_id = connection - .next_message_id - .fetch_add(1, atomic::Ordering::SeqCst); - connection - .outgoing_tx - .send(message.into_envelope(message_id, None, Some(sender_id.0))) - .await?; - Ok(()) - } - } - - pub fn respond( - self: &Arc, - receipt: Receipt, - response: T::Response, - ) -> impl Future> { - let this = self.clone(); - async move { - let mut connection = this.connection(receipt.sender_id).await?; - let message_id = connection - .next_message_id - .fetch_add(1, atomic::Ordering::SeqCst); - connection - .outgoing_tx - .send(response.into_envelope(message_id, Some(receipt.message_id), None)) - .await?; - Ok(()) - } - } - - fn connection( - self: &Arc, - connection_id: ConnectionId, - ) -> impl Future> { - let this = self.clone(); - async move { - let connections = this.connections.read().await; - let connection = connections - .get(&connection_id) - .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?; - Ok(connection.clone()) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{test, TypedEnvelope}; - - #[test] - fn test_request_response() { - smol::block_on(async move { - // create 2 clients connected to 1 server - let server = Peer::new(); - let client1 = Peer::new(); - let client2 = Peer::new(); - - let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional(); - let (client1_conn_id, io_task1, _) = - client1.add_connection(client1_to_server_conn).await; - let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await; - - let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional(); - let (client2_conn_id, io_task3, _) = - client2.add_connection(client2_to_server_conn).await; - let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await; - - smol::spawn(io_task1).detach(); - smol::spawn(io_task2).detach(); - smol::spawn(io_task3).detach(); - smol::spawn(io_task4).detach(); - smol::spawn(handle_messages(incoming1, server.clone())).detach(); - smol::spawn(handle_messages(incoming2, server.clone())).detach(); - - assert_eq!( - client1 - .request(client1_conn_id, proto::Ping { id: 1 },) - .await - .unwrap(), - proto::Pong { id: 1 } - ); - - assert_eq!( - client2 - .request(client2_conn_id, proto::Ping { id: 2 },) - .await - .unwrap(), - proto::Pong { id: 2 } - ); - - assert_eq!( - client1 - .request( - client1_conn_id, - proto::OpenBuffer { - worktree_id: 1, - path: "path/one".to_string(), - }, - ) - .await - .unwrap(), - proto::OpenBufferResponse { - buffer: Some(proto::Buffer { - id: 101, - content: "path/one content".to_string(), - history: vec![], - selections: vec![], - }), - } - ); - - assert_eq!( - client2 - .request( - client2_conn_id, - proto::OpenBuffer { - worktree_id: 2, - path: "path/two".to_string(), - }, - ) - .await - .unwrap(), - proto::OpenBufferResponse { - buffer: Some(proto::Buffer { - id: 102, - content: "path/two content".to_string(), - history: vec![], - selections: vec![], - }), - } - ); - - client1.disconnect(client1_conn_id).await; - client2.disconnect(client1_conn_id).await; - - async fn handle_messages( - mut messages: mpsc::Receiver>, - peer: Arc, - ) -> Result<()> { - while let Some(envelope) = messages.next().await { - if let Some(envelope) = envelope.downcast_ref::>() { - let receipt = envelope.receipt(); - peer.respond( - receipt, - proto::Pong { - id: envelope.payload.id, - }, - ) - .await? - } else if let Some(envelope) = - envelope.downcast_ref::>() - { - let message = &envelope.payload; - let receipt = envelope.receipt(); - let response = match message.path.as_str() { - "path/one" => { - assert_eq!(message.worktree_id, 1); - proto::OpenBufferResponse { - buffer: Some(proto::Buffer { - id: 101, - content: "path/one content".to_string(), - history: vec![], - selections: vec![], - }), - } - } - "path/two" => { - assert_eq!(message.worktree_id, 2); - proto::OpenBufferResponse { - buffer: Some(proto::Buffer { - id: 102, - content: "path/two content".to_string(), - history: vec![], - selections: vec![], - }), - } - } - _ => { - panic!("unexpected path {}", message.path); - } - }; - - peer.respond(receipt, response).await? - } else { - panic!("unknown message type"); - } - } - - Ok(()) - } - }); - } - - #[test] - fn test_disconnect() { - smol::block_on(async move { - let (client_conn, mut server_conn) = test::Channel::bidirectional(); - - let client = Peer::new(); - let (connection_id, io_handler, mut incoming) = - client.add_connection(client_conn).await; - - let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel(); - smol::spawn(async move { - io_handler.await.ok(); - io_ended_tx.send(()).await.unwrap(); - }) - .detach(); - - let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel(); - smol::spawn(async move { - incoming.next().await; - messages_ended_tx.send(()).await.unwrap(); - }) - .detach(); - - client.disconnect(connection_id).await; - - io_ended_rx.recv().await; - messages_ended_rx.recv().await; - assert!( - futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![])) - .await - .is_err() - ); - }); - } - - #[test] - fn test_io_error() { - smol::block_on(async move { - let (client_conn, server_conn) = test::Channel::bidirectional(); - drop(server_conn); - - let client = Peer::new(); - let (connection_id, io_handler, mut incoming) = - client.add_connection(client_conn).await; - smol::spawn(io_handler).detach(); - smol::spawn(async move { incoming.next().await }).detach(); - - let err = client - .request( - connection_id, - proto::Auth { - user_id: 42, - access_token: "token".to_string(), - }, - ) - .await - .unwrap_err(); - assert_eq!(err.to_string(), "connection was closed"); - }); - } -}