From f5d565d6935251851c03ca30490de6f33cd58ad4 Mon Sep 17 00:00:00 2001 From: Jingkui Wang Date: Thu, 27 Sep 2018 10:41:11 -0700 Subject: [PATCH] crosvm: add msg_sock. MsgSock wraps UnixDatagram and provides simple macro to define Messages that could be send through sock easily. TEST=cargo test BUG=None Change-Id: I296fabc41893ad6a3ec42ef82dd29c3b752be8b8 Reviewed-on: https://chromium-review.googlesource.com/1255548 Commit-Ready: ChromeOS CL Exonerator Bot Tested-by: Jingkui Wang Reviewed-by: Zach Reizner --- Cargo.lock | 19 + Cargo.toml | 1 + msg_socket/Cargo.toml | 9 + msg_socket/msg_on_socket_derive/Cargo.toml | 13 + .../msg_on_socket_derive.rs | 719 ++++++++++++++++++ msg_socket/src/lib.rs | 288 +++++++ msg_socket/src/msg_on_socket.rs | 278 +++++++ 7 files changed, 1327 insertions(+) create mode 100644 msg_socket/Cargo.toml create mode 100644 msg_socket/msg_on_socket_derive/Cargo.toml create mode 100644 msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs create mode 100644 msg_socket/src/lib.rs create mode 100644 msg_socket/src/msg_on_socket.rs diff --git a/Cargo.lock b/Cargo.lock index 8cd8374987..b65fb3acca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -91,6 +91,7 @@ dependencies = [ "kvm 0.1.0", "kvm_sys 0.1.0", "libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)", + "msg_socket 0.1.0", "net_util 0.1.0", "p9 0.1.0", "plugin_proto 0.16.0", @@ -240,6 +241,24 @@ dependencies = [ "cfg-if 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "msg_on_socket_derive" +version = "0.1.0" +dependencies = [ + "proc-macro2 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)", + "syn 0.12.15 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "msg_socket" +version = "0.1.0" +dependencies = [ + "data_model 0.1.0", + "msg_on_socket_derive 0.1.0", + "sys_util 0.1.0", +] + [[package]] name = "net_sys" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index edeb1bbe25..7d1736442c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ gpu_buffer = { path = "gpu_buffer", optional = true } io_jail = { path = "io_jail" } kvm = { path = "kvm" } kvm_sys = { path = "kvm_sys" } +msg_socket = { path = "msg_socket" } sys_util = { path = "sys_util" } kernel_cmdline = { path = "kernel_cmdline" } kernel_loader = { path = "kernel_loader" } diff --git a/msg_socket/Cargo.toml b/msg_socket/Cargo.toml new file mode 100644 index 0000000000..0ce17b6365 --- /dev/null +++ b/msg_socket/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "msg_socket" +version = "0.1.0" +authors = ["The Chromium OS Authors"] + +[dependencies] +msg_on_socket_derive = { path = "msg_on_socket_derive" } +sys_util = { path = "../sys_util" } +data_model = { path = "../data_model" } diff --git a/msg_socket/msg_on_socket_derive/Cargo.toml b/msg_socket/msg_on_socket_derive/Cargo.toml new file mode 100644 index 0000000000..39316ce0fd --- /dev/null +++ b/msg_socket/msg_on_socket_derive/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "msg_on_socket_derive" +version = "0.1.0" +authors = ["The Chromium OS Authors"] + +[dependencies] +syn = "=0.12" +quote = "=0.4" +proc-macro2 = "=0.2" + +[lib] +proc-macro = true +path = "msg_on_socket_derive.rs" diff --git a/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs b/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs new file mode 100644 index 0000000000..5576ea66d0 --- /dev/null +++ b/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs @@ -0,0 +1,719 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#![recursion_limit = "256"] +extern crate proc_macro; +extern crate proc_macro2; + +#[macro_use] +extern crate quote; + +#[cfg_attr(test, macro_use)] +extern crate syn; + +use std::vec::Vec; + +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::Tokens; +use syn::{Data, DataEnum, DataStruct, DeriveInput, Fields, Ident}; + +/// The function that derives the recursive implementation for struct or enum. +#[proc_macro_derive(MsgOnSocket)] +pub fn msg_on_socket_derive(input: TokenStream) -> TokenStream { + socket_msg_impl(syn::parse(input).unwrap()).into() +} + +fn socket_msg_impl(ast: DeriveInput) -> Tokens { + if !ast.generics.params.is_empty() { + return quote! { + compile_error!("derive(SocketMsg) does not support generic parameters"); + }; + } + match ast.data { + Data::Struct(ds) => { + if is_named_struct(&ds) { + impl_for_named_struct(ast.ident, ds) + } else { + impl_for_tuple_struct(ast.ident, ds) + } + } + Data::Enum(de) => impl_for_enum(ast.ident, de), + _ => quote! { + compile_error!("derive(SocketMsg) only support struct and enum"); + }, + } +} + +fn is_named_struct(ds: &DataStruct) -> bool { + match &ds.fields { + &Fields::Named(ref _f) => true, + _ => false, + } +} + +/************************** Named Struct Impls ********************************************/ +fn impl_for_named_struct(name: Ident, ds: DataStruct) -> Tokens { + let fields = get_struct_fields(ds); + let fields_types = get_types_from_fields_vec(&fields); + let buffer_sizes_impls = define_buffer_size_for_struct(&fields_types); + + let read_buffer = define_read_buffer_for_struct(&name, &fields); + let write_buffer = define_write_buffer_for_struct(&name, &fields); + quote!( + impl MsgOnSocket for #name { + #buffer_sizes_impls + #read_buffer + #write_buffer + } + ) +} + +fn get_types_from_fields_vec(v: &[(Ident, syn::Type)]) -> Vec { + let mut fields_types = Vec::new(); + for (_i, t) in v { + fields_types.push(t.clone()); + } + fields_types +} + +// Flatten struct fields. +// "myfield : Type" -> \(ident\("myfield"\), Token\(Type\)\) +fn get_struct_fields(ds: DataStruct) -> Vec<(Ident, syn::Type)> { + let fields = match ds.fields { + Fields::Named(fields_named) => fields_named.named, + _ => { + panic!("Struct must have named fields"); + } + }; + let mut vec = Vec::new(); + for field in fields { + let ident = match field.ident { + Some(ident) => ident, + None => panic!("Unknown Error."), + }; + let ty = field.ty; + vec.push((ident, ty)); + } + vec +} + +fn define_buffer_size_for_struct(field_types: &[syn::Type]) -> Tokens { + let (msg_size, max_fd_count) = get_fields_buffer_size_sum(field_types); + quote! { + fn msg_size() -> usize { + #msg_size + } + fn max_fd_count() -> usize { + #max_fd_count + } + } +} + +fn define_read_buffer_for_struct(_name: &Ident, fields: &[(Ident, syn::Type)]) -> Tokens { + let mut read_fields = Vec::new(); + let mut init_fields = Vec::new(); + for f in fields { + let read_field = read_from_buffer_and_move_offset(&f.0, &f.1); + read_fields.push(read_field); + let name = f.0.clone(); + init_fields.push(quote!( #name )); + } + quote!{ + unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) + -> MsgResult<(Self, usize)> { + let mut __offset = 0usize; + let mut __fd_offset = 0usize; + #(#read_fields)* + Ok(( + Self { + #(#init_fields),* + }, + __fd_offset + )) + } + } +} + +fn define_write_buffer_for_struct(_name: &Ident, fields: &[(Ident, syn::Type)]) -> Tokens { + let mut write_fields = Vec::new(); + for f in fields { + let write_field = write_to_buffer_and_move_offset(&f.0, &f.1); + write_fields.push(write_field); + } + quote!{ + fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) + -> MsgResult { + let mut __offset = 0usize; + let mut __fd_offset = 0usize; + #(#write_fields)* + Ok(__fd_offset) + } + } +} + +/************************** Enum Impls ********************************************/ +fn impl_for_enum(name: Ident, de: DataEnum) -> Tokens { + let variants = get_enum_variant_types(&de); + let buffer_sizes_impls = define_buffer_size_for_enum(&variants); + + let read_buffer = define_read_buffer_for_enum(&name, &de); + let write_buffer = define_write_buffer_for_enum(&name, &de); + quote!( + impl MsgOnSocket for #name { + #buffer_sizes_impls + #read_buffer + #write_buffer + } + ) +} + +fn define_buffer_size_for_enum(variants: &[(Ident, Vec)]) -> Tokens { + let mut variant_buffer_sizes = Vec::new(); + let mut variant_fd_sizes = Vec::new(); + for v in variants { + let (msg_size_impl, fd_count_impl) = get_fields_buffer_size_sum(&v.1); + variant_buffer_sizes.push(msg_size_impl); + variant_fd_sizes.push(fd_count_impl); + } + quote! { + fn msg_size() -> usize { + // First byte is used for variant. + [#(#variant_buffer_sizes,)*].iter().max().unwrap().clone() as usize + 1 + } + fn max_fd_count() -> usize { + [#(#variant_fd_sizes,)*].iter().max().unwrap().clone() as usize + } + } +} + +// Flatten enum variants. Return value = \[variant_name, \[types_of_this_variant\]\] +fn get_enum_variant_types(de: &DataEnum) -> Vec<(Ident, Vec)> { + let mut variants = Vec::new(); + let de = de.clone(); + for v in de.variants { + let name = v.ident; + match v.fields { + Fields::Unnamed(fields) => { + let mut vec = Vec::new(); + for field in fields.unnamed { + let ty = field.ty; + vec.push(ty); + } + variants.push((name, vec)); + } + Fields::Unit => { + variants.push((name, Vec::new())); + continue; + } + Fields::Named(fields) => { + let mut vec = Vec::new(); + for field in fields.named { + let ty = field.ty; + vec.push(ty); + } + variants.push((name, vec)); + } + }; + } + variants +} + +fn define_read_buffer_for_enum(name: &Ident, de: &DataEnum) -> Tokens { + let mut match_variants = Vec::new(); + let de = de.clone(); + let mut i = 0u8; + for v in de.variants { + let variant_name = v.ident; + match v.fields { + Fields::Named(fields) => { + let mut tmp_names = Vec::new(); + let mut read_tmps = Vec::new(); + for f in fields.named { + tmp_names.push(f.ident.clone()); + let read_tmp = read_from_buffer_and_move_offset(&f.ident.unwrap(), &f.ty); + read_tmps.push(read_tmp); + } + let v = quote!( + #i => { + let mut __offset = 1usize; + let mut __fd_offset = 0usize; + #(#read_tmps)* + Ok((#name::#variant_name{ #(#tmp_names),*}, __fd_offset)) + } + ); + match_variants.push(v); + } + Fields::Unnamed(fields) => { + let mut tmp_names = Vec::new(); + let mut read_tmps = Vec::new(); + let mut j = 0usize; + for f in fields.unnamed { + let tmp_name = format!("enum_variant_tmp{}", j); + let tmp_name = Ident::new(&tmp_name, Span::call_site()); + tmp_names.push(tmp_name.clone()); + let read_tmp = read_from_buffer_and_move_offset(&tmp_name, &f.ty); + read_tmps.push(read_tmp); + j += 1; + } + + let v = quote!( + #i => { + let mut __offset = 1usize; + let mut __fd_offset = 0usize; + #(#read_tmps)* + Ok((#name::#variant_name( #(#tmp_names),*), __fd_offset)) + } + ); + match_variants.push(v); + } + Fields::Unit => { + let v = quote!( + #i => Ok((#name::#variant_name, 0)), + ); + match_variants.push(v); + } + } + i += 1; + } + quote!( + unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) + -> MsgResult<(Self, usize)> { + let v = buffer[0]; + match v { + #(#match_variants)* + _ => { + Err(MsgError::InvalidType) + } + } + } + ) +} + +fn define_write_buffer_for_enum(name: &Ident, de: &DataEnum) -> Tokens { + let mut match_variants = Vec::new(); + let mut i = 0u8; + let de = de.clone(); + for v in de.variants { + let variant_name = v.ident; + match v.fields { + Fields::Named(fields) => { + let mut tmp_names = Vec::new(); + let mut write_tmps = Vec::new(); + for f in fields.named { + tmp_names.push(f.ident.unwrap().clone()); + let write_tmp = enum_write_to_buffer_and_move_offset(&f.ident.unwrap(), &f.ty); + write_tmps.push(write_tmp); + } + + let v = quote!( + #name::#variant_name{#(#tmp_names),*} => { + buffer[0] = #i; + let mut __offset = 1usize; + let mut __fd_offset = 0usize; + #(#write_tmps)* + Ok(__fd_offset) + } + ); + match_variants.push(v); + } + Fields::Unnamed(fields) => { + let mut tmp_names = Vec::new(); + let mut write_tmps = Vec::new(); + let mut j = 0usize; + for f in fields.unnamed { + let tmp_name = format!("enum_variant_tmp{}", j); + let tmp_name = Ident::new(&tmp_name, Span::call_site()); + tmp_names.push(tmp_name.clone()); + let write_tmp = enum_write_to_buffer_and_move_offset(&tmp_name, &f.ty); + write_tmps.push(write_tmp); + j += 1; + } + + let v = quote!( + #name::#variant_name(#(#tmp_names),*) => { + buffer[0] = #i; + let mut __offset = 1usize; + let mut __fd_offset = 0usize; + #(#write_tmps)* + Ok(__fd_offset) + } + ); + match_variants.push(v); + } + Fields::Unit => { + let v = quote!( + #name::#variant_name => { + buffer[0] = #i; + Ok(0) + } + ); + match_variants.push(v); + } + } + i += 1; + } + + quote!( + fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult { + match self { + #(#match_variants)* + } + } + ) +} + +fn enum_write_to_buffer_and_move_offset(name: &Ident, ty: &syn::Type) -> Tokens { + quote!{ + let o = #name.write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; + __offset += <#ty>::msg_size(); + __fd_offset += o; + } +} + +/************************** Tuple Impls ********************************************/ +fn impl_for_tuple_struct(name: Ident, ds: DataStruct) -> Tokens { + let types = get_tuple_types(ds); + + let buffer_sizes_impls = define_buffer_size_for_struct(&types); + + let read_buffer = define_read_buffer_for_tuples(&name, &types); + let write_buffer = define_write_buffer_for_tuples(&name, &types); + quote!( + impl MsgOnSocket for #name { + #buffer_sizes_impls + #read_buffer + #write_buffer + } + ) +} + +fn get_tuple_types(ds: DataStruct) -> Vec { + let mut types = Vec::new(); + let fields = match ds.fields { + Fields::Unnamed(fields_unnamed) => fields_unnamed.unnamed, + _ => { + panic!("Tuple struct must have unnamed fields."); + } + }; + for field in fields { + let ty = field.ty; + types.push(ty); + } + types +} + +fn define_read_buffer_for_tuples(name: &Ident, fields: &[syn::Type]) -> Tokens { + let mut read_fields = Vec::new(); + let mut init_fields = Vec::new(); + for i in 0..fields.len() { + let tmp_name = format!("tuple_tmp{}", i); + let tmp_name = Ident::new(&tmp_name, Span::call_site()); + let read_field = read_from_buffer_and_move_offset(&tmp_name, &fields[i]); + read_fields.push(read_field); + init_fields.push(quote!( #tmp_name )); + } + + quote!{ + unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) + -> MsgResult<(Self, usize)> { + let mut __offset = 0usize; + let mut __fd_offset = 0usize; + #(#read_fields)* + Ok(( + #name ( + #(#init_fields),* + ), + __fd_offset + )) + } + } +} + +fn define_write_buffer_for_tuples(name: &Ident, fields: &[syn::Type]) -> Tokens { + let mut write_fields = Vec::new(); + let mut tmp_names = Vec::new(); + for i in 0..fields.len() { + let tmp_name = format!("tuple_tmp{}", i); + let tmp_name = Ident::new(&tmp_name, Span::call_site()); + let write_field = enum_write_to_buffer_and_move_offset(&tmp_name, &fields[i]); + write_fields.push(write_field); + tmp_names.push(tmp_name); + } + quote!{ + fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) + -> MsgResult { + let mut __offset = 0usize; + let mut __fd_offset = 0usize; + let #name( #(#tmp_names),* ) = self; + #(#write_fields)* + Ok(__fd_offset) + } + } +} +/************************** Helpers ********************************************/ +fn get_fields_buffer_size_sum(field_types: &[syn::Type]) -> (Tokens, Tokens) { + if field_types.len() > 0 { + ( + quote!( + #( <#field_types>::msg_size() as usize )+* + ), + quote!( + #( <#field_types>::max_fd_count() as usize )+* + ), + ) + } else { + (quote!(0), quote!(0)) + } +} + +fn read_from_buffer_and_move_offset(name: &Ident, ty: &syn::Type) -> Tokens { + quote!{ + let t = <#ty>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; + __offset += <#ty>::msg_size(); + __fd_offset += t.1; + let #name = t.0; + } +} + +fn write_to_buffer_and_move_offset(name: &Ident, ty: &syn::Type) -> Tokens { + quote!{ + let o = self.#name.write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; + __offset += <#ty>::msg_size(); + __fd_offset += o; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn end_to_end_struct_test() { + let input: DeriveInput = parse_quote! { + struct MyMsg { + a: u8, + b: RawFd, + c: u32, + } + }; + + let expected = quote! { + impl MsgOnSocket for MyMsg { + fn msg_size() -> usize { + ::msg_size() as usize + + ::msg_size() as usize + + ::msg_size() as usize + } + fn max_fd_count() -> usize { + ::max_fd_count() as usize + + ::max_fd_count() as usize + + ::max_fd_count() as usize + } + unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) + -> MsgResult<(Self, usize)> { + let mut __offset = 0usize; + let mut __fd_offset = 0usize; + let t = ::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += t.1; + let a = t.0; + let t = ::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += t.1; + let b = t.0; + let t = ::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += t.1; + let c = t.0; + Ok((Self { a, b, c }, __fd_offset)) + } + fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) + -> MsgResult { + let mut __offset = 0usize; + let mut __fd_offset = 0usize; + let o = self.a.write_to_buffer(&mut buffer[__offset..], + &mut fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += o; + let o = self.b.write_to_buffer(&mut buffer[__offset..], + &mut fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += o; + let o = self.c.write_to_buffer(&mut buffer[__offset..], + &mut fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += o; + Ok(__fd_offset) + } + } + }; + + assert_eq!(socket_msg_impl(input), expected); + } + + #[test] + fn end_to_end_tuple_struct_test() { + let input: DeriveInput = parse_quote! { + struct MyMsg(u8, u32, File); + }; + + let expected = quote! { + impl MsgOnSocket for MyMsg { + fn msg_size() -> usize { + ::msg_size() as usize + + ::msg_size() as usize + ::msg_size() as usize + } + fn max_fd_count() -> usize { + ::max_fd_count() as usize + + ::max_fd_count() as usize + + ::max_fd_count() as usize + } + unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) + -> MsgResult<(Self, usize)> { + let mut __offset = 0usize; + let mut __fd_offset = 0usize; + let t = ::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += t.1; + let tuple_tmp0 = t.0; + let t = ::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += t.1; + let tuple_tmp1 = t.0; + let t = ::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += t.1; + let tuple_tmp2 = t.0; + Ok((MyMsg(tuple_tmp0, tuple_tmp1, tuple_tmp2), __fd_offset)) + } + fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) + -> MsgResult { + let mut __offset = 0usize; + let mut __fd_offset = 0usize; + let MyMsg(tuple_tmp0, tuple_tmp1, tuple_tmp2) = self; + let o = tuple_tmp0. + write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += o; + let o = tuple_tmp1. + write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += o; + let o = tuple_tmp2. + write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += o; + Ok(__fd_offset) + } + } + }; + + assert_eq!(socket_msg_impl(input), expected); + } + + #[test] + fn end_to_end_enum_test() { + let input: DeriveInput = parse_quote! { + enum MyMsg { + A(u8), + B, + C{f0: u8, f1: RawFd}, + } + }; + + let expected = quote! { + impl MsgOnSocket for MyMsg { + fn msg_size() -> usize { + [ + ::msg_size() as usize, + 0, + ::msg_size() as usize + ::msg_size() as usize, + ].iter() + .max().unwrap().clone() as usize+ 1 + } + fn max_fd_count() -> usize { + [ + ::max_fd_count() as usize, + 0, + ::max_fd_count() as usize + ::max_fd_count() as usize, + ].iter() + .max().unwrap().clone() as usize + } + unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> + MsgResult<(Self, usize)> { + let v = buffer[0]; + match v { + 0u8 => { + let mut __offset = 1usize; + let mut __fd_offset = 0usize; + let t = ::read_from_buffer(&buffer[__offset..], + &fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += t.1; + let enum_variant_tmp0 = t.0; + Ok((MyMsg::A(enum_variant_tmp0), __fd_offset)) + } + 1u8 => Ok((MyMsg::B, 0)), + 2u8 => { + let mut __offset = 1usize; + let mut __fd_offset = 0usize; + let t = ::read_from_buffer(&buffer[__offset..], + &fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += t.1; + let f0 = t.0; + let t = ::read_from_buffer(&buffer[__offset..], + &fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += t.1; + let f1 = t.0; + Ok((MyMsg::C{f0, f1}, __fd_offset)) + } + _ => { + Err(MsgError::InvalidType) + } + } + } + fn write_to_buffer(&self, + buffer: &mut [u8], + fds: &mut [RawFd]) -> MsgResult { + match self { + MyMsg::A(enum_variant_tmp0) => { + buffer[0] = 0u8; + let mut __offset = 1usize; + let mut __fd_offset = 0usize; + let o = enum_variant_tmp0. + write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += o; + Ok(__fd_offset) + } + MyMsg::B => { + buffer[0] = 1u8; + Ok(0) + } + MyMsg::C{f0, f1} => { + buffer[0] = 2u8; + let mut __offset = 1usize; + let mut __fd_offset = 0usize; + let o = f0. + write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += o; + let o = f1.write_to_buffer( + &mut buffer[__offset..], &mut fds[__fd_offset..])?; + __offset += ::msg_size(); + __fd_offset += o; + Ok(__fd_offset) + } + } + } + } + + }; + + assert_eq!(socket_msg_impl(input), expected); + } +} diff --git a/msg_socket/src/lib.rs b/msg_socket/src/lib.rs new file mode 100644 index 0000000000..f868b19dc3 --- /dev/null +++ b/msg_socket/src/lib.rs @@ -0,0 +1,288 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#[allow(unused_imports)] +#[macro_use] +extern crate msg_on_socket_derive; +extern crate data_model; +extern crate sys_util; + +mod msg_on_socket; + +use std::marker::PhantomData; +use std::os::unix::io::RawFd; +use std::os::unix::net::UnixDatagram; +use sys_util::{ScmSocket, UnlinkUnixDatagram}; + +pub use msg_on_socket::*; +pub use msg_on_socket_derive::*; + +/// Create a pair of socket. Request is send in one direction while response is in the other +/// direction. +pub fn pair( +) -> Option<(MsgSocket, MsgSocket)> { + let (sock1, sock2) = match UnixDatagram::pair() { + Ok((sock1, sock2)) => (sock1, sock2), + _ => { + return None; + } + }; + let requester = MsgSocket { + sock: sock1, + _i: PhantomData, + _o: PhantomData, + }; + let responder = MsgSocket { + sock: sock2, + _i: PhantomData, + _o: PhantomData, + }; + Some((requester, responder)) +} + +/// Bidirection sock that support both send and recv. +pub struct MsgSocket { + sock: UnixDatagram, + _i: PhantomData, + _o: PhantomData, +} + +impl MsgSocket { + // Create a new MsgSocket. + pub fn new(s: UnixDatagram) -> MsgSocket { + MsgSocket { + sock: s, + _i: PhantomData, + _o: PhantomData, + } + } +} + +/// Bidirection sock that support both send and recv. +pub struct UnlinkMsgSocket { + sock: UnlinkUnixDatagram, + _i: PhantomData, + _o: PhantomData, +} + +impl UnlinkMsgSocket { + // Create a new MsgSocket. + pub fn new(s: UnlinkUnixDatagram) -> UnlinkMsgSocket { + UnlinkMsgSocket { + sock: s, + _i: PhantomData, + _o: PhantomData, + } + } +} + +/// One direction socket that only supports sending. +pub struct Sender { + sock: UnixDatagram, + _m: PhantomData, +} + +impl Sender { + /// Create a new sender sock. + pub fn new(s: UnixDatagram) -> Sender { + Sender { + sock: s, + _m: PhantomData, + } + } +} + +/// One direction socket that only supports receiving. +pub struct Receiver { + sock: UnixDatagram, + _m: PhantomData, +} + +impl Receiver { + /// Create a new receiver sock. + pub fn new(s: UnixDatagram) -> Receiver { + Receiver { + sock: s, + _m: PhantomData, + } + } +} + +impl AsRef for MsgSocket { + fn as_ref(&self) -> &UnixDatagram { + &self.sock + } +} + +impl AsRef for UnlinkMsgSocket { + fn as_ref(&self) -> &UnixDatagram { + self.sock.as_ref() + } +} + +impl AsRef for Sender { + fn as_ref(&self) -> &UnixDatagram { + &self.sock + } +} + +impl AsRef for Receiver { + fn as_ref(&self) -> &UnixDatagram { + &self.sock + } +} + +/// Types that could send a message. +pub trait MsgSender: AsRef { + fn send(&self, msg: &M) -> MsgResult<()> { + let msg_size = M::msg_size(); + let fd_size = M::max_fd_count(); + let mut msg_buffer: Vec = vec![0; msg_size]; + let mut fd_buffer: Vec = vec![0; fd_size]; + + let fd_size = msg.write_to_buffer(&mut msg_buffer, &mut fd_buffer)?; + let sock: &UnixDatagram = self.as_ref(); + sock.send_with_fds(&msg_buffer[..], &fd_buffer[0..fd_size]) + .map_err(|e| MsgError::Send(e))?; + Ok(()) + } +} + +/// Types that could receive a message. +pub trait MsgReceiver: AsRef { + fn recv(&self) -> MsgResult { + let msg_size = M::msg_size(); + let fd_size = M::max_fd_count(); + let mut msg_buffer: Vec = vec![0; msg_size]; + let mut fd_buffer: Vec = vec![0; fd_size]; + + let sock: &UnixDatagram = self.as_ref(); + let (recv_msg_size, recv_fd_size) = sock + .recv_with_fds(&mut msg_buffer, &mut fd_buffer) + .map_err(|e| MsgError::Recv(e))?; + if msg_size != recv_msg_size { + return Err(MsgError::BadRecvSize(msg_size)); + } + // Safe because fd buffer is read from socket. + let (v, read_fd_size) = unsafe { + M::read_from_buffer(&msg_buffer[0..recv_msg_size], &fd_buffer[0..recv_fd_size])? + }; + if recv_fd_size != read_fd_size { + return Err(MsgError::NotExpectFd); + } + Ok(v) + } +} + +impl MsgSender for MsgSocket {} +impl MsgReceiver for MsgSocket {} + +impl MsgSender for UnlinkMsgSocket {} +impl MsgReceiver for UnlinkMsgSocket {} + +impl MsgSender for Sender {} +impl MsgReceiver for Receiver {} + +#[cfg(test)] +mod tests { + use super::*; + use sys_util::EventFd; + + #[derive(MsgOnSocket)] + struct Request { + field0: u8, + field1: EventFd, + field2: u32, + } + + #[derive(MsgOnSocket)] + enum Response { + A(u8), + B, + C(u32, EventFd), + D([u8; 4]), + E { f0: u8, f1: u32 }, + } + + #[derive(MsgOnSocket)] + struct Message(u8, u16, EventFd); + + #[test] + fn sock_send_recv_struct() { + let (req, res) = pair::().unwrap(); + let e0 = EventFd::new().unwrap(); + let e1 = e0.try_clone().unwrap(); + req.send(&Request { + field0: 2, + field1: e0, + field2: 0xf0f0, + }).unwrap(); + let r = res.recv().unwrap(); + assert_eq!(r.field0, 2); + assert_eq!(r.field2, 0xf0f0); + r.field1.write(0x0f0f).unwrap(); + assert_eq!(e1.read().unwrap(), 0x0f0f); + } + + #[test] + fn sock_send_recv_enum() { + let (req, res) = pair::().unwrap(); + let e0 = EventFd::new().unwrap(); + let e1 = e0.try_clone().unwrap(); + res.send(&Response::C(0xf0f0, e0)).unwrap(); + let r = req.recv().unwrap(); + match r { + Response::C(v, efd) => { + assert_eq!(v, 0xf0f0); + efd.write(0x0f0f).unwrap(); + } + _ => panic!("wrong type"), + }; + assert_eq!(e1.read().unwrap(), 0x0f0f); + + res.send(&Response::B).unwrap(); + match req.recv().unwrap() { + Response::B => {} + _ => panic!("Wrong enum type"), + }; + + res.send(&Response::A(0x3)).unwrap(); + match req.recv().unwrap() { + Response::A(v) => assert_eq!(v, 0x3), + _ => panic!("Wrong enum type"), + }; + + res.send(&Response::D([0, 1, 2, 3])).unwrap(); + match req.recv().unwrap() { + Response::D(v) => assert_eq!(v, [0, 1, 2, 3]), + _ => panic!("Wrong enum type"), + }; + + res.send(&Response::E { + f0: 0x12, + f1: 0x0f0f, + }).unwrap(); + match req.recv().unwrap() { + Response::E { f0, f1 } => { + assert_eq!(f0, 0x12); + assert_eq!(f1, 0x0f0f); + } + _ => panic!("Wrong enum type"), + }; + } + + #[test] + fn sock_send_recv_tuple() { + let (req, res) = pair::().unwrap(); + let e0 = EventFd::new().unwrap(); + let e1 = e0.try_clone().unwrap(); + req.send(&Message(1, 0x12, e0)).unwrap(); + let r = res.recv().unwrap(); + assert_eq!(r.0, 1); + assert_eq!(r.1, 0x12); + r.2.write(0x0f0f).unwrap(); + assert_eq!(e1.read().unwrap(), 0x0f0f); + } + +} diff --git a/msg_socket/src/msg_on_socket.rs b/msg_socket/src/msg_on_socket.rs new file mode 100644 index 0000000000..aca57e23a4 --- /dev/null +++ b/msg_socket/src/msg_on_socket.rs @@ -0,0 +1,278 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use data_model::*; +use std; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::result; +use sys_util::{Error as SysError, EventFd}; + +use std::fs::File; +use std::net::{TcpListener, TcpStream, UdpSocket}; +use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream}; + +#[derive(Debug, PartialEq)] +/// An error during transaction or serialization/deserialization. +pub enum MsgError { + /// Error while sending a request or response. + Send(SysError), + /// Error while receiving a request or response. + Recv(SysError), + /// The type of a received request or response is unknown. + InvalidType, + /// There was not the expected amount of data when receiving a message. The inner + /// value is how much data is needed. + BadRecvSize(usize), + /// There was no associated file descriptor received for a request that expected it. + ExpectFd, + /// There was some associated file descriptor received but not used when deserialize. + NotExpectFd, + /// Trying to serialize/deserialize, but fd buffer size is too small. This typically happens + /// when max_fd_count() returns a value that is too small. + WrongFdBufferSize, + /// Trying to serialize/deserialize, but msg buffer size is too small. This typically happens + /// when msg_size() returns a value that is too small. + WrongMsgBufferSize, +} + +pub type MsgResult = result::Result; + +/// A msg that could be serialized to and deserialize from array in little endian. +/// +/// For structs, we always have fixed size of bytes and fixed count of fds. +/// For enums, the size needed might be different for each variant. +/// +/// e.g. +/// ``` +/// use std::os::unix::io::RawFd; +/// enum Message { +/// VariantA(u8), +/// VariantB(u32, RawFd), +/// VariantC, +/// } +/// ``` +/// +/// For variant A, we need 1 byte to store its inner value. +/// For variant B, we need 4 bytes and 1 RawFd to store its inner value. +/// For variant C, we need 0 bytes to store its inner value. +/// When we serialize Message to (buffer, fd_buffer), we always use fixed number of bytes in +/// the buffer. Unused buffer bytes will be padded with zero. +/// However, for fd_buffer, we could not do the same thing. Otherwise, we are essentially sending +/// fd 0 through the socket. +/// Thus, read/write functions always the return correct count of fds in this variant. There will be +/// no padding in fd_buffer. +pub trait MsgOnSocket: Sized { + /// Size of message in bytes. + fn msg_size() -> usize; + /// Max possible fd count in this type. + fn max_fd_count() -> usize { + 0 + } + /// Returns (self, fd read count). + /// This function is safe only when: + /// 0. fds contains valid fds, received from socket, serialized by Self::write_to_buffer. + /// 1. For enum, fds contains correct fd layout of the particular variant. + /// 2. write_to_buffer is implemented correctly(put valid fds into the buffer, has no padding, + /// return correct count). + unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)>; + /// Serialize self to buffers. + fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult; +} + +impl MsgOnSocket for SysError { + fn msg_size() -> usize { + u32::msg_size() + } + unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { + let (v, size) = u32::read_from_buffer(buffer, fds)?; + Ok((SysError::new(v as i32), size)) + } + fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult { + let v = self.errno() as u32; + v.write_to_buffer(buffer, fds) + } +} + +impl MsgOnSocket for RawFd { + fn msg_size() -> usize { + 0 + } + fn max_fd_count() -> usize { + 1 + } + unsafe fn read_from_buffer(_buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { + if fds.len() < 1 { + return Err(MsgError::ExpectFd); + } + Ok((fds[0], 1)) + } + fn write_to_buffer(&self, _buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult { + if fds.len() < 1 { + return Err(MsgError::WrongFdBufferSize); + } + fds[0] = self.clone(); + Ok(1) + } +} + +macro_rules! rawfd_impl { + ($type:ident) => { + impl MsgOnSocket for $type { + fn msg_size() -> usize { + 0 + } + fn max_fd_count() -> usize { + 1 + } + unsafe fn read_from_buffer(_buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { + if fds.len() < 1 { + return Err(MsgError::ExpectFd); + } + Ok(($type::from_raw_fd(fds[0].clone()), 1)) + } + fn write_to_buffer(&self, _buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult { + if fds.len() < 1 { + return Err(MsgError::WrongFdBufferSize); + } + fds[0] = self.as_raw_fd(); + Ok(1) + } + } + }; +} + +rawfd_impl!(EventFd); +rawfd_impl!(File); +rawfd_impl!(UnixStream); +rawfd_impl!(TcpStream); +rawfd_impl!(TcpListener); +rawfd_impl!(UdpSocket); +rawfd_impl!(UnixListener); +rawfd_impl!(UnixDatagram); + +// usize could be different sizes on different targets. We always use u64. +impl MsgOnSocket for usize { + fn msg_size() -> usize { + std::mem::size_of::() + } + unsafe fn read_from_buffer(buffer: &[u8], _fds: &[RawFd]) -> MsgResult<(Self, usize)> { + if buffer.len() < std::mem::size_of::() { + return Err(MsgError::WrongMsgBufferSize); + } + let t: u64 = Le64::from_slice(&buffer[0..Self::msg_size()]) + .unwrap() + .clone() + .into(); + Ok((t as usize, 0)) + } + + fn write_to_buffer(&self, buffer: &mut [u8], _fds: &mut [RawFd]) -> MsgResult { + if buffer.len() < std::mem::size_of::() { + return Err(MsgError::WrongMsgBufferSize); + } + let t: Le64 = (*self as u64).into(); + buffer[0..Self::msg_size()].copy_from_slice(t.as_slice()); + Ok(0) + } +} + +macro_rules! le_impl { + ($type:ident, $le_type:ident) => { + impl MsgOnSocket for $type { + fn msg_size() -> usize { + std::mem::size_of::<$le_type>() + } + unsafe fn read_from_buffer(buffer: &[u8], _fds: &[RawFd]) -> MsgResult<(Self, usize)> { + if buffer.len() < std::mem::size_of::<$le_type>() { + return Err(MsgError::WrongMsgBufferSize); + } + let t = $le_type::from_slice(&buffer[0..Self::msg_size()]) + .unwrap() + .clone(); + Ok((t.into(), 0)) + } + + fn write_to_buffer(&self, buffer: &mut [u8], _fds: &mut [RawFd]) -> MsgResult { + if buffer.len() < std::mem::size_of::<$le_type>() { + return Err(MsgError::WrongMsgBufferSize); + } + let t: $le_type = self.clone().into(); + buffer[0..Self::msg_size()].copy_from_slice(t.as_slice()); + Ok(0) + } + } + }; +} + +le_impl!(u8, u8); +le_impl!(u16, Le16); +le_impl!(u32, Le32); +le_impl!(u64, Le64); + +le_impl!(Le16, Le16); +le_impl!(Le32, Le32); +le_impl!(Le64, Le64); + +macro_rules! array_impls { + ($N:expr, $t: ident $($ts:ident)*) + => { + impl MsgOnSocket for [T; $N] { + fn msg_size() -> usize { + T::msg_size() * $N + } + fn max_fd_count() -> usize { + T::max_fd_count() * $N + } + unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { + if buffer.len() < Self::msg_size() { + return Err(MsgError::WrongMsgBufferSize); + } + let mut offset = 0usize; + let mut fd_offset = 0usize; + let ($t, fd_size) = + T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?; + offset += T::msg_size(); + fd_offset += fd_size; + $( + let ($ts, fd_size) = + T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?; + offset += T::msg_size(); + fd_offset += fd_size; + )* + assert_eq!(offset, Self::msg_size()); + Ok(([$t, $($ts),*], fd_offset)) + } + + fn write_to_buffer( + &self, + buffer: &mut [u8], + fds: &mut [RawFd], + ) -> MsgResult { + if buffer.len() < Self::msg_size() { + return Err(MsgError::WrongMsgBufferSize); + } + let mut offset = 0usize; + let mut fd_offset = 0usize; + for idx in 0..$N { + let fd_size = self[idx].clone().write_to_buffer(&mut buffer[offset..], + &mut fds[fd_offset..])?; + offset += T::msg_size(); + fd_offset += fd_size; + } + + Ok(fd_offset) + } + } + array_impls!(($N - 1), $($ts)*); + }; + {$N:expr, } => {}; +} + +array_impls! { + 32, tmp1 tmp2 tmp3 tmp4 tmp5 tmp6 tmp7 tmp8 tmp9 tmp10 tmp11 tmp12 tmp13 tmp14 tmp15 tmp16 + tmp17 tmp18 tmp19 tmp20 tmp21 tmp22 tmp23 tmp24 tmp25 tmp26 tmp27 tmp28 tmp29 tmp30 tmp31 + tmp32 +} + +// TODO(jkwang) Define MsgOnSocket for tuple?