diff --git a/devices/src/virtio/vhost_user_frontend/mod.rs b/devices/src/virtio/vhost_user_frontend/mod.rs index 26cbdd8e29..6a28097cb6 100644 --- a/devices/src/virtio/vhost_user_frontend/mod.rs +++ b/devices/src/virtio/vhost_user_frontend/mod.rs @@ -14,6 +14,7 @@ use std::cell::RefCell; use std::collections::BTreeMap; use std::io::Read; use std::io::Write; +use std::sync::Arc; use anyhow::bail; use anyhow::Context; @@ -24,6 +25,7 @@ use base::Event; use base::RawDescriptor; use base::WorkerThread; use serde_json::Value; +use sync::Mutex; use vm_memory::GuestMemory; use vmm_vhost::message::VhostUserConfigFlags; use vmm_vhost::message::VhostUserMigrationPhase; @@ -54,7 +56,7 @@ pub struct VhostUserFrontend { device_type: DeviceType, worker_thread: Option>>, - backend_client: BackendClient, + backend_client: Arc>, avail_features: u64, acked_features: u64, protocol_features: VhostUserProtocolFeatures, @@ -238,7 +240,7 @@ impl VhostUserFrontend { Ok(VhostUserFrontend { device_type, worker_thread: None, - backend_client, + backend_client: Arc::new(Mutex::new(backend_client)), avail_features, acked_features, protocol_features, @@ -265,6 +267,7 @@ impl VhostUserFrontend { .collect(); self.backend_client + .lock() .set_mem_table(regions.as_slice()) .map_err(Error::SetMemTable)?; @@ -279,7 +282,8 @@ impl VhostUserFrontend { queue: &Queue, irqfd: &Event, ) -> Result<()> { - self.backend_client + let backend_client = self.backend_client.lock(); + backend_client .set_vring_num(queue_index, queue.size()) .map_err(Error::SetVringNum)?; @@ -297,25 +301,25 @@ impl VhostUserFrontend { .map_err(Error::GetHostAddress)? as u64, log_addr: None, }; - self.backend_client + backend_client .set_vring_addr(queue_index, &config_data) .map_err(Error::SetVringAddr)?; - self.backend_client + backend_client .set_vring_base(queue_index, queue.next_avail_to_process()) .map_err(Error::SetVringBase)?; - self.backend_client + backend_client .set_vring_call(queue_index, irqfd) .map_err(Error::SetVringCall)?; - self.backend_client + backend_client .set_vring_kick(queue_index, queue.event()) .map_err(Error::SetVringKick)?; // Per protocol documentation, `VHOST_USER_SET_VRING_ENABLE` should be sent only when // `VHOST_USER_F_PROTOCOL_FEATURES` has been negotiated. if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 { - self.backend_client + backend_client .set_vring_enable(queue_index, true) .map_err(Error::SetVringEnable)?; } @@ -325,14 +329,15 @@ impl VhostUserFrontend { /// Stops the vring for the given `queue`, returning its base index. fn deactivate_vring(&self, queue_index: usize) -> Result { + let backend_client = self.backend_client.lock(); + if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 { - self.backend_client + backend_client .set_vring_enable(queue_index, false) .map_err(Error::SetVringEnable)?; } - let vring_base = self - .backend_client + let vring_base = backend_client .get_vring_base(queue_index) .map_err(Error::GetVringBase)?; @@ -357,11 +362,16 @@ impl VhostUserFrontend { handler.frontend_mut().set_interrupt(interrupt.clone()); } + #[cfg(any(target_os = "android", target_os = "linux"))] + let backend_client = self.backend_client.clone(); + self.worker_thread = Some(WorkerThread::start(label.clone(), move |kill_evt| { let mut worker = Worker { kill_evt, non_msix_evt, backend_req_handler, + #[cfg(any(target_os = "android", target_os = "linux"))] + backend_client, }; if let Err(e) = worker.run(interrupt) { error!("failed to run {} worker: {:#}", label, e); @@ -397,6 +407,7 @@ impl VirtioDevice for VhostUserFrontend { let features = (features & self.avail_features) | self.acked_features; if let Err(e) = self .backend_client + .lock() .set_features(features) .map_err(Error::SetFeatures) { @@ -423,7 +434,7 @@ impl VirtioDevice for VhostUserFrontend { ); return; }; - let (_, config) = match self.backend_client.get_config( + let (_, config) = match self.backend_client.lock().get_config( offset, data_len, VhostUserConfigFlags::WRITABLE, @@ -445,6 +456,7 @@ impl VirtioDevice for VhostUserFrontend { }; if let Err(e) = self .backend_client + .lock() .set_config(offset, VhostUserConfigFlags::empty(), data) .map_err(Error::SetConfig) { @@ -514,6 +526,7 @@ impl VirtioDevice for VhostUserFrontend { } let regions = match self .backend_client + .lock() .get_shared_memory_regions() .map_err(Error::ShmemRegions) { @@ -609,11 +622,11 @@ impl VirtioDevice for VhostUserFrontend { { bail!("snapshot requires VHOST_USER_PROTOCOL_F_DEVICE_STATE"); } + let backend_client = self.backend_client.lock(); // Send the backend an FD to write the device state to. If it gives us an FD back, then // we need to read from that instead. let (mut r, w) = new_pipe_pair()?; - let backend_r = self - .backend_client + let backend_r = backend_client .set_device_state_fd( VhostUserTransferDirection::Save, VhostUserMigrationPhase::Stopped, @@ -632,7 +645,7 @@ impl VirtioDevice for VhostUserFrontend { } .context("failed to read device state")?; // Call `check_device_state` to ensure the data transfer was successful. - self.backend_client + backend_client .check_device_state() .context("failed to transfer device state")?; Ok(serde_json::to_value(snapshot_bytes).map_err(Error::SliceToSerdeValue)?) @@ -646,12 +659,12 @@ impl VirtioDevice for VhostUserFrontend { bail!("restore requires VHOST_USER_PROTOCOL_F_DEVICE_STATE"); } + let backend_client = self.backend_client.lock(); let data_bytes: Vec = serde_json::from_value(data).map_err(Error::SerdeValueToSlice)?; // Send the backend an FD to read the device state from. If it gives us an FD back, // then we need to write to that instead. let (r, w) = new_pipe_pair()?; - let backend_w = self - .backend_client + let backend_w = backend_client .set_device_state_fd( VhostUserTransferDirection::Load, VhostUserMigrationPhase::Stopped, @@ -673,7 +686,7 @@ impl VirtioDevice for VhostUserFrontend { .context("failed to write device state")?; } // Call `check_device_state` to ensure the data transfer was successful. - self.backend_client + backend_client .check_device_state() .context("failed to transfer device state")?; Ok(()) diff --git a/devices/src/virtio/vhost_user_frontend/worker.rs b/devices/src/virtio/vhost_user_frontend/worker.rs index 369ebe0503..567ddfa02a 100644 --- a/devices/src/virtio/vhost_user_frontend/worker.rs +++ b/devices/src/virtio/vhost_user_frontend/worker.rs @@ -2,9 +2,12 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +use std::sync::Arc; + use anyhow::bail; use anyhow::Context; use base::info; +use base::warn; #[cfg(windows)] use base::CloseNotifier; use base::Event; @@ -12,6 +15,8 @@ use base::EventToken; use base::EventType; use base::ReadNotifier; use base::WaitContext; +use sync::Mutex; +use vmm_vhost::BackendClient; use vmm_vhost::Error as VhostError; use crate::virtio::vhost_user_frontend::handler::BackendReqHandler; @@ -22,6 +27,8 @@ pub struct Worker { pub kill_evt: Event, pub non_msix_evt: Event, pub backend_req_handler: Option, + #[cfg(any(target_os = "android", target_os = "linux"))] + pub backend_client: Arc>, } impl Worker { @@ -33,8 +40,10 @@ impl Worker { Resample, ReqHandlerRead, ReqHandlerClose, + // monitor whether backend_client_fd is broken + #[cfg(any(target_os = "android", target_os = "linux"))] + BackendCloseNotify, } - let wait_ctx = WaitContext::build_with(&[ (&self.non_msix_evt, Token::NonMsixEvt), (&self.kill_evt, Token::Kill), @@ -72,6 +81,15 @@ impl Worker { .context("failed to add backend req handler close notifier to WaitContext")?; } + #[cfg(any(target_os = "android", target_os = "linux"))] + wait_ctx + .add_for_event( + self.backend_client.lock().get_read_notifier(), + EventType::None, + Token::BackendCloseNotify, + ) + .context("failed to add backend client close notifier to WaitContext")?; + 'wait: loop { let events = wait_ctx.wait().context("WaitContext::wait() failed")?; for event in events { @@ -125,6 +143,16 @@ impl Worker { let _ = wait_ctx.delete(backend_req_handler.get_close_notifier()); self.backend_req_handler = None; } + #[cfg(any(target_os = "android", target_os = "linux"))] + Token::BackendCloseNotify => { + // For linux domain socket, the close notifier fd is same with read/write + // notifier We need check whether the event is caused by socket broken. + if !event.is_hungup { + warn!("event besides hungup should not be notified"); + continue; + } + panic!("Backend device disconnected"); + } } } } diff --git a/third_party/vmm_vhost/src/backend_client.rs b/third_party/vmm_vhost/src/backend_client.rs index eee561faaf..f091ca9a28 100644 --- a/third_party/vmm_vhost/src/backend_client.rs +++ b/third_party/vmm_vhost/src/backend_client.rs @@ -5,8 +5,11 @@ use std::fs::File; use std::mem; use base::AsRawDescriptor; +#[cfg(windows)] +use base::CloseNotifier; use base::Event; use base::RawDescriptor; +use base::ReadNotifier; use base::INVALID_DESCRIPTOR; use zerocopy::AsBytes; use zerocopy::FromBytes; @@ -643,6 +646,19 @@ impl BackendClient { } } +#[cfg(windows)] +impl CloseNotifier for BackendClient { + fn get_close_notifier(&self) -> &dyn AsRawDescriptor { + self.connection.0.get_close_notifier() + } +} + +impl ReadNotifier for BackendClient { + fn get_read_notifier(&self) -> &dyn AsRawDescriptor { + self.connection.0.get_read_notifier() + } +} + // TODO(b/221882601): likely need pairs of RDs and/or SharedMemory to represent mmaps on Windows. /// Context object to pass guest memory configuration to BackendClient::set_mem_table(). struct VhostUserMemoryContext { diff --git a/third_party/vmm_vhost/src/sys/windows.rs b/third_party/vmm_vhost/src/sys/windows.rs index 01ba3f0b3f..464fb62a1a 100644 --- a/third_party/vmm_vhost/src/sys/windows.rs +++ b/third_party/vmm_vhost/src/sys/windows.rs @@ -205,6 +205,19 @@ impl AsRawDescriptor for TubePlatformConnection { } } +impl CloseNotifier for TubePlatformConnection { + /// Used for closing. + fn get_close_notifier(&self) -> &dyn AsRawDescriptor { + self.tube.get_close_notifier() + } +} + +impl ReadNotifier for TubePlatformConnection { + fn get_read_notifier(&self) -> &dyn AsRawDescriptor { + self.tube.get_close_notifier() + } +} + impl FrontendServer { /// Create a `FrontendServer` that uses a Tube internally. Must specify the backend process /// which will receive the Tube.