devices: virtio: vsock: Use async locks in async contexts

In the vsock code, we use a mutex to protect our connections `HashMap`
so that we can have access to them from multiple async tasks. We were
using a regular synchronous mutex for this, which could cause the async
tasks to block on these mutexes, possibly leading to blocking up the
executor and deadlocking vsock.

We haven't observed any bugs or deadlocks that are directly attributable
to this; it's likely that we are managing to avoid this by not holding
the lock over await points. We should still fix this for correctness, as
we have no way to otherwise enforce that future changes should uphold
the current guarantees.

BUG: b:247548758
TEST: Built and ran crosvm downstream.
Change-Id: I8928514be491f111887fbf1adac7a3f8b38219dd
Reviewed-on: https://chromium-review.googlesource.com/c/crosvm/crosvm/+/4062047
Commit-Queue: Richard Otap <rotap@google.com>
Reviewed-by: Noah Gold <nkgold@google.com>
This commit is contained in:
Richard Otap 2022-11-08 12:18:36 -08:00 committed by crosvm LUCI
parent 2594731840
commit c2ff5b359e

View file

@ -2,9 +2,6 @@
// Use of this source code is governed by a BSD-style license that can be // Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. // found in the LICENSE file.
// TODO(247548758): Remove this once all the wanrning are fixed.
#![allow(clippy::await_holding_lock)]
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt; use std::fmt;
@ -16,7 +13,6 @@ use std::os::windows::io::RawHandle;
use std::rc::Rc; use std::rc::Rc;
use std::result; use std::result;
use std::sync::Arc; use std::sync::Arc;
use std::sync::RwLock;
use std::thread; use std::thread;
use base::error; use base::error;
@ -33,6 +29,7 @@ use base::Event;
use base::EventExt; use base::EventExt;
use cros_async::select2; use cros_async::select2;
use cros_async::select6; use cros_async::select6;
use cros_async::sync::Mutex;
use cros_async::AsyncError; use cros_async::AsyncError;
use cros_async::EventAsync; use cros_async::EventAsync;
use cros_async::Executor; use cros_async::Executor;
@ -47,7 +44,6 @@ use futures::FutureExt;
use futures::SinkExt; use futures::SinkExt;
use futures::StreamExt; use futures::StreamExt;
use remain::sorted; use remain::sorted;
use sync::Mutex;
use thiserror::Error as ThisError; use thiserror::Error as ThisError;
use vm_memory::GuestMemory; use vm_memory::GuestMemory;
@ -291,7 +287,7 @@ impl PortPair {
} }
// Note: variables herein do not have to be atomic because this struct is guarded // Note: variables herein do not have to be atomic because this struct is guarded
// by a RwLock. // by a Mutex.
struct VsockConnection { struct VsockConnection {
// The guest port. // The guest port.
guest_port: Le32, guest_port: Le32,
@ -339,7 +335,7 @@ struct Worker {
host_guid: Option<String>, host_guid: Option<String>,
guest_cid: u64, guest_cid: u64,
// Map of host port to a VsockConnection. // Map of host port to a VsockConnection.
connections: RwLock<HashMap<PortPair, VsockConnection>>, connections: Mutex<HashMap<PortPair, VsockConnection>>,
connection_event: Event, connection_event: Event,
} }
@ -355,7 +351,7 @@ impl Worker {
interrupt, interrupt,
host_guid, host_guid,
guest_cid, guest_cid,
connections: RwLock::new(HashMap::new()), connections: Mutex::new(HashMap::new()),
connection_event: Event::new().unwrap(), connection_event: Event::new().unwrap(),
} }
} }
@ -372,9 +368,9 @@ impl Worker {
// TODO(b/200810561): Optimize this FuturesUnordered code. // TODO(b/200810561): Optimize this FuturesUnordered code.
// Set up the EventAsyncs to select on // Set up the EventAsyncs to select on
let futures = FuturesUnordered::new(); let futures = FuturesUnordered::new();
// This needs to be its own scope since it holds a RwLock on `self.connections`. // This needs to be its own scope since it holds a Mutex on `self.connections`.
{ {
let connections = self.connections.read().unwrap(); let connections = self.connections.read_lock().await;
for (port, connection) in connections.iter() { for (port, connection) in connections.iter() {
let h_evt = connection let h_evt = connection
.overlapped .overlapped
@ -433,7 +429,7 @@ impl Worker {
} }
continue 'connections_changed; continue 'connections_changed;
} }
let mut connections = self.connections.write().unwrap(); let mut connections = self.connections.lock().await;
let connection = if let Some(conn) = connections.get_mut(&port) { let connection = if let Some(conn) = connections.get_mut(&port) {
conn conn
} else { } else {
@ -520,7 +516,7 @@ impl Worker {
header_and_data[..HEADER_SIZE].copy_from_slice(response_header.as_slice()); header_and_data[..HEADER_SIZE].copy_from_slice(response_header.as_slice());
header_and_data[HEADER_SIZE..].copy_from_slice(data_read); header_and_data[HEADER_SIZE..].copy_from_slice(data_read);
self.write_bytes_to_queue( self.write_bytes_to_queue(
&mut recv_queue.lock(), &mut *recv_queue.lock().await,
&mut rx_queue_evt, &mut rx_queue_evt,
&header_and_data[..], &header_and_data[..],
) )
@ -621,11 +617,11 @@ impl Worker {
/// Processes a connection request and returns whether to return a response (true), or reset /// Processes a connection request and returns whether to return a response (true), or reset
/// (false). /// (false).
fn handle_vsock_connection_request(&self, header: virtio_vsock_hdr) -> bool { async fn handle_vsock_connection_request(&self, header: virtio_vsock_hdr) -> bool {
let port = PortPair::from_tx_header(&header); let port = PortPair::from_tx_header(&header);
info!("vsock: Received connection request for port {}", port); info!("vsock: Received connection request for port {}", port);
if self.connections.read().unwrap().contains_key(&port) { if self.connections.read_lock().await.contains_key(&port) {
// Connection exists, nothing for us to do. // Connection exists, nothing for us to do.
warn!( warn!(
"vsock: accepting connection request on already connected port {}", "vsock: accepting connection request on already connected port {}",
@ -691,7 +687,7 @@ impl Worker {
tx_cnt: 0_usize, tx_cnt: 0_usize,
is_buffer_full: false, is_buffer_full: false,
}; };
self.connections.write().unwrap().insert(port, connection); self.connections.lock().await.insert(port, connection);
self.connection_event.signal().unwrap_or_else(|_| { self.connection_event.signal().unwrap_or_else(|_| {
panic!( panic!(
"Failed to signal new connection event for vsock port {}.", "Failed to signal new connection event for vsock port {}.",
@ -720,7 +716,7 @@ impl Worker {
let port = PortPair::from_tx_header(&header); let port = PortPair::from_tx_header(&header);
let mut overlapped_wrapper = OverlappedWrapper::new(/* include_event= */ true).unwrap(); let mut overlapped_wrapper = OverlappedWrapper::new(/* include_event= */ true).unwrap();
{ {
let mut connections = self.connections.write().unwrap(); let mut connections = self.connections.lock().await;
if let Some(connection) = connections.get_mut(&port) { if let Some(connection) = connections.get_mut(&port) {
// Update peer buffer/recv counters // Update peer buffer/recv counters
connection.peer_recv_cnt = header.fwd_cnt.to_native() as usize; connection.peer_recv_cnt = header.fwd_cnt.to_native() as usize;
@ -766,7 +762,7 @@ impl Worker {
); );
} }
let mut connections = self.connections.write().unwrap(); let mut connections = self.connections.lock().await;
if let Some(connection) = connections.get_mut(&port) { if let Some(connection) = connections.get_mut(&port) {
let pipe = &mut connection.pipe; let pipe = &mut connection.pipe;
match pipe.get_overlapped_result(&mut overlapped_wrapper) { match pipe.get_overlapped_result(&mut overlapped_wrapper) {
@ -911,29 +907,29 @@ impl Worker {
error!("vsock: Invalid Operation requested, dropping packet"); error!("vsock: Invalid Operation requested, dropping packet");
} }
vsock_op::VIRTIO_VSOCK_OP_REQUEST => { vsock_op::VIRTIO_VSOCK_OP_REQUEST => {
let (resp_op, buf_alloc, fwd_cnt) = if self.handle_vsock_connection_request(header) let (resp_op, buf_alloc, fwd_cnt) =
{ if self.handle_vsock_connection_request(header).await {
let connections = self.connections.read().unwrap(); let connections = self.connections.read_lock().await;
let port = PortPair::from_tx_header(&header); let port = PortPair::from_tx_header(&header);
connections.get(&port).map_or_else( connections.get(&port).map_or_else(
|| { || {
warn!("vsock: port: {} connection closed during connect", port); warn!("vsock: port: {} connection closed during connect", port);
is_open = false; is_open = false;
(vsock_op::VIRTIO_VSOCK_OP_RST, 0, 0) (vsock_op::VIRTIO_VSOCK_OP_RST, 0, 0)
}, },
|conn| { |conn| {
( (
vsock_op::VIRTIO_VSOCK_OP_RESPONSE, vsock_op::VIRTIO_VSOCK_OP_RESPONSE,
conn.buf_alloc as u32, conn.buf_alloc as u32,
conn.recv_cnt as u32, conn.recv_cnt as u32,
) )
}, },
) )
} else { } else {
is_open = false; is_open = false;
(vsock_op::VIRTIO_VSOCK_OP_RST, 0, 0) (vsock_op::VIRTIO_VSOCK_OP_RST, 0, 0)
}; };
let response_header = virtio_vsock_hdr { let response_header = virtio_vsock_hdr {
src_cid: { header.dst_cid }, src_cid: { header.dst_cid },
@ -950,7 +946,7 @@ impl Worker {
// Safe because virtio_vsock_hdr is a simple data struct and converts cleanly to // Safe because virtio_vsock_hdr is a simple data struct and converts cleanly to
// bytes. // bytes.
self.write_bytes_to_queue( self.write_bytes_to_queue(
&mut send_queue.lock(), &mut *send_queue.lock().await,
rx_queue_evt, rx_queue_evt,
response_header.as_slice(), response_header.as_slice(),
) )
@ -969,7 +965,7 @@ impl Worker {
// TODO(b/237811512): Provide an optimal way to notify host of shutdowns // TODO(b/237811512): Provide an optimal way to notify host of shutdowns
// while still maintaining easy reconnections. // while still maintaining easy reconnections.
let port = PortPair::from_tx_header(&header); let port = PortPair::from_tx_header(&header);
let mut connections = self.connections.write().unwrap(); let mut connections = self.connections.lock().await;
if connections.remove(&port).is_some() { if connections.remove(&port).is_some() {
let mut response = virtio_vsock_hdr { let mut response = virtio_vsock_hdr {
src_cid: { header.dst_cid }, src_cid: { header.dst_cid },
@ -987,7 +983,7 @@ impl Worker {
}; };
// Safe because virtio_vsock_hdr is a simple data struct and converts cleanly to bytes // Safe because virtio_vsock_hdr is a simple data struct and converts cleanly to bytes
self.write_bytes_to_queue( self.write_bytes_to_queue(
&mut send_queue.lock(), &mut *send_queue.lock().await,
rx_queue_evt, rx_queue_evt,
response.as_mut_slice(), response.as_mut_slice(),
) )
@ -1004,7 +1000,11 @@ impl Worker {
vsock_op::VIRTIO_VSOCK_OP_RW => { vsock_op::VIRTIO_VSOCK_OP_RW => {
match self.handle_vsock_guest_data(header, data, ex).await { match self.handle_vsock_guest_data(header, data, ex).await {
Ok(()) => { Ok(()) => {
if self.check_free_buffer_threshold(header).unwrap_or(false) { if self
.check_free_buffer_threshold(header)
.await
.unwrap_or(false)
{
// Send a credit update if we're below the minimum free // Send a credit update if we're below the minimum free
// buffer size. We skip this if the connection is closed, // buffer size. We skip this if the connection is closed,
// which could've happened if we were closed on the other // which could've happened if we were closed on the other
@ -1026,7 +1026,7 @@ impl Worker {
// (probably) due to a a credit request *we* made. // (probably) due to a a credit request *we* made.
vsock_op::VIRTIO_VSOCK_OP_CREDIT_UPDATE => { vsock_op::VIRTIO_VSOCK_OP_CREDIT_UPDATE => {
let port = PortPair::from_tx_header(&header); let port = PortPair::from_tx_header(&header);
let mut connections = self.connections.write().unwrap(); let mut connections = self.connections.lock().await;
if let Some(connection) = connections.get_mut(&port) { if let Some(connection) = connections.get_mut(&port) {
connection.peer_recv_cnt = header.fwd_cnt.to_native() as usize; connection.peer_recv_cnt = header.fwd_cnt.to_native() as usize;
connection.peer_buf_alloc = header.buf_alloc.to_native() as usize; connection.peer_buf_alloc = header.buf_alloc.to_native() as usize;
@ -1053,8 +1053,8 @@ impl Worker {
// Checks if how much free buffer our peer thinks that *we* have available // Checks if how much free buffer our peer thinks that *we* have available
// is below our threshold percentage. If the connection is closed, returns `None`. // is below our threshold percentage. If the connection is closed, returns `None`.
fn check_free_buffer_threshold(&self, header: virtio_vsock_hdr) -> Option<bool> { async fn check_free_buffer_threshold(&self, header: virtio_vsock_hdr) -> Option<bool> {
let mut connections = self.connections.write().unwrap(); let mut connections = self.connections.lock().await;
let port = PortPair::from_tx_header(&header); let port = PortPair::from_tx_header(&header);
connections.get_mut(&port).map(|connection| { connections.get_mut(&port).map(|connection| {
let threshold: usize = (MIN_FREE_BUFFER_PCT * connection.buf_alloc as f64) as usize; let threshold: usize = (MIN_FREE_BUFFER_PCT * connection.buf_alloc as f64) as usize;
@ -1068,7 +1068,7 @@ impl Worker {
rx_queue_evt: &mut EventAsync, rx_queue_evt: &mut EventAsync,
header: virtio_vsock_hdr, header: virtio_vsock_hdr,
) { ) {
let mut connections = self.connections.write().unwrap(); let mut connections = self.connections.lock().await;
let port = PortPair::from_tx_header(&header); let port = PortPair::from_tx_header(&header);
if let Some(connection) = connections.get_mut(&port) { if let Some(connection) = connections.get_mut(&port) {
@ -1090,7 +1090,7 @@ impl Worker {
// Safe because virtio_vsock_hdr is a simple data struct and converts cleanly // Safe because virtio_vsock_hdr is a simple data struct and converts cleanly
// to bytes // to bytes
self.write_bytes_to_queue( self.write_bytes_to_queue(
&mut send_queue.lock(), &mut *send_queue.lock().await,
rx_queue_evt, rx_queue_evt,
response.as_mut_slice(), response.as_mut_slice(),
) )
@ -1110,7 +1110,7 @@ impl Worker {
rx_queue_evt: &mut EventAsync, rx_queue_evt: &mut EventAsync,
header: virtio_vsock_hdr, header: virtio_vsock_hdr,
) { ) {
let mut connections = self.connections.write().unwrap(); let mut connections = self.connections.lock().await;
let port = PortPair::from_tx_header(&header); let port = PortPair::from_tx_header(&header);
if let Some(connection) = connections.remove(&port) { if let Some(connection) = connections.remove(&port) {
let mut response = virtio_vsock_hdr { let mut response = virtio_vsock_hdr {
@ -1129,7 +1129,7 @@ impl Worker {
// Safe because virtio_vsock_hdr is a simple data struct and converts cleanly // Safe because virtio_vsock_hdr is a simple data struct and converts cleanly
// to bytes // to bytes
self.write_bytes_to_queue( self.write_bytes_to_queue(
&mut send_queue.lock(), &mut *send_queue.lock().await,
rx_queue_evt, rx_queue_evt,
response.as_mut_slice(), response.as_mut_slice(),
) )