diff --git a/Cargo.lock b/Cargo.lock index 2cd94e08fe..2a01fdf782 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4507,6 +4507,7 @@ dependencies = [ "base64 0.13.0", "futures", "log", + "parking_lot", "postage", "prost 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", "prost-build", diff --git a/gpui/src/app.rs b/gpui/src/app.rs index 5bd2a331f6..439e75e973 100644 --- a/gpui/src/app.rs +++ b/gpui/src/app.rs @@ -104,8 +104,11 @@ pub enum MenuItem<'a> { #[derive(Clone)] pub struct App(Rc>); +#[derive(Clone)] pub struct AsyncAppContext(Rc>); +pub struct BackgroundAppContext(*const RefCell); + #[derive(Clone)] pub struct TestAppContext { cx: Rc>, @@ -409,6 +412,15 @@ impl TestAppContext { } impl AsyncAppContext { + pub fn spawn(&self, f: F) -> Task + where + F: FnOnce(AsyncAppContext) -> Fut, + Fut: 'static + Future, + T: 'static, + { + self.0.borrow().foreground.spawn(f(self.clone())) + } + pub fn read T>(&mut self, callback: F) -> T { callback(self.0.borrow().as_ref()) } @@ -433,6 +445,10 @@ impl AsyncAppContext { self.0.borrow().platform() } + pub fn foreground(&self) -> Rc { + self.0.borrow().foreground.clone() + } + pub fn background(&self) -> Arc { self.0.borrow().cx.background.clone() } diff --git a/zed-rpc/Cargo.toml b/zed-rpc/Cargo.toml index f74d4cfc22..c1e3136bd8 100644 --- a/zed-rpc/Cargo.toml +++ b/zed-rpc/Cargo.toml @@ -14,6 +14,7 @@ async-tungstenite = "0.14" base64 = "0.13" futures = "0.3" log = "0.4" +parking_lot = "0.11.1" postage = {version = "0.4.1", features = ["futures-traits"]} prost = "0.7" rand = "0.8" diff --git a/zed-rpc/src/peer.rs b/zed-rpc/src/peer.rs index 8e9a58a14c..d825340db0 100644 --- a/zed-rpc/src/peer.rs +++ b/zed-rpc/src/peer.rs @@ -3,7 +3,7 @@ use anyhow::{anyhow, Context, Result}; use async_lock::{Mutex, RwLock}; use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; use futures::{ - future::BoxFuture, + future::{BoxFuture, LocalBoxFuture}, stream::{SplitSink, SplitStream}, FutureExt, StreamExt, }; @@ -30,9 +30,14 @@ pub struct ConnectionId(pub u32); pub struct PeerId(pub u32); type MessageHandler = Box< - dyn Send + Sync + Fn(&mut Option, ConnectionId) -> Option>, + dyn Send + + Sync + + Fn(&mut Option, ConnectionId) -> Option>, >; +type ForegroundMessageHandler = + Box, ConnectionId) -> Option>>; + pub struct Receipt { sender_id: ConnectionId, message_id: u32, @@ -63,10 +68,15 @@ impl TypedEnvelope { } } +pub type Router = RouterInternal; +pub type ForegroundRouter = RouterInternal; +pub struct RouterInternal { + message_handlers: Vec, + handler_types: HashSet, +} + pub struct Peer { connections: RwLock>, - message_handlers: RwLock>, - handler_types: Mutex>, next_connection_id: AtomicU32, } @@ -74,73 +84,37 @@ pub struct Peer { struct Connection { outgoing_tx: mpsc::Sender, next_message_id: Arc, - response_channels: ResponseChannels, + response_channels: Arc>>>, } -pub struct ConnectionHandler { - peer: Arc, +pub struct IOHandler { connection_id: ConnectionId, - response_channels: ResponseChannels, + incoming_tx: mpsc::Sender, outgoing_rx: mpsc::Receiver, writer: MessageStream, reader: MessageStream, } -type ResponseChannels = Arc>>>; - impl Peer { pub fn new() -> Arc { Arc::new(Self { connections: Default::default(), - message_handlers: Default::default(), - handler_types: Default::default(), next_connection_id: Default::default(), }) } - pub async fn add_message_handler( - &self, - ) -> mpsc::Receiver> { - if !self.handler_types.lock().await.insert(TypeId::of::()) { - panic!("duplicate handler type"); - } - - let (tx, rx) = mpsc::channel(256); - self.message_handlers - .write() - .await - .push(Box::new(move |envelope, connection_id| { - if envelope.as_ref().map_or(false, T::matches_envelope) { - let envelope = Option::take(envelope).unwrap(); - let mut tx = tx.clone(); - Some( - async move { - tx.send(TypedEnvelope { - sender_id: connection_id, - original_sender_id: envelope.original_sender_id.map(PeerId), - message_id: envelope.id, - payload: T::from_envelope(envelope).unwrap(), - }) - .await - .is_err() - } - .boxed(), - ) - } else { - None - } - })); - rx - } - - pub async fn add_connection( + pub async fn add_connection( self: &Arc, conn: Conn, + router: Arc>, ) -> ( ConnectionId, - ConnectionHandler, SplitStream>, + IOHandler, SplitStream>, + impl Future>, ) where + H: Fn(&mut Option, ConnectionId) -> Option, + Fut: Future, Conn: futures::Sink + futures::Stream> + Unpin, @@ -150,25 +124,45 @@ impl Peer { self.next_connection_id .fetch_add(1, atomic::Ordering::SeqCst), ); + let (incoming_tx, mut incoming_rx) = mpsc::channel(64); let (outgoing_tx, outgoing_rx) = mpsc::channel(64); let connection = Connection { outgoing_tx, next_message_id: Default::default(), response_channels: Default::default(), }; - let handler = ConnectionHandler { - peer: self.clone(), + let handle_io = IOHandler { connection_id, - response_channels: connection.response_channels.clone(), outgoing_rx, + incoming_tx, writer: MessageStream::new(tx), reader: MessageStream::new(rx), }; + + let response_channels = connection.response_channels.clone(); + let handle_messages = async move { + while let Some(message) = incoming_rx.recv().await { + if let Some(responding_to) = message.responding_to { + let channel = response_channels.lock().await.remove(&responding_to); + if let Some(mut tx) = channel { + tx.send(message).await.ok(); + } else { + log::warn!("received RPC response to unknown request {}", responding_to); + } + } else { + router.handle(connection_id, message).await; + } + } + response_channels.lock().await.clear(); + Ok(()) + }; + self.connections .write() .await .insert(connection_id, connection); - (connection_id, handler) + + (connection_id, handle_io, handle_messages) } pub async fn disconnect(&self, connection_id: ConnectionId) { @@ -177,8 +171,6 @@ impl Peer { pub async fn reset(&self) { self.connections.write().await.clear(); - self.handler_types.lock().await.clear(); - self.message_handlers.write().await.clear(); } pub fn request( @@ -302,7 +294,115 @@ impl Peer { } } -impl ConnectionHandler +impl RouterInternal +where + H: Fn(&mut Option, ConnectionId) -> Option, + Fut: Future, +{ + pub fn new() -> Self { + Self { + message_handlers: Default::default(), + handler_types: Default::default(), + } + } + + async fn handle(&self, connection_id: ConnectionId, message: proto::Envelope) { + let mut envelope = Some(message); + for handler in self.message_handlers.iter() { + if let Some(future) = handler(&mut envelope, connection_id) { + future.await; + return; + } + } + log::warn!("unhandled message: {:?}", envelope.unwrap().payload); + } +} + +impl Router { + pub fn add_message_handler(&mut self, handler: F) + where + T: EnvelopedMessage, + Fut: 'static + Send + Future>, + F: 'static + Send + Sync + Fn(TypedEnvelope) -> Fut, + { + if !self.handler_types.insert(TypeId::of::()) { + panic!("duplicate handler type"); + } + + self.message_handlers + .push(Box::new(move |envelope, connection_id| { + if envelope.as_ref().map_or(false, T::matches_envelope) { + let envelope = Option::take(envelope).unwrap(); + let message_id = envelope.id; + let future = handler(TypedEnvelope { + sender_id: connection_id, + original_sender_id: envelope.original_sender_id.map(PeerId), + message_id, + payload: T::from_envelope(envelope).unwrap(), + }); + Some( + async move { + if let Err(error) = future.await { + log::error!( + "error handling message {} {}: {:?}", + T::NAME, + message_id, + error + ); + } + } + .boxed(), + ) + } else { + None + } + })); + } +} + +impl ForegroundRouter { + pub fn add_message_handler(&mut self, handler: F) + where + T: EnvelopedMessage, + Fut: 'static + Future>, + F: 'static + Fn(TypedEnvelope) -> Fut, + { + if !self.handler_types.insert(TypeId::of::()) { + panic!("duplicate handler type"); + } + + self.message_handlers + .push(Box::new(move |envelope, connection_id| { + if envelope.as_ref().map_or(false, T::matches_envelope) { + let envelope = Option::take(envelope).unwrap(); + let message_id = envelope.id; + let future = handler(TypedEnvelope { + sender_id: connection_id, + original_sender_id: envelope.original_sender_id.map(PeerId), + message_id, + payload: T::from_envelope(envelope).unwrap(), + }); + Some( + async move { + if let Err(error) = future.await { + log::error!( + "error handling message {} {}: {:?}", + T::NAME, + message_id, + error + ); + } + } + .boxed_local(), + ) + } else { + None + } + })); + } +} + +impl IOHandler where W: futures::Sink + Unpin, R: futures::Stream> + Unpin, @@ -315,18 +415,18 @@ where futures::select_biased! { incoming = read_message => match incoming { Ok(incoming) => { - Self::handle_incoming_message(incoming, &self.peer, self.connection_id, &self.response_channels).await; + if self.incoming_tx.send(incoming).await.is_err() { + return Ok(()); + } break; } Err(error) => { - self.response_channels.lock().await.clear(); Err(error).context("received invalid RPC message")?; } }, outgoing = self.outgoing_rx.recv().fuse() => match outgoing { Some(outgoing) => { if let Err(result) = self.writer.write_message(&outgoing).await { - self.response_channels.lock().await.clear(); Err(result).context("failed to write RPC message")?; } } @@ -350,41 +450,6 @@ where payload, }) } - - async fn handle_incoming_message( - message: proto::Envelope, - peer: &Arc, - connection_id: ConnectionId, - response_channels: &ResponseChannels, - ) { - if let Some(responding_to) = message.responding_to { - let channel = response_channels.lock().await.remove(&responding_to); - if let Some(mut tx) = channel { - tx.send(message).await.ok(); - } else { - log::warn!("received RPC response to unknown request {}", responding_to); - } - } else { - let mut envelope = Some(message); - let mut handler_index = None; - let mut handler_was_dropped = false; - for (i, handler) in peer.message_handlers.read().await.iter().enumerate() { - if let Some(future) = handler(&mut envelope, connection_id) { - handler_was_dropped = future.await; - handler_index = Some(i); - break; - } - } - - if let Some(handler_index) = handler_index { - if handler_was_dropped { - drop(peer.message_handlers.write().await.remove(handler_index)); - } - } else { - log::warn!("unhandled message: {:?}", envelope.unwrap().payload); - } - } - } } impl Clone for Receipt { @@ -415,7 +480,6 @@ impl fmt::Display for PeerId { mod tests { use super::*; use crate::test; - use postage::oneshot; #[test] fn test_request_response() { @@ -425,127 +489,185 @@ mod tests { let client1 = Peer::new(); let client2 = Peer::new(); + let mut router = Router::new(); + router.add_message_handler({ + let server = server.clone(); + move |envelope: TypedEnvelope| { + let server = server.clone(); + async move { + let receipt = envelope.receipt(); + let message = envelope.payload; + server + .respond( + receipt, + match message.user_id { + 1 => { + assert_eq!(message.access_token, "access-token-1"); + proto::AuthResponse { + credentials_valid: true, + } + } + 2 => { + assert_eq!(message.access_token, "access-token-2"); + proto::AuthResponse { + credentials_valid: false, + } + } + _ => { + panic!("unexpected user id {}", message.user_id); + } + }, + ) + .await + } + } + }); + + router.add_message_handler({ + let server = server.clone(); + move |envelope: TypedEnvelope| { + let server = server.clone(); + async move { + let receipt = envelope.receipt(); + let message = envelope.payload; + server + .respond( + receipt, + match message.path.as_str() { + "path/one" => { + assert_eq!(message.worktree_id, 1); + proto::OpenBufferResponse { + buffer: Some(proto::Buffer { + id: 101, + content: "path/one content".to_string(), + history: vec![], + selections: vec![], + }), + } + } + "path/two" => { + assert_eq!(message.worktree_id, 2); + proto::OpenBufferResponse { + buffer: Some(proto::Buffer { + id: 102, + content: "path/two content".to_string(), + history: vec![], + selections: vec![], + }), + } + } + _ => { + panic!("unexpected path {}", message.path); + } + }, + ) + .await + } + } + }); + let router = Arc::new(router); + let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional(); - let (client1_conn_id, task1) = client1.add_connection(client1_to_server_conn).await; - let (_, task2) = server.add_connection(server_to_client_1_conn).await; + let (client1_conn_id, io_task1, msg_task1) = client1 + .add_connection(client1_to_server_conn, router.clone()) + .await; + let (_, io_task2, msg_task2) = server + .add_connection(server_to_client_1_conn, router.clone()) + .await; let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional(); - let (client2_conn_id, task3) = client2.add_connection(client2_to_server_conn).await; - let (_, task4) = server.add_connection(server_to_client_2_conn).await; + let (client2_conn_id, io_task3, msg_task3) = client2 + .add_connection(client2_to_server_conn, router.clone()) + .await; + let (_, io_task4, msg_task4) = server + .add_connection(server_to_client_2_conn, router.clone()) + .await; - smol::spawn(task1.run()).detach(); - smol::spawn(task2.run()).detach(); - smol::spawn(task3.run()).detach(); - smol::spawn(task4.run()).detach(); + smol::spawn(io_task1.run()).detach(); + smol::spawn(io_task2.run()).detach(); + smol::spawn(io_task3.run()).detach(); + smol::spawn(io_task4.run()).detach(); + smol::spawn(msg_task1).detach(); + smol::spawn(msg_task2).detach(); + smol::spawn(msg_task3).detach(); + smol::spawn(msg_task4).detach(); - // define the expected requests and responses - let request1 = proto::Auth { - user_id: 1, - access_token: "token-1".to_string(), - }; - let response1 = proto::AuthResponse { - credentials_valid: true, - }; - let request2 = proto::Auth { - user_id: 2, - access_token: "token-2".to_string(), - }; - let response2 = proto::AuthResponse { - credentials_valid: false, - }; - let request3 = proto::OpenBuffer { - worktree_id: 1, - path: "path/two".to_string(), - }; - let response3 = proto::OpenBufferResponse { - buffer: Some(proto::Buffer { - id: 2, - content: "path/two content".to_string(), - history: vec![], - selections: vec![], - }), - }; - let request4 = proto::OpenBuffer { - worktree_id: 2, - path: "path/one".to_string(), - }; - let response4 = proto::OpenBufferResponse { - buffer: Some(proto::Buffer { - id: 1, - content: "path/one content".to_string(), - history: vec![], - selections: vec![], - }), - }; - - // on the server, respond to two requests for each client - let mut open_buffer_rx = server.add_message_handler::().await; - let mut auth_rx = server.add_message_handler::().await; - let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>(); - smol::spawn({ - let request1 = request1.clone(); - let request2 = request2.clone(); - let request3 = request3.clone(); - let request4 = request4.clone(); - let response1 = response1.clone(); - let response2 = response2.clone(); - let response3 = response3.clone(); - let response4 = response4.clone(); - async move { - let msg = auth_rx.recv().await.unwrap(); - assert_eq!(msg.payload, request1); - server - .respond(msg.receipt(), response1.clone()) - .await - .unwrap(); - - let msg = auth_rx.recv().await.unwrap(); - assert_eq!(msg.payload, request2.clone()); - server - .respond(msg.receipt(), response2.clone()) - .await - .unwrap(); - - let msg = open_buffer_rx.recv().await.unwrap(); - assert_eq!(msg.payload, request3.clone()); - server - .respond(msg.receipt(), response3.clone()) - .await - .unwrap(); - - let msg = open_buffer_rx.recv().await.unwrap(); - assert_eq!(msg.payload, request4.clone()); - server - .respond(msg.receipt(), response4.clone()) - .await - .unwrap(); - - server_done_tx.send(()).await.unwrap(); + assert_eq!( + client1 + .request( + client1_conn_id, + proto::Auth { + user_id: 1, + access_token: "access-token-1".to_string(), + }, + ) + .await + .unwrap(), + proto::AuthResponse { + credentials_valid: true, } - }) - .detach(); + ); assert_eq!( - client1.request(client1_conn_id, request1).await.unwrap(), - response1 + client2 + .request( + client2_conn_id, + proto::Auth { + user_id: 2, + access_token: "access-token-2".to_string(), + }, + ) + .await + .unwrap(), + proto::AuthResponse { + credentials_valid: false, + } ); + assert_eq!( - client2.request(client2_conn_id, request2).await.unwrap(), - response2 + client1 + .request( + client1_conn_id, + proto::OpenBuffer { + worktree_id: 1, + path: "path/one".to_string(), + }, + ) + .await + .unwrap(), + proto::OpenBufferResponse { + buffer: Some(proto::Buffer { + id: 101, + content: "path/one content".to_string(), + history: vec![], + selections: vec![], + }), + } ); + assert_eq!( - client2.request(client2_conn_id, request3).await.unwrap(), - response3 - ); - assert_eq!( - client1.request(client1_conn_id, request4).await.unwrap(), - response4 + client2 + .request( + client2_conn_id, + proto::OpenBuffer { + worktree_id: 2, + path: "path/two".to_string(), + }, + ) + .await + .unwrap(), + proto::OpenBufferResponse { + buffer: Some(proto::Buffer { + id: 102, + content: "path/two content".to_string(), + history: vec![], + selections: vec![], + }), + } ); client1.disconnect(client1_conn_id).await; client2.disconnect(client1_conn_id).await; - - server_done_rx.recv().await.unwrap(); }); } @@ -555,17 +677,28 @@ mod tests { let (client_conn, mut server_conn) = test::Channel::bidirectional(); let client = Peer::new(); - let (connection_id, handler) = client.add_connection(client_conn).await; - let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) = - postage::barrier::channel(); + let router = Arc::new(Router::new()); + let (connection_id, io_handler, message_handler) = + client.add_connection(client_conn, router).await; + + let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel(); smol::spawn(async move { - handler.run().await.ok(); - incoming_messages_ended_tx.send(()).await.unwrap(); + io_handler.run().await.ok(); + io_ended_tx.send(()).await.unwrap(); }) .detach(); + + let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel(); + smol::spawn(async move { + message_handler.await.ok(); + messages_ended_tx.send(()).await.unwrap(); + }) + .detach(); + client.disconnect(connection_id).await; - incoming_messages_ended_rx.recv().await; + io_ended_rx.recv().await; + messages_ended_rx.recv().await; assert!( futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![])) .await @@ -581,8 +714,11 @@ mod tests { drop(server_conn); let client = Peer::new(); - let (connection_id, handler) = client.add_connection(client_conn).await; - smol::spawn(handler.run()).detach(); + let router = Arc::new(Router::new()); + let (connection_id, io_handler, message_handler) = + client.add_connection(client_conn, router).await; + smol::spawn(io_handler.run()).detach(); + smol::spawn(message_handler).detach(); let err = client .request( diff --git a/zed/src/editor.rs b/zed/src/editor.rs index a90d5bd576..621358d334 100644 --- a/zed/src/editor.rs +++ b/zed/src/editor.rs @@ -4015,7 +4015,8 @@ mod tests { let history = History::new(text.into()); Buffer::from_history(0, history, None, lang.cloned(), cx) }); - let (_, view) = cx.add_window(|cx| Editor::for_buffer(buffer, app_state.settings, cx)); + let (_, view) = + cx.add_window(|cx| Editor::for_buffer(buffer, app_state.settings.clone(), cx)); view.condition(&cx, |view, cx| !view.buffer.read(cx).is_parsing()) .await; diff --git a/zed/src/editor/buffer.rs b/zed/src/editor/buffer.rs index 304d9e3f40..265611e1ab 100644 --- a/zed/src/editor/buffer.rs +++ b/zed/src/editor/buffer.rs @@ -719,6 +719,7 @@ impl Buffer { mtime: SystemTime, cx: &mut ModelContext, ) { + eprintln!("{} did_save {:?}", self.replica_id, version); self.saved_mtime = mtime; self.saved_version = version; cx.emit(Event::Saved); diff --git a/zed/src/file_finder.rs b/zed/src/file_finder.rs index ee63f57faf..9ea1a9b1a2 100644 --- a/zed/src/file_finder.rs +++ b/zed/src/file_finder.rs @@ -479,8 +479,12 @@ mod tests { let app_state = cx.read(build_app_state); let (window_id, workspace) = cx.add_window(|cx| { - let mut workspace = - Workspace::new(app_state.settings, app_state.languages, app_state.rpc, cx); + let mut workspace = Workspace::new( + app_state.settings.clone(), + app_state.languages.clone(), + app_state.rpc.clone(), + cx, + ); workspace.add_worktree(tmp_dir.path(), cx); workspace }); @@ -559,7 +563,7 @@ mod tests { cx.read(|cx| workspace.read(cx).worktree_scans_complete(cx)) .await; let (_, finder) = - cx.add_window(|cx| FileFinder::new(app_state.settings, workspace.clone(), cx)); + cx.add_window(|cx| FileFinder::new(app_state.settings.clone(), workspace.clone(), cx)); let query = "hi".to_string(); finder @@ -622,7 +626,7 @@ mod tests { cx.read(|cx| workspace.read(cx).worktree_scans_complete(cx)) .await; let (_, finder) = - cx.add_window(|cx| FileFinder::new(app_state.settings, workspace.clone(), cx)); + cx.add_window(|cx| FileFinder::new(app_state.settings.clone(), workspace.clone(), cx)); // Even though there is only one worktree, that worktree's filename // is included in the matching, because the worktree is a single file. @@ -681,7 +685,7 @@ mod tests { .await; let (_, finder) = - cx.add_window(|cx| FileFinder::new(app_state.settings, workspace.clone(), cx)); + cx.add_window(|cx| FileFinder::new(app_state.settings.clone(), workspace.clone(), cx)); // Run a search that matches two files with the same relative path. finder diff --git a/zed/src/lib.rs b/zed/src/lib.rs index 925c798d4c..4b1c3571d6 100644 --- a/zed/src/lib.rs +++ b/zed/src/lib.rs @@ -1,3 +1,5 @@ +use zed_rpc::ForegroundRouter; + pub mod assets; pub mod editor; pub mod file_finder; @@ -14,10 +16,10 @@ mod util; pub mod workspace; pub mod worktree; -#[derive(Clone)] pub struct AppState { pub settings: postage::watch::Receiver, pub languages: std::sync::Arc, + pub rpc_router: std::sync::Arc, pub rpc: rpc::Client, } diff --git a/zed/src/main.rs b/zed/src/main.rs index d5edf9ba7b..ea11ff0c69 100644 --- a/zed/src/main.rs +++ b/zed/src/main.rs @@ -10,6 +10,7 @@ use zed::{ workspace::{self, OpenParams}, worktree, AppState, }; +use zed_rpc::ForegroundRouter; fn main() { init_logger(); @@ -20,20 +21,27 @@ fn main() { let languages = Arc::new(language::LanguageRegistry::new()); languages.set_theme(&settings.borrow().theme); - let app_state = AppState { + let mut app_state = AppState { languages: languages.clone(), settings, + rpc_router: Arc::new(ForegroundRouter::new()), rpc: rpc::Client::new(languages), }; app.run(move |cx| { - cx.set_menus(menus::menus(app_state.clone())); + worktree::init( + cx, + &app_state.rpc, + Arc::get_mut(&mut app_state.rpc_router).unwrap(), + ); zed::init(cx); workspace::init(cx); - worktree::init(cx, app_state.rpc.clone()); editor::init(cx); file_finder::init(cx); + let app_state = Arc::new(app_state); + cx.set_menus(menus::menus(&app_state.clone())); + if stdout_is_a_pty() { cx.platform().activate(true); } diff --git a/zed/src/menus.rs b/zed/src/menus.rs index 9f22c8ede9..227f0b9efc 100644 --- a/zed/src/menus.rs +++ b/zed/src/menus.rs @@ -1,8 +1,9 @@ use crate::AppState; use gpui::{Menu, MenuItem}; +use std::sync::Arc; #[cfg(target_os = "macos")] -pub fn menus(state: AppState) -> Vec> { +pub fn menus(state: &Arc) -> Vec> { vec![ Menu { name: "Zed", @@ -48,7 +49,7 @@ pub fn menus(state: AppState) -> Vec> { name: "Open…", keystroke: Some("cmd-o"), action: "workspace:open", - arg: Some(Box::new(state)), + arg: Some(Box::new(state.clone())), }, ], }, diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index fc6bba57f1..0f3087fb36 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -1,10 +1,8 @@ use crate::{language::LanguageRegistry, worktree::Worktree}; use anyhow::{anyhow, Context, Result}; use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; -use gpui::executor::Background; use gpui::{AsyncAppContext, ModelHandle, Task, WeakModelHandle}; use lazy_static::lazy_static; -use postage::prelude::Stream; use smol::lock::RwLock; use std::collections::HashMap; use std::time::Duration; @@ -13,7 +11,7 @@ use surf::Url; pub use zed_rpc::{proto, ConnectionId, PeerId, TypedEnvelope}; use zed_rpc::{ proto::{EnvelopedMessage, RequestMessage}, - Peer, Receipt, + ForegroundRouter, Peer, Receipt, }; lazy_static! { @@ -63,24 +61,30 @@ impl Client { } } - pub fn on_message(&self, handler: H, cx: &mut gpui::MutableAppContext) - where - H: 'static + for<'a> MessageHandler<'a, M>, + pub fn on_message( + &self, + router: &mut ForegroundRouter, + handler: H, + cx: &mut gpui::MutableAppContext, + ) where + H: 'static + Clone + for<'a> MessageHandler<'a, M>, M: proto::EnvelopedMessage, { let this = self.clone(); - let mut messages = smol::block_on(this.peer.add_message_handler::()); - cx.spawn(|mut cx| async move { - while let Some(message) = messages.recv().await { - if let Err(err) = handler.handle(message, &this, &mut cx).await { - log::error!("error handling message: {:?}", err); - } - } - }) - .detach(); + let cx = cx.to_async(); + router.add_message_handler(move |message| { + let this = this.clone(); + let mut cx = cx.clone(); + let handler = handler.clone(); + async move { handler.handle(message, &this, &mut cx).await } + }); } - pub async fn log_in_and_connect(&self, cx: &AsyncAppContext) -> surf::Result<()> { + pub async fn log_in_and_connect( + &self, + router: Arc, + cx: AsyncAppContext, + ) -> surf::Result<()> { if self.state.read().await.connection_id.is_some() { return Ok(()); } @@ -96,14 +100,14 @@ impl Client { .await .context("websocket handshake")?; log::info!("connected to rpc address {}", *ZED_SERVER_URL); - self.add_connection(stream, user_id, access_token, &cx.background()) + self.add_connection(stream, user_id, access_token, router, cx) .await?; } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") { let stream = smol::net::TcpStream::connect(host).await?; let (stream, _) = async_tungstenite::client_async(format!("ws://{}/rpc", host), stream).await?; log::info!("connected to rpc address {}", *ZED_SERVER_URL); - self.add_connection(stream, user_id, access_token, &cx.background()) + self.add_connection(stream, user_id, access_token, router, cx) .await?; } else { return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?; @@ -117,7 +121,8 @@ impl Client { conn: Conn, user_id: i32, access_token: String, - executor: &Arc, + router: Arc, + cx: AsyncAppContext, ) -> surf::Result<()> where Conn: 'static @@ -126,10 +131,12 @@ impl Client { + Unpin + Send, { - let (connection_id, handler) = self.peer.add_connection(conn).await; - executor + let (connection_id, handle_io, handle_messages) = + self.peer.add_connection(conn, router).await; + cx.foreground().spawn(handle_messages).detach(); + cx.background() .spawn(async move { - if let Err(error) = handler.run().await { + if let Err(error) = handle_io.run().await { log::error!("connection error: {:?}", error); } }) @@ -263,7 +270,7 @@ impl Client { } } -pub trait MessageHandler<'a, M: proto::EnvelopedMessage> { +pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone { type Output: 'a + Future>; fn handle( @@ -277,7 +284,7 @@ pub trait MessageHandler<'a, M: proto::EnvelopedMessage> { impl<'a, M, F, Fut> MessageHandler<'a, M> for F where M: proto::EnvelopedMessage, - F: Fn(TypedEnvelope, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut, + F: Clone + Fn(TypedEnvelope, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut, Fut: 'a + Future>, { type Output = Fut; diff --git a/zed/src/test.rs b/zed/src/test.rs index 8350d4b0d4..2fb5b6dd2e 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -5,6 +5,7 @@ use std::{ sync::Arc, }; use tempdir::TempDir; +use zed_rpc::ForegroundRouter; #[cfg(feature = "test-support")] pub use zed_rpc::test::Channel; @@ -143,12 +144,13 @@ fn write_tree(path: &Path, tree: serde_json::Value) { } } -pub fn build_app_state(cx: &AppContext) -> AppState { +pub fn build_app_state(cx: &AppContext) -> Arc { let settings = settings::channel(&cx.font_cache()).unwrap().1; let languages = Arc::new(LanguageRegistry::new()); - AppState { + Arc::new(AppState { settings, languages: languages.clone(), + rpc_router: Arc::new(ForegroundRouter::new()), rpc: rpc::Client::new(languages), - } + }) } diff --git a/zed/src/workspace.rs b/zed/src/workspace.rs index f37a2b2d90..8541fb2377 100644 --- a/zed/src/workspace.rs +++ b/zed/src/workspace.rs @@ -45,10 +45,10 @@ pub fn init(cx: &mut MutableAppContext) { pub struct OpenParams { pub paths: Vec, - pub app_state: AppState, + pub app_state: Arc, } -fn open(app_state: &AppState, cx: &mut MutableAppContext) { +fn open(app_state: &Arc, cx: &mut MutableAppContext) { let app_state = app_state.clone(); cx.prompt_for_paths( PathPromptOptions { @@ -101,7 +101,7 @@ fn open_paths(params: &OpenParams, cx: &mut MutableAppContext) { }); } -fn open_new(app_state: &AppState, cx: &mut MutableAppContext) { +fn open_new(app_state: &Arc, cx: &mut MutableAppContext) { cx.add_window(|cx| { let mut view = Workspace::new( app_state.settings.clone(), @@ -700,12 +700,13 @@ impl Workspace { }; } - fn share_worktree(&mut self, _: &(), cx: &mut ViewContext) { + fn share_worktree(&mut self, app_state: &Arc, cx: &mut ViewContext) { let rpc = self.rpc.clone(); let platform = cx.platform(); + let router = app_state.rpc_router.clone(); let task = cx.spawn(|this, mut cx| async move { - rpc.log_in_and_connect(&cx).await?; + rpc.log_in_and_connect(router, cx.clone()).await?; let share_task = this.update(&mut cx, |this, cx| { let worktree = this.worktrees.iter().next()?; @@ -732,12 +733,13 @@ impl Workspace { .detach(); } - fn join_worktree(&mut self, _: &(), cx: &mut ViewContext) { + fn join_worktree(&mut self, app_state: &Arc, cx: &mut ViewContext) { let rpc = self.rpc.clone(); let languages = self.languages.clone(); + let router = app_state.rpc_router.clone(); let task = cx.spawn(|this, mut cx| async move { - rpc.log_in_and_connect(&cx).await?; + rpc.log_in_and_connect(router, cx.clone()).await?; let worktree_url = cx .platform() @@ -974,8 +976,12 @@ mod tests { let app_state = cx.read(build_app_state); let (_, workspace) = cx.add_window(|cx| { - let mut workspace = - Workspace::new(app_state.settings, app_state.languages, app_state.rpc, cx); + let mut workspace = Workspace::new( + app_state.settings.clone(), + app_state.languages.clone(), + app_state.rpc.clone(), + cx, + ); workspace.add_worktree(dir.path(), cx); workspace }); @@ -1077,8 +1083,12 @@ mod tests { let app_state = cx.read(build_app_state); let (_, workspace) = cx.add_window(|cx| { - let mut workspace = - Workspace::new(app_state.settings, app_state.languages, app_state.rpc, cx); + let mut workspace = Workspace::new( + app_state.settings.clone(), + app_state.languages.clone(), + app_state.rpc.clone(), + cx, + ); workspace.add_worktree(dir1.path(), cx); workspace }); @@ -1146,8 +1156,12 @@ mod tests { let app_state = cx.read(build_app_state); let (window_id, workspace) = cx.add_window(|cx| { - let mut workspace = - Workspace::new(app_state.settings, app_state.languages, app_state.rpc, cx); + let mut workspace = Workspace::new( + app_state.settings.clone(), + app_state.languages.clone(), + app_state.rpc.clone(), + cx, + ); workspace.add_worktree(dir.path(), cx); workspace }); @@ -1315,8 +1329,12 @@ mod tests { let app_state = cx.read(build_app_state); let (window_id, workspace) = cx.add_window(|cx| { - let mut workspace = - Workspace::new(app_state.settings, app_state.languages, app_state.rpc, cx); + let mut workspace = Workspace::new( + app_state.settings.clone(), + app_state.languages.clone(), + app_state.rpc.clone(), + cx, + ); workspace.add_worktree(dir.path(), cx); workspace }); diff --git a/zed/src/worktree.rs b/zed/src/worktree.rs index eaa6c56e97..5eb19873cf 100644 --- a/zed/src/worktree.rs +++ b/zed/src/worktree.rs @@ -48,21 +48,21 @@ use std::{ }, time::{Duration, SystemTime}, }; -use zed_rpc::{PeerId, TypedEnvelope}; +use zed_rpc::{ForegroundRouter, PeerId, TypedEnvelope}; lazy_static! { static ref GITIGNORE: &'static OsStr = OsStr::new(".gitignore"); } -pub fn init(cx: &mut MutableAppContext, rpc: rpc::Client) { - rpc.on_message(remote::add_peer, cx); - rpc.on_message(remote::remove_peer, cx); - rpc.on_message(remote::update_worktree, cx); - rpc.on_message(remote::open_buffer, cx); - rpc.on_message(remote::close_buffer, cx); - rpc.on_message(remote::update_buffer, cx); - rpc.on_message(remote::buffer_saved, cx); - rpc.on_message(remote::save_buffer, cx); +pub fn init(cx: &mut MutableAppContext, rpc: &rpc::Client, router: &mut ForegroundRouter) { + rpc.on_message(router, remote::add_peer, cx); + rpc.on_message(router, remote::remove_peer, cx); + rpc.on_message(router, remote::update_worktree, cx); + rpc.on_message(router, remote::open_buffer, cx); + rpc.on_message(router, remote::close_buffer, cx); + rpc.on_message(router, remote::update_buffer, cx); + rpc.on_message(router, remote::buffer_saved, cx); + rpc.on_message(router, remote::save_buffer, cx); } #[async_trait::async_trait] @@ -2861,6 +2861,8 @@ mod remote { rpc: &rpc::Client, cx: &mut AsyncAppContext, ) -> anyhow::Result<()> { + eprintln!("got update buffer message {:?}", envelope.payload); + let message = envelope.payload; rpc.state .read() @@ -2875,6 +2877,8 @@ mod remote { rpc: &rpc::Client, cx: &mut AsyncAppContext, ) -> anyhow::Result<()> { + eprintln!("got save buffer message {:?}", envelope.payload); + let state = rpc.state.read().await; let worktree = state.shared_worktree(envelope.payload.worktree_id, cx)?; let sender_id = envelope.original_sender_id()?; @@ -2905,6 +2909,8 @@ mod remote { rpc: &rpc::Client, cx: &mut AsyncAppContext, ) -> anyhow::Result<()> { + eprintln!("got buffer_saved {:?}", envelope.payload); + rpc.state .read() .await @@ -2993,7 +2999,7 @@ mod tests { let dir = temp_tree(json!({ "file1": "the old contents", })); - let tree = cx.add_model(|cx| Worktree::local(dir.path(), app_state.languages, cx)); + let tree = cx.add_model(|cx| Worktree::local(dir.path(), app_state.languages.clone(), cx)); let buffer = tree .update(&mut cx, |tree, cx| tree.open_buffer("file1", cx)) .await @@ -3016,7 +3022,8 @@ mod tests { })); let file_path = dir.path().join("file1"); - let tree = cx.add_model(|cx| Worktree::local(file_path.clone(), app_state.languages, cx)); + let tree = + cx.add_model(|cx| Worktree::local(file_path.clone(), app_state.languages.clone(), cx)); cx.read(|cx| tree.read(cx).as_local().unwrap().scan_complete()) .await; cx.read(|cx| assert_eq!(tree.read(cx).file_count(), 1));