diff --git a/devices/src/lib.rs b/devices/src/lib.rs index 865cdeabba..3fd6795729 100644 --- a/devices/src/lib.rs +++ b/devices/src/lib.rs @@ -69,7 +69,7 @@ pub use self::usb::host_backend::host_backend_device_provider::HostBackendDevice #[cfg(feature = "usb")] pub use self::usb::xhci::xhci_controller::XhciController; pub use self::vfio::{VfioContainer, VfioDevice}; -pub use self::virtio::VirtioPciDevice; +pub use self::virtio::{vfio_wrapper, VirtioPciDevice}; /// Request CoIOMMU to unpin a specific range. use serde::{Deserialize, Serialize}; diff --git a/devices/src/vfio.rs b/devices/src/vfio.rs index 91a69f508e..620ead9747 100644 --- a/devices/src/vfio.rs +++ b/devices/src/vfio.rs @@ -34,6 +34,8 @@ use vm_memory::GuestMemory; pub enum VfioError { #[error("failed to borrow global vfio container")] BorrowVfioContainer, + #[error("failed to duplicate VfioContainer")] + ContainerDupError, #[error("failed to set container's IOMMU driver type as VfioType1V2: {0}")] ContainerSetIOMMU(Error), #[error("failed to create KVM vfio device: {0}")] @@ -140,6 +142,21 @@ impl VfioContainer { Self::new_inner(false /* host_iommu */) } + // Construct a VfioContainer from an exist container file. + pub fn new_from_container(container: File) -> Result { + // Safe as file is vfio container descriptor and ioctl is defined by kernel. + let version = unsafe { ioctl(&container, VFIO_GET_API_VERSION()) }; + if version as u8 != VFIO_API_VERSION { + return Err(VfioError::VfioApiVersion); + } + + Ok(VfioContainer { + container, + groups: HashMap::new(), + host_iommu: true, + }) + } + fn is_group_set(&self, group_id: u32) -> bool { self.groups.get(&group_id).is_some() } @@ -310,6 +327,15 @@ impl VfioContainer { self.groups.remove(&id); } } + + pub fn into_raw_descriptor(&self) -> Result { + let raw_descriptor = unsafe { libc::dup(self.container.as_raw_descriptor()) }; + if raw_descriptor < 0 { + Err(VfioError::ContainerDupError) + } else { + Ok(raw_descriptor) + } + } } impl AsRawDescriptor for VfioContainer { diff --git a/devices/src/virtio/iommu.rs b/devices/src/virtio/iommu.rs index 6e993e4603..c73e37193f 100644 --- a/devices/src/virtio/iommu.rs +++ b/devices/src/virtio/iommu.rs @@ -4,6 +4,7 @@ use std::cell::RefCell; use std::collections::BTreeMap; +use std::fs::File; use std::io::{self, Write}; use std::mem::size_of; use std::ops::RangeInclusive; @@ -23,7 +24,9 @@ use futures::{select, FutureExt}; use remain::sorted; use sync::Mutex; use thiserror::Error; -use vm_control::VirtioIOMMURequest; +use vm_control::{ + VirtioIOMMURequest, VirtioIOMMUResponse, VirtioIOMMUVfioCommand, VirtioIOMMUVfioResult, +}; use vm_memory::{GuestAddress, GuestMemory, GuestMemoryError}; use crate::pci::PciAddress; @@ -31,6 +34,7 @@ use crate::virtio::{ async_utils, copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, SignalableInterrupt, VirtioDevice, Writer, TYPE_IOMMU, }; +use crate::VfioContainer; pub mod protocol; use crate::virtio::iommu::protocol::*; @@ -41,6 +45,8 @@ pub mod memory_util; pub mod vfio_wrapper; use crate::virtio::iommu::memory_mapper::{Error as MemoryMapperError, *}; +use self::vfio_wrapper::VfioWrapper; + const QUEUE_SIZE: u16 = 256; const NUM_QUEUES: usize = 2; const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES]; @@ -138,6 +144,9 @@ pub enum IommuError { struct Worker { mem: GuestMemory, page_mask: u64, + // Hot-pluggable PCI endpoints ranges + // RangeInclusive: (start endpoint PCI address .. =end endpoint PCI address) + hp_endpoints_ranges: Vec>, // All PCI endpoints that attach to certain IOMMU domain // key: endpoint PCI address // value: attached domain ID @@ -477,12 +486,97 @@ impl Worker { } } + fn handle_add_vfio_device( + mem: &GuestMemory, + endpoint_addr: u32, + container_fd: File, + endpoints: &Rc>>>>>, + hp_endpoints_ranges: &Rc>>, + ) -> VirtioIOMMUVfioResult { + let exists = |endpoint_addr: u32| -> bool { + for endpoints_range in hp_endpoints_ranges.iter() { + if endpoints_range.contains(&endpoint_addr) { + return true; + } + } + false + }; + + if !exists(endpoint_addr) { + return VirtioIOMMUVfioResult::NotInPCIRanges; + } + + let vfio_container = match VfioContainer::new_from_container(container_fd) { + Ok(vfio_container) => vfio_container, + Err(e) => { + error!("failed to verify the new container: {}", e); + return VirtioIOMMUVfioResult::NoAvailableContainer; + } + }; + endpoints.borrow_mut().insert( + endpoint_addr, + Arc::new(Mutex::new(Box::new(VfioWrapper::new( + Arc::new(Mutex::new(vfio_container)), + mem.clone(), + )))), + ); + VirtioIOMMUVfioResult::Ok + } + + fn handle_del_vfio_device( + pci_address: u32, + endpoints: &Rc>>>>>, + ) -> VirtioIOMMUVfioResult { + if endpoints.borrow_mut().remove(&pci_address).is_none() { + error!("There is no vfio container of {}", pci_address); + return VirtioIOMMUVfioResult::NoSuchDevice; + } + VirtioIOMMUVfioResult::Ok + } + + fn handle_vfio( + mem: &GuestMemory, + vfio_cmd: VirtioIOMMUVfioCommand, + endpoints: &Rc>>>>>, + hp_endpoints_ranges: &Rc>>, + ) -> VirtioIOMMUResponse { + use VirtioIOMMUVfioCommand::*; + let vfio_result = match vfio_cmd { + VfioDeviceAdd { + endpoint_addr, + container, + } => Self::handle_add_vfio_device( + mem, + endpoint_addr, + container, + endpoints, + hp_endpoints_ranges, + ), + VfioDeviceDel { endpoint_addr } => { + Self::handle_del_vfio_device(endpoint_addr, endpoints) + } + }; + VirtioIOMMUResponse::VfioResponse(vfio_result) + } + // Async task that handles messages from the host - pub async fn handle_command_tube(command_tube: &AsyncTube) -> Result<()> { + async fn handle_command_tube( + mem: &GuestMemory, + command_tube: AsyncTube, + endpoints: &Rc>>>>>, + hp_endpoints_ranges: &Rc>>, + ) -> Result<()> { loop { match command_tube.next::().await { - Ok(_) => { - // To-Do: handle the requests from virtio-iommu tube + Ok(command) => { + let response: VirtioIOMMUResponse = match command { + VirtioIOMMURequest::VfioCommand(vfio_cmd) => { + Self::handle_vfio(mem, vfio_cmd, endpoints, hp_endpoints_ranges) + } + }; + if let Err(e) = command_tube.send(&response) { + error!("{}", IommuError::VirtioIOMMUResponseError(e)); + } } Err(e) => { return Err(IommuError::VirtioIOMMUReqError(e)); @@ -513,6 +607,11 @@ impl Worker { let (req_queue, req_evt) = (queues.remove(0), evts_async.remove(0)); + let hp_endpoints_ranges = Rc::new(self.hp_endpoints_ranges.clone()); + let mem = Rc::new(self.mem.clone()); + // contains all pass-through endpoints that attach to this IOMMU device + // key: endpoint PCI address + // value: reference counter and MemoryMapperTrait let endpoints: Rc>>>>> = Rc::new(RefCell::new(endpoints)); @@ -541,7 +640,7 @@ impl Worker { let command_tube = iommu_device_tube.into_async_tube(&ex).unwrap(); // Future to handle command messages from host, such as passing vfio containers. - let f_cmd = Self::handle_command_tube(&command_tube); + let f_cmd = Self::handle_command_tube(&mem, command_tube, &endpoints, &hp_endpoints_ranges); let done = async { select! { @@ -608,6 +707,9 @@ pub struct Iommu { worker_thread: Option>, config: virtio_iommu_config, avail_features: u64, + // Attached endpoints + // key: endpoint PCI address + // value: reference counter and MemoryMapperTrait endpoints: BTreeMap>>>, // Hot-pluggable PCI endpoints ranges // RangeInclusive: (start endpoint PCI address .. =end endpoint PCI address) @@ -704,6 +806,10 @@ impl VirtioDevice for Iommu { rds.push(rx.as_raw_descriptor()); } + if let Some(iommu_device_tube) = &self.iommu_device_tube { + rds.push(iommu_device_tube.as_raw_descriptor()); + } + rds } @@ -749,6 +855,7 @@ impl VirtioDevice for Iommu { // granularity of IOMMU mappings let page_mask = (1u64 << u64::from(self.config.page_size_mask).trailing_zeros()) - 1; let eps = self.endpoints.clone(); + let hp_endpoints_ranges = self.hp_endpoints_ranges.to_owned(); let translate_response_senders = self.translate_response_senders.take(); let translate_request_rx = self.translate_request_rx.take(); @@ -761,6 +868,7 @@ impl VirtioDevice for Iommu { let mut worker = Worker { mem, page_mask, + hp_endpoints_ranges, endpoint_map: BTreeMap::new(), domain_map: BTreeMap::new(), }; diff --git a/devices/src/virtio/iommu/memory_mapper.rs b/devices/src/virtio/iommu/memory_mapper.rs index 7a851010fb..411dba9ed4 100644 --- a/devices/src/virtio/iommu/memory_mapper.rs +++ b/devices/src/virtio/iommu/memory_mapper.rs @@ -16,6 +16,7 @@ use thiserror::Error; use vm_memory::{GuestAddress, GuestMemoryError}; use crate::vfio::VfioError; +use crate::vfio_wrapper::VfioWrapper; #[repr(u8)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] @@ -100,6 +101,19 @@ pub trait MemoryMapper: Send { fn add_map(&mut self, new_map: MappingInfo) -> Result<()>; fn remove_map(&mut self, iova_start: u64, size: u64) -> Result<()>; fn get_mask(&self) -> Result; + + /// Trait for generic MemoryMapper abstraction, that is, all reside on MemoryMapper and want to + /// be converted back to its original type. Each must provide as_XXX_wrapper() + + /// as_XXX_wrapper_mut() + into_XXX_wrapper(), default impl methods return None. + fn as_vfio_wrapper(&self) -> Option<&VfioWrapper> { + None + } + fn as_vfio_wrapper_mut(&mut self) -> Option<&mut VfioWrapper> { + None + } + fn into_vfio_wrapper(self: Box) -> Option> { + None + } } pub trait Translate { diff --git a/devices/src/virtio/iommu/vfio_wrapper.rs b/devices/src/virtio/iommu/vfio_wrapper.rs index 83aa256c66..a25ab66f8b 100644 --- a/devices/src/virtio/iommu/vfio_wrapper.rs +++ b/devices/src/virtio/iommu/vfio_wrapper.rs @@ -24,6 +24,10 @@ impl VfioWrapper { pub fn new(container: Arc>, mem: GuestMemory) -> Self { Self { container, mem } } + + pub fn as_vfio_container(&self) -> Arc> { + self.container.clone() + } } impl MemoryMapper for VfioWrapper { @@ -70,6 +74,16 @@ impl MemoryMapper for VfioWrapper { .vfio_get_iommu_page_size_mask() .map_err(MemoryMapperError::Vfio) } + + fn as_vfio_wrapper(&self) -> Option<&VfioWrapper> { + Some(self) + } + fn as_vfio_wrapper_mut(&mut self) -> Option<&mut VfioWrapper> { + Some(self) + } + fn into_vfio_wrapper(self: Box) -> Option> { + Some(self) + } } impl Translate for VfioWrapper { diff --git a/src/linux/device_helpers.rs b/src/linux/device_helpers.rs index 2ce04f4016..53d58e4000 100644 --- a/src/linux/device_helpers.rs +++ b/src/linux/device_helpers.rs @@ -13,7 +13,12 @@ use std::path::{Path, PathBuf}; use std::str; use std::sync::Arc; +use crate::{ + Config, DiskOption, TouchDeviceOption, VhostUserFsOption, VhostUserOption, VhostUserWlOption, + VhostVsockDeviceParameter, VvuOption, +}; use anyhow::{anyhow, bail, Context, Result}; +use arch::{self, VirtioDeviceStub}; use base::*; use devices::serial_device::SerialParameters; use devices::vfio::{VfioCommonSetup, VfioCommonTrait}; @@ -42,12 +47,6 @@ use resources::{Alloc, MmioType, SystemAllocator}; use sync::Mutex; use vm_memory::GuestAddress; -use crate::{ - Config, DiskOption, TouchDeviceOption, VhostUserFsOption, VhostUserOption, VhostUserWlOption, - VhostVsockDeviceParameter, VvuOption, -}; -use arch::{self, VirtioDeviceStub}; - use super::jail_helpers::*; pub enum TaggedControlTube { diff --git a/src/linux/mod.rs b/src/linux/mod.rs index 8dddbac86d..8ea04f379b 100644 --- a/src/linux/mod.rs +++ b/src/linux/mod.rs @@ -1186,9 +1186,8 @@ where &mut devices, )?; - if !iommu_attached_endpoints.is_empty() { - let (_iommu_host_tube, iommu_device_tube) = - Tube::pair().context("failed to create tube")?; + let iommu_host_tube = if !iommu_attached_endpoints.is_empty() { + let (iommu_host_tube, iommu_device_tube) = Tube::pair().context("failed to create tube")?; let iommu_dev = create_iommu_device( &cfg, (1u64 << vm.get_guest_phys_addr_bits()) - 1, @@ -1208,7 +1207,10 @@ where .context("failed to allocate resources early for virtio pci dev")?; let dev = Box::new(dev); devices.push((dev, iommu_dev.jail)); - } + Some(iommu_host_tube) + } else { + None + }; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] for device in devices @@ -1339,6 +1341,7 @@ where Arc::clone(&map_request), gralloc, kvm_vcpu_ids, + iommu_host_tube, ) } @@ -1359,6 +1362,7 @@ fn add_vfio_device( sys_allocator: &mut SystemAllocator, cfg: &Config, control_tubes: &mut Vec, + iommu_host_tube: &Option, vfio_path: &Path, ) -> Result<()> { let host_os_str = vfio_path @@ -1381,12 +1385,39 @@ fn add_vfio_device( Some(bus_num), &mut endpoints, None, - IommuDevType::NoIommu, + if iommu_host_tube.is_some() { + IommuDevType::VirtioIommu + } else { + IommuDevType::NoIommu + }, )?; let pci_address = Arch::register_pci_device(linux, vfio_pci_device, jail, sys_allocator) .context("Failed to configure pci hotplug device")?; + if let Some(iommu_host_tube) = iommu_host_tube { + let &endpoint_addr = endpoints.iter().next().unwrap().0; + let mapper = endpoints.remove(&endpoint_addr).unwrap(); + if let Some(vfio_wrapper) = mapper.lock().as_vfio_wrapper() { + let vfio_container = vfio_wrapper.as_vfio_container(); + let descriptor = vfio_container.lock().into_raw_descriptor()?; + let request = VirtioIOMMURequest::VfioCommand(VirtioIOMMUVfioCommand::VfioDeviceAdd { + endpoint_addr, + container: { + // Safe because the descriptor is uniquely owned by `descriptor`. + unsafe { File::from_raw_descriptor(descriptor) } + }, + }); + + match virtio_iommu_request(iommu_host_tube, &request) + .map_err(|_| VirtioIOMMUVfioError::SocketFailed)? + { + VirtioIOMMUResponse::VfioResponse(VirtioIOMMUVfioResult::Ok) => (), + resp => bail!("Unexpected message response: {:?}", resp), + } + }; + } + let host_os_str = vfio_path .file_name() .ok_or_else(|| anyhow!("failed to parse or find vfio path"))?; @@ -1404,6 +1435,7 @@ fn add_vfio_device( fn remove_vfio_device( linux: &RunnableLinuxVm, sys_allocator: &mut SystemAllocator, + iommu_host_tube: &Option, vfio_path: &Path, ) -> Result<()> { let host_os_str = vfio_path @@ -1417,6 +1449,19 @@ fn remove_vfio_device( for hp_bus in linux.hotplug_bus.iter() { let mut hp_bus_lock = hp_bus.lock(); if let Some(pci_addr) = hp_bus_lock.get_hotplug_device(host_key) { + if let Some(iommu_host_tube) = iommu_host_tube { + let request = + VirtioIOMMURequest::VfioCommand(VirtioIOMMUVfioCommand::VfioDeviceDel { + endpoint_addr: pci_addr.to_u32(), + }); + match virtio_iommu_request(iommu_host_tube, &request) + .map_err(|_| VirtioIOMMUVfioError::SocketFailed)? + { + VirtioIOMMUResponse::VfioResponse(VirtioIOMMUVfioResult::Ok) => (), + resp => bail!("Unexpected message response: {:?}", resp), + } + } + hp_bus_lock.hot_unplug(pci_addr); sys_allocator.release_pci(pci_addr.bus, pci_addr.dev, pci_addr.func); return Ok(()); @@ -1431,13 +1476,21 @@ fn handle_vfio_command( sys_allocator: &mut SystemAllocator, cfg: &Config, add_tubes: &mut Vec, + iommu_host_tube: &Option, vfio_path: &Path, add: bool, ) -> VmResponse { let ret = if add { - add_vfio_device(linux, sys_allocator, cfg, add_tubes, vfio_path) + add_vfio_device( + linux, + sys_allocator, + cfg, + add_tubes, + iommu_host_tube, + vfio_path, + ) } else { - remove_vfio_device(linux, sys_allocator, vfio_path) + remove_vfio_device(linux, sys_allocator, iommu_host_tube, vfio_path) }; match ret { @@ -1467,6 +1520,7 @@ fn run_control( map_request: Arc>>, mut gralloc: RutabagaGralloc, kvm_vcpu_ids: Vec, + iommu_host_tube: Option, ) -> Result { #[derive(PollToken)] enum Token { @@ -1746,6 +1800,7 @@ fn run_control( &mut sys_allocator, &cfg, &mut add_tubes, + &iommu_host_tube, &vfio_path, add, ) diff --git a/vm_control/src/lib.rs b/vm_control/src/lib.rs index 3e4d3a533a..86b9493053 100644 --- a/vm_control/src/lib.rs +++ b/vm_control/src/lib.rs @@ -25,6 +25,9 @@ use std::sync::{mpsc, Arc}; use std::thread::JoinHandle; +use remain::sorted; +use thiserror::Error; + use libc::{EINVAL, EIO, ENODEV, ENOTSUP}; use serde::{Deserialize, Serialize}; @@ -1191,5 +1194,123 @@ impl Display for VmResponse { } } +#[sorted] +#[derive(Error, Debug)] +pub enum VirtioIOMMUVfioError { + #[error("socket failed")] + SocketFailed, + #[error("unexpected response: {0}")] + UnexpectedResponse(VirtioIOMMUResponse), + #[error("unknown command: `{0}`")] + UnknownCommand(String), + #[error("{0}")] + VfioControl(VirtioIOMMUVfioResult), +} + #[derive(Serialize, Deserialize, Debug)] -pub enum VirtioIOMMURequest {} +pub enum VirtioIOMMUVfioCommand { + // Add the vfio device attached to virtio-iommu. + VfioDeviceAdd { + endpoint_addr: u32, + #[serde(with = "with_as_descriptor")] + container: File, + }, + // Delete the vfio device attached to virtio-iommu. + VfioDeviceDel { + endpoint_addr: u32, + }, +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum VirtioIOMMUVfioResult { + Ok, + NotInPCIRanges, + NoAvailableContainer, + NoSuchDevice, +} + +impl Display for VirtioIOMMUVfioResult { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use self::VirtioIOMMUVfioResult::*; + + match self { + Ok => write!(f, "successfully"), + NotInPCIRanges => write!(f, "not in the pci ranges of virtio-iommu"), + NoAvailableContainer => write!(f, "no available vfio container"), + NoSuchDevice => write!(f, "no such a vfio device"), + } + } +} + +/// A request to the virtio-iommu process to perform some operations. +/// +/// Unless otherwise noted, each request should expect a `VirtioIOMMUResponse::Ok` to be received on +/// success. +#[derive(Serialize, Deserialize, Debug)] +pub enum VirtioIOMMURequest { + /// Command for vfio related operations. + VfioCommand(VirtioIOMMUVfioCommand), +} + +/// Indication of success or failure of a `VirtioIOMMURequest`. +/// +/// Success is usually indicated `VirtioIOMMUResponse::Ok` unless there is data associated with the +/// response. +#[derive(Serialize, Deserialize, Debug)] +pub enum VirtioIOMMUResponse { + /// Indicates the request was executed successfully. + Ok, + /// Indicates the request encountered some error during execution. + Err(SysError), + /// Results for Vfio commands. + VfioResponse(VirtioIOMMUVfioResult), +} + +impl Display for VirtioIOMMUResponse { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use self::VirtioIOMMUResponse::*; + match self { + Ok => write!(f, "ok"), + Err(e) => write!(f, "error: {}", e), + VfioResponse(result) => write!( + f, + "The vfio-related virtio-iommu request got result: {:?}", + result + ), + } + } +} + +/// Send VirtioIOMMURequest without waiting for the response +pub fn virtio_iommu_request_async( + iommu_control_tube: &Tube, + req: &VirtioIOMMURequest, +) -> VirtioIOMMUResponse { + match iommu_control_tube.send(&req) { + Ok(_) => VirtioIOMMUResponse::Ok, + Err(e) => { + error!("virtio-iommu socket send failed: {:?}", e); + VirtioIOMMUResponse::Err(SysError::last()) + } + } +} + +pub type VirtioIOMMURequestResult = std::result::Result; + +/// Send VirtioIOMMURequest and wait to get the response +pub fn virtio_iommu_request( + iommu_control_tube: &Tube, + req: &VirtioIOMMURequest, +) -> VirtioIOMMURequestResult { + let response = match virtio_iommu_request_async(iommu_control_tube, req) { + VirtioIOMMUResponse::Ok => match iommu_control_tube.recv() { + Ok(response) => response, + Err(e) => { + error!("virtio-iommu socket recv failed: {:?}", e); + VirtioIOMMUResponse::Err(SysError::last()) + } + }, + resp => resp, + }; + Ok(response) +}