diff --git a/zrpc/src/peer.rs b/zrpc/src/peer.rs index eeda034e95..251ffb5bb5 100644 --- a/zrpc/src/peer.rs +++ b/zrpc/src/peer.rs @@ -87,7 +87,7 @@ pub struct Peer { struct ConnectionState { outgoing_tx: mpsc::Sender, next_message_id: Arc, - response_channels: Arc>>>, + response_channels: Arc>>>>, } impl Peer { @@ -115,7 +115,7 @@ impl Peer { let connection_state = ConnectionState { outgoing_tx, next_message_id: Default::default(), - response_channels: Default::default(), + response_channels: Arc::new(Mutex::new(Some(Default::default()))), }; let mut writer = MessageStream::new(connection.tx); let mut reader = MessageStream::new(connection.rx); @@ -123,7 +123,7 @@ impl Peer { let this = self.clone(); let response_channels = connection_state.response_channels.clone(); let handle_io = async move { - loop { + let result = 'outer: loop { let read_message = reader.read_message().fuse(); futures::pin_mut!(read_message); loop { @@ -131,7 +131,7 @@ impl Peer { incoming = read_message => match incoming { Ok(incoming) => { if let Some(responding_to) = incoming.responding_to { - let channel = response_channels.lock().await.remove(&responding_to); + let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to); if let Some(mut tx) = channel { tx.send(incoming).await.ok(); } else { @@ -140,9 +140,7 @@ impl Peer { } else { if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) { if incoming_tx.send(envelope).await.is_err() { - response_channels.lock().await.clear(); - this.connections.write().await.remove(&connection_id); - return Ok(()) + break 'outer Ok(()) } } else { log::error!("unable to construct a typed envelope"); @@ -152,28 +150,24 @@ impl Peer { break; } Err(error) => { - response_channels.lock().await.clear(); - this.connections.write().await.remove(&connection_id); - Err(error).context("received invalid RPC message")?; + break 'outer Err(error).context("received invalid RPC message") } }, outgoing = outgoing_rx.recv().fuse() => match outgoing { Some(outgoing) => { if let Err(result) = writer.write_message(&outgoing).await { - response_channels.lock().await.clear(); - this.connections.write().await.remove(&connection_id); - Err(result).context("failed to write RPC message")?; + break 'outer Err(result).context("failed to write RPC message") } } - None => { - response_channels.lock().await.clear(); - this.connections.write().await.remove(&connection_id); - return Ok(()) - } + None => break 'outer Ok(()), } } } - } + }; + + response_channels.lock().await.take(); + this.connections.write().await.remove(&connection_id); + result }; self.connections @@ -226,6 +220,8 @@ impl Peer { .response_channels .lock() .await + .as_mut() + .ok_or_else(|| anyhow!("connection was closed"))? .insert(message_id, tx); connection .outgoing_tx @@ -520,8 +516,7 @@ mod tests { #[test] fn test_io_error() { smol::block_on(async move { - let (client_conn, server_conn, _) = Connection::in_memory(); - drop(server_conn); + let (client_conn, mut server_conn, _) = Connection::in_memory(); let client = Peer::new(); let (connection_id, io_handler, mut incoming) = @@ -529,11 +524,14 @@ mod tests { smol::spawn(io_handler).detach(); smol::spawn(async move { incoming.next().await }).detach(); - let err = client - .request(connection_id, proto::Ping {}) - .await - .unwrap_err(); - assert_eq!(err.to_string(), "connection was closed"); + let response = smol::spawn(client.request(connection_id, proto::Ping {})); + let _request = server_conn.rx.next().await.unwrap().unwrap(); + + drop(server_conn); + assert_eq!( + response.await.unwrap_err().to_string(), + "connection was closed" + ); }); } }