mirror of
https://github.com/zed-industries/zed.git
synced 2025-02-03 17:44:30 +00:00
Always use the database to retrieve collaborators for a project
This commit is contained in:
parent
e9eadcaa6a
commit
ad67f5e4de
3 changed files with 160 additions and 100 deletions
|
@ -1886,6 +1886,64 @@ where
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn project_collaborators(
|
||||||
|
&self,
|
||||||
|
project_id: ProjectId,
|
||||||
|
connection_id: ConnectionId,
|
||||||
|
) -> Result<Vec<ProjectCollaborator>> {
|
||||||
|
self.transact(|mut tx| async move {
|
||||||
|
let collaborators = sqlx::query_as::<_, ProjectCollaborator>(
|
||||||
|
"
|
||||||
|
SELECT *
|
||||||
|
FROM project_collaborators
|
||||||
|
WHERE project_id = $1
|
||||||
|
",
|
||||||
|
)
|
||||||
|
.bind(project_id)
|
||||||
|
.fetch_all(&mut tx)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if collaborators
|
||||||
|
.iter()
|
||||||
|
.any(|collaborator| collaborator.connection_id == connection_id.0 as i32)
|
||||||
|
{
|
||||||
|
Ok(collaborators)
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("no such project"))?
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn project_connection_ids(
|
||||||
|
&self,
|
||||||
|
project_id: ProjectId,
|
||||||
|
connection_id: ConnectionId,
|
||||||
|
) -> Result<HashSet<ConnectionId>> {
|
||||||
|
self.transact(|mut tx| async move {
|
||||||
|
let connection_ids = sqlx::query_scalar::<_, i32>(
|
||||||
|
"
|
||||||
|
SELECT connection_id
|
||||||
|
FROM project_collaborators
|
||||||
|
WHERE project_id = $1
|
||||||
|
",
|
||||||
|
)
|
||||||
|
.bind(project_id)
|
||||||
|
.fetch_all(&mut tx)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if connection_ids.contains(&(connection_id.0 as i32)) {
|
||||||
|
Ok(connection_ids
|
||||||
|
.into_iter()
|
||||||
|
.map(|connection_id| ConnectionId(connection_id as u32))
|
||||||
|
.collect())
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("no such project"))?
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> {
|
pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> {
|
||||||
todo!()
|
todo!()
|
||||||
// test_support!(self, {
|
// test_support!(self, {
|
||||||
|
|
|
@ -1187,13 +1187,15 @@ impl Server {
|
||||||
self: Arc<Server>,
|
self: Arc<Server>,
|
||||||
request: Message<proto::UpdateLanguageServer>,
|
request: Message<proto::UpdateLanguageServer>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let receiver_ids = self.store().await.project_connection_ids(
|
let project_id = ProjectId::from_proto(request.payload.project_id);
|
||||||
ProjectId::from_proto(request.payload.project_id),
|
let project_connection_ids = self
|
||||||
request.sender_connection_id,
|
.app_state
|
||||||
)?;
|
.db
|
||||||
|
.project_connection_ids(project_id, request.sender_connection_id)
|
||||||
|
.await?;
|
||||||
broadcast(
|
broadcast(
|
||||||
request.sender_connection_id,
|
request.sender_connection_id,
|
||||||
receiver_ids,
|
project_connection_ids,
|
||||||
|connection_id| {
|
|connection_id| {
|
||||||
self.peer.forward_send(
|
self.peer.forward_send(
|
||||||
request.sender_connection_id,
|
request.sender_connection_id,
|
||||||
|
@ -1214,25 +1216,25 @@ impl Server {
|
||||||
T: EntityMessage + RequestMessage,
|
T: EntityMessage + RequestMessage,
|
||||||
{
|
{
|
||||||
let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
|
let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
|
||||||
let host_connection_id = self
|
let collaborators = self
|
||||||
.store()
|
.app_state
|
||||||
.await
|
.db
|
||||||
.read_project(project_id, request.sender_connection_id)?
|
.project_collaborators(project_id, request.sender_connection_id)
|
||||||
.host_connection_id;
|
.await?;
|
||||||
|
let host = collaborators
|
||||||
|
.iter()
|
||||||
|
.find(|collaborator| collaborator.is_host)
|
||||||
|
.ok_or_else(|| anyhow!("host not found"))?;
|
||||||
|
|
||||||
let payload = self
|
let payload = self
|
||||||
.peer
|
.peer
|
||||||
.forward_request(
|
.forward_request(
|
||||||
request.sender_connection_id,
|
request.sender_connection_id,
|
||||||
host_connection_id,
|
ConnectionId(host.connection_id as u32),
|
||||||
request.payload,
|
request.payload,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Ensure project still exists by the time we get the response from the host.
|
|
||||||
self.store()
|
|
||||||
.await
|
|
||||||
.read_project(project_id, request.sender_connection_id)?;
|
|
||||||
|
|
||||||
response.send(payload)?;
|
response.send(payload)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1243,25 +1245,39 @@ impl Server {
|
||||||
response: Response<proto::SaveBuffer>,
|
response: Response<proto::SaveBuffer>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let project_id = ProjectId::from_proto(request.payload.project_id);
|
let project_id = ProjectId::from_proto(request.payload.project_id);
|
||||||
let host = self
|
let collaborators = self
|
||||||
.store()
|
.app_state
|
||||||
.await
|
.db
|
||||||
.read_project(project_id, request.sender_connection_id)?
|
.project_collaborators(project_id, request.sender_connection_id)
|
||||||
.host_connection_id;
|
.await?;
|
||||||
|
let host = collaborators
|
||||||
|
.into_iter()
|
||||||
|
.find(|collaborator| collaborator.is_host)
|
||||||
|
.ok_or_else(|| anyhow!("host not found"))?;
|
||||||
|
let host_connection_id = ConnectionId(host.connection_id as u32);
|
||||||
let response_payload = self
|
let response_payload = self
|
||||||
.peer
|
.peer
|
||||||
.forward_request(request.sender_connection_id, host, request.payload.clone())
|
.forward_request(
|
||||||
|
request.sender_connection_id,
|
||||||
|
host_connection_id,
|
||||||
|
request.payload.clone(),
|
||||||
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut guests = self
|
let mut collaborators = self
|
||||||
.store()
|
.app_state
|
||||||
.await
|
.db
|
||||||
.read_project(project_id, request.sender_connection_id)?
|
.project_collaborators(project_id, request.sender_connection_id)
|
||||||
.connection_ids();
|
.await?;
|
||||||
guests.retain(|guest_connection_id| *guest_connection_id != request.sender_connection_id);
|
collaborators.retain(|collaborator| {
|
||||||
broadcast(host, guests, |conn_id| {
|
collaborator.connection_id != request.sender_connection_id.0 as i32
|
||||||
|
});
|
||||||
|
let project_connection_ids = collaborators
|
||||||
|
.into_iter()
|
||||||
|
.map(|collaborator| ConnectionId(collaborator.connection_id as u32));
|
||||||
|
broadcast(host_connection_id, project_connection_ids, |conn_id| {
|
||||||
self.peer
|
self.peer
|
||||||
.forward_send(host, conn_id, response_payload.clone())
|
.forward_send(host_connection_id, conn_id, response_payload.clone())
|
||||||
});
|
});
|
||||||
response.send(response_payload)?;
|
response.send(response_payload)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1285,14 +1301,15 @@ impl Server {
|
||||||
response: Response<proto::UpdateBuffer>,
|
response: Response<proto::UpdateBuffer>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let project_id = ProjectId::from_proto(request.payload.project_id);
|
let project_id = ProjectId::from_proto(request.payload.project_id);
|
||||||
let receiver_ids = {
|
let project_connection_ids = self
|
||||||
let store = self.store().await;
|
.app_state
|
||||||
store.project_connection_ids(project_id, request.sender_connection_id)?
|
.db
|
||||||
};
|
.project_connection_ids(project_id, request.sender_connection_id)
|
||||||
|
.await?;
|
||||||
|
|
||||||
broadcast(
|
broadcast(
|
||||||
request.sender_connection_id,
|
request.sender_connection_id,
|
||||||
receiver_ids,
|
project_connection_ids,
|
||||||
|connection_id| {
|
|connection_id| {
|
||||||
self.peer.forward_send(
|
self.peer.forward_send(
|
||||||
request.sender_connection_id,
|
request.sender_connection_id,
|
||||||
|
@ -1309,13 +1326,16 @@ impl Server {
|
||||||
self: Arc<Server>,
|
self: Arc<Server>,
|
||||||
request: Message<proto::UpdateBufferFile>,
|
request: Message<proto::UpdateBufferFile>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let receiver_ids = self.store().await.project_connection_ids(
|
let project_id = ProjectId::from_proto(request.payload.project_id);
|
||||||
ProjectId::from_proto(request.payload.project_id),
|
let project_connection_ids = self
|
||||||
request.sender_connection_id,
|
.app_state
|
||||||
)?;
|
.db
|
||||||
|
.project_connection_ids(project_id, request.sender_connection_id)
|
||||||
|
.await?;
|
||||||
|
|
||||||
broadcast(
|
broadcast(
|
||||||
request.sender_connection_id,
|
request.sender_connection_id,
|
||||||
receiver_ids,
|
project_connection_ids,
|
||||||
|connection_id| {
|
|connection_id| {
|
||||||
self.peer.forward_send(
|
self.peer.forward_send(
|
||||||
request.sender_connection_id,
|
request.sender_connection_id,
|
||||||
|
@ -1331,13 +1351,15 @@ impl Server {
|
||||||
self: Arc<Server>,
|
self: Arc<Server>,
|
||||||
request: Message<proto::BufferReloaded>,
|
request: Message<proto::BufferReloaded>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let receiver_ids = self.store().await.project_connection_ids(
|
let project_id = ProjectId::from_proto(request.payload.project_id);
|
||||||
ProjectId::from_proto(request.payload.project_id),
|
let project_connection_ids = self
|
||||||
request.sender_connection_id,
|
.app_state
|
||||||
)?;
|
.db
|
||||||
|
.project_connection_ids(project_id, request.sender_connection_id)
|
||||||
|
.await?;
|
||||||
broadcast(
|
broadcast(
|
||||||
request.sender_connection_id,
|
request.sender_connection_id,
|
||||||
receiver_ids,
|
project_connection_ids,
|
||||||
|connection_id| {
|
|connection_id| {
|
||||||
self.peer.forward_send(
|
self.peer.forward_send(
|
||||||
request.sender_connection_id,
|
request.sender_connection_id,
|
||||||
|
@ -1350,13 +1372,15 @@ impl Server {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn buffer_saved(self: Arc<Server>, request: Message<proto::BufferSaved>) -> Result<()> {
|
async fn buffer_saved(self: Arc<Server>, request: Message<proto::BufferSaved>) -> Result<()> {
|
||||||
let receiver_ids = self.store().await.project_connection_ids(
|
let project_id = ProjectId::from_proto(request.payload.project_id);
|
||||||
ProjectId::from_proto(request.payload.project_id),
|
let project_connection_ids = self
|
||||||
request.sender_connection_id,
|
.app_state
|
||||||
)?;
|
.db
|
||||||
|
.project_connection_ids(project_id, request.sender_connection_id)
|
||||||
|
.await?;
|
||||||
broadcast(
|
broadcast(
|
||||||
request.sender_connection_id,
|
request.sender_connection_id,
|
||||||
receiver_ids,
|
project_connection_ids,
|
||||||
|connection_id| {
|
|connection_id| {
|
||||||
self.peer.forward_send(
|
self.peer.forward_send(
|
||||||
request.sender_connection_id,
|
request.sender_connection_id,
|
||||||
|
@ -1376,14 +1400,14 @@ impl Server {
|
||||||
let project_id = ProjectId::from_proto(request.payload.project_id);
|
let project_id = ProjectId::from_proto(request.payload.project_id);
|
||||||
let leader_id = ConnectionId(request.payload.leader_id);
|
let leader_id = ConnectionId(request.payload.leader_id);
|
||||||
let follower_id = request.sender_connection_id;
|
let follower_id = request.sender_connection_id;
|
||||||
{
|
let project_connection_ids = self
|
||||||
let store = self.store().await;
|
.app_state
|
||||||
if !store
|
.db
|
||||||
.project_connection_ids(project_id, follower_id)?
|
.project_connection_ids(project_id, request.sender_connection_id)
|
||||||
.contains(&leader_id)
|
.await?;
|
||||||
{
|
|
||||||
Err(anyhow!("no such peer"))?;
|
if !project_connection_ids.contains(&leader_id) {
|
||||||
}
|
Err(anyhow!("no such peer"))?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut response_payload = self
|
let mut response_payload = self
|
||||||
|
@ -1400,11 +1424,12 @@ impl Server {
|
||||||
async fn unfollow(self: Arc<Self>, request: Message<proto::Unfollow>) -> Result<()> {
|
async fn unfollow(self: Arc<Self>, request: Message<proto::Unfollow>) -> Result<()> {
|
||||||
let project_id = ProjectId::from_proto(request.payload.project_id);
|
let project_id = ProjectId::from_proto(request.payload.project_id);
|
||||||
let leader_id = ConnectionId(request.payload.leader_id);
|
let leader_id = ConnectionId(request.payload.leader_id);
|
||||||
let store = self.store().await;
|
let project_connection_ids = self
|
||||||
if !store
|
.app_state
|
||||||
.project_connection_ids(project_id, request.sender_connection_id)?
|
.db
|
||||||
.contains(&leader_id)
|
.project_connection_ids(project_id, request.sender_connection_id)
|
||||||
{
|
.await?;
|
||||||
|
if !project_connection_ids.contains(&leader_id) {
|
||||||
Err(anyhow!("no such peer"))?;
|
Err(anyhow!("no such peer"))?;
|
||||||
}
|
}
|
||||||
self.peer
|
self.peer
|
||||||
|
@ -1417,9 +1442,12 @@ impl Server {
|
||||||
request: Message<proto::UpdateFollowers>,
|
request: Message<proto::UpdateFollowers>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let project_id = ProjectId::from_proto(request.payload.project_id);
|
let project_id = ProjectId::from_proto(request.payload.project_id);
|
||||||
let store = self.store().await;
|
let project_connection_ids = self
|
||||||
let connection_ids =
|
.app_state
|
||||||
store.project_connection_ids(project_id, request.sender_connection_id)?;
|
.db
|
||||||
|
.project_connection_ids(project_id, request.sender_connection_id)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let leader_id = request
|
let leader_id = request
|
||||||
.payload
|
.payload
|
||||||
.variant
|
.variant
|
||||||
|
@ -1431,7 +1459,7 @@ impl Server {
|
||||||
});
|
});
|
||||||
for follower_id in &request.payload.follower_ids {
|
for follower_id in &request.payload.follower_ids {
|
||||||
let follower_id = ConnectionId(*follower_id);
|
let follower_id = ConnectionId(*follower_id);
|
||||||
if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
|
if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
|
||||||
self.peer.forward_send(
|
self.peer.forward_send(
|
||||||
request.sender_connection_id,
|
request.sender_connection_id,
|
||||||
follower_id,
|
follower_id,
|
||||||
|
@ -1629,13 +1657,15 @@ impl Server {
|
||||||
self: Arc<Server>,
|
self: Arc<Server>,
|
||||||
request: Message<proto::UpdateDiffBase>,
|
request: Message<proto::UpdateDiffBase>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let receiver_ids = self.store().await.project_connection_ids(
|
let project_id = ProjectId::from_proto(request.payload.project_id);
|
||||||
ProjectId::from_proto(request.payload.project_id),
|
let project_connection_ids = self
|
||||||
request.sender_connection_id,
|
.app_state
|
||||||
)?;
|
.db
|
||||||
|
.project_connection_ids(project_id, request.sender_connection_id)
|
||||||
|
.await?;
|
||||||
broadcast(
|
broadcast(
|
||||||
request.sender_connection_id,
|
request.sender_connection_id,
|
||||||
receiver_ids,
|
project_connection_ids,
|
||||||
|connection_id| {
|
|connection_id| {
|
||||||
self.peer.forward_send(
|
self.peer.forward_send(
|
||||||
request.sender_connection_id,
|
request.sender_connection_id,
|
||||||
|
|
|
@ -325,34 +325,6 @@ impl Store {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn project_connection_ids(
|
|
||||||
&self,
|
|
||||||
project_id: ProjectId,
|
|
||||||
acting_connection_id: ConnectionId,
|
|
||||||
) -> Result<Vec<ConnectionId>> {
|
|
||||||
Ok(self
|
|
||||||
.read_project(project_id, acting_connection_id)?
|
|
||||||
.connection_ids())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn read_project(
|
|
||||||
&self,
|
|
||||||
project_id: ProjectId,
|
|
||||||
connection_id: ConnectionId,
|
|
||||||
) -> Result<&Project> {
|
|
||||||
let project = self
|
|
||||||
.projects
|
|
||||||
.get(&project_id)
|
|
||||||
.ok_or_else(|| anyhow!("no such project"))?;
|
|
||||||
if project.host_connection_id == connection_id
|
|
||||||
|| project.guests.contains_key(&connection_id)
|
|
||||||
{
|
|
||||||
Ok(project)
|
|
||||||
} else {
|
|
||||||
Err(anyhow!("no such project"))?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub fn check_invariants(&self) {
|
pub fn check_invariants(&self) {
|
||||||
for (connection_id, connection) in &self.connections {
|
for (connection_id, connection) in &self.connections {
|
||||||
|
|
Loading…
Reference in a new issue