diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 06bdb0f729..d1a717e66e 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -169,6 +169,30 @@ impl Database { self.run(body).await } + pub async fn weak_transaction(&self, f: F) -> Result + where + F: Send + Fn(TransactionHandle) -> Fut, + Fut: Send + Future>, + { + let body = async { + let (tx, result) = self.with_weak_transaction(&f).await?; + match result { + Ok(result) => match tx.commit().await.map_err(Into::into) { + Ok(()) => return Ok(result), + Err(error) => { + return Err(error); + } + }, + Err(error) => { + tx.rollback().await?; + return Err(error); + } + } + }; + + self.run(body).await + } + /// The same as room_transaction, but if you need to only optionally return a Room. async fn optional_room_transaction(&self, f: F) -> Result>> where @@ -284,6 +308,30 @@ impl Database { Ok((tx, result)) } + async fn with_weak_transaction( + &self, + f: &F, + ) -> Result<(DatabaseTransaction, Result)> + where + F: Send + Fn(TransactionHandle) -> Fut, + Fut: Send + Future>, + { + let tx = self + .pool + .begin_with_config(Some(IsolationLevel::ReadCommitted), None) + .await?; + + let mut tx = Arc::new(Some(tx)); + let result = f(TransactionHandle(tx.clone())).await; + let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else { + return Err(anyhow!( + "couldn't complete transaction because it's still in use" + ))?; + }; + + Ok((tx, result)) + } + async fn run(&self, future: F) -> Result where F: Future>, @@ -457,9 +505,8 @@ pub struct NewUserResult { /// The result of moving a channel. #[derive(Debug)] pub struct MoveChannelResult { - pub participants_to_update: HashMap, - pub participants_to_remove: HashSet, - pub moved_channels: HashSet, + pub previous_participants: Vec, + pub descendent_ids: Vec, } /// The result of renaming a channel. diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index ac18907894..95c9716a91 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -22,7 +22,6 @@ impl Database { Ok(self .create_channel(name, None, creator_id) .await? - .channel .id) } @@ -36,7 +35,6 @@ impl Database { Ok(self .create_channel(name, Some(parent), creator_id) .await? - .channel .id) } @@ -46,7 +44,7 @@ impl Database { name: &str, parent_channel_id: Option, admin_id: UserId, - ) -> Result { + ) -> Result { let name = Self::sanitize_channel_name(name)?; self.transaction(move |tx| async move { let mut parent = None; @@ -72,14 +70,7 @@ impl Database { .insert(&*tx) .await?; - let participants_to_update; - if let Some(parent) = &parent { - participants_to_update = self - .participants_to_notify_for_channel_change(parent, &*tx) - .await?; - } else { - participants_to_update = vec![]; - + if parent.is_none() { channel_member::ActiveModel { id: ActiveValue::NotSet, channel_id: ActiveValue::Set(channel.id), @@ -89,12 +80,9 @@ impl Database { } .insert(&*tx) .await?; - }; + } - Ok(CreateChannelResult { - channel: Channel::from_model(channel, ChannelRole::Admin), - participants_to_update, - }) + Ok(Channel::from_model(channel, ChannelRole::Admin)) }) .await } @@ -718,6 +706,19 @@ impl Database { }) } + pub async fn new_participants_to_notify( + &self, + parent_channel_id: ChannelId, + ) -> Result> { + self.weak_transaction(|tx| async move { + let parent_channel = self.get_channel_internal(parent_channel_id, &*tx).await?; + self.participants_to_notify_for_channel_change(&parent_channel, &*tx) + .await + }) + .await + } + + // TODO: this is very expensive, and we should rethink async fn participants_to_notify_for_channel_change( &self, new_parent: &channel::Model, @@ -1287,7 +1288,7 @@ impl Database { let mut model = channel.into_active_model(); model.parent_path = ActiveValue::Set(new_parent_path); - let channel = model.update(&*tx).await?; + model.update(&*tx).await?; if new_parent_channel.is_none() { channel_member::ActiveModel { @@ -1314,34 +1315,9 @@ impl Database { .all(&*tx) .await?; - let participants_to_update: HashMap<_, _> = self - .participants_to_notify_for_channel_change( - new_parent_channel.as_ref().unwrap_or(&channel), - &*tx, - ) - .await? - .into_iter() - .collect(); - - let mut moved_channels: HashSet = HashSet::default(); - for id in descendent_ids { - moved_channels.insert(id); - } - moved_channels.insert(channel_id); - - let mut participants_to_remove: HashSet = HashSet::default(); - for participant in previous_participants { - if participant.kind == proto::channel_member::Kind::AncestorMember { - if !participants_to_update.contains_key(&participant.user_id) { - participants_to_remove.insert(participant.user_id); - } - } - } - Ok(Some(MoveChannelResult { - participants_to_remove, - participants_to_update, - moved_channels, + previous_participants, + descendent_ids, })) }) .await diff --git a/crates/collab/src/db/tests/message_tests.rs b/crates/collab/src/db/tests/message_tests.rs index 10d9778612..22319ecc96 100644 --- a/crates/collab/src/db/tests/message_tests.rs +++ b/crates/collab/src/db/tests/message_tests.rs @@ -15,11 +15,11 @@ test_both_dbs!( async fn test_channel_message_retrieval(db: &Arc) { let user = new_test_user(db, "user@example.com").await; - let result = db.create_channel("channel", None, user).await.unwrap(); + let channel = db.create_channel("channel", None, user).await.unwrap(); let owner_id = db.create_server("test").await.unwrap().0 as u32; db.join_channel_chat( - result.channel.id, + channel.id, rpc::ConnectionId { owner_id, id: 0 }, user, ) @@ -30,7 +30,7 @@ async fn test_channel_message_retrieval(db: &Arc) { for i in 0..10 { all_messages.push( db.create_channel_message( - result.channel.id, + channel.id, user, &i.to_string(), &[], @@ -45,7 +45,7 @@ async fn test_channel_message_retrieval(db: &Arc) { } let messages = db - .get_channel_messages(result.channel.id, user, 3, None) + .get_channel_messages(channel.id, user, 3, None) .await .unwrap() .into_iter() @@ -55,7 +55,7 @@ async fn test_channel_message_retrieval(db: &Arc) { let messages = db .get_channel_messages( - result.channel.id, + channel.id, user, 4, Some(MessageId::from_proto(all_messages[6])), @@ -370,7 +370,6 @@ async fn test_channel_message_mentions(db: &Arc) { .create_channel("channel", None, user_a) .await .unwrap() - .channel .id; db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member) .await diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 7170f7f1c5..415119bcd1 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -3,9 +3,9 @@ mod connection_pool; use crate::{ auth::{self, Impersonator}, db::{ - self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult, + self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId, - MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult, + NotificationId, ProjectId, RemoveChannelMemberResult, RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult, User, UserId, }, @@ -2301,10 +2301,7 @@ async fn create_channel( let db = session.db().await; let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id)); - let CreateChannelResult { - channel, - participants_to_update, - } = db + let channel = db .create_channel(&request.name, parent_id, session.user_id) .await?; @@ -2313,6 +2310,13 @@ async fn create_channel( parent_id: request.parent_id, })?; + let participants_to_update; + if let Some(parent) = parent_id { + participants_to_update = db.new_participants_to_notify(parent).await?; + } else { + participants_to_update = vec![]; + } + let connection_pool = session.connection_pool().await; for (user_id, channels) in participants_to_update { let update = build_channels_update(channels, vec![]); @@ -2566,50 +2570,58 @@ async fn move_channel( let channel_id = ChannelId::from_proto(request.channel_id); let to = request.to.map(ChannelId::from_proto); - let result = session - .db() - .await - .move_channel(channel_id, to, session.user_id) - .await?; + let result = session.db().await.move_channel(channel_id, to, session.user_id).await?; - notify_channel_moved(result, session).await?; + if let Some(result) = result { + let participants_to_update: HashMap<_, _> = session.db().await + .new_participants_to_notify( + to.unwrap_or(channel_id) + ) + .await? + .into_iter() + .collect(); + + let mut moved_channels: HashSet = HashSet::default(); + for id in result.descendent_ids { + moved_channels.insert(id); + } + moved_channels.insert(channel_id); + + let mut participants_to_remove: HashSet = HashSet::default(); + for participant in result.previous_participants { + if participant.kind == proto::channel_member::Kind::AncestorMember { + if !participants_to_update.contains_key(&participant.user_id) { + participants_to_remove.insert(participant.user_id); + } + } + } + + let moved_channels: Vec = moved_channels.iter().map(|id| id.to_proto()).collect(); + + let connection_pool = session.connection_pool().await; + for (user_id, channels) in participants_to_update { + let mut update = build_channels_update(channels, vec![]); + update.delete_channels = moved_channels.clone(); + for connection_id in connection_pool.user_connection_ids(user_id) { + session.peer.send(connection_id, update.clone())?; + } + } + + for user_id in participants_to_remove { + let update = proto::UpdateChannels { + delete_channels: moved_channels.clone(), + ..Default::default() + }; + for connection_id in connection_pool.user_connection_ids(user_id) { + session.peer.send(connection_id, update.clone())?; + } + } + } response.send(Ack {})?; Ok(()) } -async fn notify_channel_moved(result: Option, session: Session) -> Result<()> { - let Some(MoveChannelResult { - participants_to_remove, - participants_to_update, - moved_channels, - }) = result - else { - return Ok(()); - }; - let moved_channels: Vec = moved_channels.iter().map(|id| id.to_proto()).collect(); - - let connection_pool = session.connection_pool().await; - for (user_id, channels) in participants_to_update { - let mut update = build_channels_update(channels, vec![]); - update.delete_channels = moved_channels.clone(); - for connection_id in connection_pool.user_connection_ids(user_id) { - session.peer.send(connection_id, update.clone())?; - } - } - - for user_id in participants_to_remove { - let update = proto::UpdateChannels { - delete_channels: moved_channels.clone(), - ..Default::default() - }; - for connection_id in connection_pool.user_connection_ids(user_id) { - session.peer.send(connection_id, update.clone())?; - } - } - Ok(()) -} - /// Get the list of channel members async fn get_channel_members( request: proto::GetChannelMembers,