From 5888e7b214685aa1d8dd24e657d84c7b015aa08f Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 24 Aug 2023 13:40:44 -0700 Subject: [PATCH] Dedup channel buffers --- crates/channel/src/channel_buffer.rs | 62 +++++++------- crates/channel/src/channel_store.rs | 81 +++++++++++++++---- .../collab/src/tests/channel_buffer_tests.rs | 56 +++++++++++++ 3 files changed, 152 insertions(+), 47 deletions(-) diff --git a/crates/channel/src/channel_buffer.rs b/crates/channel/src/channel_buffer.rs index cad3c4f58f..c19899501a 100644 --- a/crates/channel/src/channel_buffer.rs +++ b/crates/channel/src/channel_buffer.rs @@ -1,7 +1,7 @@ use crate::{Channel, ChannelId, ChannelStore}; use anyhow::Result; use client::Client; -use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task}; +use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle}; use rpc::{proto, TypedEnvelope}; use std::sync::Arc; use util::ResultExt; @@ -38,46 +38,44 @@ impl Entity for ChannelBuffer { } impl ChannelBuffer { - pub(crate) fn new( + pub(crate) async fn new( channel_store: ModelHandle, channel_id: ChannelId, client: Arc, - cx: &mut AppContext, - ) -> Task>> { - cx.spawn(|mut cx| async move { - let response = client - .request(proto::JoinChannelBuffer { channel_id }) - .await?; + mut cx: AsyncAppContext, + ) -> Result> { + let response = client + .request(proto::JoinChannelBuffer { channel_id }) + .await?; - let base_text = response.base_text; - let operations = response - .operations - .into_iter() - .map(language::proto::deserialize_operation) - .collect::, _>>()?; + let base_text = response.base_text; + let operations = response + .operations + .into_iter() + .map(language::proto::deserialize_operation) + .collect::, _>>()?; - let collaborators = response.collaborators; + let collaborators = response.collaborators; - let buffer = cx.add_model(|_| { - language::Buffer::remote(response.buffer_id, response.replica_id as u16, base_text) - }); - buffer.update(&mut cx, |buffer, cx| buffer.apply_ops(operations, cx))?; + let buffer = cx.add_model(|_| { + language::Buffer::remote(response.buffer_id, response.replica_id as u16, base_text) + }); + buffer.update(&mut cx, |buffer, cx| buffer.apply_ops(operations, cx))?; - let subscription = client.subscribe_to_entity(channel_id)?; + let subscription = client.subscribe_to_entity(channel_id)?; - anyhow::Ok(cx.add_model(|cx| { - cx.subscribe(&buffer, Self::on_buffer_update).detach(); + anyhow::Ok(cx.add_model(|cx| { + cx.subscribe(&buffer, Self::on_buffer_update).detach(); - Self { - buffer, - client, - channel_id, - channel_store, - collaborators, - _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()), - } - })) - }) + Self { + buffer, + client, + channel_id, + channel_store, + collaborators, + _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()), + } + })) } async fn handle_update_channel_buffer( diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index a6aad19d03..1d83bd1d7f 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -1,20 +1,13 @@ -use anyhow::anyhow; -use anyhow::Result; -use client::Status; -use client::UserId; -use client::{Client, Subscription, User, UserStore}; -use collections::HashMap; -use collections::HashSet; -use futures::channel::mpsc; -use futures::Future; -use futures::StreamExt; -use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, Task}; +use crate::channel_buffer::ChannelBuffer; +use anyhow::{anyhow, Result}; +use client::{Client, Status, Subscription, User, UserId, UserStore}; +use collections::{hash_map, HashMap, HashSet}; +use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt, TryFutureExt}; +use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use rpc::{proto, TypedEnvelope}; use std::sync::Arc; use util::ResultExt; -use crate::channel_buffer::ChannelBuffer; - pub type ChannelId = u64; pub struct ChannelStore { @@ -25,6 +18,7 @@ pub struct ChannelStore { channels_with_admin_privileges: HashSet, outgoing_invites: HashSet<(ChannelId, UserId)>, update_channels_tx: mpsc::UnboundedSender, + opened_buffers: HashMap, client: Arc, user_store: ModelHandle, _rpc_subscription: Subscription, @@ -59,6 +53,11 @@ pub enum ChannelMemberStatus { NotMember, } +enum OpenedChannelBuffer { + Open(WeakModelHandle), + Loading(Shared, Arc>>>), +} + impl ChannelStore { pub fn new( client: Arc, @@ -89,6 +88,7 @@ impl ChannelStore { } } }); + Self { channels_by_id: HashMap::default(), channel_invitations: Vec::default(), @@ -96,6 +96,7 @@ impl ChannelStore { channel_participants: Default::default(), channels_with_admin_privileges: Default::default(), outgoing_invites: Default::default(), + opened_buffers: Default::default(), update_channels_tx, client, user_store, @@ -154,11 +155,61 @@ impl ChannelStore { } pub fn open_channel_buffer( - &self, + &mut self, channel_id: ChannelId, cx: &mut ModelContext, ) -> Task>> { - ChannelBuffer::new(cx.handle(), channel_id, self.client.clone(), cx) + // Make sure that a given channel buffer is only opened once per + // app instance, even if this method is called multiple times + // with the same channel id while the first task is still running. + let task = loop { + match self.opened_buffers.entry(channel_id) { + hash_map::Entry::Occupied(e) => match e.get() { + OpenedChannelBuffer::Open(buffer) => { + if let Some(buffer) = buffer.upgrade(cx) { + break Task::ready(Ok(buffer)).shared(); + } else { + self.opened_buffers.remove(&channel_id); + continue; + } + } + OpenedChannelBuffer::Loading(task) => break task.clone(), + }, + hash_map::Entry::Vacant(e) => { + let task = cx + .spawn(|this, cx| { + ChannelBuffer::new(this, channel_id, self.client.clone(), cx) + .map_err(Arc::new) + }) + .shared(); + e.insert(OpenedChannelBuffer::Loading(task.clone())); + cx.spawn({ + let task = task.clone(); + |this, mut cx| async move { + let result = task.await; + this.update(&mut cx, |this, cx| { + if let Ok(buffer) = result { + cx.observe_release(&buffer, move |this, _, _| { + this.opened_buffers.remove(&channel_id); + }) + .detach(); + this.opened_buffers.insert( + channel_id, + OpenedChannelBuffer::Open(buffer.downgrade()), + ); + } else { + this.opened_buffers.remove(&channel_id); + } + }); + } + }) + .detach(); + break task; + } + } + }; + cx.foreground() + .spawn(async move { task.await.map_err(|error| anyhow!("{}", error)) }) } pub fn is_user_admin(&self, channel_id: ChannelId) -> bool { diff --git a/crates/collab/src/tests/channel_buffer_tests.rs b/crates/collab/src/tests/channel_buffer_tests.rs index 6a9ef3fc13..f7e5751a37 100644 --- a/crates/collab/src/tests/channel_buffer_tests.rs +++ b/crates/collab/src/tests/channel_buffer_tests.rs @@ -3,6 +3,7 @@ use call::ActiveCall; use client::UserId; use collab_ui::channel_view::ChannelView; use collections::HashMap; +use futures::future; use gpui::{executor::Deterministic, ModelHandle, TestAppContext}; use rpc::{proto, RECEIVE_TIMEOUT}; use serde_json::json; @@ -283,6 +284,61 @@ async fn test_channel_buffer_replica_ids( }); } +#[gpui::test] +async fn test_reopen_channel_buffer(deterministic: Arc, cx_a: &mut TestAppContext) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + + let zed_id = server.make_channel("zed", (&client_a, cx_a), &mut []).await; + + let channel_buffer_1 = client_a + .channel_store() + .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)); + let channel_buffer_2 = client_a + .channel_store() + .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)); + let channel_buffer_3 = client_a + .channel_store() + .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)); + + // All concurrent tasks for opening a channel buffer return the same model handle. + let (channel_buffer_1, channel_buffer_2, channel_buffer_3) = + future::try_join3(channel_buffer_1, channel_buffer_2, channel_buffer_3) + .await + .unwrap(); + let model_id = channel_buffer_1.id(); + assert_eq!(channel_buffer_1, channel_buffer_2); + assert_eq!(channel_buffer_1, channel_buffer_3); + + channel_buffer_1.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "hello")], None, cx); + }) + }); + deterministic.run_until_parked(); + + cx_a.update(|_| { + drop(channel_buffer_1); + drop(channel_buffer_2); + drop(channel_buffer_3); + }); + deterministic.run_until_parked(); + + // The channel buffer can be reopened after dropping it. + let channel_buffer = client_a + .channel_store() + .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)) + .await + .unwrap(); + assert_ne!(channel_buffer.id(), model_id); + channel_buffer.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, _| { + assert_eq!(buffer.text(), "hello"); + }) + }); +} + #[track_caller] fn assert_collaborators(collaborators: &[proto::Collaborator], ids: &[Option]) { assert_eq!(