diff --git a/crates/channel/src/channel_buffer.rs b/crates/channel/src/channel_buffer.rs index c19899501a..29f4d3493c 100644 --- a/crates/channel/src/channel_buffer.rs +++ b/crates/channel/src/channel_buffer.rs @@ -1,4 +1,4 @@ -use crate::{Channel, ChannelId, ChannelStore}; +use crate::Channel; use anyhow::Result; use client::Client; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle}; @@ -13,39 +13,43 @@ pub(crate) fn init(client: &Arc) { } pub struct ChannelBuffer { - channel_id: ChannelId, + pub(crate) channel: Arc, + connected: bool, collaborators: Vec, buffer: ModelHandle, - channel_store: ModelHandle, client: Arc, - _subscription: client::Subscription, + subscription: Option, } pub enum Event { CollaboratorsChanged, + Disconnected, } impl Entity for ChannelBuffer { type Event = Event; fn release(&mut self, _: &mut AppContext) { - self.client - .send(proto::LeaveChannelBuffer { - channel_id: self.channel_id, - }) - .log_err(); + if self.connected { + self.client + .send(proto::LeaveChannelBuffer { + channel_id: self.channel.id, + }) + .log_err(); + } } } impl ChannelBuffer { pub(crate) async fn new( - channel_store: ModelHandle, - channel_id: ChannelId, + channel: Arc, client: Arc, mut cx: AsyncAppContext, ) -> Result> { let response = client - .request(proto::JoinChannelBuffer { channel_id }) + .request(proto::JoinChannelBuffer { + channel_id: channel.id, + }) .await?; let base_text = response.base_text; @@ -62,7 +66,7 @@ impl ChannelBuffer { }); buffer.update(&mut cx, |buffer, cx| buffer.apply_ops(operations, cx))?; - let subscription = client.subscribe_to_entity(channel_id)?; + let subscription = client.subscribe_to_entity(channel.id)?; anyhow::Ok(cx.add_model(|cx| { cx.subscribe(&buffer, Self::on_buffer_update).detach(); @@ -70,10 +74,10 @@ impl ChannelBuffer { Self { buffer, client, - channel_id, - channel_store, + connected: true, collaborators, - _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()), + channel, + subscription: Some(subscription.set_model(&cx.handle(), &mut cx.to_async())), } })) } @@ -155,7 +159,7 @@ impl ChannelBuffer { let operation = language::proto::serialize_operation(operation); self.client .send(proto::UpdateChannelBuffer { - channel_id: self.channel_id, + channel_id: self.channel.id, operations: vec![operation], }) .log_err(); @@ -170,11 +174,21 @@ impl ChannelBuffer { &self.collaborators } - pub fn channel(&self, cx: &AppContext) -> Option> { - self.channel_store - .read(cx) - .channel_for_id(self.channel_id) - .cloned() + pub fn channel(&self) -> Arc { + self.channel.clone() + } + + pub(crate) fn disconnect(&mut self, cx: &mut ModelContext) { + if self.connected { + self.connected = false; + self.subscription.take(); + cx.emit(Event::Disconnected); + cx.notify() + } + } + + pub fn is_connected(&self) -> bool { + self.connected } pub fn replica_id(&self, cx: &AppContext) -> u16 { diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index 1d83bd1d7f..861f731331 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -2,7 +2,7 @@ use crate::channel_buffer::ChannelBuffer; use anyhow::{anyhow, Result}; use client::{Client, Status, Subscription, User, UserId, UserStore}; use collections::{hash_map, HashMap, HashSet}; -use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt, TryFutureExt}; +use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt}; use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use rpc::{proto, TypedEnvelope}; use std::sync::Arc; @@ -71,16 +71,14 @@ impl ChannelStore { let mut connection_status = client.status(); let watch_connection_status = cx.spawn_weak(|this, mut cx| async move { while let Some(status) = connection_status.next().await { - if matches!(status, Status::ConnectionLost | Status::SignedOut) { + if !status.is_connected() { if let Some(this) = this.upgrade(&cx) { this.update(&mut cx, |this, cx| { - this.channels_by_id.clear(); - this.channel_invitations.clear(); - this.channel_participants.clear(); - this.channels_with_admin_privileges.clear(); - this.channel_paths.clear(); - this.outgoing_invites.clear(); - cx.notify(); + if matches!(status, Status::ConnectionLost | Status::SignedOut) { + this.handle_disconnect(cx); + } else { + this.disconnect_buffers(cx); + } }); } else { break; @@ -176,9 +174,17 @@ impl ChannelStore { OpenedChannelBuffer::Loading(task) => break task.clone(), }, hash_map::Entry::Vacant(e) => { + let client = self.client.clone(); let task = cx - .spawn(|this, cx| { - ChannelBuffer::new(this, channel_id, self.client.clone(), cx) + .spawn(|this, cx| async move { + let channel = this.read_with(&cx, |this, _| { + this.channel_for_id(channel_id).cloned().ok_or_else(|| { + Arc::new(anyhow!("no channel for id: {}", channel_id)) + }) + })?; + + ChannelBuffer::new(channel, client, cx) + .await .map_err(Arc::new) }) .shared(); @@ -187,8 +193,8 @@ impl ChannelStore { let task = task.clone(); |this, mut cx| async move { let result = task.await; - this.update(&mut cx, |this, cx| { - if let Ok(buffer) = result { + this.update(&mut cx, |this, cx| match result { + Ok(buffer) => { cx.observe_release(&buffer, move |this, _, _| { this.opened_buffers.remove(&channel_id); }) @@ -197,7 +203,9 @@ impl ChannelStore { channel_id, OpenedChannelBuffer::Open(buffer.downgrade()), ); - } else { + } + Err(error) => { + log::error!("failed to open channel buffer {error:?}"); this.opened_buffers.remove(&channel_id); } }); @@ -474,6 +482,27 @@ impl ChannelStore { Ok(()) } + fn handle_disconnect(&mut self, cx: &mut ModelContext<'_, ChannelStore>) { + self.disconnect_buffers(cx); + self.channels_by_id.clear(); + self.channel_invitations.clear(); + self.channel_participants.clear(); + self.channels_with_admin_privileges.clear(); + self.channel_paths.clear(); + self.outgoing_invites.clear(); + cx.notify(); + } + + fn disconnect_buffers(&mut self, cx: &mut ModelContext) { + for (_, buffer) in self.opened_buffers.drain() { + if let OpenedChannelBuffer::Open(buffer) = buffer { + if let Some(buffer) = buffer.upgrade(cx) { + buffer.update(cx, |buffer, cx| buffer.disconnect(cx)); + } + } + } + } + pub(crate) fn update_channels( &mut self, payload: proto::UpdateChannels, @@ -508,38 +537,44 @@ impl ChannelStore { .retain(|channel_id, _| !payload.remove_channels.contains(channel_id)); self.channels_with_admin_privileges .retain(|channel_id| !payload.remove_channels.contains(channel_id)); + + for channel_id in &payload.remove_channels { + let channel_id = *channel_id; + if let Some(OpenedChannelBuffer::Open(buffer)) = + self.opened_buffers.remove(&channel_id) + { + if let Some(buffer) = buffer.upgrade(cx) { + buffer.update(cx, ChannelBuffer::disconnect); + } + } + } } - for channel in payload.channels { - if let Some(existing_channel) = self.channels_by_id.get_mut(&channel.id) { - // FIXME: We may be missing a path for this existing channel in certain cases - let existing_channel = Arc::make_mut(existing_channel); - existing_channel.name = channel.name; - continue; - } + for channel_proto in payload.channels { + if let Some(existing_channel) = self.channels_by_id.get_mut(&channel_proto.id) { + Arc::make_mut(existing_channel).name = channel_proto.name; + } else { + let channel = Arc::new(Channel { + id: channel_proto.id, + name: channel_proto.name, + }); + self.channels_by_id.insert(channel.id, channel.clone()); - self.channels_by_id.insert( - channel.id, - Arc::new(Channel { - id: channel.id, - name: channel.name, - }), - ); - - if let Some(parent_id) = channel.parent_id { - let mut ix = 0; - while ix < self.channel_paths.len() { - let path = &self.channel_paths[ix]; - if path.ends_with(&[parent_id]) { - let mut new_path = path.clone(); - new_path.push(channel.id); - self.channel_paths.insert(ix + 1, new_path); + if let Some(parent_id) = channel_proto.parent_id { + let mut ix = 0; + while ix < self.channel_paths.len() { + let path = &self.channel_paths[ix]; + if path.ends_with(&[parent_id]) { + let mut new_path = path.clone(); + new_path.push(channel.id); + self.channel_paths.insert(ix + 1, new_path); + ix += 1; + } ix += 1; } - ix += 1; + } else { + self.channel_paths.push(vec![channel.id]); } - } else { - self.channel_paths.push(vec![channel.id]); } } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 2bd39c861d..18587c2ba8 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -854,10 +854,13 @@ async fn connection_lost( .await .trace_err(); + leave_channel_buffers_for_session(&session) + .await + .trace_err(); + futures::select_biased! { _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => { leave_room_for_session(&session).await.trace_err(); - leave_channel_buffers_for_session(&session).await.trace_err(); if !session .connection_pool() diff --git a/crates/collab/src/tests/channel_buffer_tests.rs b/crates/collab/src/tests/channel_buffer_tests.rs index 0ecd4588c5..8ac4dbbd3f 100644 --- a/crates/collab/src/tests/channel_buffer_tests.rs +++ b/crates/collab/src/tests/channel_buffer_tests.rs @@ -1,5 +1,6 @@ use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; use call::ActiveCall; +use channel::Channel; use client::UserId; use collab_ui::channel_view::ChannelView; use collections::HashMap; @@ -334,6 +335,81 @@ async fn test_reopen_channel_buffer(deterministic: Arc, cx_a: &mu }); } +#[gpui::test] +async fn test_channel_buffer_disconnect( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channel_id = server + .make_channel("zed", (&client_a, cx_a), &mut [(&client_b, cx_b)]) + .await; + + let channel_buffer_a = client_a + .channel_store() + .update(cx_a, |channel, cx| { + channel.open_channel_buffer(channel_id, cx) + }) + .await + .unwrap(); + + let channel_buffer_b = client_b + .channel_store() + .update(cx_b, |channel, cx| { + channel.open_channel_buffer(channel_id, cx) + }) + .await + .unwrap(); + + server.forbid_connections(); + server.disconnect_client(client_a.peer_id().unwrap()); + deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + + channel_buffer_a.update(cx_a, |buffer, _| { + assert_eq!( + buffer.channel().as_ref(), + &Channel { + id: channel_id, + name: "zed".to_string() + } + ); + assert!(!buffer.is_connected()); + }); + + deterministic.run_until_parked(); + + server.allow_connections(); + deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + + deterministic.run_until_parked(); + + client_a + .channel_store() + .update(cx_a, |channel_store, _| { + channel_store.remove_channel(channel_id) + }) + .await + .unwrap(); + deterministic.run_until_parked(); + + // Channel buffer observed the deletion + channel_buffer_b.update(cx_b, |buffer, _| { + assert_eq!( + buffer.channel().as_ref(), + &Channel { + id: channel_id, + name: "zed".to_string() + } + ); + assert!(!buffer.is_connected()); + }); +} + #[track_caller] fn assert_collaborators(collaborators: &[proto::Collaborator], ids: &[Option]) { assert_eq!( diff --git a/crates/collab/src/tests/channel_tests.rs b/crates/collab/src/tests/channel_tests.rs index 41d2286772..b54b4d349b 100644 --- a/crates/collab/src/tests/channel_tests.rs +++ b/crates/collab/src/tests/channel_tests.rs @@ -799,7 +799,7 @@ async fn test_lost_channel_creation( deterministic.run_until_parked(); - // Sanity check + // Sanity check, B has the invitation assert_channel_invitations( client_b.channel_store(), cx_b, @@ -811,6 +811,7 @@ async fn test_lost_channel_creation( }], ); + // A creates a subchannel while the invite is still pending. let subchannel_id = client_a .channel_store() .update(cx_a, |channel_store, cx| { @@ -841,7 +842,7 @@ async fn test_lost_channel_creation( ], ); - // Accept the invite + // Client B accepts the invite client_b .channel_store() .update(cx_b, |channel_store, _| { @@ -852,7 +853,7 @@ async fn test_lost_channel_creation( deterministic.run_until_parked(); - // B should now see the channel + // Client B should now see the channel assert_channels( client_b.channel_store(), cx_b, diff --git a/crates/collab_ui/src/channel_view.rs b/crates/collab_ui/src/channel_view.rs index 0e2d3636aa..9c125117e1 100644 --- a/crates/collab_ui/src/channel_view.rs +++ b/crates/collab_ui/src/channel_view.rs @@ -114,10 +114,18 @@ impl ChannelView { fn handle_channel_buffer_event( &mut self, _: ModelHandle, - _: &channel_buffer::Event, + event: &channel_buffer::Event, cx: &mut ViewContext, ) { - self.refresh_replica_id_map(cx); + match event { + channel_buffer::Event::CollaboratorsChanged => { + self.refresh_replica_id_map(cx); + } + channel_buffer::Event::Disconnected => self.editor.update(cx, |editor, cx| { + editor.set_read_only(true); + cx.notify(); + }), + } } /// Build a mapping of channel buffer replica ids to the corresponding @@ -183,14 +191,13 @@ impl Item for ChannelView { style: &theme::Tab, cx: &gpui::AppContext, ) -> AnyElement { - let channel_name = self - .channel_buffer - .read(cx) - .channel(cx) - .map_or("[Deleted channel]".to_string(), |channel| { - format!("#{}", channel.name) - }); - Label::new(channel_name, style.label.to_owned()).into_any() + let channel_name = &self.channel_buffer.read(cx).channel().name; + let label = if self.channel_buffer.read(cx).is_connected() { + format!("#{}", channel_name) + } else { + format!("#{} (disconnected)", channel_name) + }; + Label::new(label, style.label.to_owned()).into_any() } fn clone_on_split(&self, _: WorkspaceId, cx: &mut ViewContext) -> Option { @@ -208,8 +215,9 @@ impl FollowableItem for ChannelView { } fn to_state_proto(&self, cx: &AppContext) -> Option { - self.channel_buffer.read(cx).channel(cx).map(|channel| { - proto::view::Variant::ChannelView(proto::view::ChannelView { + let channel = self.channel_buffer.read(cx).channel(); + Some(proto::view::Variant::ChannelView( + proto::view::ChannelView { channel_id: channel.id, editor: if let Some(proto::view::Variant::Editor(proto)) = self.editor.read(cx).to_state_proto(cx) @@ -218,8 +226,8 @@ impl FollowableItem for ChannelView { } else { None }, - }) - }) + }, + )) } fn from_state_proto(