diff --git a/Cargo.lock b/Cargo.lock index 92867463c4..5e22f78551 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -998,7 +998,6 @@ dependencies = [ name = "clock" version = "0.1.0" dependencies = [ - "rpc", "smallvec", ] @@ -2236,6 +2235,7 @@ dependencies = [ "tiny-skia", "tree-sitter", "usvg", + "util", "waker-fn", ] @@ -3959,6 +3959,7 @@ dependencies = [ "async-lock", "async-tungstenite", "base64 0.13.0", + "clock", "futures", "gpui", "log", @@ -3972,6 +3973,7 @@ dependencies = [ "smol", "smol-timeout", "tempdir", + "util", "zstd", ] @@ -5574,7 +5576,6 @@ name = "util" version = "0.1.0" dependencies = [ "anyhow", - "clock", "futures", "log", "rand 0.8.3", @@ -5959,6 +5960,7 @@ name = "zed-server" version = "0.1.0" dependencies = [ "anyhow", + "async-io", "async-sqlx-session", "async-std", "async-trait", diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index c40b78987c..0b26743a24 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -137,8 +137,8 @@ struct ClientState { credentials: Option, status: (watch::Sender, watch::Receiver), entity_id_extractors: HashMap u64>>, - _maintain_connection: Option>, - heartbeat_interval: Duration, + _reconnect_task: Option>, + reconnect_interval: Duration, models_by_entity_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakModelHandle>, models_by_message_type: HashMap, model_types_by_message_type: HashMap, @@ -168,8 +168,8 @@ impl Default for ClientState { credentials: None, status: watch::channel_with(Status::SignedOut), entity_id_extractors: Default::default(), - _maintain_connection: None, - heartbeat_interval: Duration::from_secs(5), + _reconnect_task: None, + reconnect_interval: Duration::from_secs(5), models_by_message_type: Default::default(), models_by_entity_type_and_remote_id: Default::default(), model_types_by_message_type: Default::default(), @@ -236,7 +236,7 @@ impl Client { #[cfg(any(test, feature = "test-support"))] pub fn tear_down(&self) { let mut state = self.state.write(); - state._maintain_connection.take(); + state._reconnect_task.take(); state.message_handlers.clear(); state.models_by_message_type.clear(); state.models_by_entity_type_and_remote_id.clear(); @@ -283,21 +283,13 @@ impl Client { match status { Status::Connected { .. } => { - let heartbeat_interval = state.heartbeat_interval; - let this = self.clone(); - let foreground = cx.foreground(); - state._maintain_connection = Some(cx.foreground().spawn(async move { - loop { - foreground.timer(heartbeat_interval).await; - let _ = this.request(proto::Ping {}).await; - } - })); + state._reconnect_task = None; } Status::ConnectionLost => { let this = self.clone(); let foreground = cx.foreground(); - let heartbeat_interval = state.heartbeat_interval; - state._maintain_connection = Some(cx.spawn(|cx| async move { + let reconnect_interval = state.reconnect_interval; + state._reconnect_task = Some(cx.spawn(|cx| async move { let mut rng = StdRng::from_entropy(); let mut delay = Duration::from_millis(100); while let Err(error) = this.authenticate_and_connect(&cx).await { @@ -311,12 +303,12 @@ impl Client { foreground.timer(delay).await; delay = delay .mul_f32(rng.gen_range(1.0..=2.0)) - .min(heartbeat_interval); + .min(reconnect_interval); } })); } Status::SignedOut | Status::UpgradeRequired => { - state._maintain_connection.take(); + state._reconnect_task.take(); } _ => {} } @@ -548,7 +540,11 @@ impl Client { } async fn set_connection(self: &Arc, conn: Connection, cx: &AsyncAppContext) { - let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await; + let executor = cx.background(); + let (connection_id, handle_io, mut incoming) = self + .peer + .add_connection(conn, move |duration| executor.timer(duration)) + .await; cx.foreground() .spawn({ let cx = cx.clone(); @@ -940,26 +936,6 @@ mod tests { use crate::test::{FakeHttpClient, FakeServer}; use gpui::TestAppContext; - #[gpui::test(iterations = 10)] - async fn test_heartbeat(cx: &mut TestAppContext) { - cx.foreground().forbid_parking(); - - let user_id = 5; - let mut client = Client::new(FakeHttpClient::with_404_response()); - let server = FakeServer::for_client(user_id, &mut client, &cx).await; - - cx.foreground().advance_clock(Duration::from_secs(10)); - let ping = server.receive::().await.unwrap(); - server.respond(ping.receipt(), proto::Ack {}).await; - - cx.foreground().advance_clock(Duration::from_secs(10)); - let ping = server.receive::().await.unwrap(); - server.respond(ping.receipt(), proto::Ack {}).await; - - client.disconnect(&cx.to_async()).unwrap(); - assert!(server.receive::().await.is_err()); - } - #[gpui::test(iterations = 10)] async fn test_reconnection(cx: &mut TestAppContext) { cx.foreground().forbid_parking(); diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 697bf3860c..f5d88c2d9a 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -75,7 +75,8 @@ impl FakeServer { } let (client_conn, server_conn, _) = Connection::in_memory(cx.background()); - let (connection_id, io, incoming) = peer.add_connection(server_conn).await; + let (connection_id, io, incoming) = + peer.add_test_connection(server_conn, cx.background()).await; cx.background().spawn(io).detach(); let mut state = state.lock(); state.connection_id = Some(connection_id); diff --git a/crates/gpui/Cargo.toml b/crates/gpui/Cargo.toml index 8f884259b7..9973ac6549 100644 --- a/crates/gpui/Cargo.toml +++ b/crates/gpui/Cargo.toml @@ -14,6 +14,7 @@ test-support = ["backtrace", "dhat", "env_logger", "collections/test-support"] [dependencies] collections = { path = "../collections" } gpui_macros = { path = "../gpui_macros" } +util = { path = "../util" } sum_tree = { path = "../sum_tree" } async-task = "4.0.3" backtrace = { version = "0.3", optional = true } diff --git a/crates/rpc/Cargo.toml b/crates/rpc/Cargo.toml index 255906ab85..e773b3f0ba 100644 --- a/crates/rpc/Cargo.toml +++ b/crates/rpc/Cargo.toml @@ -26,7 +26,9 @@ rsa = "0.4" serde = { version = "1", features = ["derive"] } smol-timeout = "0.6" zstd = "0.9" +clock = { path = "../clock" } gpui = { path = "../gpui", optional = true } +util = { path = "../util" } [build-dependencies] prost-build = "0.8" diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 0a00f6d801..f9c94cc84d 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -94,6 +94,7 @@ pub struct ConnectionState { Arc>>>>, } +const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(2); const WRITE_TIMEOUT: Duration = Duration::from_secs(10); impl Peer { @@ -104,14 +105,20 @@ impl Peer { }) } - pub async fn add_connection( + pub async fn add_connection( self: &Arc, connection: Connection, + create_timer: F, ) -> ( ConnectionId, impl Future> + Send, BoxStream<'static, Box>, - ) { + ) + where + F: Send + Fn(Duration) -> Fut, + Fut: Send + Future, + Out: Send, + { // For outgoing messages, use an unbounded channel so that application code // can always send messages without yielding. For incoming messages, use a // bounded channel so that other peers will receive backpressure if they send @@ -121,7 +128,7 @@ impl Peer { let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst)); let connection_state = ConnectionState { - outgoing_tx, + outgoing_tx: outgoing_tx.clone(), next_message_id: Default::default(), response_channels: Arc::new(Mutex::new(Some(Default::default()))), }; @@ -131,39 +138,43 @@ impl Peer { let this = self.clone(); let response_channels = connection_state.response_channels.clone(); let handle_io = async move { - let result = 'outer: loop { + let _end_connection = util::defer(|| { + response_channels.lock().take(); + this.connections.write().remove(&connection_id); + }); + + loop { let read_message = reader.read_message().fuse(); futures::pin_mut!(read_message); loop { futures::select_biased! { outgoing = outgoing_rx.next().fuse() => match outgoing { Some(outgoing) => { - match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await { - None => break 'outer Err(anyhow!("timed out writing RPC message")), - Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"), - _ => {} + if let Some(result) = writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await { + result.context("failed to write RPC message")?; + } else { + Err(anyhow!("timed out writing message"))?; } } - None => break 'outer Ok(()), + None => return Ok(()), }, - incoming = read_message => match incoming { - Ok(incoming) => { - if incoming_tx.send(incoming).await.is_err() { - break 'outer Ok(()); - } - break; - } - Err(error) => { - break 'outer Err(error).context("received invalid RPC message") + incoming = read_message => { + let incoming = incoming.context("received invalid rpc message")?; + if incoming_tx.send(incoming).await.is_err() { + return Ok(()); } + break; }, + _ = create_timer(KEEPALIVE_INTERVAL).fuse() => { + if let Some(result) = writer.ping().timeout(WRITE_TIMEOUT).await { + result.context("failed to send websocket ping")?; + } else { + Err(anyhow!("timed out sending websocket ping"))?; + } + } } } - }; - - response_channels.lock().take(); - this.connections.write().remove(&connection_id); - result + } }; let response_channels = connection_state.response_channels.clone(); @@ -191,18 +202,31 @@ impl Peer { None } else { - if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) { - Some(envelope) - } else { + proto::build_typed_envelope(connection_id, incoming).or_else(|| { log::error!("unable to construct a typed envelope"); None - } + }) } } }); (connection_id, handle_io, incoming_rx.boxed()) } + #[cfg(any(test, feature = "test-support"))] + pub async fn add_test_connection( + self: &Arc, + connection: Connection, + executor: Arc, + ) -> ( + ConnectionId, + impl Future> + Send, + BoxStream<'static, Box>, + ) { + let executor = executor.clone(); + self.add_connection(connection, move |duration| executor.timer(duration)) + .await + } + pub fn disconnect(&self, connection_id: ConnectionId) { self.connections.write().remove(&connection_id); } @@ -349,15 +373,21 @@ mod tests { let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory(cx.background()); - let (client1_conn_id, io_task1, client1_incoming) = - client1.add_connection(client1_to_server_conn).await; - let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await; + let (client1_conn_id, io_task1, client1_incoming) = client1 + .add_test_connection(client1_to_server_conn, cx.background()) + .await; + let (_, io_task2, server_incoming1) = server + .add_test_connection(server_to_client_1_conn, cx.background()) + .await; let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory(cx.background()); - let (client2_conn_id, io_task3, client2_incoming) = - client2.add_connection(client2_to_server_conn).await; - let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await; + let (client2_conn_id, io_task3, client2_incoming) = client2 + .add_test_connection(client2_to_server_conn, cx.background()) + .await; + let (_, io_task4, server_incoming2) = server + .add_test_connection(server_to_client_2_conn, cx.background()) + .await; executor.spawn(io_task1).detach(); executor.spawn(io_task2).detach(); @@ -440,10 +470,12 @@ mod tests { let (client_to_server_conn, server_to_client_conn, _) = Connection::in_memory(cx.background()); - let (client_to_server_conn_id, io_task1, mut client_incoming) = - client.add_connection(client_to_server_conn).await; - let (server_to_client_conn_id, io_task2, mut server_incoming) = - server.add_connection(server_to_client_conn).await; + let (client_to_server_conn_id, io_task1, mut client_incoming) = client + .add_test_connection(client_to_server_conn, cx.background()) + .await; + let (server_to_client_conn_id, io_task2, mut server_incoming) = server + .add_test_connection(server_to_client_conn, cx.background()) + .await; executor.spawn(io_task1).detach(); executor.spawn(io_task2).detach(); @@ -538,10 +570,12 @@ mod tests { let (client_to_server_conn, server_to_client_conn, _) = Connection::in_memory(cx.background()); - let (client_to_server_conn_id, io_task1, mut client_incoming) = - client.add_connection(client_to_server_conn).await; - let (server_to_client_conn_id, io_task2, mut server_incoming) = - server.add_connection(server_to_client_conn).await; + let (client_to_server_conn_id, io_task1, mut client_incoming) = client + .add_test_connection(client_to_server_conn, cx.background()) + .await; + let (server_to_client_conn_id, io_task2, mut server_incoming) = server + .add_test_connection(server_to_client_conn, cx.background()) + .await; executor.spawn(io_task1).detach(); executor.spawn(io_task2).detach(); @@ -649,7 +683,9 @@ mod tests { let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background()); let client = Peer::new(); - let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await; + let (connection_id, io_handler, mut incoming) = client + .add_test_connection(client_conn, cx.background()) + .await; let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel(); executor @@ -683,7 +719,9 @@ mod tests { let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background()); let client = Peer::new(); - let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await; + let (connection_id, io_handler, mut incoming) = client + .add_test_connection(client_conn, cx.background()) + .await; executor.spawn(io_handler).detach(); executor .spawn(async move { incoming.next().await }) diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 0729dbc76a..3d7557842a 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -318,6 +318,13 @@ where self.stream.send(WebSocketMessage::Binary(buffer)).await?; Ok(()) } + + pub async fn ping(&mut self) -> Result<(), WebSocketError> { + self.stream + .send(WebSocketMessage::Ping(Default::default())) + .await?; + Ok(()) + } } impl MessageStream diff --git a/crates/server/Cargo.toml b/crates/server/Cargo.toml index d7984dad04..c39fb2f10b 100644 --- a/crates/server/Cargo.toml +++ b/crates/server/Cargo.toml @@ -16,6 +16,7 @@ required-features = ["seed-support"] collections = { path = "../collections" } rpc = { path = "../rpc" } anyhow = "1.0.40" +async-io = "1.3" async-std = { version = "1.8.0", features = ["attributes"] } async-trait = "0.1.50" async-tungstenite = "0.16" diff --git a/crates/server/src/rpc.rs b/crates/server/src/rpc.rs index 9a8b4ee161..9f812ba104 100644 --- a/crates/server/src/rpc.rs +++ b/crates/server/src/rpc.rs @@ -6,6 +6,7 @@ use super::{ AppState, }; use anyhow::anyhow; +use async_io::Timer; use async_std::task; use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; use collections::{HashMap, HashSet}; @@ -16,7 +17,12 @@ use rpc::{ Connection, ConnectionId, Peer, TypedEnvelope, }; use sha1::{Digest as _, Sha1}; -use std::{any::TypeId, future::Future, sync::Arc, time::Instant}; +use std::{ + any::TypeId, + future::Future, + sync::Arc, + time::{Duration, Instant}, +}; use store::{Store, Worktree}; use surf::StatusCode; use tide::log; @@ -40,10 +46,13 @@ pub struct Server { notifications: Option>, } -pub trait Executor { +pub trait Executor: Send + Clone { + type Timer: Send + Future; fn spawn_detached>(&self, future: F); + fn timer(&self, duration: Duration) -> Self::Timer; } +#[derive(Clone)] pub struct RealExecutor; const MESSAGE_COUNT_PER_PAGE: usize = 100; @@ -167,8 +176,18 @@ impl Server { ) -> impl Future { let mut this = self.clone(); async move { - let (connection_id, handle_io, mut incoming_rx) = - this.peer.add_connection(connection).await; + let (connection_id, handle_io, mut incoming_rx) = this + .peer + .add_connection(connection, { + let executor = executor.clone(); + move |duration| { + let timer = executor.timer(duration); + async move { + timer.await; + } + } + }) + .await; if let Some(send_connection_id) = send_connection_id.as_mut() { let _ = send_connection_id.send(connection_id).await; @@ -883,9 +902,15 @@ impl Server { } impl Executor for RealExecutor { + type Timer = Timer; + fn spawn_detached>(&self, future: F) { task::spawn(future); } + + fn timer(&self, duration: Duration) -> Self::Timer { + Timer::after(duration) + } } fn broadcast( @@ -1759,7 +1784,7 @@ mod tests { } #[gpui::test(iterations = 10)] - async fn test_peer_disconnection(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { + async fn test_leaving_project(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); let lang_registry = Arc::new(LanguageRegistry::new()); let fs = FakeFs::new(cx_a.background()); @@ -1817,16 +1842,39 @@ mod tests { .await .unwrap(); - // See that a guest has joined as client A. + // Client A sees that a guest has joined. project_a .condition(&cx_a, |p, _| p.collaborators().len() == 1) .await; - // Drop client B's connection and ensure client A observes client B leaving the worktree. + // Drop client B's connection and ensure client A observes client B leaving the project. client_b.disconnect(&cx_b.to_async()).unwrap(); project_a .condition(&cx_a, |p, _| p.collaborators().len() == 0) .await; + + // Rejoin the project as client B + let _project_b = Project::remote( + project_id, + client_b.clone(), + client_b.user_store.clone(), + lang_registry.clone(), + fs.clone(), + &mut cx_b.to_async(), + ) + .await + .unwrap(); + + // Client A sees that a guest has re-joined. + project_a + .condition(&cx_a, |p, _| p.collaborators().len() == 1) + .await; + + // Simulate connection loss for client B and ensure client A observes client B leaving the project. + server.disconnect_client(client_b.current_user_id(cx_b)); + project_a + .condition(&cx_a, |p, _| p.collaborators().len() == 0) + .await; } #[gpui::test(iterations = 10)] @@ -5031,9 +5079,15 @@ mod tests { } impl Executor for Arc { + type Timer = BoxFuture<'static, ()>; + fn spawn_detached>(&self, future: F) { self.spawn(future).detach(); } + + fn timer(&self, duration: Duration) -> Self::Timer { + self.as_ref().timer(duration).boxed() + } } fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {