use super::{ConnectionId, PeerId, TypedEnvelope}; use anyhow::Result; use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; use futures::{SinkExt as _, StreamExt as _}; use prost::Message; use std::any::{Any, TypeId}; use std::{ io, time::{Duration, SystemTime, UNIX_EPOCH}, }; include!(concat!(env!("OUT_DIR"), "/zed.messages.rs")); pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static { const NAME: &'static str; fn into_envelope( self, id: u32, responding_to: Option, original_sender_id: Option, ) -> Envelope; fn from_envelope(envelope: Envelope) -> Option; } pub trait EntityMessage: EnvelopedMessage { fn remote_entity_id(&self) -> u64; } pub trait RequestMessage: EnvelopedMessage { type Response: EnvelopedMessage; } pub trait AnyTypedEnvelope: 'static + Send + Sync { fn payload_type_id(&self) -> TypeId; fn payload_type_name(&self) -> &'static str; fn as_any(&self) -> &dyn Any; fn into_any(self: Box) -> Box; } impl AnyTypedEnvelope for TypedEnvelope { fn payload_type_id(&self) -> TypeId { TypeId::of::() } fn payload_type_name(&self) -> &'static str { T::NAME } fn as_any(&self) -> &dyn Any { self } fn into_any(self: Box) -> Box { self } } macro_rules! messages { ($($name:ident),* $(,)?) => { pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option> { match envelope.payload { $(Some(envelope::Payload::$name(payload)) => { Some(Box::new(TypedEnvelope { sender_id, original_sender_id: envelope.original_sender_id.map(PeerId), message_id: envelope.id, payload, })) }, )* _ => None } } $( impl EnvelopedMessage for $name { const NAME: &'static str = std::stringify!($name); fn into_envelope( self, id: u32, responding_to: Option, original_sender_id: Option, ) -> Envelope { Envelope { id, responding_to, original_sender_id, payload: Some(envelope::Payload::$name(self)), } } fn from_envelope(envelope: Envelope) -> Option { if let Some(envelope::Payload::$name(msg)) = envelope.payload { Some(msg) } else { None } } } )* }; } macro_rules! request_messages { ($(($request_name:ident, $response_name:ident)),* $(,)?) => { $(impl RequestMessage for $request_name { type Response = $response_name; })* }; } macro_rules! entity_messages { ($id_field:ident, $($name:ident),* $(,)?) => { $(impl EntityMessage for $name { fn remote_entity_id(&self) -> u64 { self.$id_field } })* }; } messages!( AddPeer, BufferSaved, ChannelMessageSent, CloseBuffer, CloseWorktree, GetChannels, GetChannelsResponse, GetUsers, GetUsersResponse, JoinChannel, JoinChannelResponse, LeaveChannel, OpenBuffer, OpenBufferResponse, OpenWorktree, OpenWorktreeResponse, Ping, Pong, RemovePeer, SaveBuffer, SendChannelMessage, SendChannelMessageResponse, ShareWorktree, ShareWorktreeResponse, UpdateBuffer, UpdateWorktree, ); request_messages!( (GetChannels, GetChannelsResponse), (GetUsers, GetUsersResponse), (JoinChannel, JoinChannelResponse), (OpenBuffer, OpenBufferResponse), (OpenWorktree, OpenWorktreeResponse), (Ping, Pong), (SaveBuffer, BufferSaved), (ShareWorktree, ShareWorktreeResponse), (SendChannelMessage, SendChannelMessageResponse), ); entity_messages!( worktree_id, AddPeer, BufferSaved, CloseBuffer, CloseWorktree, OpenBuffer, OpenWorktree, RemovePeer, SaveBuffer, UpdateBuffer, UpdateWorktree, ); entity_messages!(channel_id, ChannelMessageSent); /// A stream of protobuf messages. pub struct MessageStream { stream: S, } impl MessageStream { pub fn new(stream: S) -> Self { Self { stream } } pub fn inner_mut(&mut self) -> &mut S { &mut self.stream } } impl MessageStream where S: futures::Sink + Unpin, { /// Write a given protobuf message to the stream. pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> { let mut buffer = Vec::with_capacity(message.encoded_len()); message .encode(&mut buffer) .map_err(|err| io::Error::from(err))?; self.stream.send(WebSocketMessage::Binary(buffer)).await?; Ok(()) } } impl MessageStream where S: futures::Stream> + Unpin, { /// Read a protobuf message of the given type from the stream. pub async fn read_message(&mut self) -> Result { while let Some(bytes) = self.stream.next().await { match bytes? { WebSocketMessage::Binary(bytes) => { let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?; return Ok(envelope); } WebSocketMessage::Close(_) => break, _ => {} } } Err(WebSocketError::ConnectionClosed) } } impl Into for Timestamp { fn into(self) -> SystemTime { UNIX_EPOCH .checked_add(Duration::new(self.seconds, self.nanos)) .unwrap() } } impl From for Timestamp { fn from(time: SystemTime) -> Self { let duration = time.duration_since(UNIX_EPOCH).unwrap(); Self { seconds: duration.as_secs(), nanos: duration.subsec_nanos(), } } } #[cfg(test)] mod tests { use super::*; use crate::test; #[test] fn test_round_trip_message() { smol::block_on(async { let stream = test::Channel::new(); let message1 = Ping { id: 5 }.into_envelope(3, None, None); let message2 = OpenBuffer { worktree_id: 0, path: "some/path".to_string(), } .into_envelope(5, None, None); let mut message_stream = MessageStream::new(stream); message_stream.write_message(&message1).await.unwrap(); message_stream.write_message(&message2).await.unwrap(); let decoded_message1 = message_stream.read_message().await.unwrap(); let decoded_message2 = message_stream.read_message().await.unwrap(); assert_eq!(decoded_message1, message1); assert_eq!(decoded_message2, message2); }); } }