Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow casting between slices of ZSTs and slices of non-ZSTs in all cases. #256

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions src/allocation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ pub fn try_cast_slice_box<A: NoUninit, B: AnyBitPattern>(
{
// If the size in bytes of the underlying buffer does not match an exact
// multiple of the size of B, we cannot cast between them.
Err((PodCastError::SizeMismatch, input))
Err((PodCastError::OutputSliceWouldHaveSlop, input))
} else {
// Because the size is an exact multiple, we can now change the length
// of the slice and recreate the Box
Expand Down Expand Up @@ -239,7 +239,7 @@ pub fn try_cast_vec<A: NoUninit, B: AnyBitPattern>(
// length and capacity are valid under B, as we do not want to
// change which bytes are considered part of the initialized slice
// of the Vec
Err((PodCastError::SizeMismatch, input))
Err((PodCastError::OutputSliceWouldHaveSlop, input))
} else {
// Because the size is an exact multiple, we can now change the length and
// capacity and recreate the Vec
Expand Down Expand Up @@ -431,7 +431,7 @@ pub fn try_cast_slice_rc<
{
// If the size in bytes of the underlying buffer does not match an exact
// multiple of the size of B, we cannot cast between them.
Err((PodCastError::SizeMismatch, input))
Err((PodCastError::OutputSliceWouldHaveSlop, input))
} else {
// Because the size is an exact multiple, we can now change the length
// of the slice and recreate the Rc
Expand Down Expand Up @@ -499,7 +499,7 @@ pub fn try_cast_slice_arc<
{
// If the size in bytes of the underlying buffer does not match an exact
// multiple of the size of B, we cannot cast between them.
Err((PodCastError::SizeMismatch, input))
Err((PodCastError::OutputSliceWouldHaveSlop, input))
} else {
// Because the size is an exact multiple, we can now change the length
// of the slice and recreate the Arc
Expand Down Expand Up @@ -846,13 +846,18 @@ impl<T: AnyBitPattern> sealed::FromBoxBytes for [T] {
let single_layout = Layout::new::<T>();
if bytes.layout.align() != single_layout.align() {
Err((PodCastError::AlignmentMismatch, bytes))
} else if single_layout.size() == 0 {
Err((PodCastError::SizeMismatch, bytes))
} else if bytes.layout.size() % single_layout.size() != 0 {
} else if (single_layout.size() == 0 && bytes.layout.size() != 0)
|| (single_layout.size() != 0
&& bytes.layout.size() % single_layout.size() != 0)
{
Err((PodCastError::OutputSliceWouldHaveSlop, bytes))
} else {
let (ptr, layout) = bytes.into_raw_parts();
let length = layout.size() / single_layout.size();
let length = if single_layout.size() != 0 {
layout.size() / single_layout.size()
} else {
0
};
let ptr =
core::ptr::slice_from_raw_parts_mut(ptr.as_ptr() as *mut T, length);
// SAFETY: See BoxBytes type invariant.
Expand Down
2 changes: 0 additions & 2 deletions src/checked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,6 @@ pub fn try_cast_mut<
/// type, and the output slice wouldn't be a whole number of elements when
/// accounting for the size change (eg: 3 `u16` values is 1.5 `u32` values, so
/// that's a failure).
/// * Similarly, you can't convert between a [ZST](https://doc.rust-lang.org/nomicon/exotic-sizes.html#zero-sized-types-zsts)
/// and a non-ZST.
/// * If any element of the converted slice would contain an invalid bit pattern
/// for `B` this fails.
#[inline]
Expand Down
42 changes: 18 additions & 24 deletions src/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,9 @@ pub(crate) fn something_went_wrong<D>(_src: &str, _err: D) -> ! {
/// empty slice might not match the pointer value of the input reference.
#[inline(always)]
pub(crate) unsafe fn bytes_of<T: Copy>(t: &T) -> &[u8] {
if size_of::<T>() == 0 {
&[]
} else {
match try_cast_slice::<T, u8>(core::slice::from_ref(t)) {
Ok(s) => s,
Err(_) => unreachable!(),
}
match try_cast_slice::<T, u8>(core::slice::from_ref(t)) {
Ok(s) => s,
Err(_) => unreachable!(),
}
}

Expand All @@ -67,13 +63,9 @@ pub(crate) unsafe fn bytes_of<T: Copy>(t: &T) -> &[u8] {
/// empty slice might not match the pointer value of the input reference.
#[inline]
pub(crate) unsafe fn bytes_of_mut<T: Copy>(t: &mut T) -> &mut [u8] {
if size_of::<T>() == 0 {
&mut []
} else {
match try_cast_slice_mut::<T, u8>(core::slice::from_mut(t)) {
Ok(s) => s,
Err(_) => unreachable!(),
}
match try_cast_slice_mut::<T, u8>(core::slice::from_mut(t)) {
Ok(s) => s,
Err(_) => unreachable!(),
}
}

Expand Down Expand Up @@ -347,12 +339,11 @@ pub(crate) unsafe fn try_cast_mut<A: Copy, B: Copy>(
/// type, and the output slice wouldn't be a whole number of elements when
/// accounting for the size change (eg: 3 `u16` values is 1.5 `u32` values, so
/// that's a failure).
/// * Similarly, you can't convert between a [ZST](https://doc.rust-lang.org/nomicon/exotic-sizes.html#zero-sized-types-zsts)
/// and a non-ZST.
#[inline]
pub(crate) unsafe fn try_cast_slice<A: Copy, B: Copy>(
a: &[A],
) -> Result<&[B], PodCastError> {
let input_bytes = core::mem::size_of_val::<[A]>(a);
// Note(Lokathor): everything with `align_of` and `size_of` will optimize away
// after monomorphization.
if align_of::<B>() > align_of::<A>()
Expand All @@ -361,10 +352,11 @@ pub(crate) unsafe fn try_cast_slice<A: Copy, B: Copy>(
Err(PodCastError::TargetAlignmentGreaterAndInputNotAligned)
} else if size_of::<B>() == size_of::<A>() {
Ok(unsafe { core::slice::from_raw_parts(a.as_ptr() as *const B, a.len()) })
} else if size_of::<A>() == 0 || size_of::<B>() == 0 {
Err(PodCastError::SizeMismatch)
} else if core::mem::size_of_val(a) % size_of::<B>() == 0 {
let new_len = core::mem::size_of_val(a) / size_of::<B>();
} else if (size_of::<B>() != 0 && input_bytes % size_of::<B>() == 0)
|| (size_of::<B>() == 0 && input_bytes == 0)
{
let new_len =
if size_of::<B>() != 0 { input_bytes / size_of::<B>() } else { 0 };
Ok(unsafe { core::slice::from_raw_parts(a.as_ptr() as *const B, new_len) })
} else {
Err(PodCastError::OutputSliceWouldHaveSlop)
Expand All @@ -379,6 +371,7 @@ pub(crate) unsafe fn try_cast_slice<A: Copy, B: Copy>(
pub(crate) unsafe fn try_cast_slice_mut<A: Copy, B: Copy>(
a: &mut [A],
) -> Result<&mut [B], PodCastError> {
let input_bytes = core::mem::size_of_val::<[A]>(a);
// Note(Lokathor): everything with `align_of` and `size_of` will optimize away
// after monomorphization.
if align_of::<B>() > align_of::<A>()
Expand All @@ -389,10 +382,11 @@ pub(crate) unsafe fn try_cast_slice_mut<A: Copy, B: Copy>(
Ok(unsafe {
core::slice::from_raw_parts_mut(a.as_mut_ptr() as *mut B, a.len())
})
} else if size_of::<A>() == 0 || size_of::<B>() == 0 {
Err(PodCastError::SizeMismatch)
} else if core::mem::size_of_val(a) % size_of::<B>() == 0 {
let new_len = core::mem::size_of_val(a) / size_of::<B>();
} else if (size_of::<B>() != 0 && input_bytes % size_of::<B>() == 0)
|| (size_of::<B>() == 0 && input_bytes == 0)
{
let new_len =
if size_of::<B>() != 0 { input_bytes / size_of::<B>() } else { 0 };
Ok(unsafe {
core::slice::from_raw_parts_mut(a.as_mut_ptr() as *mut B, new_len)
})
Expand Down
11 changes: 5 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,14 @@ pub use bytemuck_derive::{
/// The things that can go wrong when casting between [`Pod`] data forms.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PodCastError {
/// You tried to cast a slice to an element type with a higher alignment
/// requirement but the slice wasn't aligned.
/// You tried to cast a reference into a reference to a type with a higher alignment
/// requirement but the input reference wasn't aligned.
TargetAlignmentGreaterAndInputNotAligned,
/// If the element size changes then the output slice changes length
/// accordingly. If the output slice wouldn't be a whole number of elements
/// If the element size of a slice changes, then the output slice changes length
/// accordingly. If the output slice wouldn't be a whole number of elements,
/// then the conversion fails.
OutputSliceWouldHaveSlop,
/// When casting a slice you can't convert between ZST elements and non-ZST
/// elements. When casting an individual `T`, `&T`, or `&mut T` value the
/// When casting an individual `T`, `&T`, or `&mut T` value the
/// source size and destination size must be an exact match.
SizeMismatch,
/// For this type of cast the alignments must be exactly the same and they
Expand Down
34 changes: 27 additions & 7 deletions src/must.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ impl<A, B> Cast<A, B> {
const ASSERT_ALIGN_GREATER_THAN_EQUAL: () =
assert!(align_of::<A>() >= align_of::<B>());
const ASSERT_SIZE_EQUAL: () = assert!(size_of::<A>() == size_of::<B>());
const ASSERT_SIZE_MULTIPLE_OF: () = assert!(
(size_of::<A>() == 0) == (size_of::<B>() == 0)
&& (size_of::<A>() % size_of::<B>() == 0)
const ASSERT_SIZE_MULTIPLE_OF_OR_INPUT_ZST: () = assert!(
(size_of::<A>() == 0)
|| (size_of::<B>() != 0 && size_of::<A>() % size_of::<B>() == 0)
);
}

Expand Down Expand Up @@ -113,15 +113,20 @@ pub fn must_cast_mut<
/// * If the target type has a greater alignment requirement.
/// * If the target element type doesn't evenly fit into the the current element
/// type (eg: 3 `u16` values is 1.5 `u32` values, so that's a failure).
/// * Similarly, you can't convert between a [ZST](https://doc.rust-lang.org/nomicon/exotic-sizes.html#zero-sized-types-zsts)
/// and a non-ZST.
/// * Similarly, you can't convert from a non-[ZST](https://doc.rust-lang.org/nomicon/exotic-sizes.html#zero-sized-types-zsts)
/// to a ZST (e.g. 3 `u8` values is not any number of `()` values).
///
/// ## Examples
/// ```
/// let indicies: &[u16] = &[1, 2, 3];
/// // compiles:
/// let bytes: &[u8] = bytemuck::must_cast_slice(indicies);
/// ```
/// ```
/// let zsts: &[()] = &[(), (), ()];
/// // compiles:
/// let bytes: &[u8] = bytemuck::must_cast_slice(zsts);
/// ```
/// ```compile_fail,E0080
/// # let bytes : &[u8] = &[1, 0, 2, 0, 3, 0];
/// // fails to compile (bytes.len() might not be a multiple of 2):
Expand All @@ -132,9 +137,14 @@ pub fn must_cast_mut<
/// // fails to compile (alignment requirements increased):
/// let indicies : &[u16] = bytemuck::must_cast_slice(byte_pairs);
/// ```
/// ```compile_fail,E0080
/// let bytes: &[u8] = &[];
/// // fails to compile: (bytes.len() might not be 0)
/// let zsts: &[()] = bytemuck::must_cast_slice(bytes);
/// ```
#[inline]
pub fn must_cast_slice<A: NoUninit, B: AnyBitPattern>(a: &[A]) -> &[B] {
let _ = Cast::<A, B>::ASSERT_SIZE_MULTIPLE_OF;
let _ = Cast::<A, B>::ASSERT_SIZE_MULTIPLE_OF_OR_INPUT_ZST;
let _ = Cast::<A, B>::ASSERT_ALIGN_GREATER_THAN_EQUAL;
let new_len = if size_of::<A>() == size_of::<B>() {
a.len()
Expand All @@ -156,6 +166,11 @@ pub fn must_cast_slice<A: NoUninit, B: AnyBitPattern>(a: &[A]) -> &[B] {
/// // compiles:
/// let bytes: &mut [u8] = bytemuck::must_cast_slice_mut(indicies);
/// ```
/// ```
/// let zsts: &mut [()] = &mut [(), (), ()];
/// // compiles:
/// let bytes: &mut [u8] = bytemuck::must_cast_slice_mut(zsts);
/// ```
/// ```compile_fail,E0080
/// # let mut bytes = [1, 0, 2, 0, 3, 0];
/// # let bytes : &mut [u8] = &mut bytes[..];
Expand All @@ -168,14 +183,19 @@ pub fn must_cast_slice<A: NoUninit, B: AnyBitPattern>(a: &[A]) -> &[B] {
/// // fails to compile (alignment requirements increased):
/// let indicies : &mut [u16] = bytemuck::must_cast_slice_mut(byte_pairs);
/// ```
/// ```compile_fail,E0080
/// let bytes: &mut [u8] = &mut [];
/// // fails to compile: (bytes.len() might not be 0)
/// let zsts: &mut [()] = bytemuck::must_cast_slice_mut(bytes);
/// ```
#[inline]
pub fn must_cast_slice_mut<
A: NoUninit + AnyBitPattern,
B: NoUninit + AnyBitPattern,
>(
a: &mut [A],
) -> &mut [B] {
let _ = Cast::<A, B>::ASSERT_SIZE_MULTIPLE_OF;
let _ = Cast::<A, B>::ASSERT_SIZE_MULTIPLE_OF_OR_INPUT_ZST;
let _ = Cast::<A, B>::ASSERT_ALIGN_GREATER_THAN_EQUAL;
let new_len = if size_of::<A>() == size_of::<B>() {
a.len()
Expand Down
39 changes: 30 additions & 9 deletions tests/cast_slice_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,30 @@ fn test_panics() {
should_panic!(from_bytes::<u32>(&aligned_bytes[1..5]));
}

#[test]
fn test_zsts() {
#[derive(Debug, Clone, Copy)]
struct MyZst;
unsafe impl Zeroable for MyZst {}
unsafe impl Pod for MyZst {}
assert_eq!(42, cast_slice::<(), MyZst>(&[(); 42]).len());
assert_eq!(42, cast_slice_mut::<(), MyZst>(&mut [(); 42]).len());
assert_eq!(0, cast_slice::<(), u8>(&[(); 42]).len());
assert_eq!(0, cast_slice_mut::<(), u8>(&mut [(); 42]).len());
assert_eq!(0, cast_slice::<u8, ()>(&[]).len());
assert_eq!(0, cast_slice_mut::<u8, ()>(&mut []).len());

assert_eq!(
PodCastError::OutputSliceWouldHaveSlop,
try_cast_slice::<u8, ()>(&[42]).unwrap_err()
);

assert_eq!(
PodCastError::OutputSliceWouldHaveSlop,
try_cast_slice_mut::<u8, ()>(&mut [42]).unwrap_err()
);
}

#[cfg(feature = "extern_crate_alloc")]
#[test]
fn test_boxed_slices() {
Expand All @@ -209,7 +233,6 @@ fn test_boxed_slices() {
result.expect_err("u16 and i8 have different alignment");
assert_eq!(error, PodCastError::AlignmentMismatch);

// FIXME(#253): Should these next two casts' errors be consistent?
let result: Result<&[[i8; 3]], PodCastError> =
try_cast_slice(&*boxed_i8_slice);
let error =
Expand All @@ -220,7 +243,7 @@ fn test_boxed_slices() {
try_cast_slice_box(boxed_i8_slice);
let (error, boxed_i8_slice) =
result.expect_err("slice of [i8; 3] cannot be made from slice of 4 i8s");
assert_eq!(error, PodCastError::SizeMismatch);
assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop);

let empty: Box<[()]> = cast_slice_box::<u8, ()>(Box::new([]));
assert!(empty.is_empty());
Expand All @@ -229,7 +252,7 @@ fn test_boxed_slices() {
try_cast_slice_box(boxed_i8_slice);
let (error, boxed_i8_slice) =
result.expect_err("slice of ZST cannot be made from slice of 4 u8s");
assert_eq!(error, PodCastError::SizeMismatch);
assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop);

drop(boxed_i8_slice);

Expand All @@ -254,7 +277,6 @@ fn test_rc_slices() {
result.expect_err("u16 and i8 have different alignment");
assert_eq!(error, PodCastError::AlignmentMismatch);

// FIXME(#253): Should these next two casts' errors be consistent?
let result: Result<&[[i8; 3]], PodCastError> = try_cast_slice(&*rc_i8_slice);
let error =
result.expect_err("slice of [i8; 3] cannot be made from slice of 4 i8s");
Expand All @@ -264,7 +286,7 @@ fn test_rc_slices() {
try_cast_slice_rc(rc_i8_slice);
let (error, rc_i8_slice) =
result.expect_err("slice of [i8; 3] cannot be made from slice of 4 i8s");
assert_eq!(error, PodCastError::SizeMismatch);
assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop);

let empty: Rc<[()]> = cast_slice_rc::<u8, ()>(Rc::new([]));
assert!(empty.is_empty());
Expand All @@ -273,7 +295,7 @@ fn test_rc_slices() {
try_cast_slice_rc(rc_i8_slice);
let (error, rc_i8_slice) =
result.expect_err("slice of ZST cannot be made from slice of 4 u8s");
assert_eq!(error, PodCastError::SizeMismatch);
assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop);

drop(rc_i8_slice);

Expand All @@ -299,7 +321,6 @@ fn test_arc_slices() {
result.expect_err("u16 and i8 have different alignment");
assert_eq!(error, PodCastError::AlignmentMismatch);

// FIXME(#253): Should these next two casts' errors be consistent?
let result: Result<&[[i8; 3]], PodCastError> = try_cast_slice(&*arc_i8_slice);
let error =
result.expect_err("slice of [i8; 3] cannot be made from slice of 4 i8s");
Expand All @@ -309,7 +330,7 @@ fn test_arc_slices() {
try_cast_slice_arc(arc_i8_slice);
let (error, arc_i8_slice) =
result.expect_err("slice of [i8; 3] cannot be made from slice of 4 i8s");
assert_eq!(error, PodCastError::SizeMismatch);
assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop);

let empty: Arc<[()]> = cast_slice_arc::<u8, ()>(Arc::new([]));
assert!(empty.is_empty());
Expand All @@ -318,7 +339,7 @@ fn test_arc_slices() {
try_cast_slice_arc(arc_i8_slice);
let (error, arc_i8_slice) =
result.expect_err("slice of ZST cannot be made from slice of 4 u8s");
assert_eq!(error, PodCastError::SizeMismatch);
assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop);

drop(arc_i8_slice);

Expand Down
Loading