use super::{ proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage}, Connection, }; use anyhow::{anyhow, Context, Result}; use collections::HashMap; use futures::{ channel::{mpsc, oneshot}, stream::BoxStream, FutureExt, SinkExt, StreamExt, }; use parking_lot::{Mutex, RwLock}; use serde::{ser::SerializeStruct, Serialize}; use std::sync::atomic::Ordering::SeqCst; use std::{ fmt, future::Future, marker::PhantomData, sync::{ atomic::{self, AtomicU32}, Arc, }, time::Duration, }; use tracing::instrument; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)] pub struct ConnectionId(pub u32); impl fmt::Display for ConnectionId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub struct PeerId(pub u32); impl fmt::Display for PeerId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } pub struct Receipt { pub sender_id: ConnectionId, pub message_id: u32, payload_type: PhantomData, } impl Clone for Receipt { fn clone(&self) -> Self { Self { sender_id: self.sender_id, message_id: self.message_id, payload_type: PhantomData, } } } impl Copy for Receipt {} pub struct TypedEnvelope { pub sender_id: ConnectionId, pub original_sender_id: Option, pub message_id: u32, pub payload: T, } impl TypedEnvelope { pub fn original_sender_id(&self) -> Result { self.original_sender_id .ok_or_else(|| anyhow!("missing original_sender_id")) } } impl TypedEnvelope { pub fn receipt(&self) -> Receipt { Receipt { sender_id: self.sender_id, message_id: self.message_id, payload_type: PhantomData, } } } pub struct Peer { pub connections: RwLock>, next_connection_id: AtomicU32, } #[derive(Clone, Serialize)] pub struct ConnectionState { #[serde(skip)] outgoing_tx: mpsc::UnboundedSender, next_message_id: Arc, #[allow(clippy::type_complexity)] #[serde(skip)] response_channels: Arc)>>>>>, } const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1); const WRITE_TIMEOUT: Duration = Duration::from_secs(2); pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5); impl Peer { pub fn new() -> Arc { Arc::new(Self { connections: Default::default(), next_connection_id: Default::default(), }) } #[instrument(skip_all)] pub async fn add_connection( self: &Arc, connection: Connection, create_timer: F, ) -> ( ConnectionId, impl Future> + Send, BoxStream<'static, Box>, ) where F: Send + Fn(Duration) -> Fut, Fut: Send + Future, Out: Send, { // For outgoing messages, use an unbounded channel so that application code // can always send messages without yielding. For incoming messages, use a // bounded channel so that other peers will receive backpressure if they send // messages faster than this peer can process them. #[cfg(any(test, feature = "test-support"))] const INCOMING_BUFFER_SIZE: usize = 1; #[cfg(not(any(test, feature = "test-support")))] const INCOMING_BUFFER_SIZE: usize = 64; let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE); let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded(); let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst)); let connection_state = ConnectionState { outgoing_tx, next_message_id: Default::default(), response_channels: Arc::new(Mutex::new(Some(Default::default()))), }; let mut writer = MessageStream::new(connection.tx); let mut reader = MessageStream::new(connection.rx); let this = self.clone(); let response_channels = connection_state.response_channels.clone(); let handle_io = async move { tracing::debug!(%connection_id, "handle io future: start"); let _end_connection = util::defer(|| { response_channels.lock().take(); this.connections.write().remove(&connection_id); tracing::debug!(%connection_id, "handle io future: end"); }); // Send messages on this frequency so the connection isn't closed. let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse(); futures::pin_mut!(keepalive_timer); // Disconnect if we don't receive messages at least this frequently. let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse(); futures::pin_mut!(receive_timeout); loop { tracing::debug!(%connection_id, "outer loop iteration start"); let read_message = reader.read().fuse(); futures::pin_mut!(read_message); loop { tracing::debug!(%connection_id, "inner loop iteration start"); futures::select_biased! { outgoing = outgoing_rx.next().fuse() => match outgoing { Some(outgoing) => { tracing::debug!(%connection_id, "outgoing rpc message: writing"); futures::select_biased! { result = writer.write(outgoing).fuse() => { tracing::debug!(%connection_id, "outgoing rpc message: done writing"); result.context("failed to write RPC message")?; tracing::debug!(%connection_id, "keepalive interval: resetting after sending message"); keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse()); } _ = create_timer(WRITE_TIMEOUT).fuse() => { tracing::debug!(%connection_id, "outgoing rpc message: writing timed out"); Err(anyhow!("timed out writing message"))?; } } } None => { tracing::debug!(%connection_id, "outgoing rpc message: channel closed"); return Ok(()) }, }, _ = keepalive_timer => { tracing::debug!(%connection_id, "keepalive interval: pinging"); futures::select_biased! { result = writer.write(proto::Message::Ping).fuse() => { tracing::debug!(%connection_id, "keepalive interval: done pinging"); result.context("failed to send keepalive")?; tracing::debug!(%connection_id, "keepalive interval: resetting after pinging"); keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse()); } _ = create_timer(WRITE_TIMEOUT).fuse() => { tracing::debug!(%connection_id, "keepalive interval: pinging timed out"); Err(anyhow!("timed out sending keepalive"))?; } } } incoming = read_message => { let incoming = incoming.context("error reading rpc message from socket")?; tracing::debug!(%connection_id, "incoming rpc message: received"); tracing::debug!(%connection_id, "receive timeout: resetting"); receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse()); if let proto::Message::Envelope(incoming) = incoming { tracing::debug!(%connection_id, "incoming rpc message: processing"); futures::select_biased! { result = incoming_tx.send(incoming).fuse() => match result { Ok(_) => { tracing::debug!(%connection_id, "incoming rpc message: processed"); } Err(_) => { tracing::debug!(%connection_id, "incoming rpc message: channel closed"); return Ok(()) } }, _ = create_timer(WRITE_TIMEOUT).fuse() => { tracing::debug!(%connection_id, "incoming rpc message: processing timed out"); Err(anyhow!("timed out processing incoming message"))? } } } break; }, _ = receive_timeout => { tracing::debug!(%connection_id, "receive timeout: delay between messages too long"); Err(anyhow!("delay between messages too long"))? } } } } }; let response_channels = connection_state.response_channels.clone(); self.connections .write() .insert(connection_id, connection_state); let incoming_rx = incoming_rx.filter_map(move |incoming| { let response_channels = response_channels.clone(); async move { let message_id = incoming.id; tracing::debug!(?incoming, "incoming message future: start"); let _end = util::defer(move || { tracing::debug!( %connection_id, message_id, "incoming message future: end" ); }); if let Some(responding_to) = incoming.responding_to { tracing::debug!( %connection_id, message_id, responding_to, "incoming response: received" ); let channel = response_channels.lock().as_mut()?.remove(&responding_to); if let Some(tx) = channel { let requester_resumed = oneshot::channel(); if let Err(error) = tx.send((incoming, requester_resumed.0)) { tracing::debug!( %connection_id, message_id, responding_to = responding_to, ?error, "incoming response: request future dropped", ); } tracing::debug!( %connection_id, message_id, responding_to, "incoming response: waiting to resume requester" ); let _ = requester_resumed.1.await; tracing::debug!( %connection_id, message_id, responding_to, "incoming response: requester resumed" ); } else { tracing::warn!( %connection_id, message_id, responding_to, "incoming response: unknown request" ); } None } else { tracing::debug!( %connection_id, message_id, "incoming message: received" ); proto::build_typed_envelope(connection_id, incoming).or_else(|| { tracing::error!( %connection_id, message_id, "unable to construct a typed envelope" ); None }) } } }); (connection_id, handle_io, incoming_rx.boxed()) } #[cfg(any(test, feature = "test-support"))] pub async fn add_test_connection( self: &Arc, connection: Connection, executor: Arc, ) -> ( ConnectionId, impl Future> + Send, BoxStream<'static, Box>, ) { let executor = executor.clone(); self.add_connection(connection, move |duration| executor.timer(duration)) .await } pub fn disconnect(&self, connection_id: ConnectionId) { self.connections.write().remove(&connection_id); } pub fn reset(&self) { self.connections.write().clear(); } pub fn request( &self, receiver_id: ConnectionId, request: T, ) -> impl Future> { self.request_internal(None, receiver_id, request) } pub fn forward_request( &self, sender_id: ConnectionId, receiver_id: ConnectionId, request: T, ) -> impl Future> { self.request_internal(Some(sender_id), receiver_id, request) } pub fn request_internal( &self, original_sender_id: Option, receiver_id: ConnectionId, request: T, ) -> impl Future> { let (tx, rx) = oneshot::channel(); let send = self.connection_state(receiver_id).and_then(|connection| { let message_id = connection.next_message_id.fetch_add(1, SeqCst); connection .response_channels .lock() .as_mut() .ok_or_else(|| anyhow!("connection was closed"))? .insert(message_id, tx); connection .outgoing_tx .unbounded_send(proto::Message::Envelope(request.into_envelope( message_id, None, original_sender_id.map(|id| id.0), ))) .map_err(|_| anyhow!("connection was closed"))?; Ok(()) }); async move { send?; let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?; if let Some(proto::envelope::Payload::Error(error)) = &response.payload { Err(anyhow!("RPC request failed - {}", error.message)) } else { T::Response::from_envelope(response) .ok_or_else(|| anyhow!("received response of the wrong type")) } } } pub fn send(&self, receiver_id: ConnectionId, message: T) -> Result<()> { let connection = self.connection_state(receiver_id)?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); connection .outgoing_tx .unbounded_send(proto::Message::Envelope( message.into_envelope(message_id, None, None), ))?; Ok(()) } pub fn forward_send( &self, sender_id: ConnectionId, receiver_id: ConnectionId, message: T, ) -> Result<()> { let connection = self.connection_state(receiver_id)?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); connection .outgoing_tx .unbounded_send(proto::Message::Envelope(message.into_envelope( message_id, None, Some(sender_id.0), )))?; Ok(()) } pub fn respond( &self, receipt: Receipt, response: T::Response, ) -> Result<()> { let connection = self.connection_state(receipt.sender_id)?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); connection .outgoing_tx .unbounded_send(proto::Message::Envelope(response.into_envelope( message_id, Some(receipt.message_id), None, )))?; Ok(()) } pub fn respond_with_error( &self, receipt: Receipt, response: proto::Error, ) -> Result<()> { let connection = self.connection_state(receipt.sender_id)?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); connection .outgoing_tx .unbounded_send(proto::Message::Envelope(response.into_envelope( message_id, Some(receipt.message_id), None, )))?; Ok(()) } fn connection_state(&self, connection_id: ConnectionId) -> Result { let connections = self.connections.read(); let connection = connections .get(&connection_id) .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?; Ok(connection.clone()) } } impl Serialize for Peer { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { let mut state = serializer.serialize_struct("Peer", 2)?; state.serialize_field("connections", &*self.connections.read())?; state.end() } } #[cfg(test)] mod tests { use super::*; use crate::TypedEnvelope; use async_tungstenite::tungstenite::Message as WebSocketMessage; use gpui::TestAppContext; #[ctor::ctor] fn init_logger() { if std::env::var("RUST_LOG").is_ok() { env_logger::init(); } } #[gpui::test(iterations = 50)] async fn test_request_response(cx: &mut TestAppContext) { let executor = cx.foreground(); // create 2 clients connected to 1 server let server = Peer::new(); let client1 = Peer::new(); let client2 = Peer::new(); let (client1_to_server_conn, server_to_client_1_conn, _kill) = Connection::in_memory(cx.background()); let (client1_conn_id, io_task1, client1_incoming) = client1 .add_test_connection(client1_to_server_conn, cx.background()) .await; let (_, io_task2, server_incoming1) = server .add_test_connection(server_to_client_1_conn, cx.background()) .await; let (client2_to_server_conn, server_to_client_2_conn, _kill) = Connection::in_memory(cx.background()); let (client2_conn_id, io_task3, client2_incoming) = client2 .add_test_connection(client2_to_server_conn, cx.background()) .await; let (_, io_task4, server_incoming2) = server .add_test_connection(server_to_client_2_conn, cx.background()) .await; executor.spawn(io_task1).detach(); executor.spawn(io_task2).detach(); executor.spawn(io_task3).detach(); executor.spawn(io_task4).detach(); executor .spawn(handle_messages(server_incoming1, server.clone())) .detach(); executor .spawn(handle_messages(client1_incoming, client1.clone())) .detach(); executor .spawn(handle_messages(server_incoming2, server.clone())) .detach(); executor .spawn(handle_messages(client2_incoming, client2.clone())) .detach(); assert_eq!( client1 .request(client1_conn_id, proto::Ping {},) .await .unwrap(), proto::Ack {} ); assert_eq!( client2 .request(client2_conn_id, proto::Ping {},) .await .unwrap(), proto::Ack {} ); assert_eq!( client1 .request(client1_conn_id, proto::Test { id: 1 },) .await .unwrap(), proto::Test { id: 1 } ); assert_eq!( client2 .request(client2_conn_id, proto::Test { id: 2 }) .await .unwrap(), proto::Test { id: 2 } ); client1.disconnect(client1_conn_id); client2.disconnect(client1_conn_id); async fn handle_messages( mut messages: BoxStream<'static, Box>, peer: Arc, ) -> Result<()> { while let Some(envelope) = messages.next().await { let envelope = envelope.into_any(); if let Some(envelope) = envelope.downcast_ref::>() { let receipt = envelope.receipt(); peer.respond(receipt, proto::Ack {})? } else if let Some(envelope) = envelope.downcast_ref::>() { peer.respond(envelope.receipt(), envelope.payload.clone())? } else { panic!("unknown message type"); } } Ok(()) } } #[gpui::test(iterations = 50)] async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) { let executor = cx.foreground(); let server = Peer::new(); let client = Peer::new(); let (client_to_server_conn, server_to_client_conn, _kill) = Connection::in_memory(cx.background()); let (client_to_server_conn_id, io_task1, mut client_incoming) = client .add_test_connection(client_to_server_conn, cx.background()) .await; let (server_to_client_conn_id, io_task2, mut server_incoming) = server .add_test_connection(server_to_client_conn, cx.background()) .await; executor.spawn(io_task1).detach(); executor.spawn(io_task2).detach(); executor .spawn(async move { let request = server_incoming .next() .await .unwrap() .into_any() .downcast::>() .unwrap(); server .send( server_to_client_conn_id, proto::Error { message: "message 1".to_string(), }, ) .unwrap(); server .send( server_to_client_conn_id, proto::Error { message: "message 2".to_string(), }, ) .unwrap(); server.respond(request.receipt(), proto::Ack {}).unwrap(); // Prevent the connection from being dropped server_incoming.next().await; }) .detach(); let events = Arc::new(Mutex::new(Vec::new())); let response = client.request(client_to_server_conn_id, proto::Ping {}); let response_task = executor.spawn({ let events = events.clone(); async move { response.await.unwrap(); events.lock().push("response".to_string()); } }); executor .spawn({ let events = events.clone(); async move { let incoming1 = client_incoming .next() .await .unwrap() .into_any() .downcast::>() .unwrap(); events.lock().push(incoming1.payload.message); let incoming2 = client_incoming .next() .await .unwrap() .into_any() .downcast::>() .unwrap(); events.lock().push(incoming2.payload.message); // Prevent the connection from being dropped client_incoming.next().await; } }) .detach(); response_task.await; assert_eq!( &*events.lock(), &[ "message 1".to_string(), "message 2".to_string(), "response".to_string() ] ); } #[gpui::test(iterations = 50)] async fn test_dropping_request_before_completion(cx: &mut TestAppContext) { let executor = cx.foreground(); let server = Peer::new(); let client = Peer::new(); let (client_to_server_conn, server_to_client_conn, _kill) = Connection::in_memory(cx.background()); let (client_to_server_conn_id, io_task1, mut client_incoming) = client .add_test_connection(client_to_server_conn, cx.background()) .await; let (server_to_client_conn_id, io_task2, mut server_incoming) = server .add_test_connection(server_to_client_conn, cx.background()) .await; executor.spawn(io_task1).detach(); executor.spawn(io_task2).detach(); executor .spawn(async move { let request1 = server_incoming .next() .await .unwrap() .into_any() .downcast::>() .unwrap(); let request2 = server_incoming .next() .await .unwrap() .into_any() .downcast::>() .unwrap(); server .send( server_to_client_conn_id, proto::Error { message: "message 1".to_string(), }, ) .unwrap(); server .send( server_to_client_conn_id, proto::Error { message: "message 2".to_string(), }, ) .unwrap(); server.respond(request1.receipt(), proto::Ack {}).unwrap(); server.respond(request2.receipt(), proto::Ack {}).unwrap(); // Prevent the connection from being dropped server_incoming.next().await; }) .detach(); let events = Arc::new(Mutex::new(Vec::new())); let request1 = client.request(client_to_server_conn_id, proto::Ping {}); let request1_task = executor.spawn(request1); let request2 = client.request(client_to_server_conn_id, proto::Ping {}); let request2_task = executor.spawn({ let events = events.clone(); async move { request2.await.unwrap(); events.lock().push("response 2".to_string()); } }); executor .spawn({ let events = events.clone(); async move { let incoming1 = client_incoming .next() .await .unwrap() .into_any() .downcast::>() .unwrap(); events.lock().push(incoming1.payload.message); let incoming2 = client_incoming .next() .await .unwrap() .into_any() .downcast::>() .unwrap(); events.lock().push(incoming2.payload.message); // Prevent the connection from being dropped client_incoming.next().await; } }) .detach(); // Allow the request to make some progress before dropping it. cx.background().simulate_random_delay().await; drop(request1_task); request2_task.await; assert_eq!( &*events.lock(), &[ "message 1".to_string(), "message 2".to_string(), "response 2".to_string() ] ); } #[gpui::test(iterations = 50)] async fn test_disconnect(cx: &mut TestAppContext) { let executor = cx.foreground(); let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background()); let client = Peer::new(); let (connection_id, io_handler, mut incoming) = client .add_test_connection(client_conn, cx.background()) .await; let (io_ended_tx, io_ended_rx) = oneshot::channel(); executor .spawn(async move { io_handler.await.ok(); io_ended_tx.send(()).unwrap(); }) .detach(); let (messages_ended_tx, messages_ended_rx) = oneshot::channel(); executor .spawn(async move { incoming.next().await; messages_ended_tx.send(()).unwrap(); }) .detach(); client.disconnect(connection_id); let _ = io_ended_rx.await; let _ = messages_ended_rx.await; assert!(server_conn .send(WebSocketMessage::Binary(vec![])) .await .is_err()); } #[gpui::test(iterations = 50)] async fn test_io_error(cx: &mut TestAppContext) { let executor = cx.foreground(); let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background()); let client = Peer::new(); let (connection_id, io_handler, mut incoming) = client .add_test_connection(client_conn, cx.background()) .await; executor.spawn(io_handler).detach(); executor .spawn(async move { incoming.next().await }) .detach(); let response = executor.spawn(client.request(connection_id, proto::Ping {})); let _request = server_conn.rx.next().await.unwrap().unwrap(); drop(server_conn); assert_eq!( response.await.unwrap_err().to_string(), "connection was closed" ); } }