diff --git a/devices/src/vfio.rs b/devices/src/vfio.rs index 670b0daa34..b291a7fae6 100644 --- a/devices/src/vfio.rs +++ b/devices/src/vfio.rs @@ -625,6 +625,19 @@ impl VfioDevice { } } + fn validate_dev_info(dev_info: &mut vfio_device_info) -> Result<(), VfioError> { + if (dev_info.flags & VFIO_DEVICE_FLAGS_PCI) != 0 { + if dev_info.num_regions < VFIO_PCI_CONFIG_REGION_INDEX + 1 + || dev_info.num_irqs < VFIO_PCI_MSIX_IRQ_INDEX + 1 + { + return Err(VfioError::VfioDeviceGetInfo(get_error())); + } + return Ok(()); + } + + Err(VfioError::VfioDeviceGetInfo(get_error())) + } + #[allow(clippy::cast_ptr_alignment)] fn get_regions(dev: &File) -> Result, VfioError> { let mut regions: Vec = Vec::new(); @@ -637,15 +650,12 @@ impl VfioDevice { // Safe as we are the owner of dev and dev_info which are valid value, // and we verify the return value. let mut ret = unsafe { ioctl_with_mut_ref(dev, VFIO_DEVICE_GET_INFO(), &mut dev_info) }; - if ret < 0 - || (dev_info.flags & VFIO_DEVICE_FLAGS_PCI) == 0 - || dev_info.num_regions < VFIO_PCI_CONFIG_REGION_INDEX + 1 - || dev_info.num_irqs < VFIO_PCI_MSIX_IRQ_INDEX + 1 - { + if ret < 0 { return Err(VfioError::VfioDeviceGetInfo(get_error())); } - for i in VFIO_PCI_BAR0_REGION_INDEX..dev_info.num_regions { + Self::validate_dev_info(&mut dev_info)?; + for i in 0..dev_info.num_regions { let argsz = mem::size_of::() as u32; let mut reg_info = vfio_region_info { argsz,