Start on a peer2 module with an alternative implementation

This commit is contained in:
Nathan Sobo 2021-08-18 21:59:13 -06:00
parent 3f5db7284d
commit 9336c02867
5 changed files with 520 additions and 35 deletions

View file

@ -6,35 +6,45 @@ message Envelope {
optional uint32 responding_to = 2;
optional uint32 original_sender_id = 3;
oneof payload {
Auth auth = 4;
AuthResponse auth_response = 5;
ShareWorktree share_worktree = 6;
ShareWorktreeResponse share_worktree_response = 7;
OpenWorktree open_worktree = 8;
OpenWorktreeResponse open_worktree_response = 9;
UpdateWorktree update_worktree = 10;
CloseWorktree close_worktree = 11;
OpenBuffer open_buffer = 12;
OpenBufferResponse open_buffer_response = 13;
CloseBuffer close_buffer = 14;
UpdateBuffer update_buffer = 15;
SaveBuffer save_buffer = 16;
BufferSaved buffer_saved = 17;
AddPeer add_peer = 18;
RemovePeer remove_peer = 19;
GetChannels get_channels = 20;
GetChannelsResponse get_channels_response = 21;
GetUsers get_users = 22;
GetUsersResponse get_users_response = 23;
JoinChannel join_channel = 24;
JoinChannelResponse join_channel_response = 25;
SendChannelMessage send_channel_message = 26;
ChannelMessageSent channel_message_sent = 27;
Ping ping = 4;
Pong pong = 5;
Auth auth = 6;
AuthResponse auth_response = 7;
ShareWorktree share_worktree = 8;
ShareWorktreeResponse share_worktree_response = 9;
OpenWorktree open_worktree = 10;
OpenWorktreeResponse open_worktree_response = 11;
UpdateWorktree update_worktree = 12;
CloseWorktree close_worktree = 13;
OpenBuffer open_buffer = 14;
OpenBufferResponse open_buffer_response = 15;
CloseBuffer close_buffer = 16;
UpdateBuffer update_buffer = 17;
SaveBuffer save_buffer = 18;
BufferSaved buffer_saved = 19;
AddPeer add_peer = 20;
RemovePeer remove_peer = 21;
GetChannels get_channels = 22;
GetChannelsResponse get_channels_response = 23;
GetUsers get_users = 24;
GetUsersResponse get_users_response = 25;
JoinChannel join_channel = 26;
JoinChannelResponse join_channel_response = 27;
SendChannelMessage send_channel_message = 28;
ChannelMessageSent channel_message_sent = 29;
}
}
// Messages
message Ping {
int32 id = 1;
}
message Pong {
int32 id = 2;
}
message Auth {
int32 user_id = 1;
string access_token = 2;

View file

@ -1,5 +1,6 @@
pub mod auth;
mod peer;
mod peer2;
pub mod proto;
#[cfg(any(test, feature = "test-support"))]
pub mod test;

View file

@ -38,8 +38,8 @@ type ForegroundMessageHandler =
Box<dyn Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<LocalBoxFuture<'static, ()>>>;
pub struct Receipt<T> {
sender_id: ConnectionId,
message_id: u32,
pub sender_id: ConnectionId,
pub message_id: u32,
payload_type: PhantomData<T>,
}
@ -172,7 +172,7 @@ impl Peer {
} else {
router.handle(connection_id, envelope.clone()).await;
if let Some(envelope) = proto::build_typed_envelope(connection_id, envelope) {
broadcast_incoming_messages.send(envelope).await.ok();
broadcast_incoming_messages.send(Arc::from(envelope)).await.ok();
} else {
log::error!("unable to construct a typed envelope");
}

470
zrpc/src/peer2.rs Normal file
View file

@ -0,0 +1,470 @@
use crate::{
proto::{self, EnvelopedMessage, MessageStream, RequestMessage},
ConnectionId, PeerId, Receipt,
};
use anyhow::{anyhow, Context, Result};
use async_lock::{Mutex, RwLock};
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
use futures::{FutureExt, StreamExt};
use postage::{
mpsc,
prelude::{Sink as _, Stream as _},
};
use std::{
any::Any,
collections::HashMap,
future::Future,
sync::{
atomic::{self, AtomicU32},
Arc,
},
};
pub struct Peer {
connections: RwLock<HashMap<ConnectionId, Connection>>,
next_connection_id: AtomicU32,
}
#[derive(Clone)]
struct Connection {
outgoing_tx: mpsc::Sender<proto::Envelope>,
next_message_id: Arc<AtomicU32>,
response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
}
impl Peer {
pub fn new() -> Arc<Self> {
Arc::new(Self {
connections: Default::default(),
next_connection_id: Default::default(),
})
}
pub async fn add_connection<Conn>(
self: &Arc<Self>,
conn: Conn,
) -> (
ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send,
mpsc::Receiver<Box<dyn Any + Sync + Send>>,
)
where
Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
+ Send
+ Unpin,
{
let (tx, rx) = conn.split();
let connection_id = ConnectionId(
self.next_connection_id
.fetch_add(1, atomic::Ordering::SeqCst),
);
let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
let connection = Connection {
outgoing_tx,
next_message_id: Default::default(),
response_channels: Default::default(),
};
let mut writer = MessageStream::new(tx);
let mut reader = MessageStream::new(rx);
let response_channels = connection.response_channels.clone();
let handle_io = async move {
loop {
let read_message = reader.read_message().fuse();
futures::pin_mut!(read_message);
loop {
futures::select_biased! {
incoming = read_message => match incoming {
Ok(incoming) => {
if let Some(responding_to) = incoming.responding_to {
let channel = response_channels.lock().await.remove(&responding_to);
if let Some(mut tx) = channel {
tx.send(incoming).await.ok();
} else {
log::warn!("received RPC response to unknown request {}", responding_to);
}
} else {
if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
if incoming_tx.send(envelope).await.is_err() {
response_channels.lock().await.clear();
return Ok(())
}
} else {
log::error!("unable to construct a typed envelope");
}
}
break;
}
Err(error) => {
response_channels.lock().await.clear();
Err(error).context("received invalid RPC message")?;
}
},
outgoing = outgoing_rx.recv().fuse() => match outgoing {
Some(outgoing) => {
if let Err(result) = writer.write_message(&outgoing).await {
response_channels.lock().await.clear();
Err(result).context("failed to write RPC message")?;
}
}
None => {
response_channels.lock().await.clear();
return Ok(())
}
}
}
}
}
};
self.connections
.write()
.await
.insert(connection_id, connection);
(connection_id, handle_io, incoming_rx)
}
pub async fn disconnect(&self, connection_id: ConnectionId) {
self.connections.write().await.remove(&connection_id);
}
pub async fn reset(&self) {
self.connections.write().await.clear();
}
pub fn request<T: RequestMessage>(
self: &Arc<Self>,
receiver_id: ConnectionId,
request: T,
) -> impl Future<Output = Result<T::Response>> {
self.request_internal(None, receiver_id, request)
}
pub fn forward_request<T: RequestMessage>(
self: &Arc<Self>,
sender_id: ConnectionId,
receiver_id: ConnectionId,
request: T,
) -> impl Future<Output = Result<T::Response>> {
self.request_internal(Some(sender_id), receiver_id, request)
}
pub fn request_internal<T: RequestMessage>(
self: &Arc<Self>,
original_sender_id: Option<ConnectionId>,
receiver_id: ConnectionId,
request: T,
) -> impl Future<Output = Result<T::Response>> {
let this = self.clone();
let (tx, mut rx) = mpsc::channel(1);
async move {
let mut connection = this.connection(receiver_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.response_channels
.lock()
.await
.insert(message_id, tx);
connection
.outgoing_tx
.send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
.await
.map_err(|_| anyhow!("connection was closed"))?;
let response = rx
.recv()
.await
.ok_or_else(|| anyhow!("connection was closed"))?;
T::Response::from_envelope(response)
.ok_or_else(|| anyhow!("received response of the wrong type"))
}
}
pub fn send<T: EnvelopedMessage>(
self: &Arc<Self>,
receiver_id: ConnectionId,
message: T,
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
let mut connection = this.connection(receiver_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
.send(message.into_envelope(message_id, None, None))
.await?;
Ok(())
}
}
pub fn forward_send<T: EnvelopedMessage>(
self: &Arc<Self>,
sender_id: ConnectionId,
receiver_id: ConnectionId,
message: T,
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
let mut connection = this.connection(receiver_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
.send(message.into_envelope(message_id, None, Some(sender_id.0)))
.await?;
Ok(())
}
}
pub fn respond<T: RequestMessage>(
self: &Arc<Self>,
receipt: Receipt<T>,
response: T::Response,
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
let mut connection = this.connection(receipt.sender_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
.send(response.into_envelope(message_id, Some(receipt.message_id), None))
.await?;
Ok(())
}
}
fn connection(
self: &Arc<Self>,
connection_id: ConnectionId,
) -> impl Future<Output = Result<Connection>> {
let this = self.clone();
async move {
let connections = this.connections.read().await;
let connection = connections
.get(&connection_id)
.ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
Ok(connection.clone())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{test, TypedEnvelope};
#[test]
fn test_request_response() {
smol::block_on(async move {
// create 2 clients connected to 1 server
let server = Peer::new();
let client1 = Peer::new();
let client2 = Peer::new();
let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
let (client1_conn_id, io_task1, _) =
client1.add_connection(client1_to_server_conn).await;
let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
let (client2_conn_id, io_task3, _) =
client2.add_connection(client2_to_server_conn).await;
let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
smol::spawn(io_task1).detach();
smol::spawn(io_task2).detach();
smol::spawn(io_task3).detach();
smol::spawn(io_task4).detach();
smol::spawn(handle_messages(incoming1, server.clone())).detach();
smol::spawn(handle_messages(incoming2, server.clone())).detach();
assert_eq!(
client1
.request(client1_conn_id, proto::Ping { id: 1 },)
.await
.unwrap(),
proto::Pong { id: 1 }
);
assert_eq!(
client2
.request(client2_conn_id, proto::Ping { id: 2 },)
.await
.unwrap(),
proto::Pong { id: 2 }
);
assert_eq!(
client1
.request(
client1_conn_id,
proto::OpenBuffer {
worktree_id: 1,
path: "path/one".to_string(),
},
)
.await
.unwrap(),
proto::OpenBufferResponse {
buffer: Some(proto::Buffer {
id: 101,
content: "path/one content".to_string(),
history: vec![],
selections: vec![],
}),
}
);
assert_eq!(
client2
.request(
client2_conn_id,
proto::OpenBuffer {
worktree_id: 2,
path: "path/two".to_string(),
},
)
.await
.unwrap(),
proto::OpenBufferResponse {
buffer: Some(proto::Buffer {
id: 102,
content: "path/two content".to_string(),
history: vec![],
selections: vec![],
}),
}
);
client1.disconnect(client1_conn_id).await;
client2.disconnect(client1_conn_id).await;
async fn handle_messages(
mut messages: mpsc::Receiver<Box<dyn Any + Sync + Send>>,
peer: Arc<Peer>,
) -> Result<()> {
while let Some(envelope) = messages.next().await {
if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
let receipt = envelope.receipt();
peer.respond(
receipt,
proto::Pong {
id: envelope.payload.id,
},
)
.await?
} else if let Some(envelope) =
envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
{
let message = &envelope.payload;
let receipt = envelope.receipt();
let response = match message.path.as_str() {
"path/one" => {
assert_eq!(message.worktree_id, 1);
proto::OpenBufferResponse {
buffer: Some(proto::Buffer {
id: 101,
content: "path/one content".to_string(),
history: vec![],
selections: vec![],
}),
}
}
"path/two" => {
assert_eq!(message.worktree_id, 2);
proto::OpenBufferResponse {
buffer: Some(proto::Buffer {
id: 102,
content: "path/two content".to_string(),
history: vec![],
selections: vec![],
}),
}
}
_ => {
panic!("unexpected path {}", message.path);
}
};
peer.respond(receipt, response).await?
} else {
panic!("unknown message type");
}
}
Ok(())
}
});
}
#[test]
fn test_disconnect() {
smol::block_on(async move {
let (client_conn, mut server_conn) = test::Channel::bidirectional();
let client = Peer::new();
let (connection_id, io_handler, mut incoming) =
client.add_connection(client_conn).await;
let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
smol::spawn(async move {
io_handler.await.ok();
io_ended_tx.send(()).await.unwrap();
})
.detach();
let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
smol::spawn(async move {
incoming.next().await;
messages_ended_tx.send(()).await.unwrap();
})
.detach();
client.disconnect(connection_id).await;
io_ended_rx.recv().await;
messages_ended_rx.recv().await;
assert!(
futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
.await
.is_err()
);
});
}
#[test]
fn test_io_error() {
smol::block_on(async move {
let (client_conn, server_conn) = test::Channel::bidirectional();
drop(server_conn);
let client = Peer::new();
let (connection_id, io_handler, mut incoming) =
client.add_connection(client_conn).await;
smol::spawn(io_handler).detach();
smol::spawn(async move { incoming.next().await }).detach();
let err = client
.request(
connection_id,
proto::Auth {
user_id: 42,
access_token: "token".to_string(),
},
)
.await
.unwrap_err();
assert_eq!(err.to_string(), "connection was closed");
});
}
}

View file

@ -4,7 +4,6 @@ use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSock
use futures::{SinkExt as _, StreamExt as _};
use prost::Message;
use std::any::Any;
use std::sync::Arc;
use std::{
io,
time::{Duration, SystemTime, UNIX_EPOCH},
@ -34,14 +33,16 @@ pub trait RequestMessage: EnvelopedMessage {
macro_rules! messages {
($($name:ident),* $(,)?) => {
pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Arc<dyn Any + Send + Sync>> {
pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn Any + Send + Sync>> {
match envelope.payload {
$(Some(envelope::Payload::$name(payload)) => Some(Arc::new(TypedEnvelope {
sender_id,
original_sender_id: envelope.original_sender_id.map(PeerId),
message_id: envelope.id,
payload,
})), )*
$(Some(envelope::Payload::$name(payload)) => {
Some(Box::new(TypedEnvelope {
sender_id,
original_sender_id: envelope.original_sender_id.map(PeerId),
message_id: envelope.id,
payload,
}))
}, )*
_ => None
}
}
@ -116,6 +117,8 @@ messages!(
OpenBufferResponse,
OpenWorktree,
OpenWorktreeResponse,
Ping,
Pong,
RemovePeer,
SaveBuffer,
SendChannelMessage,
@ -132,6 +135,7 @@ request_messages!(
(JoinChannel, JoinChannelResponse),
(OpenBuffer, OpenBufferResponse),
(OpenWorktree, OpenWorktreeResponse),
(Ping, Pong),
(SaveBuffer, BufferSaved),
(ShareWorktree, ShareWorktreeResponse),
);