Ensure that futures returns from RpcClient are 'static

This commit is contained in:
Antonio Scandurra 2021-06-15 11:15:55 +02:00
parent 04bf84af44
commit 7b96888ab1

View file

@ -12,6 +12,7 @@ use smol::{
};
use std::{
collections::HashMap,
future::Future,
sync::{
atomic::{self, AtomicI32},
Arc,
@ -32,7 +33,7 @@ impl<Conn> RpcClient<Conn>
where
Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
pub fn new(conn: Conn, executor: Arc<Background>) -> Self {
pub fn new(conn: Conn, executor: Arc<Background>) -> Arc<Self> {
let response_channels = Arc::new(Mutex::new(HashMap::new()));
let (conn_rx, conn_tx) = smol::io::split(conn);
let (_drop_tx, drop_rx) = barrier::channel();
@ -45,12 +46,12 @@ where
))
.detach();
Self {
Arc::new(Self {
response_channels,
outgoing: Mutex::new(MessageStream::new(conn_tx)),
_drop_tx,
next_message_id: AtomicI32::new(0),
}
})
}
async fn handle_incoming(
@ -101,63 +102,75 @@ where
}
}
pub async fn request<T: RequestMessage>(&self, req: T) -> Result<T::Response> {
let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
let (tx, mut rx) = mpsc::channel(1);
self.response_channels
.lock()
.await
.insert(message_id, (tx, true));
self.outgoing
.lock()
.await
.write_message(&proto::FromClient {
id: message_id,
variant: Some(req.to_variant()),
})
.await?;
let response = rx
.recv()
.await
.expect("response channel was unexpectedly dropped");
T::Response::from_variant(response)
.ok_or_else(|| anyhow!("received response of the wrong t"))
pub fn request<T: RequestMessage>(
self: &Arc<Self>,
req: T,
) -> impl Future<Output = Result<T::Response>> {
let this = self.clone();
async move {
let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
let (tx, mut rx) = mpsc::channel(1);
this.response_channels
.lock()
.await
.insert(message_id, (tx, true));
this.outgoing
.lock()
.await
.write_message(&proto::FromClient {
id: message_id,
variant: Some(req.to_variant()),
})
.await?;
let response = rx
.recv()
.await
.expect("response channel was unexpectedly dropped");
T::Response::from_variant(response)
.ok_or_else(|| anyhow!("received response of the wrong t"))
}
}
pub async fn send<T: SendMessage>(&self, message: T) -> Result<()> {
let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
self.outgoing
.lock()
.await
.write_message(&proto::FromClient {
id: message_id,
variant: Some(message.to_variant()),
})
.await?;
Ok(())
pub fn send<T: SendMessage>(self: &Arc<Self>, message: T) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
this.outgoing
.lock()
.await
.write_message(&proto::FromClient {
id: message_id,
variant: Some(message.to_variant()),
})
.await?;
Ok(())
}
}
pub async fn subscribe<T: SubscribeMessage>(
&self,
pub fn subscribe<T: SubscribeMessage>(
self: &Arc<Self>,
subscription: T,
) -> Result<impl Stream<Item = Result<T::Event>>> {
let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
let (tx, rx) = mpsc::channel(256);
self.response_channels
.lock()
.await
.insert(message_id, (tx, false));
self.outgoing
.lock()
.await
.write_message(&proto::FromClient {
id: message_id,
variant: Some(subscription.to_variant()),
})
.await?;
Ok(rx.map(|event| {
T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}"))
}))
) -> impl Future<Output = Result<impl Stream<Item = Result<T::Event>>>> {
let this = self.clone();
async move {
let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
let (tx, rx) = mpsc::channel(256);
this.response_channels
.lock()
.await
.insert(message_id, (tx, false));
this.outgoing
.lock()
.await
.write_message(&proto::FromClient {
id: message_id,
variant: Some(subscription.to_variant()),
})
.await?;
Ok(rx.map(|event| {
T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}"))
}))
}
}
}