diff --git a/devices/src/virtio/descriptor_utils.rs b/devices/src/virtio/descriptor_utils.rs index 3c8ff876e7..b4eff2d1f2 100644 --- a/devices/src/virtio/descriptor_utils.rs +++ b/devices/src/virtio/descriptor_utils.rs @@ -297,14 +297,18 @@ impl<'a> Reader<'a> { impl<'a> io::Read for Reader<'a> { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.buffer.consume(buf.len(), |bufs| { - if let Some(vs) = bufs.first() { + let mut rem = buf; + let mut total = 0; + for vs in bufs { // This is guaranteed by the implementation of `consume`. - debug_assert_eq!(vs.size(), cmp::min(buf.len() as u64, vs.size())); - vs.copy_to(buf); - Ok(vs.size() as usize) - } else { - Ok(0) + debug_assert_eq!(vs.size(), cmp::min(rem.len() as u64, vs.size())); + + vs.copy_to(rem); + let copied = vs.size() as usize; + rem = &mut rem[copied..]; + total += copied; } + Ok(total) }) } } @@ -417,14 +421,18 @@ impl<'a> Writer<'a> { impl<'a> io::Write for Writer<'a> { fn write(&mut self, buf: &[u8]) -> io::Result { self.buffer.consume(buf.len(), |bufs| { - if let Some(vs) = bufs.first() { + let mut rem = buf; + let mut total = 0; + for vs in bufs { // This is guaranteed by the implementation of `consume`. - debug_assert_eq!(vs.size(), cmp::min(buf.len() as u64, vs.size())); - vs.copy_from(buf); - Ok(vs.size() as usize) - } else { - Ok(0) + debug_assert_eq!(vs.size(), cmp::min(rem.len() as u64, vs.size())); + + vs.copy_from(rem); + let copied = vs.size() as usize; + rem = &rem[copied..]; + total += copied; } + Ok(total) }) } @@ -1074,4 +1082,52 @@ mod tests { panic!("successfully split Reader with out of bounds offset"); } } + + #[test] + fn read_full() { + use DescriptorType::*; + + let memory_start_addr = GuestAddress(0x0); + let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap(); + + let chain = create_descriptor_chain( + &memory, + GuestAddress(0x0), + GuestAddress(0x100), + vec![(Readable, 16), (Readable, 16), (Readable, 16)], + 0, + ) + .expect("create_descriptor_chain failed"); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + + let mut buf = vec![0u8; 64]; + assert_eq!( + reader.read(&mut buf[..]).expect("failed to read to buffer"), + 48 + ); + } + + #[test] + fn write_full() { + use DescriptorType::*; + + let memory_start_addr = GuestAddress(0x0); + let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap(); + + let chain = create_descriptor_chain( + &memory, + GuestAddress(0x0), + GuestAddress(0x100), + vec![(Writable, 16), (Writable, 16), (Writable, 16)], + 0, + ) + .expect("create_descriptor_chain failed"); + let mut writer = Writer::new(&memory, chain).expect("failed to create Writer"); + + let buf = vec![0xdeu8; 64]; + assert_eq!( + writer.write(&buf[..]).expect("failed to write from buffer"), + 48 + ); + } }