From 4557915eef3a84f9913dabb0dbcbaac1113d8fe5 Mon Sep 17 00:00:00 2001 From: Richard Date: Mon, 28 Mar 2022 14:25:20 -0700 Subject: [PATCH] descriptor_reflection: Dedupe cross platform descriptor_reflection.rs is the exact same on both windows and unix, so moved it to the outer module in base Test: built and presubmit Bug: 215619368 Change-Id: I346fa58e651953e2a77b806fa7456af2c1b02cb9 Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/3555732 Reviewed-by: Dennis Kempin Tested-by: kokoro Commit-Queue: Richard Zhang --- .../{windows => }/descriptor_reflection.rs | 2 +- base/src/lib.rs | 1 + base/src/unix/descriptor_reflection.rs | 543 ------------------ base/src/unix/mod.rs | 9 +- base/src/windows/mod.rs | 9 +- 5 files changed, 10 insertions(+), 554 deletions(-) rename base/src/{windows => }/descriptor_reflection.rs (99%) delete mode 100644 base/src/unix/descriptor_reflection.rs diff --git a/base/src/windows/descriptor_reflection.rs b/base/src/descriptor_reflection.rs similarity index 99% rename from base/src/windows/descriptor_reflection.rs rename to base/src/descriptor_reflection.rs index bf7cff1fc0..d75d98811e 100644 --- a/base/src/windows/descriptor_reflection.rs +++ b/base/src/descriptor_reflection.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Chromium OS Authors. All rights reserved. +// Copyright 2020 The Chromium OS Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. diff --git a/base/src/lib.rs b/base/src/lib.rs index f25a964eff..45c72d89a6 100644 --- a/base/src/lib.rs +++ b/base/src/lib.rs @@ -3,6 +3,7 @@ // found in the LICENSE file. pub mod common; +pub mod descriptor_reflection; #[cfg(unix)] pub mod unix; diff --git a/base/src/unix/descriptor_reflection.rs b/base/src/unix/descriptor_reflection.rs deleted file mode 100644 index f05f5410c4..0000000000 --- a/base/src/unix/descriptor_reflection.rs +++ /dev/null @@ -1,543 +0,0 @@ -// Copyright 2020 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -//! Provides infrastructure for de/serializing descriptors embedded in Rust data structures. -//! -//! # Example -//! -//! ``` -//! use serde_json::to_string; -//! use crate::platform::{ -//! FileSerdeWrapper, FromRawDescriptor, SafeDescriptor, SerializeDescriptors, -//! deserialize_with_descriptors, -//! }; -//! use tempfile::tempfile; -//! -//! let tmp_f = tempfile().unwrap(); -//! -//! // Uses a simple wrapper to serialize a File because we can't implement Serialize for File. -//! let data = FileSerdeWrapper(tmp_f); -//! -//! // Wraps Serialize types to collect side channel descriptors as Serialize is called. -//! let data_wrapper = SerializeDescriptors::new(&data); -//! -//! // Use the wrapper with any serializer to serialize data is normal, grabbing descriptors -//! // as the data structures are serialized by the serializer. -//! let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize"); -//! -//! // If data_wrapper contains any side channel descriptor refs -//! // (it contains tmp_f in this case), we can retrieve the actual descriptors -//! // from the side channel using into_descriptors(). -//! let out_descriptors = data_wrapper.into_descriptors(); -//! -//! // When sending out_json over some transport, also send out_descriptors. -//! -//! // For this example, we aren't really transporting data across the process, but we do need to -//! // convert the descriptor type. -//! let mut safe_descriptors = out_descriptors -//! .iter() -//! .map(|&v| Some(unsafe { SafeDescriptor::from_raw_descriptor(v) })) -//! .collect(); -//! std::mem::forget(data); // Prevent double drop of tmp_f. -//! -//! // The deserialize_with_descriptors function is used give the descriptor deserializers access -//! // to side channel descriptors. -//! let res: FileSerdeWrapper = -//! deserialize_with_descriptors(|| serde_json::from_str(&out_json), &mut safe_descriptors) -//! .expect("failed to deserialize"); -//! ``` - -use std::{ - cell::{Cell, RefCell}, - convert::TryInto, - fmt, - fs::File, - ops::{Deref, DerefMut}, - panic::{catch_unwind, resume_unwind, AssertUnwindSafe}, -}; - -use serde::{ - de::{ - Error, Visitor, {self}, - }, - ser, Deserialize, Deserializer, Serialize, Serializer, -}; - -use super::{RawDescriptor, SafeDescriptor}; - -thread_local! { - static DESCRIPTOR_DST: RefCell>> = Default::default(); -} - -/// Initializes the thread local storage for descriptor serialization. Fails if it was already -/// initialized without an intervening `take_descriptor_dst` on this thread. -fn init_descriptor_dst() -> Result<(), &'static str> { - DESCRIPTOR_DST.with(|d| { - let mut descriptors = d.borrow_mut(); - if descriptors.is_some() { - return Err( - "attempt to initialize descriptor destination that was already initialized", - ); - } - *descriptors = Some(Default::default()); - Ok(()) - }) -} - -/// Takes the thread local storage for descriptor serialization. Fails if there wasn't a prior call -/// to `init_descriptor_dst` on this thread. -fn take_descriptor_dst() -> Result, &'static str> { - match DESCRIPTOR_DST.with(|d| d.replace(None)) { - Some(d) => Ok(d), - None => Err("attempt to take descriptor destination before it was initialized"), - } -} - -/// Pushes a descriptor on the thread local destination of descriptors, returning the index in which -/// the descriptor was pushed. -// -/// Returns Err if the thread local destination was not already initialized. -fn push_descriptor(rd: RawDescriptor) -> Result { - DESCRIPTOR_DST.with(|d| { - d.borrow_mut() - .as_mut() - .ok_or("attempt to serialize descriptor without descriptor destination") - .map(|descriptors| { - let index = descriptors.len(); - descriptors.push(rd); - index - }) - }) -} - -/// Serializes a descriptor for later retrieval in a parent `SerializeDescriptors` struct. -/// -/// If there is no parent `SerializeDescriptors` being serialized, this will return an error. -/// -/// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with = -/// "...")]` attribute which will make use of this function. -pub fn serialize_descriptor( - rd: &RawDescriptor, - se: S, -) -> std::result::Result { - let index = push_descriptor(*rd).map_err(ser::Error::custom)?; - se.serialize_u32( - index - .try_into() - .map_err(|_| ser::Error::custom("attempt to serialize too many descriptors at once"))?, - ) -} - -/// Wrapper for a `Serialize` value which will capture any descriptors exported by the value when -/// given to an ordinary `Serializer`. -/// -/// This is the corresponding type to use for serialization before using -/// `deserialize_with_descriptors`. -/// -/// # Examples -/// -/// ``` -/// use serde_json::to_string; -/// use crate::platform::{FileSerdeWrapper, SerializeDescriptors}; -/// use tempfile::tempfile; -/// -/// let tmp_f = tempfile().unwrap(); -/// let data = FileSerdeWrapper(tmp_f); -/// let data_wrapper = SerializeDescriptors::new(&data); -/// -/// // Serializes `v` as normal... -/// let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize"); -/// // If `serialize_descriptor` was called, we can capture the descriptors from here. -/// let out_descriptors = data_wrapper.into_descriptors(); -/// ``` -pub struct SerializeDescriptors<'a, T: Serialize>(&'a T, Cell>); - -impl<'a, T: Serialize> SerializeDescriptors<'a, T> { - pub fn new(inner: &'a T) -> Self { - Self(inner, Default::default()) - } - - pub fn into_descriptors(self) -> Vec { - self.1.into_inner() - } -} - -impl<'a, T: Serialize> Serialize for SerializeDescriptors<'a, T> { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - init_descriptor_dst().map_err(ser::Error::custom)?; - - // catch_unwind is used to ensure that init_descriptor_dst is always balanced with a call to - // take_descriptor_dst afterwards. - let res = catch_unwind(AssertUnwindSafe(|| self.0.serialize(serializer))); - self.1.set(take_descriptor_dst().unwrap()); - match res { - Ok(r) => r, - Err(e) => resume_unwind(e), - } - } -} - -thread_local! { - static DESCRIPTOR_SRC: RefCell>>> = Default::default(); -} - -/// Sets the thread local storage of descriptors for deserialization. Fails if this was already -/// called without a call to `take_descriptor_src` on this thread. -/// -/// This is given as a collection of `Option` so that unused descriptors can be returned. -fn set_descriptor_src(descriptors: Vec>) -> Result<(), &'static str> { - DESCRIPTOR_SRC.with(|d| { - let mut src = d.borrow_mut(); - if src.is_some() { - return Err("attempt to set descriptor source that was already set"); - } - *src = Some(descriptors); - Ok(()) - }) -} - -/// Takes the thread local storage of descriptors for deserialization. Fails if the storage was -/// already taken or never set with `set_descriptor_src`. -/// -/// If deserialization was done, the descriptors will mostly come back as `None` unless some of them -/// were unused. -fn take_descriptor_src() -> Result>, &'static str> { - DESCRIPTOR_SRC.with(|d| { - d.replace(None) - .ok_or("attempt to take descriptor source which was never set") - }) -} - -/// Takes a descriptor at the given index from the thread local source of descriptors. -// -/// Returns None if the thread local source was not already initialized. -fn take_descriptor(index: usize) -> Result { - DESCRIPTOR_SRC.with(|d| { - d.borrow_mut() - .as_mut() - .ok_or("attempt to deserialize descriptor without descriptor source")? - .get_mut(index) - .ok_or("attempt to deserialize out of bounds descriptor")? - .take() - .ok_or("attempt to deserialize descriptor that was already taken") - }) -} - -/// Deserializes a descriptor provided via `deserialize_with_descriptors`. -/// -/// If `deserialize_with_descriptors` is not in the call chain, this will return an error. -/// -/// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with = -/// "...")]` attribute which will make use of this function. -pub fn deserialize_descriptor<'de, D>(de: D) -> std::result::Result -where - D: Deserializer<'de>, -{ - struct DescriptorVisitor; - - impl<'de> Visitor<'de> for DescriptorVisitor { - type Value = u32; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("an integer which fits into a u32") - } - - fn visit_u8(self, value: u8) -> Result { - Ok(value as _) - } - - fn visit_u16(self, value: u16) -> Result { - Ok(value as _) - } - - fn visit_u32(self, value: u32) -> Result { - Ok(value) - } - - fn visit_u64(self, value: u64) -> Result { - value.try_into().map_err(E::custom) - } - - fn visit_u128(self, value: u128) -> Result { - value.try_into().map_err(E::custom) - } - - fn visit_i8(self, value: i8) -> Result { - value.try_into().map_err(E::custom) - } - - fn visit_i16(self, value: i16) -> Result { - value.try_into().map_err(E::custom) - } - - fn visit_i32(self, value: i32) -> Result { - value.try_into().map_err(E::custom) - } - - fn visit_i64(self, value: i64) -> Result { - value.try_into().map_err(E::custom) - } - - fn visit_i128(self, value: i128) -> Result { - value.try_into().map_err(E::custom) - } - } - - let index = de.deserialize_u32(DescriptorVisitor)? as usize; - take_descriptor(index).map_err(D::Error::custom) -} - -/// Allows the use of any serde deserializer within a closure while providing access to the a set of -/// descriptors for use in `deserialize_descriptor`. -/// -/// This is the corresponding call to use deserialize after using `SerializeDescriptors`. -/// -/// If `deserialize_with_descriptors` is called anywhere within the given closure, it return an -/// error. -pub fn deserialize_with_descriptors( - f: F, - descriptors: &mut Vec>, -) -> Result -where - F: FnOnce() -> Result, - E: de::Error, -{ - let swap_descriptors = std::mem::take(descriptors); - set_descriptor_src(swap_descriptors).map_err(E::custom)?; - - // catch_unwind is used to ensure that set_descriptor_src is always balanced with a call to - // take_descriptor_src afterwards. - let res = catch_unwind(AssertUnwindSafe(f)); - - // unwrap is used because set_descriptor_src is always called before this, so it should never - // panic. - *descriptors = take_descriptor_src().unwrap(); - - match res { - Ok(r) => r, - Err(e) => resume_unwind(e), - } -} - -/// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]` -/// attribute. It only works with fields with `RawDescriptor` type. -/// -/// # Examples -/// -/// ``` -/// use serde::{Deserialize, Serialize}; -/// use crate::platform::RawDescriptor; -/// -/// #[derive(Serialize, Deserialize)] -/// struct RawContainer { -/// #[serde(with = "crate::platform::with_raw_descriptor")] -/// rd: RawDescriptor, -/// } -/// ``` -pub mod with_raw_descriptor { - use super::super::{IntoRawDescriptor, RawDescriptor}; - use serde::Deserializer; - - pub use super::serialize_descriptor as serialize; - - pub fn deserialize<'de, D>(de: D) -> std::result::Result - where - D: Deserializer<'de>, - { - super::deserialize_descriptor(de).map(IntoRawDescriptor::into_raw_descriptor) - } -} - -/// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]` -/// attribute. -/// -/// # Examples -/// -/// ``` -/// use std::fs::File; -/// use serde::{Deserialize, Serialize}; -/// use crate::platform::RawDescriptor; -/// -/// #[derive(Serialize, Deserialize)] -/// struct FileContainer { -/// #[serde(with = "crate::platform::with_as_descriptor")] -/// file: File, -/// } -/// ``` -pub mod with_as_descriptor { - use super::super::{AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor}; - use serde::{Deserializer, Serializer}; - - pub fn serialize( - rd: &dyn AsRawDescriptor, - se: S, - ) -> std::result::Result { - super::serialize_descriptor(&rd.as_raw_descriptor(), se) - } - - pub fn deserialize<'de, D, T>(de: D) -> std::result::Result - where - D: Deserializer<'de>, - T: FromRawDescriptor, - { - super::deserialize_descriptor(de) - .map(IntoRawDescriptor::into_raw_descriptor) - .map(|rd| unsafe { T::from_raw_descriptor(rd) }) - } -} - -/// A simple wrapper around `File` that implements `Serialize`/`Deserialize`, which is useful when -/// the `#[serde(with = "with_as_descriptor")]` trait is infeasible, such as for a field with type -/// `Option`. -#[derive(Serialize, Deserialize)] -#[serde(transparent)] -pub struct FileSerdeWrapper(#[serde(with = "with_as_descriptor")] pub File); - -impl fmt::Debug for FileSerdeWrapper { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.0.fmt(f) - } -} - -impl From for FileSerdeWrapper { - fn from(file: File) -> Self { - FileSerdeWrapper(file) - } -} - -impl From for File { - fn from(f: FileSerdeWrapper) -> File { - f.0 - } -} - -impl Deref for FileSerdeWrapper { - type Target = File; - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for FileSerdeWrapper { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -#[cfg(test)] -mod tests { - use super::super::{ - deserialize_with_descriptors, with_as_descriptor, with_raw_descriptor, FileSerdeWrapper, - FromRawDescriptor, RawDescriptor, SafeDescriptor, SerializeDescriptors, - }; - - use std::{collections::HashMap, fs::File, mem::ManuallyDrop, os::unix::io::AsRawFd}; - - use serde::{de::DeserializeOwned, Deserialize, Serialize}; - use tempfile::tempfile; - - fn deserialize(json: &str, descriptors: &[RawDescriptor]) -> T { - let mut safe_descriptors = descriptors - .iter() - .map(|&v| Some(unsafe { SafeDescriptor::from_raw_descriptor(v) })) - .collect(); - - let res = - deserialize_with_descriptors(|| serde_json::from_str(json), &mut safe_descriptors) - .unwrap(); - - assert!(safe_descriptors.iter().all(|v| v.is_none())); - - res - } - - #[test] - fn raw() { - #[derive(Serialize, Deserialize, PartialEq, Debug)] - struct RawContainer { - #[serde(with = "with_raw_descriptor")] - rd: RawDescriptor, - } - // Specifically chosen to not overlap a real descriptor to avoid having to allocate any - // descriptors for this test. - let fake_rd = 5_123_457_i32; - let v = RawContainer { rd: fake_rd }; - let v_serialize = SerializeDescriptors::new(&v); - let json = serde_json::to_string(&v_serialize).unwrap(); - let descriptors = v_serialize.into_descriptors(); - let res = deserialize(&json, &descriptors); - assert_eq!(v, res); - } - - #[test] - fn file() { - #[derive(Serialize, Deserialize)] - struct FileContainer { - #[serde(with = "with_as_descriptor")] - file: File, - } - - let v = FileContainer { - file: tempfile().unwrap(), - }; - let v_serialize = SerializeDescriptors::new(&v); - let json = serde_json::to_string(&v_serialize).unwrap(); - let descriptors = v_serialize.into_descriptors(); - let v = ManuallyDrop::new(v); - let res: FileContainer = deserialize(&json, &descriptors); - assert_eq!(v.file.as_raw_fd(), res.file.as_raw_fd()); - } - - #[test] - fn option() { - #[derive(Serialize, Deserialize)] - struct TestOption { - a: Option, - b: Option, - } - - let v = TestOption { - a: None, - b: Some(tempfile().unwrap().into()), - }; - let v_serialize = SerializeDescriptors::new(&v); - let json = serde_json::to_string(&v_serialize).unwrap(); - let descriptors = v_serialize.into_descriptors(); - let v = ManuallyDrop::new(v); - let res: TestOption = deserialize(&json, &descriptors); - assert!(res.a.is_none()); - assert!(res.b.is_some()); - assert_eq!( - v.b.as_ref().unwrap().as_raw_fd(), - res.b.unwrap().as_raw_fd() - ); - } - - #[test] - fn map() { - let mut v: HashMap = HashMap::new(); - v.insert("a".into(), tempfile().unwrap().into()); - v.insert("b".into(), tempfile().unwrap().into()); - v.insert("c".into(), tempfile().unwrap().into()); - let v_serialize = SerializeDescriptors::new(&v); - let json = serde_json::to_string(&v_serialize).unwrap(); - let descriptors = v_serialize.into_descriptors(); - // Prevent the files in `v` from dropping while allowing the HashMap itself to drop. It is - // done this way to prevent a double close of the files (which should reside in `res`) - // without triggering the leak sanitizer on `v`'s HashMap heap memory. - let v: HashMap<_, _> = v - .into_iter() - .map(|(k, v)| (k, ManuallyDrop::new(v))) - .collect(); - let res: HashMap = deserialize(&json, &descriptors); - - assert_eq!(v.len(), res.len()); - for (k, v) in v.iter() { - assert_eq!(res.get(k).unwrap().as_raw_fd(), v.as_raw_fd()); - } - } -} diff --git a/base/src/unix/mod.rs b/base/src/unix/mod.rs index def9f7d625..f722e0a024 100644 --- a/base/src/unix/mod.rs +++ b/base/src/unix/mod.rs @@ -27,7 +27,6 @@ mod acpi_event; mod capabilities; mod clock; mod descriptor; -mod descriptor_reflection; mod eventfd; mod file_flags; pub mod file_traits; @@ -52,6 +51,10 @@ mod timerfd; pub mod vsock; mod write_zeroes; +pub use crate::descriptor_reflection::{ + deserialize_with_descriptors, with_as_descriptor, with_raw_descriptor, FileSerdeWrapper, + SerializeDescriptors, +}; pub use crate::{ common::{Error, Result, *}, generate_scoped_event, @@ -61,10 +64,6 @@ pub use base_poll_token_derive::*; pub use capabilities::drop_capabilities; pub use clock::{Clock, FakeClock}; pub use descriptor::*; -pub use descriptor_reflection::{ - deserialize_with_descriptors, with_as_descriptor, with_raw_descriptor, FileSerdeWrapper, - SerializeDescriptors, -}; pub use eventfd::*; pub use file_flags::*; pub use fork::*; diff --git a/base/src/windows/mod.rs b/base/src/windows/mod.rs index e3b5beef6f..b0cca80491 100644 --- a/base/src/windows/mod.rs +++ b/base/src/windows/mod.rs @@ -18,7 +18,6 @@ mod clock; #[path = "win/console.rs"] mod console; mod descriptor; -mod descriptor_reflection; #[path = "win/event.rs"] mod event; mod events; @@ -45,14 +44,14 @@ pub mod thread; mod write_zeroes; pub use crate::common::{Error, Result, *}; +pub use crate::descriptor_reflection::{ + deserialize_with_descriptors, with_as_descriptor, with_raw_descriptor, FileSerdeWrapper, + SerializeDescriptors, +}; pub use base_poll_token_derive::*; pub use clock::{Clock, FakeClock}; pub use console::*; pub use descriptor::*; -pub use descriptor_reflection::{ - deserialize_with_descriptors, with_as_descriptor, with_raw_descriptor, FileSerdeWrapper, - SerializeDescriptors, -}; pub use event::*; pub use events::*; pub use get_filesystem_type::*;