diff --git a/crates/collab/src/db/queries/buffers.rs b/crates/collab/src/db/queries/buffers.rs index b22bfc80cf..6a74ae4d44 100644 --- a/crates/collab/src/db/queries/buffers.rs +++ b/crates/collab/src/db/queries/buffers.rs @@ -529,7 +529,7 @@ impl Database { .on_conflict( OnConflict::columns([Column::UserId, Column::BufferId]) .update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId]) - .target_cond_where( + .action_cond_where( Condition::any() .add(Column::Epoch.lt(*max_operation.epoch.as_ref())) .add( @@ -702,7 +702,7 @@ impl Database { pub async fn channels_with_changed_notes( &self, user_id: UserId, - channel_ids: impl IntoIterator, + channel_ids: &[ChannelId], tx: &DatabaseTransaction, ) -> Result> { #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] @@ -713,7 +713,7 @@ impl Database { let mut channel_ids_by_buffer_id = HashMap::default(); let mut rows = buffer::Entity::find() - .filter(buffer::Column::ChannelId.is_in(channel_ids)) + .filter(buffer::Column::ChannelId.is_in(channel_ids.iter().copied())) .stream(&*tx) .await?; while let Some(row) = rows.next().await { diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index 6d976b310e..8292e9dbcb 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -463,20 +463,14 @@ impl Database { } } + let channel_ids = graph.channels.iter().map(|c| c.id).collect::>(); let channels_with_changed_notes = self - .channels_with_changed_notes( - user_id, - graph.channels.iter().map(|channel| channel.id), - &*tx, - ) + .channels_with_changed_notes(user_id, &channel_ids, &*tx) .await?; - let mut channels_with_new_messages = HashSet::default(); - for channel in graph.channels.iter() { - if self.has_new_message(channel.id, user_id, tx).await? { - channels_with_new_messages.insert(channel.id); - } - } + let channels_with_new_messages = self + .channels_with_new_messages(user_id, &channel_ids, &*tx) + .await?; Ok(ChannelsForUser { channels: graph, diff --git a/crates/collab/src/db/queries/messages.rs b/crates/collab/src/db/queries/messages.rs index 8e3c92d916..484509f685 100644 --- a/crates/collab/src/db/queries/messages.rs +++ b/crates/collab/src/db/queries/messages.rs @@ -217,14 +217,12 @@ impl Database { ConnectionId, } - // Observe this message for all participants - observed_channel_messages::Entity::insert_many(participant_user_ids.iter().map( - |pariticpant_id| observed_channel_messages::ActiveModel { - user_id: ActiveValue::Set(*pariticpant_id), - channel_id: ActiveValue::Set(channel_id), - channel_message_id: ActiveValue::Set(message.last_insert_id), - }, - )) + // Observe this message for the sender + observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel { + user_id: ActiveValue::Set(user_id), + channel_id: ActiveValue::Set(channel_id), + channel_message_id: ActiveValue::Set(message.last_insert_id), + }) .on_conflict( OnConflict::columns([ observed_channel_messages::Column::ChannelId, @@ -248,51 +246,74 @@ impl Database { .await } - #[cfg(test)] - pub async fn has_new_message_tx(&self, channel_id: ChannelId, user_id: UserId) -> Result { - self.transaction(|tx| async move { self.has_new_message(channel_id, user_id, &*tx).await }) - .await - } - - #[cfg(test)] - pub async fn dbg_print_messages(&self) -> Result<()> { - self.transaction(|tx| async move { - dbg!(observed_channel_messages::Entity::find() - .all(&*tx) - .await - .unwrap()); - dbg!(channel_message::Entity::find().all(&*tx).await.unwrap()); - - Ok(()) - }) - .await - } - - pub async fn has_new_message( + pub async fn channels_with_new_messages( &self, - channel_id: ChannelId, user_id: UserId, + channel_ids: &[ChannelId], tx: &DatabaseTransaction, - ) -> Result { - self.check_user_is_channel_member(channel_id, user_id, &*tx) + ) -> Result> { + let mut observed_messages_by_channel_id = HashMap::default(); + let mut rows = observed_channel_messages::Entity::find() + .filter(observed_channel_messages::Column::UserId.eq(user_id)) + .filter(observed_channel_messages::Column::ChannelId.is_in(channel_ids.iter().copied())) + .stream(&*tx) .await?; - let latest_message_id = channel_message::Entity::find() - .filter(Condition::all().add(channel_message::Column::ChannelId.eq(channel_id))) - .order_by(channel_message::Column::SentAt, sea_query::Order::Desc) - .limit(1 as u64) - .one(&*tx) - .await? - .map(|model| model.id); + while let Some(row) = rows.next().await { + let row = row?; + observed_messages_by_channel_id.insert(row.channel_id, row); + } + drop(rows); + let mut values = String::new(); + for id in channel_ids { + if !values.is_empty() { + values.push_str(", "); + } + write!(&mut values, "({})", id).unwrap(); + } - let last_message_read = observed_channel_messages::Entity::find() - .filter(observed_channel_messages::Column::ChannelId.eq(channel_id)) - .filter(observed_channel_messages::Column::UserId.eq(user_id)) - .one(&*tx) - .await? - .map(|model| model.channel_message_id); + if values.is_empty() { + return Ok(Default::default()); + } - Ok(last_message_read != latest_message_id) + let sql = format!( + r#" + SELECT + * + FROM ( + SELECT + *, + row_number() OVER ( + PARTITION BY channel_id + ORDER BY id DESC + ) as row_number + FROM channel_messages + WHERE + channel_id in ({values}) + ) AS messages + WHERE + row_number = 1 + "#, + ); + + let stmt = Statement::from_string(self.pool.get_database_backend(), sql); + let last_messages = channel_message::Model::find_by_statement(stmt) + .all(&*tx) + .await?; + + let mut channels_with_new_changes = HashSet::default(); + for last_message in last_messages { + if let Some(observed_message) = + observed_messages_by_channel_id.get(&last_message.channel_id) + { + if observed_message.channel_message_id == last_message.id { + continue; + } + } + channels_with_new_changes.insert(last_message.channel_id); + } + + Ok(channels_with_new_changes) } pub async fn remove_channel_message( diff --git a/crates/collab/src/db/tests/buffer_tests.rs b/crates/collab/src/db/tests/buffer_tests.rs index 5a5fe6a812..d8edef963a 100644 --- a/crates/collab/src/db/tests/buffer_tests.rs +++ b/crates/collab/src/db/tests/buffer_tests.rs @@ -171,6 +171,8 @@ test_both_dbs!( ); async fn test_channel_buffers_diffs(db: &Database) { + panic!("Rewriting the way this works"); + let a_id = db .create_user( "user_a@example.com", diff --git a/crates/collab/src/db/tests/message_tests.rs b/crates/collab/src/db/tests/message_tests.rs index 98b8cc6037..e212c36466 100644 --- a/crates/collab/src/db/tests/message_tests.rs +++ b/crates/collab/src/db/tests/message_tests.rs @@ -65,6 +65,8 @@ test_both_dbs!( ); async fn test_channel_message_new_notification(db: &Arc) { + panic!("Rewriting the way this works"); + let user_a = db .create_user( "user_a@example.com", @@ -108,7 +110,7 @@ async fn test_channel_message_new_notification(db: &Arc) { let owner_id = db.create_server("test").await.unwrap().0 as u32; // Zero case: no messages at all - assert!(!db.has_new_message_tx(channel, user_b).await.unwrap()); + // assert!(!db.has_new_message_tx(channel, user_b).await.unwrap()); let a_connection_id = rpc::ConnectionId { owner_id, id: 0 }; db.join_channel_chat(channel, a_connection_id, user_a) @@ -131,7 +133,7 @@ async fn test_channel_message_new_notification(db: &Arc) { .unwrap(); // Smoke test: can we detect a new message? - assert!(db.has_new_message_tx(channel, user_b).await.unwrap()); + // assert!(db.has_new_message_tx(channel, user_b).await.unwrap()); let b_connection_id = rpc::ConnectionId { owner_id, id: 1 }; db.join_channel_chat(channel, b_connection_id, user_b) @@ -139,7 +141,7 @@ async fn test_channel_message_new_notification(db: &Arc) { .unwrap(); // Joining the channel should _not_ update us to the latest message - assert!(db.has_new_message_tx(channel, user_b).await.unwrap()); + // assert!(db.has_new_message_tx(channel, user_b).await.unwrap()); // Reading the earlier messages should not change that we have new messages let _ = db @@ -147,7 +149,7 @@ async fn test_channel_message_new_notification(db: &Arc) { .await .unwrap(); - assert!(db.has_new_message_tx(channel, user_b).await.unwrap()); + // assert!(db.has_new_message_tx(channel, user_b).await.unwrap()); // This constraint is currently inexpressible, creating a message implicitly broadcasts // it to all participants @@ -165,7 +167,7 @@ async fn test_channel_message_new_notification(db: &Arc) { .await .unwrap(); - assert!(!db.has_new_message_tx(channel, user_b).await.unwrap()); + // assert!(!db.has_new_message_tx(channel, user_b).await.unwrap()); // And future messages should not reset the flag let _ = db @@ -173,26 +175,26 @@ async fn test_channel_message_new_notification(db: &Arc) { .await .unwrap(); - assert!(!db.has_new_message_tx(channel, user_b).await.unwrap()); + // assert!(!db.has_new_message_tx(channel, user_b).await.unwrap()); let _ = db .create_channel_message(channel, user_b, "6", OffsetDateTime::now_utc(), 6) .await .unwrap(); - assert!(!db.has_new_message_tx(channel, user_b).await.unwrap()); + // assert!(!db.has_new_message_tx(channel, user_b).await.unwrap()); // And we should start seeing the flag again after we've left the channel db.leave_channel_chat(channel, b_connection_id, user_b) .await .unwrap(); - assert!(!db.has_new_message_tx(channel, user_b).await.unwrap()); + // assert!(!db.has_new_message_tx(channel, user_b).await.unwrap()); let _ = db .create_channel_message(channel, user_a, "7", OffsetDateTime::now_utc(), 7) .await .unwrap(); - assert!(db.has_new_message_tx(channel, user_b).await.unwrap()); + // assert!(db.has_new_message_tx(channel, user_b).await.unwrap()); }