diff --git a/Cargo.lock b/Cargo.lock index e9da7f99a8..41843ee9d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -447,6 +447,7 @@ dependencies = [ "futures 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "libc 0.2.44 (registry+https://github.com/rust-lang/crates.io-index)", "msg_on_socket_derive 0.1.0", + "sync 0.1.0", "sys_util 0.1.0", ] diff --git a/msg_socket/Cargo.toml b/msg_socket/Cargo.toml index c803bed487..80eba0b493 100644 --- a/msg_socket/Cargo.toml +++ b/msg_socket/Cargo.toml @@ -11,3 +11,4 @@ futures = "*" libc = "*" msg_on_socket_derive = { path = "msg_on_socket_derive" } sys_util = { path = "../sys_util" } +sync = { path = "../sync" } diff --git a/msg_socket/src/msg_on_socket.rs b/msg_socket/src/msg_on_socket.rs index b925630695..097bc48f63 100644 --- a/msg_socket/src/msg_on_socket.rs +++ b/msg_socket/src/msg_on_socket.rs @@ -10,8 +10,10 @@ use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream}; use std::ptr::drop_in_place; use std::result; +use std::sync::Arc; use data_model::*; +use sync::Mutex; use sys_util::{Error as SysError, EventFd}; #[derive(Debug, PartialEq)] @@ -209,6 +211,50 @@ impl MsgOnSocket for Option { } } +impl MsgOnSocket for Mutex { + fn uses_fd() -> bool { + T::uses_fd() + } + + fn msg_size(&self) -> usize { + self.lock().msg_size() + } + + fn fd_count(&self) -> usize { + self.lock().fd_count() + } + + unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { + T::read_from_buffer(buffer, fds).map(|(v, count)| (Mutex::new(v), count)) + } + + fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult { + self.lock().write_to_buffer(buffer, fds) + } +} + +impl MsgOnSocket for Arc { + fn uses_fd() -> bool { + T::uses_fd() + } + + fn msg_size(&self) -> usize { + (**self).msg_size() + } + + fn fd_count(&self) -> usize { + (**self).fd_count() + } + + unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { + T::read_from_buffer(buffer, fds).map(|(v, count)| (Arc::new(v), count)) + } + + fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult { + (**self).write_to_buffer(buffer, fds) + } +} + impl MsgOnSocket for () { fn fixed_size() -> Option { Some(0)