Skip to content

Commit

Permalink
Change IoVecBuffer[Mut] len to u32
Browse files Browse the repository at this point in the history
This commit changes the iovec len primitive to match descriptor chain's
(u32). This removes some ugly casting and potential overflow problems,
and allows us to upcast when needed in a non-lossy manor.

Signed-off-by: Brandon Pike <bpike@amazon.com>
  • Loading branch information
brandonpike committed May 20, 2024
1 parent 3853362 commit feae33f
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 45 deletions.
63 changes: 38 additions & 25 deletions src/vmm/src/devices/virtio/iovec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pub enum IoVecError {
WriteOnlyDescriptor,
/// Tried to create an 'IoVecMut` from a read-only descriptor chain
ReadOnlyDescriptor,
/// Tried to create an `IoVec` or `IoVecMut` from a descriptor chain that was too large
OverflowedDescriptor,
/// Guest memory error: {0}
GuestMemory(#[from] GuestMemoryError),
}
Expand All @@ -40,14 +42,14 @@ pub struct IoVecBuffer {
// container of the memory regions included in this IO vector
vecs: IoVecVec,
// Total length of the IoVecBuffer
len: usize,
len: u32,
}

impl IoVecBuffer {
/// Create an `IoVecBuffer` from a `DescriptorChain`
pub fn from_descriptor_chain(head: DescriptorChain) -> Result<Self, IoVecError> {
let mut vecs = IoVecVec::new();
let mut len = 0usize;
let mut len = 0u32;

let mut next_descriptor = Some(head);
while let Some(desc) = next_descriptor {
Expand All @@ -68,7 +70,9 @@ impl IoVecBuffer {
iov_base,
iov_len: desc.len as size_t,
});
len += desc.len as usize;
len = len
.checked_add(desc.len)
.ok_or(IoVecError::OverflowedDescriptor)?;

next_descriptor = desc.next_descriptor();
}
Expand All @@ -77,7 +81,7 @@ impl IoVecBuffer {
}

/// Get the total length of the memory regions covered by this `IoVecBuffer`
pub(crate) fn len(&self) -> usize {
pub(crate) fn len(&self) -> u32 {
self.len
}

Expand Down Expand Up @@ -106,7 +110,7 @@ impl IoVecBuffer {
mut buf: &mut [u8],
offset: usize,
) -> Result<(), VolatileMemoryError> {
if offset < self.len() {
if offset < self.len() as usize {
let expected = buf.len();
let bytes_read = self.read_volatile_at(&mut buf, offset, expected)?;

Expand Down Expand Up @@ -188,14 +192,14 @@ pub struct IoVecBufferMut {
// container of the memory regions included in this IO vector
vecs: IoVecVec,
// Total length of the IoVecBufferMut
len: usize,
len: u32,
}

impl IoVecBufferMut {
/// Create an `IoVecBufferMut` from a `DescriptorChain`
pub fn from_descriptor_chain(head: DescriptorChain) -> Result<Self, IoVecError> {
let mut vecs = IoVecVec::new();
let mut len = 0usize;
let mut len = 0u32;

for desc in head {
if !desc.is_write_only() {
Expand All @@ -217,14 +221,16 @@ impl IoVecBufferMut {
iov_base,
iov_len: desc.len as size_t,
});
len += desc.len as usize;
len = len
.checked_add(desc.len)
.ok_or(IoVecError::OverflowedDescriptor)?;
}

Ok(Self { vecs, len })
}

/// Get the total length of the memory regions covered by this `IoVecBuffer`
pub(crate) fn len(&self) -> usize {
pub(crate) fn len(&self) -> u32 {
self.len
}

Expand All @@ -244,7 +250,7 @@ impl IoVecBufferMut {
mut buf: &[u8],
offset: usize,
) -> Result<(), VolatileMemoryError> {
if offset < self.len() {
if offset < self.len() as usize {
let expected = buf.len();
let bytes_written = self.write_volatile_at(&mut buf, offset, expected)?;

Expand Down Expand Up @@ -335,18 +341,18 @@ mod tests {
iov_len: buf.len(),
}]
.into(),
len: buf.len(),
len: buf.len().try_into().unwrap(),
}
}
}

impl<'a> From<Vec<&'a [u8]>> for IoVecBuffer {
fn from(buffer: Vec<&'a [u8]>) -> Self {
let mut len = 0;
let mut len = 0_u32;
let vecs = buffer
.into_iter()
.map(|slice| {
len += slice.len();
len += TryInto::<u32>::try_into(slice.len()).unwrap();
iovec {
iov_base: slice.as_ptr() as *mut c_void,
iov_len: slice.len(),
Expand All @@ -366,7 +372,7 @@ mod tests {
iov_len: buf.len(),
}]
.into(),
len: buf.len(),
len: buf.len().try_into().unwrap(),
}
}
}
Expand Down Expand Up @@ -607,7 +613,6 @@ mod verification {

use libc::{c_void, iovec};
use vm_memory::bitmap::BitmapSlice;
use vm_memory::volatile_memory::Error;
use vm_memory::VolatileSlice;

use super::{IoVecBuffer, IoVecBufferMut, IoVecVec};
Expand All @@ -622,10 +627,10 @@ mod verification {
// >= 1.
const MAX_DESC_LENGTH: usize = 4;

fn create_iovecs(mem: *mut u8, size: usize) -> (IoVecVec, usize) {
fn create_iovecs(mem: *mut u8, size: usize) -> (IoVecVec, u32) {
let nr_descs: usize = kani::any_where(|&n| n <= MAX_DESC_LENGTH);
let mut vecs: Vec<iovec> = Vec::with_capacity(nr_descs);
let mut len = 0usize;
let mut len = 0u32;
for _ in 0..nr_descs {
// The `IoVecBuffer(Mut)` constructors ensure that the memory region described by every
// `Descriptor` in the chain is a valid, i.e. it is memory with then guest's memory
Expand All @@ -637,7 +642,7 @@ mod verification {
let iov_base = unsafe { mem.offset(addr.try_into().unwrap()) } as *mut c_void;

vecs.push(iovec { iov_base, iov_len });
len += iov_len;
len += u32::try_from(iov_len).unwrap();
}

(vecs, len)
Expand Down Expand Up @@ -712,7 +717,7 @@ mod verification {
let iov: IoVecBuffer = kani::any();

let mut buf = vec![0; GUEST_MEMORY_SIZE];
let offset: usize = kani::any();
let offset: u32 = kani::any();

// We can't really check the contents that the operation here writes into `buf`, because
// our `IoVecBuffer` being completely arbitrary can contain overlapping memory regions, so
Expand All @@ -724,9 +729,13 @@ mod verification {
// Furthermore, we know our Read-/WriteVolatile implementation above is infallible, so
// provided that the logic inside read_volatile_at is correct, we should always get Ok(...)
assert_eq!(
iov.read_volatile_at(&mut KaniBuffer(&mut buf), offset, GUEST_MEMORY_SIZE)
.unwrap(),
buf.len().min(iov.len().saturating_sub(offset))
iov.read_volatile_at(
&mut KaniBuffer(&mut buf),
offset as usize,
GUEST_MEMORY_SIZE
)
.unwrap(),
buf.len().min(iov.len().saturating_sub(offset) as usize)
);
}

Expand All @@ -737,7 +746,7 @@ mod verification {
let mut iov_mut: IoVecBufferMut = kani::any();

let mut buf = kani::vec::any_vec::<u8, GUEST_MEMORY_SIZE>();
let offset: usize = kani::any();
let offset: u32 = kani::any();

// We can't really check the contents that the operation here writes into `IoVecBufferMut`,
// because our `IoVecBufferMut` being completely arbitrary can contain overlapping memory
Expand All @@ -750,9 +759,13 @@ mod verification {
// provided that the logic inside write_volatile_at is correct, we should always get Ok(...)
assert_eq!(
iov_mut
.write_volatile_at(&mut KaniBuffer(&mut buf), offset, GUEST_MEMORY_SIZE)
.write_volatile_at(
&mut KaniBuffer(&mut buf),
offset as usize,
GUEST_MEMORY_SIZE
)
.unwrap(),
buf.len().min(iov_mut.len().saturating_sub(offset))
buf.len().min(iov_mut.len().saturating_sub(offset) as usize)
);
}
}
10 changes: 5 additions & 5 deletions src/vmm/src/devices/virtio/net/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ impl Net {

if let Some(ns) = mmds_ns {
if ns.is_mmds_frame(headers) {
let mut frame = vec![0u8; frame_iovec.len() - vnet_hdr_len()];
let mut frame = vec![0u8; frame_iovec.len() as usize - vnet_hdr_len()];
// Ok to unwrap here, because we are passing a buffer that has the exact size
// of the `IoVecBuffer` minus the VNET headers.
frame_iovec
Expand All @@ -472,7 +472,7 @@ impl Net {
METRICS.mmds.rx_accepted.inc();

// MMDS frames are not accounted by the rate limiter.
Self::rate_limiter_replenish_op(rate_limiter, frame_iovec.len() as u64);
Self::rate_limiter_replenish_op(rate_limiter, u64::from(frame_iovec.len()));

// MMDS consumed the frame.
return Ok(true);
Expand All @@ -493,7 +493,7 @@ impl Net {
let _metric = net_metrics.tap_write_agg.record_latency_metrics();
match Self::write_tap(tap, frame_iovec) {
Ok(_) => {
let len = frame_iovec.len() as u64;
let len = u64::from(frame_iovec.len());
net_metrics.tx_bytes_count.add(len);
net_metrics.tx_packets_count.inc();
net_metrics.tx_count.inc();
Expand Down Expand Up @@ -609,7 +609,7 @@ impl Net {
};

// We only handle frames that are up to MAX_BUFFER_SIZE
if buffer.len() > MAX_BUFFER_SIZE {
if buffer.len() as usize > MAX_BUFFER_SIZE {
error!("net: received too big frame from driver");
self.metrics.tx_malformed_frames.inc();
tx_queue
Expand All @@ -618,7 +618,7 @@ impl Net {
continue;
}

if !Self::rate_limiter_consume_op(&mut self.tx_rate_limiter, buffer.len() as u64) {
if !Self::rate_limiter_consume_op(&mut self.tx_rate_limiter, u64::from(buffer.len())) {
tx_queue.undo_pop();
self.metrics.tx_rate_limiter_throttled.inc();
break;
Expand Down
2 changes: 1 addition & 1 deletion src/vmm/src/devices/virtio/net/tap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ pub mod tests {

tap.write_iovec(&scattered).unwrap();

let mut read_buf = vec![0u8; scattered.len()];
let mut read_buf = vec![0u8; scattered.len() as usize];
assert!(tap_traffic_simulator.pop_rx_packet(&mut read_buf));
assert_eq!(
&read_buf[..PAYLOAD_SIZE - VNET_HDR_SIZE],
Expand Down
6 changes: 3 additions & 3 deletions src/vmm/src/devices/virtio/rng/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,15 @@ impl Entropy {
return Ok(0);
}

let mut rand_bytes = vec![0; iovec.len()];
let mut rand_bytes = vec![0; iovec.len() as usize];
rand::fill(&mut rand_bytes).map_err(|err| {
METRICS.host_rng_fails.inc();
err
})?;

// It is ok to unwrap here. We are writing `iovec.len()` bytes at offset 0.
iovec.write_all_volatile_at(&rand_bytes, 0).unwrap();
Ok(iovec.len().try_into().unwrap())
Ok(iovec.len())
}

fn process_entropy_queue(&mut self) {
Expand All @@ -142,7 +142,7 @@ impl Entropy {
// Check for available rate limiting budget.
// If not enough budget is available, leave the request descriptor in the queue
// to handle once we do have budget.
if !Self::rate_limit_request(&mut self.rate_limiter, iovec.len() as u64) {
if !Self::rate_limit_request(&mut self.rate_limiter, u64::from(iovec.len())) {
debug!("entropy: throttling entropy queue");
METRICS.entropy_rate_limiter_throttled.inc();
self.queues[RNG_QUEUE].undo_pop();
Expand Down
5 changes: 4 additions & 1 deletion src/vmm/src/devices/virtio/vsock/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ mod defs {
#[rustfmt::skip]
pub enum VsockError {
/** The total length of the descriptor chain ({0}) is too short to hold a packet of length {1} + header */
DescChainTooShortForPacket(usize, u32),
DescChainTooShortForPacket(u32, u32),
/// Empty queue
EmptyQueue,
/// EventFd error: {0}
Expand All @@ -122,6 +122,8 @@ pub enum VsockError {
/** The total length of the descriptor chain ({0}) is less than the number of bytes required\
to hold a vsock packet header.*/
DescChainTooShortForHeader(usize),
/// The descriptor chain length was greater than the max ([u32::MAX])
DescChainOverflow,
/// The vsock header `len` field holds an invalid value: {0}
InvalidPktLen(u32),
/// A data fetch was attempted when no data was available.
Expand All @@ -144,6 +146,7 @@ impl From<IoVecError> for VsockError {
IoVecError::WriteOnlyDescriptor => VsockError::UnreadableDescriptor,
IoVecError::ReadOnlyDescriptor => VsockError::UnwritableDescriptor,
IoVecError::GuestMemory(err) => VsockError::GuestMemoryMmap(err),
IoVecError::OverflowedDescriptor => VsockError::DescChainOverflow,
}
}
}
Expand Down
19 changes: 9 additions & 10 deletions src/vmm/src/devices/virtio/vsock/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl VsockPacket {
return Err(VsockError::InvalidPktLen(hdr.len));
}

if (hdr.len as usize) > buffer.len() - VSOCK_PKT_HDR_SIZE as usize {
if hdr.len > buffer.len() - VSOCK_PKT_HDR_SIZE {
return Err(VsockError::DescChainTooShortForPacket(
buffer.len(),
hdr.len,
Expand All @@ -160,8 +160,8 @@ impl VsockPacket {
pub fn from_rx_virtq_head(chain: DescriptorChain) -> Result<Self, VsockError> {
let buffer = IoVecBufferMut::from_descriptor_chain(chain)?;

if buffer.len() < VSOCK_PKT_HDR_SIZE as usize {
return Err(VsockError::DescChainTooShortForHeader(buffer.len()));
if buffer.len() < VSOCK_PKT_HDR_SIZE {
return Err(VsockError::DescChainTooShortForHeader(buffer.len() as usize));
}

Ok(Self {
Expand Down Expand Up @@ -212,7 +212,7 @@ impl VsockPacket {
VsockPacketBuffer::Tx(ref iovec_buf) => iovec_buf.len(),
VsockPacketBuffer::Rx(ref iovec_buf) => iovec_buf.len(),
};
chain_length - VSOCK_PKT_HDR_SIZE as usize
(chain_length - VSOCK_PKT_HDR_SIZE) as usize
}

pub fn read_at_offset_from<T: ReadVolatile + Debug>(
Expand All @@ -225,8 +225,7 @@ impl VsockPacket {
VsockPacketBuffer::Tx(_) => Err(VsockError::UnwritableDescriptor),
VsockPacketBuffer::Rx(ref mut buffer) => {
if count
> buffer
.len()
> (buffer.len() as usize)
.saturating_sub(VSOCK_PKT_HDR_SIZE as usize)
.saturating_sub(offset)
{
Expand All @@ -249,8 +248,7 @@ impl VsockPacket {
match self.buffer {
VsockPacketBuffer::Tx(ref buffer) => {
if count
> buffer
.len()
> (buffer.len() as usize)
.saturating_sub(VSOCK_PKT_HDR_SIZE as usize)
.saturating_sub(offset)
{
Expand Down Expand Up @@ -427,9 +425,10 @@ mod tests {
.unwrap(),
)
.unwrap();

assert_eq!(
pkt.buf_size(),
handler_ctx.guest_txvq.dtable[1].len.get() as usize
TryInto::<u32>::try_into(pkt.buf_size()).unwrap(),
handler_ctx.guest_txvq.dtable[1].len.get()
);
}

Expand Down

0 comments on commit feae33f

Please sign in to comment.