diff --git a/devices/src/virtio/vhost/user/device/gpu.rs b/devices/src/virtio/vhost/user/device/gpu.rs index 0860ddbd9e..c2bd1bbc6b 100644 --- a/devices/src/virtio/vhost/user/device/gpu.rs +++ b/devices/src/virtio/vhost/user/device/gpu.rs @@ -35,6 +35,8 @@ use vmm_vhost::message::VhostUserVirtioFeatures; use crate::virtio; use crate::virtio::gpu; use crate::virtio::vhost::user::device::handler::sys::Doorbell; +use crate::virtio::vhost::user::device::handler::VhostBackendReqConnection; +use crate::virtio::vhost::user::device::handler::VhostBackendReqConnectionState; use crate::virtio::vhost::user::device::handler::VhostUserBackend; use crate::virtio::vhost::user::device::listener::sys::VhostUserListener; use crate::virtio::vhost::user::device::listener::VhostUserListenerTrait; @@ -45,7 +47,6 @@ use crate::virtio::GpuDisplayParameters; use crate::virtio::GpuParameters; use crate::virtio::Queue; use crate::virtio::QueueReader; -use crate::virtio::SharedMemoryMapper; use crate::virtio::SharedMemoryRegion; use crate::virtio::VirtioDevice; @@ -142,7 +143,7 @@ struct GpuBackend { fence_state: Arc>, display_worker: Option>, workers: [Option>; MAX_QUEUE_NUM], - mapper: Option>, + backend_req_conn: VhostBackendReqConnectionState, } impl VhostUserBackend for GpuBackend { @@ -229,7 +230,18 @@ impl VhostUserBackend for GpuBackend { } else { let fence_handler = gpu::create_fence_handler(mem.clone(), reader.clone(), self.fence_state.clone()); - let mapper = self.mapper.take().context("missing mapper")?; + + let mapper = { + match &mut self.backend_req_conn { + VhostBackendReqConnectionState::Connected(request) => { + request.take_shmem_mapper()? + } + VhostBackendReqConnectionState::NoConnection => { + bail!("No backend request connection found") + } + } + }; + let state = Rc::new(RefCell::new( self.gpu .borrow_mut() @@ -297,8 +309,12 @@ impl VhostUserBackend for GpuBackend { self.gpu.borrow().get_shared_memory_region() } - fn set_shared_memory_mapper(&mut self, mapper: Box) { - self.mapper = Some(mapper); + fn set_backend_req_connection(&mut self, conn: VhostBackendReqConnection) { + if let VhostBackendReqConnectionState::Connected(_) = &self.backend_req_conn { + warn!("connection already established. overwriting"); + } + + self.backend_req_conn = VhostBackendReqConnectionState::Connected(conn); } } @@ -442,7 +458,7 @@ pub fn run_gpu_device(opts: Options) -> anyhow::Result<()> { fence_state: Default::default(), display_worker: None, workers: Default::default(), - mapper: None, + backend_req_conn: VhostBackendReqConnectionState::NoConnection, }); ex.run_until(listener.run_backend(backend, &ex))? } diff --git a/devices/src/virtio/vhost/user/device/handler.rs b/devices/src/virtio/vhost/user/device/handler.rs index 310a47ce36..79a753b108 100644 --- a/devices/src/virtio/vhost/user/device/handler.rs +++ b/devices/src/virtio/vhost/user/device/handler.rs @@ -210,12 +210,14 @@ pub trait VhostUserBackend { None } - /// Accepts the trait object used to map files into the device's shared - /// memory region. + /// Accepts `VhostBackendReqConnection` to conduct Vhost backend to frontend message + /// handling. /// - /// If `get_shared_memory_region` returns `Some`, then this will be called - /// before `start_queue`. - fn set_shared_memory_mapper(&mut self, _mapper: Box) {} + /// This method will be called when `VhostUserProtocolFeatures::SLAVE_REQ` is + /// negotiated. + fn set_backend_req_connection(&mut self, _conn: VhostBackendReqConnection) { + error!("set_backend_req_connection is not implemented"); + } } /// A virtio ring entry. @@ -645,14 +647,8 @@ impl VhostUserSlaveReqHandlerMut for DeviceRequestHandl } fn set_slave_req_fd(&mut self, ep: Box>) { - let shmid = self.shmid.expect("unexpected slave_req_fd"); - let frontend = Slave::new(ep); - self.backend - .set_shared_memory_mapper(Box::new(VhostShmemMapper { - frontend, - shmid, - mapped_regions: BTreeMap::new(), - })); + let conn = VhostBackendReqConnection::new(Slave::new(ep), self.shmid); + self.backend.set_backend_req_connection(conn); } fn get_inflight_fd( @@ -695,12 +691,62 @@ impl VhostUserSlaveReqHandlerMut for DeviceRequestHandl } } -struct VhostShmemMapper { - frontend: Slave, +/// Indicates the state of backend request connection +pub enum VhostBackendReqConnectionState { + /// A backend request connection (`VhostBackendReqConnection`) is established + Connected(VhostBackendReqConnection), + /// No backend request connection has been established yet + NoConnection, +} + +/// Keeps track of Vhost user backend request connection. +pub struct VhostBackendReqConnection { + conn: Slave, + shmem_info: Option, +} + +#[derive(Clone)] +struct ShmemInfo { shmid: u8, mapped_regions: BTreeMap, } +impl VhostBackendReqConnection { + pub fn new(conn: Slave, shmid: Option) -> Self { + let shmem_info = shmid.map(|shmid| ShmemInfo { + shmid, + mapped_regions: BTreeMap::new(), + }); + Self { conn, shmem_info } + } + + /// Send `VHOST_USER_CONFIG_CHANGE_MSG` to the frontend + pub fn send_config_changed(&self) -> anyhow::Result<()> { + self.conn + .handle_config_change() + .context("Could not send config change message")?; + Ok(()) + } + + /// Create a SharedMemoryMapper trait object from the ShmemInfo. + pub fn take_shmem_mapper(&mut self) -> anyhow::Result> { + let shmem_info = self + .shmem_info + .take() + .context("could not take shared memory mapper information")?; + + Ok(Box::new(VhostShmemMapper { + conn: self.conn.clone(), + shmem_info, + })) + } +} + +struct VhostShmemMapper { + conn: Slave, + shmem_info: ShmemInfo, +} + impl SharedMemoryMapper for VhostShmemMapper { fn add_mapping( &mut self, @@ -724,21 +770,22 @@ impl SharedMemoryMapper for VhostShmemMapper { _ => bail!("unsupported source"), }; let flags = VhostUserShmemMapMsgFlags::from(prot); - let msg = VhostUserShmemMapMsg::new(self.shmid, offset, fd_offset, size, flags); - self.frontend + let msg = VhostUserShmemMapMsg::new(self.shmem_info.shmid, offset, fd_offset, size, flags); + self.conn .shmem_map(&msg, &descriptor) .context("failed to map memory")?; - self.mapped_regions.insert(offset, size); + self.shmem_info.mapped_regions.insert(offset, size); Ok(()) } fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> { let size = self + .shmem_info .mapped_regions .remove(&offset) .context("unknown offset")?; - let msg = VhostUserShmemUnmapMsg::new(self.shmid, offset, size); - self.frontend + let msg = VhostUserShmemUnmapMsg::new(self.shmem_info.shmid, offset, size); + self.conn .shmem_unmap(&msg) .context("failed to map memory") .map(|_| ()) diff --git a/devices/src/virtio/vhost/user/device/wl.rs b/devices/src/virtio/vhost/user/device/wl.rs index e88c577649..b0386786c0 100644 --- a/devices/src/virtio/vhost/user/device/wl.rs +++ b/devices/src/virtio/vhost/user/device/wl.rs @@ -37,12 +37,13 @@ use vmm_vhost::message::VhostUserVirtioFeatures; use crate::virtio::base_features; use crate::virtio::vhost::user::device::handler::sys::Doorbell; +use crate::virtio::vhost::user::device::handler::VhostBackendReqConnection; +use crate::virtio::vhost::user::device::handler::VhostBackendReqConnectionState; use crate::virtio::vhost::user::device::handler::VhostUserBackend; use crate::virtio::vhost::user::device::listener::sys::VhostUserListener; use crate::virtio::vhost::user::device::listener::VhostUserListenerTrait; use crate::virtio::wl; use crate::virtio::Queue; -use crate::virtio::SharedMemoryMapper; use crate::virtio::SharedMemoryRegion; const MAX_QUEUE_NUM: usize = wl::QUEUE_SIZES.len(); @@ -96,7 +97,6 @@ async fn run_in_queue( struct WlBackend { ex: Executor, wayland_paths: Option>, - mapper: Option>, resource_bridge: Option, use_transition_flags: bool, use_send_vfd_v2: bool, @@ -105,6 +105,7 @@ struct WlBackend { acked_features: u64, wlstate: Option>>, workers: [Option; MAX_QUEUE_NUM], + backend_req_conn: VhostBackendReqConnectionState, } impl WlBackend { @@ -121,7 +122,6 @@ impl WlBackend { WlBackend { ex: ex.clone(), wayland_paths: Some(wayland_paths), - mapper: None, resource_bridge, use_transition_flags: false, use_send_vfd_v2: false, @@ -130,6 +130,7 @@ impl WlBackend { acked_features: 0, wlstate: None, workers: Default::default(), + backend_req_conn: VhostBackendReqConnectionState::NoConnection, } } } @@ -217,12 +218,22 @@ impl VhostUserBackend for WlBackend { // think we're borrowing all of `self` in the closure below. let WlBackend { ref mut wayland_paths, - ref mut mapper, ref mut resource_bridge, ref use_transition_flags, ref use_send_vfd_v2, .. } = self; + + let mapper = { + match &mut self.backend_req_conn { + VhostBackendReqConnectionState::Connected(request) => { + request.take_shmem_mapper()? + } + VhostBackendReqConnectionState::NoConnection => { + bail!("No backend request connection found") + } + } + }; #[cfg(feature = "minigbm")] let gralloc = RutabagaGralloc::new().context("Failed to initailize gralloc")?; let wlstate = self @@ -230,7 +241,7 @@ impl VhostUserBackend for WlBackend { .get_or_insert_with(|| { Rc::new(RefCell::new(wl::WlState::new( wayland_paths.take().expect("WlState already initialized"), - mapper.take().expect("WlState already initialized"), + mapper, *use_transition_flags, *use_send_vfd_v2, resource_bridge.take(), @@ -294,8 +305,12 @@ impl VhostUserBackend for WlBackend { }) } - fn set_shared_memory_mapper(&mut self, mapper: Box) { - self.mapper = Some(mapper); + fn set_backend_req_connection(&mut self, conn: VhostBackendReqConnection) { + if let VhostBackendReqConnectionState::Connected(_) = &self.backend_req_conn { + warn!("connection already established. Overwriting"); + } + + self.backend_req_conn = VhostBackendReqConnectionState::Connected(conn); } } diff --git a/devices/src/virtio/vhost/user/vmm/handler.rs b/devices/src/virtio/vhost/user/vmm/handler.rs index 33669ec066..fa8fc38ec3 100644 --- a/devices/src/virtio/vhost/user/vmm/handler.rs +++ b/devices/src/virtio/vhost/user/vmm/handler.rs @@ -28,6 +28,7 @@ use vmm_vhost::VhostUserMasterReqHandlerMut; use vmm_vhost::VhostUserMemoryRegionInfo; use vmm_vhost::VringConfigData; +use crate::virtio::vhost::user::vmm::handler::sys::create_backend_req_handler; use crate::virtio::vhost::user::vmm::handler::sys::SocketMaster; use crate::virtio::vhost::user::vmm::Error; use crate::virtio::vhost::user::vmm::Result; @@ -81,12 +82,30 @@ impl VhostUserHandler { .map_err(Error::SetProtocolFeatures)?; } + // Create backend request handler and send slave request fd to backend + // if protocol feature `VhostUserProtocolFeatures::SLAVE_REQ` is negotiated. + let backend_req_handler = + if protocol_features.contains(VhostUserProtocolFeatures::SLAVE_REQ) { + let mut handler = create_backend_req_handler( + BackendReqHandlerImpl { + shared_mapper_state: None, + }, + #[cfg(windows)] + backend_pid, + )?; + vu.set_slave_request_fd(&handler.take_tx_descriptor()) + .map_err(Error::SetDeviceRequestChannel)?; + Some(handler) + } else { + None + }; + Ok(VhostUserHandler { vu, avail_features, acked_features, protocol_features, - backend_req_handler: None, + backend_req_handler, shmem_region: None, #[cfg(windows)] backend_pid, @@ -307,6 +326,15 @@ impl VhostUserHandler { } pub fn set_shared_memory_mapper(&mut self, mapper: Box) -> Result<()> { + // Return error if backend request handler is not available. This indicates + // that `VhostUserProtocolFeatures::SLAVE_REQ` is not negotiated. + let backend_req_handler = + self.backend_req_handler + .as_mut() + .ok_or(Error::ProtocolFeatureNotNegoiated( + VhostUserProtocolFeatures::SLAVE_REQ, + ))?; + // The virtio framework will only call this if get_shared_memory_region returned a region let shmid = self .shmem_region @@ -314,26 +342,49 @@ impl VhostUserHandler { .flatten() .expect("missing shmid") .id; - self.initialize_backend_req_handler(BackendReqHandlerImpl { mapper, shmid }) + + backend_req_handler + .backend() + .lock() + .unwrap() + .set_shared_mapper_state(SharedMapperState { mapper, shmid }); + Ok(()) } } -pub struct BackendReqHandlerImpl { +struct SharedMapperState { mapper: Box, shmid: u8, } +pub struct BackendReqHandlerImpl { + shared_mapper_state: Option, +} + +impl BackendReqHandlerImpl { + fn set_shared_mapper_state(&mut self, shared_mapper_state: SharedMapperState) { + self.shared_mapper_state = Some(shared_mapper_state); + } +} + impl VhostUserMasterReqHandlerMut for BackendReqHandlerImpl { fn shmem_map( &mut self, req: &VhostUserShmemMapMsg, fd: &dyn AsRawDescriptor, ) -> HandlerResult { - if req.shmid != self.shmid { - error!("bad shmid {}, expected {}", req.shmid, self.shmid); + let shared_mapper_state = self + .shared_mapper_state + .as_mut() + .ok_or_else(|| std::io::Error::from_raw_os_error(libc::EINVAL))?; + if req.shmid != shared_mapper_state.shmid { + error!( + "bad shmid {}, expected {}", + req.shmid, shared_mapper_state.shmid + ); return Err(std::io::Error::from_raw_os_error(libc::EINVAL)); } - match self.mapper.add_mapping( + match shared_mapper_state.mapper.add_mapping( VmMemorySource::Descriptor { descriptor: SafeDescriptor::try_from(fd) .map_err(|_| std::io::Error::from_raw_os_error(libc::EIO))?, @@ -353,11 +404,18 @@ impl VhostUserMasterReqHandlerMut for BackendReqHandlerImpl { } fn shmem_unmap(&mut self, req: &VhostUserShmemUnmapMsg) -> HandlerResult { - if req.shmid != self.shmid { - error!("bad shmid {}, expected {}", req.shmid, self.shmid); + let shared_mapper_state = self + .shared_mapper_state + .as_mut() + .ok_or_else(|| std::io::Error::from_raw_os_error(libc::EINVAL))?; + if req.shmid != shared_mapper_state.shmid { + error!( + "bad shmid {}, expected {}", + req.shmid, shared_mapper_state.shmid + ); return Err(std::io::Error::from_raw_os_error(libc::EINVAL)); } - match self.mapper.remove_mapping(req.shm_offset) { + match shared_mapper_state.mapper.remove_mapping(req.shm_offset) { Ok(()) => Ok(0), Err(e) => { error!("failed to remove mapping {:?}", e); diff --git a/devices/src/virtio/vhost/user/vmm/handler/sys/unix.rs b/devices/src/virtio/vhost/user/vmm/handler/sys/unix.rs index 1dce1a6f13..9ed60c4698 100644 --- a/devices/src/virtio/vhost/user/vmm/handler/sys/unix.rs +++ b/devices/src/virtio/vhost/user/vmm/handler/sys/unix.rs @@ -19,7 +19,6 @@ use vmm_vhost::message::VhostUserProtocolFeatures; use vmm_vhost::Error as VhostError; use vmm_vhost::Master; use vmm_vhost::MasterReqHandler; -use vmm_vhost::VhostUserMaster; use crate::virtio::vhost::user::vmm::handler::BackendReqHandler; use crate::virtio::vhost::user::vmm::handler::BackendReqHandlerImpl; @@ -48,16 +47,12 @@ impl VhostUserHandler { allow_protocol_features, ) } +} - pub fn initialize_backend_req_handler(&mut self, h: BackendReqHandlerImpl) -> VhostResult<()> { - let mut handler = MasterReqHandler::with_stream(Arc::new(Mutex::new(h))) - .map_err(Error::CreateShmemMapperError)?; - self.vu - .set_slave_request_fd(&handler.take_tx_descriptor()) - .map_err(Error::SetDeviceRequestChannel)?; - self.backend_req_handler = Some(handler); - Ok(()) - } +pub fn create_backend_req_handler(h: BackendReqHandlerImpl) -> VhostResult { + let handler = MasterReqHandler::with_stream(Arc::new(Mutex::new(h))) + .map_err(Error::CreateBackendReqHandler)?; + Ok(handler) } pub async fn run_backend_request_handler( diff --git a/devices/src/virtio/vhost/user/vmm/handler/sys/windows.rs b/devices/src/virtio/vhost/user/vmm/handler/sys/windows.rs index fdb293c658..a511071bf8 100644 --- a/devices/src/virtio/vhost/user/vmm/handler/sys/windows.rs +++ b/devices/src/virtio/vhost/user/vmm/handler/sys/windows.rs @@ -54,19 +54,16 @@ impl VhostUserHandler { backend_pid, ) } +} - pub fn initialize_backend_req_handler(&mut self, h: BackendReqHandlerImpl) -> VhostResult<()> { - let backend_pid = self - .backend_pid - .expect("tube needs target pid for backend requests"); - let mut handler = MasterReqHandler::with_tube(Arc::new(Mutex::new(h)), backend_pid) - .map_err(Error::CreateShmemMapperError)?; - self.vu - .set_slave_request_fd(&handler.take_tx_descriptor()) - .map_err(Error::SetDeviceRequestChannel)?; - self.backend_req_handler = Some(handler); - Ok(()) - } +pub fn create_backend_req_handler( + h: BackendReqHandlerImpl, + backend_pid: Option, +) -> VhostResult { + let backend_pid = backend_pid.expect("tube needs target pid for backend requests"); + let mut handler = MasterReqHandler::with_tube(Arc::new(Mutex::new(h)), backend_pid) + .map_err(Error::CreateBackendReqHandler)?; + Ok(handler) } pub async fn run_backend_request_handler( diff --git a/devices/src/virtio/vhost/user/vmm/mod.rs b/devices/src/virtio/vhost/user/vmm/mod.rs index d381854c85..f65e503e09 100644 --- a/devices/src/virtio/vhost/user/vmm/mod.rs +++ b/devices/src/virtio/vhost/user/vmm/mod.rs @@ -8,6 +8,7 @@ mod handler; use remain::sorted; use thiserror::Error as ThisError; use vm_memory::GuestMemoryError; +use vmm_vhost::message::VhostUserProtocolFeatures; use vmm_vhost::Error as VhostError; pub use self::block::*; @@ -50,12 +51,15 @@ cfg_if::cfg_if! { #[sorted] #[derive(ThisError, Debug)] pub enum Error { + /// Failed to copy config to a buffer. + #[error("failed to copy config to a buffer: {0}")] + CopyConfig(std::io::Error), + /// Failed to create backend request handler + #[error("could not create backend req handler: {0}")] + CreateBackendReqHandler(VhostError), /// Failed to create `base::Event`. #[error("failed to create Event: {0}")] CreateEvent(base::Error), - /// Unsupported shared memory mapper - #[error("unsupported shared memory mapper: {0}")] - CreateShmemMapperError(VhostError), /// Failed to get config. #[error("failed to get config: {0}")] GetConfig(VhostError), @@ -86,6 +90,8 @@ pub enum Error { /// MSI-X irqfd is unavailable. #[error("MSI-X irqfd is unavailable")] MsixIrqfdUnavailable, + #[error("protocol feature is not negotiated: {0:?}")] + ProtocolFeatureNotNegoiated(VhostUserProtocolFeatures), /// Failed to reset owner. #[error("failed to reset owner: {0}")] ResetOwner(VhostError), diff --git a/third_party/vmm_vhost/src/master_req_handler.rs b/third_party/vmm_vhost/src/master_req_handler.rs index c63fd3da74..ca86ed9575 100644 --- a/third_party/vmm_vhost/src/master_req_handler.rs +++ b/third_party/vmm_vhost/src/master_req_handler.rs @@ -206,8 +206,10 @@ pub struct MasterReqHandler { serialize_tx: Box SafeDescriptor + Send>, // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. reply_ack_negotiated: bool, - // the VirtIO backend device object + + /// the VirtIO backend device object backend: Arc, + // whether the endpoint has encountered any failure error: Option, } @@ -265,6 +267,11 @@ impl MasterReqHandler { } } + /// Get the underlying backend device + pub fn backend(&self) -> Arc { + Arc::clone(&self.backend) + } + /// Main entrance to server slave request from the slave communication channel. /// /// The caller needs to: