sys_util: recv entire UnixSeqpacket packets into Vec

This change adds the `recv_*_vec` suite of methods for getting an entire
packet into a `Vec` without needing to know the packet size through some
other means.

TEST=cargo test -p sys_util -p msg_socket
BUG=None

Change-Id: Ia4f931ccb91f6de6ee2103387fd95dfad3d3d38b
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/2034025
Commit-Queue: Zach Reizner <zachr@chromium.org>
Tested-by: Zach Reizner <zachr@chromium.org>
Tested-by: kokoro <noreply+kokoro@google.com>
Auto-Submit: Zach Reizner <zachr@chromium.org>
Reviewed-by: Daniel Verkamp <dverkamp@chromium.org>
Reviewed-by: Stephen Barber <smbarber@chromium.org>
This commit is contained in:
Zach Reizner 2020-01-31 17:17:32 -08:00 committed by Commit Bot
parent 4441c01124
commit 787c84b51b
3 changed files with 133 additions and 15 deletions

View file

@ -145,33 +145,33 @@ pub trait MsgReceiver: AsRef<UnixSeqpacket> {
fn recv(&self) -> MsgResult<Self::M> {
let msg_size = Self::M::msg_size();
let fd_size = Self::M::max_fd_count();
let mut msg_buffer: Vec<u8> = vec![0; msg_size];
let mut fd_buffer: Vec<RawFd> = vec![0; fd_size];
let sock: &UnixSeqpacket = self.as_ref();
let (recv_msg_size, recv_fd_size) = {
let (msg_buffer, fd_buffer) = {
if fd_size == 0 {
let size = sock
.recv(&mut msg_buffer)
.map_err(|e| MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0))))?;
(size, 0)
(
sock.recv_as_vec().map_err(|e| {
MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0)))
})?,
vec![],
)
} else {
sock.recv_with_fds(&mut msg_buffer, &mut fd_buffer)
.map_err(MsgError::Recv)?
sock.recv_as_vec_with_fds()
.map_err(|e| MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0))))?
}
};
if msg_size != recv_msg_size {
if msg_size != msg_buffer.len() {
return Err(MsgError::BadRecvSize {
expected: msg_size,
actual: recv_msg_size,
actual: msg_buffer.len(),
});
}
// Safe because fd buffer is read from socket.
let (v, read_fd_size) = unsafe {
Self::M::read_from_buffer(&msg_buffer[0..recv_msg_size], &fd_buffer[0..recv_fd_size])?
};
if recv_fd_size != read_fd_size {
let (v, read_fd_size) =
unsafe { Self::M::read_from_buffer(&msg_buffer[..], &fd_buffer[..])? };
if fd_buffer.len() != read_fd_size {
return Err(MsgError::NotExpectFd);
}
Ok(v)

View file

@ -16,6 +16,10 @@ use std::path::PathBuf;
use std::ptr::null_mut;
use std::time::Duration;
use libc::{recvfrom, MSG_PEEK, MSG_TRUNC};
use crate::sock_ctrl_msg::{ScmSocket, SCM_SOCKET_MAX_FD_COUNT};
// Offset of sun_path in structure sockaddr_un.
fn sun_path_offset() -> usize {
// Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer.
@ -149,6 +153,28 @@ impl UnixSeqpacket {
}
}
/// Gets the number of bytes in the next packet. This blocks as if `recv` were called,
/// respecting the blocking and timeout settings of the underlying socket.
pub fn next_packet_size(&self) -> io::Result<usize> {
// This form of recvfrom doesn't modify any data because all null pointers are used. We only
// use the return value and check for errors on an FD owned by this structure.
let ret = unsafe {
recvfrom(
self.fd,
null_mut(),
0,
MSG_TRUNC | MSG_PEEK,
null_mut(),
null_mut(),
)
};
if ret < 0 {
Err(io::Error::last_os_error())
} else {
Ok(ret as usize)
}
}
/// Write data from a given buffer to the socket fd
///
/// # Arguments
@ -193,6 +219,52 @@ impl UnixSeqpacket {
}
}
/// Read data from the socket fd to a given `Vec`, resizing it to the received packet's size.
///
/// # Arguments
/// * `buf` - A mut reference to a `Vec` to resize and read into.
///
/// # Errors
/// Returns error when `libc::read` or `get_readable_bytes` failed.
pub fn recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()> {
let packet_size = self.next_packet_size()?;
buf.resize(packet_size, 0);
let read_bytes = self.recv(buf)?;
buf.resize(read_bytes, 0);
Ok(())
}
/// Read data from the socket fd to a new `Vec`.
///
/// # Returns
/// * `vec` - A new `Vec` with the entire received packet.
///
/// # Errors
/// Returns error when `libc::read` or `get_readable_bytes` failed.
pub fn recv_as_vec(&self) -> io::Result<Vec<u8>> {
let mut buf = Vec::new();
self.recv_to_vec(&mut buf)?;
Ok(buf)
}
/// Read data and fds from the socket fd to a new pair of `Vec`.
///
/// # Returns
/// * `Vec<u8>` - A new `Vec` with the entire received packet's bytes.
/// * `Vec<RawFd>` - A new `Vec` with the entire received packet's fds.
///
/// # Errors
/// Returns error when `recv_with_fds` or `get_readable_bytes` failed.
pub fn recv_as_vec_with_fds(&self) -> io::Result<(Vec<u8>, Vec<RawFd>)> {
let packet_size = self.next_packet_size()?;
let mut buf = vec![0; packet_size];
let mut fd_buf = vec![-1; SCM_SOCKET_MAX_FD_COUNT];
let (read_bytes, read_fds) = self.recv_with_fds(&mut buf, &mut fd_buf)?;
buf.resize(read_bytes, 0);
fd_buf.resize(read_fds, -1);
Ok((buf, fd_buf))
}
fn set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
let timeval = match timeout {
Some(t) => {
@ -412,6 +484,7 @@ impl Drop for UnlinkUnixSeqpacketListener {
mod tests {
use super::*;
use std::env;
use std::io::ErrorKind;
use std::path::PathBuf;
fn tmpdir() -> PathBuf {
@ -584,4 +657,45 @@ mod tests {
assert_eq!(s1.get_readable_bytes().unwrap(), 0);
assert_eq!(s2.get_readable_bytes().unwrap(), 0);
}
#[test]
fn unix_seqpacket_next_packet_size() {
let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
let data1 = &[0, 1, 2, 3, 4];
s1.send(data1).expect("failed to send data");
assert_eq!(s2.next_packet_size().unwrap(), 5);
s1.set_read_timeout(Some(Duration::from_micros(1)))
.expect("failed to set read timeout");
assert_eq!(
s1.next_packet_size().unwrap_err().kind(),
ErrorKind::WouldBlock
);
drop(s2);
assert_eq!(
s1.next_packet_size().unwrap_err().kind(),
ErrorKind::ConnectionReset
);
}
#[test]
fn unix_seqpacket_recv_to_vec() {
let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
let data1 = &[0, 1, 2, 3, 4];
s1.send(data1).expect("failed to send data");
let recv_data = &mut vec![];
s2.recv_to_vec(recv_data).expect("failed to recv data");
assert_eq!(recv_data, &mut vec![0, 1, 2, 3, 4]);
}
#[test]
fn unix_seqpacket_recv_as_vec() {
let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
let data1 = &[0, 1, 2, 3, 4];
s1.send(data1).expect("failed to send data");
let recv_data = s2.recv_as_vec().expect("failed to recv data");
assert_eq!(recv_data, vec![0, 1, 2, 3, 4]);
}
}

View file

@ -213,6 +213,9 @@ fn raw_recvmsg(fd: RawFd, in_data: &mut [u8], in_fds: &mut [RawFd]) -> Result<(u
Ok((total_read as usize, in_fds_count))
}
/// The maximum number of FDs that can be sent in a single send.
pub const SCM_SOCKET_MAX_FD_COUNT: usize = 253;
/// Trait for file descriptors can send and receive socket control messages via `sendmsg` and
/// `recvmsg`.
pub trait ScmSocket {
@ -292,6 +295,7 @@ impl ScmSocket for UnixStream {
self.as_raw_fd()
}
}
impl ScmSocket for UnixSeqpacket {
fn socket_fd(&self) -> RawFd {
self.as_raw_fd()