diff --git a/server/src/rpc.rs b/server/src/rpc.rs index d107e3606b..e5086d870c 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -619,12 +619,14 @@ impl Server { .app_state .db .create_channel_message(channel_id, user_id, &request.payload.body, timestamp) - .await?; + .await? + .to_proto(); + let receipt = request.receipt(); let message = proto::ChannelMessageSent { channel_id: channel_id.to_proto(), message: Some(proto::ChannelMessage { sender_id: user_id.to_proto(), - id: message_id.to_proto(), + id: message_id, body: request.payload.body, timestamp: timestamp.unix_timestamp() as u64, }), @@ -633,7 +635,15 @@ impl Server { self.peer.send(conn_id, message.clone()) }) .await?; - + self.peer + .respond( + receipt, + proto::SendChannelMessageResponse { + message_id, + timestamp: timestamp.unix_timestamp() as u64, + }, + ) + .await?; Ok(()) } diff --git a/server/src/tests.rs b/server/src/tests.rs index dd899e9fed..b301d7b2d6 100644 --- a/server/src/tests.rs +++ b/server/src/tests.rs @@ -485,13 +485,11 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) { let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await; let (user_id_b, client_b) = server.create_client(&mut cx_a, "user_b").await; - // Create an org that includes these 2 users and 1 other user. + // Create an org that includes these 2 users. let db = &server.app_state.db; - let user_id_c = db.create_user("user_c", false).await.unwrap(); let org_id = db.create_org("Test Org", "test-org").await.unwrap(); db.add_org_member(org_id, user_id_a, false).await.unwrap(); db.add_org_member(org_id, user_id_b, false).await.unwrap(); - db.add_org_member(org_id, user_id_c, false).await.unwrap(); // Create a channel that includes all the users. let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); @@ -501,13 +499,10 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) { db.add_channel_member(channel_id, user_id_b, false) .await .unwrap(); - db.add_channel_member(channel_id, user_id_c, false) - .await - .unwrap(); db.create_channel_message( channel_id, - user_id_c, - "first message!", + user_id_b, + "hello A, it's B.", OffsetDateTime::now_utc(), ) .await @@ -516,9 +511,6 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) { let channels_a = ChannelList::new(client_a, &mut cx_a.to_async()) .await .unwrap(); - let channels_b = ChannelList::new(client_b, &mut cx_b.to_async()) - .await - .unwrap(); channels_a.read_with(&cx_a, |list, _| { assert_eq!( list.available_channels(), @@ -532,12 +524,33 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) { let channel_a = channels_a.update(&mut cx_a, |this, cx| { this.get_channel(channel_id.to_proto(), cx).unwrap() }); - - channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_none())); + channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty())); channel_a.next_notification(&cx_a).await; channel_a.read_with(&cx_a, |channel, _| { - assert_eq!(channel.messages().unwrap().len(), 1); + assert_eq!( + channel + .messages() + .iter() + .map(|m| (m.sender_id, m.body.as_ref())) + .collect::>(), + &[(user_id_b.to_proto(), "hello A, it's B.")] + ); }); + + channel_a.update(&mut cx_a, |channel, cx| { + channel.send_message("oh, hi B.".to_string(), cx).unwrap(); + channel.send_message("sup".to_string(), cx).unwrap(); + assert_eq!( + channel + .pending_messages() + .iter() + .map(|m| &m.body) + .collect::>(), + &["oh, hi B.", "sup"] + ) + }); + + channel_a.next_notification(&cx_a).await; } struct TestServer { @@ -577,10 +590,9 @@ impl TestServer { ) .detach(); client - .add_connection(client_conn, cx.to_async()) + .add_connection(user_id.to_proto(), client_conn, cx.to_async()) .await .unwrap(); - (user_id, client) } diff --git a/zed/src/channel.rs b/zed/src/channel.rs index d3fc64a505..cbbdec6015 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -1,5 +1,8 @@ -use crate::rpc::{self, Client}; -use anyhow::{Context, Result}; +use crate::{ + rpc::{self, Client}, + util::log_async_errors, +}; +use anyhow::{anyhow, Context, Result}; use gpui::{ AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, WeakModelHandle, }; @@ -27,14 +30,24 @@ pub struct ChannelDetails { pub struct Channel { details: ChannelDetails, first_message_id: Option, - messages: Option>, + messages: VecDeque, + pending_messages: Vec, + next_local_message_id: u64, rpc: Arc, _subscription: rpc::Subscription, } pub struct ChannelMessage { - id: u64, + pub id: u64, + pub sender_id: u64, + pub body: String, } + +pub struct PendingChannelMessage { + pub body: String, + local_id: u64, +} + pub enum Event {} impl Entity for ChannelList { @@ -110,13 +123,10 @@ impl Channel { let channel_id = details.id; cx.spawn(|channel, mut cx| async move { match rpc.request(proto::JoinChannel { channel_id }).await { - Ok(response) => { - let messages = response.messages.into_iter().map(Into::into).collect(); - channel.update(&mut cx, |channel, cx| { - channel.messages = Some(messages); - cx.notify(); - }) - } + Ok(response) => channel.update(&mut cx, |channel, cx| { + channel.messages = response.messages.into_iter().map(Into::into).collect(); + cx.notify(); + }), Err(error) => log::error!("error joining channel: {}", error), } }) @@ -127,14 +137,54 @@ impl Channel { details, rpc, first_message_id: None, - messages: None, + messages: Default::default(), + pending_messages: Default::default(), + next_local_message_id: 0, _subscription, } } + pub fn send_message(&mut self, body: String, cx: &mut ModelContext) -> Result<()> { + let channel_id = self.details.id; + let current_user_id = self.rpc.user_id().ok_or_else(|| anyhow!("not logged in"))?; + let local_id = self.next_local_message_id; + self.next_local_message_id += 1; + self.pending_messages.push(PendingChannelMessage { + local_id, + body: body.clone(), + }); + let rpc = self.rpc.clone(); + cx.spawn(|this, mut cx| { + log_async_errors(async move { + let request = rpc.request(proto::SendChannelMessage { channel_id, body }); + let response = request.await?; + this.update(&mut cx, |this, cx| { + if let Ok(i) = this + .pending_messages + .binary_search_by_key(&local_id, |msg| msg.local_id) + { + let body = this.pending_messages.remove(i).body; + this.messages.push_back(ChannelMessage { + id: response.message_id, + sender_id: current_user_id, + body, + }); + cx.notify(); + } + }); + Ok(()) + }) + }) + .detach(); + Ok(()) + } - pub fn messages(&self) -> Option<&VecDeque> { - self.messages.as_ref() + pub fn messages(&self) -> &VecDeque { + &self.messages + } + + pub fn pending_messages(&self) -> &[PendingChannelMessage] { + &self.pending_messages } fn handle_message_sent( @@ -158,6 +208,10 @@ impl From for ChannelDetails { impl From for ChannelMessage { fn from(message: proto::ChannelMessage) -> Self { - ChannelMessage { id: message.id } + ChannelMessage { + id: message.id, + sender_id: message.sender_id, + body: message.body, + } } } diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index a65551a7ff..b8909a988a 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -31,6 +31,7 @@ pub struct Client { #[derive(Default)] struct ClientState { connection_id: Option, + user_id: Option, entity_id_extractors: HashMap u64>>, model_handlers: HashMap< (TypeId, u64), @@ -66,6 +67,10 @@ impl Client { }) } + pub fn user_id(&self) -> Option { + self.state.read().user_id + } + pub fn subscribe_from_model( self: &Arc, remote_id: u64, @@ -125,7 +130,7 @@ impl Client { } let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?; - let user_id: i32 = user_id.parse()?; + let user_id = user_id.parse::()?; let request = Request::builder().header("Authorization", format!("{} {}", user_id, access_token)); @@ -135,23 +140,25 @@ impl Client { let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream) .await .context("websocket handshake")?; - log::info!("connected to rpc address {}", *ZED_SERVER_URL); - self.add_connection(stream, cx).await?; + self.add_connection(user_id, stream, cx).await?; } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") { let stream = smol::net::TcpStream::connect(host).await?; let request = request.uri(format!("ws://{}/rpc", host)).body(())?; - let (stream, _) = async_tungstenite::client_async(request, stream).await?; - log::info!("connected to rpc address {}", *ZED_SERVER_URL); - self.add_connection(stream, cx).await?; + let (stream, _) = async_tungstenite::client_async(request, stream) + .await + .context("websocket handshake")?; + self.add_connection(user_id, stream, cx).await?; } else { return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?; }; + log::info!("connected to rpc address {}", *ZED_SERVER_URL); Ok(()) } pub async fn add_connection( self: &Arc, + user_id: u64, conn: Conn, cx: AsyncAppContext, ) -> surf::Result<()> @@ -202,7 +209,9 @@ impl Client { } }) .detach(); - self.state.write().connection_id = Some(connection_id); + let mut state = self.state.write(); + state.connection_id = Some(connection_id); + state.user_id = Some(user_id); Ok(()) } diff --git a/zrpc/proto/zed.proto b/zrpc/proto/zed.proto index 48b481c5f4..f368cb2d47 100644 --- a/zrpc/proto/zed.proto +++ b/zrpc/proto/zed.proto @@ -30,7 +30,8 @@ message Envelope { JoinChannelResponse join_channel_response = 25; LeaveChannel leave_channel = 26; SendChannelMessage send_channel_message = 27; - ChannelMessageSent channel_message_sent = 28; + SendChannelMessageResponse send_channel_message_response = 28; + ChannelMessageSent channel_message_sent = 29; } } @@ -148,6 +149,11 @@ message SendChannelMessage { string body = 2; } +message SendChannelMessageResponse { + uint64 message_id = 1; + uint64 timestamp = 2; +} + message ChannelMessageSent { uint64 channel_id = 1; ChannelMessage message = 2;