From 1039388f0baa4532f499c7e809a92dc5552e2edd Mon Sep 17 00:00:00 2001 From: zachs18 <8355914+zachs18@users.noreply.github.com> Date: Fri, 17 Feb 2023 13:24:16 -0600 Subject: [PATCH] Fix soundness issue of TransparentWrapper derive macro. (#173) Uses the compiler to check that all non-wrapped fields are actually 1-ZSTs, and uses Zeroable to check that all non-wrapped fields are "conjurable". Additionally, relaxes the bound of `PhantomData: Zeroable` to all `T: ?Sized`. --- .gitignore | 1 + derive/src/traits.rs | 43 ++++++++++++++++++++++++++++++++----------- derive/tests/basic.rs | 34 ++++++++++++++++++++++++++++++++-- src/transparent.rs | 41 ++++++++++++++++++++++++++++++++++++++++- src/zeroable.rs | 2 +- tests/derive.rs | 34 ++++++++++++++++++++++++++++++---- 6 files changed, 136 insertions(+), 19 deletions(-) diff --git a/.gitignore b/.gitignore index f939153..fb74370 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ Cargo.lock **/*.rs.bk /derive/target/ +/derive/.vscode/ diff --git a/derive/src/traits.rs b/derive/src/traits.rs index 1e6920b..a5b6952 100644 --- a/derive/src/traits.rs +++ b/derive/src/traits.rs @@ -218,8 +218,7 @@ impl Derivable for CheckedBitPattern { Ok(assert_fields_are_maybe_pod) } - Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed - * OK by NoUninit */ + Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed OK by NoUninit */ Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */ } } @@ -273,21 +272,43 @@ impl Derivable for TransparentWrapper { } fn asserts(input: &DeriveInput) -> Result { + let (impl_generics, _ty_generics, where_clause) = + input.generics.split_for_impl(); let fields = get_struct_fields(input)?; let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) { Some(wrapped_type) => wrapped_type.to_string(), None => unreachable!(), /* other code will already reject this derive */ }; - let mut wrapped_fields = fields - .iter() - .filter(|field| field.ty.to_token_stream().to_string() == wrapped_type); - if let None = wrapped_fields.next() { - bail!("TransparentWrapper must have one field of the wrapped type"); - }; - if let Some(_) = wrapped_fields.next() { - bail!("TransparentWrapper can only have one field of the wrapped type") + let mut wrapped_field_ty = None; + let mut nonwrapped_field_tys = vec![]; + for field in fields.iter() { + let field_ty = &field.ty; + if field_ty.to_token_stream().to_string() == wrapped_type { + if wrapped_field_ty.is_some() { + bail!( + "TransparentWrapper can only have one field of the wrapped type" + ); + } + wrapped_field_ty = Some(field_ty); + } else { + nonwrapped_field_tys.push(field_ty); + } + } + if let Some(wrapped_field_ty) = wrapped_field_ty { + Ok(quote!( + const _: () = { + #[repr(transparent)] + struct AssertWrappedIsWrapped #impl_generics((u8, ::core::marker::PhantomData<#wrapped_field_ty>), #(#nonwrapped_field_tys),*) #where_clause; + fn assert_zeroable() {} + fn check #impl_generics () #where_clause { + #( + assert_zeroable::<#nonwrapped_field_tys>(); + )* + } + }; + )) } else { - Ok(quote!()) + bail!("TransparentWrapper must have one field of the wrapped type") } } diff --git a/derive/tests/basic.rs b/derive/tests/basic.rs index 755a667..e53344e 100644 --- a/derive/tests/basic.rs +++ b/derive/tests/basic.rs @@ -4,7 +4,7 @@ use bytemuck::{ AnyBitPattern, CheckedBitPattern, Contiguous, NoUninit, Pod, TransparentWrapper, Zeroable, }; -use std::marker::PhantomData; +use std::marker::{PhantomData, PhantomPinned}; #[derive(Copy, Clone, Pod, Zeroable)] #[repr(C)] @@ -64,6 +64,14 @@ struct TransparentWithZeroSized { b: PhantomData, } +struct MyZst(PhantomData, [u8; 0], PhantomPinned); +unsafe impl Zeroable for MyZst {} + +#[derive(TransparentWrapper)] +#[repr(transparent)] +#[transparent(u16)] +struct TransparentTupleWithCustomZeroSized(u16, MyZst); + #[repr(u8)] #[derive(Clone, Copy, Contiguous)] enum ContiguousWithValues { @@ -169,6 +177,21 @@ struct AnyBitPatternTest { #[repr(transparent)] struct NewtypeWrapperTest(T); +/// ```compile_fail +/// use bytemuck::TransparentWrapper; +/// +/// struct NonTransparentSafeZST; +/// +/// #[derive(TransparentWrapper)] +/// #[repr(transparent)] +/// struct Wrapper(T, NonTransparentSafeZST); +/// ``` +#[derive( + Debug, Copy, Clone, PartialEq, Eq, Pod, Zeroable, TransparentWrapper, +)] +#[repr(transparent)] +struct TransarentWrapperZstTest(T); + #[test] fn fails_cast_contiguous() { let can_cast = CheckedBitPatternEnumWithValues::is_valid_bit_pattern(&5); @@ -207,7 +230,14 @@ fn fails_cast_bytelit() { fn passes_cast_bytelit() { let res = bytemuck::checked::cast_slice::(b"CAB"); - assert_eq!(res, [CheckedBitPatternEnumByteLit::C, CheckedBitPatternEnumByteLit::A, CheckedBitPatternEnumByteLit::B]); + assert_eq!( + res, + [ + CheckedBitPatternEnumByteLit::C, + CheckedBitPatternEnumByteLit::A, + CheckedBitPatternEnumByteLit::B + ] + ); } #[test] diff --git a/src/transparent.rs b/src/transparent.rs index 0329d1c..5b9fe0e 100644 --- a/src/transparent.rs +++ b/src/transparent.rs @@ -23,7 +23,9 @@ use super::*; /// the only non-ZST field. /// /// 2. Any fields *other* than the `Inner` field must be trivially constructable -/// ZSTs, for example `PhantomData`, `PhantomPinned`, etc. +/// ZSTs, for example `PhantomData`, `PhantomPinned`, etc. (When deriving +/// `TransparentWrapper` on a type with ZST fields, the ZST fields must be +/// [`Zeroable`]). /// /// 3. The `Wrapper` may not impose additional alignment requirements over /// `Inner`. @@ -84,6 +86,43 @@ use super::*; /// let mut buf = [1, 2, 3u8]; /// let sm = Slice::wrap_mut(&mut buf); /// ``` +/// +/// ## Deriving +/// +/// When deriving, the non-wrapped fields must uphold all the normal requirements, +/// and must also be `Zeroable`. +/// +#[cfg_attr(feature = "derive", doc = "```")] +#[cfg_attr( + not(feature = "derive"), + doc = "```ignore +// This example requires the `derive` feature." +)] +/// use bytemuck::TransparentWrapper; +/// use std::marker::PhantomData; +/// +/// #[derive(TransparentWrapper)] +/// #[repr(transparent)] +/// #[transparent(usize)] +/// struct Wrapper(usize, PhantomData); // PhantomData implements Zeroable for all T +/// ``` +/// +/// Here, an error will occur, because `MyZst` does not implement `Zeroable`. +/// +#[cfg_attr(feature = "derive", doc = "```compile_fail")] +#[cfg_attr( + not(feature = "derive"), + doc = "```ignore +// This example requires the `derive` feature." +)] +/// use bytemuck::TransparentWrapper; +/// struct MyZst; +/// +/// #[derive(TransparentWrapper)] +/// #[repr(transparent)] +/// #[transparent(usize)] +/// struct Wrapper(usize, MyZst); // MyZst does not implement Zeroable +/// ``` pub unsafe trait TransparentWrapper { /// Convert the inner type into the wrapper type. #[inline] diff --git a/src/zeroable.rs b/src/zeroable.rs index 687ba0f..d10fb1a 100644 --- a/src/zeroable.rs +++ b/src/zeroable.rs @@ -64,7 +64,7 @@ unsafe impl Zeroable for *const [T] {} unsafe impl Zeroable for *mut str {} unsafe impl Zeroable for *const str {} -unsafe impl Zeroable for PhantomData {} +unsafe impl Zeroable for PhantomData {} unsafe impl Zeroable for PhantomPinned {} unsafe impl Zeroable for ManuallyDrop {} unsafe impl Zeroable for core::cell::UnsafeCell {} diff --git a/tests/derive.rs b/tests/derive.rs index f06ff77..1c6b10e 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -2,6 +2,7 @@ #![allow(dead_code)] use bytemuck::{ByteEq, ByteHash, Pod, TransparentWrapper, Zeroable}; +use std::marker::PhantomData; #[derive(Copy, Clone, Pod, Zeroable, ByteEq, ByteHash)] #[repr(C)] @@ -26,7 +27,7 @@ struct TransparentWithZeroSized { #[derive(TransparentWrapper)] #[repr(transparent)] -struct TransparentWithGeneric { +struct TransparentWithGeneric { a: T, } @@ -39,9 +40,9 @@ fn test_generic(x: T) -> TransparentWithGeneric { #[derive(TransparentWrapper)] #[repr(transparent)] #[transparent(T)] -struct TransparentWithGenericAndZeroSized { - a: T, - b: () +struct TransparentWithGenericAndZeroSized { + a: (), + b: T, } /// Ensuring that no additional bounds are emitted. @@ -49,3 +50,28 @@ struct TransparentWithGenericAndZeroSized { fn test_generic_with_zst(x: T) -> TransparentWithGenericAndZeroSized { TransparentWithGenericAndZeroSized::wrap(x) } + +#[derive(TransparentWrapper)] +#[repr(transparent)] +struct TransparentUnsized { + a: dyn std::fmt::Debug, +} + +type DynDebug = dyn std::fmt::Debug; + +#[derive(TransparentWrapper)] +#[repr(transparent)] +#[transparent(DynDebug)] +struct TransparentUnsizedWithZeroSized { + a: (), + b: DynDebug, +} + +#[derive(TransparentWrapper)] +#[repr(transparent)] +#[transparent(DynDebug)] +struct TransparentUnsizedWithGenericZeroSizeds { + a: PhantomData, + b: PhantomData, + c: DynDebug, +}