diff --git a/Cargo.lock b/Cargo.lock index 241d34356d..b0cf20e8ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1908,6 +1908,7 @@ dependencies = [ "serde_json", "tempfile", "thiserror", + "tube_transporter", ] [[package]] diff --git a/cros_async/src/event.rs b/cros_async/src/event.rs index 1e091babe7..0c97d1167a 100644 --- a/cros_async/src/event.rs +++ b/cros_async/src/event.rs @@ -2,8 +2,15 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -use base::Event; +use std::mem::ManuallyDrop; +use base::AsRawDescriptor; +use base::Event; +use base::FromRawDescriptor; + +use crate::AsyncError; +use crate::AsyncResult; +use crate::Executor; use crate::IntoAsync; use crate::IoSourceExt; @@ -18,6 +25,25 @@ impl EventAsync { pub fn get_io_source_ref(&self) -> &dyn IoSourceExt { self.io_source.as_ref() } + + /// Given a non-owning raw descriptor to an Event, will make a clone to construct this async + /// Event. Use for cases where you have a valid raw event descriptor, but don't own it. + pub fn clone_raw(descriptor: &dyn AsRawDescriptor, ex: &Executor) -> AsyncResult { + // Safe because: + // a) the underlying Event should be validated by the caller. + // b) we do NOT take ownership of the underlying Event. If we did that would cause an early + // free (and later a double free @ the end of this scope). This is why we have to wrap + // it in ManuallyDrop. + // c) we own the clone that is produced exclusively, so it is safe to take ownership of it. + Self::new( + unsafe { + ManuallyDrop::new(Event::from_raw_descriptor(descriptor.as_raw_descriptor())) + } + .try_clone() + .map_err(AsyncError::EventAsync)?, + ex, + ) + } } impl IntoAsync for Event {} diff --git a/cros_async/src/io_ext.rs b/cros_async/src/io_ext.rs index 6ed6381ecb..d989b2b8c5 100644 --- a/cros_async/src/io_ext.rs +++ b/cros_async/src/io_ext.rs @@ -36,6 +36,9 @@ use super::MemRegion; #[sorted] #[derive(ThisError, Debug)] pub enum Error { + /// An error with EventAsync. + #[error("An error with an EventAsync: {0}")] + EventAsync(base::Error), /// An error with a polled(FD) source. #[error("An error with a poll source: {0}")] Poll(crate::sys::unix::poll_source::Error), @@ -77,6 +80,7 @@ impl From for io::Error { fn from(e: Error) -> Self { use Error::*; match e { + EventAsync(e) => e.into(), Poll(e) => e.into(), Uring(e) => e.into(), } diff --git a/devices/src/virtio/vhost/user/device/handler/sys/windows.rs b/devices/src/virtio/vhost/user/device/handler/sys/windows.rs index d62d8d76f5..9eee630ec1 100644 --- a/devices/src/virtio/vhost/user/device/handler/sys/windows.rs +++ b/devices/src/virtio/vhost/user/device/handler/sys/windows.rs @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -use std::mem::ManuallyDrop; - use anyhow::Context; use anyhow::Result; use base::info; @@ -12,7 +10,6 @@ use base::named_pipes::FramingMode; use base::named_pipes::PipeConnection; use base::CloseNotifier; use base::Event; -use base::FromRawDescriptor; use base::RawDescriptor; use base::ReadNotifier; use base::Tube; @@ -52,36 +49,10 @@ impl DeviceRequestHandler { let read_notifier = vhost_user_tube.get_read_notifier(); let close_notifier = vhost_user_tube.get_close_notifier(); - // Safe because: - // a) the underlying Event is guaranteed valid by the Tube. - // b) we do NOT take ownership of the underlying Event. If we did that would cause an early - // free (and later a double free @ the end of this scope). This is why we have to wrap - // it in ManuallyDrop. - // c) we own the clone that is produced exclusively, so it is safe to take ownership of it. - let read_event = EventAsync::new( - // Safe, see block comment. - unsafe { - ManuallyDrop::new(Event::from_raw_descriptor( - read_notifier.as_raw_descriptor(), - )) - } - .try_clone() - .context("failed to clone event")?, - ex, - ) - .context("failed to create an async event")?; - let close_event = EventAsync::new( - // Safe, see block comment. - unsafe { - ManuallyDrop::new(Event::from_raw_descriptor( - close_notifier.as_raw_descriptor(), - )) - } - .try_clone() - .context("failed to clone event")?, - ex, - ) - .context("failed to create an async event")?; + let read_event = + EventAsync::clone_raw(read_notifier, ex).context("failed to create an async event")?; + let close_event = + EventAsync::clone_raw(close_notifier, ex).context("failed to create an async event")?; let exit_event = EventAsync::new(exit_event, ex).context("failed to create an async event")?; diff --git a/devices/src/virtio/vhost/user/vmm/handler.rs b/devices/src/virtio/vhost/user/vmm/handler.rs index a77426fceb..61273acbdc 100644 --- a/devices/src/virtio/vhost/user/vmm/handler.rs +++ b/devices/src/virtio/vhost/user/vmm/handler.rs @@ -6,6 +6,7 @@ mod sys; mod worker; use std::io::Write; +use std::sync::Mutex; use std::thread; use base::error; @@ -21,13 +22,13 @@ use vmm_vhost::message::VhostUserShmemMapMsg; use vmm_vhost::message::VhostUserShmemUnmapMsg; use vmm_vhost::message::VhostUserVirtioFeatures; use vmm_vhost::HandlerResult; +use vmm_vhost::MasterReqHandler; use vmm_vhost::VhostBackend; use vmm_vhost::VhostUserMaster; use vmm_vhost::VhostUserMasterReqHandlerMut; use vmm_vhost::VhostUserMemoryRegionInfo; use vmm_vhost::VringConfigData; -use crate::virtio::vhost::user::vmm::handler::sys::BackendReqHandler; use crate::virtio::vhost::user::vmm::handler::sys::SocketMaster; use crate::virtio::vhost::user::vmm::Error; use crate::virtio::vhost::user::vmm::Result; @@ -36,6 +37,8 @@ use crate::virtio::Queue; use crate::virtio::SharedMemoryMapper; use crate::virtio::SharedMemoryRegion; +type BackendReqHandler = MasterReqHandler>; + fn set_features(vu: &mut SocketMaster, avail_features: u64, ack_features: u64) -> Result { let features = avail_features & ack_features; vu.set_features(features).map_err(Error::SetFeatures)?; @@ -50,6 +53,9 @@ pub struct VhostUserHandler { backend_req_handler: Option, // Shared memory region info. IPC result from backend is saved with outer Option. shmem_region: Option>, + // On Windows, we need a backend pid to support backend requests. + #[cfg(windows)] + backend_pid: Option, } impl VhostUserHandler { @@ -59,6 +65,7 @@ impl VhostUserHandler { allow_features: u64, init_features: u64, allow_protocol_features: VhostUserProtocolFeatures, + #[cfg(windows)] backend_pid: Option, ) -> Result { vu.set_owner().map_err(Error::SetOwner)?; @@ -82,6 +89,8 @@ impl VhostUserHandler { protocol_features, backend_req_handler: None, shmem_region: None, + #[cfg(windows)] + backend_pid, }) } diff --git a/devices/src/virtio/vhost/user/vmm/handler/sys.rs b/devices/src/virtio/vhost/user/vmm/handler/sys.rs index e74fb8def5..ac7c2bc213 100644 --- a/devices/src/virtio/vhost/user/vmm/handler/sys.rs +++ b/devices/src/virtio/vhost/user/vmm/handler/sys.rs @@ -15,4 +15,3 @@ cfg_if::cfg_if! { } pub(super) use platform::run_backend_request_handler; -pub(super) use platform::BackendReqHandler; diff --git a/devices/src/virtio/vhost/user/vmm/handler/sys/unix.rs b/devices/src/virtio/vhost/user/vmm/handler/sys/unix.rs index 637ea807fd..f69c9a8533 100644 --- a/devices/src/virtio/vhost/user/vmm/handler/sys/unix.rs +++ b/devices/src/virtio/vhost/user/vmm/handler/sys/unix.rs @@ -12,7 +12,6 @@ use anyhow::Context; use anyhow::Result; use base::info; use base::AsRawDescriptor; -use base::Descriptor; use base::SafeDescriptor; use cros_async::AsyncWrapper; use cros_async::Executor; @@ -24,6 +23,7 @@ use vmm_vhost::Master; use vmm_vhost::MasterReqHandler; use vmm_vhost::VhostUserMaster; +use crate::virtio::vhost::user::vmm::handler::BackendReqHandler; use crate::virtio::vhost::user::vmm::handler::BackendReqHandlerImpl; use crate::virtio::vhost::user::vmm::handler::VhostUserHandler; use crate::virtio::vhost::user::vmm::Error; @@ -32,9 +32,6 @@ use crate::virtio::vhost::user::vmm::Result as VhostResult; pub(in crate::virtio::vhost::user::vmm::handler) type SocketMaster = Master>; -pub(in crate::virtio::vhost::user::vmm::handler) type BackendReqHandler = - MasterReqHandler>; - impl VhostUserHandler { /// Creates a `VhostUserHandler` instance attached to the provided UDS path /// with features and protocol features initialized. @@ -72,10 +69,10 @@ impl VhostUserHandler { } pub fn initialize_backend_req_handler(&mut self, h: BackendReqHandlerImpl) -> VhostResult<()> { - let handler = MasterReqHandler::new(Arc::new(Mutex::new(h))) + let mut handler = MasterReqHandler::with_stream(Arc::new(Mutex::new(h))) .map_err(Error::CreateShmemMapperError)?; self.vu - .set_slave_request_fd(&Descriptor(handler.get_tx_raw_fd())) + .set_slave_request_fd(&handler.take_tx_descriptor()) .map_err(Error::SetDeviceRequestChannel)?; self.backend_req_handler = Some(handler); Ok(()) diff --git a/devices/src/virtio/vhost/user/vmm/handler/sys/windows.rs b/devices/src/virtio/vhost/user/vmm/handler/sys/windows.rs index 73865c0160..3a83543022 100644 --- a/devices/src/virtio/vhost/user/vmm/handler/sys/windows.rs +++ b/devices/src/virtio/vhost/user/vmm/handler/sys/windows.rs @@ -2,21 +2,32 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +use anyhow::Context; +use futures::pin_mut; +use futures::select; +use futures::FutureExt; +use std::sync::Arc; use std::sync::Mutex; +use anyhow::Result; +use base::info; +use base::CloseNotifier; +use base::ReadNotifier; use base::Tube; +use cros_async::EventAsync; use cros_async::Executor; use vmm_vhost::connection::TubeEndpoint; use vmm_vhost::message::MasterReq; use vmm_vhost::message::VhostUserProtocolFeatures; -use vmm_vhost::Error as VhostError; use vmm_vhost::Master; -use vmm_vhost::VhostUserMasterReqHandler; +use vmm_vhost::MasterReqHandler; +use vmm_vhost::VhostUserMaster; +use crate::virtio::vhost::user::vmm::handler::BackendReqHandler; use crate::virtio::vhost::user::vmm::handler::BackendReqHandlerImpl; use crate::virtio::vhost::user::vmm::handler::VhostUserHandler; use crate::virtio::vhost::user::vmm::Error; -use crate::virtio::vhost::user::vmm::Result; +use crate::virtio::vhost::user::vmm::Result as VhostResult; // TODO(rizhang): upstream CL so SocketMaster is renamed to EndpointMaster to make it more cross // platform. @@ -32,33 +43,66 @@ impl VhostUserHandler { allow_features: u64, init_features: u64, allow_protocol_features: VhostUserProtocolFeatures, - ) -> Result { + ) -> VhostResult { + let backend_pid = tube.target_pid(); Self::new( SocketMaster::from_stream(tube, max_queue_num), allow_features, init_features, allow_protocol_features, + backend_pid, ) } - pub fn initialize_backend_req_handler(&mut self, h: BackendReqHandlerImpl) -> Result<()> { - Err(Error::CreateShmemMapperError( - VhostError::MasterInternalError, - )) + pub fn initialize_backend_req_handler(&mut self, h: BackendReqHandlerImpl) -> VhostResult<()> { + let backend_pid = self + .backend_pid + .expect("tube needs target pid for backend requests"); + let mut handler = MasterReqHandler::with_tube(Arc::new(Mutex::new(h)), backend_pid) + .map_err(Error::CreateShmemMapperError)?; + self.vu + .set_slave_request_fd(&handler.take_tx_descriptor()) + .map_err(Error::SetDeviceRequestChannel)?; + self.backend_req_handler = Some(handler); + Ok(()) } } -pub struct BackendReqHandler {} - -impl VhostUserMasterReqHandler for BackendReqHandler {} - pub async fn run_backend_request_handler( handler: Option, - _ex: &Executor, + ex: &Executor, ) -> Result<()> { - match handler { - // We never initialize a BackendReqHandler in |initialize_backend_req_handler|. - Some(_) => unimplemented!("unexpected BackendReqHandler"), + let mut handler = match handler { + Some(h) => h, None => std::future::pending().await, + }; + + let read_notifier = handler.get_read_notifier(); + let close_notifier = handler.get_close_notifier(); + + let read_event = + EventAsync::clone_raw(read_notifier, ex).context("failed to create an async event")?; + let close_event = + EventAsync::clone_raw(close_notifier, ex).context("failed to create an async event")?; + + let read_event_fut = read_event.next_val().fuse(); + let close_event_fut = close_event.next_val().fuse(); + pin_mut!(read_event_fut); + pin_mut!(close_event_fut); + + loop { + select! { + _read_res = read_event_fut => { + handler + .handle_request() + .context("failed to handle a vhost-user request")?; + read_event_fut.set(read_event.next_val().fuse()); + } + // Tube closed event. + _close_res = close_event_fut => { + info!("exit run loop: got close event"); + return Ok(()) + } + } } } diff --git a/devices/src/virtio/vhost/user/vmm/handler/worker.rs b/devices/src/virtio/vhost/user/vmm/handler/worker.rs index c78c7567f4..463e89fe27 100644 --- a/devices/src/virtio/vhost/user/vmm/handler/worker.rs +++ b/devices/src/virtio/vhost/user/vmm/handler/worker.rs @@ -11,7 +11,7 @@ use vm_memory::GuestMemory; use crate::virtio::async_utils; use crate::virtio::vhost::user::vmm::handler::sys::run_backend_request_handler; -use crate::virtio::vhost::user::vmm::handler::sys::BackendReqHandler; +use crate::virtio::vhost::user::vmm::handler::BackendReqHandler; use crate::virtio::Interrupt; use crate::virtio::Queue; diff --git a/third_party/vmm_vhost/Cargo.toml b/third_party/vmm_vhost/Cargo.toml index ff69f92f79..b9f8dfbf48 100644 --- a/third_party/vmm_vhost/Cargo.toml +++ b/third_party/vmm_vhost/Cargo.toml @@ -30,3 +30,4 @@ thiserror = { version = "1.0.20" } [target.'cfg(windows)'.dependencies] serde = { version = "1", features = [ "derive" ] } serde_json = "*" +tube_transporter = { path = "../../tube_transporter" } diff --git a/third_party/vmm_vhost/src/connection/tube.rs b/third_party/vmm_vhost/src/connection/tube.rs index 03d296a4d2..8c26ee3edb 100644 --- a/third_party/vmm_vhost/src/connection/tube.rs +++ b/third_party/vmm_vhost/src/connection/tube.rs @@ -18,10 +18,12 @@ use base::RawDescriptor; use base::Tube; use serde::Deserialize; use serde::Serialize; +use tube_transporter::packed_tube; use crate::connection::Endpoint; use crate::connection::Req; use crate::message::SlaveReq; +use crate::take_single_file; use crate::Error; use crate::Result; @@ -43,6 +45,12 @@ pub struct TubeEndpoint { _r: PhantomData, } +impl TubeEndpoint { + pub(crate) fn get_tube(&self) -> &Tube { + &self.tube + } +} + impl From for TubeEndpoint { fn from(tube: Tube) -> Self { Self { @@ -150,9 +158,12 @@ impl Endpoint for TubeEndpoint { fn create_slave_request_endpoint( &mut self, - _files: Option>, + files: Option>, ) -> Result>> { - unimplemented!("SET_SLAVE_REQ_FD not supported"); + let file = take_single_file(files).ok_or(Error::InvalidMessage)?; + // Safe because the file represents a packed tube. + let tube = unsafe { packed_tube::unpack(file.into()).expect("unpacked Tube") }; + Ok(Box::new(TubeEndpoint::from(tube))) } } diff --git a/third_party/vmm_vhost/src/lib.rs b/third_party/vmm_vhost/src/lib.rs index 34cda433e9..a7d3df21d2 100644 --- a/third_party/vmm_vhost/src/lib.rs +++ b/third_party/vmm_vhost/src/lib.rs @@ -77,7 +77,7 @@ cfg_if::cfg_if! { } } cfg_if::cfg_if! { - if #[cfg(all(feature = "vmm", unix))] { + if #[cfg(feature = "vmm")] { pub use self::master_req_handler::MasterReqHandler; } } @@ -151,7 +151,19 @@ pub enum Error { VfioDeviceError(anyhow::Error), } -impl std::convert::From for Error { +impl From for Error { + fn from(err: base::TubeError) -> Self { + Error::TubeError(err) + } +} + +impl From for Error { + fn from(err: std::io::Error) -> Self { + Error::SocketError(err) + } +} + +impl From for Error { /// Convert raw socket errors into meaningful vhost-user errors. /// /// The base::Error is a simple wrapper over the raw errno, which doesn't means @@ -336,12 +348,8 @@ mod tests { slave.handle_request().unwrap(); slave.handle_request().unwrap(); - // set_slave_request_rd isn't implemented on Windows. - #[cfg(unix)] - { - // set_slave_request_fd - slave.handle_request().unwrap(); - } + // set_slave_request_fd + slave.handle_request().unwrap(); // set_vring_enable slave.handle_request().unwrap(); @@ -419,13 +427,17 @@ mod tests { assert_eq!(offset, 0x100); assert_eq!(reply_payload[0], 0xa5); - // slave request rds are not implemented on Windows. + #[cfg(windows)] + let tubes = base::Tube::pair().unwrap(); + #[cfg(windows)] + // Safe because we will be importing the Tube in the other thread. + let descriptor = + unsafe { tube_transporter::packed_tube::pack(tubes.0, std::process::id()).unwrap() }; + #[cfg(unix)] - { - master - .set_slave_request_fd(&event as &dyn AsRawDescriptor) - .unwrap(); - } + let descriptor = base::Event::new().unwrap(); + + master.set_slave_request_fd(&descriptor).unwrap(); master.set_vring_enable(0, true).unwrap(); // unimplemented yet diff --git a/third_party/vmm_vhost/src/master_req_handler.rs b/third_party/vmm_vhost/src/master_req_handler.rs index 94f157f0f1..c63fd3da74 100644 --- a/third_party/vmm_vhost/src/master_req_handler.rs +++ b/third_party/vmm_vhost/src/master_req_handler.rs @@ -1,30 +1,28 @@ // Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -#[cfg(unix)] +cfg_if::cfg_if! { + if #[cfg(unix)] { + mod unix; + } else if #[cfg(windows)] { + mod windows; + } +} + use std::fs::File; -#[cfg(unix)] use std::mem; -#[cfg(unix)] -use std::os::unix::io::AsRawFd; -#[cfg(unix)] use std::sync::Arc; use std::sync::Mutex; use base::AsRawDescriptor; -use base::RawDescriptor; +use base::SafeDescriptor; -#[cfg(unix)] -use crate::connection::socket::Endpoint as SocketEndpoint; -#[cfg(unix)] use crate::connection::EndpointExt; use crate::message::*; -#[cfg(unix)] use crate::Error; use crate::HandlerResult; -#[cfg(unix)] use crate::Result; -#[cfg(unix)] +use crate::SlaveReqEndpoint; use crate::SystemStream; /// Define services provided by masters for the slave communication channel. @@ -199,20 +197,13 @@ impl VhostUserMasterReqHandler for Mutex { /// [MasterReqHandler]: struct.MasterReqHandler.html /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html /// -/// TODO(b/221882601): we can write a version of this for Windows by switching the socket for a Tube. -/// The interfaces would need to change so that we fetch a full Tube (which is 2 rds on Windows) -/// and send it to the device backend (slave) as a message on the master -> slave channel. -/// (Currently the interface only supports sending a single rd.) -/// -/// Note that handling requests from slaves is not needed for the initial devices we plan to -/// support. -/// /// Server to handle service requests from slaves from the slave communication channel. -#[cfg(unix)] pub struct MasterReqHandler { // underlying Unix domain socket for communication - sub_sock: SocketEndpoint, - tx_sock: SystemStream, + sub_sock: SlaveReqEndpoint, + tx_sock: Option, + // Serializes tx_sock for passing to the backend. + serialize_tx: Box SafeDescriptor + Send>, // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. reply_ack_negotiated: bool, // the VirtIO backend device object @@ -221,35 +212,39 @@ pub struct MasterReqHandler { error: Option, } -#[cfg(unix)] impl MasterReqHandler { /// Create a server to handle service requests from slaves on the slave communication channel. /// /// This opens a pair of connected anonymous sockets to form the slave communication channel. - /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by + /// The socket fd returned by [Self::take_tx_descriptor()] should be sent to the slave by /// [VhostUserMaster::set_slave_request_fd()]. /// - /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd + /// [Self::take_tx_descriptor()]: struct.MasterReqHandler.html#method.take_tx_descriptor /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd - pub fn new(backend: Arc) -> Result { - let (tx, rx) = SystemStream::pair().map_err(Error::SocketError)?; + pub fn new( + backend: Arc, + serialize_tx: Box SafeDescriptor + Send>, + ) -> Result { + let (tx, rx) = SystemStream::pair()?; Ok(MasterReqHandler { - sub_sock: SocketEndpoint::::from(rx), - tx_sock: tx, + sub_sock: SlaveReqEndpoint::from(rx), + tx_sock: Some(tx), + serialize_tx, reply_ack_negotiated: false, backend, error: None, }) } - /// Get the socket fd for the slave to communication with the master. + /// Get the descriptor for the slave to communication with the master. /// - /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()]. + /// The caller owns the descriptor. The returned descriptor should be sent to the slave by + /// [VhostUserMaster::set_slave_request_fd()]. /// /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd - pub fn get_tx_raw_fd(&self) -> RawDescriptor { - self.tx_sock.as_raw_fd() + pub fn take_tx_descriptor(&mut self) -> SafeDescriptor { + (self.serialize_tx)(self.tx_sock.take().expect("tx_sock should have a value")) } /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature. @@ -453,116 +448,3 @@ impl MasterReqHandler { Ok(()) } } - -#[cfg(unix)] -impl AsRawDescriptor for MasterReqHandler { - fn as_raw_descriptor(&self) -> RawDescriptor { - // TODO(b/221882601): figure out whether this is used for polling. If so, we need theTube's - // read notifier here instead. - self.sub_sock.as_raw_descriptor() - } -} - -#[cfg(unix)] -#[cfg(test)] -mod tests { - use base::AsRawDescriptor; - use base::Descriptor; - use base::FromRawDescriptor; - use base::INVALID_DESCRIPTOR; - - use super::*; - #[cfg(feature = "device")] - use crate::Slave; - - struct MockMasterReqHandler {} - - impl VhostUserMasterReqHandlerMut for MockMasterReqHandler { - /// Handle virtio-fs map file requests from the slave. - fn fs_slave_map( - &mut self, - _fs: &VhostUserFSSlaveMsg, - _fd: &dyn AsRawDescriptor, - ) -> HandlerResult { - Ok(0) - } - - /// Handle virtio-fs unmap file requests from the slave. - fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult { - Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) - } - } - - #[test] - fn test_new_master_req_handler() { - let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); - let mut handler = MasterReqHandler::new(backend).unwrap(); - - assert!(handler.get_tx_raw_fd() >= 0); - assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR); - handler.check_state().unwrap(); - - assert_eq!(handler.error, None); - handler.set_failed(libc::EAGAIN); - assert_eq!(handler.error, Some(libc::EAGAIN)); - handler.check_state().unwrap_err(); - } - - #[cfg(feature = "device")] - #[test] - fn test_master_slave_req_handler() { - let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); - let mut handler = MasterReqHandler::new(backend).unwrap(); - - let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; - if fd < 0 { - panic!("failed to duplicated tx fd!"); - } - let stream = unsafe { SystemStream::from_raw_descriptor(fd) }; - let fs_cache = Slave::from_stream(stream); - - std::thread::spawn(move || { - let res = handler.handle_request().unwrap(); - assert_eq!(res, 0); - handler.handle_request().unwrap_err(); - }); - - fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), &Descriptor(fd)) - .unwrap(); - // When REPLY_ACK has not been negotiated, the master has no way to detect failure from - // slave side. - fs_cache - .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) - .unwrap(); - } - - #[cfg(feature = "device")] - #[test] - fn test_master_slave_req_handler_with_ack() { - let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); - let mut handler = MasterReqHandler::new(backend).unwrap(); - handler.set_reply_ack_flag(true); - - let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; - if fd < 0 { - panic!("failed to duplicated tx fd!"); - } - let stream = unsafe { SystemStream::from_raw_descriptor(fd) }; - let fs_cache = Slave::from_stream(stream); - - std::thread::spawn(move || { - let res = handler.handle_request().unwrap(); - assert_eq!(res, 0); - handler.handle_request().unwrap_err(); - }); - - fs_cache.set_reply_ack_flag(true); - fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), &Descriptor(fd)) - .unwrap(); - fs_cache - .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) - .unwrap_err(); - } -} diff --git a/third_party/vmm_vhost/src/master_req_handler/unix.rs b/third_party/vmm_vhost/src/master_req_handler/unix.rs new file mode 100644 index 0000000000..c00939ecaf --- /dev/null +++ b/third_party/vmm_vhost/src/master_req_handler/unix.rs @@ -0,0 +1,148 @@ +// Copyright 2022 The Chromium OS Authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Unix specific code that keeps rest of the code in the crate platform independent. + +use std::os::unix::io::IntoRawFd; +use std::sync::Arc; + +use base::AsRawDescriptor; +use base::FromRawDescriptor; +use base::RawDescriptor; +use base::SafeDescriptor; + +use crate::master_req_handler::MasterReqHandler; +use crate::Result; +use crate::VhostUserMasterReqHandler; + +impl AsRawDescriptor for MasterReqHandler { + /// Used for polling. + fn as_raw_descriptor(&self) -> RawDescriptor { + self.sub_sock.as_raw_descriptor() + } +} + +impl MasterReqHandler { + /// Create a `MasterReqHandler` that uses a Unix stream internally. + pub fn with_stream(backend: Arc) -> Result { + Self::new( + backend, + Box::new(|stream| unsafe { + // Safe because we own the raw fd. + SafeDescriptor::from_raw_descriptor(stream.into_raw_fd()) + }), + ) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Mutex; + + use base::AsRawDescriptor; + use base::Descriptor; + use base::FromRawDescriptor; + use base::INVALID_DESCRIPTOR; + + use super::*; + use crate::message::VhostUserFSSlaveMsg; + use crate::HandlerResult; + #[cfg(feature = "device")] + use crate::Slave; + use crate::SystemStream; + use crate::VhostUserMasterReqHandlerMut; + + struct MockMasterReqHandler {} + + impl VhostUserMasterReqHandlerMut for MockMasterReqHandler { + /// Handle virtio-fs map file requests from the slave. + fn fs_slave_map( + &mut self, + _fs: &VhostUserFSSlaveMsg, + _fd: &dyn AsRawDescriptor, + ) -> HandlerResult { + Ok(0) + } + + /// Handle virtio-fs unmap file requests from the slave. + fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + } + + #[test] + fn test_new_master_req_handler() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::with_stream(backend).unwrap(); + + let tx_descriptor = handler.take_tx_descriptor(); + assert!(tx_descriptor.as_raw_descriptor() >= 0); + assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR); + handler.check_state().unwrap(); + + assert_eq!(handler.error, None); + handler.set_failed(libc::EAGAIN); + assert_eq!(handler.error, Some(libc::EAGAIN)); + handler.check_state().unwrap_err(); + } + + #[cfg(feature = "device")] + #[test] + fn test_master_slave_req_handler() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::with_stream(backend).unwrap(); + + let tx_descriptor = handler.take_tx_descriptor(); + let fd = unsafe { libc::dup(tx_descriptor.as_raw_descriptor()) }; + if fd < 0 { + panic!("failed to duplicated tx fd!"); + } + let stream = unsafe { SystemStream::from_raw_descriptor(fd) }; + let fs_cache = Slave::from_stream(stream); + + std::thread::spawn(move || { + let res = handler.handle_request().unwrap(); + assert_eq!(res, 0); + handler.handle_request().unwrap_err(); + }); + + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &Descriptor(fd)) + .unwrap(); + // When REPLY_ACK has not been negotiated, the master has no way to detect failure from + // slave side. + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap(); + } + + #[cfg(feature = "device")] + #[test] + fn test_master_slave_req_handler_with_ack() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::with_stream(backend).unwrap(); + handler.set_reply_ack_flag(true); + + let tx_descriptor = handler.take_tx_descriptor(); + let fd = unsafe { libc::dup(tx_descriptor.as_raw_descriptor()) }; + if fd < 0 { + panic!("failed to duplicated tx fd!"); + } + let stream = unsafe { SystemStream::from_raw_descriptor(fd) }; + let fs_cache = Slave::from_stream(stream); + + std::thread::spawn(move || { + let res = handler.handle_request().unwrap(); + assert_eq!(res, 0); + handler.handle_request().unwrap_err(); + }); + + fs_cache.set_reply_ack_flag(true); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &Descriptor(fd)) + .unwrap(); + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap_err(); + } +} diff --git a/third_party/vmm_vhost/src/master_req_handler/windows.rs b/third_party/vmm_vhost/src/master_req_handler/windows.rs new file mode 100644 index 0000000000..39c49441be --- /dev/null +++ b/third_party/vmm_vhost/src/master_req_handler/windows.rs @@ -0,0 +1,154 @@ +// Copyright 2022 The Chromium OS Authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Windows specific code that keeps rest of the code in the crate platform independent. + +use std::sync::Arc; + +use base::AsRawDescriptor; +use base::CloseNotifier; +use base::ReadNotifier; +use tube_transporter::packed_tube; + +use crate::master_req_handler::MasterReqHandler; +use crate::Result; +use crate::VhostUserMasterReqHandler; + +impl MasterReqHandler { + /// Create a `MasterReqHandler` that uses a Tube internally. Must specify the backend process + /// which will receive the Tube. + pub fn with_tube(backend: Arc, backend_pid: u32) -> Result { + Self::new( + backend, + Box::new(move |tube| unsafe { + // Safe because we expect the tube to be unpacked in the other process. + packed_tube::pack(tube, backend_pid).expect("packed tube") + }), + ) + } +} + +impl ReadNotifier for MasterReqHandler { + /// Used for polling. + fn get_read_notifier(&self) -> &dyn AsRawDescriptor { + self.sub_sock.get_tube().get_read_notifier() + } +} + +impl CloseNotifier for MasterReqHandler { + /// Used for closing. + fn get_close_notifier(&self) -> &dyn AsRawDescriptor { + self.sub_sock.get_tube().get_close_notifier() + } +} + +#[cfg(test)] +mod tests { + use std::sync::Mutex; + + use base::AsRawDescriptor; + use base::Descriptor; + use base::INVALID_DESCRIPTOR; + + use super::*; + use crate::message::VhostUserFSSlaveMsg; + use crate::HandlerResult; + #[cfg(feature = "device")] + use crate::Slave; + use crate::VhostUserMasterReqHandlerMut; + + struct MockMasterReqHandler {} + + impl VhostUserMasterReqHandlerMut for MockMasterReqHandler { + /// Handle virtio-fs map file requests from the slave. + fn fs_slave_map( + &mut self, + _fs: &VhostUserFSSlaveMsg, + _fd: &dyn AsRawDescriptor, + ) -> HandlerResult { + Ok(0) + } + + /// Handle virtio-fs unmap file requests from the slave. + fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + } + + #[test] + fn test_new_master_req_handler() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::with_tube(backend, std::process::id()).unwrap(); + + assert!(handler.get_read_notifier().as_raw_descriptor() != INVALID_DESCRIPTOR); + assert!(handler.get_close_notifier().as_raw_descriptor() != INVALID_DESCRIPTOR); + handler.check_state().unwrap(); + + assert_eq!(handler.error, None); + handler.set_failed(libc::EAGAIN); + assert_eq!(handler.error, Some(libc::EAGAIN)); + handler.check_state().unwrap_err(); + } + + #[cfg(feature = "device")] + #[test] + fn test_master_slave_req_handler() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::with_tube(backend, std::process::id()).unwrap(); + + let event = base::Event::new().unwrap(); + let tx_descriptor = handler.take_tx_descriptor(); + // Safe because we only do it once. + let stream = unsafe { packed_tube::unpack(tx_descriptor).unwrap() }; + let fs_cache = Slave::from_stream(stream); + + std::thread::spawn(move || { + let res = handler.handle_request().unwrap(); + assert_eq!(res, 0); + handler.handle_request().unwrap_err(); + }); + + fs_cache + .fs_slave_map( + &VhostUserFSSlaveMsg::default(), + &Descriptor(event.as_raw_descriptor()), + ) + .unwrap(); + // When REPLY_ACK has not been negotiated, the master has no way to detect failure from + // slave side. + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap(); + } + + #[cfg(feature = "device")] + #[test] + fn test_master_slave_req_handler_with_ack() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::with_tube(backend, std::process::id()).unwrap(); + handler.set_reply_ack_flag(true); + + let event = base::Event::new().unwrap(); + let tx_descriptor = handler.take_tx_descriptor(); + // Safe because we only do it once. + let stream = unsafe { packed_tube::unpack(tx_descriptor).unwrap() }; + let fs_cache = Slave::from_stream(stream); + + std::thread::spawn(move || { + let res = handler.handle_request().unwrap(); + assert_eq!(res, 0); + handler.handle_request().unwrap_err(); + }); + + fs_cache.set_reply_ack_flag(true); + fs_cache + .fs_slave_map( + &VhostUserFSSlaveMsg::default(), + &Descriptor(event.as_raw_descriptor()), + ) + .unwrap(); + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap_err(); + } +} diff --git a/tube_transporter/src/lib.rs b/tube_transporter/src/lib.rs index e8b3d6f6b6..3ba6df1569 100644 --- a/tube_transporter/src/lib.rs +++ b/tube_transporter/src/lib.rs @@ -27,6 +27,8 @@ use serde::Deserialize; use serde::Serialize; use thiserror::Error as ThisError; +pub mod packed_tube; + pub type TransportTubeResult = std::result::Result; /// Contains information for a child process to set up the Tube for use. diff --git a/tube_transporter/src/packed_tube.rs b/tube_transporter/src/packed_tube.rs new file mode 100644 index 0000000000..84d361a8f0 --- /dev/null +++ b/tube_transporter/src/packed_tube.rs @@ -0,0 +1,115 @@ +// Copyright 2022 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use base::deserialize_and_recv; +use base::named_pipes; +use base::named_pipes::BlockingMode; +use base::named_pipes::FramingMode; +use base::serialize_and_send; +use base::Error as SysError; +use base::FromRawDescriptor; +use base::IntoRawDescriptor; +use base::PipeConnection; +use base::SafeDescriptor; +use base::Tube; +use base::TubeError; +use serde::Deserialize; +use serde::Serialize; +use thiserror::Error as ThisError; + +pub type PackedTubeResult = Result; + +#[derive(Debug, ThisError)] +pub enum PackedTubeError { + #[error("Serializing and recving failed: {0}")] + DeserializeRecvError(TubeError), + #[error("Named pipe error: {0}")] + PipeError(SysError), + #[error("Serializing and sending failed: {0}")] + SerializeSendError(TubeError), +} + +#[derive(Deserialize, Serialize)] +struct PackedTube { + tube: Tube, + server_pipe: PipeConnection, +} + +/// Sends a [Tube] through a protocol that expects a [RawDescriptor]. +/// +/// A packed tube works by creating a named pipe pair, and serializing both the Tube and the +/// server end of the pipe. Then, it returns the client end of the named pipe pair, which can be +/// used as the desired descriptor to send / duplicate to the target. +/// +/// The receiver will need to use [packed_tube::unpack] to read the message off the pipe, and thus +/// extract a real [Tube]. It will also read the server end of the pipe, and close it. The +/// `receiver_pid` is the pid of the process that will be unpacking the tube. +/// +/// # Safety +/// To prevent dangling handles, the resulting descriptor must be passed to [packed_tube::unpack], +/// in the process which corresponds to `receiver_pid`. +pub unsafe fn pack(tube: Tube, receiver_pid: u32) -> PackedTubeResult { + let (server_pipe, client_pipe) = named_pipes::pair( + &FramingMode::Message, + &BlockingMode::Wait, + /* timeout= */ 0, + ) + .map_err(SysError::from) + .map_err(PackedTubeError::PipeError)?; + + let packed = PackedTube { tube, server_pipe }; + + // Serialize the packed tube, which also duplicates the server end of the pipe into the other + // process. This lets us drop it on our side without destroying the channel. + serialize_and_send( + |buf| packed.server_pipe.write(buf), + &packed, + Some(receiver_pid), + ) + .map_err(PackedTubeError::SerializeSendError)?; + + Ok(SafeDescriptor::from_raw_descriptor( + client_pipe.into_raw_descriptor(), + )) +} + +/// Unpacks a tube from a client descriptor. This must come from a packed tube. +/// +/// # Safety +/// The descriptor passed in must come from [packed_tube::pack]. +pub unsafe fn unpack(descriptor: SafeDescriptor) -> PackedTubeResult { + let pipe = PipeConnection::from_raw_descriptor( + descriptor.into_raw_descriptor(), + FramingMode::Message, + BlockingMode::Wait, + ); + // Safe because we own the descriptor and it came from a PackedTube. + let unpacked: PackedTube = deserialize_and_recv(|buf| pipe.read(buf)) + .map_err(PackedTubeError::DeserializeRecvError)?; + // By dropping `unpacked` we close the server end of the pipe. + Ok(unpacked.tube) +} + +#[cfg(test)] +mod tests { + use crate::packed_tube; + + use base::Tube; + + #[test] + /// Tests packing and unpacking. + fn test_pack_unpack() { + let (tube_server, tube_client) = Tube::pair().unwrap(); + let packed_tube = unsafe { packed_tube::pack(tube_client, std::process::id()).unwrap() }; + + // Safe because get_descriptor clones the underlying pipe. + let recovered_tube = unsafe { packed_tube::unpack(packed_tube).unwrap() }; + + let test_message = "Test message".to_string(); + tube_server.send(&test_message).unwrap(); + let received: String = recovered_tube.recv().unwrap(); + + assert_eq!(test_message, received); + } +}