diff --git a/third_party/minijail b/third_party/minijail index 77383c711a..5cc0b4179f 160000 --- a/third_party/minijail +++ b/third_party/minijail @@ -1 +1 @@ -Subproject commit 77383c711a1a7c183becb6bd38dbd8040c4ffd60 +Subproject commit 5cc0b4179f673dfd01803c26c8e803e6ced07e48 diff --git a/third_party/vmm_vhost/src/vhost_user/connection.rs b/third_party/vmm_vhost/src/vhost_user/connection.rs index ca922bd516..32c95254dc 100644 --- a/third_party/vmm_vhost/src/vhost_user/connection.rs +++ b/third_party/vmm_vhost/src/vhost_user/connection.rs @@ -137,25 +137,25 @@ impl Endpoint { /// * - number of bytes sent on success /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. - pub fn send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result { + pub fn send_iovec_all( + &mut self, + mut iovs: &mut [&[u8]], + mut fds: Option<&[RawFd]>, + ) -> Result { + // Guarantee that `iovs` becomes empty if it doesn't contain any data. + advance_slices(&mut iovs, 0); + let mut data_sent = 0; - let mut data_total = 0; - let iov_lens: Vec = iovs.iter().map(|iov| iov.len()).collect(); - for len in &iov_lens { - data_total += len; - } - - while (data_total - data_sent) > 0 { - let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_sent); - let iov = &iovs[nr_skip][offset..]; - - let data = &[&[iov], &iovs[(nr_skip + 1)..]].concat(); - let sfds = if data_sent == 0 { fds } else { None }; - - let sent = self.send_iovec(data, sfds); - match sent { - Ok(0) => return Ok(data_sent), - Ok(n) => data_sent += n, + while !iovs.is_empty() { + match self.send_iovec(iovs, fds) { + Ok(0) => { + break; + } + Ok(n) => { + data_sent += n; + fds = None; + advance_slices(&mut iovs, n); + } Err(e) => match e { Error::SocketRetry(_) => {} _ => return Err(e), @@ -190,13 +190,13 @@ impl Endpoint { fds: Option<&[RawFd]>, ) -> Result<()> { // Safe because there can't be other mutable referance to hdr. - let iovs = unsafe { + let mut iovs = unsafe { [slice::from_raw_parts( hdr as *const VhostUserMsgHeader as *const u8, mem::size_of::>(), )] }; - let bytes = self.send_iovec_all(&iovs[..], fds)?; + let bytes = self.send_iovec_all(&mut iovs[..], fds)?; if bytes != mem::size_of::>() { return Err(Error::PartialMessage); } @@ -222,7 +222,7 @@ impl Endpoint { return Err(Error::OversizedMsg); } // Safe because there can't be other mutable referance to hdr and body. - let iovs = unsafe { + let mut iovs = unsafe { [ slice::from_raw_parts( hdr as *const VhostUserMsgHeader as *const u8, @@ -231,7 +231,7 @@ impl Endpoint { slice::from_raw_parts(body as *const T as *const u8, mem::size_of::()), ] }; - let bytes = self.send_iovec_all(&iovs[..], fds)?; + let bytes = self.send_iovec_all(&mut iovs[..], fds)?; if bytes != mem::size_of::>() + mem::size_of::() { return Err(Error::PartialMessage); } @@ -270,7 +270,7 @@ impl Endpoint { } // Safe because there can't be other mutable reference to hdr, body and payload. - let iovs = unsafe { + let mut iovs = unsafe { [ slice::from_raw_parts( hdr as *const VhostUserMsgHeader as *const u8, @@ -281,7 +281,7 @@ impl Endpoint { ] }; let total = mem::size_of::>() + mem::size_of::() + len; - let len = self.send_iovec_all(&iovs, fds)?; + let len = self.send_iovec_all(&mut iovs, fds)?; if len != total { return Err(Error::PartialMessage); } @@ -601,6 +601,25 @@ fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) { (nr_skip, size) } +// Advance the internal cursor of the slices. +// This is same with a nightly API `IoSlice::advance_slices` but for `[&[u8]]`. +fn advance_slices(bufs: &mut &mut [&[u8]], mut count: usize) { + use std::mem::replace; + + let mut idx = 0; + for b in bufs.iter() { + if count < b.len() { + break; + } + count -= b.len(); + idx += 1; + } + *bufs = &mut replace(bufs, &mut [])[idx..]; + if !bufs.is_empty() { + bufs[0] = &bufs[0][count..]; + } +} + #[cfg(test)] mod tests { use super::*; @@ -611,6 +630,18 @@ mod tests { Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap() } + #[test] + fn test_advance_slices() { + // Test case from https://doc.rust-lang.org/std/io/struct.IoSlice.html#method.advance_slices + let buf1 = [1; 8]; + let buf2 = [2; 16]; + let buf3 = [3; 8]; + let mut bufs = &mut [&buf1[..], &buf2[..], &buf3[..]][..]; + advance_slices(&mut bufs, 10); + assert_eq!(bufs[0], [2; 14].as_ref()); + assert_eq!(bufs[1], [3; 8].as_ref()); + } + #[test] fn create_listener() { let dir = temp_dir();