Use a fixed-length delimiter for encoding/decoding messages in RPC

Co-Authored-By: Max Brunsfeld <max@zed.dev>
This commit is contained in:
Antonio Scandurra 2021-06-14 18:33:34 +02:00
parent 0ddbe0c757
commit cd1a4c49cf
2 changed files with 36 additions and 97 deletions

View file

@ -1,7 +1,7 @@
use futures_io::{AsyncRead, AsyncWrite}; use futures_io::{AsyncRead, AsyncWrite};
use futures_lite::{AsyncReadExt, AsyncWriteExt as _}; use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
use prost::Message; use prost::Message;
use std::io; use std::{convert::TryInto, io};
include!(concat!(env!("OUT_DIR"), "/zed.messages.rs")); include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
@ -96,9 +96,14 @@ where
T: AsyncWrite + Unpin, T: AsyncWrite + Unpin,
{ {
/// Write a given protobuf message to the stream. /// Write a given protobuf message to the stream.
pub async fn write_message(&mut self, message: &impl Message) -> futures_io::Result<()> { pub async fn write_message(&mut self, message: &impl Message) -> io::Result<()> {
let message_len: u32 = message
.encoded_len()
.try_into()
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
self.buffer.clear(); self.buffer.clear();
message.encode_length_delimited(&mut self.buffer).unwrap(); self.buffer.extend_from_slice(&message_len.to_be_bytes());
message.encode(&mut self.buffer)?;
self.byte_stream.write_all(&self.buffer).await self.byte_stream.write_all(&self.buffer).await
} }
} }
@ -109,44 +114,12 @@ where
{ {
/// Read a protobuf message of the given type from the stream. /// Read a protobuf message of the given type from the stream.
pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> { pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
// Ensure the buffer is large enough to hold the maximum delimiter length let mut delimiter_buf = [0; 4];
const MAX_DELIMITER_LEN: usize = 10; self.byte_stream.read_exact(&mut delimiter_buf).await?;
self.buffer.resize(MAX_DELIMITER_LEN, 0); let message_len = u32::from_be_bytes(delimiter_buf) as usize;
self.buffer.resize(message_len, 0);
// Read until a complete length delimiter can be decoded. self.byte_stream.read_exact(&mut self.buffer).await?;
let mut read_start_offset = 0; Ok(M::decode(self.buffer.as_slice())?)
let (encoded_len, delimiter_len) = loop {
let bytes_read = self
.byte_stream
.read(&mut self.buffer[read_start_offset..])
.await?;
read_start_offset += bytes_read;
let mut buffer = &self.buffer[0..read_start_offset];
match prost::decode_length_delimiter(&mut buffer) {
Err(_) => {
if read_start_offset >= MAX_DELIMITER_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid message length delimiter",
));
}
}
Ok(encoded_len) => {
let delimiter_len = read_start_offset - buffer.len();
break (encoded_len, delimiter_len);
}
}
};
// Read the message itself.
self.buffer.resize(delimiter_len + encoded_len, 0);
self.byte_stream
.read_exact(&mut self.buffer[read_start_offset..])
.await?;
let message = M::decode(&self.buffer[delimiter_len..])?;
Ok(message)
} }
} }
@ -196,60 +169,6 @@ mod tests {
}); });
} }
#[test]
fn test_read_message_when_length_delimiter_is_not_complete_in_first_read() {
smol::block_on(async {
let byte_stream = ChunkedStream {
bytes: Vec::new(),
read_offset: 0,
chunk_size: 2,
};
// This message is so long that its length delimiter requires three bytes,
// so it won't be delivered in a single read from the chunked byte stream.
let message = FromClient {
id: 4,
variant: Some(from_client::Variant::UploadFile(from_client::UploadFile {
path: Vec::new(),
content: "long ".repeat(256 * 256).into(),
})),
};
assert!(prost::length_delimiter_len(message.encoded_len()) > byte_stream.chunk_size);
let mut message_stream = MessageStream::new(byte_stream);
message_stream.write_message(&message).await.unwrap();
let decoded_message = message_stream.read_message::<FromClient>().await.unwrap();
assert_eq!(decoded_message, message);
});
}
#[test]
fn test_protobuf_parse_error() {
smol::block_on(async {
let mut byte_stream = ChunkedStream {
bytes: Vec::new(),
read_offset: 0,
chunk_size: 2,
};
let message = FromClient {
id: 3,
variant: Some(from_client::Variant::Auth(from_client::Auth {
user_id: 5,
access_token: "the-access-token".into(),
})),
};
byte_stream.write_all(b"omg").await.unwrap();
let mut message_stream = MessageStream::new(byte_stream);
message_stream.write_message(&message).await.unwrap();
// Read the wrong type of message from the stream.
let result = message_stream.read_message::<FromServer>().await;
assert!(result.is_err());
});
}
struct ChunkedStream { struct ChunkedStream {
bytes: Vec<u8>, bytes: Vec<u8>,
read_offset: usize, read_offset: usize,

View file

@ -103,8 +103,16 @@ where
.ok_or_else(|| anyhow!("received response of the wrong t")) .ok_or_else(|| anyhow!("received response of the wrong t"))
} }
pub async fn send<T: SendMessage>(_: T) -> Result<()> { pub async fn send<T: SendMessage>(&mut self, message: T) -> Result<()> {
todo!() let message_id = self.next_message_id;
self.next_message_id += 1;
self.stream
.write_message(&proto::FromClient {
id: message_id,
variant: Some(message.to_variant()),
})
.await?;
Ok(())
} }
} }
@ -152,6 +160,18 @@ mod tests {
)) ))
); );
// Respond to another request to ensure requests are properly matched up.
server_stream
.write_message(&proto::FromServer {
request_id: Some(999),
variant: Some(proto::from_server::Variant::AuthResponse(
proto::from_server::AuthResponse {
credentials_valid: false,
},
)),
})
.await
.unwrap();
server_stream server_stream
.write_message(&proto::FromServer { .write_message(&proto::FromServer {
request_id: Some(server_req.id), request_id: Some(server_req.id),