Use an unbounded channel for peer's outgoing messages

Using a bounded channel may have blocked the collaboration server
from making progress handling RPC traffic.

There's no need to apply backpressure to calling code within the
same process - suspending a task that is attempting to call `send` has
an even greater memory cost than just buffering a protobuf message.

We do still want a bounded channel for incoming messages, so that
we provide backpressure to noisy peers - blocking their writes as opposed
to allowing them to buffer arbitrarily many messages in our server.

Co-Authored-By: Antonio Scandurra <me@as-cii.com>
Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Max Brunsfeld 2022-02-07 12:27:13 -08:00
parent 82afacd33d
commit d4fe1115e7
7 changed files with 341 additions and 472 deletions

View file

@ -17,7 +17,7 @@ use std::{
}; };
use sum_tree::{Bias, SumTree}; use sum_tree::{Bias, SumTree};
use time::OffsetDateTime; use time::OffsetDateTime;
use util::{post_inc, TryFutureExt}; use util::{post_inc, ResultExt as _, TryFutureExt};
pub struct ChannelList { pub struct ChannelList {
available_channels: Option<Vec<ChannelDetails>>, available_channels: Option<Vec<ChannelDetails>>,
@ -168,16 +168,12 @@ impl ChannelList {
impl Entity for Channel { impl Entity for Channel {
type Event = ChannelEvent; type Event = ChannelEvent;
fn release(&mut self, cx: &mut MutableAppContext) { fn release(&mut self, _: &mut MutableAppContext) {
let rpc = self.rpc.clone(); self.rpc
let channel_id = self.details.id; .send(proto::LeaveChannel {
cx.foreground() channel_id: self.details.id,
.spawn(async move {
if let Err(error) = rpc.send(proto::LeaveChannel { channel_id }).await {
log::error!("error leaving channel: {}", error);
};
}) })
.detach() .log_err();
} }
} }
@ -718,18 +714,16 @@ mod tests {
}); });
// Receive a new message. // Receive a new message.
server server.send(proto::ChannelMessageSent {
.send(proto::ChannelMessageSent { channel_id: channel.read_with(&cx, |channel, _| channel.details.id),
channel_id: channel.read_with(&cx, |channel, _| channel.details.id), message: Some(proto::ChannelMessage {
message: Some(proto::ChannelMessage { id: 12,
id: 12, body: "c".into(),
body: "c".into(), timestamp: 1002,
timestamp: 1002, sender_id: 7,
sender_id: 7, nonce: Some(3.into()),
nonce: Some(3.into()), }),
}), });
})
.await;
// Client requests user for message since they haven't seen them yet // Client requests user for message since they haven't seen them yet
let get_users = server.receive::<proto::GetUsers>().await.unwrap(); let get_users = server.receive::<proto::GetUsers>().await.unwrap();

View file

@ -24,7 +24,6 @@ use std::{
collections::HashMap, collections::HashMap,
convert::TryFrom, convert::TryFrom,
fmt::Write as _, fmt::Write as _,
future::Future,
sync::{Arc, Weak}, sync::{Arc, Weak},
time::{Duration, Instant}, time::{Duration, Instant},
}; };
@ -677,8 +676,8 @@ impl Client {
} }
} }
pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> { pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
self.peer.send(self.connection_id()?, message).await self.peer.send(self.connection_id()?, message)
} }
pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> { pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
@ -689,7 +688,7 @@ impl Client {
&self, &self,
receipt: Receipt<T>, receipt: Receipt<T>,
response: T::Response, response: T::Response,
) -> impl Future<Output = Result<()>> { ) -> Result<()> {
self.peer.respond(receipt, response) self.peer.respond(receipt, response)
} }
@ -697,7 +696,7 @@ impl Client {
&self, &self,
receipt: Receipt<T>, receipt: Receipt<T>,
error: proto::Error, error: proto::Error,
) -> impl Future<Output = Result<()>> { ) -> Result<()> {
self.peer.respond_with_error(receipt, error) self.peer.respond_with_error(receipt, error)
} }
} }
@ -860,8 +859,8 @@ mod tests {
}); });
drop(subscription3); drop(subscription3);
server.send(proto::UnshareProject { project_id: 1 }).await; server.send(proto::UnshareProject { project_id: 1 });
server.send(proto::UnshareProject { project_id: 2 }).await; server.send(proto::UnshareProject { project_id: 2 });
done_rx1.next().await.unwrap(); done_rx1.next().await.unwrap();
done_rx2.next().await.unwrap(); done_rx2.next().await.unwrap();
} }
@ -890,7 +889,7 @@ mod tests {
Ok(()) Ok(())
}) })
}); });
server.send(proto::Ping {}).await; server.send(proto::Ping {});
done_rx2.next().await.unwrap(); done_rx2.next().await.unwrap();
} }
@ -914,7 +913,7 @@ mod tests {
}, },
)); ));
}); });
server.send(proto::Ping {}).await; server.send(proto::Ping {});
done_rx.next().await.unwrap(); done_rx.next().await.unwrap();
} }

View file

@ -118,8 +118,8 @@ impl FakeServer {
self.forbid_connections.store(false, SeqCst); self.forbid_connections.store(false, SeqCst);
} }
pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) { pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
self.peer.send(self.connection_id(), message).await.unwrap(); self.peer.send(self.connection_id(), message).unwrap();
} }
pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> { pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
@ -148,7 +148,7 @@ impl FakeServer {
receipt: Receipt<T>, receipt: Receipt<T>,
response: T::Response, response: T::Response,
) { ) {
self.peer.respond(receipt, response).await.unwrap() self.peer.respond(receipt, response).unwrap()
} }
fn connection_id(&self) -> ConnectionId { fn connection_id(&self) -> ConnectionId {

View file

@ -460,7 +460,7 @@ impl Project {
} }
})?; })?;
rpc.send(proto::UnshareProject { project_id }).await?; rpc.send(proto::UnshareProject { project_id })?;
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.collaborators.clear(); this.collaborators.clear();
this.shared_buffers.clear(); this.shared_buffers.clear();
@ -818,15 +818,13 @@ impl Project {
let this = cx.read(|cx| this.upgrade(cx))?; let this = cx.read(|cx| this.upgrade(cx))?;
match message { match message {
LspEvent::DiagnosticsStart => { LspEvent::DiagnosticsStart => {
let send = this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.disk_based_diagnostics_started(cx); this.disk_based_diagnostics_started(cx);
this.remote_id().map(|project_id| { if let Some(project_id) = this.remote_id() {
rpc.send(proto::DiskBasedDiagnosticsUpdating { project_id }) rpc.send(proto::DiskBasedDiagnosticsUpdating { project_id })
}) .log_err();
}
}); });
if let Some(send) = send {
send.await.log_err();
}
} }
LspEvent::DiagnosticsUpdate(mut params) => { LspEvent::DiagnosticsUpdate(mut params) => {
language.process_diagnostics(&mut params); language.process_diagnostics(&mut params);
@ -836,15 +834,13 @@ impl Project {
}); });
} }
LspEvent::DiagnosticsFinish => { LspEvent::DiagnosticsFinish => {
let send = this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.disk_based_diagnostics_finished(cx); this.disk_based_diagnostics_finished(cx);
this.remote_id().map(|project_id| { if let Some(project_id) = this.remote_id() {
rpc.send(proto::DiskBasedDiagnosticsUpdated { project_id }) rpc.send(proto::DiskBasedDiagnosticsUpdated { project_id })
}) .log_err();
}
}); });
if let Some(send) = send {
send.await.log_err();
}
} }
} }
} }
@ -1311,15 +1307,13 @@ impl Project {
}; };
if let Some(project_id) = self.remote_id() { if let Some(project_id) = self.remote_id() {
let client = self.client.clone(); self.client
let message = proto::UpdateBufferFile { .send(proto::UpdateBufferFile {
project_id, project_id,
buffer_id: *buffer_id as u64, buffer_id: *buffer_id as u64,
file: Some(new_file.to_proto()), file: Some(new_file.to_proto()),
}; })
cx.foreground() .log_err();
.spawn(async move { client.send(message).await })
.detach_and_log_err(cx);
} }
buffer.file_updated(Box::new(new_file), cx).detach(); buffer.file_updated(Box::new(new_file), cx).detach();
} }
@ -1639,8 +1633,7 @@ impl Project {
version: (&version).into(), version: (&version).into(),
mtime: Some(mtime.into()), mtime: Some(mtime.into()),
}, },
) )?;
.await?;
Ok(()) Ok(())
} }
@ -1669,16 +1662,13 @@ impl Project {
// associated with formatting. // associated with formatting.
cx.spawn(|_| async move { cx.spawn(|_| async move {
match format { match format {
Ok(()) => rpc.respond(receipt, proto::Ack {}).await?, Ok(()) => rpc.respond(receipt, proto::Ack {})?,
Err(error) => { Err(error) => rpc.respond_with_error(
rpc.respond_with_error( receipt,
receipt, proto::Error {
proto::Error { message: error.to_string(),
message: error.to_string(), },
}, )?,
)
.await?
}
} }
Ok::<_, anyhow::Error>(()) Ok::<_, anyhow::Error>(())
}) })
@ -1712,27 +1702,21 @@ impl Project {
.update(&mut cx, |buffer, cx| buffer.completions(position, cx)) .update(&mut cx, |buffer, cx| buffer.completions(position, cx))
.await .await
{ {
Ok(completions) => { Ok(completions) => rpc.respond(
rpc.respond( receipt,
receipt, proto::GetCompletionsResponse {
proto::GetCompletionsResponse { completions: completions
completions: completions .iter()
.iter() .map(language::proto::serialize_completion)
.map(language::proto::serialize_completion) .collect(),
.collect(), },
}, ),
) Err(error) => rpc.respond_with_error(
.await receipt,
} proto::Error {
Err(error) => { message: error.to_string(),
rpc.respond_with_error( },
receipt, ),
proto::Error {
message: error.to_string(),
},
)
.await
}
} }
}) })
.detach_and_log_err(cx); .detach_and_log_err(cx);
@ -1767,30 +1751,24 @@ impl Project {
}) })
.await .await
{ {
Ok(edit_ids) => { Ok(edit_ids) => rpc.respond(
rpc.respond( receipt,
receipt, proto::ApplyCompletionAdditionalEditsResponse {
proto::ApplyCompletionAdditionalEditsResponse { additional_edits: edit_ids
additional_edits: edit_ids .into_iter()
.into_iter() .map(|edit_id| proto::AdditionalEdit {
.map(|edit_id| proto::AdditionalEdit { replica_id: edit_id.replica_id as u32,
replica_id: edit_id.replica_id as u32, local_timestamp: edit_id.value,
local_timestamp: edit_id.value, })
}) .collect(),
.collect(), },
}, ),
) Err(error) => rpc.respond_with_error(
.await receipt,
} proto::Error {
Err(error) => { message: error.to_string(),
rpc.respond_with_error( },
receipt, ),
proto::Error {
message: error.to_string(),
},
)
.await
}
} }
}) })
.detach_and_log_err(cx); .detach_and_log_err(cx);
@ -1836,7 +1814,7 @@ impl Project {
}); });
} }
}); });
rpc.respond(receipt, response).await?; rpc.respond(receipt, response)?;
Ok::<_, anyhow::Error>(()) Ok::<_, anyhow::Error>(())
}) })
.detach_and_log_err(cx); .detach_and_log_err(cx);
@ -1872,7 +1850,6 @@ impl Project {
buffer: Some(buffer), buffer: Some(buffer),
}, },
) )
.await
} }
.log_err() .log_err()
}) })
@ -2106,28 +2083,21 @@ impl<'a> Iterator for CandidateSetIter<'a> {
impl Entity for Project { impl Entity for Project {
type Event = Event; type Event = Event;
fn release(&mut self, cx: &mut gpui::MutableAppContext) { fn release(&mut self, _: &mut gpui::MutableAppContext) {
match &self.client_state { match &self.client_state {
ProjectClientState::Local { remote_id_rx, .. } => { ProjectClientState::Local { remote_id_rx, .. } => {
if let Some(project_id) = *remote_id_rx.borrow() { if let Some(project_id) = *remote_id_rx.borrow() {
let rpc = self.client.clone(); self.client
cx.spawn(|_| async move { .send(proto::UnregisterProject { project_id })
if let Err(err) = rpc.send(proto::UnregisterProject { project_id }).await { .log_err();
log::error!("error unregistering project: {}", err);
}
})
.detach();
} }
} }
ProjectClientState::Remote { remote_id, .. } => { ProjectClientState::Remote { remote_id, .. } => {
let rpc = self.client.clone(); self.client
let project_id = *remote_id; .send(proto::LeaveProject {
cx.spawn(|_| async move { project_id: *remote_id,
if let Err(err) = rpc.send(proto::LeaveProject { project_id }).await { })
log::error!("error leaving project: {}", err); .log_err();
}
})
.detach();
} }
} }
} }

View file

@ -149,7 +149,7 @@ pub enum Event {
impl Entity for Worktree { impl Entity for Worktree {
type Event = Event; type Event = Event;
fn release(&mut self, cx: &mut MutableAppContext) { fn release(&mut self, _: &mut MutableAppContext) {
if let Some(worktree) = self.as_local_mut() { if let Some(worktree) = self.as_local_mut() {
if let Registration::Done { project_id } = worktree.registration { if let Registration::Done { project_id } = worktree.registration {
let client = worktree.client.clone(); let client = worktree.client.clone();
@ -157,12 +157,7 @@ impl Entity for Worktree {
project_id, project_id,
worktree_id: worktree.id().to_proto(), worktree_id: worktree.id().to_proto(),
}; };
cx.foreground() client.send(unregister_message).log_err();
.spawn(async move {
client.send(unregister_message).await?;
Ok::<_, anyhow::Error>(())
})
.detach_and_log_err(cx);
} }
} }
} }
@ -596,7 +591,7 @@ impl LocalWorktree {
&mut self, &mut self,
worktree_path: Arc<Path>, worktree_path: Arc<Path>,
diagnostics: Vec<DiagnosticEntry<PointUtf16>>, diagnostics: Vec<DiagnosticEntry<PointUtf16>>,
cx: &mut ModelContext<Worktree>, _: &mut ModelContext<Worktree>,
) -> Result<()> { ) -> Result<()> {
let summary = DiagnosticSummary::new(&diagnostics); let summary = DiagnosticSummary::new(&diagnostics);
self.diagnostic_summaries self.diagnostic_summaries
@ -604,30 +599,19 @@ impl LocalWorktree {
self.diagnostics.insert(worktree_path.clone(), diagnostics); self.diagnostics.insert(worktree_path.clone(), diagnostics);
if let Some(share) = self.share.as_ref() { if let Some(share) = self.share.as_ref() {
cx.foreground() self.client
.spawn({ .send(proto::UpdateDiagnosticSummary {
let client = self.client.clone(); project_id: share.project_id,
let project_id = share.project_id; worktree_id: self.id().to_proto(),
let worktree_id = self.id().to_proto(); summary: Some(proto::DiagnosticSummary {
let path = worktree_path.to_string_lossy().to_string(); path: worktree_path.to_string_lossy().to_string(),
async move { error_count: summary.error_count as u32,
client warning_count: summary.warning_count as u32,
.send(proto::UpdateDiagnosticSummary { info_count: summary.info_count as u32,
project_id, hint_count: summary.hint_count as u32,
worktree_id, }),
summary: Some(proto::DiagnosticSummary {
path,
error_count: summary.error_count as u32,
warning_count: summary.warning_count as u32,
info_count: summary.info_count as u32,
hint_count: summary.hint_count as u32,
}),
})
.await
.log_err()
}
}) })
.detach(); .log_err();
} }
Ok(()) Ok(())
@ -787,7 +771,7 @@ impl LocalWorktree {
while let Ok(snapshot) = snapshots_to_send_rx.recv().await { while let Ok(snapshot) = snapshots_to_send_rx.recv().await {
let message = let message =
snapshot.build_update(&prev_snapshot, project_id, worktree_id, false); snapshot.build_update(&prev_snapshot, project_id, worktree_id, false);
match rpc.send(message).await { match rpc.send(message) {
Ok(()) => prev_snapshot = snapshot, Ok(()) => prev_snapshot = snapshot,
Err(err) => log::error!("error sending snapshot diff {}", err), Err(err) => log::error!("error sending snapshot diff {}", err),
} }
@ -1377,8 +1361,7 @@ impl language::File for File {
buffer_id, buffer_id,
version: (&version).into(), version: (&version).into(),
mtime: Some(entry.mtime.into()), mtime: Some(entry.mtime.into()),
}) })?;
.await?;
} }
Ok((version, entry.mtime)) Ok((version, entry.mtime))
}) })
@ -1501,23 +1484,15 @@ impl language::File for File {
} }
fn buffer_removed(&self, buffer_id: u64, cx: &mut MutableAppContext) { fn buffer_removed(&self, buffer_id: u64, cx: &mut MutableAppContext) {
self.worktree.update(cx, |worktree, cx| { self.worktree.update(cx, |worktree, _| {
if let Worktree::Remote(worktree) = worktree { if let Worktree::Remote(worktree) = worktree {
let project_id = worktree.project_id; worktree
let rpc = worktree.client.clone(); .client
cx.background() .send(proto::CloseBuffer {
.spawn(async move { project_id: worktree.project_id,
if let Err(error) = rpc buffer_id,
.send(proto::CloseBuffer {
project_id,
buffer_id,
})
.await
{
log::error!("error closing remote buffer: {}", error);
}
}) })
.detach(); .log_err();
} }
}); });
} }
@ -1563,16 +1538,15 @@ impl language::LocalFile for File {
) { ) {
let worktree = self.worktree.read(cx).as_local().unwrap(); let worktree = self.worktree.read(cx).as_local().unwrap();
if let Some(project_id) = worktree.share.as_ref().map(|share| share.project_id) { if let Some(project_id) = worktree.share.as_ref().map(|share| share.project_id) {
let rpc = worktree.client.clone(); worktree
let message = proto::BufferReloaded { .client
project_id, .send(proto::BufferReloaded {
buffer_id, project_id,
version: version.into(), buffer_id,
mtime: Some(mtime.into()), version: version.into(),
}; mtime: Some(mtime.into()),
cx.background() })
.spawn(async move { rpc.send(message).await }) .log_err();
.detach_and_log_err(cx);
} }
} }
} }

View file

@ -89,7 +89,7 @@ pub struct Peer {
#[derive(Clone)] #[derive(Clone)]
pub struct ConnectionState { pub struct ConnectionState {
outgoing_tx: mpsc::Sender<proto::Envelope>, outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
next_message_id: Arc<AtomicU32>, next_message_id: Arc<AtomicU32>,
response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>, response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
} }
@ -112,9 +112,14 @@ impl Peer {
impl Future<Output = anyhow::Result<()>> + Send, impl Future<Output = anyhow::Result<()>> + Send,
BoxStream<'static, Box<dyn AnyTypedEnvelope>>, BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
) { ) {
let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst)); // For outgoing messages, use an unbounded channel so that application code
// can always send messages without yielding. For incoming messages, use a
// bounded channel so that other peers will receive backpressure if they send
// messages faster than this peer can process them.
let (mut incoming_tx, incoming_rx) = mpsc::channel(64); let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64); let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
let connection_state = ConnectionState { let connection_state = ConnectionState {
outgoing_tx, outgoing_tx,
next_message_id: Default::default(), next_message_id: Default::default(),
@ -131,6 +136,16 @@ impl Peer {
futures::pin_mut!(read_message); futures::pin_mut!(read_message);
loop { loop {
futures::select_biased! { futures::select_biased! {
outgoing = outgoing_rx.next().fuse() => match outgoing {
Some(outgoing) => {
match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
None => break 'outer Err(anyhow!("timed out writing RPC message")),
Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
_ => {}
}
}
None => break 'outer Ok(()),
},
incoming = read_message => match incoming { incoming = read_message => match incoming {
Ok(incoming) => { Ok(incoming) => {
if incoming_tx.send(incoming).await.is_err() { if incoming_tx.send(incoming).await.is_err() {
@ -142,16 +157,6 @@ impl Peer {
break 'outer Err(error).context("received invalid RPC message") break 'outer Err(error).context("received invalid RPC message")
} }
}, },
outgoing = outgoing_rx.recv().fuse() => match outgoing {
Some(outgoing) => {
match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
None => break 'outer Err(anyhow!("timed out writing RPC message")),
Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
_ => {}
}
}
None => break 'outer Ok(()),
}
} }
} }
}; };
@ -223,9 +228,9 @@ impl Peer {
request: T, request: T,
) -> impl Future<Output = Result<T::Response>> { ) -> impl Future<Output = Result<T::Response>> {
let this = self.clone(); let this = self.clone();
let (tx, mut rx) = mpsc::channel(1);
async move { async move {
let mut connection = this.connection_state(receiver_id)?; let (tx, mut rx) = mpsc::channel(1);
let connection = this.connection_state(receiver_id)?;
let message_id = connection.next_message_id.fetch_add(1, SeqCst); let message_id = connection.next_message_id.fetch_add(1, SeqCst);
connection connection
.response_channels .response_channels
@ -235,8 +240,11 @@ impl Peer {
.insert(message_id, tx); .insert(message_id, tx);
connection connection
.outgoing_tx .outgoing_tx
.send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0))) .unbounded_send(request.into_envelope(
.await message_id,
None,
original_sender_id.map(|id| id.0),
))
.map_err(|_| anyhow!("connection was closed"))?; .map_err(|_| anyhow!("connection was closed"))?;
let response = rx let response = rx
.recv() .recv()
@ -255,19 +263,15 @@ impl Peer {
self: &Arc<Self>, self: &Arc<Self>,
receiver_id: ConnectionId, receiver_id: ConnectionId,
message: T, message: T,
) -> impl Future<Output = Result<()>> { ) -> Result<()> {
let this = self.clone(); let connection = self.connection_state(receiver_id)?;
async move { let message_id = connection
let mut connection = this.connection_state(receiver_id)?; .next_message_id
let message_id = connection .fetch_add(1, atomic::Ordering::SeqCst);
.next_message_id connection
.fetch_add(1, atomic::Ordering::SeqCst); .outgoing_tx
connection .unbounded_send(message.into_envelope(message_id, None, None))?;
.outgoing_tx Ok(())
.send(message.into_envelope(message_id, None, None))
.await?;
Ok(())
}
} }
pub fn forward_send<T: EnvelopedMessage>( pub fn forward_send<T: EnvelopedMessage>(
@ -275,57 +279,45 @@ impl Peer {
sender_id: ConnectionId, sender_id: ConnectionId,
receiver_id: ConnectionId, receiver_id: ConnectionId,
message: T, message: T,
) -> impl Future<Output = Result<()>> { ) -> Result<()> {
let this = self.clone(); let connection = self.connection_state(receiver_id)?;
async move { let message_id = connection
let mut connection = this.connection_state(receiver_id)?; .next_message_id
let message_id = connection .fetch_add(1, atomic::Ordering::SeqCst);
.next_message_id connection
.fetch_add(1, atomic::Ordering::SeqCst); .outgoing_tx
connection .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
.outgoing_tx Ok(())
.send(message.into_envelope(message_id, None, Some(sender_id.0)))
.await?;
Ok(())
}
} }
pub fn respond<T: RequestMessage>( pub fn respond<T: RequestMessage>(
self: &Arc<Self>, self: &Arc<Self>,
receipt: Receipt<T>, receipt: Receipt<T>,
response: T::Response, response: T::Response,
) -> impl Future<Output = Result<()>> { ) -> Result<()> {
let this = self.clone(); let connection = self.connection_state(receipt.sender_id)?;
async move { let message_id = connection
let mut connection = this.connection_state(receipt.sender_id)?; .next_message_id
let message_id = connection .fetch_add(1, atomic::Ordering::SeqCst);
.next_message_id connection
.fetch_add(1, atomic::Ordering::SeqCst); .outgoing_tx
connection .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
.outgoing_tx Ok(())
.send(response.into_envelope(message_id, Some(receipt.message_id), None))
.await?;
Ok(())
}
} }
pub fn respond_with_error<T: RequestMessage>( pub fn respond_with_error<T: RequestMessage>(
self: &Arc<Self>, self: &Arc<Self>,
receipt: Receipt<T>, receipt: Receipt<T>,
response: proto::Error, response: proto::Error,
) -> impl Future<Output = Result<()>> { ) -> Result<()> {
let this = self.clone(); let connection = self.connection_state(receipt.sender_id)?;
async move { let message_id = connection
let mut connection = this.connection_state(receipt.sender_id)?; .next_message_id
let message_id = connection .fetch_add(1, atomic::Ordering::SeqCst);
.next_message_id connection
.fetch_add(1, atomic::Ordering::SeqCst); .outgoing_tx
connection .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
.outgoing_tx Ok(())
.send(response.into_envelope(message_id, Some(receipt.message_id), None))
.await?;
Ok(())
}
} }
fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> { fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
@ -447,7 +439,7 @@ mod tests {
let envelope = envelope.into_any(); let envelope = envelope.into_any();
if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() { if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
let receipt = envelope.receipt(); let receipt = envelope.receipt();
peer.respond(receipt, proto::Ack {}).await? peer.respond(receipt, proto::Ack {})?
} else if let Some(envelope) = } else if let Some(envelope) =
envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>() envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
{ {
@ -475,7 +467,7 @@ mod tests {
} }
}; };
peer.respond(receipt, response).await? peer.respond(receipt, response)?
} else { } else {
panic!("unknown message type"); panic!("unknown message type");
} }
@ -518,7 +510,6 @@ mod tests {
message: "message 1".to_string(), message: "message 1".to_string(),
}, },
) )
.await
.unwrap(); .unwrap();
server server
.send( .send(
@ -527,12 +518,8 @@ mod tests {
message: "message 2".to_string(), message: "message 2".to_string(),
}, },
) )
.await
.unwrap();
server
.respond(request.receipt(), proto::Ack {})
.await
.unwrap(); .unwrap();
server.respond(request.receipt(), proto::Ack {}).unwrap();
// Prevent the connection from being dropped // Prevent the connection from being dropped
server_incoming.next().await; server_incoming.next().await;

View file

@ -131,7 +131,7 @@ impl Server {
} }
this.state_mut().add_connection(connection_id, user_id); this.state_mut().add_connection(connection_id, user_id);
if let Err(err) = this.update_contacts_for_users(&[user_id]).await { if let Err(err) = this.update_contacts_for_users(&[user_id]) {
log::error!("error updating contacts for {:?}: {}", user_id, err); log::error!("error updating contacts for {:?}: {}", user_id, err);
} }
@ -141,6 +141,12 @@ impl Server {
let next_message = incoming_rx.next().fuse(); let next_message = incoming_rx.next().fuse();
futures::pin_mut!(next_message); futures::pin_mut!(next_message);
futures::select_biased! { futures::select_biased! {
result = handle_io => {
if let Err(err) = result {
log::error!("error handling rpc connection {:?} - {:?}", addr, err);
}
break;
}
message = next_message => { message = next_message => {
if let Some(message) = message { if let Some(message) = message {
let start_time = Instant::now(); let start_time = Instant::now();
@ -163,12 +169,6 @@ impl Server {
break; break;
} }
} }
handle_io = handle_io => {
if let Err(err) = handle_io {
log::error!("error handling rpc connection {:?} - {:?}", addr, err);
}
break;
}
} }
} }
@ -191,8 +191,7 @@ impl Server {
self.peer self.peer
.send(conn_id, proto::UnshareProject { project_id }) .send(conn_id, proto::UnshareProject { project_id })
}, },
) )?;
.await?;
} }
} }
@ -205,18 +204,15 @@ impl Server {
peer_id: connection_id.0, peer_id: connection_id.0,
}, },
) )
}) })?;
.await?;
} }
self.update_contacts_for_users(removed_connection.contact_ids.iter()) self.update_contacts_for_users(removed_connection.contact_ids.iter())?;
.await?;
Ok(()) Ok(())
} }
async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> { async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
self.peer.respond(request.receipt(), proto::Ack {}).await?; self.peer.respond(request.receipt(), proto::Ack {})?;
Ok(()) Ok(())
} }
@ -229,12 +225,10 @@ impl Server {
let user_id = state.user_id_for_connection(request.sender_id)?; let user_id = state.user_id_for_connection(request.sender_id)?;
state.register_project(request.sender_id, user_id) state.register_project(request.sender_id, user_id)
}; };
self.peer self.peer.respond(
.respond( request.receipt(),
request.receipt(), proto::RegisterProjectResponse { project_id },
proto::RegisterProjectResponse { project_id }, )?;
)
.await?;
Ok(()) Ok(())
} }
@ -246,8 +240,7 @@ impl Server {
.state_mut() .state_mut()
.unregister_project(request.payload.project_id, request.sender_id) .unregister_project(request.payload.project_id, request.sender_id)
.ok_or_else(|| anyhow!("no such project"))?; .ok_or_else(|| anyhow!("no such project"))?;
self.update_contacts_for_users(project.authorized_user_ids().iter()) self.update_contacts_for_users(project.authorized_user_ids().iter())?;
.await?;
Ok(()) Ok(())
} }
@ -257,7 +250,7 @@ impl Server {
) -> tide::Result<()> { ) -> tide::Result<()> {
self.state_mut() self.state_mut()
.share_project(request.payload.project_id, request.sender_id); .share_project(request.payload.project_id, request.sender_id);
self.peer.respond(request.receipt(), proto::Ack {}).await?; self.peer.respond(request.receipt(), proto::Ack {})?;
Ok(()) Ok(())
} }
@ -273,11 +266,8 @@ impl Server {
broadcast(request.sender_id, project.connection_ids, |conn_id| { broadcast(request.sender_id, project.connection_ids, |conn_id| {
self.peer self.peer
.send(conn_id, proto::UnshareProject { project_id }) .send(conn_id, proto::UnshareProject { project_id })
}) })?;
.await?; self.update_contacts_for_users(&project.authorized_user_ids)?;
self.update_contacts_for_users(&project.authorized_user_ids)
.await?;
Ok(()) Ok(())
} }
@ -351,20 +341,17 @@ impl Server {
}), }),
}, },
) )
}) })?;
.await?; self.peer.respond(request.receipt(), response)?;
self.peer.respond(request.receipt(), response).await?; self.update_contacts_for_users(&contact_user_ids)?;
self.update_contacts_for_users(&contact_user_ids).await?;
} }
Err(error) => { Err(error) => {
self.peer self.peer.respond_with_error(
.respond_with_error( request.receipt(),
request.receipt(), proto::Error {
proto::Error { message: error.to_string(),
message: error.to_string(), },
}, )?;
)
.await?;
} }
} }
@ -387,10 +374,8 @@ impl Server {
peer_id: sender_id.0, peer_id: sender_id.0,
}, },
) )
}) })?;
.await?; self.update_contacts_for_users(&worktree.authorized_user_ids)?;
self.update_contacts_for_users(&worktree.authorized_user_ids)
.await?;
} }
Ok(()) Ok(())
} }
@ -412,8 +397,7 @@ impl Server {
Err(err) => { Err(err) => {
let message = err.to_string(); let message = err.to_string();
self.peer self.peer
.respond_with_error(receipt, proto::Error { message }) .respond_with_error(receipt, proto::Error { message })?;
.await?;
return Ok(()); return Ok(());
} }
} }
@ -432,17 +416,15 @@ impl Server {
); );
if ok { if ok {
self.peer.respond(receipt, proto::Ack {}).await?; self.peer.respond(receipt, proto::Ack {})?;
self.update_contacts_for_users(&contact_user_ids).await?; self.update_contacts_for_users(&contact_user_ids)?;
} else { } else {
self.peer self.peer.respond_with_error(
.respond_with_error( receipt,
receipt, proto::Error {
proto::Error { message: NO_SUCH_PROJECT.to_string(),
message: NO_SUCH_PROJECT.to_string(), },
}, )?;
)
.await?;
} }
Ok(()) Ok(())
@ -457,7 +439,6 @@ impl Server {
let (worktree, guest_connection_ids) = let (worktree, guest_connection_ids) =
self.state_mut() self.state_mut()
.unregister_worktree(project_id, worktree_id, request.sender_id)?; .unregister_worktree(project_id, worktree_id, request.sender_id)?;
broadcast(request.sender_id, guest_connection_ids, |conn_id| { broadcast(request.sender_id, guest_connection_ids, |conn_id| {
self.peer.send( self.peer.send(
conn_id, conn_id,
@ -466,10 +447,8 @@ impl Server {
worktree_id, worktree_id,
}, },
) )
}) })?;
.await?; self.update_contacts_for_users(&worktree.authorized_user_ids)?;
self.update_contacts_for_users(&worktree.authorized_user_ids)
.await?;
Ok(()) Ok(())
} }
@ -511,20 +490,16 @@ impl Server {
request.payload.clone(), request.payload.clone(),
) )
}, },
) )?;
.await?; self.peer.respond(request.receipt(), proto::Ack {})?;
self.peer.respond(request.receipt(), proto::Ack {}).await?; self.update_contacts_for_users(&shared_worktree.authorized_user_ids)?;
self.update_contacts_for_users(&shared_worktree.authorized_user_ids)
.await?;
} else { } else {
self.peer self.peer.respond_with_error(
.respond_with_error( request.receipt(),
request.receipt(), proto::Error {
proto::Error { message: "no such worktree".to_string(),
message: "no such worktree".to_string(), },
}, )?;
)
.await?;
} }
Ok(()) Ok(())
} }
@ -547,8 +522,7 @@ impl Server {
broadcast(request.sender_id, connection_ids, |connection_id| { broadcast(request.sender_id, connection_ids, |connection_id| {
self.peer self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone()) .forward_send(request.sender_id, connection_id, request.payload.clone())
}) })?;
.await?;
Ok(()) Ok(())
} }
@ -574,8 +548,7 @@ impl Server {
broadcast(request.sender_id, receiver_ids, |connection_id| { broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone()) .forward_send(request.sender_id, connection_id, request.payload.clone())
}) })?;
.await?;
Ok(()) Ok(())
} }
@ -590,8 +563,7 @@ impl Server {
broadcast(request.sender_id, receiver_ids, |connection_id| { broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone()) .forward_send(request.sender_id, connection_id, request.payload.clone())
}) })?;
.await?;
Ok(()) Ok(())
} }
@ -606,8 +578,7 @@ impl Server {
broadcast(request.sender_id, receiver_ids, |connection_id| { broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone()) .forward_send(request.sender_id, connection_id, request.payload.clone())
}) })?;
.await?;
Ok(()) Ok(())
} }
@ -625,7 +596,7 @@ impl Server {
.peer .peer
.forward_request(request.sender_id, host_connection_id, request.payload) .forward_request(request.sender_id, host_connection_id, request.payload)
.await?; .await?;
self.peer.respond(receipt, response).await?; self.peer.respond(receipt, response)?;
Ok(()) Ok(())
} }
@ -643,7 +614,7 @@ impl Server {
.peer .peer
.forward_request(request.sender_id, host_connection_id, request.payload) .forward_request(request.sender_id, host_connection_id, request.payload)
.await?; .await?;
self.peer.respond(receipt, response).await?; self.peer.respond(receipt, response)?;
Ok(()) Ok(())
} }
@ -657,8 +628,7 @@ impl Server {
.ok_or_else(|| anyhow!(NO_SUCH_PROJECT))? .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))?
.host_connection_id; .host_connection_id;
self.peer self.peer
.forward_send(request.sender_id, host_connection_id, request.payload) .forward_send(request.sender_id, host_connection_id, request.payload)?;
.await?;
Ok(()) Ok(())
} }
@ -686,16 +656,12 @@ impl Server {
broadcast(host, guests, |conn_id| { broadcast(host, guests, |conn_id| {
let response = response.clone(); let response = response.clone();
let peer = &self.peer; if conn_id == sender {
async move { self.peer.respond(receipt, response)
if conn_id == sender { } else {
peer.respond(receipt, response).await self.peer.forward_send(host, conn_id, response)
} else {
peer.forward_send(host, conn_id, response).await
}
} }
}) })?;
.await?;
Ok(()) Ok(())
} }
@ -719,7 +685,7 @@ impl Server {
.peer .peer
.forward_request(sender, host, request.payload.clone()) .forward_request(sender, host, request.payload.clone())
.await?; .await?;
self.peer.respond(receipt, response).await?; self.peer.respond(receipt, response)?;
Ok(()) Ok(())
} }
@ -743,8 +709,7 @@ impl Server {
.peer .peer
.forward_request(sender, host, request.payload.clone()) .forward_request(sender, host, request.payload.clone())
.await?; .await?;
self.peer.respond(receipt, response).await?; self.peer.respond(receipt, response)?;
Ok(()) Ok(())
} }
@ -767,8 +732,7 @@ impl Server {
.peer .peer
.forward_request(sender, host, request.payload.clone()) .forward_request(sender, host, request.payload.clone())
.await?; .await?;
self.peer.respond(receipt, response).await?; self.peer.respond(receipt, response)?;
Ok(()) Ok(())
} }
@ -783,9 +747,8 @@ impl Server {
broadcast(request.sender_id, receiver_ids, |connection_id| { broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone()) .forward_send(request.sender_id, connection_id, request.payload.clone())
}) })?;
.await?; self.peer.respond(request.receipt(), proto::Ack {})?;
self.peer.respond(request.receipt(), proto::Ack {}).await?;
Ok(()) Ok(())
} }
@ -800,8 +763,7 @@ impl Server {
broadcast(request.sender_id, receiver_ids, |connection_id| { broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone()) .forward_send(request.sender_id, connection_id, request.payload.clone())
}) })?;
.await?;
Ok(()) Ok(())
} }
@ -816,8 +778,7 @@ impl Server {
broadcast(request.sender_id, receiver_ids, |connection_id| { broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone()) .forward_send(request.sender_id, connection_id, request.payload.clone())
}) })?;
.await?;
Ok(()) Ok(())
} }
@ -832,8 +793,7 @@ impl Server {
broadcast(request.sender_id, receiver_ids, |connection_id| { broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone()) .forward_send(request.sender_id, connection_id, request.payload.clone())
}) })?;
.await?;
Ok(()) Ok(())
} }
@ -843,20 +803,18 @@ impl Server {
) -> tide::Result<()> { ) -> tide::Result<()> {
let user_id = self.state().user_id_for_connection(request.sender_id)?; let user_id = self.state().user_id_for_connection(request.sender_id)?;
let channels = self.app_state.db.get_accessible_channels(user_id).await?; let channels = self.app_state.db.get_accessible_channels(user_id).await?;
self.peer self.peer.respond(
.respond( request.receipt(),
request.receipt(), proto::GetChannelsResponse {
proto::GetChannelsResponse { channels: channels
channels: channels .into_iter()
.into_iter() .map(|chan| proto::Channel {
.map(|chan| proto::Channel { id: chan.id.to_proto(),
id: chan.id.to_proto(), name: chan.name,
name: chan.name, })
}) .collect(),
.collect(), },
}, )?;
)
.await?;
Ok(()) Ok(())
} }
@ -879,34 +837,30 @@ impl Server {
}) })
.collect(); .collect();
self.peer self.peer
.respond(receipt, proto::GetUsersResponse { users }) .respond(receipt, proto::GetUsersResponse { users })?;
.await?;
Ok(()) Ok(())
} }
async fn update_contacts_for_users<'a>( fn update_contacts_for_users<'a>(
self: &Arc<Server>, self: &Arc<Server>,
user_ids: impl IntoIterator<Item = &'a UserId>, user_ids: impl IntoIterator<Item = &'a UserId>,
) -> tide::Result<()> { ) -> anyhow::Result<()> {
let mut send_futures = Vec::new(); let mut result = Ok(());
let state = self.state();
{ for user_id in user_ids {
let state = self.state(); let contacts = state.contacts_for_user(*user_id);
for user_id in user_ids { for connection_id in state.connection_ids_for_user(*user_id) {
let contacts = state.contacts_for_user(*user_id); if let Err(error) = self.peer.send(
for connection_id in state.connection_ids_for_user(*user_id) { connection_id,
send_futures.push(self.peer.send( proto::UpdateContacts {
connection_id, contacts: contacts.clone(),
proto::UpdateContacts { },
contacts: contacts.clone(), ) {
}, result = Err(error);
));
} }
} }
} }
futures::future::try_join_all(send_futures).await?; result
Ok(())
} }
async fn join_channel( async fn join_channel(
@ -939,15 +893,13 @@ impl Server {
nonce: Some(msg.nonce.as_u128().into()), nonce: Some(msg.nonce.as_u128().into()),
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
self.peer self.peer.respond(
.respond( request.receipt(),
request.receipt(), proto::JoinChannelResponse {
proto::JoinChannelResponse { done: messages.len() < MESSAGE_COUNT_PER_PAGE,
done: messages.len() < MESSAGE_COUNT_PER_PAGE, messages,
messages, },
}, )?;
)
.await?;
Ok(()) Ok(())
} }
@ -993,25 +945,21 @@ impl Server {
// Validate the message body. // Validate the message body.
let body = request.payload.body.trim().to_string(); let body = request.payload.body.trim().to_string();
if body.len() > MAX_MESSAGE_LEN { if body.len() > MAX_MESSAGE_LEN {
self.peer self.peer.respond_with_error(
.respond_with_error( receipt,
receipt, proto::Error {
proto::Error { message: "message is too long".to_string(),
message: "message is too long".to_string(), },
}, )?;
)
.await?;
return Ok(()); return Ok(());
} }
if body.is_empty() { if body.is_empty() {
self.peer self.peer.respond_with_error(
.respond_with_error( receipt,
receipt, proto::Error {
proto::Error { message: "message can't be blank".to_string(),
message: "message can't be blank".to_string(), },
}, )?;
)
.await?;
return Ok(()); return Ok(());
} }
@ -1019,14 +967,12 @@ impl Server {
let nonce = if let Some(nonce) = request.payload.nonce { let nonce = if let Some(nonce) = request.payload.nonce {
nonce nonce
} else { } else {
self.peer self.peer.respond_with_error(
.respond_with_error( receipt,
receipt, proto::Error {
proto::Error { message: "nonce can't be blank".to_string(),
message: "nonce can't be blank".to_string(), },
}, )?;
)
.await?;
return Ok(()); return Ok(());
}; };
@ -1051,16 +997,13 @@ impl Server {
message: Some(message.clone()), message: Some(message.clone()),
}, },
) )
}) })?;
.await?; self.peer.respond(
self.peer receipt,
.respond( proto::SendChannelMessageResponse {
receipt, message: Some(message),
proto::SendChannelMessageResponse { },
message: Some(message), )?;
},
)
.await?;
Ok(()) Ok(())
} }
@ -1097,15 +1040,13 @@ impl Server {
nonce: Some(msg.nonce.as_u128().into()), nonce: Some(msg.nonce.as_u128().into()),
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
self.peer self.peer.respond(
.respond( request.receipt(),
request.receipt(), proto::GetChannelMessagesResponse {
proto::GetChannelMessagesResponse { done: messages.len() < MESSAGE_COUNT_PER_PAGE,
done: messages.len() < MESSAGE_COUNT_PER_PAGE, messages,
messages, },
}, )?;
)
.await?;
Ok(()) Ok(())
} }
@ -1118,21 +1059,25 @@ impl Server {
} }
} }
pub async fn broadcast<F, T>( fn broadcast<F>(
sender_id: ConnectionId, sender_id: ConnectionId,
receiver_ids: Vec<ConnectionId>, receiver_ids: Vec<ConnectionId>,
mut f: F, mut f: F,
) -> anyhow::Result<()> ) -> anyhow::Result<()>
where where
F: FnMut(ConnectionId) -> T, F: FnMut(ConnectionId) -> anyhow::Result<()>,
T: Future<Output = anyhow::Result<()>>,
{ {
let futures = receiver_ids let mut result = Ok(());
.into_iter() for receiver_id in receiver_ids {
.filter(|id| *id != sender_id) if receiver_id != sender_id {
.map(|id| f(id)); if let Err(error) = f(receiver_id) {
futures::future::try_join_all(futures).await?; if result.is_ok() {
Ok(()) result = Err(error);
}
}
}
}
result
} }
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) { pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {