From b8aba0972d19b09b24e5ee2e209649df5c533151 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 9 May 2022 17:23:39 -0700 Subject: [PATCH] Wait until contacts have been cleared when disconnecting Also, use an mpsc for UpdateContacts messages, not a watch, since the messages now represent changes instead of snapshots. Co-authored-by: Nathan Sobo --- crates/client/src/user.rs | 237 +++++++++++++++++++++----------------- crates/collab/src/db.rs | 104 +++++++++-------- crates/collab/src/rpc.rs | 10 +- 3 files changed, 193 insertions(+), 158 deletions(-) diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index b4743d3567..a8de7a082d 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -1,6 +1,6 @@ use super::{http::HttpClient, proto, Client, Status, TypedEnvelope}; use anyhow::{anyhow, Context, Result}; -use futures::{future, AsyncReadExt}; +use futures::{channel::mpsc, future, AsyncReadExt, Future, StreamExt}; use gpui::{AsyncAppContext, Entity, ImageData, ModelContext, ModelHandle, Task}; use postage::{prelude::Stream, sink::Sink, watch}; use rpc::proto::{RequestMessage, UsersResponse}; @@ -42,7 +42,7 @@ pub enum ContactRequestStatus { pub struct UserStore { users: HashMap>, - update_contacts_tx: watch::Sender>, + update_contacts_tx: mpsc::UnboundedSender, current_user: watch::Receiver>>, contacts: Vec>, incoming_contact_requests: Vec>, @@ -60,6 +60,11 @@ impl Entity for UserStore { type Event = Event; } +enum UpdateContacts { + Update(proto::UpdateContacts), + Clear(postage::barrier::Sender), +} + impl UserStore { pub fn new( client: Arc, @@ -67,8 +72,7 @@ impl UserStore { cx: &mut ModelContext, ) -> Self { let (mut current_user_tx, current_user_rx) = watch::channel(); - let (update_contacts_tx, mut update_contacts_rx) = - watch::channel::>(); + let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded(); let rpc_subscription = client.add_message_handler(cx.handle(), Self::handle_update_contacts); Self { @@ -82,8 +86,8 @@ impl UserStore { http, _maintain_contacts: cx.spawn_weak(|this, mut cx| async move { let _subscription = rpc_subscription; - while let Some(message) = update_contacts_rx.recv().await { - if let Some((message, this)) = message.zip(this.upgrade(&cx)) { + while let Some(message) = update_contacts_rx.next().await { + if let Some(this) = this.upgrade(&cx) { this.update(&mut cx, |this, cx| this.update_contacts(message, cx)) .log_err() .await; @@ -121,114 +125,130 @@ impl UserStore { mut cx: AsyncAppContext, ) -> Result<()> { this.update(&mut cx, |this, _| { - *this.update_contacts_tx.borrow_mut() = Some(msg.payload); + this.update_contacts_tx + .unbounded_send(UpdateContacts::Update(msg.payload)) + .unwrap(); }); Ok(()) } fn update_contacts( &mut self, - message: proto::UpdateContacts, + message: UpdateContacts, cx: &mut ModelContext, ) -> Task> { - log::info!("update contacts on client {:?}", message); - let mut user_ids = HashSet::new(); - for contact in &message.contacts { - user_ids.insert(contact.user_id); - user_ids.extend(contact.projects.iter().flat_map(|w| &w.guests).copied()); - } - user_ids.extend(message.incoming_requests.iter().map(|req| req.requester_id)); - user_ids.extend(message.outgoing_requests.iter()); - - let load_users = self.get_users(user_ids.into_iter().collect(), cx); - cx.spawn(|this, mut cx| async move { - load_users.await?; - - // Users are fetched in parallel above and cached in call to get_users - // No need to paralellize here - let mut updated_contacts = Vec::new(); - for contact in message.contacts { - updated_contacts.push(Arc::new( - Contact::from_proto(contact, &this, &mut cx).await?, - )); + match message { + UpdateContacts::Clear(barrier) => { + self.contacts.clear(); + self.incoming_contact_requests.clear(); + self.outgoing_contact_requests.clear(); + drop(barrier); + Task::ready(Ok(())) } + UpdateContacts::Update(message) => { + log::info!( + "update contacts on client {}: {:?}", + self.client.upgrade().unwrap().id, + message + ); + let mut user_ids = HashSet::new(); + for contact in &message.contacts { + user_ids.insert(contact.user_id); + user_ids.extend(contact.projects.iter().flat_map(|w| &w.guests).copied()); + } + user_ids.extend(message.incoming_requests.iter().map(|req| req.requester_id)); + user_ids.extend(message.outgoing_requests.iter()); + + let load_users = self.get_users(user_ids.into_iter().collect(), cx); + cx.spawn(|this, mut cx| async move { + load_users.await?; + + // Users are fetched in parallel above and cached in call to get_users + // No need to paralellize here + let mut updated_contacts = Vec::new(); + for contact in message.contacts { + updated_contacts.push(Arc::new( + Contact::from_proto(contact, &this, &mut cx).await?, + )); + } + + let mut incoming_requests = Vec::new(); + for request in message.incoming_requests { + incoming_requests.push( + this.update(&mut cx, |this, cx| { + this.fetch_user(request.requester_id, cx) + }) + .await?, + ); + } + + let mut outgoing_requests = Vec::new(); + for requested_user_id in message.outgoing_requests { + outgoing_requests.push( + this.update(&mut cx, |this, cx| this.fetch_user(requested_user_id, cx)) + .await?, + ); + } + + let removed_contacts = + HashSet::::from_iter(message.remove_contacts.iter().copied()); + let removed_incoming_requests = + HashSet::::from_iter(message.remove_incoming_requests.iter().copied()); + let removed_outgoing_requests = + HashSet::::from_iter(message.remove_outgoing_requests.iter().copied()); - let mut incoming_requests = Vec::new(); - for request in message.incoming_requests { - incoming_requests.push( this.update(&mut cx, |this, cx| { - this.fetch_user(request.requester_id, cx) - }) - .await?, - ); + // Remove contacts + this.contacts + .retain(|contact| !removed_contacts.contains(&contact.user.id)); + // Update existing contacts and insert new ones + for updated_contact in updated_contacts { + match this.contacts.binary_search_by_key( + &&updated_contact.user.github_login, + |contact| &contact.user.github_login, + ) { + Ok(ix) => this.contacts[ix] = updated_contact, + Err(ix) => this.contacts.insert(ix, updated_contact), + } + } + + // Remove incoming contact requests + this.incoming_contact_requests + .retain(|user| !removed_incoming_requests.contains(&user.id)); + // Update existing incoming requests and insert new ones + for request in incoming_requests { + match this + .incoming_contact_requests + .binary_search_by_key(&&request.github_login, |contact| { + &contact.github_login + }) { + Ok(ix) => this.incoming_contact_requests[ix] = request, + Err(ix) => this.incoming_contact_requests.insert(ix, request), + } + } + + // Remove outgoing contact requests + this.outgoing_contact_requests + .retain(|user| !removed_outgoing_requests.contains(&user.id)); + // Update existing incoming requests and insert new ones + for request in outgoing_requests { + match this + .outgoing_contact_requests + .binary_search_by_key(&&request.github_login, |contact| { + &contact.github_login + }) { + Ok(ix) => this.outgoing_contact_requests[ix] = request, + Err(ix) => this.outgoing_contact_requests.insert(ix, request), + } + } + + cx.notify(); + }); + + Ok(()) + }) } - - let mut outgoing_requests = Vec::new(); - for requested_user_id in message.outgoing_requests { - outgoing_requests.push( - this.update(&mut cx, |this, cx| this.fetch_user(requested_user_id, cx)) - .await?, - ); - } - - let removed_contacts = - HashSet::::from_iter(message.remove_contacts.iter().copied()); - let removed_incoming_requests = - HashSet::::from_iter(message.remove_incoming_requests.iter().copied()); - let removed_outgoing_requests = - HashSet::::from_iter(message.remove_outgoing_requests.iter().copied()); - - this.update(&mut cx, |this, cx| { - // Remove contacts - this.contacts - .retain(|contact| !removed_contacts.contains(&contact.user.id)); - // Update existing contacts and insert new ones - for updated_contact in updated_contacts { - match this - .contacts - .binary_search_by_key(&&updated_contact.user.github_login, |contact| { - &contact.user.github_login - }) { - Ok(ix) => this.contacts[ix] = updated_contact, - Err(ix) => this.contacts.insert(ix, updated_contact), - } - } - - // Remove incoming contact requests - this.incoming_contact_requests - .retain(|user| !removed_incoming_requests.contains(&user.id)); - // Update existing incoming requests and insert new ones - for request in incoming_requests { - match this - .incoming_contact_requests - .binary_search_by_key(&&request.github_login, |contact| { - &contact.github_login - }) { - Ok(ix) => this.incoming_contact_requests[ix] = request, - Err(ix) => this.incoming_contact_requests.insert(ix, request), - } - } - - // Remove outgoing contact requests - this.outgoing_contact_requests - .retain(|user| !removed_outgoing_requests.contains(&user.id)); - // Update existing incoming requests and insert new ones - for request in outgoing_requests { - match this - .outgoing_contact_requests - .binary_search_by_key(&&request.github_login, |contact| { - &contact.github_login - }) { - Ok(ix) => this.outgoing_contact_requests[ix] = request, - Err(ix) => this.outgoing_contact_requests.insert(ix, request), - } - } - - cx.notify(); - }); - - Ok(()) - }) + } } pub fn contacts(&self) -> &[Arc] { @@ -342,11 +362,14 @@ impl UserStore { }) } - #[cfg(any(test, feature = "test-support"))] - pub fn clear_contacts(&mut self) { - self.contacts.clear(); - self.incoming_contact_requests.clear(); - self.outgoing_contact_requests.clear(); + pub fn clear_contacts(&mut self) -> impl Future { + let (tx, mut rx) = postage::barrier::channel(); + self.update_contacts_tx + .unbounded_send(UpdateContacts::Clear(tx)) + .unwrap(); + async move { + rx.recv().await; + } } pub fn get_users( diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 57619941a0..3e8ec6b322 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -77,6 +77,8 @@ pub trait Db: Send + Sync { ) -> Result>; #[cfg(test)] async fn teardown(&self, url: &str); + #[cfg(test)] + fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>; } pub struct PostgresDb { @@ -291,6 +293,37 @@ impl Db for PostgresDb { } } + async fn dismiss_contact_request( + &self, + responder_id: UserId, + requester_id: UserId, + ) -> Result<()> { + let (id_a, id_b, a_to_b) = if responder_id < requester_id { + (responder_id, requester_id, false) + } else { + (requester_id, responder_id, true) + }; + + let query = " + UPDATE contacts + SET should_notify = 'f' + WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3; + "; + + let result = sqlx::query(query) + .bind(id_a.0) + .bind(id_b.0) + .bind(a_to_b) + .execute(&self.pool) + .await?; + + if result.rows_affected() == 0 { + Err(anyhow!("no such contact request"))?; + } + + Ok(()) + } + async fn respond_to_contact_request( &self, responder_id: UserId, @@ -333,37 +366,6 @@ impl Db for PostgresDb { } } - async fn dismiss_contact_request( - &self, - responder_id: UserId, - requester_id: UserId, - ) -> Result<()> { - let (id_a, id_b, a_to_b) = if responder_id < requester_id { - (responder_id, requester_id, false) - } else { - (requester_id, responder_id, true) - }; - - let query = " - UPDATE contacts - SET should_notify = 'f' - WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3; - "; - - let result = sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&self.pool) - .await?; - - if result.rows_affected() == 0 { - Err(anyhow!("no such contact request"))?; - } - - Ok(()) - } - // access tokens async fn create_access_token_hash( @@ -620,6 +622,11 @@ impl Db for PostgresDb { .await .log_err(); } + + #[cfg(test)] + fn as_fake(&self) -> Option<&tests::FakeDb> { + None + } } macro_rules! id_type { @@ -1108,25 +1115,25 @@ pub mod tests { pub struct FakeDb { background: Arc, - users: Mutex>, - next_user_id: Mutex, - orgs: Mutex>, - next_org_id: Mutex, - org_memberships: Mutex>, - channels: Mutex>, - next_channel_id: Mutex, - channel_memberships: Mutex>, - channel_messages: Mutex>, + pub users: Mutex>, + pub orgs: Mutex>, + pub org_memberships: Mutex>, + pub channels: Mutex>, + pub channel_memberships: Mutex>, + pub channel_messages: Mutex>, + pub contacts: Mutex>, next_channel_message_id: Mutex, - contacts: Mutex>, + next_user_id: Mutex, + next_org_id: Mutex, + next_channel_id: Mutex, } #[derive(Debug)] - struct FakeContact { - requester_id: UserId, - responder_id: UserId, - accepted: bool, - should_notify: bool, + pub struct FakeContact { + pub requester_id: UserId, + pub responder_id: UserId, + pub accepted: bool, + pub should_notify: bool, } impl FakeDb { @@ -1514,5 +1521,10 @@ pub mod tests { } async fn teardown(&self, _: &str) {} + + #[cfg(test)] + fn as_fake(&self) -> Option<&FakeDb> { + Some(self) + } } } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 98342f1be3..ea71ab5f00 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -5321,7 +5321,7 @@ mod tests { async fn disconnect_and_reconnect(client: &TestClient, cx: &mut TestAppContext) { client.disconnect(&cx.to_async()).unwrap(); - client.clear_contacts(cx); + client.clear_contacts(cx).await; client .authenticate_and_connect(false, &cx.to_async()) .await @@ -6584,10 +6584,10 @@ mod tests { while authed_user.next().await.unwrap().is_none() {} } - fn clear_contacts(&self, cx: &mut TestAppContext) { - self.user_store.update(cx, |store, _| { - store.clear_contacts(); - }); + async fn clear_contacts(&self, cx: &mut TestAppContext) { + self.user_store + .update(cx, |store, _| store.clear_contacts()) + .await; } fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary {