From 3631fbd874a0bf3b788d8b5f7581fa7fb0e51171 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 19 Aug 2021 12:17:52 -0700 Subject: [PATCH] Consolidate server's rpc state into the rpc::Server struct Co-Authored-By: Nathan Sobo --- server/src/auth.rs | 41 +- server/src/main.rs | 4 +- server/src/rpc.rs | 1130 +++++++++++++++++++++---------------------- server/src/tests.rs | 3 +- 4 files changed, 569 insertions(+), 609 deletions(-) diff --git a/server/src/auth.rs b/server/src/auth.rs index d61428fa37..5a3e301d27 100644 --- a/server/src/auth.rs +++ b/server/src/auth.rs @@ -2,7 +2,7 @@ use super::{ db::{self, UserId}, errors::TideResultExt, }; -use crate::{github, rpc, AppState, Request, RequestExt as _}; +use crate::{github, AppState, Request, RequestExt as _}; use anyhow::{anyhow, Context}; use async_trait::async_trait; pub use oauth2::basic::BasicClient as Client; @@ -19,7 +19,7 @@ use serde::{Deserialize, Serialize}; use std::{borrow::Cow, convert::TryFrom, sync::Arc}; use surf::Url; use tide::Server; -use zrpc::{auth as zed_auth, proto, Peer}; +use zrpc::auth as zed_auth; static CURRENT_GITHUB_USER: &'static str = "current_github_user"; static GITHUB_AUTH_URL: &'static str = "https://github.com/login/oauth/authorize"; @@ -100,43 +100,6 @@ impl RequestExt for Request { } } -#[async_trait] -pub trait PeerExt { - async fn sign_out( - self: &Arc, - connection_id: zrpc::ConnectionId, - state: &AppState, - ) -> tide::Result<()>; -} - -#[async_trait] -impl PeerExt for Peer { - async fn sign_out( - self: &Arc, - connection_id: zrpc::ConnectionId, - state: &AppState, - ) -> tide::Result<()> { - self.disconnect(connection_id).await; - let worktree_ids = state.rpc.write().await.remove_connection(connection_id); - for worktree_id in worktree_ids { - let state = state.rpc.read().await; - if let Some(worktree) = state.worktrees.get(&worktree_id) { - rpc::broadcast(connection_id, worktree.connection_ids(), |conn_id| { - self.send( - conn_id, - proto::RemovePeer { - worktree_id, - peer_id: connection_id.0, - }, - ) - }) - .await?; - } - } - Ok(()) - } -} - pub fn build_client(client_id: &str, client_secret: &str) -> Client { Client::new( ClientId::new(client_id.to_string()), diff --git a/server/src/main.rs b/server/src/main.rs index ec153bea8f..b98c8b0d04 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -14,7 +14,7 @@ mod tests; use self::errors::TideResultExt as _; use anyhow::{Context, Result}; -use async_std::{net::TcpListener, sync::RwLock as AsyncRwLock}; +use async_std::net::TcpListener; use async_trait::async_trait; use auth::RequestExt as _; use db::{Db, DbOptions}; @@ -51,7 +51,6 @@ pub struct AppState { auth_client: auth::Client, github_client: Arc, repo_client: github::RepoClient, - rpc: AsyncRwLock, config: Config, } @@ -76,7 +75,6 @@ impl AppState { auth_client: auth::build_client(&config.github_client_id, &config.github_client_secret), github_client, repo_client, - rpc: Default::default(), config, }; this.register_partials(); diff --git a/server/src/rpc.rs b/server/src/rpc.rs index 77a50aceac..d107e3606b 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1,10 +1,10 @@ use super::{ - auth::{self, PeerExt as _}, + auth, db::{ChannelId, UserId}, AppState, }; use anyhow::anyhow; -use async_std::task; +use async_std::{sync::RwLock, task}; use async_tungstenite::{ tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage}, WebSocketStream, @@ -13,7 +13,7 @@ use futures::{future::BoxFuture, FutureExt}; use postage::prelude::Stream as _; use sha1::{Digest as _, Sha1}; use std::{ - any::{Any, TypeId}, + any::TypeId, collections::{HashMap, HashSet}, future::Future, mem, @@ -38,51 +38,90 @@ type ReplicaId = u16; type MessageHandler = Box< dyn Send + Sync - + Fn(Box, Arc) -> BoxFuture<'static, tide::Result<()>>, + + Fn(Arc, Box) -> BoxFuture<'static, tide::Result<()>>, >; -#[derive(Default)] -struct ServerBuilder { +pub struct Server { + peer: Arc, + state: RwLock, + app_state: Arc, handlers: HashMap, } -impl ServerBuilder { - pub fn on_message(mut self, handler: F) -> Self +#[derive(Default)] +struct ServerState { + connections: HashMap, + pub worktrees: HashMap, + channels: HashMap, + next_worktree_id: u64, +} + +struct Connection { + user_id: UserId, + worktrees: HashSet, + channels: HashSet, +} + +struct Worktree { + host_connection_id: Option, + guest_connection_ids: HashMap, + active_replica_ids: HashSet, + access_token: String, + root_name: String, + entries: HashMap, +} + +#[derive(Default)] +struct Channel { + connection_ids: HashSet, +} + +impl Server { + pub fn new(app_state: Arc, peer: Arc) -> Arc { + let mut server = Server { + peer, + app_state, + state: Default::default(), + handlers: Default::default(), + }; + + server + .add_handler(Server::share_worktree) + .add_handler(Server::join_worktree) + .add_handler(Server::update_worktree) + .add_handler(Server::close_worktree) + .add_handler(Server::open_buffer) + .add_handler(Server::close_buffer) + .add_handler(Server::update_buffer) + .add_handler(Server::buffer_saved) + .add_handler(Server::save_buffer) + .add_handler(Server::get_channels) + .add_handler(Server::get_users) + .add_handler(Server::join_channel) + .add_handler(Server::send_channel_message); + + Arc::new(server) + } + + fn add_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Box>, Arc) -> Fut, + F: 'static + Send + Sync + Fn(Arc, TypedEnvelope) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { let prev_handler = self.handlers.insert( TypeId::of::(), - Box::new(move |envelope, server| { + Box::new(move |server, envelope| { let envelope = envelope.into_any().downcast::>().unwrap(); - (handler)(envelope, server).boxed() + (handler)(server, *envelope).boxed() }), ); if prev_handler.is_some() { panic!("registered a handler for the same message twice"); } - self } - pub fn build(self, rpc: &Arc, state: &Arc) -> Arc { - Arc::new(Server { - rpc: rpc.clone(), - state: state.clone(), - handlers: self.handlers, - }) - } -} - -pub struct Server { - rpc: Arc, - state: Arc, - handlers: HashMap, -} - -impl Server { pub fn handle_connection( self: &Arc, connection: Conn, @@ -99,12 +138,8 @@ impl Server { let this = self.clone(); async move { let (connection_id, handle_io, mut incoming_rx) = - this.rpc.add_connection(connection).await; - this.state - .rpc - .write() - .await - .add_connection(connection_id, user_id); + this.peer.add_connection(connection).await; + this.add_connection(connection_id, user_id).await; let handle_io = handle_io.fuse(); futures::pin_mut!(handle_io); @@ -117,7 +152,7 @@ impl Server { let start_time = Instant::now(); log::info!("RPC message received: {}", message.payload_type_name()); if let Some(handler) = this.handlers.get(&message.payload_type_id()) { - if let Err(err) = (handler)(message, this.clone()).await { + if let Err(err) = (handler)(this.clone(), message).await { log::error!("error handling message: {:?}", err); } else { log::info!("RPC message handled. duration:{:?}", start_time.elapsed()); @@ -139,67 +174,36 @@ impl Server { } } - if let Err(err) = this.rpc.sign_out(connection_id, &this.state).await { + if let Err(err) = this.sign_out(connection_id).await { log::error!("error signing out connection {:?} - {:?}", addr, err); } } } -} -#[derive(Default)] -pub struct State { - connections: HashMap, - pub worktrees: HashMap, - channels: HashMap, - next_worktree_id: u64, -} - -struct Connection { - user_id: UserId, - worktrees: HashSet, - channels: HashSet, -} - -pub struct Worktree { - host_connection_id: Option, - guest_connection_ids: HashMap, - active_replica_ids: HashSet, - access_token: String, - root_name: String, - entries: HashMap, -} - -#[derive(Default)] -struct Channel { - connection_ids: HashSet, -} - -impl Worktree { - pub fn connection_ids(&self) -> Vec { - self.guest_connection_ids - .keys() - .copied() - .chain(self.host_connection_id) - .collect() + async fn sign_out(self: &Arc, connection_id: zrpc::ConnectionId) -> tide::Result<()> { + self.peer.disconnect(connection_id).await; + let worktree_ids = self.remove_connection(connection_id).await; + for worktree_id in worktree_ids { + let state = self.state.read().await; + if let Some(worktree) = state.worktrees.get(&worktree_id) { + broadcast(connection_id, worktree.connection_ids(), |conn_id| { + self.peer.send( + conn_id, + proto::RemovePeer { + worktree_id, + peer_id: connection_id.0, + }, + ) + }) + .await?; + } + } + Ok(()) } - fn host_connection_id(&self) -> tide::Result { - Ok(self - .host_connection_id - .ok_or_else(|| anyhow!("host disconnected from worktree"))?) - } -} - -impl Channel { - fn connection_ids(&self) -> Vec { - self.connection_ids.iter().copied().collect() - } -} - -impl State { // Add a new connection associated with a given user. - pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) { - self.connections.insert( + async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) { + self.state.write().await.connections.insert( connection_id, Connection { user_id, @@ -210,16 +214,17 @@ impl State { } // Remove the given connection and its association with any worktrees. - pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Vec { + async fn remove_connection(&self, connection_id: ConnectionId) -> Vec { let mut worktree_ids = Vec::new(); - if let Some(connection) = self.connections.remove(&connection_id) { + let mut state = self.state.write().await; + if let Some(connection) = state.connections.remove(&connection_id) { for channel_id in connection.channels { - if let Some(channel) = self.channels.get_mut(&channel_id) { + if let Some(channel) = state.channels.get_mut(&channel_id) { channel.connection_ids.remove(&connection_id); } } for worktree_id in connection.worktrees { - if let Some(worktree) = self.worktrees.get_mut(&worktree_id) { + if let Some(worktree) = state.worktrees.get_mut(&worktree_id) { if worktree.host_connection_id == Some(connection_id) { worktree_ids.push(worktree_id); } else if let Some(replica_id) = @@ -234,6 +239,444 @@ impl State { worktree_ids } + async fn share_worktree( + self: Arc, + mut request: TypedEnvelope, + ) -> tide::Result<()> { + let mut state = self.state.write().await; + let worktree_id = state.next_worktree_id; + state.next_worktree_id += 1; + let access_token = random_token(); + let worktree = request + .payload + .worktree + .as_mut() + .ok_or_else(|| anyhow!("missing worktree"))?; + let entries = mem::take(&mut worktree.entries) + .into_iter() + .map(|entry| (entry.id, entry)) + .collect(); + state.worktrees.insert( + worktree_id, + Worktree { + host_connection_id: Some(request.sender_id), + guest_connection_ids: Default::default(), + active_replica_ids: Default::default(), + access_token: access_token.clone(), + root_name: mem::take(&mut worktree.root_name), + entries, + }, + ); + + self.peer + .respond( + request.receipt(), + proto::ShareWorktreeResponse { + worktree_id, + access_token, + }, + ) + .await?; + Ok(()) + } + + async fn join_worktree( + self: Arc, + request: TypedEnvelope, + ) -> tide::Result<()> { + let worktree_id = request.payload.worktree_id; + let access_token = &request.payload.access_token; + + let mut state = self.state.write().await; + if let Some((peer_replica_id, worktree)) = + state.join_worktree(request.sender_id, worktree_id, access_token) + { + let mut peers = Vec::new(); + if let Some(host_connection_id) = worktree.host_connection_id { + peers.push(proto::Peer { + peer_id: host_connection_id.0, + replica_id: 0, + }); + } + for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids { + if *peer_conn_id != request.sender_id { + peers.push(proto::Peer { + peer_id: peer_conn_id.0, + replica_id: *peer_replica_id as u32, + }); + } + } + + broadcast(request.sender_id, worktree.connection_ids(), |conn_id| { + self.peer.send( + conn_id, + proto::AddPeer { + worktree_id, + peer: Some(proto::Peer { + peer_id: request.sender_id.0, + replica_id: peer_replica_id as u32, + }), + }, + ) + }) + .await?; + self.peer + .respond( + request.receipt(), + proto::OpenWorktreeResponse { + worktree_id, + worktree: Some(proto::Worktree { + root_name: worktree.root_name.clone(), + entries: worktree.entries.values().cloned().collect(), + }), + replica_id: peer_replica_id as u32, + peers, + }, + ) + .await?; + } else { + self.peer + .respond( + request.receipt(), + proto::OpenWorktreeResponse { + worktree_id, + worktree: None, + replica_id: 0, + peers: Vec::new(), + }, + ) + .await?; + } + + Ok(()) + } + + async fn update_worktree( + self: Arc, + request: TypedEnvelope, + ) -> tide::Result<()> { + { + let mut state = self.state.write().await; + let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?; + for entry_id in &request.payload.removed_entries { + worktree.entries.remove(&entry_id); + } + + for entry in &request.payload.updated_entries { + worktree.entries.insert(entry.id, entry.clone()); + } + } + + self.broadcast_in_worktree(request.payload.worktree_id, &request) + .await?; + Ok(()) + } + + async fn close_worktree( + self: Arc, + request: TypedEnvelope, + ) -> tide::Result<()> { + let connection_ids; + { + let mut state = self.state.write().await; + let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?; + connection_ids = worktree.connection_ids(); + if worktree.host_connection_id == Some(request.sender_id) { + worktree.host_connection_id = None; + } else if let Some(replica_id) = + worktree.guest_connection_ids.remove(&request.sender_id) + { + worktree.active_replica_ids.remove(&replica_id); + } + } + + broadcast(request.sender_id, connection_ids, |conn_id| { + self.peer.send( + conn_id, + proto::RemovePeer { + worktree_id: request.payload.worktree_id, + peer_id: request.sender_id.0, + }, + ) + }) + .await?; + + Ok(()) + } + + async fn open_buffer( + self: Arc, + request: TypedEnvelope, + ) -> tide::Result<()> { + let receipt = request.receipt(); + let worktree_id = request.payload.worktree_id; + let host_connection_id = self + .state + .read() + .await + .read_worktree(worktree_id, request.sender_id)? + .host_connection_id()?; + + let response = self + .peer + .forward_request(request.sender_id, host_connection_id, request.payload) + .await?; + self.peer.respond(receipt, response).await?; + Ok(()) + } + + async fn close_buffer( + self: Arc, + request: TypedEnvelope, + ) -> tide::Result<()> { + let host_connection_id = self + .state + .read() + .await + .read_worktree(request.payload.worktree_id, request.sender_id)? + .host_connection_id()?; + + self.peer + .forward_send(request.sender_id, host_connection_id, request.payload) + .await?; + + Ok(()) + } + + async fn save_buffer( + self: Arc, + request: TypedEnvelope, + ) -> tide::Result<()> { + let host; + let guests; + { + let state = self.state.read().await; + let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?; + host = worktree.host_connection_id()?; + guests = worktree + .guest_connection_ids + .keys() + .copied() + .collect::>(); + } + + let sender = request.sender_id; + let receipt = request.receipt(); + let response = self + .peer + .forward_request(sender, host, request.payload.clone()) + .await?; + + broadcast(host, guests, |conn_id| { + let response = response.clone(); + let peer = &self.peer; + async move { + if conn_id == sender { + peer.respond(receipt, response).await + } else { + peer.forward_send(host, conn_id, response).await + } + } + }) + .await?; + + Ok(()) + } + + async fn update_buffer( + self: Arc, + request: TypedEnvelope, + ) -> tide::Result<()> { + self.broadcast_in_worktree(request.payload.worktree_id, &request) + .await + } + + async fn buffer_saved( + self: Arc, + request: TypedEnvelope, + ) -> tide::Result<()> { + self.broadcast_in_worktree(request.payload.worktree_id, &request) + .await + } + + async fn get_channels( + self: Arc, + request: TypedEnvelope, + ) -> tide::Result<()> { + let user_id = self + .state + .read() + .await + .user_id_for_connection(request.sender_id)?; + let channels = self.app_state.db.get_channels_for_user(user_id).await?; + self.peer + .respond( + request.receipt(), + proto::GetChannelsResponse { + channels: channels + .into_iter() + .map(|chan| proto::Channel { + id: chan.id.to_proto(), + name: chan.name, + }) + .collect(), + }, + ) + .await?; + Ok(()) + } + + async fn get_users( + self: Arc, + request: TypedEnvelope, + ) -> tide::Result<()> { + let user_id = self + .state + .read() + .await + .user_id_for_connection(request.sender_id)?; + let receipt = request.receipt(); + let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto); + let users = self + .app_state + .db + .get_users_by_ids(user_id, user_ids) + .await? + .into_iter() + .map(|user| proto::User { + id: user.id.to_proto(), + github_login: user.github_login, + avatar_url: String::new(), + }) + .collect(); + self.peer + .respond(receipt, proto::GetUsersResponse { users }) + .await?; + Ok(()) + } + + async fn join_channel( + self: Arc, + request: TypedEnvelope, + ) -> tide::Result<()> { + let user_id = self + .state + .read() + .await + .user_id_for_connection(request.sender_id)?; + let channel_id = ChannelId::from_proto(request.payload.channel_id); + if !self + .app_state + .db + .can_user_access_channel(user_id, channel_id) + .await? + { + Err(anyhow!("access denied"))?; + } + + self.state + .write() + .await + .join_channel(request.sender_id, channel_id); + let messages = self + .app_state + .db + .get_recent_channel_messages(channel_id, 50) + .await? + .into_iter() + .map(|msg| proto::ChannelMessage { + id: msg.id.to_proto(), + body: msg.body, + timestamp: msg.sent_at.unix_timestamp() as u64, + sender_id: msg.sender_id.to_proto(), + }) + .collect(); + self.peer + .respond(request.receipt(), proto::JoinChannelResponse { messages }) + .await?; + Ok(()) + } + + async fn send_channel_message( + self: Arc, + request: TypedEnvelope, + ) -> tide::Result<()> { + let channel_id = ChannelId::from_proto(request.payload.channel_id); + let user_id; + let connection_ids; + { + let state = self.state.read().await; + user_id = state.user_id_for_connection(request.sender_id)?; + if let Some(channel) = state.channels.get(&channel_id) { + connection_ids = channel.connection_ids(); + } else { + return Ok(()); + } + } + + let timestamp = OffsetDateTime::now_utc(); + let message_id = self + .app_state + .db + .create_channel_message(channel_id, user_id, &request.payload.body, timestamp) + .await?; + let message = proto::ChannelMessageSent { + channel_id: channel_id.to_proto(), + message: Some(proto::ChannelMessage { + sender_id: user_id.to_proto(), + id: message_id.to_proto(), + body: request.payload.body, + timestamp: timestamp.unix_timestamp() as u64, + }), + }; + broadcast(request.sender_id, connection_ids, |conn_id| { + self.peer.send(conn_id, message.clone()) + }) + .await?; + + Ok(()) + } + + async fn broadcast_in_worktree( + &self, + worktree_id: u64, + message: &TypedEnvelope, + ) -> tide::Result<()> { + let connection_ids = self + .state + .read() + .await + .read_worktree(worktree_id, message.sender_id)? + .connection_ids(); + + broadcast(message.sender_id, connection_ids, |conn_id| { + self.peer + .forward_send(message.sender_id, conn_id, message.payload.clone()) + }) + .await?; + + Ok(()) + } +} + +pub async fn broadcast( + sender_id: ConnectionId, + receiver_ids: Vec, + mut f: F, +) -> anyhow::Result<()> +where + F: FnMut(ConnectionId) -> T, + T: Future>, +{ + let futures = receiver_ids + .into_iter() + .filter(|id| *id != sender_id) + .map(|id| f(id)); + futures::future::try_join_all(futures).await?; + Ok(()) +} + +impl ServerState { fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) { if let Some(connection) = self.connections.get_mut(&connection_id) { connection.channels.insert(channel_id); @@ -245,8 +688,16 @@ impl State { } } + fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result { + Ok(self + .connections + .get(&connection_id) + .ok_or_else(|| anyhow!("unknown connection"))? + .user_id) + } + // Add the given connection as a guest of the given worktree - pub fn join_worktree( + fn join_worktree( &mut self, connection_id: ConnectionId, worktree_id: u64, @@ -275,14 +726,6 @@ impl State { } } - fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result { - Ok(self - .connections - .get(&connection_id) - .ok_or_else(|| anyhow!("unknown connection"))? - .user_id) - } - fn read_worktree( &self, worktree_id: u64, @@ -330,26 +773,30 @@ impl State { } } -pub fn build_server(state: &Arc, rpc: &Arc) -> Arc { - ServerBuilder::default() - .on_message(share_worktree) - .on_message(join_worktree) - .on_message(update_worktree) - .on_message(close_worktree) - .on_message(open_buffer) - .on_message(close_buffer) - .on_message(update_buffer) - .on_message(buffer_saved) - .on_message(save_buffer) - .on_message(get_channels) - .on_message(get_users) - .on_message(join_channel) - .on_message(send_channel_message) - .build(rpc, state) +impl Worktree { + pub fn connection_ids(&self) -> Vec { + self.guest_connection_ids + .keys() + .copied() + .chain(self.host_connection_id) + .collect() + } + + fn host_connection_id(&self) -> tide::Result { + Ok(self + .host_connection_id + .ok_or_else(|| anyhow!("host disconnected from worktree"))?) + } +} + +impl Channel { + fn connection_ids(&self) -> Vec { + self.connection_ids.iter().copied().collect() + } } pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { - let server = build_server(app.state(), rpc); + let server = Server::new(app.state().clone(), rpc.clone()); app.at("/rpc").with(auth::VerifyToken).get(move |request: Request>| { let user_id = request.ext::().copied(); let server = server.clone(); @@ -392,453 +839,6 @@ pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { }); } -async fn share_worktree( - mut request: Box>, - server: Arc, -) -> tide::Result<()> { - let mut state = server.state.rpc.write().await; - let worktree_id = state.next_worktree_id; - state.next_worktree_id += 1; - let access_token = random_token(); - let worktree = request - .payload - .worktree - .as_mut() - .ok_or_else(|| anyhow!("missing worktree"))?; - let entries = mem::take(&mut worktree.entries) - .into_iter() - .map(|entry| (entry.id, entry)) - .collect(); - state.worktrees.insert( - worktree_id, - Worktree { - host_connection_id: Some(request.sender_id), - guest_connection_ids: Default::default(), - active_replica_ids: Default::default(), - access_token: access_token.clone(), - root_name: mem::take(&mut worktree.root_name), - entries, - }, - ); - - server - .rpc - .respond( - request.receipt(), - proto::ShareWorktreeResponse { - worktree_id, - access_token, - }, - ) - .await?; - Ok(()) -} - -async fn join_worktree( - request: Box>, - server: Arc, -) -> tide::Result<()> { - let worktree_id = request.payload.worktree_id; - let access_token = &request.payload.access_token; - - let mut state = server.state.rpc.write().await; - if let Some((peer_replica_id, worktree)) = - state.join_worktree(request.sender_id, worktree_id, access_token) - { - let mut peers = Vec::new(); - if let Some(host_connection_id) = worktree.host_connection_id { - peers.push(proto::Peer { - peer_id: host_connection_id.0, - replica_id: 0, - }); - } - for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids { - if *peer_conn_id != request.sender_id { - peers.push(proto::Peer { - peer_id: peer_conn_id.0, - replica_id: *peer_replica_id as u32, - }); - } - } - - broadcast(request.sender_id, worktree.connection_ids(), |conn_id| { - server.rpc.send( - conn_id, - proto::AddPeer { - worktree_id, - peer: Some(proto::Peer { - peer_id: request.sender_id.0, - replica_id: peer_replica_id as u32, - }), - }, - ) - }) - .await?; - server - .rpc - .respond( - request.receipt(), - proto::OpenWorktreeResponse { - worktree_id, - worktree: Some(proto::Worktree { - root_name: worktree.root_name.clone(), - entries: worktree.entries.values().cloned().collect(), - }), - replica_id: peer_replica_id as u32, - peers, - }, - ) - .await?; - } else { - server - .rpc - .respond( - request.receipt(), - proto::OpenWorktreeResponse { - worktree_id, - worktree: None, - replica_id: 0, - peers: Vec::new(), - }, - ) - .await?; - } - - Ok(()) -} - -async fn update_worktree( - request: Box>, - server: Arc, -) -> tide::Result<()> { - { - let mut state = server.state.rpc.write().await; - let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?; - for entry_id in &request.payload.removed_entries { - worktree.entries.remove(&entry_id); - } - - for entry in &request.payload.updated_entries { - worktree.entries.insert(entry.id, entry.clone()); - } - } - - broadcast_in_worktree(request.payload.worktree_id, &request, &server).await?; - Ok(()) -} - -async fn close_worktree( - request: Box>, - server: Arc, -) -> tide::Result<()> { - let connection_ids; - { - let mut state = server.state.rpc.write().await; - let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?; - connection_ids = worktree.connection_ids(); - if worktree.host_connection_id == Some(request.sender_id) { - worktree.host_connection_id = None; - } else if let Some(replica_id) = worktree.guest_connection_ids.remove(&request.sender_id) { - worktree.active_replica_ids.remove(&replica_id); - } - } - - broadcast(request.sender_id, connection_ids, |conn_id| { - server.rpc.send( - conn_id, - proto::RemovePeer { - worktree_id: request.payload.worktree_id, - peer_id: request.sender_id.0, - }, - ) - }) - .await?; - - Ok(()) -} - -async fn open_buffer( - request: Box>, - server: Arc, -) -> tide::Result<()> { - let receipt = request.receipt(); - let worktree_id = request.payload.worktree_id; - let host_connection_id = server - .state - .rpc - .read() - .await - .read_worktree(worktree_id, request.sender_id)? - .host_connection_id()?; - - let response = server - .rpc - .forward_request(request.sender_id, host_connection_id, request.payload) - .await?; - server.rpc.respond(receipt, response).await?; - Ok(()) -} - -async fn close_buffer( - request: Box>, - server: Arc, -) -> tide::Result<()> { - let host_connection_id = server - .state - .rpc - .read() - .await - .read_worktree(request.payload.worktree_id, request.sender_id)? - .host_connection_id()?; - - server - .rpc - .forward_send(request.sender_id, host_connection_id, request.payload) - .await?; - - Ok(()) -} - -async fn save_buffer( - request: Box>, - server: Arc, -) -> tide::Result<()> { - let host; - let guests; - { - let state = server.state.rpc.read().await; - let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?; - host = worktree.host_connection_id()?; - guests = worktree - .guest_connection_ids - .keys() - .copied() - .collect::>(); - } - - let sender = request.sender_id; - let receipt = request.receipt(); - let response = server - .rpc - .forward_request(sender, host, request.payload.clone()) - .await?; - - broadcast(host, guests, |conn_id| { - let response = response.clone(); - let server = &server; - async move { - if conn_id == sender { - server.rpc.respond(receipt, response).await - } else { - server.rpc.forward_send(host, conn_id, response).await - } - } - }) - .await?; - - Ok(()) -} - -async fn update_buffer( - request: Box>, - server: Arc, -) -> tide::Result<()> { - broadcast_in_worktree(request.payload.worktree_id, &request, &server).await -} - -async fn buffer_saved( - request: Box>, - server: Arc, -) -> tide::Result<()> { - broadcast_in_worktree(request.payload.worktree_id, &request, &server).await -} - -async fn get_channels( - request: Box>, - server: Arc, -) -> tide::Result<()> { - let user_id = server - .state - .rpc - .read() - .await - .user_id_for_connection(request.sender_id)?; - let channels = server.state.db.get_channels_for_user(user_id).await?; - server - .rpc - .respond( - request.receipt(), - proto::GetChannelsResponse { - channels: channels - .into_iter() - .map(|chan| proto::Channel { - id: chan.id.to_proto(), - name: chan.name, - }) - .collect(), - }, - ) - .await?; - Ok(()) -} - -async fn get_users( - request: Box>, - server: Arc, -) -> tide::Result<()> { - let user_id = server - .state - .rpc - .read() - .await - .user_id_for_connection(request.sender_id)?; - let receipt = request.receipt(); - let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto); - let users = server - .state - .db - .get_users_by_ids(user_id, user_ids) - .await? - .into_iter() - .map(|user| proto::User { - id: user.id.to_proto(), - github_login: user.github_login, - avatar_url: String::new(), - }) - .collect(); - server - .rpc - .respond(receipt, proto::GetUsersResponse { users }) - .await?; - Ok(()) -} - -async fn join_channel( - request: Box>, - server: Arc, -) -> tide::Result<()> { - let user_id = server - .state - .rpc - .read() - .await - .user_id_for_connection(request.sender_id)?; - let channel_id = ChannelId::from_proto(request.payload.channel_id); - if !server - .state - .db - .can_user_access_channel(user_id, channel_id) - .await? - { - Err(anyhow!("access denied"))?; - } - - server - .state - .rpc - .write() - .await - .join_channel(request.sender_id, channel_id); - let messages = server - .state - .db - .get_recent_channel_messages(channel_id, 50) - .await? - .into_iter() - .map(|msg| proto::ChannelMessage { - id: msg.id.to_proto(), - body: msg.body, - timestamp: msg.sent_at.unix_timestamp() as u64, - sender_id: msg.sender_id.to_proto(), - }) - .collect(); - server - .rpc - .respond(request.receipt(), proto::JoinChannelResponse { messages }) - .await?; - Ok(()) -} - -async fn send_channel_message( - request: Box>, - server: Arc, -) -> tide::Result<()> { - let channel_id = ChannelId::from_proto(request.payload.channel_id); - let user_id; - let connection_ids; - { - let state = server.state.rpc.read().await; - user_id = state.user_id_for_connection(request.sender_id)?; - if let Some(channel) = state.channels.get(&channel_id) { - connection_ids = channel.connection_ids(); - } else { - return Ok(()); - } - } - - let timestamp = OffsetDateTime::now_utc(); - let message_id = server - .state - .db - .create_channel_message(channel_id, user_id, &request.payload.body, timestamp) - .await?; - let message = proto::ChannelMessageSent { - channel_id: channel_id.to_proto(), - message: Some(proto::ChannelMessage { - sender_id: user_id.to_proto(), - id: message_id.to_proto(), - body: request.payload.body, - timestamp: timestamp.unix_timestamp() as u64, - }), - }; - broadcast(request.sender_id, connection_ids, |conn_id| { - server.rpc.send(conn_id, message.clone()) - }) - .await?; - - Ok(()) -} - -async fn broadcast_in_worktree( - worktree_id: u64, - request: &TypedEnvelope, - server: &Arc, -) -> tide::Result<()> { - let connection_ids = server - .state - .rpc - .read() - .await - .read_worktree(worktree_id, request.sender_id)? - .connection_ids(); - - broadcast(request.sender_id, connection_ids, |conn_id| { - server - .rpc - .forward_send(request.sender_id, conn_id, request.payload.clone()) - }) - .await?; - - Ok(()) -} - -pub async fn broadcast( - sender_id: ConnectionId, - receiver_ids: Vec, - mut f: F, -) -> anyhow::Result<()> -where - F: FnMut(ConnectionId) -> T, - T: Future>, -{ - let futures = receiver_ids - .into_iter() - .filter(|id| *id != sender_id) - .map(|id| f(id)); - futures::future::try_join_all(futures).await?; - Ok(()) -} - fn header_contains_ignore_case( request: &tide::Request, header_name: HeaderName, diff --git a/server/src/tests.rs b/server/src/tests.rs index 5df19aa530..1cc7bf18ba 100644 --- a/server/src/tests.rs +++ b/server/src/tests.rs @@ -540,7 +540,7 @@ impl TestServer { let db_name = format!("zed-test-{}", rng.gen::()); let app_state = Self::build_app_state(&db_name).await; let peer = Peer::new(); - let server = rpc::build_server(&app_state, &peer); + let server = rpc::Server::new(app_state.clone(), peer.clone()); Self { peer, app_state, @@ -595,7 +595,6 @@ impl TestServer { auth_client: auth::build_client("", ""), repo_client: github::RepoClient::test(&github_client), github_client, - rpc: Default::default(), config, }) }