diff --git a/devices/src/virtio/vhost/user/device/handler.rs b/devices/src/virtio/vhost/user/device/handler.rs index 56957b1b9b..f3aa238f29 100644 --- a/devices/src/virtio/vhost/user/device/handler.rs +++ b/devices/src/virtio/vhost/user/device/handler.rs @@ -704,18 +704,13 @@ impl VhostUserSlaveReqHandlerMut for DeviceRequestHandl Ok(()) } - #[allow(unused, dead_code, clippy::diverging_sub_expression)] - fn set_slave_req_fd(&mut self, file: File) { + fn set_slave_req_fd(&mut self, ep: Box>) { let shmid = match self.shmid { Some(shmid) => shmid, None => { - if let Err(e) = self.backend.set_device_request_channel(file) { - error!("failed to set device request channel: {}", e); - } - return; + unimplemented!("set_device_request_channel no longer supported"); } }; - let ep: Box> = todo!(); let frontend = Slave::new(ep); self.backend .set_shared_memory_mapper(Box::new(VhostShmemMapper { diff --git a/devices/src/virtio/vhost/user/device/vsock.rs b/devices/src/virtio/vhost/user/device/vsock.rs index 0d7345a78f..62b46c003c 100644 --- a/devices/src/virtio/vhost/user/device/vsock.rs +++ b/devices/src/virtio/vhost/user/device/vsock.rs @@ -33,6 +33,8 @@ use vhost::Vsock; use vhost::{self}; use vm_memory::GuestMemory; use vmm_vhost::connection::vfio::Listener as VfioListener; +use vmm_vhost::connection::Endpoint; +use vmm_vhost::message::SlaveReq; use vmm_vhost::message::VhostSharedMemoryRegion; use vmm_vhost::message::VhostUserConfigFlags; use vmm_vhost::message::VhostUserInflight; @@ -52,7 +54,8 @@ use vmm_vhost::VhostUserSlaveReqHandlerMut; use crate::virtio::base_features; use crate::virtio::vhost::user::device::handler::run_handler; // TODO(acourbot) try to remove the system dependencies and make the device usable on all platforms. -use crate::virtio::vhost::user::device::handler::sys::unix::{Doorbell, VvuOps}; +use crate::virtio::vhost::user::device::handler::sys::unix::Doorbell; +use crate::virtio::vhost::user::device::handler::sys::unix::VvuOps; use crate::virtio::vhost::user::device::handler::vmm_va_to_gpa; use crate::virtio::vhost::user::device::handler::MappingInfo; use crate::virtio::vhost::user::device::handler::VhostUserPlatformOps; @@ -430,7 +433,10 @@ impl VhostUserSlaveReqHandlerMut for VsockBackend { Err(Error::InvalidOperation) } - fn set_slave_req_fd(&mut self, _vu_req: File) {} + fn set_slave_req_fd(&mut self, _vu_req: Box>) { + // We didn't set VhostUserProtocolFeatures::SLAVE_REQ + unreachable!("unexpected set_slave_req_fd"); + } fn get_inflight_fd( &mut self, diff --git a/devices/src/virtio/vhost/user/device/vvu/device.rs b/devices/src/virtio/vhost/user/device/vvu/device.rs index d871cdf1ad..fa7c7718e7 100644 --- a/devices/src/virtio/vhost/user/device/vvu/device.rs +++ b/devices/src/virtio/vhost/user/device/vvu/device.rs @@ -8,6 +8,7 @@ use std::cmp::Ordering; use std::io::IoSlice; use std::io::IoSliceMut; use std::mem; +use std::os::unix::prelude::RawFd; use std::sync::mpsc::channel; use std::sync::mpsc::Receiver; use std::sync::mpsc::Sender; @@ -29,8 +30,10 @@ use futures::select; use futures::FutureExt; use sync::Mutex; use vmm_vhost::connection::vfio::Device as VfioDeviceTrait; +use vmm_vhost::connection::vfio::Endpoint as VfioEndpoint; use vmm_vhost::connection::vfio::RecvIntoBufsError; -use vmm_vhost::message::MasterReq; +use vmm_vhost::connection::Endpoint; +use vmm_vhost::message::*; use crate::virtio::vhost::user::device::vvu::pci::QueueNotifier; use crate::virtio::vhost::user::device::vvu::pci::VvuPciDevice; @@ -141,17 +144,26 @@ impl VfioReceiver { } } +// Data queued to send on an endpoint. +#[derive(Default)] +struct EndpointTxBuffer { + bytes: Vec, +} + // Utility class for writing an input vhost-user byte stream to the vvu // tx virtqueue as discrete vhost-user messages. struct Queue { txq: UserQueue, txq_notifier: QueueNotifier, - - bytes: Vec, } impl Queue { - fn send_bufs(&mut self, iovs: &[IoSlice], fds: Option<&[RawDescriptor]>) -> Result { + fn send_bufs( + &mut self, + iovs: &[IoSlice], + fds: Option<&[RawDescriptor]>, + tx_state: &mut EndpointTxBuffer, + ) -> Result { if fds.is_some() { bail!("cannot send FDs"); } @@ -160,15 +172,15 @@ impl Queue { for iov in iovs { let mut vec = iov.to_vec(); size += iov.len(); - self.bytes.append(&mut vec); + tx_state.bytes.append(&mut vec); } - if let Some(hdr) = vhost_header_from_bytes::(&self.bytes) { + if let Some(hdr) = vhost_header_from_bytes::(&tx_state.bytes) { let bytes_needed = hdr.get_size() as usize + HEADER_LEN; - match bytes_needed.cmp(&self.bytes.len()) { + match bytes_needed.cmp(&tx_state.bytes.len()) { Ordering::Greater => (), Ordering::Equal => { - let msg = mem::take(&mut self.bytes); + let msg = mem::take(&mut tx_state.bytes); self.txq.write(&msg).context("Failed to send data")?; } Ordering::Less => bail!("sent bytes larger than message size"), @@ -184,7 +196,8 @@ async fn process_rxq( evt: EventAsync, mut rxq: UserQueue, rxq_notifier: QueueNotifier, - sender: VfioSender, + frontend_sender: VfioSender, + backend_sender: VfioSender, ) -> Result<()> { loop { if let Err(e) = evt.next_val().await { @@ -200,6 +213,9 @@ async fn process_rxq( let mut buf = vec![0_u8; slice.size()]; slice.copy_to(&mut buf); + // The inbound message may be a SlaveReq message. However, the values + // of all SlaveReq enum values can be safely interpreted as MasterReq + // enum values. let hdr = vhost_header_from_bytes::(&buf).context("rxq message too short")?; if HEADER_LEN + hdr.get_size() as usize != slice.size() { @@ -210,7 +226,13 @@ async fn process_rxq( ); } - sender.send(buf).context("send failed")?; + if hdr.is_reply() { + &backend_sender + } else { + &frontend_sender + } + .send(buf) + .context("send failed")?; } rxq_notifier.notify(); } @@ -232,12 +254,19 @@ fn run_worker( rx_queue: UserQueue, rx_irq: Event, rx_notifier: QueueNotifier, - sender: VfioSender, + frontend_sender: VfioSender, + backend_sender: VfioSender, tx_queue: Arc>, tx_irq: Event, ) -> Result<()> { let rx_irq = EventAsync::new(rx_irq, &ex).context("failed to create async event")?; - let rxq = process_rxq(rx_irq, rx_queue, rx_notifier, sender); + let rxq = process_rxq( + rx_irq, + rx_queue, + rx_notifier, + frontend_sender, + backend_sender, + ); pin_mut!(rxq); let tx_irq = EventAsync::new(tx_irq, &ex).context("failed to create async event")?; @@ -267,6 +296,7 @@ enum DeviceState { }, Running { rxq_receiver: VfioReceiver, + tx_state: EndpointTxBuffer, txq: Arc>, }, @@ -274,21 +304,24 @@ enum DeviceState { pub struct VvuDevice { state: DeviceState, - rxq_evt: Event, + frontend_rxq_evt: Event, + + backend_channel: Option>, } impl VvuDevice { pub fn new(device: VvuPciDevice) -> Self { Self { state: DeviceState::Initialized { device }, - rxq_evt: Event::new().expect("failed to create VvuDevice's rxq_evt"), + frontend_rxq_evt: Event::new().expect("failed to create VvuDevice's rxq_evt"), + backend_channel: None, } } } impl VfioDeviceTrait for VvuDevice { fn event(&self) -> &Event { - &self.rxq_evt + &self.frontend_rxq_evt } fn start(&mut self) -> Result<()> { @@ -309,23 +342,34 @@ impl VfioDeviceTrait for VvuDevice { let rxq_notifier = queue_notifiers.remove(0); // TODO: Can we use async channel instead so we don't need `rxq_evt`? let (rxq_sender, rxq_receiver) = channel(); - let rxq_evt = self.rxq_evt.try_clone().expect("rxq_evt clone"); + let rxq_evt = self.frontend_rxq_evt.try_clone().expect("rxq_evt clone"); let txq = Arc::new(Mutex::new(Queue { txq: queues.remove(0), txq_notifier: queue_notifiers.remove(0), - bytes: Vec::new(), })); let txq_cloned = Arc::clone(&txq); let txq_irq = irqs.remove(0); + let (backend_rxq_sender, backend_rxq_receiver) = channel(); + let backend_rxq_evt = Event::new().expect("failed to create VvuDevice's rxq_evt"); + let backend_rxq_evt2 = backend_rxq_evt.try_clone().expect("rxq_evt clone"); + self.backend_channel = Some(VfioEndpoint::from(BackendChannel { + receiver: VfioReceiver::new(backend_rxq_receiver, backend_rxq_evt), + queue: txq.clone(), + tx_state: EndpointTxBuffer::default(), + })); + let old_state = std::mem::replace( &mut self.state, DeviceState::Running { rxq_receiver: VfioReceiver::new( rxq_receiver, - self.rxq_evt.try_clone().expect("rxq_evt clone"), + self.frontend_rxq_evt + .try_clone() + .expect("frontend_rxq_evt clone"), ), + tx_state: EndpointTxBuffer::default(), txq, }, ); @@ -335,14 +379,22 @@ impl VfioDeviceTrait for VvuDevice { _ => unreachable!(), }; - let sender = VfioSender::new(rxq_sender, rxq_evt); + let frontend_sender = VfioSender::new(rxq_sender, rxq_evt); + let backend_sender = VfioSender::new(backend_rxq_sender, backend_rxq_evt2); thread::Builder::new() .name("virtio-vhost-user driver".to_string()) .spawn(move || { device.start().expect("failed to start device"); - if let Err(e) = - run_worker(ex, rxq, rxq_irq, rxq_notifier, sender, txq_cloned, txq_irq) - { + if let Err(e) = run_worker( + ex, + rxq, + rxq_irq, + rxq_notifier, + frontend_sender, + backend_sender, + txq_cloned, + txq_irq, + ) { error!("worker thread exited with error: {}", e); } })?; @@ -351,14 +403,15 @@ impl VfioDeviceTrait for VvuDevice { } fn send_bufs(&mut self, iovs: &[IoSlice], fds: Option<&[RawDescriptor]>) -> Result { - let txq = match &mut self.state { + match &mut self.state { DeviceState::Initialized { .. } => { bail!("VvuDevice hasn't started yet"); } - DeviceState::Running { txq, .. } => txq, - }; - - txq.lock().send_bufs(iovs, fds) + DeviceState::Running { txq, tx_state, .. } => { + let mut queue = txq.lock(); + queue.send_bufs(iovs, fds, tx_state) + } + } } fn recv_into_bufs(&mut self, bufs: &mut [IoSliceMut]) -> Result { @@ -369,4 +422,43 @@ impl VfioDeviceTrait for VvuDevice { DeviceState::Running { rxq_receiver, .. } => rxq_receiver.recv_into_bufs(bufs), } } + + fn create_slave_request_endpoint(&mut self) -> Result>> { + self.backend_channel + .take() + .map_or(Err(anyhow!("missing backend endpoint")), |c| { + Ok(Box::new(c)) + }) + } +} + +// Struct which implements the Endpoint for backend messages. +struct BackendChannel { + receiver: VfioReceiver, + queue: Arc>, + tx_state: EndpointTxBuffer, +} + +impl VfioDeviceTrait for BackendChannel { + fn event(&self) -> &Event { + &self.receiver.evt + } + + fn start(&mut self) -> Result<()> { + Ok(()) + } + + fn send_bufs(&mut self, iovs: &[IoSlice], fds: Option<&[RawFd]>) -> Result { + self.queue.lock().send_bufs(iovs, fds, &mut self.tx_state) + } + + fn recv_into_bufs(&mut self, bufs: &mut [IoSliceMut]) -> Result { + self.receiver.recv_into_bufs(bufs) + } + + fn create_slave_request_endpoint(&mut self) -> Result>> { + Err(anyhow!( + "can't construct backend endpoint from backend endpoint" + )) + } } diff --git a/devices/src/virtio/vhost/user/proxy.rs b/devices/src/virtio/vhost/user/proxy.rs index b64e007f81..fbc3e02daa 100644 --- a/devices/src/virtio/vhost/user/proxy.rs +++ b/devices/src/virtio/vhost/user/proxy.rs @@ -12,8 +12,11 @@ use std::fmt; use std::fs::File; +use std::io::IoSlice; +use std::io::Read; use std::io::Write; use std::os::unix::net::UnixListener; +use std::os::unix::net::UnixStream; use std::sync::Arc; use std::thread; @@ -31,6 +34,7 @@ use base::IntoRawDescriptor; use base::Protection; use base::RawDescriptor; use base::SafeDescriptor; +use base::ScmSocket; use base::Tube; use base::WaitContext; use data_model::DataInit; @@ -51,6 +55,7 @@ use vmm_vhost::connection::socket::Endpoint as SocketEndpoint; use vmm_vhost::connection::EndpointExt; use vmm_vhost::message::MasterReq; use vmm_vhost::message::Req; +use vmm_vhost::message::SlaveReq; use vmm_vhost::message::VhostUserMemory; use vmm_vhost::message::VhostUserMemoryRegion; use vmm_vhost::message::VhostUserMsgHeader; @@ -66,6 +71,7 @@ use crate::pci::PciBarRegionType; use crate::pci::PciCapability; use crate::pci::PciCapabilityID; use crate::virtio::copy_config; +use crate::virtio::vhost::vhost_header_from_bytes; use crate::virtio::DescriptorChain; use crate::virtio::DeviceType; use crate::virtio::Interrupt; @@ -207,14 +213,6 @@ fn check_attached_files( } } -// Check if `hdr` is valid. -fn is_header_valid(hdr: &VhostUserMsgHeader) -> bool { - if hdr.is_reply() || hdr.get_version() != 0x1 { - return false; - } - true -} - // Payload sent by the sibling in a |SET_VRING_KICK| message. #[derive(Default)] struct KickData { @@ -260,12 +258,17 @@ struct Worker { // Stores memory regions that the worker has asked the main thread to register. registered_memory: Vec, + + // Channel for backend mesages. + slave_req_fd: Option>, } -#[derive(EventToken, Debug, Clone)] +#[derive(EventToken, Debug, Clone, PartialEq)] enum Token { // Data is available on the Vhost-user sibling socket. SiblingSocket, + // Data is available on the vhost-user backend socket. + BackendSocket, // The device backend has made a read buffer available. RxQueue, // The device backend has sent a buffer to the |Worker::tx_queue|. @@ -302,7 +305,10 @@ enum ExitReason { } // Trait used to process an incoming vhost-user message -trait Action: Req { +trait RxAction: Req { + // Checks whether the header is valid + fn is_header_valid(hdr: &VhostUserMsgHeader) -> bool; + // Process a message before forwarding it on to the virtqueue fn process_message( worker: &mut Worker, @@ -319,7 +325,14 @@ trait Action: Req { fn handle_failure(worker: &mut Worker, hdr: &VhostUserMsgHeader) -> Result<()>; } -impl Action for MasterReq { +impl RxAction for MasterReq { + fn is_header_valid(hdr: &VhostUserMsgHeader) -> bool { + if hdr.is_reply() || hdr.get_version() != 0x1 { + return false; + } + true + } + fn process_message( worker: &mut Worker, wait_ctx: &mut WaitContext, @@ -332,9 +345,10 @@ impl Action for MasterReq { return Ok(()); } match hdr.get_code() { - MasterReq::SET_MEM_TABLE => worker.set_mem_table(hdr, payload, files), - MasterReq::SET_VRING_CALL => worker.set_vring_call(hdr, payload, files), - MasterReq::SET_VRING_KICK => worker.set_vring_kick(wait_ctx, hdr, payload, files), + MasterReq::SET_MEM_TABLE => worker.set_mem_table(payload, files), + MasterReq::SET_VRING_CALL => worker.set_vring_call(payload, files), + MasterReq::SET_VRING_KICK => worker.set_vring_kick(wait_ctx, payload, files), + MasterReq::SET_SLAVE_REQ_FD => worker.set_slave_req_fd(wait_ctx, files), _ => unimplemented!("unimplemented action message: {:?}", hdr.get_code()), } } @@ -351,6 +365,36 @@ impl Action for MasterReq { } } +impl RxAction for SlaveReq { + fn is_header_valid(hdr: &VhostUserMsgHeader) -> bool { + if !hdr.is_reply() || hdr.get_version() != 0x1 { + return false; + } + true + } + + fn process_message( + _worker: &mut Worker, + _wait_ctx: &mut WaitContext, + _hdr: &VhostUserMsgHeader, + _payload: &[u8], + _files: Option>, + ) -> Result<()> { + Ok(()) + } + + fn get_ep(worker: &mut Worker) -> &mut SocketEndpoint { + // We can only be here if we slave_req_fd became readable, so it must exist. + worker.slave_req_fd.as_mut().unwrap() + } + + fn handle_failure(_worker: &mut Worker, hdr: &VhostUserMsgHeader) -> Result<()> { + // There's nothing we can do to directly handle this failure here. + error!("failed to process reply to backend {:?}", hdr.get_code()); + Ok(()) + } +} + impl Worker { // The entry point into `Worker`. // - At this point the connection with the sibling is already established. @@ -381,20 +425,19 @@ impl Worker { let events = wait_ctx.wait().context("failed to wait for events")?; for event in events.iter().filter(|e| e.is_readable) { match event.token { - Token::SiblingSocket => { - match self.process_rx::(&mut wait_ctx) { + Token::SiblingSocket | Token::BackendSocket => { + let res = if event.token == Token::SiblingSocket { + self.process_rx::(&mut wait_ctx) + } else { + self.process_rx::(&mut wait_ctx) + }; + match res { Ok(RxqStatus::Processed) => (), Ok(RxqStatus::DescriptorsExhausted) => { // If the driver has no Rx buffers left, then no // point monitoring the Vhost-user sibling for data. There // would be no way to send it to the device backend. - wait_ctx - .modify( - &self.slave_req_helper, - EventType::None, - Token::SiblingSocket, - ) - .context("failed to disable EPOLLIN on sibling VM socket fd")?; + self.set_rx_polling_state(&mut wait_ctx, EventType::None)?; sibling_socket_polling_enabled = false; } Ok(RxqStatus::Disconnected) => { @@ -411,13 +454,7 @@ impl Worker { // Rx buffers are available, now we should monitor the // Vhost-user sibling connection for data. if !sibling_socket_polling_enabled { - wait_ctx - .modify( - &self.slave_req_helper, - EventType::Read, - Token::SiblingSocket, - ) - .context("failed to add kick event to the epoll set")?; + self.set_rx_polling_state(&mut wait_ctx, EventType::Read)?; sibling_socket_polling_enabled = true; } } @@ -425,7 +462,8 @@ impl Worker { if let Err(e) = tx_queue_evt.read() { bail!("error reading tx queue event: {}", e); } - self.process_tx(); + self.process_tx() + .context("error processing tx queue event")?; } Token::SiblingKick { index } => { if let Err(e) = self.process_sibling_kick(index) { @@ -450,8 +488,32 @@ impl Worker { } } + // Set the target event to poll for on rx descriptors. + fn set_rx_polling_state( + &mut self, + wait_ctx: &mut WaitContext, + target_event: EventType, + ) -> Result<()> { + let fds = std::iter::once(( + &self.slave_req_helper as &dyn AsRawDescriptor, + Token::SiblingSocket, + )) + .chain( + self.slave_req_fd + .as_ref() + .map(|fd| (fd as &dyn AsRawDescriptor, Token::BackendSocket)) + .into_iter(), + ); + for (fd, token) in fds { + wait_ctx + .modify(fd, target_event, token) + .context("failed to set EPOLLIN on socket fd")?; + } + Ok(()) + } + // Processes data from the Vhost-user sibling and forwards to the driver via Rx buffers. - fn process_rx(&mut self, wait_ctx: &mut WaitContext) -> Result { + fn process_rx(&mut self, wait_ctx: &mut WaitContext) -> Result { // Keep looping until - // - No more Rx buffers are available on the Rx queue. OR // - No more data is available on the Vhost-user sibling socket (checked via a @@ -507,7 +569,11 @@ impl Worker { let index = desc.index; let bytes_written = { - let res = R::process_message(self, wait_ctx, &hdr, &buf, files); + let res = if !R::is_header_valid(&hdr) { + Err(anyhow!("invalid header for {:?}", hdr.get_code())) + } else { + R::process_message(self, wait_ctx, &hdr, &buf, files) + }; // If the "action" in response to the action messages // failed then no bytes have been written to the virt // queue. Else, the action is done. Now forward the @@ -543,7 +609,7 @@ impl Worker { } // Returns the sibling connection status. - fn check_sibling_connection(&mut self) -> ConnStatus { + fn check_sibling_connection(&mut self) -> ConnStatus { // Peek if any data is left on the Vhost-user sibling socket. If no, then // nothing to forwad to the device backend. let mut peek_buf = [0; 1]; @@ -571,7 +637,10 @@ impl Worker { } // Returns any data attached to a Vhost-user sibling message. - fn get_sibling_msg_data(&mut self, hdr: &VhostUserMsgHeader) -> Result> { + fn get_sibling_msg_data( + &mut self, + hdr: &VhostUserMsgHeader, + ) -> Result> { let buf = match hdr.get_size() { 0 => vec![0u8; 0], len => { @@ -628,16 +697,7 @@ impl Worker { // this function both this VMM and the sibling have two regions of // virtual memory pointing to the same physical page. These regions will be // accessed by the device VM and the silbing VM. - fn set_mem_table( - &mut self, - hdr: &VhostUserMsgHeader, - payload: &[u8], - files: Option>, - ) -> Result<()> { - if !is_header_valid(hdr) { - bail!("invalid header for SET_MEM_TABLE"); - } - + fn set_mem_table(&mut self, payload: &[u8], files: Option>) -> Result<()> { // `hdr` is followed by a `payload`. `payload` consists of metadata about the number of // memory regions and then memory regions themeselves. The memory regions structs consist of // metadata about actual device related memory passed from the sibling. Ensure that the size @@ -752,16 +812,7 @@ impl Worker { } // Handles |SET_VRING_CALL|. - fn set_vring_call( - &mut self, - hdr: &VhostUserMsgHeader, - payload: &[u8], - files: Option>, - ) -> Result<()> { - if !is_header_valid(hdr) { - bail!("invalid header for SET_VRING_CALL"); - } - + fn set_vring_call(&mut self, payload: &[u8], files: Option>) -> Result<()> { let payload_size = payload.len(); if payload_size != std::mem::size_of::() { bail!("wrong payload size {} for SET_VRING_CALL", payload_size); @@ -790,14 +841,9 @@ impl Worker { fn set_vring_kick( &mut self, wait_ctx: &mut WaitContext, - hdr: &VhostUserMsgHeader, payload: &[u8], files: Option>, ) -> Result<()> { - if !is_header_valid(hdr) { - bail!("invalid header for SET_VRING_KICK"); - } - let payload_size = payload.len(); if payload_size != std::mem::size_of::() { bail!("wrong payload size {} for SET_VRING_KICK", payload_size); @@ -831,23 +877,77 @@ impl Worker { Ok(()) } + // Handles |SET_SLAVE_REQ_FD|. Prepares the proxy to handle backend messages by + // proxying messages/replies to/from the slave_req_fd. + fn set_slave_req_fd( + &mut self, + wait_ctx: &mut WaitContext, + files: Option>, + ) -> Result<()> { + // Validated by check_attached_files + let mut files = files.expect("missing files"); + let file = files.pop().context("missing file for set_slave_req_fd")?; + if !files.is_empty() { + bail!("invalid file count for SET_SLAVE_REQ_FD {}", files.len()); + } + + // Safe because we own the file. + let socket = unsafe { UnixStream::from_raw_descriptor(file.into_raw_descriptor()) }; + + wait_ctx + .add(&socket, Token::BackendSocket) + .context("failed to set EPOLLIN on socket fd")?; + + self.slave_req_fd = Some(SocketEndpoint::from(socket)); + Ok(()) + } + + fn process_message_from_backend( + &mut self, + msg: Vec, + ) -> Result<(Vec, Option>)> { + Ok((msg, None)) + } + // Processes data from the device backend (via virtio Tx queue) and forward it to // the Vhost-user sibling over its socket connection. - fn process_tx(&mut self) { + fn process_tx(&mut self) -> Result<()> { while let Some(desc_chain) = self.tx_queue.pop(&self.mem) { let index = desc_chain.index; match Reader::new(self.mem.clone(), desc_chain) { Ok(mut reader) => { let expected_count = reader.available_bytes(); - match reader.read_to(self.slave_req_helper.as_mut().as_mut(), expected_count) { - Ok(count) => { - // The |reader| guarantees that all the available data is read. - if count != expected_count { - error!("wrote only {} bytes of {}", count, expected_count); - } - } - Err(e) => error!("failed to write message to vhost-vmm: {}", e), + let mut msg = vec![0; expected_count]; + reader + .read_exact(&mut msg) + .context("virtqueue read failed")?; + + // This may be a SlaveReq, but the bytes of any valid SlaveReq + // are also a valid MasterReq. + let hdr = + vhost_header_from_bytes::(&msg).context("message too short")?; + let (dest, (msg, fd)) = if hdr.is_reply() { + (self.slave_req_helper.as_mut().as_mut(), (msg, None)) + } else { + let processed_msg = self.process_message_from_backend(msg)?; + ( + self.slave_req_fd + .as_mut() + .context("missing slave_req_fd")? + .as_mut(), + processed_msg, + ) + }; + + if let Some(fd) = fd { + let written = dest + .send_with_fd(&[IoSlice::new(msg.as_slice())], fd.as_raw_descriptor()) + .context("failed to foward message")?; + dest.write_all(&msg[written..]) + } else { + dest.write_all(msg.as_slice()) } + .context("failed to foward message")?; } Err(e) => error!("failed to create Reader: {}", e), } @@ -856,6 +956,7 @@ impl Worker { panic!("failed inject tx queue interrupt"); } } + Ok(()) } // Processes a sibling kick for the |index|-th vring and injects the corresponding interrupt @@ -1284,6 +1385,7 @@ impl VirtioVhostUser { vrings, slave_req_helper, registered_memory: Vec::new(), + slave_req_fd: None, }; match worker.run( rx_queue_evt.try_clone().unwrap(), diff --git a/third_party/vmm_vhost/src/connection.rs b/third_party/vmm_vhost/src/connection.rs index 2362cef314..b559b6d003 100644 --- a/third_party/vmm_vhost/src/connection.rs +++ b/third_party/vmm_vhost/src/connection.rs @@ -69,6 +69,15 @@ pub trait Endpoint: Send { bufs: &mut [IoSliceMut], allow_fd: bool, ) -> Result<(usize, Option>)>; + + /// Constructs the slave request endpoint for self. + /// + /// # Arguments + /// * `files` - Files from which to create the endpoint + fn create_slave_request_endpoint( + &mut self, + files: Option>, + ) -> Result>>; } // Advance the internal cursor of the slices. diff --git a/third_party/vmm_vhost/src/connection/socket.rs b/third_party/vmm_vhost/src/connection/socket.rs index 424f2180af..7abf673ad6 100644 --- a/third_party/vmm_vhost/src/connection/socket.rs +++ b/third_party/vmm_vhost/src/connection/socket.rs @@ -9,12 +9,12 @@ use std::io::{ErrorKind, IoSlice, IoSliceMut}; use std::marker::PhantomData; use std::path::{Path, PathBuf}; -use base::{AsRawDescriptor, FromRawDescriptor, RawDescriptor, ScmSocket}; +use base::{AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor, RawDescriptor, ScmSocket}; use super::{Error, Result}; use crate::connection::{Endpoint as EndpointTrait, Listener as ListenerTrait, Req}; use crate::message::*; -use crate::{SystemListener, SystemStream}; +use crate::{take_single_file, SystemListener, SystemStream}; /// Unix domain socket listener for accepting incoming connections. pub struct Listener { @@ -205,6 +205,16 @@ impl EndpointTrait for Endpoint { Ok((bytes, files)) } + + fn create_slave_request_endpoint( + &mut self, + files: Option>, + ) -> Result>> { + let file = take_single_file(files).ok_or(Error::InvalidMessage)?; + // Safe because we own the file + let tube = unsafe { SystemStream::from_raw_descriptor(file.into_raw_descriptor()) }; + Ok(Box::new(Endpoint::from(tube))) + } } impl AsRawDescriptor for Endpoint { diff --git a/third_party/vmm_vhost/src/connection/tube.rs b/third_party/vmm_vhost/src/connection/tube.rs index 8204293b3a..e9b7197bed 100644 --- a/third_party/vmm_vhost/src/connection/tube.rs +++ b/third_party/vmm_vhost/src/connection/tube.rs @@ -13,6 +13,7 @@ use serde::{Deserialize, Serialize}; use super::{Error, Result}; use crate::connection::{Endpoint, Req}; +use crate::message::SlaveReq; use std::cmp::min; use std::fs::File; use std::marker::PhantomData; @@ -139,6 +140,13 @@ impl Endpoint for TubeEndpoint { Ok((bytes_read, files)) } + + fn create_slave_request_endpoint( + &mut self, + files: Option>, + ) -> Result>> { + unimplemented!("SET_SLAVE_REQ_FD not supported"); + } } impl AsRawDescriptor for TubeEndpoint { diff --git a/third_party/vmm_vhost/src/connection/vfio.rs b/third_party/vmm_vhost/src/connection/vfio.rs index be8a1756fd..3537fcbd83 100644 --- a/third_party/vmm_vhost/src/connection/vfio.rs +++ b/third_party/vmm_vhost/src/connection/vfio.rs @@ -16,7 +16,7 @@ use thiserror::Error as ThisError; use super::{Error, Result}; use crate::connection::{Endpoint as EndpointTrait, Listener as ListenerTrait, Req}; -use crate::message::MasterReq; +use crate::message::{MasterReq, SlaveReq}; /// Errors for `Device::recv_into_bufs()`. #[sorted] @@ -59,6 +59,11 @@ pub trait Device: Send { &mut self, iovs: &mut [IoSliceMut], ) -> std::result::Result; + + /// Constructs the slave request endpoint for the endpoint backed by this device. + fn create_slave_request_endpoint( + &mut self, + ) -> std::result::Result>, anyhow::Error>; } /// Listener for accepting incoming connections from virtio-vhost-user device through VFIO. @@ -111,6 +116,15 @@ pub struct Endpoint { _r: PhantomData, } +impl From for Endpoint { + fn from(device: D) -> Self { + Self { + device, + _r: PhantomData, + } + } +} + impl EndpointTrait for Endpoint { fn connect>(_path: P) -> Result { // TODO: remove this method from Endpoint trait? @@ -136,6 +150,19 @@ impl EndpointTrait for Endpoint { // VFIO backend doesn't receive any files. Ok((size, None)) } + + fn create_slave_request_endpoint( + &mut self, + files: Option>, + ) -> Result>> { + if files.is_some() { + return Err(Error::InvalidMessage); + } + + self.device + .create_slave_request_endpoint() + .map_err(Error::VfioDeviceError) + } } impl AsRawDescriptor for Endpoint { diff --git a/third_party/vmm_vhost/src/slave_req_handler.rs b/third_party/vmm_vhost/src/slave_req_handler.rs index 8fdf111cb8..d90507a208 100644 --- a/third_party/vmm_vhost/src/slave_req_handler.rs +++ b/third_party/vmm_vhost/src/slave_req_handler.rs @@ -76,7 +76,7 @@ pub trait VhostUserSlaveReqHandler { fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>; fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result>; fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; - fn set_slave_req_fd(&self, _vu_req: File) {} + fn set_slave_req_fd(&self, _vu_req: Box>) {} fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>; fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>; fn get_max_mem_slots(&self) -> Result; @@ -125,7 +125,7 @@ pub trait VhostUserSlaveReqHandlerMut { flags: VhostUserConfigFlags, ) -> Result>; fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; - fn set_slave_req_fd(&mut self, _vu_req: File) {} + fn set_slave_req_fd(&mut self, _vu_req: Box>) {} fn get_inflight_fd( &mut self, inflight: &VhostUserInflight, @@ -224,7 +224,7 @@ impl VhostUserSlaveReqHandler for Mutex { self.lock().unwrap().set_config(offset, buf, flags) } - fn set_slave_req_fd(&self, vu_req: File) { + fn set_slave_req_fd(&self, vu_req: Box>) { self.lock().unwrap().set_slave_req_fd(vu_req) } @@ -860,13 +860,12 @@ impl> SlaveReqHandler } fn set_slave_req_fd(&mut self, files: Option>) -> Result<()> { - if cfg!(windows) { - unimplemented!(); - } else { - let file = take_single_file(files).ok_or(Error::InvalidMessage)?; - self.backend.set_slave_req_fd(file); - Ok(()) - } + let ep = self + .slave_req_helper + .endpoint + .create_slave_request_endpoint(files)?; + self.backend.set_slave_req_fd(ep); + Ok(()) } fn handle_vring_fd_request(