Add check_is_channel_participant

Refactor permission checks to load ancestor permissions into memory
for all checks to make the different logics more explicit.
This commit is contained in:
Conrad Irwin 2023-10-12 19:59:50 -06:00
parent 78432d08ca
commit a7db2aa39d
7 changed files with 292 additions and 38 deletions

View file

@ -192,7 +192,8 @@ CREATE INDEX "index_followers_on_room_id" ON "followers" ("room_id");
CREATE TABLE "channels" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
"name" VARCHAR NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT now
"created_at" TIMESTAMP NOT NULL DEFAULT now,
"visibility" VARCHAR NOT NULL
);
CREATE TABLE IF NOT EXISTS "channel_chat_participants" (

View file

@ -91,6 +91,8 @@ pub enum ChannelRole {
Member,
#[sea_orm(string_value = "guest")]
Guest,
#[sea_orm(string_value = "banned")]
Banned,
}
impl From<proto::ChannelRole> for ChannelRole {
@ -99,6 +101,7 @@ impl From<proto::ChannelRole> for ChannelRole {
proto::ChannelRole::Admin => ChannelRole::Admin,
proto::ChannelRole::Member => ChannelRole::Member,
proto::ChannelRole::Guest => ChannelRole::Guest,
proto::ChannelRole::Banned => ChannelRole::Banned,
}
}
}
@ -109,6 +112,7 @@ impl Into<proto::ChannelRole> for ChannelRole {
ChannelRole::Admin => proto::ChannelRole::Admin,
ChannelRole::Member => proto::ChannelRole::Member,
ChannelRole::Guest => proto::ChannelRole::Guest,
ChannelRole::Banned => proto::ChannelRole::Banned,
}
}
}

View file

@ -37,8 +37,9 @@ impl Database {
}
let channel = channel::ActiveModel {
id: ActiveValue::NotSet,
name: ActiveValue::Set(name.to_string()),
..Default::default()
visibility: ActiveValue::Set(ChannelVisibility::ChannelMembers),
}
.insert(&*tx)
.await?;
@ -89,6 +90,29 @@ impl Database {
.await
}
pub async fn set_channel_visibility(
&self,
channel_id: ChannelId,
visibility: ChannelVisibility,
user_id: UserId,
) -> Result<()> {
self.transaction(move |tx| async move {
self.check_user_is_channel_admin(channel_id, user_id, &*tx)
.await?;
channel::ActiveModel {
id: ActiveValue::Unchanged(channel_id),
visibility: ActiveValue::Set(visibility),
..Default::default()
}
.update(&*tx)
.await?;
Ok(())
})
.await
}
pub async fn delete_channel(
&self,
channel_id: ChannelId,
@ -160,11 +184,11 @@ impl Database {
&self,
channel_id: ChannelId,
invitee_id: UserId,
inviter_id: UserId,
admin_id: UserId,
role: ChannelRole,
) -> Result<()> {
self.transaction(move |tx| async move {
self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
.await?;
channel_member::ActiveModel {
@ -262,10 +286,10 @@ impl Database {
&self,
channel_id: ChannelId,
member_id: UserId,
remover_id: UserId,
admin_id: UserId,
) -> Result<()> {
self.transaction(|tx| async move {
self.check_user_is_channel_admin(channel_id, remover_id, &*tx)
self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
.await?;
let result = channel_member::Entity::delete_many()
@ -481,12 +505,12 @@ impl Database {
pub async fn set_channel_member_role(
&self,
channel_id: ChannelId,
from: UserId,
admin_id: UserId,
for_user: UserId,
role: ChannelRole,
) -> Result<()> {
self.transaction(|tx| async move {
self.check_user_is_channel_admin(channel_id, from, &*tx)
self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
.await?;
let result = channel_member::Entity::update_many()
@ -613,43 +637,147 @@ impl Database {
Ok(user_ids)
}
pub async fn check_user_is_channel_member(
&self,
channel_id: ChannelId,
user_id: UserId,
tx: &DatabaseTransaction,
) -> Result<()> {
let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
channel_member::Entity::find()
.filter(
channel_member::Column::ChannelId
.is_in(channel_ids)
.and(channel_member::Column::UserId.eq(user_id)),
)
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
Ok(())
}
pub async fn check_user_is_channel_admin(
&self,
channel_id: ChannelId,
user_id: UserId,
tx: &DatabaseTransaction,
) -> Result<()> {
match self.channel_role_for_user(channel_id, user_id, tx).await? {
Some(ChannelRole::Admin) => Ok(()),
Some(ChannelRole::Member)
| Some(ChannelRole::Banned)
| Some(ChannelRole::Guest)
| None => Err(anyhow!(
"user is not a channel admin or channel does not exist"
))?,
}
}
pub async fn check_user_is_channel_member(
&self,
channel_id: ChannelId,
user_id: UserId,
tx: &DatabaseTransaction,
) -> Result<()> {
match self.channel_role_for_user(channel_id, user_id, tx).await? {
Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(()),
Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!(
"user is not a channel member or channel does not exist"
))?,
}
}
pub async fn check_user_is_channel_participant(
&self,
channel_id: ChannelId,
user_id: UserId,
tx: &DatabaseTransaction,
) -> Result<()> {
match self.channel_role_for_user(channel_id, user_id, tx).await? {
Some(ChannelRole::Admin) | Some(ChannelRole::Member) | Some(ChannelRole::Guest) => {
Ok(())
}
Some(ChannelRole::Banned) | None => Err(anyhow!(
"user is not a channel participant or channel does not exist"
))?,
}
}
pub async fn channel_role_for_user(
&self,
channel_id: ChannelId,
user_id: UserId,
tx: &DatabaseTransaction,
) -> Result<Option<ChannelRole>> {
let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
channel_member::Entity::find()
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryChannelMembership {
ChannelId,
Role,
Admin,
Visibility,
}
let mut rows = channel_member::Entity::find()
.left_join(channel::Entity)
.filter(
channel_member::Column::ChannelId
.is_in(channel_ids)
.and(channel_member::Column::UserId.eq(user_id))
.and(channel_member::Column::Admin.eq(true)),
.and(channel_member::Column::UserId.eq(user_id)),
)
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
Ok(())
.select_only()
.column(channel_member::Column::ChannelId)
.column(channel_member::Column::Role)
.column(channel_member::Column::Admin)
.column(channel::Column::Visibility)
.into_values::<_, QueryChannelMembership>()
.stream(&*tx)
.await?;
let mut is_admin = false;
let mut is_member = false;
let mut is_participant = false;
let mut is_banned = false;
let mut current_channel_visibility = None;
// note these channels are not iterated in any particular order,
// our current logic takes the highest permission available.
while let Some(row) = rows.next().await {
let (ch_id, role, admin, visibility): (
ChannelId,
Option<ChannelRole>,
bool,
ChannelVisibility,
) = row?;
match role {
Some(ChannelRole::Admin) => is_admin = true,
Some(ChannelRole::Member) => is_member = true,
Some(ChannelRole::Guest) => {
if visibility == ChannelVisibility::Public {
is_participant = true
}
}
Some(ChannelRole::Banned) => is_banned = true,
None => {
// rows created from pre-role collab server.
if admin {
is_admin = true
} else {
is_member = true
}
}
}
if channel_id == ch_id {
current_channel_visibility = Some(visibility);
}
}
// free up database connection
drop(rows);
Ok(if is_admin {
Some(ChannelRole::Admin)
} else if is_member {
Some(ChannelRole::Member)
} else if is_banned {
Some(ChannelRole::Banned)
} else if is_participant {
if current_channel_visibility.is_none() {
current_channel_visibility = channel::Entity::find()
.filter(channel::Column::Id.eq(channel_id))
.one(&*tx)
.await?
.map(|channel| channel.visibility);
}
if current_channel_visibility == Some(ChannelVisibility::Public) {
Some(ChannelRole::Guest)
} else {
None
}
} else {
None
})
}
/// Returns the channel ancestors, deepest first

View file

@ -7,7 +7,7 @@ pub struct Model {
#[sea_orm(primary_key)]
pub id: ChannelId,
pub name: String,
pub visbility: ChannelVisibility,
pub visibility: ChannelVisibility,
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -8,11 +8,14 @@ use crate::{
db::{
queries::channels::ChannelGraph,
tests::{graph, TEST_RELEASE_CHANNEL},
ChannelId, ChannelRole, Database, NewUserParams,
ChannelId, ChannelRole, Database, NewUserParams, UserId,
},
test_both_dbs,
};
use std::sync::Arc;
use std::sync::{
atomic::{AtomicI32, Ordering},
Arc,
};
test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite);
@ -850,6 +853,101 @@ async fn test_db_channel_moving_bugs(db: &Arc<Database>) {
);
}
test_both_dbs!(
test_user_is_channel_participant,
test_user_is_channel_participant_postgres,
test_user_is_channel_participant_sqlite
);
async fn test_user_is_channel_participant(db: &Arc<Database>) {
let admin_id = new_test_user(db, "admin@example.com").await;
let member_id = new_test_user(db, "member@example.com").await;
let guest_id = new_test_user(db, "guest@example.com").await;
let zed_id = db.create_root_channel("zed", admin_id).await.unwrap();
let intermediate_id = db
.create_channel("active", Some(zed_id), admin_id)
.await
.unwrap();
let public_id = db
.create_channel("active", Some(intermediate_id), admin_id)
.await
.unwrap();
db.set_channel_visibility(public_id, crate::db::ChannelVisibility::Public, admin_id)
.await
.unwrap();
db.invite_channel_member(intermediate_id, member_id, admin_id, ChannelRole::Member)
.await
.unwrap();
db.invite_channel_member(public_id, guest_id, admin_id, ChannelRole::Guest)
.await
.unwrap();
db.transaction(|tx| async move {
db.check_user_is_channel_participant(public_id, admin_id, &*tx)
.await
})
.await
.unwrap();
db.transaction(|tx| async move {
db.check_user_is_channel_participant(public_id, member_id, &*tx)
.await
})
.await
.unwrap();
db.transaction(|tx| async move {
db.check_user_is_channel_participant(public_id, guest_id, &*tx)
.await
})
.await
.unwrap();
db.set_channel_member_role(public_id, admin_id, guest_id, ChannelRole::Banned)
.await
.unwrap();
assert!(db
.transaction(|tx| async move {
db.check_user_is_channel_participant(public_id, guest_id, &*tx)
.await
})
.await
.is_err());
db.remove_channel_member(public_id, guest_id, admin_id)
.await
.unwrap();
db.set_channel_visibility(zed_id, crate::db::ChannelVisibility::Public, admin_id)
.await
.unwrap();
db.invite_channel_member(zed_id, guest_id, admin_id, ChannelRole::Guest)
.await
.unwrap();
db.transaction(|tx| async move {
db.check_user_is_channel_participant(zed_id, guest_id, &*tx)
.await
})
.await
.unwrap();
assert!(db
.transaction(|tx| async move {
db.check_user_is_channel_participant(intermediate_id, guest_id, &*tx)
.await
})
.await
.is_err(),);
db.transaction(|tx| async move {
db.check_user_is_channel_participant(public_id, guest_id, &*tx)
.await
})
.await
.unwrap();
}
#[track_caller]
fn assert_dag(actual: ChannelGraph, expected: &[(ChannelId, Option<ChannelId>)]) {
let mut actual_map: HashMap<ChannelId, HashSet<ChannelId>> = HashMap::default();
@ -874,3 +972,22 @@ fn assert_dag(actual: ChannelGraph, expected: &[(ChannelId, Option<ChannelId>)])
pretty_assertions::assert_eq!(actual_map, expected_map)
}
static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
let gid = GITHUB_USER_ID.fetch_add(1, Ordering::SeqCst);
db.create_user(
email,
false,
NewUserParams {
github_login: email[0..email.find("@").unwrap()].to_string(),
github_user_id: GITHUB_USER_ID.fetch_add(1, Ordering::SeqCst),
invite_count: 0,
},
)
.await
.unwrap()
.user_id
}

View file

@ -6,7 +6,10 @@ use call::ActiveCall;
use channel::{ChannelId, ChannelMembership, ChannelStore};
use client::User;
use gpui::{executor::Deterministic, ModelHandle, TestAppContext};
use rpc::{proto, RECEIVE_TIMEOUT};
use rpc::{
proto::{self},
RECEIVE_TIMEOUT,
};
use std::sync::Arc;
#[gpui::test]

View file

@ -1040,6 +1040,7 @@ enum ChannelRole {
Admin = 0;
Member = 1;
Guest = 2;
Banned = 3;
}
message SetChannelMemberRole {