diff --git a/Cargo.lock b/Cargo.lock index d744744f62..f92a3fdbcb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -239,7 +239,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "which", + "which 3.1.1", ] [[package]] @@ -306,6 +306,12 @@ version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae44d1a3d5a19df61dd0c8beb138458ac2a53a7ac09eba97d55592540004306b" +[[package]] +name = "bytes" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" + [[package]] name = "cab" version = "0.1.0" @@ -893,6 +899,12 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "fixedbitset" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37ab347416e802de484e4d03c7316c48f1ecb56574dfd4a46a80f173ce1de04d" + [[package]] name = "flate2" version = "1.0.20" @@ -1067,9 +1079,9 @@ checksum = "28be053525281ad8259d47e4de5de657b25e7bac113458555bb4b70bc6870500" [[package]] name = "futures-lite" -version = "1.11.3" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4481d0cd0de1d204a4fa55e7d45f07b1d958abcb06714b3446438e2eff695fb" +checksum = "7694489acd39452c77daa48516b894c153f192c3578d5a839b62c58099fcbf48" dependencies = [ "fastrand", "futures-core", @@ -1249,6 +1261,15 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04" +[[package]] +name = "heck" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d621efb26863f0e9924c6ac577e8275e5e6b77455db64ffa6c65c904e9e132c" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "hermit-abi" version = "0.1.18" @@ -1354,6 +1375,24 @@ dependencies = [ "cfg-if 1.0.0", ] +[[package]] +name = "itertools" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "0.3.4" @@ -1577,6 +1616,12 @@ dependencies = [ "uuid", ] +[[package]] +name = "multimap" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" + [[package]] name = "nb-connect" version = "1.0.3" @@ -1808,6 +1853,16 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e" +[[package]] +name = "petgraph" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "467d164a6de56270bd7c4d070df81d07beace25012d5103ced4e9ff08d6afdb7" +dependencies = [ + "fixedbitset", + "indexmap", +] + [[package]] name = "phf" version = "0.7.24" @@ -1976,6 +2031,76 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "prost" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e6984d2f1a23009bd270b8bb56d0926810a3d483f59c987d77969e9d8e840b2" +dependencies = [ + "bytes", + "prost-derive 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "prost" +version = "0.7.0" +source = "git+https://github.com/sfackler/prost?rev=082f3e65874fe91382e72482863896b7b4db3728#082f3e65874fe91382e72482863896b7b4db3728" +dependencies = [ + "bytes", + "prost-derive 0.7.0 (git+https://github.com/sfackler/prost?rev=082f3e65874fe91382e72482863896b7b4db3728)", +] + +[[package]] +name = "prost-build" +version = "0.7.0" +source = "git+https://github.com/sfackler/prost?rev=082f3e65874fe91382e72482863896b7b4db3728#082f3e65874fe91382e72482863896b7b4db3728" +dependencies = [ + "bytes", + "heck", + "itertools 0.10.1", + "log", + "multimap", + "petgraph", + "prost 0.7.0 (git+https://github.com/sfackler/prost?rev=082f3e65874fe91382e72482863896b7b4db3728)", + "prost-types", + "tempfile", + "which 4.1.0", +] + +[[package]] +name = "prost-derive" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "169a15f3008ecb5160cba7d37bcd690a7601b6d30cfb87a117d45e59d52af5d4" +dependencies = [ + "anyhow", + "itertools 0.9.0", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "prost-derive" +version = "0.7.0" +source = "git+https://github.com/sfackler/prost?rev=082f3e65874fe91382e72482863896b7b4db3728#082f3e65874fe91382e72482863896b7b4db3728" +dependencies = [ + "anyhow", + "itertools 0.10.1", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "prost-types" +version = "0.7.0" +source = "git+https://github.com/sfackler/prost?rev=082f3e65874fe91382e72482863896b7b4db3728#082f3e65874fe91382e72482863896b7b4db3728" +dependencies = [ + "bytes", + "prost 0.7.0 (git+https://github.com/sfackler/prost?rev=082f3e65874fe91382e72482863896b7b4db3728)", +] + [[package]] name = "quote" version = "1.0.9" @@ -2763,6 +2888,20 @@ dependencies = [ "remove_dir_all", ] +[[package]] +name = "tempfile" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dac1c663cfc93810f88aed9b8941d48cabf856a1b111c29a40439018d870eb22" +dependencies = [ + "cfg-if 1.0.0", + "libc", + "rand 0.8.3", + "redox_syscall 0.2.5", + "remove_dir_all", + "winapi 0.3.9", +] + [[package]] name = "term" version = "0.4.6" @@ -2957,6 +3096,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79bf4d5fc96546fdb73f9827097810bbda93b11a6770ff3a54e1f445d4135787" +[[package]] +name = "unicode-segmentation" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0d2e7be6ae3a5fa87eed5fb451aff96f2573d2694942e40543ae0bbe19c796" + [[package]] name = "unicode-vo" version = "0.1.0" @@ -3094,6 +3239,16 @@ dependencies = [ "libc", ] +[[package]] +name = "which" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b55551e42cbdf2ce2bedd2203d0cc08dba002c27510f86dab6d0ce304cba3dfe" +dependencies = [ + "either", + "libc", +] + [[package]] name = "winapi" version = "0.2.8" @@ -3214,8 +3369,13 @@ version = "0.1.0" dependencies = [ "anyhow", "base64", + "futures-io", + "futures-lite", + "prost 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", + "prost-build", "rand 0.8.3", "rsa", + "smol", ] [[package]] diff --git a/zed-rpc/Cargo.toml b/zed-rpc/Cargo.toml index a07f7af3d9..acf2e0cb02 100644 --- a/zed-rpc/Cargo.toml +++ b/zed-rpc/Cargo.toml @@ -7,5 +7,14 @@ version = "0.1.0" [dependencies] anyhow = "1.0" base64 = "0.13" +futures-io = "0.3" +futures-lite = "1" +prost = "0.7" rsa = "0.4" rand = "0.8" + +[build-dependencies] +prost-build = { git = "https://github.com/sfackler/prost", rev = "082f3e65874fe91382e72482863896b7b4db3728" } + +[dev-dependencies] +smol = "1.2.5" diff --git a/zed-rpc/build.rs b/zed-rpc/build.rs new file mode 100644 index 0000000000..d453616095 --- /dev/null +++ b/zed-rpc/build.rs @@ -0,0 +1,7 @@ +fn main() { + let mut build = prost_build::Config::new(); + // build.protoc_arg("--experimental_allow_proto3_optional"); + build + .compile_protos(&["proto/zed.proto"], &["proto"]) + .unwrap(); +} diff --git a/zed-rpc/proto/zed.proto b/zed-rpc/proto/zed.proto new file mode 100644 index 0000000000..94a4d173b1 --- /dev/null +++ b/zed-rpc/proto/zed.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; +package zed.messages; + +message FromClient { + int32 id = 1; + + oneof variant { + Auth auth = 2; + } + + message Auth { + int32 user_id = 1; + string access_token = 2; + } +} + +message FromServer { + optional int32 request_id = 1; + + oneof variant { + Ack ack = 2; + } + + message Ack { + optional string error_message = 1; + } +} \ No newline at end of file diff --git a/zed-rpc/src/lib.rs b/zed-rpc/src/lib.rs index 0e4a05d597..8006a10b50 100644 --- a/zed-rpc/src/lib.rs +++ b/zed-rpc/src/lib.rs @@ -1 +1,2 @@ pub mod auth; +pub mod proto; diff --git a/zed-rpc/src/proto.rs b/zed-rpc/src/proto.rs new file mode 100644 index 0000000000..177dbd6fa2 --- /dev/null +++ b/zed-rpc/src/proto.rs @@ -0,0 +1,234 @@ +use futures_io::{AsyncRead, AsyncWrite}; +use futures_lite::{AsyncReadExt, AsyncWriteExt as _}; +use prost::Message; +use std::io; + +include!(concat!(env!("OUT_DIR"), "/zed.messages.rs")); + +pub trait Request { + type Response; +} + +impl Request for from_client::Auth { + type Response = from_server::Ack; +} + +/// A stream of protobuf messages. +pub struct MessageStream { + byte_stream: T, + buffer: Vec, +} + +impl MessageStream { + pub fn new(byte_stream: T) -> Self { + Self { + byte_stream, + buffer: Default::default(), + } + } +} + +impl MessageStream +where + T: AsyncWrite + Unpin, +{ + /// Write a given protobuf message to the stream. + pub async fn write_message(&mut self, message: &impl Message) -> futures_io::Result<()> { + self.buffer.clear(); + message.encode_length_delimited(&mut self.buffer).unwrap(); + self.byte_stream.write_all(&self.buffer).await + } +} + +impl MessageStream +where + T: AsyncRead + Unpin, +{ + /// Read a protobuf message of the given type from the stream. + pub async fn read_message(&mut self) -> futures_io::Result { + // Ensure the buffer is large enough to hold the maximum delimiter length + const MAX_DELIMITER_LEN: usize = 10; + self.buffer.clear(); + self.buffer.resize(MAX_DELIMITER_LEN, 0); + + // Read until a complete length delimiter can be decoded. + let mut read_start_offset = 0; + let (encoded_len, delimiter_len) = loop { + let bytes_read = self + .byte_stream + .read(&mut self.buffer[read_start_offset..]) + .await?; + read_start_offset += bytes_read; + + let mut buffer = &self.buffer[0..read_start_offset]; + match prost::decode_length_delimiter(&mut buffer) { + Err(_) => { + if read_start_offset >= MAX_DELIMITER_LEN { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid message length delimiter", + )); + } + } + Ok(encoded_len) => { + let delimiter_len = read_start_offset - buffer.len(); + break (encoded_len, delimiter_len); + } + } + }; + + // Read the message itself. + self.buffer.resize(delimiter_len + encoded_len, 0); + self.byte_stream + .read_exact(&mut self.buffer[read_start_offset..]) + .await?; + let message = M::decode(&self.buffer[delimiter_len..])?; + + Ok(message) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{ + pin::Pin, + task::{Context, Poll}, + }; + + #[test] + fn test_round_trip_message() { + smol::block_on(async { + let byte_stream = ChunkedStream { + bytes: Vec::new(), + read_offset: 0, + chunk_size: 3, + }; + + // In reality there will never be both `FromClient` and `FromServer` messages + // sent in the same direction on the same stream. + let message1 = FromClient { + id: 3, + variant: Some(from_client::Variant::Auth(from_client::Auth { + user_id: 5, + access_token: "the-access-token".into(), + })), + }; + let message2 = FromServer { + request_id: Some(4), + variant: Some(from_server::Variant::Ack(from_server::Ack { + error_message: Some( + format!( + "a {}long error message that requires a two-byte length delimiter", + "very ".repeat(60) + ) + .into(), + ), + })), + }; + + let mut message_stream = MessageStream::new(byte_stream); + message_stream.write_message(&message1).await.unwrap(); + message_stream.write_message(&message2).await.unwrap(); + let decoded_message1 = message_stream.read_message::().await.unwrap(); + let decoded_message2 = message_stream.read_message::().await.unwrap(); + assert_eq!(decoded_message1, message1); + assert_eq!(decoded_message2, message2); + }); + } + + #[test] + fn test_read_message_when_length_delimiter_is_not_complete_in_first_read() { + smol::block_on(async { + let byte_stream = ChunkedStream { + bytes: Vec::new(), + read_offset: 0, + chunk_size: 2, + }; + + // This message is so long that its length delimiter requires three bytes, + // so it won't be delivered in a single read from the chunked byte stream. + let message = FromServer { + request_id: Some(4), + variant: Some(from_server::Variant::Ack(from_server::Ack { + error_message: Some("long ".repeat(256 * 256).into()), + })), + }; + assert!(prost::length_delimiter_len(message.encoded_len()) > byte_stream.chunk_size); + + let mut message_stream = MessageStream::new(byte_stream); + message_stream.write_message(&message).await.unwrap(); + let decoded_message = message_stream.read_message::().await.unwrap(); + assert_eq!(decoded_message, message); + }); + } + + #[test] + fn test_protobuf_parse_error() { + smol::block_on(async { + let byte_stream = ChunkedStream { + bytes: Vec::new(), + read_offset: 0, + chunk_size: 2, + }; + + let message = FromClient { + id: 3, + variant: Some(from_client::Variant::Auth(from_client::Auth { + user_id: 5, + access_token: "the-access-token".into(), + })), + }; + + let mut message_stream = MessageStream::new(byte_stream); + message_stream.write_message(&message).await.unwrap(); + + // Read the wrong type of message from the stream. + let result = message_stream.read_message::().await; + assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData); + }); + } + + struct ChunkedStream { + bytes: Vec, + read_offset: usize, + chunk_size: usize, + } + + impl AsyncWrite for ChunkedStream { + fn poll_write( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let bytes_written = buf.len().min(self.chunk_size); + self.bytes.extend_from_slice(&buf[0..bytes_written]); + Poll::Ready(Ok(bytes_written)) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl AsyncRead for ChunkedStream { + fn poll_read( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let bytes_read = buf + .len() + .min(self.chunk_size) + .min(self.bytes.len() - self.read_offset); + let end_offset = self.read_offset + bytes_read; + buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]); + self.read_offset = end_offset; + Poll::Ready(Ok(bytes_read)) + } + } +}