diff --git a/devices/src/virtio/vhost/user/device/handler.rs b/devices/src/virtio/vhost/user/device/handler.rs index 32e4aaf876..bc4a863ecf 100644 --- a/devices/src/virtio/vhost/user/device/handler.rs +++ b/devices/src/virtio/vhost/user/device/handler.rs @@ -778,8 +778,7 @@ impl SharedMemoryMapper for VhostShmemMapper { } _ => bail!("unsupported source"), }; - let flags = VhostUserShmemMapMsgFlags::from_bits(libc::c_int::from(prot) as u8) - .context(format!("unsupported protection flags {:?}", prot))?; + let flags = VhostUserShmemMapMsgFlags::from(prot); let msg = VhostUserShmemMapMsg::new(self.shmid, offset, fd_offset, size, flags); self.frontend .shmem_map(&msg, &descriptor) diff --git a/devices/src/virtio/vhost/user/device/vvu/device.rs b/devices/src/virtio/vhost/user/device/vvu/device.rs index 76e8272edb..917d796c29 100644 --- a/devices/src/virtio/vhost/user/device/vvu/device.rs +++ b/devices/src/virtio/vhost/user/device/vvu/device.rs @@ -28,7 +28,6 @@ use base::Event; use base::MappedRegion; use base::MemoryMappingBuilder; use base::MemoryMappingBuilderUnix; -use base::Protection; use base::RawDescriptor; use base::SafeDescriptor; use cros_async::EventAsync; @@ -546,7 +545,7 @@ impl BackendChannelInner { let mapping = MemoryMappingBuilder::new(msg.len as usize) .from_descriptor(&file) .offset(msg.fd_offset) - .protection(Protection::from(msg.flags.bits() as libc::c_int)) + .protection(msg.flags.into()) .build() .context("failed to map file")?; diff --git a/devices/src/virtio/vhost/user/proxy.rs b/devices/src/virtio/vhost/user/proxy.rs index 9d7fbc93ae..a1a339c9e3 100644 --- a/devices/src/virtio/vhost/user/proxy.rs +++ b/devices/src/virtio/vhost/user/proxy.rs @@ -66,7 +66,6 @@ use vmm_vhost::message::VhostUserMemoryRegion; use vmm_vhost::message::VhostUserMsgHeader; use vmm_vhost::message::VhostUserMsgValidator; use vmm_vhost::message::VhostUserShmemMapMsg; -use vmm_vhost::message::VhostUserShmemMapMsgFlags; use vmm_vhost::message::VhostUserShmemUnmapMsg; use vmm_vhost::message::VhostUserU64; use vmm_vhost::Error as VhostError; @@ -984,15 +983,7 @@ impl Worker { .export(msg.fd_offset, msg.len) .context("failed to export")?; - let prot = match ( - msg.flags.contains(VhostUserShmemMapMsgFlags::MAP_R), - msg.flags.contains(VhostUserShmemMapMsgFlags::MAP_W), - ) { - (true, true) => Protection::read_write(), - (true, false) => Protection::read(), - (false, true) => Protection::write(), - (false, false) => bail!("unsupported protection"), - }; + let prot = Protection::from(msg.flags); let regions = regions .iter() .map(|r| { diff --git a/devices/src/virtio/vhost/user/vmm/handler.rs b/devices/src/virtio/vhost/user/vmm/handler.rs index b68a255ef7..a77426fceb 100644 --- a/devices/src/virtio/vhost/user/vmm/handler.rs +++ b/devices/src/virtio/vhost/user/vmm/handler.rs @@ -345,7 +345,7 @@ impl VhostUserMasterReqHandlerMut for BackendReqHandlerImpl { gpu_blob: false, }, req.shm_offset, - Protection::from(req.flags.bits() as libc::c_int), + Protection::from(req.flags), ) { Ok(()) => Ok(0), Err(e) => { diff --git a/third_party/vmm_vhost/src/message.rs b/third_party/vmm_vhost/src/message.rs index e6a6174348..d1e9f943b2 100644 --- a/third_party/vmm_vhost/src/message.rs +++ b/third_party/vmm_vhost/src/message.rs @@ -13,6 +13,7 @@ use std::convert::TryInto; use std::fmt::Debug; use std::marker::PhantomData; +use base::Protection; use bitflags::bitflags; use data_model::DataInit; @@ -863,6 +864,28 @@ bitflags! { } } +impl From for VhostUserShmemMapMsgFlags { + fn from(prot: Protection) -> Self { + let mut flags = Self::EMPTY; + flags.set(Self::MAP_R, prot.allows(&Protection::read())); + flags.set(Self::MAP_W, prot.allows(&Protection::write())); + flags + } +} + +impl From for Protection { + fn from(flags: VhostUserShmemMapMsgFlags) -> Self { + let mut prot = Protection::from(0); + if flags.contains(VhostUserShmemMapMsgFlags::MAP_R) { + prot = prot.set_read(); + } + if flags.contains(VhostUserShmemMapMsgFlags::MAP_W) { + prot = prot.set_write(); + } + prot + } +} + /// Slave request message to map a file into a shared memory region. #[repr(C, packed)] #[derive(Default, Copy, Clone)]