From 0090567d248657f7be04204f60e6c12cce18dc0f Mon Sep 17 00:00:00 2001 From: Daniel Verkamp Date: Thu, 10 Oct 2024 11:18:43 -0700 Subject: [PATCH] vhost: improve set_vring_addr() validation Use the GuestMemory get_slice_at_addr() function to ensure the memory regions corresponding to the descriptor table and used/avail rings are actually contiguous, rather than just validating the ending address of each region exists in guest memory. The log region is currently not validated, which matches previous behavior. BUG=None TEST=tools/dev_container tools/presubmit Change-Id: I1e80326dfae085380fbbf4dd1c7960dd72485793 Reviewed-on: https://chromium-review.googlesource.com/c/crosvm/crosvm/+/5922041 Commit-Queue: Daniel Verkamp Reviewed-by: Keiichi Watanabe --- vhost/src/lib.rs | 73 ++++++++++++------------------------------------ 1 file changed, 18 insertions(+), 55 deletions(-) diff --git a/vhost/src/lib.rs b/vhost/src/lib.rs index b474bf2f7c..cb57e662eb 100644 --- a/vhost/src/lib.rs +++ b/vhost/src/lib.rs @@ -202,42 +202,6 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized { Ok(()) } - // TODO(smbarber): This is copypasta. Eliminate the copypasta. - #[allow(clippy::if_same_then_else)] - fn is_valid( - &self, - mem: &GuestMemory, - queue_max_size: u16, - queue_size: u16, - desc_addr: GuestAddress, - avail_addr: GuestAddress, - used_addr: GuestAddress, - ) -> bool { - let desc_table_size = 16 * queue_size as usize; - let avail_ring_size = 6 + 2 * queue_size as usize; - let used_ring_size = 6 + 8 * queue_size as usize; - if queue_size > queue_max_size || queue_size == 0 || (queue_size & (queue_size - 1)) != 0 { - false - } else if desc_addr - .checked_add(desc_table_size as u64) - .map_or(true, |v| !mem.address_in_range(v)) - { - false - } else if avail_addr - .checked_add(avail_ring_size as u64) - .map_or(true, |v| !mem.address_in_range(v)) - { - false - } else if used_addr - .checked_add(used_ring_size as u64) - .map_or(true, |v| !mem.address_in_range(v)) - { - false - } else { - true - } - } - /// Set the addresses for a given vring. /// /// # Arguments @@ -261,28 +225,27 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized { avail_addr: GuestAddress, log_addr: Option, ) -> Result<()> { - // TODO(smbarber): Refactor out virtio from crosvm so we can - // validate a Queue struct directly. - if !self.is_valid( - mem, - queue_max_size, - queue_size, - desc_addr, - used_addr, - avail_addr, - ) { + if queue_size > queue_max_size || queue_size == 0 || !queue_size.is_power_of_two() { return Err(Error::InvalidQueue); } - let desc_addr = mem - .get_host_address(desc_addr) + let queue_size = usize::from(queue_size); + + let desc_table_size = 16 * queue_size; + let desc_table = mem + .get_slice_at_addr(desc_addr, desc_table_size) .map_err(Error::DescriptorTableAddress)?; - let used_addr = mem - .get_host_address(used_addr) + + let used_ring_size = 6 + 8 * queue_size; + let used_ring = mem + .get_slice_at_addr(used_addr, used_ring_size) .map_err(Error::UsedAddress)?; - let avail_addr = mem - .get_host_address(avail_addr) + + let avail_ring_size = 6 + 2 * queue_size; + let avail_ring = mem + .get_slice_at_addr(avail_addr, avail_ring_size) .map_err(Error::AvailAddress)?; + let log_addr = match log_addr { None => null(), Some(a) => mem.get_host_address(a).map_err(Error::LogAddress)?, @@ -291,9 +254,9 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized { let vring_addr = virtio_sys::vhost::vhost_vring_addr { index: queue_index as u32, flags, - desc_user_addr: desc_addr as u64, - used_user_addr: used_addr as u64, - avail_user_addr: avail_addr as u64, + desc_user_addr: desc_table.as_ptr() as u64, + used_user_addr: used_ring.as_ptr() as u64, + avail_user_addr: avail_ring.as_ptr() as u64, log_guest_addr: log_addr as u64, };