Skip to content

Commit

Permalink
Use u32 for Vsock related buffer sizes
Browse files Browse the repository at this point in the history
Signed-off-by: River Phillips <riverphillips1@gmail.com>
  • Loading branch information
RiverPhillips committed Sep 8, 2024
1 parent 87a03c7 commit 71f0957
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 70 deletions.
81 changes: 51 additions & 30 deletions src/vmm/src/devices/virtio/iovec.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use std::io;
use std::io::ErrorKind;

use libc::{c_void, iovec, size_t};
Expand Down Expand Up @@ -140,23 +141,26 @@ impl IoVecBuffer {
pub fn read_exact_volatile_at(
&self,
mut buf: &mut [u8],
offset: usize,
offset: u32,
) -> Result<(), VolatileMemoryError> {
if offset < self.len() as usize {
let expected = buf.len();
if offset < self.len() {
let expected: u32 = match buf.len().try_into() {
Ok(len) => len,
Err(_) => return Err(VolatileMemoryError::IOError(io::Error::new(ErrorKind::InvalidInput, "Cannot read more than u32::MAX bytes from a descriptor chain."))),,
};
let bytes_read = self.read_volatile_at(&mut buf, offset, expected)?;

if bytes_read != expected {
return Err(VolatileMemoryError::PartialBuffer {
expected,
completed: bytes_read,
expected: expected.try_into().unwrap(),
completed: bytes_read.try_into().unwrap(),
});
}

Ok(())
} else {
// If `offset` is past size, there's nothing to read.
Err(VolatileMemoryError::OutOfBounds { addr: offset })
Err(VolatileMemoryError::OutOfBounds { addr: offset.try_into().unwrap() })
}
}

Expand All @@ -166,11 +170,12 @@ impl IoVecBuffer {
pub fn read_volatile_at<W: WriteVolatile>(
&self,
dst: &mut W,
mut offset: usize,
mut len: usize,
) -> Result<usize, VolatileMemoryError> {
let mut total_bytes_read = 0;

offset: u32,
mut len: u32,
) -> Result<u32, VolatileMemoryError> {
let mut total_bytes_read: u32 = 0;
//`iov.iov_len` is a `usize` but it gets assigned from `DescriptorChain::len` which is a `u32`, so the guest cannot pass to us something that is bigger than `u32`. As a result we can safely cast it to `u32`.
let mut offset = u32::try_into(offset).unwrap();
for iov in &self.vecs {
if len == 0 {
break;
Expand All @@ -186,25 +191,27 @@ impl IoVecBuffer {
// all iovecs contained point towards valid ranges of guest memory
unsafe { VolatileSlice::new(iov.iov_base.cast(), iov.iov_len).offset(offset)? };
offset = 0;

if slice.len() > len {
slice = slice.subslice(0, len)?;
{
let len = len.try_into().unwrap();
if slice.len() > len {
slice = slice.subslice(0, len)?;
}
}

let bytes_read = loop {
let bytes_read: u32 = loop {
match dst.write_volatile(&slice) {
Err(VolatileMemoryError::IOError(err))
if err.kind() == ErrorKind::Interrupted =>
{
continue
}
Ok(bytes_read) => break bytes_read,
Ok(bytes_read) => break bytes_read.try_into().unwrap(),
Err(volatile_memory_error) => return Err(volatile_memory_error),
}
};
total_bytes_read += bytes_read;

if bytes_read < slice.len() {
if bytes_read < slice.len().try_into().unwrap() {
break;
}
len -= bytes_read;
Expand Down Expand Up @@ -307,27 +314,30 @@ impl IoVecBufferMut {
///
/// `Ok(())` if the entire contents of `buf` could be written to this [`IoVecBufferMut`],
/// `Err(VolatileMemoryError::PartialBuffer)` if only part of `buf` could be transferred, and
/// `Err(VolatileMemoryError::OutOfBounds)` if `offset >= self.len()`.
/// `Err(VolatileMemoryError::OutOfBounds)` if `offset >= self.len()` or offset would overflow a u32.
pub fn write_all_volatile_at(
&mut self,
mut buf: &[u8],
offset: usize,
offset: u32,
) -> Result<(), VolatileMemoryError> {
if offset < self.len() as usize {
let expected = buf.len();
if offset < self.len() {
let expected = match buf.len().try_into() {
Ok(len) => len,
Err(_) => return Err(VolatileMemoryError::IOError(io::Error::new(ErrorKind::InvalidInput, "Cannot read more than u32::MAX bytes from a descriptor chain."))),
};
let bytes_written = self.write_volatile_at(&mut buf, offset, expected)?;

if bytes_written != expected {
return Err(VolatileMemoryError::PartialBuffer {
expected,
completed: bytes_written,
expected: expected.try_into().unwrap(),
completed: bytes_written.try_into().unwrap(),
});
}

Ok(())
} else {
// We cannot write past the end of the `IoVecBufferMut`.
Err(VolatileMemoryError::OutOfBounds { addr: offset })
Err(VolatileMemoryError::OutOfBounds { addr: offset.try_into().unwrap() })
}
}

Expand All @@ -337,10 +347,18 @@ impl IoVecBufferMut {
pub fn write_volatile_at<W: ReadVolatile>(
&mut self,
src: &mut W,
mut offset: usize,
mut len: usize,
) -> Result<usize, VolatileMemoryError> {
offset: u32,
len: u32,
) -> Result<u32, VolatileMemoryError> {
let mut total_bytes_read = 0;
let mut len: usize = match len.try_into() {
Ok(len) => len,
Err(_) => return Err(VolatileMemoryError::IOError(io::Error::new(ErrorKind::InvalidInput, "Cannot read more than u32::MAX bytes from a descriptor chain."))),
};
let mut offset: usize = match offset.try_into() {
Ok(offset) => offset,
Err(_) => return Err(VolatileMemoryError::IOError(io::Error::new(ErrorKind::InvalidInput, "Cannot read more than u32::MAX bytes from a descriptor chain."))),
};

for iov in &self.vecs {
if len == 0 {
Expand All @@ -355,7 +373,7 @@ impl IoVecBufferMut {
let mut slice =
// SAFETY: the constructor IoVecBufferMut::from_descriptor_chain ensures that
// all iovecs contained point towards valid ranges of guest memory
unsafe { VolatileSlice::new(iov.iov_base.cast(), iov.iov_len).offset(offset)? };
unsafe { VolatileSlice::new(iov.iov_base.cast(), iov.iov_len).offset(offset.into())? };
offset = 0;

if slice.len() > len {
Expand All @@ -381,7 +399,10 @@ impl IoVecBufferMut {
len -= bytes_read;
}

Ok(total_bytes_read)
match total_bytes_read.try_into() {
Ok(bytes_read) => Ok(bytes_read),
Err(_) => return Err(VolatileMemoryError::IOError(io::Error::new(ErrorKind::InvalidInput, "Cannot read more than u32::MAX bytes from a descriptor chain."))),
}
}
}

Expand Down Expand Up @@ -658,7 +679,7 @@ mod tests {
// 5 bytes at offset 252 (only 4 bytes left).
test_vec4[60..64].copy_from_slice(&buf[0..4]);
assert_eq!(
iovec.write_volatile_at(&mut &*buf, 252, buf.len()).unwrap(),
iovec.write_volatile_at(&mut &*buf, 252, buf.len() as u32).unwrap(),
4
);
vq.dtable[0].check_data(&test_vec1);
Expand Down
6 changes: 3 additions & 3 deletions src/vmm/src/devices/virtio/net/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ impl Net {
net_metrics: &NetDeviceMetrics,
) -> Result<bool, NetError> {
// Read the frame headers from the IoVecBuffer
let max_header_len = headers.len();
let max_header_len = headers.len().try_into().unwrap();
let header_len = frame_iovec
.read_volatile_at(&mut &mut *headers, 0, max_header_len)
.map_err(|err| {
Expand All @@ -456,7 +456,7 @@ impl Net {
NetError::VnetHeaderMissing
})?;

let headers = frame_bytes_from_buf(&headers[..header_len]).map_err(|e| {
let headers = frame_bytes_from_buf(&headers[..header_len as usize]).map_err(|e| {
error!("VNET headers missing in TX frame");
net_metrics.tx_malformed_frames.inc();
e
Expand All @@ -468,7 +468,7 @@ impl Net {
// Ok to unwrap here, because we are passing a buffer that has the exact size
// of the `IoVecBuffer` minus the VNET headers.
frame_iovec
.read_exact_volatile_at(&mut frame, vnet_hdr_len())
.read_exact_volatile_at(&mut frame, vnet_hdr_len().try_into().unwrap())
.unwrap();
let _ = ns.detour_frame(&frame);
METRICS.mmds.rx_accepted.inc();
Expand Down
16 changes: 8 additions & 8 deletions src/vmm/src/devices/virtio/vsock/csm/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ where
/// Raw data can either be sent straight to the host stream, or to our TX buffer, if the
/// former fails.
fn send_bytes(&mut self, pkt: &VsockPacket) -> Result<(), VsockError> {
let len = pkt.len() as usize;
let len = pkt.len();

// If there is data in the TX buffer, that means we're already registered for EPOLLOUT
// events on the underlying stream. Therefore, there's no point in attempting a write
Expand Down Expand Up @@ -635,7 +635,7 @@ where
};
// Move the "forwarded bytes" counter ahead by how much we were able to send out.
// Safe to unwrap because the maximum value is pkt.len(), which is a u32.
self.fwd_cnt += wrap_usize_to_u32(written);
self.fwd_cnt += wrap_usize_to_u32(written.try_into().unwrap());
METRICS.tx_bytes_count.add(written as u64);

// If we couldn't write the whole slice, we'll need to push the remaining data to our
Expand All @@ -662,8 +662,8 @@ where

/// Get the maximum number of bytes that we can send to our peer, without overflowing its
/// buffer.
fn peer_avail_credit(&self) -> usize {
(Wrapping(self.peer_buf_alloc) - (self.rx_cnt - self.peer_fwd_cnt)).0 as usize
fn peer_avail_credit(&self) -> u32 {
(Wrapping(self.peer_buf_alloc) - (self.rx_cnt - self.peer_fwd_cnt)).0
}

/// Prepare a packet header for transmission to our peer.
Expand Down Expand Up @@ -916,7 +916,7 @@ mod tests {
assert!(credit < self.conn.peer_buf_alloc);
self.conn.peer_fwd_cnt = Wrapping(0);
self.conn.rx_cnt = Wrapping(self.conn.peer_buf_alloc - credit);
assert_eq!(self.conn.peer_avail_credit(), credit as usize);
assert_eq!(self.conn.peer_avail_credit(), credit);
}

fn send(&mut self) {
Expand All @@ -941,11 +941,11 @@ mod tests {
}

fn init_data_tx_pkt(&mut self, mut data: &[u8]) -> &VsockPacket {
assert!(data.len() <= self.tx_pkt.buf_size());
assert!(data.len() <= self.tx_pkt.buf_size() as usize);
self.init_tx_pkt(uapi::VSOCK_OP_RW, u32::try_from(data.len()).unwrap());

let len = data.len();
self.rx_pkt.read_at_offset_from(&mut data, 0, len).unwrap();
self.rx_pkt.read_at_offset_from(&mut data, 0, len.try_into().unwrap()).unwrap();
&self.tx_pkt
}
}
Expand Down Expand Up @@ -1282,7 +1282,7 @@ mod tests {
ctx.set_stream(stream);

// Fill up the TX buffer.
let data = vec![0u8; ctx.tx_pkt.buf_size()];
let data = vec![0u8; ctx.tx_pkt.buf_size() as usize];
ctx.init_data_tx_pkt(data.as_slice());
for _i in 0..(csm_defs::CONN_TX_BUF_SIZE as usize / data.len()) {
ctx.send();
Expand Down
50 changes: 25 additions & 25 deletions src/vmm/src/devices/virtio/vsock/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,33 +219,33 @@ impl VsockPacket {
///
/// Return value will equal the total length of the underlying descriptor chain's buffers,
/// minus the length of the vsock header.
pub fn buf_size(&self) -> usize {
pub fn buf_size(&self) -> u32 {
let chain_length = match self.buffer {
VsockPacketBuffer::Tx(ref iovec_buf) => iovec_buf.len(),
VsockPacketBuffer::Rx(ref iovec_buf) => iovec_buf.len(),
};
(chain_length - VSOCK_PKT_HDR_SIZE) as usize
chain_length - VSOCK_PKT_HDR_SIZE
}

pub fn read_at_offset_from<T: ReadVolatile + Debug>(
&mut self,
src: &mut T,
offset: usize,
count: usize,
) -> Result<usize, VsockError> {
offset: u32,
count: u32,
) -> Result<u32, VsockError> {
match self.buffer {
VsockPacketBuffer::Tx(_) => Err(VsockError::UnwritableDescriptor),
VsockPacketBuffer::Rx(ref mut buffer) => {
if count
> (buffer.len() as usize)
.saturating_sub(VSOCK_PKT_HDR_SIZE as usize)
> buffer.len()
.saturating_sub(VSOCK_PKT_HDR_SIZE)
.saturating_sub(offset)
{
return Err(VsockError::GuestMemoryBounds);
}

buffer
.write_volatile_at(src, offset + VSOCK_PKT_HDR_SIZE as usize, count)
buffer
.write_volatile_at(src, offset + VSOCK_PKT_HDR_SIZE, count)
.map_err(|err| VsockError::GuestMemoryMmap(GuestMemoryError::from(err)))
}
}
Expand All @@ -254,21 +254,21 @@ impl VsockPacket {
pub fn write_from_offset_to<T: WriteVolatile + Debug>(
&self,
dst: &mut T,
offset: usize,
count: usize,
) -> Result<usize, VsockError> {
offset: u32,
count: u32,
) -> Result<u32, VsockError> {
match self.buffer {
VsockPacketBuffer::Tx(ref buffer) => {
if count
> (buffer.len() as usize)
.saturating_sub(VSOCK_PKT_HDR_SIZE as usize)
> (buffer.len())
.saturating_sub(VSOCK_PKT_HDR_SIZE)
.saturating_sub(offset)
{
return Err(VsockError::GuestMemoryBounds);
}

buffer
.read_volatile_at(dst, offset + VSOCK_PKT_HDR_SIZE as usize, count)
.read_volatile_at(dst, offset + VSOCK_PKT_HDR_SIZE, count)
.map_err(|err| VsockError::GuestMemoryMmap(GuestMemoryError::from(err)))
}
VsockPacketBuffer::Rx(_) => Err(VsockError::UnreadableDescriptor),
Expand Down Expand Up @@ -539,7 +539,7 @@ mod tests {
.unwrap();
assert_eq!(
pkt.buf_size(),
handler_ctx.guest_rxvq.dtable[1].len.get() as usize
handler_ctx.guest_rxvq.dtable[1].len.get()
);
}

Expand Down Expand Up @@ -646,35 +646,35 @@ mod tests {
.unwrap();

let buf_desc = &mut handler_ctx.guest_rxvq.dtable[1];
assert_eq!(pkt.buf_size(), buf_desc.len.get() as usize);
let zeros = vec![0_u8; pkt.buf_size()];
assert_eq!(pkt.buf_size(), buf_desc.len.get());
let zeros = vec![0_u8; pkt.buf_size() as usize];
let data: Vec<u8> = (0..pkt.buf_size())
.map(|i| ((i as u64) & 0xff) as u8)
.collect();
for offset in 0..pkt.buf_size() {
buf_desc.set_data(&zeros);

let mut expected_data = zeros[..offset].to_vec();
expected_data.extend_from_slice(&data[..pkt.buf_size() - offset]);
let mut expected_data = zeros[..offset as usize].to_vec();
expected_data.extend_from_slice(&data[..(pkt.buf_size() - offset) as usize]);

pkt.read_at_offset_from(&mut data.as_slice(), offset, pkt.buf_size() - offset)
.unwrap();

buf_desc.check_data(&expected_data);

let mut buf = vec![0; pkt.buf_size()];
let mut buf = vec![0; pkt.buf_size() as usize];
pkt2.write_from_offset_to(&mut buf.as_mut_slice(), offset, pkt.buf_size() - offset)
.unwrap();
assert_eq!(&buf[..pkt.buf_size() - offset], &expected_data[offset..]);
assert_eq!(&buf[..(pkt.buf_size() - offset) as usize], &expected_data[offset as usize..]);
}

let oob_cases = vec![
(1, pkt.buf_size()),
(pkt.buf_size(), 1),
(usize::MAX, 1),
(1, usize::MAX),
(u32::MAX, 1),
(1, u32::MAX),
];
let mut buf = vec![0; pkt.buf_size()];
let mut buf = vec![0; pkt.buf_size() as usize];
for (offset, count) in oob_cases {
let res = pkt.read_at_offset_from(&mut data.as_slice(), offset, count);
assert!(matches!(res, Err(VsockError::GuestMemoryBounds)));
Expand Down
Loading

0 comments on commit 71f0957

Please sign in to comment.