diff --git a/mountpoint-s3/src/checksums.rs b/mountpoint-s3/src/checksums.rs index 087153064..59c10d21f 100644 --- a/mountpoint-s3/src/checksums.rs +++ b/mountpoint-s3/src/checksums.rs @@ -1,4 +1,4 @@ -use std::ops::RangeBounds; +use std::ops::{Bound, Range, RangeBounds}; use bytes::{Bytes, BytesMut}; use mountpoint_s3_crt::checksums::crc32c::{self, Crc32c}; @@ -12,19 +12,20 @@ use thiserror::Error; #[derive(Clone, Debug)] #[must_use] pub struct ChecksummedBytes { - orig_bytes: Bytes, - /// Always a subslice of `orig_bytes` - curr_slice: Bytes, - /// Checksum for `orig_bytes` + /// Underlying buffer + buffer: Bytes, + /// Range over [Self::buffer] + range: Range, + /// Checksum for [Self::buffer] checksum: Crc32c, } impl ChecksummedBytes { pub fn new(bytes: Bytes, checksum: Crc32c) -> Self { - let curr_slice = bytes.clone(); + let full_range = 0..bytes.len(); Self { - orig_bytes: bytes, - curr_slice, + buffer: bytes, + range: full_range, checksum, } } @@ -40,18 +41,17 @@ impl ChecksummedBytes { /// Return [IntegrityError] on data corruption. pub fn into_bytes(self) -> Result { self.validate()?; - - Ok(self.curr_slice) + Ok(self.buffer_slice()) } /// Returns the number of bytes contained in this [ChecksummedBytes]. pub fn len(&self) -> usize { - self.curr_slice.len() + self.range.len() } /// Returns true if the [ChecksummedBytes] has a length of 0. pub fn is_empty(&self) -> bool { - self.curr_slice.is_empty() + self.range.is_empty() } /// Split off the checksummed bytes at the given index. @@ -61,10 +61,16 @@ impl ChecksummedBytes { /// This operation just increases the reference count and sets a few indices, /// so there will be no validation and the checksum will not be recomputed. pub fn split_off(&mut self, at: usize) -> ChecksummedBytes { - let new_bytes = self.curr_slice.split_off(at); + assert!(at < self.len()); + + let start = self.range.start; + let prefix_range = start..(start + at); + let suffix_range = (start + at)..self.range.end; + + self.range = prefix_range; Self { - orig_bytes: self.orig_bytes.clone(), - curr_slice: new_bytes, + buffer: self.buffer.clone(), + range: suffix_range, checksum: self.checksum, } } @@ -74,9 +80,41 @@ impl ChecksummedBytes { /// This operation just increases the reference count and sets a few indices, /// so there will be no validation and the checksum will not be recomputed. pub fn slice(&self, range: impl RangeBounds) -> Self { + let sliced_range = { + let original_len = self.len(); + let original_start = self.range.start; + + let slice_start_offset = match range.start_bound() { + Bound::Included(&n) => n, + Bound::Excluded(&n) => n.checked_add(1).expect("range start greater than maximum usize"), + Bound::Unbounded => 0, + }; + + let slice_end_offset = match range.end_bound() { + Bound::Included(&n) => n.checked_add(1).expect("range end greater than maximum usize"), + Bound::Excluded(&n) => n, + Bound::Unbounded => original_len, + }; + + assert!( + slice_start_offset <= slice_end_offset, + "range start must not be greater than end: {:?} <= {:?}", + slice_start_offset, + slice_end_offset, + ); + assert!( + slice_end_offset <= original_len, + "range end out of bounds: {:?} <= {:?}", + slice_end_offset, + original_len, + ); + + (original_start + slice_start_offset)..(original_start + slice_end_offset) + }; + Self { - orig_bytes: self.orig_bytes.clone(), - curr_slice: self.curr_slice.slice(range), + buffer: self.buffer.clone(), + range: sliced_range, checksum: self.checksum, } } @@ -86,12 +124,12 @@ impl ChecksummedBytes { /// /// Return [IntegrityError] if data corruption is detected. pub fn shrink_to_fit(&self) -> Result { - if self.curr_slice.len() == self.orig_bytes.len() { + if self.len() == self.buffer.len() { return Ok(self.clone()); } - // Note that no data is copied: `bytes` still points to a subslice of `orig_bytes`. - let bytes = self.curr_slice.clone(); + // Note that no data is copied: `bytes` still points to a subslice of `buffer`. + let bytes = self.buffer_slice(); let checksum = crc32c::checksum(&bytes); let result = Self::new(bytes, checksum); @@ -124,9 +162,9 @@ impl ChecksummedBytes { // rather than the exact one for the slice, we need to first invoke `shrink_to_fit` on each // slice and use the resulting exact checksums. let prefix = self.shrink_to_fit()?; - assert_eq!(prefix.orig_bytes.len(), prefix.curr_slice.len()); + assert_eq!(prefix.buffer.len(), prefix.len()); let suffix = extend.shrink_to_fit()?; - assert_eq!(suffix.orig_bytes.len(), suffix.curr_slice.len()); + assert_eq!(suffix.buffer.len(), suffix.len()); // Combine the checksums. let new_checksum = combine_checksums(prefix.checksum, suffix.checksum, suffix.len()); @@ -134,8 +172,8 @@ impl ChecksummedBytes { // Combine the slices. let new_bytes = { let mut bytes_mut = BytesMut::with_capacity(prefix.len() + suffix.len()); - bytes_mut.extend_from_slice(&prefix.curr_slice); - bytes_mut.extend_from_slice(&suffix.curr_slice); + bytes_mut.extend_from_slice(&prefix.buffer); + bytes_mut.extend_from_slice(&suffix.buffer); bytes_mut.freeze() }; *self = ChecksummedBytes::new(new_bytes, new_checksum); @@ -146,7 +184,7 @@ impl ChecksummedBytes { /// /// Return [IntegrityError] on data corruption. pub fn validate(&self) -> Result<(), IntegrityError> { - let checksum = crc32c::checksum(&self.orig_bytes); + let checksum = crc32c::checksum(&self.buffer); if self.checksum != checksum { return Err(IntegrityError::ChecksumMismatch(self.checksum, checksum)); } @@ -160,19 +198,23 @@ impl ChecksummedBytes { /// If you are only interested in the underlying bytes, **you should use `into_bytes()`**. pub fn into_inner(self) -> Result<(Bytes, Crc32c), IntegrityError> { let fit = self.shrink_to_fit()?; - Ok((fit.curr_slice, fit.checksum)) + Ok((fit.buffer, fit.checksum)) + } + + /// Return the slice of `buffer` corresponding to `range`. + /// + /// Note that no data is copied: the returned `Bytes` still points to a subslice of `buffer`. + fn buffer_slice(&self) -> Bytes { + self.buffer.slice(self.range.clone()) } } impl Default for ChecksummedBytes { fn default() -> Self { - let orig_bytes = Bytes::new(); - let curr_slice = orig_bytes.clone(); - let checksum = Crc32c::new(0); Self { - orig_bytes, - curr_slice, - checksum, + buffer: Default::default(), + range: Default::default(), + checksum: Crc32c::new(0), } } } @@ -208,11 +250,7 @@ pub enum IntegrityError { #[cfg(test)] impl PartialEq for ChecksummedBytes { fn eq(&self, other: &Self) -> bool { - if self.curr_slice != other.curr_slice { - return false; - } - - let result = self.orig_bytes == other.orig_bytes && self.checksum == other.checksum; + let result = self.buffer_slice() == other.buffer_slice(); self.validate().expect("should be valid"); other.validate().expect("should be valid"); result @@ -221,7 +259,10 @@ impl PartialEq for ChecksummedBytes { #[cfg(test)] mod tests { + use std::ops::{RangeFrom, RangeTo}; + use mountpoint_s3_crt::checksums::crc32c; + use test_case::test_case; use super::*; @@ -257,10 +298,10 @@ mod tests { let expected_part2 = expected_part1.split_off(split_off_at); let new_checksummed_bytes = checksummed_bytes.split_off(split_off_at); - assert_eq!(expected, checksummed_bytes.orig_bytes); - assert_eq!(expected, new_checksummed_bytes.orig_bytes); - assert_eq!(expected_part1, checksummed_bytes.curr_slice); - assert_eq!(expected_part2, new_checksummed_bytes.curr_slice); + assert_eq!(expected, checksummed_bytes.buffer); + assert_eq!(expected, new_checksummed_bytes.buffer); + assert_eq!(expected_part1, checksummed_bytes.buffer_slice()); + assert_eq!(expected_part2, new_checksummed_bytes.buffer_slice()); assert_eq!(checksum, checksummed_bytes.checksum); assert_eq!(checksum, new_checksummed_bytes.checksum); } @@ -275,26 +316,72 @@ mod tests { let original = ChecksummedBytes::new(bytes, checksum); let slice = original.slice(range); - assert_eq!(expected, original.orig_bytes); - assert_eq!(expected, original.curr_slice); - assert_eq!(expected, slice.orig_bytes); - assert_eq!(expected_slice, slice.curr_slice); + assert_eq!(expected, original.buffer); + assert_eq!(expected, original.buffer_slice()); + assert_eq!(expected, slice.buffer); + assert_eq!(expected_slice, slice.buffer_slice()); assert_eq!(checksum, original.checksum); assert_eq!(checksum, slice.checksum); } + fn create_checksummed_bytes_with_range(range: Range) -> ChecksummedBytes { + let buffer = Bytes::copy_from_slice(&vec![0; range.len()]); + let checksum = crc32c::checksum(&buffer); + ChecksummedBytes { + buffer, + range, + checksum, + } + } + + #[test_case(0..10, 0..10, 0..10)] + #[test_case(0..10, 5..6, 5..6)] + #[test_case(5..10, 2..4, 7..9)] + fn test_slice_range(original: Range, range: Range, expected: Range) { + let bytes = create_checksummed_bytes_with_range(original); + let slice = bytes.slice(range); + assert_eq!(slice.range, expected); + } + + #[allow(clippy::reversed_empty_ranges)] + #[should_panic] + #[test_case(5..10, 4..2; "start greater than end")] + #[test_case(5..10, 4..12; "out of bounds")] + fn test_slice_range_fail(original: Range, range: Range) { + let bytes = create_checksummed_bytes_with_range(original); + _ = bytes.slice(range); + } + + #[test_case(0..10, ..10, 0..10)] + #[test_case(0..10, ..6, 0..6)] + #[test_case(5..10, ..4, 5..9)] + fn test_slice_range_to(original: Range, range: RangeTo, expected: Range) { + let bytes = create_checksummed_bytes_with_range(original); + let slice = bytes.slice(range); + assert_eq!(slice.range, expected); + } + + #[test_case(0..10, 0.., 0..10)] + #[test_case(0..10, 4.., 4..10)] + #[test_case(5..10, 2.., 7..10)] + fn test_slice_range_from(original: Range, range: RangeFrom, expected: Range) { + let bytes = create_checksummed_bytes_with_range(original); + let slice = bytes.slice(range); + assert_eq!(slice.range, expected); + } + #[test] fn test_shrink_to_fit() { let original = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes")); let unchanged = original.shrink_to_fit().unwrap(); - assert_eq!(original.curr_slice, unchanged.curr_slice); - assert_eq!(original.orig_bytes, unchanged.orig_bytes); + assert_eq!(original.buffer_slice(), unchanged.buffer_slice()); + assert_eq!(original.buffer, unchanged.buffer); assert_eq!(original.checksum, unchanged.checksum); let slice = original.clone().split_off(5); let shrunken = slice.shrink_to_fit().unwrap(); - assert_eq!(slice.curr_slice, shrunken.curr_slice); - assert_ne!(slice.orig_bytes, shrunken.orig_bytes); + assert_eq!(slice.buffer_slice(), shrunken.buffer_slice()); + assert_ne!(slice.buffer, shrunken.buffer); assert_ne!(slice.checksum, shrunken.checksum); } @@ -308,8 +395,8 @@ mod tests { )); let unchanged = original.shrink_to_fit().unwrap(); - assert_eq!(original.curr_slice, unchanged.curr_slice); - assert_eq!(original.orig_bytes, unchanged.orig_bytes); + assert_eq!(original.buffer_slice(), unchanged.buffer_slice()); + assert_eq!(original.buffer, unchanged.buffer); assert_eq!(original.checksum, unchanged.checksum); assert!(matches!( unchanged.validate(), @@ -327,14 +414,14 @@ mod tests { fn test_into_inner() { let original = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes")); let (unchanged_bytes, unchanged_checksum) = original.clone().into_inner().unwrap(); - assert_eq!(original.curr_slice, unchanged_bytes); - assert_eq!(original.orig_bytes, unchanged_bytes); + assert_eq!(original.buffer_slice(), unchanged_bytes); + assert_eq!(original.buffer, unchanged_bytes); assert_eq!(original.checksum, unchanged_checksum); let slice = original.clone().split_off(5); let (shrunken_bytes, shrunken_checksum) = slice.clone().into_inner().unwrap(); - assert_eq!(slice.curr_slice, shrunken_bytes); - assert_ne!(slice.orig_bytes, shrunken_bytes); + assert_eq!(slice.buffer_slice(), shrunken_bytes); + assert_ne!(slice.buffer, shrunken_bytes); assert_ne!(slice.checksum, shrunken_checksum); } @@ -344,7 +431,7 @@ mod tests { let mut checksummed_bytes = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes")); let extend_bytes = ChecksummedBytes::from_bytes(Bytes::from_static(b" extended")); checksummed_bytes.extend(extend_bytes).unwrap(); - let actual = checksummed_bytes.curr_slice; + let actual = checksummed_bytes.buffer_slice(); assert_eq!(expected, actual); } @@ -358,7 +445,7 @@ mod tests { _ = checksummed_bytes.split_off(split_off_at); _ = extend.split_off(split_off_at); checksummed_bytes.extend(extend).unwrap(); - let actual = checksummed_bytes.curr_slice; + let actual = checksummed_bytes.buffer_slice(); assert_eq!(expected, actual); }