vmm_vhost: vhost_user: Simplify send_iovec_all algorithm

BUG=b:204720423
TEST=cargo test --all-features

Change-Id: I554526fb39fb5f2aad14189d4825033290d1d6d4
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/3268264
Tested-by: kokoro <noreply+kokoro@google.com>
Commit-Queue: Keiichi Watanabe <keiichiw@chromium.org>
Reviewed-by: Chirantan Ekbote <chirantan@chromium.org>
This commit is contained in:
Keiichi Watanabe 2021-10-30 02:02:19 +09:00 committed by Commit Bot
parent a205dc949a
commit a939c0c77f
2 changed files with 56 additions and 25 deletions

@ -1 +1 @@
Subproject commit 77383c711a1a7c183becb6bd38dbd8040c4ffd60
Subproject commit 5cc0b4179f673dfd01803c26c8e803e6ced07e48

View file

@ -137,25 +137,25 @@ impl<R: Req> Endpoint<R> {
/// * - number of bytes sent on success
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
pub fn send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
pub fn send_iovec_all(
&mut self,
mut iovs: &mut [&[u8]],
mut fds: Option<&[RawFd]>,
) -> Result<usize> {
// Guarantee that `iovs` becomes empty if it doesn't contain any data.
advance_slices(&mut iovs, 0);
let mut data_sent = 0;
let mut data_total = 0;
let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.len()).collect();
for len in &iov_lens {
data_total += len;
}
while (data_total - data_sent) > 0 {
let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_sent);
let iov = &iovs[nr_skip][offset..];
let data = &[&[iov], &iovs[(nr_skip + 1)..]].concat();
let sfds = if data_sent == 0 { fds } else { None };
let sent = self.send_iovec(data, sfds);
match sent {
Ok(0) => return Ok(data_sent),
Ok(n) => data_sent += n,
while !iovs.is_empty() {
match self.send_iovec(iovs, fds) {
Ok(0) => {
break;
}
Ok(n) => {
data_sent += n;
fds = None;
advance_slices(&mut iovs, n);
}
Err(e) => match e {
Error::SocketRetry(_) => {}
_ => return Err(e),
@ -190,13 +190,13 @@ impl<R: Req> Endpoint<R> {
fds: Option<&[RawFd]>,
) -> Result<()> {
// Safe because there can't be other mutable referance to hdr.
let iovs = unsafe {
let mut iovs = unsafe {
[slice::from_raw_parts(
hdr as *const VhostUserMsgHeader<R> as *const u8,
mem::size_of::<VhostUserMsgHeader<R>>(),
)]
};
let bytes = self.send_iovec_all(&iovs[..], fds)?;
let bytes = self.send_iovec_all(&mut iovs[..], fds)?;
if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
return Err(Error::PartialMessage);
}
@ -222,7 +222,7 @@ impl<R: Req> Endpoint<R> {
return Err(Error::OversizedMsg);
}
// Safe because there can't be other mutable referance to hdr and body.
let iovs = unsafe {
let mut iovs = unsafe {
[
slice::from_raw_parts(
hdr as *const VhostUserMsgHeader<R> as *const u8,
@ -231,7 +231,7 @@ impl<R: Req> Endpoint<R> {
slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()),
]
};
let bytes = self.send_iovec_all(&iovs[..], fds)?;
let bytes = self.send_iovec_all(&mut iovs[..], fds)?;
if bytes != mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() {
return Err(Error::PartialMessage);
}
@ -270,7 +270,7 @@ impl<R: Req> Endpoint<R> {
}
// Safe because there can't be other mutable reference to hdr, body and payload.
let iovs = unsafe {
let mut iovs = unsafe {
[
slice::from_raw_parts(
hdr as *const VhostUserMsgHeader<R> as *const u8,
@ -281,7 +281,7 @@ impl<R: Req> Endpoint<R> {
]
};
let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() + len;
let len = self.send_iovec_all(&iovs, fds)?;
let len = self.send_iovec_all(&mut iovs, fds)?;
if len != total {
return Err(Error::PartialMessage);
}
@ -601,6 +601,25 @@ fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) {
(nr_skip, size)
}
// Advance the internal cursor of the slices.
// This is same with a nightly API `IoSlice::advance_slices` but for `[&[u8]]`.
fn advance_slices(bufs: &mut &mut [&[u8]], mut count: usize) {
use std::mem::replace;
let mut idx = 0;
for b in bufs.iter() {
if count < b.len() {
break;
}
count -= b.len();
idx += 1;
}
*bufs = &mut replace(bufs, &mut [])[idx..];
if !bufs.is_empty() {
bufs[0] = &bufs[0][count..];
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -611,6 +630,18 @@ mod tests {
Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
}
#[test]
fn test_advance_slices() {
// Test case from https://doc.rust-lang.org/std/io/struct.IoSlice.html#method.advance_slices
let buf1 = [1; 8];
let buf2 = [2; 16];
let buf3 = [3; 8];
let mut bufs = &mut [&buf1[..], &buf2[..], &buf3[..]][..];
advance_slices(&mut bufs, 10);
assert_eq!(bufs[0], [2; 14].as_ref());
assert_eq!(bufs[1], [3; 8].as_ref());
}
#[test]
fn create_listener() {
let dir = temp_dir();