From cd1a4c49cffebbd8e3983eda4e3e2159f304afdc Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 14 Jun 2021 18:33:34 +0200 Subject: [PATCH] Use a fixed-length delimiter for encoding/decoding messages in RPC Co-Authored-By: Max Brunsfeld --- zed-rpc/src/proto.rs | 109 ++++++------------------------------------ zed/src/rpc_client.rs | 24 +++++++++- 2 files changed, 36 insertions(+), 97 deletions(-) diff --git a/zed-rpc/src/proto.rs b/zed-rpc/src/proto.rs index 09758fdf74..d68caa4769 100644 --- a/zed-rpc/src/proto.rs +++ b/zed-rpc/src/proto.rs @@ -1,7 +1,7 @@ use futures_io::{AsyncRead, AsyncWrite}; use futures_lite::{AsyncReadExt, AsyncWriteExt as _}; use prost::Message; -use std::io; +use std::{convert::TryInto, io}; include!(concat!(env!("OUT_DIR"), "/zed.messages.rs")); @@ -96,9 +96,14 @@ where T: AsyncWrite + Unpin, { /// Write a given protobuf message to the stream. - pub async fn write_message(&mut self, message: &impl Message) -> futures_io::Result<()> { + pub async fn write_message(&mut self, message: &impl Message) -> io::Result<()> { + let message_len: u32 = message + .encoded_len() + .try_into() + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?; self.buffer.clear(); - message.encode_length_delimited(&mut self.buffer).unwrap(); + self.buffer.extend_from_slice(&message_len.to_be_bytes()); + message.encode(&mut self.buffer)?; self.byte_stream.write_all(&self.buffer).await } } @@ -109,44 +114,12 @@ where { /// Read a protobuf message of the given type from the stream. pub async fn read_message(&mut self) -> futures_io::Result { - // Ensure the buffer is large enough to hold the maximum delimiter length - const MAX_DELIMITER_LEN: usize = 10; - self.buffer.resize(MAX_DELIMITER_LEN, 0); - - // Read until a complete length delimiter can be decoded. - let mut read_start_offset = 0; - let (encoded_len, delimiter_len) = loop { - let bytes_read = self - .byte_stream - .read(&mut self.buffer[read_start_offset..]) - .await?; - read_start_offset += bytes_read; - - let mut buffer = &self.buffer[0..read_start_offset]; - match prost::decode_length_delimiter(&mut buffer) { - Err(_) => { - if read_start_offset >= MAX_DELIMITER_LEN { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "invalid message length delimiter", - )); - } - } - Ok(encoded_len) => { - let delimiter_len = read_start_offset - buffer.len(); - break (encoded_len, delimiter_len); - } - } - }; - - // Read the message itself. - self.buffer.resize(delimiter_len + encoded_len, 0); - self.byte_stream - .read_exact(&mut self.buffer[read_start_offset..]) - .await?; - let message = M::decode(&self.buffer[delimiter_len..])?; - - Ok(message) + let mut delimiter_buf = [0; 4]; + self.byte_stream.read_exact(&mut delimiter_buf).await?; + let message_len = u32::from_be_bytes(delimiter_buf) as usize; + self.buffer.resize(message_len, 0); + self.byte_stream.read_exact(&mut self.buffer).await?; + Ok(M::decode(self.buffer.as_slice())?) } } @@ -196,60 +169,6 @@ mod tests { }); } - #[test] - fn test_read_message_when_length_delimiter_is_not_complete_in_first_read() { - smol::block_on(async { - let byte_stream = ChunkedStream { - bytes: Vec::new(), - read_offset: 0, - chunk_size: 2, - }; - - // This message is so long that its length delimiter requires three bytes, - // so it won't be delivered in a single read from the chunked byte stream. - let message = FromClient { - id: 4, - variant: Some(from_client::Variant::UploadFile(from_client::UploadFile { - path: Vec::new(), - content: "long ".repeat(256 * 256).into(), - })), - }; - assert!(prost::length_delimiter_len(message.encoded_len()) > byte_stream.chunk_size); - - let mut message_stream = MessageStream::new(byte_stream); - message_stream.write_message(&message).await.unwrap(); - let decoded_message = message_stream.read_message::().await.unwrap(); - assert_eq!(decoded_message, message); - }); - } - - #[test] - fn test_protobuf_parse_error() { - smol::block_on(async { - let mut byte_stream = ChunkedStream { - bytes: Vec::new(), - read_offset: 0, - chunk_size: 2, - }; - - let message = FromClient { - id: 3, - variant: Some(from_client::Variant::Auth(from_client::Auth { - user_id: 5, - access_token: "the-access-token".into(), - })), - }; - - byte_stream.write_all(b"omg").await.unwrap(); - let mut message_stream = MessageStream::new(byte_stream); - message_stream.write_message(&message).await.unwrap(); - - // Read the wrong type of message from the stream. - let result = message_stream.read_message::().await; - assert!(result.is_err()); - }); - } - struct ChunkedStream { bytes: Vec, read_offset: usize, diff --git a/zed/src/rpc_client.rs b/zed/src/rpc_client.rs index 9592d33785..2737f524f8 100644 --- a/zed/src/rpc_client.rs +++ b/zed/src/rpc_client.rs @@ -103,8 +103,16 @@ where .ok_or_else(|| anyhow!("received response of the wrong t")) } - pub async fn send(_: T) -> Result<()> { - todo!() + pub async fn send(&mut self, message: T) -> Result<()> { + let message_id = self.next_message_id; + self.next_message_id += 1; + self.stream + .write_message(&proto::FromClient { + id: message_id, + variant: Some(message.to_variant()), + }) + .await?; + Ok(()) } } @@ -152,6 +160,18 @@ mod tests { )) ); + // Respond to another request to ensure requests are properly matched up. + server_stream + .write_message(&proto::FromServer { + request_id: Some(999), + variant: Some(proto::from_server::Variant::AuthResponse( + proto::from_server::AuthResponse { + credentials_valid: false, + }, + )), + }) + .await + .unwrap(); server_stream .write_message(&proto::FromServer { request_id: Some(server_req.id),