diff --git a/devices/src/vfio.rs b/devices/src/vfio.rs index 9f30a5c03d..c0b8a7356a 100644 --- a/devices/src/vfio.rs +++ b/devices/src/vfio.rs @@ -88,7 +88,7 @@ static KVM_VFIO_FILE: OnceCell = OnceCell::new(); /// VfioContainer contain multi VfioGroup, and delegate an IOMMU domain table pub struct VfioContainer { container: File, - groups: HashMap>, + groups: HashMap>>, } const VFIO_API_VERSION: u8 = 0; @@ -219,11 +219,16 @@ impl VfioContainer { Ok(()) } - fn get_group(&mut self, id: u32, vm: &impl Vm, iommu_enabled: bool) -> Result> { + fn get_group( + &mut self, + id: u32, + vm: &impl Vm, + iommu_enabled: bool, + ) -> Result>> { match self.groups.get(&id) { Some(group) => Ok(group.clone()), None => { - let group = Arc::new(VfioGroup::new(self, id)?); + let group = Arc::new(Mutex::new(VfioGroup::new(self, id)?)); if self.groups.is_empty() { // Before the first group is added into container, do once cotainer @@ -234,7 +239,7 @@ impl VfioContainer { let kvm_vfio_file = KVM_VFIO_FILE .get_or_try_init(|| vm.create_device(DeviceKind::Vfio)) .map_err(VfioError::CreateVfioKvmDevice)?; - group.kvm_device_add_group(kvm_vfio_file)?; + group.lock().kvm_device_add_group(kvm_vfio_file)?; self.groups.insert(id, group.clone()); @@ -242,6 +247,26 @@ impl VfioContainer { } } } + + fn remove_group(&mut self, id: u32, reduce: bool) { + let remove = match self.groups.get(&id) { + Some(group) => { + if reduce { + group.lock().reduce_device_num(); + } + if group.lock().device_num() == 0 { + true + } else { + false + } + } + None => false, + }; + + if remove { + self.groups.remove(&id); + } + } } impl AsRawDescriptor for VfioContainer { @@ -252,6 +277,7 @@ impl AsRawDescriptor for VfioContainer { struct VfioGroup { group: File, + device_num: u32, } impl VfioGroup { @@ -295,7 +321,10 @@ impl VfioGroup { return Err(VfioError::GroupSetContainer(get_error())); } - Ok(VfioGroup { group: group_file }) + Ok(VfioGroup { + group: group_file, + device_num: 0, + }) } fn get_group_id(sysfspath: &Path) -> Result { @@ -350,6 +379,18 @@ impl VfioGroup { // Safe as ret is valid FD Ok(unsafe { File::from_raw_descriptor(ret) }) } + + fn add_device_num(&mut self) { + self.device_num += 1; + } + + fn reduce_device_num(&mut self) { + self.device_num -= 1; + } + + fn device_num(&self) -> u32 { + self.device_num + } } impl AsRawDescriptor for VfioGroup { @@ -469,6 +510,7 @@ pub struct VfioDevice { name: String, container: Arc>, group_descriptor: RawDescriptor, + group_id: u32, // vec for vfio device's regions regions: Vec, } @@ -488,14 +530,30 @@ impl VfioDevice { let name_osstr = sysfspath.file_name().ok_or(VfioError::InvalidPath)?; let name_str = name_osstr.to_str().ok_or(VfioError::InvalidPath)?; let name = String::from(name_str); - let dev = group.get_device(&name)?; - let regions = Self::get_regions(&dev)?; + + let dev = match group.lock().get_device(&name) { + Ok(dev) => dev, + Err(e) => { + container.lock().remove_group(group_id, false); + return Err(e); + } + }; + let regions = match Self::get_regions(&dev) { + Ok(regions) => regions, + Err(e) => { + container.lock().remove_group(group_id, false); + return Err(e); + } + }; + group.lock().add_device_num(); + let group_descriptor = group.lock().as_raw_descriptor(); Ok(VfioDevice { dev, name, container, - group_descriptor: group.as_raw_descriptor(), + group_descriptor, + group_id, regions, }) } @@ -1007,6 +1065,11 @@ impl VfioDevice { pub fn device_file(&self) -> &File { &self.dev } + + /// close vfio device + pub fn close(&self) { + self.container.lock().remove_group(self.group_id, true); + } } pub struct VfioPciConfig {