msg_socket: impl skip helper attribute

Fields with a default value can be skipped using the
`#[msg_on_socket(skip)]` attribute.

TEST=cargo test -p msg_socket
BUG=None

Change-Id: I9fea33e641a7da62b7864ba1847e884b32502491
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/2168587
Reviewed-by: Dylan Reid <dgreid@chromium.org>
Tested-by: kokoro <noreply+kokoro@google.com>
Tested-by: Zach Reizner <zachr@chromium.org>
Commit-Queue: Zach Reizner <zachr@chromium.org>
This commit is contained in:
Zach Reizner 2020-03-25 01:36:46 -07:00 committed by Commit Bot
parent 8b3ee41b30
commit 882e2cea3b

View file

@ -10,11 +10,12 @@ use std::vec::Vec;
use proc_macro2::{Span, TokenStream}; use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::{ use syn::{
parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Fields, Ident, Index, Member, Type, parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Fields, Ident, Index, Member, Meta,
NestedMeta, Type,
}; };
/// The function that derives the recursive implementation for struct or enum. /// The function that derives the recursive implementation for struct or enum.
#[proc_macro_derive(MsgOnSocket)] #[proc_macro_derive(MsgOnSocket, attributes(msg_on_socket))]
pub fn msg_on_socket_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { pub fn msg_on_socket_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput); let input = parse_macro_input!(input as DeriveInput);
let impl_for_input = socket_msg_impl(input); let impl_for_input = socket_msg_impl(input);
@ -50,6 +51,13 @@ fn is_named_struct(ds: &DataStruct) -> bool {
} }
/************************** Named Struct Impls ********************************************/ /************************** Named Struct Impls ********************************************/
struct StructField {
member: Member,
ty: Type,
skipped: bool,
}
fn impl_for_named_struct(name: Ident, ds: DataStruct) -> TokenStream { fn impl_for_named_struct(name: Ident, ds: DataStruct) -> TokenStream {
let fields = get_struct_fields(ds); let fields = get_struct_fields(ds);
let uses_fd_impl = define_uses_fd_for_struct(&fields); let uses_fd_impl = define_uses_fd_for_struct(&fields);
@ -68,7 +76,7 @@ fn impl_for_named_struct(name: Ident, ds: DataStruct) -> TokenStream {
} }
// Flatten struct fields. // Flatten struct fields.
fn get_struct_fields(ds: DataStruct) -> Vec<(Member, Type)> { fn get_struct_fields(ds: DataStruct) -> Vec<StructField> {
let fields = match ds.fields { let fields = match ds.fields {
Fields::Named(fields_named) => fields_named.named, Fields::Named(fields_named) => fields_named.named,
_ => { _ => {
@ -82,17 +90,48 @@ fn get_struct_fields(ds: DataStruct) -> Vec<(Member, Type)> {
None => panic!("Unknown Error."), None => panic!("Unknown Error."),
}; };
let ty = field.ty; let ty = field.ty;
vec.push((member, ty)); let mut skipped = false;
for attr in field
.attrs
.iter()
.filter(|attr| attr.path.is_ident("msg_on_socket"))
{
match attr.parse_meta().unwrap() {
Meta::List(meta) => {
for nested in meta.nested {
match nested {
NestedMeta::Meta(Meta::Path(meta_path))
if meta_path.is_ident("skip") =>
{
skipped = true;
}
_ => panic!("unrecognized attribute meta `{}`", quote! { #nested }),
}
}
}
_ => panic!("unrecognized attribute `{}`", quote! { #attr }),
}
}
vec.push(StructField {
member,
ty,
skipped,
});
} }
vec vec
} }
fn define_uses_fd_for_struct(fields: &[(Member, Type)]) -> TokenStream { fn define_uses_fd_for_struct(fields: &[StructField]) -> TokenStream {
if fields.len() == 0 { let field_types: Vec<_> = fields
.iter()
.filter(|f| !f.skipped)
.map(|f| &f.ty)
.collect();
if field_types.is_empty() {
return quote!(); return quote!();
} }
let field_types = fields.iter().map(|(_, ty)| ty);
quote! { quote! {
fn uses_fd() -> bool { fn uses_fd() -> bool {
#(<#field_types>::uses_fd())||* #(<#field_types>::uses_fd())||*
@ -100,7 +139,7 @@ fn define_uses_fd_for_struct(fields: &[(Member, Type)]) -> TokenStream {
} }
} }
fn define_buffer_size_for_struct(fields: &[(Member, Type)]) -> TokenStream { fn define_buffer_size_for_struct(fields: &[StructField]) -> TokenStream {
let (msg_size, fd_count) = get_fields_buffer_size_sum(fields); let (msg_size, fd_count) = get_fields_buffer_size_sum(fields);
quote! { quote! {
fn msg_size(&self) -> usize { fn msg_size(&self) -> usize {
@ -112,17 +151,24 @@ fn define_buffer_size_for_struct(fields: &[(Member, Type)]) -> TokenStream {
} }
} }
fn define_read_buffer_for_struct(_name: &Ident, fields: &[(Member, Type)]) -> TokenStream { fn define_read_buffer_for_struct(_name: &Ident, fields: &[StructField]) -> TokenStream {
let mut read_fields = Vec::new(); let mut read_fields = Vec::new();
let mut init_fields = Vec::new(); let mut init_fields = Vec::new();
for (field_member, field_ty) in fields { for field in fields {
let ident = match field_member { let ident = match &field.member {
Member::Named(ident) => ident, Member::Named(ident) => ident,
Member::Unnamed(_) => unreachable!(), Member::Unnamed(_) => unreachable!(),
}; };
let read_field = read_from_buffer_and_move_offset(&ident, &field_ty);
read_fields.push(read_field);
let name = ident.clone(); let name = ident.clone();
if field.skipped {
let ty = &field.ty;
init_fields.push(quote! {
#name: <#ty>::default()
});
continue;
}
let read_field = read_from_buffer_and_move_offset(&ident, &field.ty);
read_fields.push(read_field);
init_fields.push(quote!(#name)); init_fields.push(quote!(#name));
} }
quote! { quote! {
@ -143,10 +189,13 @@ fn define_read_buffer_for_struct(_name: &Ident, fields: &[(Member, Type)]) -> To
} }
} }
fn define_write_buffer_for_struct(_name: &Ident, fields: &[(Member, Type)]) -> TokenStream { fn define_write_buffer_for_struct(_name: &Ident, fields: &[StructField]) -> TokenStream {
let mut write_fields = Vec::new(); let mut write_fields = Vec::new();
for (field_member, _) in fields { for field in fields {
let ident = match field_member { if field.skipped {
continue;
}
let ident = match &field.member {
Member::Named(ident) => ident, Member::Named(ident) => ident,
Member::Unnamed(_) => unreachable!(), Member::Unnamed(_) => unreachable!(),
}; };
@ -438,7 +487,7 @@ fn impl_for_tuple_struct(name: Ident, ds: DataStruct) -> TokenStream {
} }
} }
fn get_tuple_fields(ds: DataStruct) -> Vec<(Member, Type)> { fn get_tuple_fields(ds: DataStruct) -> Vec<StructField> {
let mut field_idents = Vec::new(); let mut field_idents = Vec::new();
let fields = match ds.fields { let fields = match ds.fields {
Fields::Unnamed(fields_unnamed) => fields_unnamed.unnamed, Fields::Unnamed(fields_unnamed) => fields_unnamed.unnamed,
@ -449,17 +498,21 @@ fn get_tuple_fields(ds: DataStruct) -> Vec<(Member, Type)> {
for (idx, field) in fields.iter().enumerate() { for (idx, field) in fields.iter().enumerate() {
let member = Member::Unnamed(Index::from(idx)); let member = Member::Unnamed(Index::from(idx));
let ty = field.ty.clone(); let ty = field.ty.clone();
field_idents.push((member, ty)); field_idents.push(StructField {
member,
ty,
skipped: false,
});
} }
field_idents field_idents
} }
fn define_uses_fd_for_tuples(fields: &[(Member, Type)]) -> TokenStream { fn define_uses_fd_for_tuples(fields: &[StructField]) -> TokenStream {
if fields.len() == 0 { if fields.len() == 0 {
return quote!(); return quote!();
} }
let field_types = fields.iter().map(|(_, ty)| ty); let field_types = fields.iter().map(|f| &f.ty);
quote! { quote! {
fn uses_fd() -> bool { fn uses_fd() -> bool {
#(<#field_types>::uses_fd())||* #(<#field_types>::uses_fd())||*
@ -467,13 +520,13 @@ fn define_uses_fd_for_tuples(fields: &[(Member, Type)]) -> TokenStream {
} }
} }
fn define_read_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> TokenStream { fn define_read_buffer_for_tuples(name: &Ident, fields: &[StructField]) -> TokenStream {
let mut read_fields = Vec::new(); let mut read_fields = Vec::new();
let mut init_fields = Vec::new(); let mut init_fields = Vec::new();
for (idx, (_, field_ty)) in fields.iter().enumerate() { for (idx, field) in fields.iter().enumerate() {
let tmp_name = format!("tuple_tmp{}", idx); let tmp_name = format!("tuple_tmp{}", idx);
let tmp_name = Ident::new(&tmp_name, Span::call_site()); let tmp_name = Ident::new(&tmp_name, Span::call_site());
let read_field = read_from_buffer_and_move_offset(&tmp_name, field_ty); let read_field = read_from_buffer_and_move_offset(&tmp_name, &field.ty);
read_fields.push(read_field); read_fields.push(read_field);
init_fields.push(quote!(#tmp_name)); init_fields.push(quote!(#tmp_name));
} }
@ -496,7 +549,7 @@ fn define_read_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> Tok
} }
} }
fn define_write_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> TokenStream { fn define_write_buffer_for_tuples(name: &Ident, fields: &[StructField]) -> TokenStream {
let mut write_fields = Vec::new(); let mut write_fields = Vec::new();
let mut tmp_names = Vec::new(); let mut tmp_names = Vec::new();
for idx in 0..fields.len() { for idx in 0..fields.len() {
@ -520,8 +573,12 @@ fn define_write_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> To
} }
} }
/************************** Helpers ********************************************/ /************************** Helpers ********************************************/
fn get_fields_buffer_size_sum(fields: &[(Member, Type)]) -> (TokenStream, TokenStream) { fn get_fields_buffer_size_sum(fields: &[StructField]) -> (TokenStream, TokenStream) {
let fields: Vec<_> = fields.iter().map(|(m, _)| m).collect(); let fields: Vec<_> = fields
.iter()
.filter(|f| !f.skipped)
.map(|f| &f.member)
.collect();
if fields.len() > 0 { if fields.len() > 0 {
( (
quote! { quote! {
@ -808,4 +865,45 @@ mod tests {
assert_eq!(socket_msg_impl(input).to_string(), expected.to_string()); assert_eq!(socket_msg_impl(input).to_string(), expected.to_string());
} }
#[test]
fn end_to_end_struct_skip_test() {
let input: DeriveInput = parse_quote! {
struct MyMsg {
#[msg_on_socket(skip)]
a: u8,
}
};
let expected = quote! {
impl msg_socket::MsgOnSocket for MyMsg {
fn msg_size(&self) -> usize {
0
}
fn fd_count(&self) -> usize {
0
}
unsafe fn read_from_buffer(
buffer: &[u8],
fds: &[std::os::unix::io::RawFd],
) -> msg_socket::MsgResult<(Self, usize)> {
let mut __offset = 0usize;
let mut __fd_offset = 0usize;
Ok((Self { a: <u8>::default() }, __fd_offset))
}
fn write_to_buffer(
&self,
buffer: &mut [u8],
fds: &mut [std::os::unix::io::RawFd],
) -> msg_socket::MsgResult<usize> {
let mut __offset = 0usize;
let mut __fd_offset = 0usize;
Ok(__fd_offset)
}
}
};
assert_eq!(socket_msg_impl(input).to_string(), expected.to_string());
}
} }