Only allow read-write users to update buffers

This commit is contained in:
Conrad Irwin 2024-01-08 15:39:24 -07:00
parent 18b31f1552
commit d7c5d29237
3 changed files with 69 additions and 21 deletions

View file

@ -148,6 +148,14 @@ impl ChannelRole {
Guest | Banned => false, Guest | Banned => false,
} }
} }
pub fn can_read_projects(&self) -> bool {
use ChannelRole::*;
match self {
Admin | Member | Guest => true,
Banned => false,
}
}
} }
impl From<proto::ChannelRole> for ChannelRole { impl From<proto::ChannelRole> for ChannelRole {

View file

@ -805,6 +805,43 @@ impl Database {
.map(|guard| guard.into_inner()) .map(|guard| guard.into_inner())
} }
pub async fn host_for_read_only_project_request(
&self,
project_id: ProjectId,
connection_id: ConnectionId,
) -> Result<ConnectionId> {
let room_id = self.room_id_for_project(project_id).await?;
self.room_transaction(room_id, |tx| async move {
let current_participant = room_participant::Entity::find()
.filter(room_participant::Column::RoomId.eq(room_id))
.filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.id))
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("no such room"))?;
if !current_participant
.role
.map_or(false, |role| role.can_read_projects())
{
Err(anyhow!("not authorized to read projects"))?;
}
let host = project_collaborator::Entity::find()
.filter(
project_collaborator::Column::ProjectId
.eq(project_id)
.and(project_collaborator::Column::IsHost.eq(true)),
)
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("failed to read project host"))?;
Ok(host.connection())
})
.await
.map(|guard| guard.into_inner())
}
pub async fn host_for_mutating_project_request( pub async fn host_for_mutating_project_request(
&self, &self,
project_id: ProjectId, project_id: ProjectId,
@ -821,8 +858,7 @@ impl Database {
if !current_participant if !current_participant
.role .role
.unwrap_or(ChannelRole::Guest) .map_or(false, |role| role.can_edit_projects())
.can_edit_projects()
{ {
Err(anyhow!("not authorized to edit projects"))?; Err(anyhow!("not authorized to edit projects"))?;
} }
@ -843,13 +879,27 @@ impl Database {
.map(|guard| guard.into_inner()) .map(|guard| guard.into_inner())
} }
pub async fn project_collaborators( pub async fn project_collaborators_for_buffer_update(
&self, &self,
project_id: ProjectId, project_id: ProjectId,
connection_id: ConnectionId, connection_id: ConnectionId,
) -> Result<RoomGuard<Vec<ProjectCollaborator>>> { ) -> Result<RoomGuard<Vec<ProjectCollaborator>>> {
let room_id = self.room_id_for_project(project_id).await?; let room_id = self.room_id_for_project(project_id).await?;
self.room_transaction(room_id, |tx| async move { self.room_transaction(room_id, |tx| async move {
let current_participant = room_participant::Entity::find()
.filter(room_participant::Column::RoomId.eq(room_id))
.filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.id))
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("no such room"))?;
if !current_participant
.role
.map_or(false, |role| role.can_edit_projects())
{
Err(anyhow!("not authorized to edit projects"))?;
}
let collaborators = project_collaborator::Entity::find() let collaborators = project_collaborator::Entity::find()
.filter(project_collaborator::Column::ProjectId.eq(project_id)) .filter(project_collaborator::Column::ProjectId.eq(project_id))
.all(&*tx) .all(&*tx)

View file

@ -227,7 +227,7 @@ impl Server {
.add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>) .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
.add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>) .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
.add_request_handler(forward_read_only_project_request::<proto::InlayHints>) .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
.add_request_handler(forward_mutating_project_request::<proto::OpenBufferByPath>) .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
.add_request_handler(forward_mutating_project_request::<proto::GetCompletions>) .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
.add_request_handler( .add_request_handler(
forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>, forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
@ -1750,24 +1750,15 @@ where
T: EntityMessage + RequestMessage, T: EntityMessage + RequestMessage,
{ {
let project_id = ProjectId::from_proto(request.remote_entity_id()); let project_id = ProjectId::from_proto(request.remote_entity_id());
let host_connection_id = { let host_connection_id = session
let collaborators = session .db()
.db() .await
.await .host_for_read_only_project_request(project_id, session.connection_id)
.project_collaborators(project_id, session.connection_id) .await?;
.await?;
collaborators
.iter()
.find(|collaborator| collaborator.is_host)
.ok_or_else(|| anyhow!("host not found"))?
.connection_id
};
let payload = session let payload = session
.peer .peer
.forward_request(session.connection_id, host_connection_id, request) .forward_request(session.connection_id, host_connection_id, request)
.await?; .await?;
response.send(payload)?; response.send(payload)?;
Ok(()) Ok(())
} }
@ -1786,12 +1777,10 @@ where
.await .await
.host_for_mutating_project_request(project_id, session.connection_id) .host_for_mutating_project_request(project_id, session.connection_id)
.await?; .await?;
let payload = session let payload = session
.peer .peer
.forward_request(session.connection_id, host_connection_id, request) .forward_request(session.connection_id, host_connection_id, request)
.await?; .await?;
response.send(payload)?; response.send(payload)?;
Ok(()) Ok(())
} }
@ -1823,11 +1812,12 @@ async fn update_buffer(
let project_id = ProjectId::from_proto(request.project_id); let project_id = ProjectId::from_proto(request.project_id);
let mut guest_connection_ids; let mut guest_connection_ids;
let mut host_connection_id = None; let mut host_connection_id = None;
{ {
let collaborators = session let collaborators = session
.db() .db()
.await .await
.project_collaborators(project_id, session.connection_id) .project_collaborators_for_buffer_update(project_id, session.connection_id)
.await?; .await?;
guest_connection_ids = Vec::with_capacity(collaborators.len() - 1); guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
for collaborator in collaborators.iter() { for collaborator in collaborators.iter() {