diff --git a/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs b/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs index 5576ea66d0..a589a5a94e 100644 --- a/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs +++ b/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs @@ -22,7 +22,23 @@ use syn::{Data, DataEnum, DataStruct, DeriveInput, Fields, Ident}; /// The function that derives the recursive implementation for struct or enum. #[proc_macro_derive(MsgOnSocket)] pub fn msg_on_socket_derive(input: TokenStream) -> TokenStream { - socket_msg_impl(syn::parse(input).unwrap()).into() + let ast: DeriveInput = syn::parse(input).unwrap(); + + let const_namespace = Ident::new( + &format!("__MSG_ON_SOCKET_IMPL_{}", ast.ident), + Span::call_site(), + ); + + let impl_for_input = socket_msg_impl(ast); + + let wrapped_impl = quote! { + const #const_namespace: () = { + extern crate msg_socket as _msg_socket; + #impl_for_input + }; + }; + + wrapped_impl.into() } fn socket_msg_impl(ast: DeriveInput) -> Tokens { @@ -62,7 +78,7 @@ fn impl_for_named_struct(name: Ident, ds: DataStruct) -> Tokens { let read_buffer = define_read_buffer_for_struct(&name, &fields); let write_buffer = define_write_buffer_for_struct(&name, &fields); quote!( - impl MsgOnSocket for #name { + impl _msg_socket::MsgOnSocket for #name { #buffer_sizes_impls #read_buffer #write_buffer @@ -121,8 +137,8 @@ fn define_read_buffer_for_struct(_name: &Ident, fields: &[(Ident, syn::Type)]) - init_fields.push(quote!( #name )); } quote!{ - unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) - -> MsgResult<(Self, usize)> { + unsafe fn read_from_buffer(buffer: &[u8], fds: &[std::os::unix::io::RawFd]) + -> _msg_socket::MsgResult<(Self, usize)> { let mut __offset = 0usize; let mut __fd_offset = 0usize; #(#read_fields)* @@ -143,8 +159,8 @@ fn define_write_buffer_for_struct(_name: &Ident, fields: &[(Ident, syn::Type)]) write_fields.push(write_field); } quote!{ - fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) - -> MsgResult { + fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [std::os::unix::io::RawFd]) + -> _msg_socket::MsgResult { let mut __offset = 0usize; let mut __fd_offset = 0usize; #(#write_fields)* @@ -161,7 +177,7 @@ fn impl_for_enum(name: Ident, de: DataEnum) -> Tokens { let read_buffer = define_read_buffer_for_enum(&name, &de); let write_buffer = define_write_buffer_for_enum(&name, &de); quote!( - impl MsgOnSocket for #name { + impl _msg_socket::MsgOnSocket for #name { #buffer_sizes_impls #read_buffer #write_buffer @@ -278,13 +294,13 @@ fn define_read_buffer_for_enum(name: &Ident, de: &DataEnum) -> Tokens { i += 1; } quote!( - unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) - -> MsgResult<(Self, usize)> { + unsafe fn read_from_buffer(buffer: &[u8], fds: &[std::os::unix::io::RawFd]) + -> _msg_socket::MsgResult<(Self, usize)> { let v = buffer[0]; match v { #(#match_variants)* _ => { - Err(MsgError::InvalidType) + Err(_msg_socket::MsgError::InvalidType) } } } @@ -356,7 +372,8 @@ fn define_write_buffer_for_enum(name: &Ident, de: &DataEnum) -> Tokens { } quote!( - fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult { + fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [std::os::unix::io::RawFd]) + -> _msg_socket::MsgResult { match self { #(#match_variants)* } @@ -381,7 +398,7 @@ fn impl_for_tuple_struct(name: Ident, ds: DataStruct) -> Tokens { let read_buffer = define_read_buffer_for_tuples(&name, &types); let write_buffer = define_write_buffer_for_tuples(&name, &types); quote!( - impl MsgOnSocket for #name { + impl _msg_socket::MsgOnSocket for #name { #buffer_sizes_impls #read_buffer #write_buffer @@ -416,8 +433,8 @@ fn define_read_buffer_for_tuples(name: &Ident, fields: &[syn::Type]) -> Tokens { } quote!{ - unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) - -> MsgResult<(Self, usize)> { + unsafe fn read_from_buffer(buffer: &[u8], fds: &[std::os::unix::io::RawFd]) + -> _msg_socket::MsgResult<(Self, usize)> { let mut __offset = 0usize; let mut __fd_offset = 0usize; #(#read_fields)* @@ -442,8 +459,8 @@ fn define_write_buffer_for_tuples(name: &Ident, fields: &[syn::Type]) -> Tokens tmp_names.push(tmp_name); } quote!{ - fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) - -> MsgResult { + fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [std::os::unix::io::RawFd]) + -> _msg_socket::MsgResult { let mut __offset = 0usize; let mut __fd_offset = 0usize; let #name( #(#tmp_names),* ) = self; @@ -487,7 +504,8 @@ fn write_to_buffer_and_move_offset(name: &Ident, ty: &syn::Type) -> Tokens { #[cfg(test)] mod tests { - use super::*; + use syn::DeriveInput; + use socket_msg_impl; #[test] fn end_to_end_struct_test() { @@ -500,7 +518,7 @@ mod tests { }; let expected = quote! { - impl MsgOnSocket for MyMsg { + impl _msg_socket::MsgOnSocket for MyMsg { fn msg_size() -> usize { ::msg_size() as usize + ::msg_size() as usize @@ -511,8 +529,8 @@ mod tests { + ::max_fd_count() as usize + ::max_fd_count() as usize } - unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) - -> MsgResult<(Self, usize)> { + unsafe fn read_from_buffer(buffer: &[u8], fds: &[std::os::unix::io::RawFd]) + -> _msg_socket::MsgResult<(Self, usize)> { let mut __offset = 0usize; let mut __fd_offset = 0usize; let t = ::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; @@ -529,8 +547,8 @@ mod tests { let c = t.0; Ok((Self { a, b, c }, __fd_offset)) } - fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) - -> MsgResult { + fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [std::os::unix::io::RawFd]) + -> _msg_socket::MsgResult { let mut __offset = 0usize; let mut __fd_offset = 0usize; let o = self.a.write_to_buffer(&mut buffer[__offset..], @@ -560,7 +578,7 @@ mod tests { }; let expected = quote! { - impl MsgOnSocket for MyMsg { + impl _msg_socket::MsgOnSocket for MyMsg { fn msg_size() -> usize { ::msg_size() as usize + ::msg_size() as usize + ::msg_size() as usize @@ -570,8 +588,8 @@ mod tests { + ::max_fd_count() as usize + ::max_fd_count() as usize } - unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) - -> MsgResult<(Self, usize)> { + unsafe fn read_from_buffer(buffer: &[u8], fds: &[std::os::unix::io::RawFd]) + -> _msg_socket::MsgResult<(Self, usize)> { let mut __offset = 0usize; let mut __fd_offset = 0usize; let t = ::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; @@ -588,8 +606,8 @@ mod tests { let tuple_tmp2 = t.0; Ok((MyMsg(tuple_tmp0, tuple_tmp1, tuple_tmp2), __fd_offset)) } - fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) - -> MsgResult { + fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [std::os::unix::io::RawFd]) + -> _msg_socket::MsgResult { let mut __offset = 0usize; let mut __fd_offset = 0usize; let MyMsg(tuple_tmp0, tuple_tmp1, tuple_tmp2) = self; @@ -624,7 +642,7 @@ mod tests { }; let expected = quote! { - impl MsgOnSocket for MyMsg { + impl _msg_socket::MsgOnSocket for MyMsg { fn msg_size() -> usize { [ ::msg_size() as usize, @@ -641,8 +659,8 @@ mod tests { ].iter() .max().unwrap().clone() as usize } - unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> - MsgResult<(Self, usize)> { + unsafe fn read_from_buffer(buffer: &[u8], fds: &[std::os::unix::io::RawFd]) -> + _msg_socket::MsgResult<(Self, usize)> { let v = buffer[0]; match v { 0u8 => { @@ -672,13 +690,14 @@ mod tests { Ok((MyMsg::C{f0, f1}, __fd_offset)) } _ => { - Err(MsgError::InvalidType) + Err(_msg_socket::MsgError::InvalidType) } } } fn write_to_buffer(&self, buffer: &mut [u8], - fds: &mut [RawFd]) -> MsgResult { + fds: &mut [std::os::unix::io::RawFd]) + -> _msg_socket::MsgResult { match self { MyMsg::A(enum_variant_tmp0) => { buffer[0] = 0u8; diff --git a/msg_socket/src/lib.rs b/msg_socket/src/lib.rs index c4d60df9dc..e1a1b484e1 100644 --- a/msg_socket/src/lib.rs +++ b/msg_socket/src/lib.rs @@ -199,106 +199,3 @@ impl MsgReceiver for UnlinkMsgSocket {} impl MsgSender for Sender {} impl MsgReceiver for Receiver {} - -#[cfg(test)] -mod tests { - use super::*; - use sys_util::EventFd; - - #[derive(MsgOnSocket)] - struct Request { - field0: u8, - field1: EventFd, - field2: u32, - } - - #[derive(MsgOnSocket)] - enum Response { - A(u8), - B, - C(u32, EventFd), - D([u8; 4]), - E { f0: u8, f1: u32 }, - } - - #[derive(MsgOnSocket)] - struct Message(u8, u16, EventFd); - - #[test] - fn sock_send_recv_struct() { - let (req, res) = pair::().unwrap(); - let e0 = EventFd::new().unwrap(); - let e1 = e0.try_clone().unwrap(); - req.send(&Request { - field0: 2, - field1: e0, - field2: 0xf0f0, - }).unwrap(); - let r = res.recv().unwrap(); - assert_eq!(r.field0, 2); - assert_eq!(r.field2, 0xf0f0); - r.field1.write(0x0f0f).unwrap(); - assert_eq!(e1.read().unwrap(), 0x0f0f); - } - - #[test] - fn sock_send_recv_enum() { - let (req, res) = pair::().unwrap(); - let e0 = EventFd::new().unwrap(); - let e1 = e0.try_clone().unwrap(); - res.send(&Response::C(0xf0f0, e0)).unwrap(); - let r = req.recv().unwrap(); - match r { - Response::C(v, efd) => { - assert_eq!(v, 0xf0f0); - efd.write(0x0f0f).unwrap(); - } - _ => panic!("wrong type"), - }; - assert_eq!(e1.read().unwrap(), 0x0f0f); - - res.send(&Response::B).unwrap(); - match req.recv().unwrap() { - Response::B => {} - _ => panic!("Wrong enum type"), - }; - - res.send(&Response::A(0x3)).unwrap(); - match req.recv().unwrap() { - Response::A(v) => assert_eq!(v, 0x3), - _ => panic!("Wrong enum type"), - }; - - res.send(&Response::D([0, 1, 2, 3])).unwrap(); - match req.recv().unwrap() { - Response::D(v) => assert_eq!(v, [0, 1, 2, 3]), - _ => panic!("Wrong enum type"), - }; - - res.send(&Response::E { - f0: 0x12, - f1: 0x0f0f, - }).unwrap(); - match req.recv().unwrap() { - Response::E { f0, f1 } => { - assert_eq!(f0, 0x12); - assert_eq!(f1, 0x0f0f); - } - _ => panic!("Wrong enum type"), - }; - } - - #[test] - fn sock_send_recv_tuple() { - let (req, res) = pair::().unwrap(); - let e0 = EventFd::new().unwrap(); - let e1 = e0.try_clone().unwrap(); - req.send(&Message(1, 0x12, e0)).unwrap(); - let r = res.recv().unwrap(); - assert_eq!(r.0, 1); - assert_eq!(r.1, 0x12); - r.2.write(0x0f0f).unwrap(); - assert_eq!(e1.read().unwrap(), 0x0f0f); - } - -} diff --git a/msg_socket/tests/enum.rs b/msg_socket/tests/enum.rs new file mode 100644 index 0000000000..d63cfc15f9 --- /dev/null +++ b/msg_socket/tests/enum.rs @@ -0,0 +1,66 @@ +extern crate msg_on_socket_derive; +extern crate msg_socket; +extern crate sys_util; + +use sys_util::EventFd; + +use msg_socket::*; + +#[derive(MsgOnSocket)] +struct DummyRequest {} + +#[derive(MsgOnSocket)] +enum Response { + A(u8), + B, + C(u32, EventFd), + D([u8; 4]), + E { f0: u8, f1: u32 }, +} + +#[test] +fn sock_send_recv_enum() { + let (req, res) = pair::().unwrap(); + let e0 = EventFd::new().unwrap(); + let e1 = e0.try_clone().unwrap(); + res.send(&Response::C(0xf0f0, e0)).unwrap(); + let r = req.recv().unwrap(); + match r { + Response::C(v, efd) => { + assert_eq!(v, 0xf0f0); + efd.write(0x0f0f).unwrap(); + } + _ => panic!("wrong type"), + }; + assert_eq!(e1.read().unwrap(), 0x0f0f); + + res.send(&Response::B).unwrap(); + match req.recv().unwrap() { + Response::B => {} + _ => panic!("Wrong enum type"), + }; + + res.send(&Response::A(0x3)).unwrap(); + match req.recv().unwrap() { + Response::A(v) => assert_eq!(v, 0x3), + _ => panic!("Wrong enum type"), + }; + + res.send(&Response::D([0, 1, 2, 3])).unwrap(); + match req.recv().unwrap() { + Response::D(v) => assert_eq!(v, [0, 1, 2, 3]), + _ => panic!("Wrong enum type"), + }; + + res.send(&Response::E { + f0: 0x12, + f1: 0x0f0f, + }).unwrap(); + match req.recv().unwrap() { + Response::E { f0, f1 } => { + assert_eq!(f0, 0x12); + assert_eq!(f1, 0x0f0f); + } + _ => panic!("Wrong enum type"), + }; +} diff --git a/msg_socket/tests/struct.rs b/msg_socket/tests/struct.rs new file mode 100644 index 0000000000..ab97a6f6bc --- /dev/null +++ b/msg_socket/tests/struct.rs @@ -0,0 +1,34 @@ +extern crate msg_on_socket_derive; +extern crate msg_socket; +extern crate sys_util; + +use sys_util::EventFd; + +use msg_socket::*; + +#[derive(MsgOnSocket)] +struct Request { + field0: u8, + field1: EventFd, + field2: u32, +} + +#[derive(MsgOnSocket)] +struct DummyResponse {} + +#[test] +fn sock_send_recv_struct() { + let (req, res) = pair::().unwrap(); + let e0 = EventFd::new().unwrap(); + let e1 = e0.try_clone().unwrap(); + req.send(&Request { + field0: 2, + field1: e0, + field2: 0xf0f0, + }).unwrap(); + let r = res.recv().unwrap(); + assert_eq!(r.field0, 2); + assert_eq!(r.field2, 0xf0f0); + r.field1.write(0x0f0f).unwrap(); + assert_eq!(e1.read().unwrap(), 0x0f0f); +} diff --git a/msg_socket/tests/tuple.rs b/msg_socket/tests/tuple.rs new file mode 100644 index 0000000000..717cd91ab4 --- /dev/null +++ b/msg_socket/tests/tuple.rs @@ -0,0 +1,23 @@ +extern crate msg_on_socket_derive; +extern crate msg_socket; +extern crate sys_util; + +use sys_util::EventFd; + +use msg_socket::*; + +#[derive(MsgOnSocket)] +struct Message(u8, u16, EventFd); + +#[test] +fn sock_send_recv_tuple() { + let (req, res) = pair::().unwrap(); + let e0 = EventFd::new().unwrap(); + let e1 = e0.try_clone().unwrap(); + req.send(&Message(1, 0x12, e0)).unwrap(); + let r = res.recv().unwrap(); + assert_eq!(r.0, 1); + assert_eq!(r.1, 0x12); + r.2.write(0x0f0f).unwrap(); + assert_eq!(e1.read().unwrap(), 0x0f0f); +}