diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index d4276603f9..e3a6170452 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -88,80 +88,87 @@ impl Database { .await } - pub async fn join_channel_internal( - &self, - channel_id: ChannelId, - user_id: UserId, - connection: ConnectionId, - environment: &str, - tx: &DatabaseTransaction, - ) -> Result<(JoinRoom, bool)> { - let mut joined = false; - - let channel = channel::Entity::find() - .filter(channel::Column::Id.eq(channel_id)) - .one(&*tx) - .await?; - - let mut role = self - .channel_role_for_user(channel_id, user_id, &*tx) - .await?; - - if role.is_none() { - if channel.as_ref().map(|c| c.visibility) == Some(ChannelVisibility::Public) { - channel_member::Entity::insert(channel_member::ActiveModel { - id: ActiveValue::NotSet, - channel_id: ActiveValue::Set(channel_id), - user_id: ActiveValue::Set(user_id), - accepted: ActiveValue::Set(true), - role: ActiveValue::Set(ChannelRole::Guest), - }) - .on_conflict( - OnConflict::columns([ - channel_member::Column::UserId, - channel_member::Column::ChannelId, - ]) - .update_columns([channel_member::Column::Accepted]) - .to_owned(), - ) - .exec(&*tx) - .await?; - - debug_assert!( - self.channel_role_for_user(channel_id, user_id, &*tx) - .await? - == Some(ChannelRole::Guest) - ); - - role = Some(ChannelRole::Guest); - joined = true; - } - } - - if channel.is_none() || role.is_none() || role == Some(ChannelRole::Banned) { - Err(anyhow!("no such channel, or not allowed"))? - } - - let live_kit_room = format!("channel-{}", nanoid::nanoid!(30)); - let room_id = self - .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx) - .await?; - - self.join_channel_room_internal(channel_id, room_id, user_id, connection, &*tx) - .await - .map(|jr| (jr, joined)) - } - pub async fn join_channel( &self, channel_id: ChannelId, user_id: UserId, connection: ConnectionId, environment: &str, - ) -> Result<(JoinRoom, bool)> { + ) -> Result<(JoinRoom, Option)> { self.transaction(move |tx| async move { - self.join_channel_internal(channel_id, user_id, connection, environment, &*tx) + let mut joined_channel_id = None; + + let channel = channel::Entity::find() + .filter(channel::Column::Id.eq(channel_id)) + .one(&*tx) + .await?; + + let mut role = self + .channel_role_for_user(channel_id, user_id, &*tx) + .await?; + + if role.is_none() && channel.is_some() { + if let Some(invitation) = self + .pending_invite_for_channel(channel_id, user_id, &*tx) + .await? + { + // note, this may be a parent channel + joined_channel_id = Some(invitation.channel_id); + role = Some(invitation.role); + + channel_member::Entity::update(channel_member::ActiveModel { + accepted: ActiveValue::Set(true), + ..invitation.into_active_model() + }) + .exec(&*tx) + .await?; + + debug_assert!( + self.channel_role_for_user(channel_id, user_id, &*tx) + .await? + == role + ); + } + } + if role.is_none() + && channel.as_ref().map(|c| c.visibility) == Some(ChannelVisibility::Public) + { + let channel_id_to_join = self + .most_public_ancestor_for_channel(channel_id, &*tx) + .await? + .unwrap_or(channel_id); + role = Some(ChannelRole::Guest); + joined_channel_id = Some(channel_id_to_join); + + channel_member::Entity::insert(channel_member::ActiveModel { + id: ActiveValue::NotSet, + channel_id: ActiveValue::Set(channel_id_to_join), + user_id: ActiveValue::Set(user_id), + accepted: ActiveValue::Set(true), + role: ActiveValue::Set(ChannelRole::Guest), + }) + .exec(&*tx) + .await?; + + debug_assert!( + self.channel_role_for_user(channel_id, user_id, &*tx) + .await? + == role + ); + } + + if channel.is_none() || role.is_none() || role == Some(ChannelRole::Banned) { + Err(anyhow!("no such channel, or not allowed"))? + } + + let live_kit_room = format!("channel-{}", nanoid::nanoid!(30)); + let room_id = self + .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx) + .await?; + + self.join_channel_room_internal(channel_id, room_id, user_id, connection, &*tx) .await + .map(|jr| (jr, joined_channel_id)) }) .await } @@ -624,29 +631,29 @@ impl Database { admin_id: UserId, for_user: UserId, role: ChannelRole, - ) -> Result<()> { + ) -> Result { self.transaction(|tx| async move { self.check_user_is_channel_admin(channel_id, admin_id, &*tx) .await?; - let result = channel_member::Entity::update_many() + let membership = channel_member::Entity::find() .filter( channel_member::Column::ChannelId .eq(channel_id) .and(channel_member::Column::UserId.eq(for_user)), ) - .set(channel_member::ActiveModel { - role: ActiveValue::set(role), - ..Default::default() - }) - .exec(&*tx) + .one(&*tx) .await?; - if result.rows_affected == 0 { - Err(anyhow!("no such member"))?; - } + let Some(membership) = membership else { + Err(anyhow!("no such member"))? + }; - Ok(()) + let mut update = membership.into_active_model(); + update.role = ActiveValue::Set(role); + let updated = channel_member::Entity::update(update).exec(&*tx).await?; + + Ok(updated) }) .await } @@ -844,6 +851,52 @@ impl Database { } } + pub async fn pending_invite_for_channel( + &self, + channel_id: ChannelId, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result> { + let channel_ids = self.get_channel_ancestors(channel_id, tx).await?; + + let row = channel_member::Entity::find() + .filter(channel_member::Column::ChannelId.is_in(channel_ids)) + .filter(channel_member::Column::UserId.eq(user_id)) + .filter(channel_member::Column::Accepted.eq(false)) + .one(&*tx) + .await?; + + Ok(row) + } + + pub async fn most_public_ancestor_for_channel( + &self, + channel_id: ChannelId, + tx: &DatabaseTransaction, + ) -> Result> { + let channel_ids = self.get_channel_ancestors(channel_id, tx).await?; + + let rows = channel::Entity::find() + .filter(channel::Column::Id.is_in(channel_ids.clone())) + .filter(channel::Column::Visibility.eq(ChannelVisibility::Public)) + .all(&*tx) + .await?; + + let mut visible_channels: HashSet = HashSet::default(); + + for row in rows { + visible_channels.insert(row.id); + } + + for ancestor in channel_ids.into_iter().rev() { + if visible_channels.contains(&ancestor) { + return Ok(Some(ancestor)); + } + } + + Ok(None) + } + pub async fn channel_role_for_user( &self, channel_id: ChannelId, @@ -864,7 +917,8 @@ impl Database { .filter( channel_member::Column::ChannelId .is_in(channel_ids) - .and(channel_member::Column::UserId.eq(user_id)), + .and(channel_member::Column::UserId.eq(user_id)) + .and(channel_member::Column::Accepted.eq(true)), ) .select_only() .column(channel_member::Column::ChannelId) @@ -1009,52 +1063,22 @@ impl Database { Ok(results) } - /// Returns the channel with the given ID and: - /// - true if the user is a member - /// - false if the user hasn't accepted the invitation yet - pub async fn get_channel( - &self, - channel_id: ChannelId, - user_id: UserId, - ) -> Result> { + /// Returns the channel with the given ID + pub async fn get_channel(&self, channel_id: ChannelId, user_id: UserId) -> Result { self.transaction(|tx| async move { - let tx = tx; + self.check_user_is_channel_participant(channel_id, user_id, &*tx) + .await?; let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?; + let Some(channel) = channel else { + Err(anyhow!("no such channel"))? + }; - if let Some(channel) = channel { - if self - .check_user_is_channel_member(channel_id, user_id, &*tx) - .await - .is_err() - { - return Ok(None); - } - - let channel_membership = channel_member::Entity::find() - .filter( - channel_member::Column::ChannelId - .eq(channel_id) - .and(channel_member::Column::UserId.eq(user_id)), - ) - .one(&*tx) - .await?; - - let is_accepted = channel_membership - .map(|membership| membership.accepted) - .unwrap_or(false); - - Ok(Some(( - Channel { - id: channel.id, - visibility: channel.visibility, - name: channel.name, - }, - is_accepted, - ))) - } else { - Ok(None) - } + Ok(Channel { + id: channel.id, + visibility: channel.visibility, + name: channel.name, + }) }) .await } diff --git a/crates/collab/src/db/tests/channel_tests.rs b/crates/collab/src/db/tests/channel_tests.rs index 9b6d8d1525..f08b1554bc 100644 --- a/crates/collab/src/db/tests/channel_tests.rs +++ b/crates/collab/src/db/tests/channel_tests.rs @@ -51,7 +51,7 @@ async fn test_channels(db: &Arc) { let zed_id = db.create_root_channel("zed", a_id).await.unwrap(); // Make sure that people cannot read channels they haven't been invited to - assert!(db.get_channel(zed_id, b_id).await.unwrap().is_none()); + assert!(db.get_channel(zed_id, b_id).await.is_err()); db.invite_channel_member(zed_id, b_id, a_id, ChannelRole::Member) .await @@ -157,7 +157,7 @@ async fn test_channels(db: &Arc) { // Remove a single channel db.delete_channel(crdb_id, a_id).await.unwrap(); - assert!(db.get_channel(crdb_id, a_id).await.unwrap().is_none()); + assert!(db.get_channel(crdb_id, a_id).await.is_err()); // Remove a channel tree let (mut channel_ids, user_ids) = db.delete_channel(rust_id, a_id).await.unwrap(); @@ -165,9 +165,9 @@ async fn test_channels(db: &Arc) { assert_eq!(channel_ids, &[rust_id, cargo_id, cargo_ra_id]); assert_eq!(user_ids, &[a_id]); - assert!(db.get_channel(rust_id, a_id).await.unwrap().is_none()); - assert!(db.get_channel(cargo_id, a_id).await.unwrap().is_none()); - assert!(db.get_channel(cargo_ra_id, a_id).await.unwrap().is_none()); + assert!(db.get_channel(rust_id, a_id).await.is_err()); + assert!(db.get_channel(cargo_id, a_id).await.is_err()); + assert!(db.get_channel(cargo_ra_id, a_id).await.is_err()); } test_both_dbs!( @@ -381,11 +381,7 @@ async fn test_channel_renames(db: &Arc) { let zed_archive_id = zed_id; - let (channel, _) = db - .get_channel(zed_archive_id, user_1) - .await - .unwrap() - .unwrap(); + let channel = db.get_channel(zed_archive_id, user_1).await.unwrap(); assert_eq!(channel.name, "zed-archive"); let non_permissioned_rename = db @@ -860,12 +856,6 @@ async fn test_user_is_channel_participant(db: &Arc) { }) .await .unwrap(); - db.transaction(|tx| async move { - db.check_user_is_channel_participant(vim_channel, guest, &*tx) - .await - }) - .await - .unwrap(); let members = db .get_channel_participant_details(vim_channel, admin) @@ -896,6 +886,13 @@ async fn test_user_is_channel_participant(db: &Arc) { .await .unwrap(); + db.transaction(|tx| async move { + db.check_user_is_channel_participant(vim_channel, guest, &*tx) + .await + }) + .await + .unwrap(); + let channels = db.get_channels_for_user(guest).await.unwrap().channels; assert_dag(channels, &[(vim_channel, None)]); let channels = db.get_channels_for_user(member).await.unwrap().channels; @@ -953,29 +950,7 @@ async fn test_user_is_channel_participant(db: &Arc) { .await .unwrap(); - db.transaction(|tx| async move { - db.check_user_is_channel_participant(zed_channel, guest, &*tx) - .await - }) - .await - .unwrap(); - assert!(db - .transaction(|tx| async move { - db.check_user_is_channel_participant(active_channel, guest, &*tx) - .await - }) - .await - .is_err(),); - - db.transaction(|tx| async move { - db.check_user_is_channel_participant(vim_channel, guest, &*tx) - .await - }) - .await - .unwrap(); - // currently people invited to parent channels are not shown here - // (though they *do* have permissions!) let members = db .get_channel_participant_details(vim_channel, admin) .await @@ -1000,6 +975,27 @@ async fn test_user_is_channel_participant(db: &Arc) { .await .unwrap(); + db.transaction(|tx| async move { + db.check_user_is_channel_participant(zed_channel, guest, &*tx) + .await + }) + .await + .unwrap(); + assert!(db + .transaction(|tx| async move { + db.check_user_is_channel_participant(active_channel, guest, &*tx) + .await + }) + .await + .is_err(),); + + db.transaction(|tx| async move { + db.check_user_is_channel_participant(vim_channel, guest, &*tx) + .await + }) + .await + .unwrap(); + let members = db .get_channel_participant_details(vim_channel, admin) .await diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 26ad2f281a..4b33550c39 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -38,7 +38,7 @@ use lazy_static::lazy_static; use prometheus::{register_int_gauge, IntGauge}; use rpc::{ proto::{ - self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage, JoinRoom, + self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators, }, Connection, ConnectionId, Peer, Receipt, TypedEnvelope, @@ -2289,10 +2289,7 @@ async fn invite_channel_member( ) .await?; - let (channel, _) = db - .get_channel(channel_id, session.user_id) - .await? - .ok_or_else(|| anyhow!("channel not found"))?; + let channel = db.get_channel(channel_id, session.user_id).await?; let mut update = proto::UpdateChannels::default(); update.channel_invitations.push(proto::Channel { @@ -2380,21 +2377,19 @@ async fn set_channel_member_role( let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let member_id = UserId::from_proto(request.user_id); - db.set_channel_member_role( - channel_id, - session.user_id, - member_id, - request.role().into(), - ) - .await?; + let channel_member = db + .set_channel_member_role( + channel_id, + session.user_id, + member_id, + request.role().into(), + ) + .await?; - let (channel, has_accepted) = db - .get_channel(channel_id, member_id) - .await? - .ok_or_else(|| anyhow!("channel not found"))?; + let channel = db.get_channel(channel_id, session.user_id).await?; let mut update = proto::UpdateChannels::default(); - if has_accepted { + if channel_member.accepted { update.channel_permissions.push(proto::ChannelPermission { channel_id: channel.id.to_proto(), role: request.role, @@ -2724,9 +2719,11 @@ async fn join_channel_internal( channel_id: joined_room.channel_id.map(|id| id.to_proto()), live_kit_connection_info, })?; + dbg!("Joined channel", &joined_channel); - if joined_channel { - channel_membership_updated(db, channel_id, &session).await? + if let Some(joined_channel) = joined_channel { + dbg!("CMU"); + channel_membership_updated(db, joined_channel, &session).await? } room_updated(&joined_room.room, &session.peer); diff --git a/crates/collab/src/tests/channel_tests.rs b/crates/collab/src/tests/channel_tests.rs index 1700dfc5d3..1bb8c92ac8 100644 --- a/crates/collab/src/tests/channel_tests.rs +++ b/crates/collab/src/tests/channel_tests.rs @@ -7,7 +7,7 @@ use channel::{ChannelId, ChannelMembership, ChannelStore}; use client::User; use gpui::{executor::Deterministic, ModelHandle, TestAppContext}; use rpc::{ - proto::{self}, + proto::{self, ChannelRole}, RECEIVE_TIMEOUT, }; use std::sync::Arc; @@ -965,6 +965,67 @@ async fn test_guest_access( }) } +#[gpui::test] +async fn test_invite_access( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channels = server + .make_channel_tree( + &[("channel-a", None), ("channel-b", Some("channel-a"))], + (&client_a, cx_a), + ) + .await; + let channel_a_id = channels[0]; + let channel_b_id = channels[0]; + + let active_call_b = cx_b.read(ActiveCall::global); + + // should not be allowed to join + assert!(active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_b_id, cx)) + .await + .is_err()); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.invite_member( + channel_a_id, + client_b.user_id().unwrap(), + ChannelRole::Member, + cx, + ) + }) + .await + .unwrap(); + + active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_b_id, cx)) + .await + .unwrap(); + + deterministic.run_until_parked(); + + client_b.channel_store().update(cx_b, |channel_store, _| { + assert!(channel_store.channel_for_id(channel_b_id).is_some()); + assert!(channel_store.channel_for_id(channel_a_id).is_some()); + }); + + client_a.channel_store().update(cx_a, |channel_store, _| { + let participants = channel_store.channel_participants(channel_b_id); + assert_eq!(participants.len(), 1); + assert_eq!(participants[0].id, client_b.user_id().unwrap()); + }) +} + #[gpui::test] async fn test_channel_moving( deterministic: Arc,