diff --git a/crates/server/src/rpc.rs b/crates/server/src/rpc.rs index a160769d6d..8d406e910c 100644 --- a/crates/server/src/rpc.rs +++ b/crates/server/src/rpc.rs @@ -13,7 +13,7 @@ use futures::{future::BoxFuture, FutureExt, StreamExt}; use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use postage::{mpsc, prelude::Sink as _}; use rpc::{ - proto::{self, AnyTypedEnvelope, EnvelopedMessage}, + proto::{self, AnyTypedEnvelope, EnvelopedMessage, RequestMessage}, Connection, ConnectionId, Peer, TypedEnvelope, }; use sha1::{Digest as _, Sha1}; @@ -43,7 +43,6 @@ pub struct Server { const MESSAGE_COUNT_PER_PAGE: usize = 100; const MAX_MESSAGE_LEN: usize = 1024; -const NO_SUCH_PROJECT: &'static str = "no such project"; impl Server { pub fn new( @@ -60,44 +59,44 @@ impl Server { }; server - .add_handler(Server::ping) - .add_handler(Server::register_project) - .add_handler(Server::unregister_project) - .add_handler(Server::share_project) - .add_handler(Server::unshare_project) - .add_handler(Server::join_project) - .add_handler(Server::leave_project) - .add_handler(Server::register_worktree) - .add_handler(Server::unregister_worktree) - .add_handler(Server::share_worktree) - .add_handler(Server::update_worktree) - .add_handler(Server::update_diagnostic_summary) - .add_handler(Server::disk_based_diagnostics_updating) - .add_handler(Server::disk_based_diagnostics_updated) - .add_handler(Server::get_definition) - .add_handler(Server::open_buffer) - .add_handler(Server::close_buffer) - .add_handler(Server::update_buffer) - .add_handler(Server::update_buffer_file) - .add_handler(Server::buffer_reloaded) - .add_handler(Server::buffer_saved) - .add_handler(Server::save_buffer) - .add_handler(Server::format_buffers) - .add_handler(Server::get_completions) - .add_handler(Server::apply_additional_edits_for_completion) - .add_handler(Server::get_code_actions) - .add_handler(Server::apply_code_action) - .add_handler(Server::get_channels) - .add_handler(Server::get_users) - .add_handler(Server::join_channel) - .add_handler(Server::leave_channel) - .add_handler(Server::send_channel_message) - .add_handler(Server::get_channel_messages); + .add_request_handler(Server::ping) + .add_request_handler(Server::register_project) + .add_message_handler(Server::unregister_project) + .add_request_handler(Server::share_project) + .add_message_handler(Server::unshare_project) + .add_request_handler(Server::join_project) + .add_message_handler(Server::leave_project) + .add_request_handler(Server::register_worktree) + .add_message_handler(Server::unregister_worktree) + .add_request_handler(Server::share_worktree) + .add_message_handler(Server::update_worktree) + .add_message_handler(Server::update_diagnostic_summary) + .add_message_handler(Server::disk_based_diagnostics_updating) + .add_message_handler(Server::disk_based_diagnostics_updated) + .add_request_handler(Server::get_definition) + .add_request_handler(Server::open_buffer) + .add_message_handler(Server::close_buffer) + .add_request_handler(Server::update_buffer) + .add_message_handler(Server::update_buffer_file) + .add_message_handler(Server::buffer_reloaded) + .add_message_handler(Server::buffer_saved) + .add_request_handler(Server::save_buffer) + .add_request_handler(Server::format_buffers) + .add_request_handler(Server::get_completions) + .add_request_handler(Server::apply_additional_edits_for_completion) + .add_request_handler(Server::get_code_actions) + .add_request_handler(Server::apply_code_action) + .add_request_handler(Server::get_channels) + .add_request_handler(Server::get_users) + .add_request_handler(Server::join_channel) + .add_message_handler(Server::leave_channel) + .add_request_handler(Server::send_channel_message) + .add_request_handler(Server::get_channel_messages); Arc::new(server) } - fn add_handler(&mut self, handler: F) -> &mut Self + fn add_message_handler(&mut self, handler: F) -> &mut Self where F: 'static + Send + Sync + Fn(Arc, TypedEnvelope) -> Fut, Fut: 'static + Send + Future>, @@ -116,6 +115,44 @@ impl Server { self } + fn add_request_handler(&mut self, handler: F) -> &mut Self + where + F: 'static + Send + Sync + Fn(Arc, TypedEnvelope) -> Fut, + Fut: 'static + Send + Future>, + M: RequestMessage, + { + let prev_handler = self.handlers.insert( + TypeId::of::(), + Box::new(move |server, envelope| { + let envelope = envelope.into_any().downcast::>().unwrap(); + let receipt = envelope.receipt(); + let response = (handler)(server.clone(), *envelope); + async move { + match response.await { + Ok(response) => { + server.peer.respond(receipt, response)?; + Ok(()) + } + Err(error) => { + server.peer.respond_with_error( + receipt, + proto::Error { + message: error.to_string(), + }, + )?; + Err(error) + } + } + } + .boxed() + }), + ); + if prev_handler.is_some() { + panic!("registered a handler for the same message twice"); + } + self + } + pub fn handle_connection( self: &Arc, connection: Connection, @@ -214,25 +251,20 @@ impl Server { Ok(()) } - async fn ping(self: Arc, request: TypedEnvelope) -> tide::Result<()> { - self.peer.respond(request.receipt(), proto::Ack {})?; - Ok(()) + async fn ping(self: Arc, _: TypedEnvelope) -> tide::Result { + Ok(proto::Ack {}) } async fn register_project( mut self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { + ) -> tide::Result { let project_id = { let mut state = self.state_mut(); let user_id = state.user_id_for_connection(request.sender_id)?; state.register_project(request.sender_id, user_id) }; - self.peer.respond( - request.receipt(), - proto::RegisterProjectResponse { project_id }, - )?; - Ok(()) + Ok(proto::RegisterProjectResponse { project_id }) } async fn unregister_project( @@ -241,8 +273,7 @@ impl Server { ) -> tide::Result<()> { let project = self .state_mut() - .unregister_project(request.payload.project_id, request.sender_id) - .ok_or_else(|| anyhow!("no such project"))?; + .unregister_project(request.payload.project_id, request.sender_id)?; self.update_contacts_for_users(project.authorized_user_ids().iter())?; Ok(()) } @@ -250,11 +281,10 @@ impl Server { async fn share_project( mut self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { + ) -> tide::Result { self.state_mut() .share_project(request.payload.project_id, request.sender_id); - self.peer.respond(request.receipt(), proto::Ack {})?; - Ok(()) + Ok(proto::Ack {}) } async fn unshare_project( @@ -277,11 +307,11 @@ impl Server { async fn join_project( mut self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { + ) -> tide::Result { let project_id = request.payload.project_id; let user_id = self.state().user_id_for_connection(request.sender_id)?; - let response_data = self + let (response, connection_ids, contact_user_ids) = self .state_mut() .join_project(request.sender_id, user_id, project_id) .and_then(|joined| { @@ -328,37 +358,23 @@ impl Server { let connection_ids = joined.project.connection_ids(); let contact_user_ids = joined.project.authorized_user_ids(); Ok((response, connection_ids, contact_user_ids)) - }); + })?; - match response_data { - Ok((response, connection_ids, contact_user_ids)) => { - broadcast(request.sender_id, connection_ids, |conn_id| { - self.peer.send( - conn_id, - proto::AddProjectCollaborator { - project_id, - collaborator: Some(proto::Collaborator { - peer_id: request.sender_id.0, - replica_id: response.replica_id, - user_id: user_id.to_proto(), - }), - }, - ) - })?; - self.peer.respond(request.receipt(), response)?; - self.update_contacts_for_users(&contact_user_ids)?; - } - Err(error) => { - self.peer.respond_with_error( - request.receipt(), - proto::Error { - message: error.to_string(), - }, - )?; - } - } - - Ok(()) + broadcast(request.sender_id, connection_ids, |conn_id| { + self.peer.send( + conn_id, + proto::AddProjectCollaborator { + project_id, + collaborator: Some(proto::Collaborator { + peer_id: request.sender_id.0, + replica_id: response.replica_id, + user_id: user_id.to_proto(), + }), + }, + ) + })?; + self.update_contacts_for_users(&contact_user_ids)?; + Ok(response) } async fn leave_project( @@ -367,70 +383,49 @@ impl Server { ) -> tide::Result<()> { let sender_id = request.sender_id; let project_id = request.payload.project_id; - let worktree = self.state_mut().leave_project(sender_id, project_id); - if let Some(worktree) = worktree { - broadcast(sender_id, worktree.connection_ids, |conn_id| { - self.peer.send( - conn_id, - proto::RemoveProjectCollaborator { - project_id, - peer_id: sender_id.0, - }, - ) - })?; - self.update_contacts_for_users(&worktree.authorized_user_ids)?; - } + let worktree = self.state_mut().leave_project(sender_id, project_id)?; + + broadcast(sender_id, worktree.connection_ids, |conn_id| { + self.peer.send( + conn_id, + proto::RemoveProjectCollaborator { + project_id, + peer_id: sender_id.0, + }, + ) + })?; + self.update_contacts_for_users(&worktree.authorized_user_ids)?; + Ok(()) } async fn register_worktree( mut self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { - let receipt = request.receipt(); + ) -> tide::Result { let host_user_id = self.state().user_id_for_connection(request.sender_id)?; let mut contact_user_ids = HashSet::default(); contact_user_ids.insert(host_user_id); for github_login in request.payload.authorized_logins { - match self.app_state.db.create_user(&github_login, false).await { - Ok(contact_user_id) => { - contact_user_ids.insert(contact_user_id); - } - Err(err) => { - let message = err.to_string(); - self.peer - .respond_with_error(receipt, proto::Error { message })?; - return Ok(()); - } - } + let contact_user_id = self.app_state.db.create_user(&github_login, false).await?; + contact_user_ids.insert(contact_user_id); } let contact_user_ids = contact_user_ids.into_iter().collect::>(); - let ok = self.state_mut().register_worktree( + self.state_mut().register_worktree( request.payload.project_id, request.payload.worktree_id, + request.sender_id, Worktree { authorized_user_ids: contact_user_ids.clone(), root_name: request.payload.root_name, share: None, weak: false, }, - ); - - if ok { - self.peer.respond(receipt, proto::Ack {})?; - self.update_contacts_for_users(&contact_user_ids)?; - } else { - self.peer.respond_with_error( - receipt, - proto::Error { - message: NO_SUCH_PROJECT.to_string(), - }, - )?; - } - - Ok(()) + )?; + self.update_contacts_for_users(&contact_user_ids)?; + Ok(proto::Ack {}) } async fn unregister_worktree( @@ -458,7 +453,7 @@ impl Server { async fn share_worktree( mut self: Arc, mut request: TypedEnvelope, - ) -> tide::Result<()> { + ) -> tide::Result { let worktree = request .payload .worktree @@ -481,46 +476,32 @@ impl Server { request.sender_id, entries, diagnostic_summaries, - ); - if let Some(shared_worktree) = shared_worktree { - broadcast( - request.sender_id, - shared_worktree.connection_ids, - |connection_id| { - self.peer.forward_send( - request.sender_id, - connection_id, - request.payload.clone(), - ) - }, - )?; - self.peer.respond(request.receipt(), proto::Ack {})?; - self.update_contacts_for_users(&shared_worktree.authorized_user_ids)?; - } else { - self.peer.respond_with_error( - request.receipt(), - proto::Error { - message: "no such worktree".to_string(), - }, - )?; - } - Ok(()) + )?; + + broadcast( + request.sender_id, + shared_worktree.connection_ids, + |connection_id| { + self.peer + .forward_send(request.sender_id, connection_id, request.payload.clone()) + }, + )?; + self.update_contacts_for_users(&shared_worktree.authorized_user_ids)?; + + Ok(proto::Ack {}) } async fn update_worktree( mut self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { - let connection_ids = self - .state_mut() - .update_worktree( - request.sender_id, - request.payload.project_id, - request.payload.worktree_id, - &request.payload.removed_entries, - &request.payload.updated_entries, - ) - .ok_or_else(|| anyhow!("no such worktree"))?; + let connection_ids = self.state_mut().update_worktree( + request.sender_id, + request.payload.project_id, + request.payload.worktree_id, + &request.payload.removed_entries, + &request.payload.updated_entries, + )?; broadcast(request.sender_id, connection_ids, |connection_id| { self.peer @@ -534,19 +515,17 @@ impl Server { mut self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { - let receiver_ids = request + let summary = request .payload .summary .clone() - .and_then(|summary| { - self.state_mut().update_diagnostic_summary( - request.payload.project_id, - request.payload.worktree_id, - request.sender_id, - summary, - ) - }) - .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))?; + .ok_or_else(|| anyhow!("invalid summary"))?; + let receiver_ids = self.state_mut().update_diagnostic_summary( + request.payload.project_id, + request.payload.worktree_id, + request.sender_id, + summary, + )?; broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer @@ -561,8 +540,7 @@ impl Server { ) -> tide::Result<()> { let receiver_ids = self .state() - .project_connection_ids(request.payload.project_id, request.sender_id) - .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))?; + .project_connection_ids(request.payload.project_id, request.sender_id)?; broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) @@ -576,8 +554,7 @@ impl Server { ) -> tide::Result<()> { let receiver_ids = self .state() - .project_connection_ids(request.payload.project_id, request.sender_id) - .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))?; + .project_connection_ids(request.payload.project_id, request.sender_id)?; broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) @@ -588,37 +565,29 @@ impl Server { async fn get_definition( self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { - let receipt = request.receipt(); + ) -> tide::Result { let host_connection_id = self .state() - .read_project(request.payload.project_id, request.sender_id) - .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))? + .read_project(request.payload.project_id, request.sender_id)? .host_connection_id; - let response = self + Ok(self .peer .forward_request(request.sender_id, host_connection_id, request.payload) - .await?; - self.peer.respond(receipt, response)?; - Ok(()) + .await?) } async fn open_buffer( self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { - let receipt = request.receipt(); + ) -> tide::Result { let host_connection_id = self .state() - .read_project(request.payload.project_id, request.sender_id) - .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))? + .read_project(request.payload.project_id, request.sender_id)? .host_connection_id; - let response = self + Ok(self .peer .forward_request(request.sender_id, host_connection_id, request.payload) - .await?; - self.peer.respond(receipt, response)?; - Ok(()) + .await?) } async fn close_buffer( @@ -627,8 +596,7 @@ impl Server { ) -> tide::Result<()> { let host_connection_id = self .state() - .read_project(request.payload.project_id, request.sender_id) - .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))? + .read_project(request.payload.project_id, request.sender_id)? .host_connection_id; self.peer .forward_send(request.sender_id, host_connection_id, request.payload)?; @@ -638,207 +606,111 @@ impl Server { async fn save_buffer( self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { + ) -> tide::Result { let host; - let guests; + let mut guests; { let state = self.state(); - let project = state - .read_project(request.payload.project_id, request.sender_id) - .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))?; + let project = state.read_project(request.payload.project_id, request.sender_id)?; host = project.host_connection_id; guests = project.guest_connection_ids() } - let sender = request.sender_id; - let receipt = request.receipt(); let response = self .peer - .forward_request(sender, host, request.payload.clone()) + .forward_request(request.sender_id, host, request.payload.clone()) .await?; + guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id); broadcast(host, guests, |conn_id| { - let response = response.clone(); - if conn_id == sender { - self.peer.respond(receipt, response) - } else { - self.peer.forward_send(host, conn_id, response) - } + self.peer.forward_send(host, conn_id, response.clone()) })?; - Ok(()) + Ok(response) } async fn format_buffers( self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { - let host; - { - let state = self.state(); - let project = state - .read_project(request.payload.project_id, request.sender_id) - .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))?; - host = project.host_connection_id; - } - - let sender = request.sender_id; - let receipt = request.receipt(); - match self + ) -> tide::Result { + let host = self + .state() + .read_project(request.payload.project_id, request.sender_id)? + .host_connection_id; + Ok(self .peer - .forward_request(sender, host, request.payload.clone()) - .await - { - Ok(response) => self.peer.respond(receipt, response)?, - Err(error) => self.peer.respond_with_error( - receipt, - proto::Error { - message: error.to_string(), - }, - )?, - } - - Ok(()) + .forward_request(request.sender_id, host, request.payload.clone()) + .await?) } async fn get_completions( self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { - let host; - { - let state = self.state(); - let project = state - .read_project(request.payload.project_id, request.sender_id) - .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))?; - host = project.host_connection_id; - } - - let sender = request.sender_id; - let receipt = request.receipt(); - match self + ) -> tide::Result { + let host = self + .state() + .read_project(request.payload.project_id, request.sender_id)? + .host_connection_id; + Ok(self .peer - .forward_request(sender, host, request.payload.clone()) - .await - { - Ok(response) => self.peer.respond(receipt, response)?, - Err(error) => self.peer.respond_with_error( - receipt, - proto::Error { - message: error.to_string(), - }, - )?, - } - Ok(()) + .forward_request(request.sender_id, host, request.payload.clone()) + .await?) } async fn apply_additional_edits_for_completion( self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { - let host; - { - let state = self.state(); - let project = state - .read_project(request.payload.project_id, request.sender_id) - .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))?; - host = project.host_connection_id; - } - - let sender = request.sender_id; - let receipt = request.receipt(); - match self + ) -> tide::Result { + let host = self + .state() + .read_project(request.payload.project_id, request.sender_id)? + .host_connection_id; + Ok(self .peer - .forward_request(sender, host, request.payload.clone()) - .await - { - Ok(response) => self.peer.respond(receipt, response)?, - Err(error) => self.peer.respond_with_error( - receipt, - proto::Error { - message: error.to_string(), - }, - )?, - } - Ok(()) + .forward_request(request.sender_id, host, request.payload.clone()) + .await?) } async fn get_code_actions( self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { - let host; - { - let state = self.state(); - let project = state - .read_project(request.payload.project_id, request.sender_id) - .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))?; - host = project.host_connection_id; - } - - let sender = request.sender_id; - let receipt = request.receipt(); - match self + ) -> tide::Result { + let host = self + .state() + .read_project(request.payload.project_id, request.sender_id)? + .host_connection_id; + Ok(self .peer - .forward_request(sender, host, request.payload.clone()) - .await - { - Ok(response) => self.peer.respond(receipt, response)?, - Err(error) => self.peer.respond_with_error( - receipt, - proto::Error { - message: error.to_string(), - }, - )?, - } - Ok(()) + .forward_request(request.sender_id, host, request.payload.clone()) + .await?) } async fn apply_code_action( self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { - let host; - { - let state = self.state(); - let project = state - .read_project(request.payload.project_id, request.sender_id) - .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))?; - host = project.host_connection_id; - } - - let sender = request.sender_id; - let receipt = request.receipt(); - match self + ) -> tide::Result { + let host = self + .state() + .read_project(request.payload.project_id, request.sender_id)? + .host_connection_id; + Ok(self .peer - .forward_request(sender, host, request.payload.clone()) - .await - { - Ok(response) => self.peer.respond(receipt, response)?, - Err(error) => self.peer.respond_with_error( - receipt, - proto::Error { - message: error.to_string(), - }, - )?, - } - Ok(()) + .forward_request(request.sender_id, host, request.payload.clone()) + .await?) } async fn update_buffer( self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { + ) -> tide::Result { let receiver_ids = self .state() - .project_connection_ids(request.payload.project_id, request.sender_id) - .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))?; + .project_connection_ids(request.payload.project_id, request.sender_id)?; broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) })?; - self.peer.respond(request.receipt(), proto::Ack {})?; - Ok(()) + Ok(proto::Ack {}) } async fn update_buffer_file( @@ -847,8 +719,7 @@ impl Server { ) -> tide::Result<()> { let receiver_ids = self .state() - .project_connection_ids(request.payload.project_id, request.sender_id) - .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))?; + .project_connection_ids(request.payload.project_id, request.sender_id)?; broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) @@ -862,8 +733,7 @@ impl Server { ) -> tide::Result<()> { let receiver_ids = self .state() - .project_connection_ids(request.payload.project_id, request.sender_id) - .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))?; + .project_connection_ids(request.payload.project_id, request.sender_id)?; broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) @@ -877,8 +747,7 @@ impl Server { ) -> tide::Result<()> { let receiver_ids = self .state() - .project_connection_ids(request.payload.project_id, request.sender_id) - .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))?; + .project_connection_ids(request.payload.project_id, request.sender_id)?; broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) @@ -889,29 +758,24 @@ impl Server { async fn get_channels( self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { + ) -> tide::Result { let user_id = self.state().user_id_for_connection(request.sender_id)?; let channels = self.app_state.db.get_accessible_channels(user_id).await?; - self.peer.respond( - request.receipt(), - proto::GetChannelsResponse { - channels: channels - .into_iter() - .map(|chan| proto::Channel { - id: chan.id.to_proto(), - name: chan.name, - }) - .collect(), - }, - )?; - Ok(()) + Ok(proto::GetChannelsResponse { + channels: channels + .into_iter() + .map(|chan| proto::Channel { + id: chan.id.to_proto(), + name: chan.name, + }) + .collect(), + }) } async fn get_users( self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { - let receipt = request.receipt(); + ) -> tide::Result { let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto); let users = self .app_state @@ -925,9 +789,7 @@ impl Server { github_login: user.github_login, }) .collect(); - self.peer - .respond(receipt, proto::GetUsersResponse { users })?; - Ok(()) + Ok(proto::GetUsersResponse { users }) } fn update_contacts_for_users<'a>( @@ -955,7 +817,7 @@ impl Server { async fn join_channel( mut self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { + ) -> tide::Result { let user_id = self.state().user_id_for_connection(request.sender_id)?; let channel_id = ChannelId::from_proto(request.payload.channel_id); if !self @@ -982,14 +844,10 @@ impl Server { nonce: Some(msg.nonce.as_u128().into()), }) .collect::>(); - self.peer.respond( - request.receipt(), - proto::JoinChannelResponse { - done: messages.len() < MESSAGE_COUNT_PER_PAGE, - messages, - }, - )?; - Ok(()) + Ok(proto::JoinChannelResponse { + done: messages.len() < MESSAGE_COUNT_PER_PAGE, + messages, + }) } async fn leave_channel( @@ -1016,54 +874,30 @@ impl Server { async fn send_channel_message( self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { - let receipt = request.receipt(); + ) -> tide::Result { let channel_id = ChannelId::from_proto(request.payload.channel_id); let user_id; let connection_ids; { let state = self.state(); user_id = state.user_id_for_connection(request.sender_id)?; - if let Some(ids) = state.channel_connection_ids(channel_id) { - connection_ids = ids; - } else { - return Ok(()); - } + connection_ids = state.channel_connection_ids(channel_id)?; } // Validate the message body. let body = request.payload.body.trim().to_string(); if body.len() > MAX_MESSAGE_LEN { - self.peer.respond_with_error( - receipt, - proto::Error { - message: "message is too long".to_string(), - }, - )?; - return Ok(()); + return Err(anyhow!("message is too long"))?; } if body.is_empty() { - self.peer.respond_with_error( - receipt, - proto::Error { - message: "message can't be blank".to_string(), - }, - )?; - return Ok(()); + return Err(anyhow!("message can't be blank"))?; } let timestamp = OffsetDateTime::now_utc(); - let nonce = if let Some(nonce) = request.payload.nonce { - nonce - } else { - self.peer.respond_with_error( - receipt, - proto::Error { - message: "nonce can't be blank".to_string(), - }, - )?; - return Ok(()); - }; + let nonce = request + .payload + .nonce + .ok_or_else(|| anyhow!("nonce can't be blank"))?; let message_id = self .app_state @@ -1087,19 +921,15 @@ impl Server { }, ) })?; - self.peer.respond( - receipt, - proto::SendChannelMessageResponse { - message: Some(message), - }, - )?; - Ok(()) + Ok(proto::SendChannelMessageResponse { + message: Some(message), + }) } async fn get_channel_messages( self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { + ) -> tide::Result { let user_id = self.state().user_id_for_connection(request.sender_id)?; let channel_id = ChannelId::from_proto(request.payload.channel_id); if !self @@ -1129,14 +959,11 @@ impl Server { nonce: Some(msg.nonce.as_u128().into()), }) .collect::>(); - self.peer.respond( - request.receipt(), - proto::GetChannelMessagesResponse { - done: messages.len() < MESSAGE_COUNT_PER_PAGE, - messages, - }, - )?; - Ok(()) + + Ok(proto::GetChannelMessagesResponse { + done: messages.len() < MESSAGE_COUNT_PER_PAGE, + messages, + }) } fn state<'a>(self: &'a Arc) -> RwLockReadGuard<'a, Store> { diff --git a/crates/server/src/rpc/store.rs b/crates/server/src/rpc/store.rs index 6e11f431ac..5cb0a0e1db 100644 --- a/crates/server/src/rpc/store.rs +++ b/crates/server/src/rpc/store.rs @@ -122,10 +122,10 @@ impl Store { let mut result = RemovedConnectionState::default(); for project_id in connection.projects.clone() { - if let Some(project) = self.unregister_project(project_id, connection_id) { + if let Ok(project) = self.unregister_project(project_id, connection_id) { result.contact_ids.extend(project.authorized_user_ids()); result.hosted_projects.insert(project_id, project); - } else if let Some(project) = self.leave_project(connection_id, project_id) { + } else if let Ok(project) = self.leave_project(connection_id, project_id) { result .guest_project_ids .insert(project_id, project.connection_ids); @@ -254,9 +254,14 @@ impl Store { &mut self, project_id: u64, worktree_id: u64, + connection_id: ConnectionId, worktree: Worktree, - ) -> bool { - if let Some(project) = self.projects.get_mut(&project_id) { + ) -> tide::Result<()> { + let project = self + .projects + .get_mut(&project_id) + .ok_or_else(|| anyhow!("no such project"))?; + if project.host_connection_id == connection_id { for authorized_user_id in &worktree.authorized_user_ids { self.visible_projects_by_user_id .entry(*authorized_user_id) @@ -270,9 +275,9 @@ impl Store { #[cfg(test)] self.check_invariants(); - true + Ok(()) } else { - false + Err(anyhow!("no such project"))? } } @@ -280,7 +285,7 @@ impl Store { &mut self, project_id: u64, connection_id: ConnectionId, - ) -> Option { + ) -> tide::Result { match self.projects.entry(project_id) { hash_map::Entry::Occupied(e) => { if e.get().host_connection_id == connection_id { @@ -292,12 +297,12 @@ impl Store { } } - Some(e.remove()) + Ok(e.remove()) } else { - None + Err(anyhow!("no such project"))? } } - hash_map::Entry::Vacant(_) => None, + hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?, } } @@ -398,20 +403,26 @@ impl Store { connection_id: ConnectionId, entries: HashMap, diagnostic_summaries: BTreeMap, - ) -> Option { - let project = self.projects.get_mut(&project_id)?; - let worktree = project.worktrees.get_mut(&worktree_id)?; + ) -> tide::Result { + let project = self + .projects + .get_mut(&project_id) + .ok_or_else(|| anyhow!("no such project"))?; + let worktree = project + .worktrees + .get_mut(&worktree_id) + .ok_or_else(|| anyhow!("no such worktree"))?; if project.host_connection_id == connection_id && project.share.is_some() { worktree.share = Some(WorktreeShare { entries, diagnostic_summaries, }); - Some(SharedWorktree { + Ok(SharedWorktree { authorized_user_ids: project.authorized_user_ids(), connection_ids: project.guest_connection_ids(), }) } else { - None + Err(anyhow!("no such worktree"))? } } @@ -421,19 +432,25 @@ impl Store { worktree_id: u64, connection_id: ConnectionId, summary: proto::DiagnosticSummary, - ) -> Option> { - let project = self.projects.get_mut(&project_id)?; - let worktree = project.worktrees.get_mut(&worktree_id)?; + ) -> tide::Result> { + let project = self + .projects + .get_mut(&project_id) + .ok_or_else(|| anyhow!("no such project"))?; + let worktree = project + .worktrees + .get_mut(&worktree_id) + .ok_or_else(|| anyhow!("no such worktree"))?; if project.host_connection_id == connection_id { if let Some(share) = worktree.share.as_mut() { share .diagnostic_summaries .insert(summary.path.clone().into(), summary); - return Some(project.connection_ids()); + return Ok(project.connection_ids()); } } - None + Err(anyhow!("no such worktree"))? } pub fn join_project( @@ -481,10 +498,19 @@ impl Store { &mut self, connection_id: ConnectionId, project_id: u64, - ) -> Option { - let project = self.projects.get_mut(&project_id)?; - let share = project.share.as_mut()?; - let (replica_id, _) = share.guests.remove(&connection_id)?; + ) -> tide::Result { + let project = self + .projects + .get_mut(&project_id) + .ok_or_else(|| anyhow!("no such project"))?; + let share = project + .share + .as_mut() + .ok_or_else(|| anyhow!("project is not shared"))?; + let (replica_id, _) = share + .guests + .remove(&connection_id) + .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?; share.active_replica_ids.remove(&replica_id); if let Some(connection) = self.connections.get_mut(&connection_id) { @@ -497,7 +523,7 @@ impl Store { #[cfg(test)] self.check_invariants(); - Some(LeftProject { + Ok(LeftProject { connection_ids, authorized_user_ids, }) @@ -510,31 +536,40 @@ impl Store { worktree_id: u64, removed_entries: &[u64], updated_entries: &[proto::Entry], - ) -> Option> { + ) -> tide::Result> { let project = self.write_project(project_id, connection_id)?; - let share = project.worktrees.get_mut(&worktree_id)?.share.as_mut()?; + let share = project + .worktrees + .get_mut(&worktree_id) + .ok_or_else(|| anyhow!("no such worktree"))? + .share + .as_mut() + .ok_or_else(|| anyhow!("worktree is not shared"))?; for entry_id in removed_entries { share.entries.remove(&entry_id); } for entry in updated_entries { share.entries.insert(entry.id, entry.clone()); } - Some(project.connection_ids()) + Ok(project.connection_ids()) } pub fn project_connection_ids( &self, project_id: u64, acting_connection_id: ConnectionId, - ) -> Option> { - Some( - self.read_project(project_id, acting_connection_id)? - .connection_ids(), - ) + ) -> tide::Result> { + Ok(self + .read_project(project_id, acting_connection_id)? + .connection_ids()) } - pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option> { - Some(self.channels.get(&channel_id)?.connection_ids()) + pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result> { + Ok(self + .channels + .get(&channel_id) + .ok_or_else(|| anyhow!("no such channel"))? + .connection_ids()) } #[cfg(test)] @@ -542,14 +577,26 @@ impl Store { self.projects.get(&project_id) } - pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Option<&Project> { - let project = self.projects.get(&project_id)?; + pub fn read_project( + &self, + project_id: u64, + connection_id: ConnectionId, + ) -> tide::Result<&Project> { + let project = self + .projects + .get(&project_id) + .ok_or_else(|| anyhow!("no such project"))?; if project.host_connection_id == connection_id - || project.share.as_ref()?.guests.contains_key(&connection_id) + || project + .share + .as_ref() + .ok_or_else(|| anyhow!("project is not shared"))? + .guests + .contains_key(&connection_id) { - Some(project) + Ok(project) } else { - None + Err(anyhow!("no such project"))? } } @@ -557,14 +604,22 @@ impl Store { &mut self, project_id: u64, connection_id: ConnectionId, - ) -> Option<&mut Project> { - let project = self.projects.get_mut(&project_id)?; + ) -> tide::Result<&mut Project> { + let project = self + .projects + .get_mut(&project_id) + .ok_or_else(|| anyhow!("no such project"))?; if project.host_connection_id == connection_id - || project.share.as_ref()?.guests.contains_key(&connection_id) + || project + .share + .as_ref() + .ok_or_else(|| anyhow!("project is not shared"))? + .guests + .contains_key(&connection_id) { - Some(project) + Ok(project) } else { - None + Err(anyhow!("no such project"))? } }