Skip to content

Commit

Permalink
slight changes
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandru Agache <aagch@amazon.com>
  • Loading branch information
alexandruag committed Jan 20, 2021
1 parent ee95182 commit ec17a51
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 53 deletions.
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ pub enum Error {
InvalidChain,
/// Invalid descriptor index.
InvalidDescriptorIndex,
/// Invalid read-only descriptor encountered.
InvalidReadOnlyDescriptor,
/// Volatile memory related error.
VolatileMemoryError(VolatileMemoryError),
/// Descriptor chain overflow.
Expand All @@ -48,6 +50,7 @@ impl Display for Error {
InvalidIndirectDescriptor => write!(f, "invalid indirect descriptor"),
InvalidIndirectDescriptorTable => write!(f, "invalid indirect descriptor table"),
InvalidDescriptorIndex => write!(f, "invalid descriptor index"),
InvalidReadOnlyDescriptor => write!(f, "invalid read-only descriptor"),
VolatileMemoryError(e) => write!(f, "volatile memory error: {}", e),
DescriptorChainOverflow => write!(
f,
Expand Down
116 changes: 63 additions & 53 deletions src/queue/descriptor_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,65 @@ struct DescriptorChainConsumer<'a> {
}

impl<'a> DescriptorChainConsumer<'a> {
fn new<M: GuestAddressSpace>(
chain: &mut DescriptorChain<M>,
mem: &'a M::M,
readable: bool,
) -> Result<Self> {
Self::collect_desc_chain_buffers(chain, mem, readable).map(|buffers| Self {
buffers,
bytes_consumed: 0,
index: 0,
})
}

fn collect_desc_chain_buffers<M: GuestAddressSpace>(
desc_chain: &mut DescriptorChain<M>,
mem: &'a M::M,
readable: bool,
) -> Result<Vec<VolatileSlice<'a>>> {
let mut total_len: usize = 0;
let mut buffers = Vec::new();

for desc in desc_chain {
if readable && desc.is_write_only() {
break;
}

if !readable && !desc.is_write_only() {
return Err(Error::InvalidReadOnlyDescriptor);
}

// Verify that summing the descriptor sizes does not overflow.
// This can happen if a driver tricks a device into reading more data than
// fits in a `usize`.
total_len = total_len
.checked_add(desc.len() as usize)
.ok_or(Error::DescriptorChainOverflow)?;

let mut len = desc.len() as usize;
let mut addr = desc.addr();
while len > 0 {
let region = mem.find_region(addr).ok_or(Error::FindMemoryRegion)?;
let offset = addr
.checked_sub(region.start_addr().raw_value())
.unwrap()
.raw_value() as usize;
let buf_len = cmp::min(region.len() as usize - offset, len);
let buf = region
.get_slice(MemoryRegionAddress(offset as u64), buf_len as usize)
.map_err(Error::GuestMemory)?;
buffers.push(buf);
len -= buf_len;
addr = addr
.checked_add(buf_len as _)
.ok_or(Error::DescriptorChainOverflow)?;
}
}

Ok(buffers)
}

fn available_bytes(&self) -> usize {
// This is guaranteed not to overflow because the total length of the chain
// is checked during all creations of `DescriptorChainConsumer` (see
Expand Down Expand Up @@ -277,43 +336,6 @@ impl<'a> DescriptorChainConsumer<'a> {
}
}

fn collect_desc_chain_buffers<M: GuestAddressSpace>(
mem: &M::M,
desc_chain_iter: DescriptorChainRwIter<M>,
) -> Result<Vec<VolatileSlice<'_>>> {
let mut total_len: usize = 0;
let mut buffers = Vec::new();
for desc in desc_chain_iter {
// Verify that summing the descriptor sizes does not overflow.
// This can happen if a driver tricks a device into reading more data than
// fits in a `usize`.
total_len = total_len
.checked_add(desc.len() as usize)
.ok_or(Error::DescriptorChainOverflow)?;

let mut len = desc.len() as usize;
let mut addr = desc.addr();
while len > 0 {
let region = mem.find_region(addr).ok_or(Error::FindMemoryRegion)?;
let offset = addr
.checked_sub(region.start_addr().raw_value())
.unwrap()
.raw_value() as usize;
let buf_len = cmp::min(region.len() as usize - offset, len);
let buf = region
.get_slice(MemoryRegionAddress(offset as u64), buf_len as usize)
.map_err(Error::GuestMemory)?;
buffers.push(buf);
len -= buf_len;
addr = addr
.checked_add(buf_len as _)
.ok_or(Error::DescriptorChainOverflow)?;
}
}

Ok(buffers)
}

/// Provides high-level interface over the sequence of memory regions
/// defined by readable descriptors in the descriptor chain.
///
Expand All @@ -329,16 +351,10 @@ pub struct Reader<'a> {
impl<'a> Reader<'a> {
/// Construct a new Reader wrapper over `desc_chain`.
pub fn new<M: GuestAddressSpace>(
desc_chain: &mut DescriptorChain<M>,
mem: &'a M::M,
desc_chain: DescriptorChain<M>,
) -> Result<Reader<'a>> {
Ok(Reader {
buffer: DescriptorChainConsumer {
buffers: collect_desc_chain_buffers(mem, desc_chain.readable())?,
bytes_consumed: 0,
index: 0,
},
})
DescriptorChainConsumer::new(desc_chain, mem, true).map(|c| Reader { buffer: c })
}

/// Reads an object from the descriptor chain buffer.
Expand Down Expand Up @@ -416,16 +432,10 @@ pub struct Writer<'a> {
impl<'a> Writer<'a> {
/// Construct a new Writer wrapper over `desc_chain`.
pub fn new<M: GuestAddressSpace>(
desc_chain: &mut DescriptorChain<M>,
mem: &'a M::M,
desc_chain: DescriptorChain<M>,
) -> Result<Writer<'a>> {
Ok(Writer {
buffer: DescriptorChainConsumer {
buffers: collect_desc_chain_buffers(mem, desc_chain.writable())?,
bytes_consumed: 0,
index: 0,
},
})
DescriptorChainConsumer::new(desc_chain, mem, false).map(|c| Writer { buffer: c })
}

/// Writes an object to the descriptor chain buffer.
Expand Down

0 comments on commit ec17a51

Please sign in to comment.