Change RpcClient methods to take shared references

This will make it easier to spawn a future on gpui's executors
when calling `RpcClient` methods.

Co-Authored-By: Max Brunsfeld <max@zed.dev>
This commit is contained in:
Antonio Scandurra 2021-06-14 19:59:46 +02:00
parent e551894189
commit a87d4db155
5 changed files with 115 additions and 96 deletions

15
Cargo.lock generated
View file

@ -1350,6 +1350,7 @@ checksum = "da9052a1a50244d8d5aa9bf55cbc2fb6f357c86cc52e46c62ed390a7180cf150"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
@ -1372,6 +1373,17 @@ version = "0.3.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79e5145dde8da7d1b3892dad07a9c98fc04bc39892b1ecc9692cf53e2b780a65"
[[package]]
name = "futures-executor"
version = "0.3.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9e59fdc009a4b3096bf94f740a0f2424c082521f20a9b08c5c07c48d90fd9b9"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.12"
@ -1423,6 +1435,7 @@ version = "0.3.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632a8cd0f2a4b3fdea1657f08bde063848c3bd00f9bbf6e256b8be78802e624b"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
@ -4304,7 +4317,7 @@ dependencies = [
"easy-parallel",
"env_logger",
"fsevent",
"futures-core",
"futures",
"gpui",
"http-auth-basic",
"ignore",

View file

@ -21,7 +21,7 @@ ctor = "0.1.20"
dirs = "3.0"
easy-parallel = "3.1.0"
fsevent = { path = "../fsevent" }
futures-core = "0.3"
futures = "0.3"
gpui = { path = "../gpui" }
http-auth-basic = "0.1.3"
ignore = "0.4"

View file

@ -1,106 +1,112 @@
use anyhow::{anyhow, Result};
use futures::FutureExt;
use gpui::executor::Background;
use parking_lot::Mutex;
use postage::{
mpsc, oneshot,
mpsc,
prelude::{Sink, Stream},
};
use smol::{
future::FutureExt,
io::WriteHalf,
prelude::{AsyncRead, AsyncWrite},
use smol::prelude::{AsyncRead, AsyncWrite};
use std::{
collections::HashMap,
io,
sync::{
atomic::{self, AtomicI32},
Arc,
},
};
use std::{collections::HashMap, sync::Arc};
use zed_rpc::proto::{
self, MessageStream, RequestMessage, SendMessage, ServerMessage, SubscribeMessage,
};
pub struct RpcClient<Conn> {
stream: MessageStream<WriteHalf<Conn>>,
pub struct RpcClient {
response_channels: Arc<Mutex<HashMap<i32, (mpsc::Sender<proto::from_server::Variant>, bool)>>>,
next_message_id: i32,
_drop_tx: oneshot::Sender<()>,
outgoing_tx: mpsc::Sender<proto::FromClient>,
next_message_id: AtomicI32,
}
impl<Conn> RpcClient<Conn>
where
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
pub fn new(conn: Conn, executor: Arc<Background>) -> Self {
let (conn_rx, conn_tx) = smol::io::split(conn);
let (drop_tx, mut drop_rx) = oneshot::channel();
impl RpcClient {
pub fn new<Conn>(conn: Conn, executor: Arc<Background>) -> Self
where
Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let response_channels = Arc::new(Mutex::new(HashMap::new()));
let client = Self {
next_message_id: 0,
stream: MessageStream::new(conn_tx),
response_channels: response_channels.clone(),
_drop_tx: drop_tx,
};
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(32);
executor
.spawn::<Result<()>, _>(async move {
enum Message {
Message(proto::FromServer),
ClientDropped,
}
let mut stream = MessageStream::new(conn_rx);
let client_dropped = async move {
assert!(drop_rx.recv().await.is_none());
Ok(Message::ClientDropped) as Result<_>
};
smol::pin!(client_dropped);
loop {
let message = async {
Ok(Message::Message(
stream.read_message::<proto::FromServer>().await?,
))
};
match message.race(&mut client_dropped).await? {
Message::Message(message) => {
if let Some(variant) = message.variant {
if let Some(request_id) = message.request_id {
let channel = response_channels.lock().remove(&request_id);
if let Some((mut tx, oneshot)) = channel {
if tx.send(variant).await.is_ok() {
if !oneshot {
response_channels
.lock()
.insert(request_id, (tx, false));
}
}
} else {
log::warn!(
"received RPC response to unknown request id {}",
request_id
);
}
{
let response_channels = response_channels.clone();
executor
.spawn(async move {
let (conn_rx, conn_tx) = smol::io::split(conn);
let mut stream_tx = MessageStream::new(conn_tx);
let mut stream_rx = MessageStream::new(conn_rx);
loop {
futures::select! {
incoming = stream_rx.read_message::<proto::FromServer>().fuse() => {
Self::handle_incoming(incoming, &response_channels).await;
}
outgoing = outgoing_rx.recv().fuse() => {
if let Some(outgoing) = outgoing {
stream_tx.write_message(&outgoing).await;
} else {
break;
}
} else {
log::warn!("received RPC message with no content");
}
}
Message::ClientDropped => break Ok(()),
}
}
})
.detach();
})
.detach();
}
client
Self {
response_channels,
outgoing_tx,
next_message_id: AtomicI32::new(0),
}
}
pub async fn request<T: RequestMessage>(&mut self, req: T) -> Result<T::Response> {
let message_id = self.next_message_id;
self.next_message_id += 1;
async fn handle_incoming(
incoming: io::Result<proto::FromServer>,
response_channels: &Mutex<HashMap<i32, (mpsc::Sender<proto::from_server::Variant>, bool)>>,
) {
match incoming {
Ok(incoming) => {
if let Some(variant) = incoming.variant {
if let Some(request_id) = incoming.request_id {
let channel = response_channels.lock().remove(&request_id);
if let Some((mut tx, oneshot)) = channel {
if tx.send(variant).await.is_ok() {
if !oneshot {
response_channels.lock().insert(request_id, (tx, false));
}
}
} else {
log::warn!(
"received RPC response to unknown request id {}",
request_id
);
}
}
} else {
log::warn!("received RPC message with no content");
}
}
Err(error) => log::warn!("invalid incoming RPC message {:?}", error),
}
}
pub async fn request<T: RequestMessage>(&self, req: T) -> Result<T::Response> {
let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
let (tx, mut rx) = mpsc::channel(1);
self.response_channels.lock().insert(message_id, (tx, true));
self.stream
.write_message(&proto::FromClient {
self.outgoing_tx
.clone()
.send(proto::FromClient {
id: message_id,
variant: Some(req.to_variant()),
})
.await?;
.await
.unwrap();
let response = rx
.recv()
.await
@ -109,15 +115,16 @@ where
.ok_or_else(|| anyhow!("received response of the wrong t"))
}
pub async fn send<T: SendMessage>(&mut self, message: T) -> Result<()> {
let message_id = self.next_message_id;
self.next_message_id += 1;
self.stream
.write_message(&proto::FromClient {
pub async fn send<T: SendMessage>(&self, message: T) -> Result<()> {
let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
self.outgoing_tx
.clone()
.send(proto::FromClient {
id: message_id,
variant: Some(message.to_variant()),
})
.await?;
.await
.unwrap();
Ok(())
}
@ -125,19 +132,19 @@ where
&mut self,
subscription: T,
) -> Result<impl Stream<Item = Result<T::Event>>> {
let message_id = self.next_message_id;
self.next_message_id += 1;
let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
let (tx, rx) = mpsc::channel(256);
self.response_channels
.lock()
.insert(message_id, (tx, false));
self.stream
.write_message(&proto::FromClient {
self.outgoing_tx
.clone()
.send(proto::FromClient {
id: message_id,
variant: Some(subscription.to_variant()),
})
.await?;
.await
.unwrap();
Ok(rx.map(|event| {
T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}"))
}))
@ -165,7 +172,7 @@ mod tests {
let (server_conn, _) = listener.accept().await.unwrap();
let mut server_stream = MessageStream::new(server_conn);
let mut client = RpcClient::new(client_conn, executor.clone());
let client = RpcClient::new(client_conn, executor.clone());
let client_req = client.request(proto::from_client::Auth {
user_id: 42,

View file

@ -8,7 +8,6 @@ use crate::{
worktree::{FileHandle, Worktree, WorktreeHandle},
AppState,
};
use futures_core::Future;
use gpui::{
color::rgbu, elements::*, json::to_string_pretty, keymap::Binding, AnyViewHandle, AppContext,
ClipboardItem, Entity, ModelHandle, MutableAppContext, PathPromptOptions, PromptLevel, Task,
@ -19,10 +18,10 @@ pub use pane::*;
pub use pane_group::*;
use postage::watch;
use smol::prelude::*;
use std::{collections::HashMap, path::PathBuf};
use std::{
collections::{hash_map::Entry, HashSet},
path::Path,
collections::{hash_map::Entry, HashMap, HashSet},
future::Future,
path::{Path, PathBuf},
sync::Arc,
};

View file

@ -1207,7 +1207,7 @@ pub trait WorktreeHandle {
fn flush_fs_events<'a>(
&self,
cx: &'a gpui::TestAppContext,
) -> futures_core::future::LocalBoxFuture<'a, ()>;
) -> futures::future::LocalBoxFuture<'a, ()>;
}
impl WorktreeHandle for ModelHandle<Worktree> {
@ -1268,7 +1268,7 @@ impl WorktreeHandle for ModelHandle<Worktree> {
fn flush_fs_events<'a>(
&self,
cx: &'a gpui::TestAppContext,
) -> futures_core::future::LocalBoxFuture<'a, ()> {
) -> futures::future::LocalBoxFuture<'a, ()> {
use smol::future::FutureExt;
let filename = "fs-event-sentinel";