Restructure Peer to handle connections' messages in order

This commit is contained in:
Max Brunsfeld 2021-07-09 16:27:33 -07:00
parent b7fae693f9
commit eeebc761b6
14 changed files with 479 additions and 274 deletions

1
Cargo.lock generated
View file

@ -4507,6 +4507,7 @@ dependencies = [
"base64 0.13.0",
"futures",
"log",
"parking_lot",
"postage",
"prost 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
"prost-build",

View file

@ -104,8 +104,11 @@ pub enum MenuItem<'a> {
#[derive(Clone)]
pub struct App(Rc<RefCell<MutableAppContext>>);
#[derive(Clone)]
pub struct AsyncAppContext(Rc<RefCell<MutableAppContext>>);
pub struct BackgroundAppContext(*const RefCell<MutableAppContext>);
#[derive(Clone)]
pub struct TestAppContext {
cx: Rc<RefCell<MutableAppContext>>,
@ -409,6 +412,15 @@ impl TestAppContext {
}
impl AsyncAppContext {
pub fn spawn<F, Fut, T>(&self, f: F) -> Task<T>
where
F: FnOnce(AsyncAppContext) -> Fut,
Fut: 'static + Future<Output = T>,
T: 'static,
{
self.0.borrow().foreground.spawn(f(self.clone()))
}
pub fn read<T, F: FnOnce(&AppContext) -> T>(&mut self, callback: F) -> T {
callback(self.0.borrow().as_ref())
}
@ -433,6 +445,10 @@ impl AsyncAppContext {
self.0.borrow().platform()
}
pub fn foreground(&self) -> Rc<executor::Foreground> {
self.0.borrow().foreground.clone()
}
pub fn background(&self) -> Arc<executor::Background> {
self.0.borrow().cx.background.clone()
}

View file

@ -14,6 +14,7 @@ async-tungstenite = "0.14"
base64 = "0.13"
futures = "0.3"
log = "0.4"
parking_lot = "0.11.1"
postage = {version = "0.4.1", features = ["futures-traits"]}
prost = "0.7"
rand = "0.8"

View file

@ -3,7 +3,7 @@ use anyhow::{anyhow, Context, Result};
use async_lock::{Mutex, RwLock};
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
use futures::{
future::BoxFuture,
future::{BoxFuture, LocalBoxFuture},
stream::{SplitSink, SplitStream},
FutureExt, StreamExt,
};
@ -30,9 +30,14 @@ pub struct ConnectionId(pub u32);
pub struct PeerId(pub u32);
type MessageHandler = Box<
dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
dyn Send
+ Sync
+ Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<'static, ()>>,
>;
type ForegroundMessageHandler =
Box<dyn Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<LocalBoxFuture<'static, ()>>>;
pub struct Receipt<T> {
sender_id: ConnectionId,
message_id: u32,
@ -63,10 +68,15 @@ impl<T: RequestMessage> TypedEnvelope<T> {
}
}
pub type Router = RouterInternal<MessageHandler>;
pub type ForegroundRouter = RouterInternal<ForegroundMessageHandler>;
pub struct RouterInternal<H> {
message_handlers: Vec<H>,
handler_types: HashSet<TypeId>,
}
pub struct Peer {
connections: RwLock<HashMap<ConnectionId, Connection>>,
message_handlers: RwLock<Vec<MessageHandler>>,
handler_types: Mutex<HashSet<TypeId>>,
next_connection_id: AtomicU32,
}
@ -74,73 +84,37 @@ pub struct Peer {
struct Connection {
outgoing_tx: mpsc::Sender<proto::Envelope>,
next_message_id: Arc<AtomicU32>,
response_channels: ResponseChannels,
response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
}
pub struct ConnectionHandler<W, R> {
peer: Arc<Peer>,
pub struct IOHandler<W, R> {
connection_id: ConnectionId,
response_channels: ResponseChannels,
incoming_tx: mpsc::Sender<proto::Envelope>,
outgoing_rx: mpsc::Receiver<proto::Envelope>,
writer: MessageStream<W>,
reader: MessageStream<R>,
}
type ResponseChannels = Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>;
impl Peer {
pub fn new() -> Arc<Self> {
Arc::new(Self {
connections: Default::default(),
message_handlers: Default::default(),
handler_types: Default::default(),
next_connection_id: Default::default(),
})
}
pub async fn add_message_handler<T: EnvelopedMessage>(
&self,
) -> mpsc::Receiver<TypedEnvelope<T>> {
if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
panic!("duplicate handler type");
}
let (tx, rx) = mpsc::channel(256);
self.message_handlers
.write()
.await
.push(Box::new(move |envelope, connection_id| {
if envelope.as_ref().map_or(false, T::matches_envelope) {
let envelope = Option::take(envelope).unwrap();
let mut tx = tx.clone();
Some(
async move {
tx.send(TypedEnvelope {
sender_id: connection_id,
original_sender_id: envelope.original_sender_id.map(PeerId),
message_id: envelope.id,
payload: T::from_envelope(envelope).unwrap(),
})
.await
.is_err()
}
.boxed(),
)
} else {
None
}
}));
rx
}
pub async fn add_connection<Conn>(
pub async fn add_connection<Conn, H, Fut>(
self: &Arc<Self>,
conn: Conn,
router: Arc<RouterInternal<H>>,
) -> (
ConnectionId,
ConnectionHandler<SplitSink<Conn, WebSocketMessage>, SplitStream<Conn>>,
IOHandler<SplitSink<Conn, WebSocketMessage>, SplitStream<Conn>>,
impl Future<Output = anyhow::Result<()>>,
)
where
H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
Fut: Future<Output = ()>,
Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
+ Unpin,
@ -150,25 +124,45 @@ impl Peer {
self.next_connection_id
.fetch_add(1, atomic::Ordering::SeqCst),
);
let (incoming_tx, mut incoming_rx) = mpsc::channel(64);
let (outgoing_tx, outgoing_rx) = mpsc::channel(64);
let connection = Connection {
outgoing_tx,
next_message_id: Default::default(),
response_channels: Default::default(),
};
let handler = ConnectionHandler {
peer: self.clone(),
let handle_io = IOHandler {
connection_id,
response_channels: connection.response_channels.clone(),
outgoing_rx,
incoming_tx,
writer: MessageStream::new(tx),
reader: MessageStream::new(rx),
};
let response_channels = connection.response_channels.clone();
let handle_messages = async move {
while let Some(message) = incoming_rx.recv().await {
if let Some(responding_to) = message.responding_to {
let channel = response_channels.lock().await.remove(&responding_to);
if let Some(mut tx) = channel {
tx.send(message).await.ok();
} else {
log::warn!("received RPC response to unknown request {}", responding_to);
}
} else {
router.handle(connection_id, message).await;
}
}
response_channels.lock().await.clear();
Ok(())
};
self.connections
.write()
.await
.insert(connection_id, connection);
(connection_id, handler)
(connection_id, handle_io, handle_messages)
}
pub async fn disconnect(&self, connection_id: ConnectionId) {
@ -177,8 +171,6 @@ impl Peer {
pub async fn reset(&self) {
self.connections.write().await.clear();
self.handler_types.lock().await.clear();
self.message_handlers.write().await.clear();
}
pub fn request<T: RequestMessage>(
@ -302,7 +294,115 @@ impl Peer {
}
}
impl<W, R> ConnectionHandler<W, R>
impl<H, Fut> RouterInternal<H>
where
H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
Fut: Future<Output = ()>,
{
pub fn new() -> Self {
Self {
message_handlers: Default::default(),
handler_types: Default::default(),
}
}
async fn handle(&self, connection_id: ConnectionId, message: proto::Envelope) {
let mut envelope = Some(message);
for handler in self.message_handlers.iter() {
if let Some(future) = handler(&mut envelope, connection_id) {
future.await;
return;
}
}
log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
}
}
impl Router {
pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
where
T: EnvelopedMessage,
Fut: 'static + Send + Future<Output = Result<()>>,
F: 'static + Send + Sync + Fn(TypedEnvelope<T>) -> Fut,
{
if !self.handler_types.insert(TypeId::of::<T>()) {
panic!("duplicate handler type");
}
self.message_handlers
.push(Box::new(move |envelope, connection_id| {
if envelope.as_ref().map_or(false, T::matches_envelope) {
let envelope = Option::take(envelope).unwrap();
let message_id = envelope.id;
let future = handler(TypedEnvelope {
sender_id: connection_id,
original_sender_id: envelope.original_sender_id.map(PeerId),
message_id,
payload: T::from_envelope(envelope).unwrap(),
});
Some(
async move {
if let Err(error) = future.await {
log::error!(
"error handling message {} {}: {:?}",
T::NAME,
message_id,
error
);
}
}
.boxed(),
)
} else {
None
}
}));
}
}
impl ForegroundRouter {
pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
where
T: EnvelopedMessage,
Fut: 'static + Future<Output = Result<()>>,
F: 'static + Fn(TypedEnvelope<T>) -> Fut,
{
if !self.handler_types.insert(TypeId::of::<T>()) {
panic!("duplicate handler type");
}
self.message_handlers
.push(Box::new(move |envelope, connection_id| {
if envelope.as_ref().map_or(false, T::matches_envelope) {
let envelope = Option::take(envelope).unwrap();
let message_id = envelope.id;
let future = handler(TypedEnvelope {
sender_id: connection_id,
original_sender_id: envelope.original_sender_id.map(PeerId),
message_id,
payload: T::from_envelope(envelope).unwrap(),
});
Some(
async move {
if let Err(error) = future.await {
log::error!(
"error handling message {} {}: {:?}",
T::NAME,
message_id,
error
);
}
}
.boxed_local(),
)
} else {
None
}
}));
}
}
impl<W, R> IOHandler<W, R>
where
W: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
R: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
@ -315,18 +415,18 @@ where
futures::select_biased! {
incoming = read_message => match incoming {
Ok(incoming) => {
Self::handle_incoming_message(incoming, &self.peer, self.connection_id, &self.response_channels).await;
if self.incoming_tx.send(incoming).await.is_err() {
return Ok(());
}
break;
}
Err(error) => {
self.response_channels.lock().await.clear();
Err(error).context("received invalid RPC message")?;
}
},
outgoing = self.outgoing_rx.recv().fuse() => match outgoing {
Some(outgoing) => {
if let Err(result) = self.writer.write_message(&outgoing).await {
self.response_channels.lock().await.clear();
Err(result).context("failed to write RPC message")?;
}
}
@ -350,41 +450,6 @@ where
payload,
})
}
async fn handle_incoming_message(
message: proto::Envelope,
peer: &Arc<Peer>,
connection_id: ConnectionId,
response_channels: &ResponseChannels,
) {
if let Some(responding_to) = message.responding_to {
let channel = response_channels.lock().await.remove(&responding_to);
if let Some(mut tx) = channel {
tx.send(message).await.ok();
} else {
log::warn!("received RPC response to unknown request {}", responding_to);
}
} else {
let mut envelope = Some(message);
let mut handler_index = None;
let mut handler_was_dropped = false;
for (i, handler) in peer.message_handlers.read().await.iter().enumerate() {
if let Some(future) = handler(&mut envelope, connection_id) {
handler_was_dropped = future.await;
handler_index = Some(i);
break;
}
}
if let Some(handler_index) = handler_index {
if handler_was_dropped {
drop(peer.message_handlers.write().await.remove(handler_index));
}
} else {
log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
}
}
}
}
impl<T> Clone for Receipt<T> {
@ -415,7 +480,6 @@ impl fmt::Display for PeerId {
mod tests {
use super::*;
use crate::test;
use postage::oneshot;
#[test]
fn test_request_response() {
@ -425,127 +489,185 @@ mod tests {
let client1 = Peer::new();
let client2 = Peer::new();
let mut router = Router::new();
router.add_message_handler({
let server = server.clone();
move |envelope: TypedEnvelope<proto::Auth>| {
let server = server.clone();
async move {
let receipt = envelope.receipt();
let message = envelope.payload;
server
.respond(
receipt,
match message.user_id {
1 => {
assert_eq!(message.access_token, "access-token-1");
proto::AuthResponse {
credentials_valid: true,
}
}
2 => {
assert_eq!(message.access_token, "access-token-2");
proto::AuthResponse {
credentials_valid: false,
}
}
_ => {
panic!("unexpected user id {}", message.user_id);
}
},
)
.await
}
}
});
router.add_message_handler({
let server = server.clone();
move |envelope: TypedEnvelope<proto::OpenBuffer>| {
let server = server.clone();
async move {
let receipt = envelope.receipt();
let message = envelope.payload;
server
.respond(
receipt,
match message.path.as_str() {
"path/one" => {
assert_eq!(message.worktree_id, 1);
proto::OpenBufferResponse {
buffer: Some(proto::Buffer {
id: 101,
content: "path/one content".to_string(),
history: vec![],
selections: vec![],
}),
}
}
"path/two" => {
assert_eq!(message.worktree_id, 2);
proto::OpenBufferResponse {
buffer: Some(proto::Buffer {
id: 102,
content: "path/two content".to_string(),
history: vec![],
selections: vec![],
}),
}
}
_ => {
panic!("unexpected path {}", message.path);
}
},
)
.await
}
}
});
let router = Arc::new(router);
let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
let (client1_conn_id, task1) = client1.add_connection(client1_to_server_conn).await;
let (_, task2) = server.add_connection(server_to_client_1_conn).await;
let (client1_conn_id, io_task1, msg_task1) = client1
.add_connection(client1_to_server_conn, router.clone())
.await;
let (_, io_task2, msg_task2) = server
.add_connection(server_to_client_1_conn, router.clone())
.await;
let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
let (client2_conn_id, task3) = client2.add_connection(client2_to_server_conn).await;
let (_, task4) = server.add_connection(server_to_client_2_conn).await;
let (client2_conn_id, io_task3, msg_task3) = client2
.add_connection(client2_to_server_conn, router.clone())
.await;
let (_, io_task4, msg_task4) = server
.add_connection(server_to_client_2_conn, router.clone())
.await;
smol::spawn(task1.run()).detach();
smol::spawn(task2.run()).detach();
smol::spawn(task3.run()).detach();
smol::spawn(task4.run()).detach();
smol::spawn(io_task1.run()).detach();
smol::spawn(io_task2.run()).detach();
smol::spawn(io_task3.run()).detach();
smol::spawn(io_task4.run()).detach();
smol::spawn(msg_task1).detach();
smol::spawn(msg_task2).detach();
smol::spawn(msg_task3).detach();
smol::spawn(msg_task4).detach();
// define the expected requests and responses
let request1 = proto::Auth {
user_id: 1,
access_token: "token-1".to_string(),
};
let response1 = proto::AuthResponse {
credentials_valid: true,
};
let request2 = proto::Auth {
user_id: 2,
access_token: "token-2".to_string(),
};
let response2 = proto::AuthResponse {
credentials_valid: false,
};
let request3 = proto::OpenBuffer {
worktree_id: 1,
path: "path/two".to_string(),
};
let response3 = proto::OpenBufferResponse {
buffer: Some(proto::Buffer {
id: 2,
content: "path/two content".to_string(),
history: vec![],
selections: vec![],
}),
};
let request4 = proto::OpenBuffer {
worktree_id: 2,
path: "path/one".to_string(),
};
let response4 = proto::OpenBufferResponse {
buffer: Some(proto::Buffer {
id: 1,
content: "path/one content".to_string(),
history: vec![],
selections: vec![],
}),
};
// on the server, respond to two requests for each client
let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
smol::spawn({
let request1 = request1.clone();
let request2 = request2.clone();
let request3 = request3.clone();
let request4 = request4.clone();
let response1 = response1.clone();
let response2 = response2.clone();
let response3 = response3.clone();
let response4 = response4.clone();
async move {
let msg = auth_rx.recv().await.unwrap();
assert_eq!(msg.payload, request1);
server
.respond(msg.receipt(), response1.clone())
.await
.unwrap();
let msg = auth_rx.recv().await.unwrap();
assert_eq!(msg.payload, request2.clone());
server
.respond(msg.receipt(), response2.clone())
.await
.unwrap();
let msg = open_buffer_rx.recv().await.unwrap();
assert_eq!(msg.payload, request3.clone());
server
.respond(msg.receipt(), response3.clone())
.await
.unwrap();
let msg = open_buffer_rx.recv().await.unwrap();
assert_eq!(msg.payload, request4.clone());
server
.respond(msg.receipt(), response4.clone())
.await
.unwrap();
server_done_tx.send(()).await.unwrap();
assert_eq!(
client1
.request(
client1_conn_id,
proto::Auth {
user_id: 1,
access_token: "access-token-1".to_string(),
},
)
.await
.unwrap(),
proto::AuthResponse {
credentials_valid: true,
}
})
.detach();
);
assert_eq!(
client1.request(client1_conn_id, request1).await.unwrap(),
response1
client2
.request(
client2_conn_id,
proto::Auth {
user_id: 2,
access_token: "access-token-2".to_string(),
},
)
.await
.unwrap(),
proto::AuthResponse {
credentials_valid: false,
}
);
assert_eq!(
client2.request(client2_conn_id, request2).await.unwrap(),
response2
client1
.request(
client1_conn_id,
proto::OpenBuffer {
worktree_id: 1,
path: "path/one".to_string(),
},
)
.await
.unwrap(),
proto::OpenBufferResponse {
buffer: Some(proto::Buffer {
id: 101,
content: "path/one content".to_string(),
history: vec![],
selections: vec![],
}),
}
);
assert_eq!(
client2.request(client2_conn_id, request3).await.unwrap(),
response3
);
assert_eq!(
client1.request(client1_conn_id, request4).await.unwrap(),
response4
client2
.request(
client2_conn_id,
proto::OpenBuffer {
worktree_id: 2,
path: "path/two".to_string(),
},
)
.await
.unwrap(),
proto::OpenBufferResponse {
buffer: Some(proto::Buffer {
id: 102,
content: "path/two content".to_string(),
history: vec![],
selections: vec![],
}),
}
);
client1.disconnect(client1_conn_id).await;
client2.disconnect(client1_conn_id).await;
server_done_rx.recv().await.unwrap();
});
}
@ -555,17 +677,28 @@ mod tests {
let (client_conn, mut server_conn) = test::Channel::bidirectional();
let client = Peer::new();
let (connection_id, handler) = client.add_connection(client_conn).await;
let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
postage::barrier::channel();
let router = Arc::new(Router::new());
let (connection_id, io_handler, message_handler) =
client.add_connection(client_conn, router).await;
let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
smol::spawn(async move {
handler.run().await.ok();
incoming_messages_ended_tx.send(()).await.unwrap();
io_handler.run().await.ok();
io_ended_tx.send(()).await.unwrap();
})
.detach();
let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
smol::spawn(async move {
message_handler.await.ok();
messages_ended_tx.send(()).await.unwrap();
})
.detach();
client.disconnect(connection_id).await;
incoming_messages_ended_rx.recv().await;
io_ended_rx.recv().await;
messages_ended_rx.recv().await;
assert!(
futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
.await
@ -581,8 +714,11 @@ mod tests {
drop(server_conn);
let client = Peer::new();
let (connection_id, handler) = client.add_connection(client_conn).await;
smol::spawn(handler.run()).detach();
let router = Arc::new(Router::new());
let (connection_id, io_handler, message_handler) =
client.add_connection(client_conn, router).await;
smol::spawn(io_handler.run()).detach();
smol::spawn(message_handler).detach();
let err = client
.request(

View file

@ -4015,7 +4015,8 @@ mod tests {
let history = History::new(text.into());
Buffer::from_history(0, history, None, lang.cloned(), cx)
});
let (_, view) = cx.add_window(|cx| Editor::for_buffer(buffer, app_state.settings, cx));
let (_, view) =
cx.add_window(|cx| Editor::for_buffer(buffer, app_state.settings.clone(), cx));
view.condition(&cx, |view, cx| !view.buffer.read(cx).is_parsing())
.await;

View file

@ -719,6 +719,7 @@ impl Buffer {
mtime: SystemTime,
cx: &mut ModelContext<Self>,
) {
eprintln!("{} did_save {:?}", self.replica_id, version);
self.saved_mtime = mtime;
self.saved_version = version;
cx.emit(Event::Saved);

View file

@ -479,8 +479,12 @@ mod tests {
let app_state = cx.read(build_app_state);
let (window_id, workspace) = cx.add_window(|cx| {
let mut workspace =
Workspace::new(app_state.settings, app_state.languages, app_state.rpc, cx);
let mut workspace = Workspace::new(
app_state.settings.clone(),
app_state.languages.clone(),
app_state.rpc.clone(),
cx,
);
workspace.add_worktree(tmp_dir.path(), cx);
workspace
});
@ -559,7 +563,7 @@ mod tests {
cx.read(|cx| workspace.read(cx).worktree_scans_complete(cx))
.await;
let (_, finder) =
cx.add_window(|cx| FileFinder::new(app_state.settings, workspace.clone(), cx));
cx.add_window(|cx| FileFinder::new(app_state.settings.clone(), workspace.clone(), cx));
let query = "hi".to_string();
finder
@ -622,7 +626,7 @@ mod tests {
cx.read(|cx| workspace.read(cx).worktree_scans_complete(cx))
.await;
let (_, finder) =
cx.add_window(|cx| FileFinder::new(app_state.settings, workspace.clone(), cx));
cx.add_window(|cx| FileFinder::new(app_state.settings.clone(), workspace.clone(), cx));
// Even though there is only one worktree, that worktree's filename
// is included in the matching, because the worktree is a single file.
@ -681,7 +685,7 @@ mod tests {
.await;
let (_, finder) =
cx.add_window(|cx| FileFinder::new(app_state.settings, workspace.clone(), cx));
cx.add_window(|cx| FileFinder::new(app_state.settings.clone(), workspace.clone(), cx));
// Run a search that matches two files with the same relative path.
finder

View file

@ -1,3 +1,5 @@
use zed_rpc::ForegroundRouter;
pub mod assets;
pub mod editor;
pub mod file_finder;
@ -14,10 +16,10 @@ mod util;
pub mod workspace;
pub mod worktree;
#[derive(Clone)]
pub struct AppState {
pub settings: postage::watch::Receiver<settings::Settings>,
pub languages: std::sync::Arc<language::LanguageRegistry>,
pub rpc_router: std::sync::Arc<ForegroundRouter>,
pub rpc: rpc::Client,
}

View file

@ -10,6 +10,7 @@ use zed::{
workspace::{self, OpenParams},
worktree, AppState,
};
use zed_rpc::ForegroundRouter;
fn main() {
init_logger();
@ -20,20 +21,27 @@ fn main() {
let languages = Arc::new(language::LanguageRegistry::new());
languages.set_theme(&settings.borrow().theme);
let app_state = AppState {
let mut app_state = AppState {
languages: languages.clone(),
settings,
rpc_router: Arc::new(ForegroundRouter::new()),
rpc: rpc::Client::new(languages),
};
app.run(move |cx| {
cx.set_menus(menus::menus(app_state.clone()));
worktree::init(
cx,
&app_state.rpc,
Arc::get_mut(&mut app_state.rpc_router).unwrap(),
);
zed::init(cx);
workspace::init(cx);
worktree::init(cx, app_state.rpc.clone());
editor::init(cx);
file_finder::init(cx);
let app_state = Arc::new(app_state);
cx.set_menus(menus::menus(&app_state.clone()));
if stdout_is_a_pty() {
cx.platform().activate(true);
}

View file

@ -1,8 +1,9 @@
use crate::AppState;
use gpui::{Menu, MenuItem};
use std::sync::Arc;
#[cfg(target_os = "macos")]
pub fn menus(state: AppState) -> Vec<Menu<'static>> {
pub fn menus(state: &Arc<AppState>) -> Vec<Menu<'static>> {
vec![
Menu {
name: "Zed",
@ -48,7 +49,7 @@ pub fn menus(state: AppState) -> Vec<Menu<'static>> {
name: "Open…",
keystroke: Some("cmd-o"),
action: "workspace:open",
arg: Some(Box::new(state)),
arg: Some(Box::new(state.clone())),
},
],
},

View file

@ -1,10 +1,8 @@
use crate::{language::LanguageRegistry, worktree::Worktree};
use anyhow::{anyhow, Context, Result};
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
use gpui::executor::Background;
use gpui::{AsyncAppContext, ModelHandle, Task, WeakModelHandle};
use lazy_static::lazy_static;
use postage::prelude::Stream;
use smol::lock::RwLock;
use std::collections::HashMap;
use std::time::Duration;
@ -13,7 +11,7 @@ use surf::Url;
pub use zed_rpc::{proto, ConnectionId, PeerId, TypedEnvelope};
use zed_rpc::{
proto::{EnvelopedMessage, RequestMessage},
Peer, Receipt,
ForegroundRouter, Peer, Receipt,
};
lazy_static! {
@ -63,24 +61,30 @@ impl Client {
}
}
pub fn on_message<H, M>(&self, handler: H, cx: &mut gpui::MutableAppContext)
where
H: 'static + for<'a> MessageHandler<'a, M>,
pub fn on_message<H, M>(
&self,
router: &mut ForegroundRouter,
handler: H,
cx: &mut gpui::MutableAppContext,
) where
H: 'static + Clone + for<'a> MessageHandler<'a, M>,
M: proto::EnvelopedMessage,
{
let this = self.clone();
let mut messages = smol::block_on(this.peer.add_message_handler::<M>());
cx.spawn(|mut cx| async move {
while let Some(message) = messages.recv().await {
if let Err(err) = handler.handle(message, &this, &mut cx).await {
log::error!("error handling message: {:?}", err);
}
}
})
.detach();
let cx = cx.to_async();
router.add_message_handler(move |message| {
let this = this.clone();
let mut cx = cx.clone();
let handler = handler.clone();
async move { handler.handle(message, &this, &mut cx).await }
});
}
pub async fn log_in_and_connect(&self, cx: &AsyncAppContext) -> surf::Result<()> {
pub async fn log_in_and_connect(
&self,
router: Arc<ForegroundRouter>,
cx: AsyncAppContext,
) -> surf::Result<()> {
if self.state.read().await.connection_id.is_some() {
return Ok(());
}
@ -96,14 +100,14 @@ impl Client {
.await
.context("websocket handshake")?;
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
self.add_connection(stream, user_id, access_token, &cx.background())
self.add_connection(stream, user_id, access_token, router, cx)
.await?;
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
let stream = smol::net::TcpStream::connect(host).await?;
let (stream, _) =
async_tungstenite::client_async(format!("ws://{}/rpc", host), stream).await?;
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
self.add_connection(stream, user_id, access_token, &cx.background())
self.add_connection(stream, user_id, access_token, router, cx)
.await?;
} else {
return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
@ -117,7 +121,8 @@ impl Client {
conn: Conn,
user_id: i32,
access_token: String,
executor: &Arc<Background>,
router: Arc<ForegroundRouter>,
cx: AsyncAppContext,
) -> surf::Result<()>
where
Conn: 'static
@ -126,10 +131,12 @@ impl Client {
+ Unpin
+ Send,
{
let (connection_id, handler) = self.peer.add_connection(conn).await;
executor
let (connection_id, handle_io, handle_messages) =
self.peer.add_connection(conn, router).await;
cx.foreground().spawn(handle_messages).detach();
cx.background()
.spawn(async move {
if let Err(error) = handler.run().await {
if let Err(error) = handle_io.run().await {
log::error!("connection error: {:?}", error);
}
})
@ -263,7 +270,7 @@ impl Client {
}
}
pub trait MessageHandler<'a, M: proto::EnvelopedMessage> {
pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
type Output: 'a + Future<Output = anyhow::Result<()>>;
fn handle(
@ -277,7 +284,7 @@ pub trait MessageHandler<'a, M: proto::EnvelopedMessage> {
impl<'a, M, F, Fut> MessageHandler<'a, M> for F
where
M: proto::EnvelopedMessage,
F: Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
Fut: 'a + Future<Output = anyhow::Result<()>>,
{
type Output = Fut;

View file

@ -5,6 +5,7 @@ use std::{
sync::Arc,
};
use tempdir::TempDir;
use zed_rpc::ForegroundRouter;
#[cfg(feature = "test-support")]
pub use zed_rpc::test::Channel;
@ -143,12 +144,13 @@ fn write_tree(path: &Path, tree: serde_json::Value) {
}
}
pub fn build_app_state(cx: &AppContext) -> AppState {
pub fn build_app_state(cx: &AppContext) -> Arc<AppState> {
let settings = settings::channel(&cx.font_cache()).unwrap().1;
let languages = Arc::new(LanguageRegistry::new());
AppState {
Arc::new(AppState {
settings,
languages: languages.clone(),
rpc_router: Arc::new(ForegroundRouter::new()),
rpc: rpc::Client::new(languages),
}
})
}

View file

@ -45,10 +45,10 @@ pub fn init(cx: &mut MutableAppContext) {
pub struct OpenParams {
pub paths: Vec<PathBuf>,
pub app_state: AppState,
pub app_state: Arc<AppState>,
}
fn open(app_state: &AppState, cx: &mut MutableAppContext) {
fn open(app_state: &Arc<AppState>, cx: &mut MutableAppContext) {
let app_state = app_state.clone();
cx.prompt_for_paths(
PathPromptOptions {
@ -101,7 +101,7 @@ fn open_paths(params: &OpenParams, cx: &mut MutableAppContext) {
});
}
fn open_new(app_state: &AppState, cx: &mut MutableAppContext) {
fn open_new(app_state: &Arc<AppState>, cx: &mut MutableAppContext) {
cx.add_window(|cx| {
let mut view = Workspace::new(
app_state.settings.clone(),
@ -700,12 +700,13 @@ impl Workspace {
};
}
fn share_worktree(&mut self, _: &(), cx: &mut ViewContext<Self>) {
fn share_worktree(&mut self, app_state: &Arc<AppState>, cx: &mut ViewContext<Self>) {
let rpc = self.rpc.clone();
let platform = cx.platform();
let router = app_state.rpc_router.clone();
let task = cx.spawn(|this, mut cx| async move {
rpc.log_in_and_connect(&cx).await?;
rpc.log_in_and_connect(router, cx.clone()).await?;
let share_task = this.update(&mut cx, |this, cx| {
let worktree = this.worktrees.iter().next()?;
@ -732,12 +733,13 @@ impl Workspace {
.detach();
}
fn join_worktree(&mut self, _: &(), cx: &mut ViewContext<Self>) {
fn join_worktree(&mut self, app_state: &Arc<AppState>, cx: &mut ViewContext<Self>) {
let rpc = self.rpc.clone();
let languages = self.languages.clone();
let router = app_state.rpc_router.clone();
let task = cx.spawn(|this, mut cx| async move {
rpc.log_in_and_connect(&cx).await?;
rpc.log_in_and_connect(router, cx.clone()).await?;
let worktree_url = cx
.platform()
@ -974,8 +976,12 @@ mod tests {
let app_state = cx.read(build_app_state);
let (_, workspace) = cx.add_window(|cx| {
let mut workspace =
Workspace::new(app_state.settings, app_state.languages, app_state.rpc, cx);
let mut workspace = Workspace::new(
app_state.settings.clone(),
app_state.languages.clone(),
app_state.rpc.clone(),
cx,
);
workspace.add_worktree(dir.path(), cx);
workspace
});
@ -1077,8 +1083,12 @@ mod tests {
let app_state = cx.read(build_app_state);
let (_, workspace) = cx.add_window(|cx| {
let mut workspace =
Workspace::new(app_state.settings, app_state.languages, app_state.rpc, cx);
let mut workspace = Workspace::new(
app_state.settings.clone(),
app_state.languages.clone(),
app_state.rpc.clone(),
cx,
);
workspace.add_worktree(dir1.path(), cx);
workspace
});
@ -1146,8 +1156,12 @@ mod tests {
let app_state = cx.read(build_app_state);
let (window_id, workspace) = cx.add_window(|cx| {
let mut workspace =
Workspace::new(app_state.settings, app_state.languages, app_state.rpc, cx);
let mut workspace = Workspace::new(
app_state.settings.clone(),
app_state.languages.clone(),
app_state.rpc.clone(),
cx,
);
workspace.add_worktree(dir.path(), cx);
workspace
});
@ -1315,8 +1329,12 @@ mod tests {
let app_state = cx.read(build_app_state);
let (window_id, workspace) = cx.add_window(|cx| {
let mut workspace =
Workspace::new(app_state.settings, app_state.languages, app_state.rpc, cx);
let mut workspace = Workspace::new(
app_state.settings.clone(),
app_state.languages.clone(),
app_state.rpc.clone(),
cx,
);
workspace.add_worktree(dir.path(), cx);
workspace
});

View file

@ -48,21 +48,21 @@ use std::{
},
time::{Duration, SystemTime},
};
use zed_rpc::{PeerId, TypedEnvelope};
use zed_rpc::{ForegroundRouter, PeerId, TypedEnvelope};
lazy_static! {
static ref GITIGNORE: &'static OsStr = OsStr::new(".gitignore");
}
pub fn init(cx: &mut MutableAppContext, rpc: rpc::Client) {
rpc.on_message(remote::add_peer, cx);
rpc.on_message(remote::remove_peer, cx);
rpc.on_message(remote::update_worktree, cx);
rpc.on_message(remote::open_buffer, cx);
rpc.on_message(remote::close_buffer, cx);
rpc.on_message(remote::update_buffer, cx);
rpc.on_message(remote::buffer_saved, cx);
rpc.on_message(remote::save_buffer, cx);
pub fn init(cx: &mut MutableAppContext, rpc: &rpc::Client, router: &mut ForegroundRouter) {
rpc.on_message(router, remote::add_peer, cx);
rpc.on_message(router, remote::remove_peer, cx);
rpc.on_message(router, remote::update_worktree, cx);
rpc.on_message(router, remote::open_buffer, cx);
rpc.on_message(router, remote::close_buffer, cx);
rpc.on_message(router, remote::update_buffer, cx);
rpc.on_message(router, remote::buffer_saved, cx);
rpc.on_message(router, remote::save_buffer, cx);
}
#[async_trait::async_trait]
@ -2861,6 +2861,8 @@ mod remote {
rpc: &rpc::Client,
cx: &mut AsyncAppContext,
) -> anyhow::Result<()> {
eprintln!("got update buffer message {:?}", envelope.payload);
let message = envelope.payload;
rpc.state
.read()
@ -2875,6 +2877,8 @@ mod remote {
rpc: &rpc::Client,
cx: &mut AsyncAppContext,
) -> anyhow::Result<()> {
eprintln!("got save buffer message {:?}", envelope.payload);
let state = rpc.state.read().await;
let worktree = state.shared_worktree(envelope.payload.worktree_id, cx)?;
let sender_id = envelope.original_sender_id()?;
@ -2905,6 +2909,8 @@ mod remote {
rpc: &rpc::Client,
cx: &mut AsyncAppContext,
) -> anyhow::Result<()> {
eprintln!("got buffer_saved {:?}", envelope.payload);
rpc.state
.read()
.await
@ -2993,7 +2999,7 @@ mod tests {
let dir = temp_tree(json!({
"file1": "the old contents",
}));
let tree = cx.add_model(|cx| Worktree::local(dir.path(), app_state.languages, cx));
let tree = cx.add_model(|cx| Worktree::local(dir.path(), app_state.languages.clone(), cx));
let buffer = tree
.update(&mut cx, |tree, cx| tree.open_buffer("file1", cx))
.await
@ -3016,7 +3022,8 @@ mod tests {
}));
let file_path = dir.path().join("file1");
let tree = cx.add_model(|cx| Worktree::local(file_path.clone(), app_state.languages, cx));
let tree =
cx.add_model(|cx| Worktree::local(file_path.clone(), app_state.languages.clone(), cx));
cx.read(|cx| tree.read(cx).as_local().unwrap().scan_complete())
.await;
cx.read(|cx| assert_eq!(tree.read(cx).file_count(), 1));