diff --git a/vm_memory/src/guest_memory.rs b/vm_memory/src/guest_memory.rs index 049f8efabf..85f765607f 100644 --- a/vm_memory/src/guest_memory.rs +++ b/vm_memory/src/guest_memory.rs @@ -28,6 +28,7 @@ use crate::guest_address::GuestAddress; pub enum Error { InvalidGuestAddress(GuestAddress), InvalidOffset(u64), + InvalidSize(usize), MemoryAccess(GuestAddress, MmapError), MemoryMappingFailed(MmapError), MemoryRegionOverlap, @@ -51,6 +52,7 @@ impl Display for Error { match self { InvalidGuestAddress(addr) => write!(f, "invalid guest address {}", addr), InvalidOffset(addr) => write!(f, "invalid offset {}", addr), + InvalidSize(size) => write!(f, "size {} must not be zero", size), MemoryAccess(addr, e) => { write!(f, "invalid guest memory access at addr={}: {}", addr, e) } @@ -686,6 +688,52 @@ impl GuestMemory { }) } + /// Convert a GuestAddress into a pointer in the address space of this + /// process, and verify that the provided size define a valid range within + /// a single memory region. Similar to get_host_address(), this should only + /// be used for giving addresses to the kernel. + /// + /// # Arguments + /// * `guest_addr` - Guest address to convert. + /// * `size` - Size of the address range to be converted. + /// + /// # Examples + /// + /// ``` + /// # use vm_memory::{GuestAddress, GuestMemory}; + /// # fn test_host_addr() -> Result<(), ()> { + /// let start_addr = GuestAddress(0x1000); + /// let mut gm = GuestMemory::new(&vec![(start_addr, 0x500)]).map_err(|_| ())?; + /// let addr = gm.get_host_address_range(GuestAddress(0x1200), 0x200).unwrap(); + /// println!("Host address is {:p}", addr); + /// Ok(()) + /// # } + /// ``` + pub fn get_host_address_range( + &self, + guest_addr: GuestAddress, + size: usize, + ) -> Result<*const u8> { + if size == 0 { + return Err(Error::InvalidSize(size)); + } + + // Assume no overlap among regions + self.do_in_region(guest_addr, |mapping, offset, _| { + if mapping + .size() + .checked_sub(offset) + .map_or(true, |v| v < size) + { + return Err(Error::InvalidGuestAddress(guest_addr)); + } + + // This is safe; `do_in_region` already checks that offset is in + // bounds. + Ok(unsafe { mapping.as_ptr().add(offset) } as *const u8) + }) + } + /// Returns a reference to the SharedMemory region that backs the given address. pub fn shm_region(&self, guest_addr: GuestAddress) -> Result<&SharedMemory> { self.regions @@ -919,6 +967,31 @@ mod tests { assert!(mem.get_host_address(bad_addr).is_err()); } + #[test] + fn guest_to_host_range() { + let start_addr1 = GuestAddress(0x0); + let start_addr2 = GuestAddress(0x1000); + let mem = GuestMemory::new(&[(start_addr1, 0x1000), (start_addr2, 0x4000)]).unwrap(); + + // Verify the host addresses match what we expect from the mappings. + let addr1_base = get_mapping(&mem, start_addr1).unwrap(); + let addr2_base = get_mapping(&mem, start_addr2).unwrap(); + let host_addr1 = mem.get_host_address_range(start_addr1, 0x1000).unwrap(); + let host_addr2 = mem.get_host_address_range(start_addr2, 0x1000).unwrap(); + assert_eq!(host_addr1, addr1_base); + assert_eq!(host_addr2, addr2_base); + + let host_addr3 = mem.get_host_address_range(start_addr2, 0x2000).unwrap(); + assert_eq!(host_addr3, addr2_base); + + // Check that a valid guest address with an invalid size returns an error. + assert!(mem.get_host_address_range(start_addr1, 0x2000).is_err()); + + // Check that a bad address returns an error. + let bad_addr = GuestAddress(0x123456); + assert!(mem.get_host_address_range(bad_addr, 0x1000).is_err()); + } + #[test] fn shm_offset() { if !kernel_has_memfd() {