mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-12 05:15:00 +00:00
Use a fixed-length delimiter for encoding/decoding messages in RPC
Co-Authored-By: Max Brunsfeld <max@zed.dev>
This commit is contained in:
parent
0ddbe0c757
commit
cd1a4c49cf
2 changed files with 36 additions and 97 deletions
|
@ -1,7 +1,7 @@
|
||||||
use futures_io::{AsyncRead, AsyncWrite};
|
use futures_io::{AsyncRead, AsyncWrite};
|
||||||
use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
|
use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
|
||||||
use prost::Message;
|
use prost::Message;
|
||||||
use std::io;
|
use std::{convert::TryInto, io};
|
||||||
|
|
||||||
include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
|
include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
|
||||||
|
|
||||||
|
@ -96,9 +96,14 @@ where
|
||||||
T: AsyncWrite + Unpin,
|
T: AsyncWrite + Unpin,
|
||||||
{
|
{
|
||||||
/// Write a given protobuf message to the stream.
|
/// Write a given protobuf message to the stream.
|
||||||
pub async fn write_message(&mut self, message: &impl Message) -> futures_io::Result<()> {
|
pub async fn write_message(&mut self, message: &impl Message) -> io::Result<()> {
|
||||||
|
let message_len: u32 = message
|
||||||
|
.encoded_len()
|
||||||
|
.try_into()
|
||||||
|
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
|
||||||
self.buffer.clear();
|
self.buffer.clear();
|
||||||
message.encode_length_delimited(&mut self.buffer).unwrap();
|
self.buffer.extend_from_slice(&message_len.to_be_bytes());
|
||||||
|
message.encode(&mut self.buffer)?;
|
||||||
self.byte_stream.write_all(&self.buffer).await
|
self.byte_stream.write_all(&self.buffer).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -109,44 +114,12 @@ where
|
||||||
{
|
{
|
||||||
/// Read a protobuf message of the given type from the stream.
|
/// Read a protobuf message of the given type from the stream.
|
||||||
pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
|
pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
|
||||||
// Ensure the buffer is large enough to hold the maximum delimiter length
|
let mut delimiter_buf = [0; 4];
|
||||||
const MAX_DELIMITER_LEN: usize = 10;
|
self.byte_stream.read_exact(&mut delimiter_buf).await?;
|
||||||
self.buffer.resize(MAX_DELIMITER_LEN, 0);
|
let message_len = u32::from_be_bytes(delimiter_buf) as usize;
|
||||||
|
self.buffer.resize(message_len, 0);
|
||||||
// Read until a complete length delimiter can be decoded.
|
self.byte_stream.read_exact(&mut self.buffer).await?;
|
||||||
let mut read_start_offset = 0;
|
Ok(M::decode(self.buffer.as_slice())?)
|
||||||
let (encoded_len, delimiter_len) = loop {
|
|
||||||
let bytes_read = self
|
|
||||||
.byte_stream
|
|
||||||
.read(&mut self.buffer[read_start_offset..])
|
|
||||||
.await?;
|
|
||||||
read_start_offset += bytes_read;
|
|
||||||
|
|
||||||
let mut buffer = &self.buffer[0..read_start_offset];
|
|
||||||
match prost::decode_length_delimiter(&mut buffer) {
|
|
||||||
Err(_) => {
|
|
||||||
if read_start_offset >= MAX_DELIMITER_LEN {
|
|
||||||
return Err(io::Error::new(
|
|
||||||
io::ErrorKind::InvalidData,
|
|
||||||
"invalid message length delimiter",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(encoded_len) => {
|
|
||||||
let delimiter_len = read_start_offset - buffer.len();
|
|
||||||
break (encoded_len, delimiter_len);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Read the message itself.
|
|
||||||
self.buffer.resize(delimiter_len + encoded_len, 0);
|
|
||||||
self.byte_stream
|
|
||||||
.read_exact(&mut self.buffer[read_start_offset..])
|
|
||||||
.await?;
|
|
||||||
let message = M::decode(&self.buffer[delimiter_len..])?;
|
|
||||||
|
|
||||||
Ok(message)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -196,60 +169,6 @@ mod tests {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_read_message_when_length_delimiter_is_not_complete_in_first_read() {
|
|
||||||
smol::block_on(async {
|
|
||||||
let byte_stream = ChunkedStream {
|
|
||||||
bytes: Vec::new(),
|
|
||||||
read_offset: 0,
|
|
||||||
chunk_size: 2,
|
|
||||||
};
|
|
||||||
|
|
||||||
// This message is so long that its length delimiter requires three bytes,
|
|
||||||
// so it won't be delivered in a single read from the chunked byte stream.
|
|
||||||
let message = FromClient {
|
|
||||||
id: 4,
|
|
||||||
variant: Some(from_client::Variant::UploadFile(from_client::UploadFile {
|
|
||||||
path: Vec::new(),
|
|
||||||
content: "long ".repeat(256 * 256).into(),
|
|
||||||
})),
|
|
||||||
};
|
|
||||||
assert!(prost::length_delimiter_len(message.encoded_len()) > byte_stream.chunk_size);
|
|
||||||
|
|
||||||
let mut message_stream = MessageStream::new(byte_stream);
|
|
||||||
message_stream.write_message(&message).await.unwrap();
|
|
||||||
let decoded_message = message_stream.read_message::<FromClient>().await.unwrap();
|
|
||||||
assert_eq!(decoded_message, message);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_protobuf_parse_error() {
|
|
||||||
smol::block_on(async {
|
|
||||||
let mut byte_stream = ChunkedStream {
|
|
||||||
bytes: Vec::new(),
|
|
||||||
read_offset: 0,
|
|
||||||
chunk_size: 2,
|
|
||||||
};
|
|
||||||
|
|
||||||
let message = FromClient {
|
|
||||||
id: 3,
|
|
||||||
variant: Some(from_client::Variant::Auth(from_client::Auth {
|
|
||||||
user_id: 5,
|
|
||||||
access_token: "the-access-token".into(),
|
|
||||||
})),
|
|
||||||
};
|
|
||||||
|
|
||||||
byte_stream.write_all(b"omg").await.unwrap();
|
|
||||||
let mut message_stream = MessageStream::new(byte_stream);
|
|
||||||
message_stream.write_message(&message).await.unwrap();
|
|
||||||
|
|
||||||
// Read the wrong type of message from the stream.
|
|
||||||
let result = message_stream.read_message::<FromServer>().await;
|
|
||||||
assert!(result.is_err());
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ChunkedStream {
|
struct ChunkedStream {
|
||||||
bytes: Vec<u8>,
|
bytes: Vec<u8>,
|
||||||
read_offset: usize,
|
read_offset: usize,
|
||||||
|
|
|
@ -103,8 +103,16 @@ where
|
||||||
.ok_or_else(|| anyhow!("received response of the wrong t"))
|
.ok_or_else(|| anyhow!("received response of the wrong t"))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send<T: SendMessage>(_: T) -> Result<()> {
|
pub async fn send<T: SendMessage>(&mut self, message: T) -> Result<()> {
|
||||||
todo!()
|
let message_id = self.next_message_id;
|
||||||
|
self.next_message_id += 1;
|
||||||
|
self.stream
|
||||||
|
.write_message(&proto::FromClient {
|
||||||
|
id: message_id,
|
||||||
|
variant: Some(message.to_variant()),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,6 +160,18 @@ mod tests {
|
||||||
))
|
))
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Respond to another request to ensure requests are properly matched up.
|
||||||
|
server_stream
|
||||||
|
.write_message(&proto::FromServer {
|
||||||
|
request_id: Some(999),
|
||||||
|
variant: Some(proto::from_server::Variant::AuthResponse(
|
||||||
|
proto::from_server::AuthResponse {
|
||||||
|
credentials_valid: false,
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
server_stream
|
server_stream
|
||||||
.write_message(&proto::FromServer {
|
.write_message(&proto::FromServer {
|
||||||
request_id: Some(server_req.id),
|
request_id: Some(server_req.id),
|
||||||
|
|
Loading…
Reference in a new issue