Always use the database to retrieve collaborators for a project

This commit is contained in:
Antonio Scandurra 2022-11-15 17:49:37 +01:00
parent e9eadcaa6a
commit ad67f5e4de
3 changed files with 160 additions and 100 deletions

View file

@ -1886,6 +1886,64 @@ where
.await .await
} }
pub async fn project_collaborators(
&self,
project_id: ProjectId,
connection_id: ConnectionId,
) -> Result<Vec<ProjectCollaborator>> {
self.transact(|mut tx| async move {
let collaborators = sqlx::query_as::<_, ProjectCollaborator>(
"
SELECT *
FROM project_collaborators
WHERE project_id = $1
",
)
.bind(project_id)
.fetch_all(&mut tx)
.await?;
if collaborators
.iter()
.any(|collaborator| collaborator.connection_id == connection_id.0 as i32)
{
Ok(collaborators)
} else {
Err(anyhow!("no such project"))?
}
})
.await
}
pub async fn project_connection_ids(
&self,
project_id: ProjectId,
connection_id: ConnectionId,
) -> Result<HashSet<ConnectionId>> {
self.transact(|mut tx| async move {
let connection_ids = sqlx::query_scalar::<_, i32>(
"
SELECT connection_id
FROM project_collaborators
WHERE project_id = $1
",
)
.bind(project_id)
.fetch_all(&mut tx)
.await?;
if connection_ids.contains(&(connection_id.0 as i32)) {
Ok(connection_ids
.into_iter()
.map(|connection_id| ConnectionId(connection_id as u32))
.collect())
} else {
Err(anyhow!("no such project"))?
}
})
.await
}
pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> { pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> {
todo!() todo!()
// test_support!(self, { // test_support!(self, {

View file

@ -1187,13 +1187,15 @@ impl Server {
self: Arc<Server>, self: Arc<Server>,
request: Message<proto::UpdateLanguageServer>, request: Message<proto::UpdateLanguageServer>,
) -> Result<()> { ) -> Result<()> {
let receiver_ids = self.store().await.project_connection_ids( let project_id = ProjectId::from_proto(request.payload.project_id);
ProjectId::from_proto(request.payload.project_id), let project_connection_ids = self
request.sender_connection_id, .app_state
)?; .db
.project_connection_ids(project_id, request.sender_connection_id)
.await?;
broadcast( broadcast(
request.sender_connection_id, request.sender_connection_id,
receiver_ids, project_connection_ids,
|connection_id| { |connection_id| {
self.peer.forward_send( self.peer.forward_send(
request.sender_connection_id, request.sender_connection_id,
@ -1214,25 +1216,25 @@ impl Server {
T: EntityMessage + RequestMessage, T: EntityMessage + RequestMessage,
{ {
let project_id = ProjectId::from_proto(request.payload.remote_entity_id()); let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
let host_connection_id = self let collaborators = self
.store() .app_state
.await .db
.read_project(project_id, request.sender_connection_id)? .project_collaborators(project_id, request.sender_connection_id)
.host_connection_id; .await?;
let host = collaborators
.iter()
.find(|collaborator| collaborator.is_host)
.ok_or_else(|| anyhow!("host not found"))?;
let payload = self let payload = self
.peer .peer
.forward_request( .forward_request(
request.sender_connection_id, request.sender_connection_id,
host_connection_id, ConnectionId(host.connection_id as u32),
request.payload, request.payload,
) )
.await?; .await?;
// Ensure project still exists by the time we get the response from the host.
self.store()
.await
.read_project(project_id, request.sender_connection_id)?;
response.send(payload)?; response.send(payload)?;
Ok(()) Ok(())
} }
@ -1243,25 +1245,39 @@ impl Server {
response: Response<proto::SaveBuffer>, response: Response<proto::SaveBuffer>,
) -> Result<()> { ) -> Result<()> {
let project_id = ProjectId::from_proto(request.payload.project_id); let project_id = ProjectId::from_proto(request.payload.project_id);
let host = self let collaborators = self
.store() .app_state
.await .db
.read_project(project_id, request.sender_connection_id)? .project_collaborators(project_id, request.sender_connection_id)
.host_connection_id; .await?;
let host = collaborators
.into_iter()
.find(|collaborator| collaborator.is_host)
.ok_or_else(|| anyhow!("host not found"))?;
let host_connection_id = ConnectionId(host.connection_id as u32);
let response_payload = self let response_payload = self
.peer .peer
.forward_request(request.sender_connection_id, host, request.payload.clone()) .forward_request(
request.sender_connection_id,
host_connection_id,
request.payload.clone(),
)
.await?; .await?;
let mut guests = self let mut collaborators = self
.store() .app_state
.await .db
.read_project(project_id, request.sender_connection_id)? .project_collaborators(project_id, request.sender_connection_id)
.connection_ids(); .await?;
guests.retain(|guest_connection_id| *guest_connection_id != request.sender_connection_id); collaborators.retain(|collaborator| {
broadcast(host, guests, |conn_id| { collaborator.connection_id != request.sender_connection_id.0 as i32
});
let project_connection_ids = collaborators
.into_iter()
.map(|collaborator| ConnectionId(collaborator.connection_id as u32));
broadcast(host_connection_id, project_connection_ids, |conn_id| {
self.peer self.peer
.forward_send(host, conn_id, response_payload.clone()) .forward_send(host_connection_id, conn_id, response_payload.clone())
}); });
response.send(response_payload)?; response.send(response_payload)?;
Ok(()) Ok(())
@ -1285,14 +1301,15 @@ impl Server {
response: Response<proto::UpdateBuffer>, response: Response<proto::UpdateBuffer>,
) -> Result<()> { ) -> Result<()> {
let project_id = ProjectId::from_proto(request.payload.project_id); let project_id = ProjectId::from_proto(request.payload.project_id);
let receiver_ids = { let project_connection_ids = self
let store = self.store().await; .app_state
store.project_connection_ids(project_id, request.sender_connection_id)? .db
}; .project_connection_ids(project_id, request.sender_connection_id)
.await?;
broadcast( broadcast(
request.sender_connection_id, request.sender_connection_id,
receiver_ids, project_connection_ids,
|connection_id| { |connection_id| {
self.peer.forward_send( self.peer.forward_send(
request.sender_connection_id, request.sender_connection_id,
@ -1309,13 +1326,16 @@ impl Server {
self: Arc<Server>, self: Arc<Server>,
request: Message<proto::UpdateBufferFile>, request: Message<proto::UpdateBufferFile>,
) -> Result<()> { ) -> Result<()> {
let receiver_ids = self.store().await.project_connection_ids( let project_id = ProjectId::from_proto(request.payload.project_id);
ProjectId::from_proto(request.payload.project_id), let project_connection_ids = self
request.sender_connection_id, .app_state
)?; .db
.project_connection_ids(project_id, request.sender_connection_id)
.await?;
broadcast( broadcast(
request.sender_connection_id, request.sender_connection_id,
receiver_ids, project_connection_ids,
|connection_id| { |connection_id| {
self.peer.forward_send( self.peer.forward_send(
request.sender_connection_id, request.sender_connection_id,
@ -1331,13 +1351,15 @@ impl Server {
self: Arc<Server>, self: Arc<Server>,
request: Message<proto::BufferReloaded>, request: Message<proto::BufferReloaded>,
) -> Result<()> { ) -> Result<()> {
let receiver_ids = self.store().await.project_connection_ids( let project_id = ProjectId::from_proto(request.payload.project_id);
ProjectId::from_proto(request.payload.project_id), let project_connection_ids = self
request.sender_connection_id, .app_state
)?; .db
.project_connection_ids(project_id, request.sender_connection_id)
.await?;
broadcast( broadcast(
request.sender_connection_id, request.sender_connection_id,
receiver_ids, project_connection_ids,
|connection_id| { |connection_id| {
self.peer.forward_send( self.peer.forward_send(
request.sender_connection_id, request.sender_connection_id,
@ -1350,13 +1372,15 @@ impl Server {
} }
async fn buffer_saved(self: Arc<Server>, request: Message<proto::BufferSaved>) -> Result<()> { async fn buffer_saved(self: Arc<Server>, request: Message<proto::BufferSaved>) -> Result<()> {
let receiver_ids = self.store().await.project_connection_ids( let project_id = ProjectId::from_proto(request.payload.project_id);
ProjectId::from_proto(request.payload.project_id), let project_connection_ids = self
request.sender_connection_id, .app_state
)?; .db
.project_connection_ids(project_id, request.sender_connection_id)
.await?;
broadcast( broadcast(
request.sender_connection_id, request.sender_connection_id,
receiver_ids, project_connection_ids,
|connection_id| { |connection_id| {
self.peer.forward_send( self.peer.forward_send(
request.sender_connection_id, request.sender_connection_id,
@ -1376,14 +1400,14 @@ impl Server {
let project_id = ProjectId::from_proto(request.payload.project_id); let project_id = ProjectId::from_proto(request.payload.project_id);
let leader_id = ConnectionId(request.payload.leader_id); let leader_id = ConnectionId(request.payload.leader_id);
let follower_id = request.sender_connection_id; let follower_id = request.sender_connection_id;
{ let project_connection_ids = self
let store = self.store().await; .app_state
if !store .db
.project_connection_ids(project_id, follower_id)? .project_connection_ids(project_id, request.sender_connection_id)
.contains(&leader_id) .await?;
{
Err(anyhow!("no such peer"))?; if !project_connection_ids.contains(&leader_id) {
} Err(anyhow!("no such peer"))?;
} }
let mut response_payload = self let mut response_payload = self
@ -1400,11 +1424,12 @@ impl Server {
async fn unfollow(self: Arc<Self>, request: Message<proto::Unfollow>) -> Result<()> { async fn unfollow(self: Arc<Self>, request: Message<proto::Unfollow>) -> Result<()> {
let project_id = ProjectId::from_proto(request.payload.project_id); let project_id = ProjectId::from_proto(request.payload.project_id);
let leader_id = ConnectionId(request.payload.leader_id); let leader_id = ConnectionId(request.payload.leader_id);
let store = self.store().await; let project_connection_ids = self
if !store .app_state
.project_connection_ids(project_id, request.sender_connection_id)? .db
.contains(&leader_id) .project_connection_ids(project_id, request.sender_connection_id)
{ .await?;
if !project_connection_ids.contains(&leader_id) {
Err(anyhow!("no such peer"))?; Err(anyhow!("no such peer"))?;
} }
self.peer self.peer
@ -1417,9 +1442,12 @@ impl Server {
request: Message<proto::UpdateFollowers>, request: Message<proto::UpdateFollowers>,
) -> Result<()> { ) -> Result<()> {
let project_id = ProjectId::from_proto(request.payload.project_id); let project_id = ProjectId::from_proto(request.payload.project_id);
let store = self.store().await; let project_connection_ids = self
let connection_ids = .app_state
store.project_connection_ids(project_id, request.sender_connection_id)?; .db
.project_connection_ids(project_id, request.sender_connection_id)
.await?;
let leader_id = request let leader_id = request
.payload .payload
.variant .variant
@ -1431,7 +1459,7 @@ impl Server {
}); });
for follower_id in &request.payload.follower_ids { for follower_id in &request.payload.follower_ids {
let follower_id = ConnectionId(*follower_id); let follower_id = ConnectionId(*follower_id);
if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id { if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
self.peer.forward_send( self.peer.forward_send(
request.sender_connection_id, request.sender_connection_id,
follower_id, follower_id,
@ -1629,13 +1657,15 @@ impl Server {
self: Arc<Server>, self: Arc<Server>,
request: Message<proto::UpdateDiffBase>, request: Message<proto::UpdateDiffBase>,
) -> Result<()> { ) -> Result<()> {
let receiver_ids = self.store().await.project_connection_ids( let project_id = ProjectId::from_proto(request.payload.project_id);
ProjectId::from_proto(request.payload.project_id), let project_connection_ids = self
request.sender_connection_id, .app_state
)?; .db
.project_connection_ids(project_id, request.sender_connection_id)
.await?;
broadcast( broadcast(
request.sender_connection_id, request.sender_connection_id,
receiver_ids, project_connection_ids,
|connection_id| { |connection_id| {
self.peer.forward_send( self.peer.forward_send(
request.sender_connection_id, request.sender_connection_id,

View file

@ -325,34 +325,6 @@ impl Store {
}) })
} }
pub fn project_connection_ids(
&self,
project_id: ProjectId,
acting_connection_id: ConnectionId,
) -> Result<Vec<ConnectionId>> {
Ok(self
.read_project(project_id, acting_connection_id)?
.connection_ids())
}
pub fn read_project(
&self,
project_id: ProjectId,
connection_id: ConnectionId,
) -> Result<&Project> {
let project = self
.projects
.get(&project_id)
.ok_or_else(|| anyhow!("no such project"))?;
if project.host_connection_id == connection_id
|| project.guests.contains_key(&connection_id)
{
Ok(project)
} else {
Err(anyhow!("no such project"))?
}
}
#[cfg(test)] #[cfg(test)]
pub fn check_invariants(&self) { pub fn check_invariants(&self) {
for (connection_id, connection) in &self.connections { for (connection_id, connection) in &self.connections {