diff --git a/gpui/src/app.rs b/gpui/src/app.rs index a2fa818824..221419341e 100644 --- a/gpui/src/app.rs +++ b/gpui/src/app.rs @@ -14,7 +14,7 @@ use keymap::MatchResult; use parking_lot::{Mutex, RwLock}; use pathfinder_geometry::{rect::RectF, vector::vec2f}; use platform::Event; -use postage::{mpsc, sink::Sink as _, stream::Stream as _}; +use postage::{mpsc, oneshot, sink::Sink as _, stream::Stream as _}; use smol::prelude::*; use std::{ any::{type_name, Any, TypeId}, @@ -2310,6 +2310,24 @@ impl ModelHandle { cx.update_model(self, update) } + pub fn next_notification(&self, cx: &TestAppContext) -> impl Future { + let (tx, mut rx) = oneshot::channel(); + let mut tx = Some(tx); + + let mut cx = cx.cx.borrow_mut(); + self.update(&mut *cx, |_, cx| { + cx.observe(self, move |_, _, _| { + if let Some(mut tx) = tx.take() { + tx.blocking_send(()).ok(); + } + }); + }); + + async move { + rx.recv().await; + } + } + pub fn condition( &self, cx: &TestAppContext, diff --git a/server/src/tests.rs b/server/src/tests.rs index 607929327f..dd899e9fed 100644 --- a/server/src/tests.rs +++ b/server/src/tests.rs @@ -480,8 +480,6 @@ async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) #[gpui::test] async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) { - let lang_registry = Arc::new(LanguageRegistry::new()); - // Connect to a server as 2 clients. let mut server = TestServer::start().await; let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await; @@ -531,8 +529,14 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) { ) }); - let channel_a = channels_a.read_with(&cx_a, |this, cx| { - this.get_channel(channel_id.to_proto(), &cx).unwrap() + let channel_a = channels_a.update(&mut cx_a, |this, cx| { + this.get_channel(channel_id.to_proto(), cx).unwrap() + }); + + channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_none())); + channel_a.next_notification(&cx_a).await; + channel_a.read_with(&cx_a, |channel, _| { + assert_eq!(channel.messages().unwrap().len(), 1); }); } diff --git a/zed/src/channel.rs b/zed/src/channel.rs index 1c7d576c65..3161592f4d 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -1,8 +1,11 @@ use crate::rpc::{self, Client}; use anyhow::{Context, Result}; -use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, WeakModelHandle}; +use gpui::{ + executor, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, + WeakModelHandle, +}; use std::{ - collections::{HashMap, VecDeque}, + collections::{hash_map, HashMap, VecDeque}, sync::Arc, }; use zrpc::{ @@ -16,7 +19,7 @@ pub struct ChannelList { rpc: Arc, } -#[derive(Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq)] pub struct ChannelDetails { pub id: u64, pub name: String, @@ -28,6 +31,7 @@ pub struct Channel { messages: Option>, rpc: Arc, _subscription: rpc::Subscription, + background: Arc, } pub struct ChannelMessage { @@ -57,11 +61,28 @@ impl ChannelList { &self.available_channels } - pub fn get_channel(&self, id: u64, cx: &AppContext) -> Option> { - self.channels - .get(&id) - .cloned() - .and_then(|handle| handle.upgrade(cx)) + pub fn get_channel( + &mut self, + id: u64, + cx: &mut MutableAppContext, + ) -> Option> { + match self.channels.entry(id) { + hash_map::Entry::Occupied(entry) => entry.get().upgrade(cx), + hash_map::Entry::Vacant(entry) => { + if let Some(details) = self + .available_channels + .iter() + .find(|details| details.id == id) + { + let rpc = self.rpc.clone(); + let channel = cx.add_model(|cx| Channel::new(details.clone(), rpc, cx)); + entry.insert(channel.downgrade()); + Some(channel) + } else { + None + } + } + } } } @@ -73,12 +94,31 @@ impl Channel { pub fn new(details: ChannelDetails, rpc: Arc, cx: &mut ModelContext) -> Self { let _subscription = rpc.subscribe_from_model(details.id, cx, Self::handle_message_sent); + { + let rpc = rpc.clone(); + let channel_id = details.id; + cx.spawn(|channel, mut cx| async move { + match rpc.request(proto::JoinChannel { channel_id }).await { + Ok(response) => { + let messages = response.messages.into_iter().map(Into::into).collect(); + channel.update(&mut cx, |channel, cx| { + channel.messages = Some(messages); + cx.notify(); + }) + } + Err(error) => log::error!("error joining channel: {}", error), + } + }) + .detach(); + } + Self { details, rpc, first_message_id: None, messages: None, _subscription, + background: cx.background().clone(), } } @@ -90,6 +130,25 @@ impl Channel { ) -> Result<()> { Ok(()) } + + pub fn messages(&self) -> Option<&VecDeque> { + self.messages.as_ref() + } +} + +// TODO: Implement the server side of leaving a channel +impl Drop for Channel { + fn drop(&mut self) { + let rpc = self.rpc.clone(); + let channel_id = self.details.id; + self.background + .spawn(async move { + if let Err(error) = rpc.send(proto::LeaveChannel { channel_id }).await { + log::error!("error leaving channel: {}", error); + }; + }) + .detach() + } } impl From for ChannelDetails { @@ -100,3 +159,9 @@ impl From for ChannelDetails { } } } + +impl From for ChannelMessage { + fn from(message: proto::ChannelMessage) -> Self { + ChannelMessage { id: message.id } + } +} diff --git a/zrpc/proto/zed.proto b/zrpc/proto/zed.proto index 3a0b7aabb6..cbced82e7b 100644 --- a/zrpc/proto/zed.proto +++ b/zrpc/proto/zed.proto @@ -30,8 +30,9 @@ message Envelope { GetUsersResponse get_users_response = 25; JoinChannel join_channel = 26; JoinChannelResponse join_channel_response = 27; - SendChannelMessage send_channel_message = 28; - ChannelMessageSent channel_message_sent = 29; + LeaveChannel leave_channel = 28; + SendChannelMessage send_channel_message = 29; + ChannelMessageSent channel_message_sent = 30; } } @@ -141,6 +142,10 @@ message JoinChannelResponse { repeated ChannelMessage messages = 1; } +message LeaveChannel { + uint64 channel_id = 1; +} + message GetUsers { repeated uint64 user_ids = 1; } diff --git a/zrpc/src/proto.rs b/zrpc/src/proto.rs index 271fa8e29e..173e802f00 100644 --- a/zrpc/src/proto.rs +++ b/zrpc/src/proto.rs @@ -138,6 +138,7 @@ messages!( GetUsersResponse, JoinChannel, JoinChannelResponse, + LeaveChannel, OpenBuffer, OpenBufferResponse, OpenWorktree,