Skip to content

Commit

Permalink
Rework ChecksummedBytes internals to use a Range instead of a Bytes s…
Browse files Browse the repository at this point in the history
…lice (#687)

* Rework ChecksummedBytes internals to use a Range instead of a Bytes slice

Preliminary refactor to prepare for adding integrity checks on the range itself. No changes in behavior.

Signed-off-by: Alessandro Passaro <alexpax@amazon.co.uk>

* Fix rustdoc

Signed-off-by: Alessandro Passaro <alexpax@amazon.co.uk>

* Improve setup of slice tests

Signed-off-by: Alessandro Passaro <alexpax@amazon.co.uk>

---------

Signed-off-by: Alessandro Passaro <alexpax@amazon.co.uk>
  • Loading branch information
passaro authored Jan 3, 2024
1 parent 374a0f2 commit 45414a2
Showing 1 changed file with 144 additions and 57 deletions.
201 changes: 144 additions & 57 deletions mountpoint-s3/src/checksums.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::RangeBounds;
use std::ops::{Bound, Range, RangeBounds};

use bytes::{Bytes, BytesMut};
use mountpoint_s3_crt::checksums::crc32c::{self, Crc32c};
Expand All @@ -12,19 +12,20 @@ use thiserror::Error;
#[derive(Clone, Debug)]
#[must_use]
pub struct ChecksummedBytes {
orig_bytes: Bytes,
/// Always a subslice of `orig_bytes`
curr_slice: Bytes,
/// Checksum for `orig_bytes`
/// Underlying buffer
buffer: Bytes,
/// Range over [Self::buffer]
range: Range<usize>,
/// Checksum for [Self::buffer]
checksum: Crc32c,
}

impl ChecksummedBytes {
pub fn new(bytes: Bytes, checksum: Crc32c) -> Self {
let curr_slice = bytes.clone();
let full_range = 0..bytes.len();
Self {
orig_bytes: bytes,
curr_slice,
buffer: bytes,
range: full_range,
checksum,
}
}
Expand All @@ -40,18 +41,17 @@ impl ChecksummedBytes {
/// Return [IntegrityError] on data corruption.
pub fn into_bytes(self) -> Result<Bytes, IntegrityError> {
self.validate()?;

Ok(self.curr_slice)
Ok(self.buffer_slice())
}

/// Returns the number of bytes contained in this [ChecksummedBytes].
pub fn len(&self) -> usize {
self.curr_slice.len()
self.range.len()
}

/// Returns true if the [ChecksummedBytes] has a length of 0.
pub fn is_empty(&self) -> bool {
self.curr_slice.is_empty()
self.range.is_empty()
}

/// Split off the checksummed bytes at the given index.
Expand All @@ -61,10 +61,16 @@ impl ChecksummedBytes {
/// This operation just increases the reference count and sets a few indices,
/// so there will be no validation and the checksum will not be recomputed.
pub fn split_off(&mut self, at: usize) -> ChecksummedBytes {
let new_bytes = self.curr_slice.split_off(at);
assert!(at < self.len());

let start = self.range.start;
let prefix_range = start..(start + at);
let suffix_range = (start + at)..self.range.end;

self.range = prefix_range;
Self {
orig_bytes: self.orig_bytes.clone(),
curr_slice: new_bytes,
buffer: self.buffer.clone(),
range: suffix_range,
checksum: self.checksum,
}
}
Expand All @@ -74,9 +80,41 @@ impl ChecksummedBytes {
/// This operation just increases the reference count and sets a few indices,
/// so there will be no validation and the checksum will not be recomputed.
pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
let sliced_range = {
let original_len = self.len();
let original_start = self.range.start;

let slice_start_offset = match range.start_bound() {
Bound::Included(&n) => n,
Bound::Excluded(&n) => n.checked_add(1).expect("range start greater than maximum usize"),
Bound::Unbounded => 0,
};

let slice_end_offset = match range.end_bound() {
Bound::Included(&n) => n.checked_add(1).expect("range end greater than maximum usize"),
Bound::Excluded(&n) => n,
Bound::Unbounded => original_len,
};

assert!(
slice_start_offset <= slice_end_offset,
"range start must not be greater than end: {:?} <= {:?}",
slice_start_offset,
slice_end_offset,
);
assert!(
slice_end_offset <= original_len,
"range end out of bounds: {:?} <= {:?}",
slice_end_offset,
original_len,
);

(original_start + slice_start_offset)..(original_start + slice_end_offset)
};

Self {
orig_bytes: self.orig_bytes.clone(),
curr_slice: self.curr_slice.slice(range),
buffer: self.buffer.clone(),
range: sliced_range,
checksum: self.checksum,
}
}
Expand All @@ -86,12 +124,12 @@ impl ChecksummedBytes {
///
/// Return [IntegrityError] if data corruption is detected.
pub fn shrink_to_fit(&self) -> Result<Self, IntegrityError> {
if self.curr_slice.len() == self.orig_bytes.len() {
if self.len() == self.buffer.len() {
return Ok(self.clone());
}

// Note that no data is copied: `bytes` still points to a subslice of `orig_bytes`.
let bytes = self.curr_slice.clone();
// Note that no data is copied: `bytes` still points to a subslice of `buffer`.
let bytes = self.buffer_slice();
let checksum = crc32c::checksum(&bytes);
let result = Self::new(bytes, checksum);

Expand Down Expand Up @@ -124,18 +162,18 @@ impl ChecksummedBytes {
// rather than the exact one for the slice, we need to first invoke `shrink_to_fit` on each
// slice and use the resulting exact checksums.
let prefix = self.shrink_to_fit()?;
assert_eq!(prefix.orig_bytes.len(), prefix.curr_slice.len());
assert_eq!(prefix.buffer.len(), prefix.len());
let suffix = extend.shrink_to_fit()?;
assert_eq!(suffix.orig_bytes.len(), suffix.curr_slice.len());
assert_eq!(suffix.buffer.len(), suffix.len());

// Combine the checksums.
let new_checksum = combine_checksums(prefix.checksum, suffix.checksum, suffix.len());

// Combine the slices.
let new_bytes = {
let mut bytes_mut = BytesMut::with_capacity(prefix.len() + suffix.len());
bytes_mut.extend_from_slice(&prefix.curr_slice);
bytes_mut.extend_from_slice(&suffix.curr_slice);
bytes_mut.extend_from_slice(&prefix.buffer);
bytes_mut.extend_from_slice(&suffix.buffer);
bytes_mut.freeze()
};
*self = ChecksummedBytes::new(new_bytes, new_checksum);
Expand All @@ -146,7 +184,7 @@ impl ChecksummedBytes {
///
/// Return [IntegrityError] on data corruption.
pub fn validate(&self) -> Result<(), IntegrityError> {
let checksum = crc32c::checksum(&self.orig_bytes);
let checksum = crc32c::checksum(&self.buffer);
if self.checksum != checksum {
return Err(IntegrityError::ChecksumMismatch(self.checksum, checksum));
}
Expand All @@ -160,19 +198,23 @@ impl ChecksummedBytes {
/// If you are only interested in the underlying bytes, **you should use `into_bytes()`**.
pub fn into_inner(self) -> Result<(Bytes, Crc32c), IntegrityError> {
let fit = self.shrink_to_fit()?;
Ok((fit.curr_slice, fit.checksum))
Ok((fit.buffer, fit.checksum))
}

/// Return the slice of `buffer` corresponding to `range`.
///
/// Note that no data is copied: the returned `Bytes` still points to a subslice of `buffer`.
fn buffer_slice(&self) -> Bytes {
self.buffer.slice(self.range.clone())
}
}

impl Default for ChecksummedBytes {
fn default() -> Self {
let orig_bytes = Bytes::new();
let curr_slice = orig_bytes.clone();
let checksum = Crc32c::new(0);
Self {
orig_bytes,
curr_slice,
checksum,
buffer: Default::default(),
range: Default::default(),
checksum: Crc32c::new(0),
}
}
}
Expand Down Expand Up @@ -208,11 +250,7 @@ pub enum IntegrityError {
#[cfg(test)]
impl PartialEq for ChecksummedBytes {
fn eq(&self, other: &Self) -> bool {
if self.curr_slice != other.curr_slice {
return false;
}

let result = self.orig_bytes == other.orig_bytes && self.checksum == other.checksum;
let result = self.buffer_slice() == other.buffer_slice();
self.validate().expect("should be valid");
other.validate().expect("should be valid");
result
Expand All @@ -221,7 +259,10 @@ impl PartialEq for ChecksummedBytes {

#[cfg(test)]
mod tests {
use std::ops::{RangeFrom, RangeTo};

use mountpoint_s3_crt::checksums::crc32c;
use test_case::test_case;

use super::*;

Expand Down Expand Up @@ -257,10 +298,10 @@ mod tests {
let expected_part2 = expected_part1.split_off(split_off_at);
let new_checksummed_bytes = checksummed_bytes.split_off(split_off_at);

assert_eq!(expected, checksummed_bytes.orig_bytes);
assert_eq!(expected, new_checksummed_bytes.orig_bytes);
assert_eq!(expected_part1, checksummed_bytes.curr_slice);
assert_eq!(expected_part2, new_checksummed_bytes.curr_slice);
assert_eq!(expected, checksummed_bytes.buffer);
assert_eq!(expected, new_checksummed_bytes.buffer);
assert_eq!(expected_part1, checksummed_bytes.buffer_slice());
assert_eq!(expected_part2, new_checksummed_bytes.buffer_slice());
assert_eq!(checksum, checksummed_bytes.checksum);
assert_eq!(checksum, new_checksummed_bytes.checksum);
}
Expand All @@ -275,26 +316,72 @@ mod tests {
let original = ChecksummedBytes::new(bytes, checksum);
let slice = original.slice(range);

assert_eq!(expected, original.orig_bytes);
assert_eq!(expected, original.curr_slice);
assert_eq!(expected, slice.orig_bytes);
assert_eq!(expected_slice, slice.curr_slice);
assert_eq!(expected, original.buffer);
assert_eq!(expected, original.buffer_slice());
assert_eq!(expected, slice.buffer);
assert_eq!(expected_slice, slice.buffer_slice());
assert_eq!(checksum, original.checksum);
assert_eq!(checksum, slice.checksum);
}

fn create_checksummed_bytes_with_range(range: Range<usize>) -> ChecksummedBytes {
let buffer = Bytes::copy_from_slice(&vec![0; range.len()]);
let checksum = crc32c::checksum(&buffer);
ChecksummedBytes {
buffer,
range,
checksum,
}
}

#[test_case(0..10, 0..10, 0..10)]
#[test_case(0..10, 5..6, 5..6)]
#[test_case(5..10, 2..4, 7..9)]
fn test_slice_range(original: Range<usize>, range: Range<usize>, expected: Range<usize>) {
let bytes = create_checksummed_bytes_with_range(original);
let slice = bytes.slice(range);
assert_eq!(slice.range, expected);
}

#[allow(clippy::reversed_empty_ranges)]
#[should_panic]
#[test_case(5..10, 4..2; "start greater than end")]
#[test_case(5..10, 4..12; "out of bounds")]
fn test_slice_range_fail(original: Range<usize>, range: Range<usize>) {
let bytes = create_checksummed_bytes_with_range(original);
_ = bytes.slice(range);
}

#[test_case(0..10, ..10, 0..10)]
#[test_case(0..10, ..6, 0..6)]
#[test_case(5..10, ..4, 5..9)]
fn test_slice_range_to(original: Range<usize>, range: RangeTo<usize>, expected: Range<usize>) {
let bytes = create_checksummed_bytes_with_range(original);
let slice = bytes.slice(range);
assert_eq!(slice.range, expected);
}

#[test_case(0..10, 0.., 0..10)]
#[test_case(0..10, 4.., 4..10)]
#[test_case(5..10, 2.., 7..10)]
fn test_slice_range_from(original: Range<usize>, range: RangeFrom<usize>, expected: Range<usize>) {
let bytes = create_checksummed_bytes_with_range(original);
let slice = bytes.slice(range);
assert_eq!(slice.range, expected);
}

#[test]
fn test_shrink_to_fit() {
let original = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes"));
let unchanged = original.shrink_to_fit().unwrap();
assert_eq!(original.curr_slice, unchanged.curr_slice);
assert_eq!(original.orig_bytes, unchanged.orig_bytes);
assert_eq!(original.buffer_slice(), unchanged.buffer_slice());
assert_eq!(original.buffer, unchanged.buffer);
assert_eq!(original.checksum, unchanged.checksum);

let slice = original.clone().split_off(5);
let shrunken = slice.shrink_to_fit().unwrap();
assert_eq!(slice.curr_slice, shrunken.curr_slice);
assert_ne!(slice.orig_bytes, shrunken.orig_bytes);
assert_eq!(slice.buffer_slice(), shrunken.buffer_slice());
assert_ne!(slice.buffer, shrunken.buffer);
assert_ne!(slice.checksum, shrunken.checksum);
}

Expand All @@ -308,8 +395,8 @@ mod tests {
));

let unchanged = original.shrink_to_fit().unwrap();
assert_eq!(original.curr_slice, unchanged.curr_slice);
assert_eq!(original.orig_bytes, unchanged.orig_bytes);
assert_eq!(original.buffer_slice(), unchanged.buffer_slice());
assert_eq!(original.buffer, unchanged.buffer);
assert_eq!(original.checksum, unchanged.checksum);
assert!(matches!(
unchanged.validate(),
Expand All @@ -327,14 +414,14 @@ mod tests {
fn test_into_inner() {
let original = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes"));
let (unchanged_bytes, unchanged_checksum) = original.clone().into_inner().unwrap();
assert_eq!(original.curr_slice, unchanged_bytes);
assert_eq!(original.orig_bytes, unchanged_bytes);
assert_eq!(original.buffer_slice(), unchanged_bytes);
assert_eq!(original.buffer, unchanged_bytes);
assert_eq!(original.checksum, unchanged_checksum);

let slice = original.clone().split_off(5);
let (shrunken_bytes, shrunken_checksum) = slice.clone().into_inner().unwrap();
assert_eq!(slice.curr_slice, shrunken_bytes);
assert_ne!(slice.orig_bytes, shrunken_bytes);
assert_eq!(slice.buffer_slice(), shrunken_bytes);
assert_ne!(slice.buffer, shrunken_bytes);
assert_ne!(slice.checksum, shrunken_checksum);
}

Expand All @@ -344,7 +431,7 @@ mod tests {
let mut checksummed_bytes = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes"));
let extend_bytes = ChecksummedBytes::from_bytes(Bytes::from_static(b" extended"));
checksummed_bytes.extend(extend_bytes).unwrap();
let actual = checksummed_bytes.curr_slice;
let actual = checksummed_bytes.buffer_slice();
assert_eq!(expected, actual);
}

Expand All @@ -358,7 +445,7 @@ mod tests {
_ = checksummed_bytes.split_off(split_off_at);
_ = extend.split_off(split_off_at);
checksummed_bytes.extend(extend).unwrap();
let actual = checksummed_bytes.curr_slice;
let actual = checksummed_bytes.buffer_slice();
assert_eq!(expected, actual);
}

Expand Down

0 comments on commit 45414a2

Please sign in to comment.