mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-24 17:28:40 +00:00
Use a fixed-length delimiter for encoding/decoding messages in RPC
Co-Authored-By: Max Brunsfeld <max@zed.dev>
This commit is contained in:
parent
0ddbe0c757
commit
cd1a4c49cf
2 changed files with 36 additions and 97 deletions
|
@ -1,7 +1,7 @@
|
|||
use futures_io::{AsyncRead, AsyncWrite};
|
||||
use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
|
||||
use prost::Message;
|
||||
use std::io;
|
||||
use std::{convert::TryInto, io};
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
|
||||
|
||||
|
@ -96,9 +96,14 @@ where
|
|||
T: AsyncWrite + Unpin,
|
||||
{
|
||||
/// Write a given protobuf message to the stream.
|
||||
pub async fn write_message(&mut self, message: &impl Message) -> futures_io::Result<()> {
|
||||
pub async fn write_message(&mut self, message: &impl Message) -> io::Result<()> {
|
||||
let message_len: u32 = message
|
||||
.encoded_len()
|
||||
.try_into()
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
|
||||
self.buffer.clear();
|
||||
message.encode_length_delimited(&mut self.buffer).unwrap();
|
||||
self.buffer.extend_from_slice(&message_len.to_be_bytes());
|
||||
message.encode(&mut self.buffer)?;
|
||||
self.byte_stream.write_all(&self.buffer).await
|
||||
}
|
||||
}
|
||||
|
@ -109,44 +114,12 @@ where
|
|||
{
|
||||
/// Read a protobuf message of the given type from the stream.
|
||||
pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
|
||||
// Ensure the buffer is large enough to hold the maximum delimiter length
|
||||
const MAX_DELIMITER_LEN: usize = 10;
|
||||
self.buffer.resize(MAX_DELIMITER_LEN, 0);
|
||||
|
||||
// Read until a complete length delimiter can be decoded.
|
||||
let mut read_start_offset = 0;
|
||||
let (encoded_len, delimiter_len) = loop {
|
||||
let bytes_read = self
|
||||
.byte_stream
|
||||
.read(&mut self.buffer[read_start_offset..])
|
||||
.await?;
|
||||
read_start_offset += bytes_read;
|
||||
|
||||
let mut buffer = &self.buffer[0..read_start_offset];
|
||||
match prost::decode_length_delimiter(&mut buffer) {
|
||||
Err(_) => {
|
||||
if read_start_offset >= MAX_DELIMITER_LEN {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"invalid message length delimiter",
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(encoded_len) => {
|
||||
let delimiter_len = read_start_offset - buffer.len();
|
||||
break (encoded_len, delimiter_len);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Read the message itself.
|
||||
self.buffer.resize(delimiter_len + encoded_len, 0);
|
||||
self.byte_stream
|
||||
.read_exact(&mut self.buffer[read_start_offset..])
|
||||
.await?;
|
||||
let message = M::decode(&self.buffer[delimiter_len..])?;
|
||||
|
||||
Ok(message)
|
||||
let mut delimiter_buf = [0; 4];
|
||||
self.byte_stream.read_exact(&mut delimiter_buf).await?;
|
||||
let message_len = u32::from_be_bytes(delimiter_buf) as usize;
|
||||
self.buffer.resize(message_len, 0);
|
||||
self.byte_stream.read_exact(&mut self.buffer).await?;
|
||||
Ok(M::decode(self.buffer.as_slice())?)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -196,60 +169,6 @@ mod tests {
|
|||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_message_when_length_delimiter_is_not_complete_in_first_read() {
|
||||
smol::block_on(async {
|
||||
let byte_stream = ChunkedStream {
|
||||
bytes: Vec::new(),
|
||||
read_offset: 0,
|
||||
chunk_size: 2,
|
||||
};
|
||||
|
||||
// This message is so long that its length delimiter requires three bytes,
|
||||
// so it won't be delivered in a single read from the chunked byte stream.
|
||||
let message = FromClient {
|
||||
id: 4,
|
||||
variant: Some(from_client::Variant::UploadFile(from_client::UploadFile {
|
||||
path: Vec::new(),
|
||||
content: "long ".repeat(256 * 256).into(),
|
||||
})),
|
||||
};
|
||||
assert!(prost::length_delimiter_len(message.encoded_len()) > byte_stream.chunk_size);
|
||||
|
||||
let mut message_stream = MessageStream::new(byte_stream);
|
||||
message_stream.write_message(&message).await.unwrap();
|
||||
let decoded_message = message_stream.read_message::<FromClient>().await.unwrap();
|
||||
assert_eq!(decoded_message, message);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_protobuf_parse_error() {
|
||||
smol::block_on(async {
|
||||
let mut byte_stream = ChunkedStream {
|
||||
bytes: Vec::new(),
|
||||
read_offset: 0,
|
||||
chunk_size: 2,
|
||||
};
|
||||
|
||||
let message = FromClient {
|
||||
id: 3,
|
||||
variant: Some(from_client::Variant::Auth(from_client::Auth {
|
||||
user_id: 5,
|
||||
access_token: "the-access-token".into(),
|
||||
})),
|
||||
};
|
||||
|
||||
byte_stream.write_all(b"omg").await.unwrap();
|
||||
let mut message_stream = MessageStream::new(byte_stream);
|
||||
message_stream.write_message(&message).await.unwrap();
|
||||
|
||||
// Read the wrong type of message from the stream.
|
||||
let result = message_stream.read_message::<FromServer>().await;
|
||||
assert!(result.is_err());
|
||||
});
|
||||
}
|
||||
|
||||
struct ChunkedStream {
|
||||
bytes: Vec<u8>,
|
||||
read_offset: usize,
|
||||
|
|
|
@ -103,8 +103,16 @@ where
|
|||
.ok_or_else(|| anyhow!("received response of the wrong t"))
|
||||
}
|
||||
|
||||
pub async fn send<T: SendMessage>(_: T) -> Result<()> {
|
||||
todo!()
|
||||
pub async fn send<T: SendMessage>(&mut self, message: T) -> Result<()> {
|
||||
let message_id = self.next_message_id;
|
||||
self.next_message_id += 1;
|
||||
self.stream
|
||||
.write_message(&proto::FromClient {
|
||||
id: message_id,
|
||||
variant: Some(message.to_variant()),
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -152,6 +160,18 @@ mod tests {
|
|||
))
|
||||
);
|
||||
|
||||
// Respond to another request to ensure requests are properly matched up.
|
||||
server_stream
|
||||
.write_message(&proto::FromServer {
|
||||
request_id: Some(999),
|
||||
variant: Some(proto::from_server::Variant::AuthResponse(
|
||||
proto::from_server::AuthResponse {
|
||||
credentials_valid: false,
|
||||
},
|
||||
)),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
server_stream
|
||||
.write_message(&proto::FromServer {
|
||||
request_id: Some(server_req.id),
|
||||
|
|
Loading…
Reference in a new issue