From 8e7bc966166626d7ec2176b47e5bc190597b604a Mon Sep 17 00:00:00 2001 From: Keiichi Watanabe Date: Tue, 20 Apr 2021 13:14:04 +0900 Subject: [PATCH] Add vhost_user_devices crate Add `vhost_user_devices` crate which will be used to create a vhost-user device executables. BUG=b:185089400 TEST=cargo test in /vhost_user_devices Change-Id: I7256d68316f7763d3ceaa65abc97663975e7608f Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/2822169 Tested-by: kokoro Reviewed-by: Noah Gold Reviewed-by: Chirantan Ekbote Commit-Queue: Keiichi Watanabe --- Cargo.lock | 16 + Cargo.toml | 3 +- devices/src/virtio/queue.rs | 4 +- devices/src/virtio/vhost/user/handler.rs | 29 +- devices/src/virtio/vhost/user/mod.rs | 1 + vhost_user_devices/Cargo.toml | 19 + vhost_user_devices/src/lib.rs | 874 +++++++++++++++++++++++ vm_memory/src/guest_memory.rs | 72 +- 8 files changed, 992 insertions(+), 26 deletions(-) create mode 100644 vhost_user_devices/Cargo.toml create mode 100644 vhost_user_devices/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 2205ab0c7d..cc02f55af6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -231,6 +231,7 @@ dependencies = [ "tempfile", "thiserror", "vhost", + "vhost_user_devices", "vm_control", "vm_memory", "x86_64", @@ -1143,6 +1144,21 @@ dependencies = [ "vm_memory", ] +[[package]] +name = "vhost_user_devices" +version = "0.1.0" +dependencies = [ + "base", + "devices", + "libc", + "remain", + "sync", + "tempfile", + "thiserror", + "vm_memory", + "vmm_vhost", +] + [[package]] name = "virtio_sys" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index fdc20bf37d..7198f711c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -92,6 +92,7 @@ sync = { path = "sync" } tempfile = "*" thiserror = { version = "1.0.20", optional = true } vhost = { path = "vhost" } +vhost_user_devices = { path = "vhost_user_devices" } vm_control = { path = "vm_control" } acpi_tables = { path = "acpi_tables" } vm_memory = { path = "vm_memory" } @@ -118,5 +119,5 @@ p9 = { path = "../../platform2/vm_tools/p9" } # ignored by ebuild sync = { path = "sync" } sys_util = { path = "sys_util" } tempfile = { path = "tempfile" } -vmm_vhost = { path = "../../third_party/rust-vmm/vhost", features = ["vhost-user-master"] } # ignored by ebuild +vmm_vhost = { path = "../../third_party/rust-vmm/vhost", features = ["vhost-user-master", "vhost-user-slave"] } # ignored by ebuild wire_format_derive = { path = "../../platform2/vm_tools/p9/wire_format_derive" } # ignored by ebuild diff --git a/devices/src/virtio/queue.rs b/devices/src/virtio/queue.rs index 537abaeeef..1fd614b513 100644 --- a/devices/src/virtio/queue.rs +++ b/devices/src/virtio/queue.rs @@ -221,8 +221,8 @@ pub struct Queue { /// Guest physical address of the used ring pub used_ring: GuestAddress, - next_avail: Wrapping, - next_used: Wrapping, + pub next_avail: Wrapping, + pub next_used: Wrapping, // Device feature bits accepted by the driver features: u64, diff --git a/devices/src/virtio/vhost/user/handler.rs b/devices/src/virtio/vhost/user/handler.rs index 25a6403bc6..b3faacd1ee 100644 --- a/devices/src/virtio/vhost/user/handler.rs +++ b/devices/src/virtio/vhost/user/handler.rs @@ -108,7 +108,8 @@ impl VhostUserHandler { .map_err(Error::CopyConfig) } - fn set_mem_table(&mut self, mem: &GuestMemory) -> Result<()> { + /// Sets the memory map regions so it can translate the vring addresses. + pub fn set_mem_table(&mut self, mem: &GuestMemory) -> Result<()> { let mut regions: Vec = Vec::new(); mem.with_regions::<_, ()>( |_idx, guest_phys_addr, memory_size, userspace_addr, mmap, mmap_offset| { @@ -132,13 +133,14 @@ impl VhostUserHandler { Ok(()) } - fn activate_vring( + /// Activates a vring for the given `queue`. + pub fn activate_vring( &mut self, mem: &GuestMemory, queue_index: usize, queue: &Queue, queue_evt: &Event, - interrupt: &Interrupt, + irqfd: &Event, ) -> Result<()> { self.vu .set_vring_num(queue_index, queue.actual_size()) @@ -167,18 +169,9 @@ impl VhostUserHandler { .set_vring_base(queue_index, 0) .map_err(Error::SetVringBase)?; - let msix_config_opt = interrupt - .get_msix_config() - .as_ref() - .ok_or(Error::MsixConfigUnavailable)?; - let msix_config = msix_config_opt.lock(); - let irqfd = msix_config - .get_irqfd(queue.vector as usize) - .ok_or(Error::MsixIrqfdUnavailable)?; self.vu .set_vring_call(queue_index, &irqfd.0) .map_err(Error::SetVringCall)?; - self.vu .set_vring_kick(queue_index, &queue_evt.0) .map_err(Error::SetVringKick)?; @@ -199,9 +192,19 @@ impl VhostUserHandler { ) -> Result<()> { self.set_mem_table(&mem)?; + let msix_config_opt = interrupt + .get_msix_config() + .as_ref() + .ok_or(Error::MsixConfigUnavailable)?; + let msix_config = msix_config_opt.lock(); + for (queue_index, queue) in queues.iter().enumerate() { let queue_evt = &queue_evts[queue_index]; - self.activate_vring(&mem, queue_index, queue, queue_evt, &interrupt)?; + let irqfd = msix_config + .get_irqfd(queue.vector as usize) + .ok_or(Error::MsixIrqfdUnavailable)?; + + self.activate_vring(&mem, queue_index, queue, queue_evt, &irqfd)?; } Ok(()) diff --git a/devices/src/virtio/vhost/user/mod.rs b/devices/src/virtio/vhost/user/mod.rs index b408e1c18f..fab3de0abd 100644 --- a/devices/src/virtio/vhost/user/mod.rs +++ b/devices/src/virtio/vhost/user/mod.rs @@ -10,6 +10,7 @@ mod worker; pub use self::block::*; pub use self::fs::*; +pub use self::handler::VhostUserHandler; pub use self::net::*; use remain::sorted; diff --git a/vhost_user_devices/Cargo.toml b/vhost_user_devices/Cargo.toml new file mode 100644 index 0000000000..5921f688a9 --- /dev/null +++ b/vhost_user_devices/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "vhost_user_devices" +version = "0.1.0" +authors = ["The Chromium OS Authors"] +edition = "2018" + +[dependencies] +base = { path = "../base" } +devices = { path = "../devices" } +libc = "*" +remain = "*" +sync = { path = "../sync" } +thiserror = "*" +vm_memory = { path = "../vm_memory" } +vmm_vhost = { version = "*", features = ["vhost-user-slave"] } + +[dev-dependencies] +data_model = { path = "../data_model" } +tempfile = { path = "../tempfile" } diff --git a/vhost_user_devices/src/lib.rs b/vhost_user_devices/src/lib.rs new file mode 100644 index 0000000000..1ee873f63a --- /dev/null +++ b/vhost_user_devices/src/lib.rs @@ -0,0 +1,874 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! Library for implementing vhost-user device executables. +//! +//! This crate provides +//! * `VhostUserBackend` trait, which is a collection of methods to handle vhost-user requests, and +//! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop. +//! +//! They are expected to be used as follows: +//! 1. Define a struct which `VhostUserBackend` is implemented for. +//! 2. Create an instance of `DeviceRequestHandler` with the backend and call its `start()` method +//! to start an event loop. +//! +//! ```ignore +//! struct MyBackend { +//! /* fields */ +//! } +//! +//! impl VhostUserBackend for MyBackend { +//! /* implement methods */ +//! } +//! +//! fn main() { +//! let backend = MyBackend { /* initialize fields */ }; +//! let handler = DeviceRequestHandler::new(backend).unwrap(); +//! let socket = std::path::Path("/path/to/socket"); +//! +//! if let Err(e) = handler.start(socket) { +//! eprintln!("error happened: {}", e); +//! } +//! } +//! ``` +//! + +use std::cell::RefCell; +use std::convert::TryFrom; +use std::num::Wrapping; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::path::Path; +use std::rc::Rc; +use std::sync::Arc; + +use base::{ + error, AsRawDescriptor, Event, EventType, FromRawDescriptor, PollToken, SafeDescriptor, + SharedMemory, SharedMemoryUnix, WaitContext, +}; +use devices::virtio::{Queue, SignalableInterrupt}; +use remain::sorted; +use thiserror::Error as ThisError; +use vm_memory::{GuestAddress, GuestMemory, MemoryRegion}; +use vmm_vhost::vhost_user::message::{ + VhostUserConfigFlags, VhostUserMemoryRegion, VhostUserProtocolFeatures, + VhostUserSingleMemoryRegion, VhostUserVirtioFeatures, VhostUserVringAddrFlags, + VhostUserVringState, +}; +use vmm_vhost::vhost_user::{ + Error as VhostError, Listener, Result as VhostResult, SlaveFsCacheReq, SlaveListener, + VhostUserSlaveReqHandlerMut, +}; + +/// An event to deliver an interrupt to the guest. +/// +/// Unlike `devices::Interrupt`, this doesn't support interrupt status and signal resampling. +// TODO(b/187487351): To avoid sending unnecessary events, we might want to support interrupt +// status. For this purpose, we need a mechanism to share interrupt status between the vmm and the +// device process. +pub struct CallEvent(Event); + +impl SignalableInterrupt for CallEvent { + fn signal(&self, _vector: u16, _interrupt_status_mask: u32) { + self.0.write(1).unwrap(); + } + + fn signal_config_changed(&self) {} // TODO(dgreid) + + fn get_resample_evt(&self) -> Option<&Event> { + None + } + + fn do_interrupt_resample(&self) {} +} + +/// Keeps a mpaaing from the vmm's virtual addresses to guest addresses. +/// used to translate messages from the vmm to guest offsets. +#[derive(Default)] +struct MappingInfo { + vmm_addr: u64, + guest_phys: u64, + size: u64, +} + +fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult { + for map in maps { + if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size { + return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys)); + } + } + Err(VhostError::InvalidMessage) +} + +/// Trait for vhost-user backend. +pub trait VhostUserBackend +where + Self: Sized, + Self::EventToken: PollToken + std::fmt::Debug, + Self::Error: std::error::Error + std::fmt::Debug, +{ + const MAX_QUEUE_NUM: usize; + const MAX_VRING_NUM: usize; + + /// Types of tokens that can be associated with polling events. + type EventToken; + + /// Error type specific to this backend. + type Error; + + /// Translates a queue's index into `EventToken`. + fn index_to_event_type(queue_index: usize) -> Option; + + /// The set of feature bits that this backend supports. + fn features(&self) -> u64; + + /// Acknowledges that this set of features should be enabled. + fn ack_features(&mut self, value: u64) -> std::result::Result<(), Self::Error>; + + /// Returns the set of enabled features. + fn acked_features(&self) -> u64; + + /// The set of protocol feature bits that this backend supports. + fn protocol_features(&self) -> VhostUserProtocolFeatures; + + /// Acknowledges that this set of protocol features should be enabled. + fn ack_protocol_features(&mut self, _value: u64) -> std::result::Result<(), Self::Error>; + + /// Returns the set of enabled protocol features. + fn acked_protocol_features(&self) -> u64; + + /// Reads this device configuration space at `offset`. + fn read_config(&self, offset: u64, dst: &mut [u8]); + + /// Sets guest memory regions. + fn set_guest_mem(&mut self, mem: GuestMemory); + + /// Returns a backend event to be waited for. + fn backend_event(&self) -> Option<(&dyn AsRawDescriptor, EventType, Self::EventToken)>; + + /// Processes a given event. + fn handle_event( + &mut self, + wait_ctx: &Rc>>, + event: &Self::EventToken, + vrings: &[Rc>], + ) -> std::result::Result<(), Self::Error>; + + /// Resets the vhost-user backend. + fn reset(&mut self); +} + +/// A virtio ring entry. +pub struct Vring { + pub queue: Queue, + pub call_evt: Option>, + pub kick_evt: Option, + pub enabled: bool, +} + +impl Vring { + fn new(max_size: u16) -> Self { + Self { + queue: Queue::new(max_size), + call_evt: None, + kick_evt: None, + enabled: false, + } + } + + fn reset(&mut self) { + self.queue.reset(); + self.call_evt = None; + self.kick_evt = None; + self.enabled = false; + } +} + +#[sorted] +#[derive(ThisError, Debug)] +pub enum HandlerError { + /// Failed to accept an incoming connection. + #[error("failed to accept an incoming connection: {0}")] + AcceptConnection(VhostError), + /// Failed to create a connection listener. + #[error("failed to create a connection listener: {0}")] + CreateConnectionListener(VhostError), + /// Failed to create a UNIX domain socket listener. + #[error("failed to create a UNIX domain socket listener: {0}")] + CreateSocketListener(VhostError), + /// Failed to handle a backend event. + #[error("failed to handle a backend event: {0}")] + HandleBackendEvent(BackendError), + /// Failed to handle a vhost-user request. + #[error("failed to handle a vhost-user request: {0}")] + HandleVhostUserRequest(VhostError), + /// Invalid queue index is given. + #[error("invalid queue index is given: {index}")] + InvalidQueueIndex { index: usize }, + /// Failed to add new FD(s) to wait context. + #[error("failed to add new FD(s) to wait context: {0}")] + WaitContextAdd(base::Error), + /// Failed to create a wait context. + #[error("failed to create a wait context: {0}")] + WaitContextCreate(base::Error), + /// Failed to delete a FD from wait context. + #[error("failed to delete a FD from wait context: {0}")] + WaitContextDel(base::Error), + /// Failed to wait for event. + #[error("failed to wait for an event triggered: {0}")] + WaitContextWait(base::Error), +} + +type HandlerResult = std::result::Result::Error>>; + +#[derive(Debug)] +pub enum HandlerPollToken { + BackendToken(B::EventToken), + VhostUserRequest, +} + +impl PollToken for HandlerPollToken { + fn as_raw_token(&self) -> u64 { + match self { + Self::BackendToken(t) => t.as_raw_token(), + Self::VhostUserRequest => u64::MAX, + } + } + + fn from_raw_token(data: u64) -> Self { + match data { + u64::MAX => Self::VhostUserRequest, + _ => Self::BackendToken(B::EventToken::from_raw_token(data)), + } + } +} + +/// Structure to have an event loop for interaction between a VMM and `VhostUserBackend`. +pub struct DeviceRequestHandler +where + B: 'static + VhostUserBackend, +{ + owned: bool, + vrings: Vec>>, + vmm_maps: Option>, + backend: Rc>, + wait_ctx: Rc>>, +} + +impl DeviceRequestHandler +where + B: 'static + VhostUserBackend, +{ + /// Creates the handler instance for `backend`. + pub fn new(backend: B) -> HandlerResult { + let mut vrings = Vec::with_capacity(B::MAX_QUEUE_NUM as usize); + for _ in 0..B::MAX_QUEUE_NUM { + vrings.push(Rc::new(RefCell::new(Vring::new(B::MAX_VRING_NUM as u16)))); + } + + let wait_ctx: WaitContext> = + WaitContext::new().map_err(HandlerError::WaitContextCreate)?; + + if let Some((evt, typ, token)) = backend.backend_event() { + wait_ctx + .add_for_event(evt, typ, HandlerPollToken::BackendToken(token)) + .map_err(HandlerError::WaitContextAdd)?; + } + + Ok(DeviceRequestHandler { + owned: false, + vmm_maps: None, + vrings, + backend: Rc::new(RefCell::new(backend)), + wait_ctx: Rc::new(wait_ctx), + }) + } + + /// Connects to `socket` and starts an event loop which handles incoming vhost-user requests from + /// the VMM and events from the backend. + // TODO(keiichiw): Remove the clippy annotation once we uprev clippy to 1.52.0 or later. + // cf. https://github.com/rust-lang/rust-clippy/issues/6546 + #[allow(clippy::clippy::result_unit_err)] + pub fn start>(self, socket: P) -> HandlerResult { + let vrings = self.vrings.clone(); + let backend = self.backend.clone(); + let wait_ctx = self.wait_ctx.clone(); + + let listener = Listener::new(socket, true).map_err(HandlerError::CreateSocketListener)?; + let mut s_listener = SlaveListener::new(listener, Arc::new(std::sync::Mutex::new(self))) + .map_err(HandlerError::CreateConnectionListener)?; + + let mut req_handler = s_listener + .accept() + .map_err(HandlerError::AcceptConnection)? + .expect("no incoming connection was detected"); + + let sd = SafeDescriptor::try_from(&req_handler as &dyn AsRawFd) + .expect("failed to get safe descriptor for handler"); + wait_ctx + .add(&sd, HandlerPollToken::VhostUserRequest) + .map_err(HandlerError::WaitContextAdd)?; + + loop { + let events = wait_ctx.wait().map_err(HandlerError::WaitContextWait)?; + for event in events.iter() { + match &event.token { + HandlerPollToken::BackendToken(token) => { + backend + .borrow_mut() + .handle_event(&wait_ctx, &token, &vrings) + .map_err(HandlerError::HandleBackendEvent)?; + } + HandlerPollToken::VhostUserRequest => { + req_handler + .handle_request() + .map_err(HandlerError::HandleVhostUserRequest)?; + } + } + } + } + } + + fn register_kickfd(&self, index: usize, event: &Event) -> HandlerResult { + let token = + B::index_to_event_type(index).ok_or(HandlerError::InvalidQueueIndex { index })?; + self.wait_ctx + .add(&event.0, HandlerPollToken::BackendToken(token)) + .map_err(HandlerError::WaitContextAdd) + } + + fn unregister_kickfd(&self, event: &Event) -> HandlerResult { + self.wait_ctx + .delete(&event.0) + .map_err(HandlerError::WaitContextDel) + } +} + +impl VhostUserSlaveReqHandlerMut for DeviceRequestHandler { + fn set_owner(&mut self) -> VhostResult<()> { + if self.owned { + return Err(VhostError::InvalidOperation); + } + self.owned = true; + Ok(()) + } + + fn reset_owner(&mut self) -> VhostResult<()> { + self.owned = false; + self.backend.borrow_mut().reset(); + Ok(()) + } + + fn get_features(&mut self) -> VhostResult { + let features = self.backend.borrow().features(); + Ok(features) + } + + fn set_features(&mut self, features: u64) -> VhostResult<()> { + if !self.owned { + return Err(VhostError::InvalidOperation); + } + + if (features & !(self.backend.borrow().features())) != 0 { + return Err(VhostError::InvalidParam); + } + + if let Err(e) = self.backend.borrow_mut().ack_features(features) { + error!("failed to acknowledge features 0x{:x}: {}", features, e); + return Err(VhostError::InvalidOperation); + } + + // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an + // enabled state. + // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a + // disabled state. + // Client must not pass data to/from the backend until ring is enabled by + // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by + // VHOST_USER_SET_VRING_ENABLE with parameter 0. + let acked_features = self.backend.borrow().acked_features(); + let vring_enabled = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() & acked_features != 0; + for v in &mut self.vrings { + let mut vring = v.borrow_mut(); + vring.enabled = vring_enabled; + } + + Ok(()) + } + + fn get_protocol_features(&mut self) -> VhostResult { + Ok(self.backend.borrow().protocol_features()) + } + + fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> { + if let Err(e) = self.backend.borrow_mut().ack_protocol_features(features) { + error!("failed to set protocol features 0x{:x}: {}", features, e); + return Err(VhostError::InvalidOperation); + } + Ok(()) + } + + fn set_mem_table( + &mut self, + contexts: &[VhostUserMemoryRegion], + fds: &[RawFd], + ) -> VhostResult<()> { + if fds.len() != contexts.len() { + return Err(VhostError::InvalidParam); + } + + let mut regions = Vec::with_capacity(fds.len()); + for (region, &fd) in contexts.iter().zip(fds.iter()) { + let rd = base::validate_raw_descriptor(fd).map_err(|e| { + error!("invalid fd is given: {}", e); + VhostError::InvalidParam + })?; + // Safe because we verified that we are the unique owner of `rd`. + let sd = unsafe { SafeDescriptor::from_raw_descriptor(rd) }; + + let region = MemoryRegion::new( + region.memory_size, + GuestAddress(region.guest_phys_addr), + region.mmap_offset, + Arc::new(SharedMemory::from_safe_descriptor(sd).unwrap()), + ) + .map_err(|e| { + error!("failed to create a memory region: {}", e); + VhostError::InvalidOperation + })?; + regions.push(region); + } + let guest_mem = GuestMemory::from_regions(regions).map_err(|e| { + error!("failed to create guest memory: {}", e); + VhostError::InvalidOperation + })?; + + let vmm_maps = contexts + .iter() + .map(|region| MappingInfo { + vmm_addr: region.user_addr, + guest_phys: region.guest_phys_addr, + size: region.memory_size, + }) + .collect(); + + self.backend.borrow_mut().set_guest_mem(guest_mem); + + self.vmm_maps = Some(vmm_maps); + Ok(()) + } + + fn get_queue_num(&mut self) -> VhostResult { + Ok(self.vrings.len() as u64) + } + + fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> { + if index as usize >= self.vrings.len() || num == 0 || num as usize > B::MAX_VRING_NUM { + return Err(VhostError::InvalidParam); + } + let mut vring = self.vrings[index as usize].borrow_mut(); + vring.queue.size = num as u16; + + Ok(()) + } + + fn set_vring_addr( + &mut self, + index: u32, + _flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + _log: u64, + ) -> VhostResult<()> { + if index as usize >= self.vrings.len() { + return Err(VhostError::InvalidParam); + } + + let vmm_maps = self.vmm_maps.as_ref().ok_or(VhostError::InvalidParam)?; + let mut vring = self.vrings[index as usize].borrow_mut(); + vring.queue.desc_table = vmm_va_to_gpa(&vmm_maps, descriptor)?; + vring.queue.avail_ring = vmm_va_to_gpa(&vmm_maps, available)?; + vring.queue.used_ring = vmm_va_to_gpa(&vmm_maps, used)?; + + Ok(()) + } + + fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> { + if index as usize >= self.vrings.len() || base as usize >= B::MAX_VRING_NUM { + return Err(VhostError::InvalidParam); + } + + let mut vring = self.vrings[index as usize].borrow_mut(); + vring.queue.next_avail = Wrapping(base as u16); + vring.queue.next_used = Wrapping(base as u16); + + Ok(()) + } + + fn get_vring_base(&mut self, index: u32) -> VhostResult { + if index as usize >= self.vrings.len() { + return Err(VhostError::InvalidParam); + } + + // Quotation from vhost-user spec: + // Client must start ring upon receiving a kick (that is, detecting + // that file descriptor is readable) on the descriptor specified by + // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving + // VHOST_USER_GET_VRING_BASE. + let mut vring = self.vrings[index as usize].borrow_mut(); + vring.reset(); + if let Some(kick) = &vring.kick_evt { + self.unregister_kickfd(kick).expect("unregister_kickfd"); + } + + Ok(VhostUserVringState::new( + index, + vring.queue.next_avail.0 as u32, + )) + } + + fn set_vring_kick(&mut self, index: u8, fd: Option) -> VhostResult<()> { + if index as usize >= self.vrings.len() { + return Err(VhostError::InvalidParam); + } + + if let Some(fd) = fd { + // TODO(b/186625058): The current code returns an error when `FD_CLOEXEC` is already + // set, which is not harmful. Once we update the vhost crate's API to pass around `File` + // instead of `RawFd`, we won't need this validation. + let rd = base::validate_raw_descriptor(fd).map_err(|e| { + error!("invalid fd is given: {}", e); + VhostError::InvalidParam + })?; + // Safe because the FD is now owned. + let kick = unsafe { Event::from_raw_descriptor(rd) }; + + self.register_kickfd(index as usize, &kick) + .expect("register_kickfd"); + + let mut vring = self.vrings[index as usize].borrow_mut(); + vring.kick_evt = Some(kick); + vring.queue.ready = true; + } + Ok(()) + } + + fn set_vring_call(&mut self, index: u8, fd: Option) -> VhostResult<()> { + if index as usize >= self.vrings.len() { + return Err(VhostError::InvalidParam); + } + + if let Some(fd) = fd { + let rd = base::validate_raw_descriptor(fd).map_err(|e| { + error!("invalid fd is given: {}", e); + VhostError::InvalidParam + })?; + // Safe because the FD is now owned. + let call = unsafe { Event::from_raw_descriptor(rd) }; + self.vrings[index as usize].borrow_mut().call_evt = Some(Arc::new(CallEvent(call))); + } + + Ok(()) + } + + fn set_vring_err(&mut self, _index: u8, _fd: Option) -> VhostResult<()> { + // TODO + Ok(()) + } + + fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> { + if index as usize >= self.vrings.len() { + return Err(VhostError::InvalidParam); + } + + // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES + // has been negotiated. + if self.backend.borrow().acked_features() + & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() + == 0 + { + return Err(VhostError::InvalidOperation); + } + + // Slave must not pass data to/from the backend until ring is + // enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1, + // or after it has been disabled by VHOST_USER_SET_VRING_ENABLE + // with parameter 0. + let mut vring = self.vrings[index as usize].borrow_mut(); + vring.enabled = enable; + + Ok(()) + } + + fn get_config( + &mut self, + offset: u32, + size: u32, + _flags: VhostUserConfigFlags, + ) -> VhostResult> { + if offset >= size { + return Err(VhostError::InvalidParam); + } + + let mut data = vec![0; size as usize]; + self.backend + .borrow() + .read_config(u64::from(offset), &mut data); + Ok(data) + } + + fn set_config( + &mut self, + _offset: u32, + _buf: &[u8], + _flags: VhostUserConfigFlags, + ) -> VhostResult<()> { + // TODO + Ok(()) + } + + fn set_slave_req_fd(&mut self, _vu_req: SlaveFsCacheReq) { + // TODO + } + + fn get_max_mem_slots(&mut self) -> VhostResult { + //TODO + Ok(0) + } + + fn add_mem_region( + &mut self, + _region: &VhostUserSingleMemoryRegion, + _fd: RawFd, + ) -> VhostResult<()> { + //TODO + Ok(()) + } + + fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> { + //TODO + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::mpsc::channel; + use std::sync::Barrier; + + use data_model::DataInit; + use devices::virtio::vhost::user::VhostUserHandler; + use tempfile::{Builder, TempDir}; + use vmm_vhost::vhost_user::Master; + + #[derive(PollToken, Debug)] + enum FakeToken { + Queue0, + } + + #[derive(ThisError, Debug)] + enum FakeError { + #[error("invalid features are given: 0x{features:x}")] + InvalidFeatures { features: u64 }, + #[error("invalid protocol features are given: 0x{features:x}")] + InvalidProtocolFeatures { features: u64 }, + } + + #[derive(Clone, Copy, Debug, PartialEq, Eq)] + #[repr(C)] + struct FakeConfig { + x: u32, + y: u64, + } + + unsafe impl DataInit for FakeConfig {} + + const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 }; + + struct FakeBackend { + mem: Option, + avail_features: u64, + acked_features: u64, + acked_protocol_features: VhostUserProtocolFeatures, + } + + impl FakeBackend { + fn new() -> Self { + Self { + mem: None, + avail_features: VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(), + acked_features: 0, + acked_protocol_features: VhostUserProtocolFeatures::empty(), + } + } + } + + impl VhostUserBackend for FakeBackend { + const MAX_QUEUE_NUM: usize = 16; + const MAX_VRING_NUM: usize = 256; + + type EventToken = FakeToken; + type Error = FakeError; + + fn index_to_event_type(queue_index: usize) -> Option { + match queue_index { + 0 => Some(FakeToken::Queue0), + _ => None, + } + } + + fn features(&self) -> u64 { + self.avail_features + } + + fn ack_features(&mut self, value: u64) -> std::result::Result<(), Self::Error> { + let unrequested_features = value & !self.avail_features; + if unrequested_features != 0 { + return Err(FakeError::InvalidFeatures { + features: unrequested_features, + }); + } + self.acked_features |= value; + Ok(()) + } + + fn acked_features(&self) -> u64 { + self.acked_features + } + + fn set_guest_mem(&mut self, mem: GuestMemory) { + self.mem = Some(mem); + } + + fn protocol_features(&self) -> VhostUserProtocolFeatures { + VhostUserProtocolFeatures::CONFIG + } + + fn ack_protocol_features(&mut self, features: u64) -> std::result::Result<(), Self::Error> { + let features = VhostUserProtocolFeatures::from_bits(features) + .ok_or(FakeError::InvalidProtocolFeatures { features })?; + let supported = self.protocol_features(); + self.acked_protocol_features = features & supported; + Ok(()) + } + + fn acked_protocol_features(&self) -> u64 { + self.acked_protocol_features.bits() + } + + fn backend_event(&self) -> Option<(&dyn AsRawDescriptor, EventType, Self::EventToken)> { + None + } + + fn handle_event( + &mut self, + _wait_ctx: &Rc>>, + _event: &Self::EventToken, + _vrings: &[Rc>], + ) -> std::result::Result<(), Self::Error> { + Ok(()) + } + + fn read_config(&self, offset: u64, dst: &mut [u8]) { + dst.copy_from_slice(&FAKE_CONFIG_DATA.as_slice()[offset as usize..]); + } + + fn reset(&mut self) {} + } + + fn temp_dir() -> TempDir { + Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap() + } + + #[test] + fn test_vhost_user_activate() { + use vmm_vhost::vhost_user::{Listener, SlaveListener}; + + const QUEUES_NUM: usize = 2; + + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let listener = Listener::new(&path, true).unwrap(); + + let vmm_bar = Arc::new(Barrier::new(2)); + let dev_bar = vmm_bar.clone(); + + let (tx, rx) = channel(); + + std::thread::spawn(move || { + // VMM side + rx.recv().unwrap(); // Ensure the device is ready. + + let vu = Master::connect(&path, 1).unwrap(); + let allow_features = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let init_features = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let allow_protocol_features = VhostUserProtocolFeatures::CONFIG; + let mut vmm_handler = + VhostUserHandler::new(vu, allow_features, init_features, allow_protocol_features) + .unwrap(); + + println!("read_config"); + let mut buf = vec![0; std::mem::size_of::()]; + vmm_handler.read_config::(0, &mut buf).unwrap(); + // Check if the obtained config data is correct. + let config = FakeConfig::from_slice(&buf).unwrap(); + assert_eq!(*config, FAKE_CONFIG_DATA); + + println!("set_mem_table"); + let mem = GuestMemory::new(&vec![(GuestAddress(0x0), 0x10000)]).unwrap(); + vmm_handler.set_mem_table(&mem).unwrap(); + + for idx in 0..QUEUES_NUM { + println!("activate_mem_table: queue_index={}", idx); + let queue = Queue::new(0x10); + let queue_evt = Event::new().unwrap(); + let irqfd = Event::new().unwrap(); + + vmm_handler + .activate_vring(&mem, 0, &queue, &queue_evt, &irqfd) + .unwrap(); + } + + vmm_bar.wait(); + }); + + // Device side + let handler = Arc::new(std::sync::Mutex::new( + DeviceRequestHandler::new(FakeBackend::new()).unwrap(), + )); + let mut listener = SlaveListener::new(listener, handler).unwrap(); + + // Notify listener is ready. + tx.send(()).unwrap(); + + let mut listener = listener.accept().unwrap().unwrap(); + + // VhostUserHandler::new() + listener.handle_request().expect("set_owner"); + listener.handle_request().expect("get_features"); + listener.handle_request().expect("set_features"); + listener.handle_request().expect("get_protocol_features"); + listener.handle_request().expect("set_protocol_features"); + + // VhostUserHandler::read_config() + listener.handle_request().expect("get_config"); + + // VhostUserHandler::set_mem_table() + listener.handle_request().expect("set_mem_table"); + + for _ in 0..QUEUES_NUM { + // VhostUserHandler::activate_vring() + listener.handle_request().expect("set_vring_num"); + listener.handle_request().expect("set_vring_addr"); + listener.handle_request().expect("set_vring_base"); + listener.handle_request().expect("set_vring_call"); + listener.handle_request().expect("set_vring_kick"); + listener.handle_request().expect("set_vring_enable"); + } + + dev_bar.wait(); + } +} diff --git a/vm_memory/src/guest_memory.rs b/vm_memory/src/guest_memory.rs index 62ba134e7a..803c167508 100644 --- a/vm_memory/src/guest_memory.rs +++ b/vm_memory/src/guest_memory.rs @@ -11,18 +11,18 @@ use std::mem::size_of; use std::result; use std::sync::Arc; -use crate::guest_address::GuestAddress; use base::{pagesize, Error as SysError}; use base::{ AsRawDescriptor, AsRawDescriptors, MappedRegion, MemfdSeals, MemoryMapping, - MemoryMappingBuilder, MemoryMappingUnix, MmapError, RawDescriptor, SharedMemory, - SharedMemoryUnix, + MemoryMappingBuilder, MemoryMappingBuilderUnix, MemoryMappingUnix, MmapError, RawDescriptor, + SharedMemory, SharedMemoryUnix, }; +use bitflags::bitflags; use cros_async::{mem, BackingMemory}; use data_model::volatile_memory::*; use data_model::DataInit; -use bitflags::bitflags; +use crate::guest_address::GuestAddress; #[derive(Debug)] pub enum Error { @@ -32,7 +32,7 @@ pub enum Error { MemoryAccess(GuestAddress, MmapError), MemoryMappingFailed(MmapError), MemoryRegionOverlap, - MemoryRegionTooLarge(u64), + MemoryRegionTooLarge(u128), MemoryNotAligned, MemoryCreationFailed(SysError), MemoryAddSealsFailed(SysError), @@ -93,7 +93,10 @@ bitflags! { } } -struct MemoryRegion { +/// A regions of memory mapped memory. +/// Holds the memory mapping with its offset in guest memory. +/// Also holds the backing fd for the mapping and the offset in that fd of the mapping. +pub struct MemoryRegion { mapping: MemoryMapping, guest_base: GuestAddress, shm_offset: u64, @@ -101,6 +104,27 @@ struct MemoryRegion { } impl MemoryRegion { + /// Creates a new MemoryRegion using the given SharedMemory object to later be attached to a VM + /// at `guest_base` address in the guest. + pub fn new( + size: u64, + guest_base: GuestAddress, + shm_offset: u64, + shm: Arc, + ) -> Result { + let mapping = MemoryMappingBuilder::new(size as usize) + .from_descriptor(shm.as_ref()) + .offset(shm_offset) + .build() + .map_err(Error::MemoryMappingFailed)?; + Ok(MemoryRegion { + mapping, + guest_base, + shm_offset, + shm, + }) + } + fn start(&self) -> GuestAddress { self.guest_base } @@ -115,8 +139,8 @@ impl MemoryRegion { } } -/// Tracks a memory region and where it is mapped in the guest, along with a shm -/// fd of the underlying memory regions. +/// Tracks memory regions and where they are mapped in the guest, along with shm +/// fds of the underlying memory regions. #[derive(Clone)] pub struct GuestMemory { regions: Arc<[MemoryRegion]>, @@ -178,8 +202,8 @@ impl GuestMemory { } } - let size = - usize::try_from(range.1).map_err(|_| Error::MemoryRegionTooLarge(range.1))?; + let size = usize::try_from(range.1) + .map_err(|_| Error::MemoryRegionTooLarge(range.1 as u128))?; let mapping = MemoryMappingBuilder::new(size) .from_shared_memory(shm.as_ref()) .offset(offset) @@ -200,6 +224,34 @@ impl GuestMemory { }) } + /// Creates a `GuestMemory` from a collection of MemoryRegions. + pub fn from_regions(mut regions: Vec) -> Result { + // Sort the regions and ensure non overlap. + regions.sort_by(|a, b| a.guest_base.cmp(&b.guest_base)); + + if regions.len() > 1 { + let mut prev_end = regions[0] + .guest_base + .checked_add(regions[0].mapping.size() as u64) + .ok_or(Error::MemoryRegionOverlap)?; + for region in ®ions[1..] { + if prev_end > region.guest_base { + return Err(Error::MemoryRegionOverlap); + } + prev_end = region + .guest_base + .checked_add(region.mapping.size() as u64) + .ok_or(Error::MemoryRegionTooLarge( + region.guest_base.0 as u128 + region.mapping.size() as u128, + ))?; + } + } + + Ok(GuestMemory { + regions: Arc::from(regions), + }) + } + /// Returns the end address of memory. /// /// # Examples