diff --git a/zerocopy-derive/src/lib.rs b/zerocopy-derive/src/lib.rs index f1a7f389a5..74775e0468 100644 --- a/zerocopy-derive/src/lib.rs +++ b/zerocopy-derive/src/lib.rs @@ -79,22 +79,26 @@ macro_rules! try_or_print { /// are currently required to live at the crate root, and so the caller must /// specify the name in order to avoid name collisions. macro_rules! derive { - ($trait:ident => $outer:ident => $inner:ident) => { - #[proc_macro_derive($trait)] - pub fn $outer(ts: proc_macro::TokenStream) -> proc_macro::TokenStream { - let ast = syn::parse_macro_input!(ts as DeriveInput); - $inner(&ast).into() - } + ($($trait:ident => $outer:ident => $inner:ident,)*) => { + $( + #[proc_macro_derive($trait)] + pub fn $outer(ts: proc_macro::TokenStream) -> proc_macro::TokenStream { + let ast = syn::parse_macro_input!(ts as DeriveInput); + $inner(&ast, Trait::$trait).into() + } + )* }; } -derive!(KnownLayout => derive_known_layout => derive_known_layout_inner); -derive!(Immutable => derive_no_cell => derive_no_cell_inner); -derive!(TryFromBytes => derive_try_from_bytes => derive_try_from_bytes_inner); -derive!(FromZeros => derive_from_zeros => derive_from_zeros_inner); -derive!(FromBytes => derive_from_bytes => derive_from_bytes_inner); -derive!(IntoBytes => derive_into_bytes => derive_into_bytes_inner); -derive!(Unaligned => derive_unaligned => derive_unaligned_inner); +derive!( + KnownLayout => derive_known_layout => derive_known_layout_inner, + Immutable => derive_no_cell => derive_no_cell_inner, + TryFromBytes => derive_try_from_bytes => derive_try_from_bytes_inner, + FromZeros => derive_from_zeros => derive_from_zeros_inner, + FromBytes => derive_from_bytes => derive_from_bytes_inner, + IntoBytes => derive_into_bytes => derive_into_bytes_inner, + Unaligned => derive_unaligned => derive_unaligned_inner, +); /// Deprecated: prefer [`FromZeros`] instead. #[deprecated(since = "0.8.0", note = "`FromZeroes` was renamed to `FromZeros`")] @@ -112,7 +116,7 @@ pub fn derive_as_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenStream { derive_into_bytes(ts) } -fn derive_known_layout_inner(ast: &DeriveInput) -> proc_macro2::TokenStream { +fn derive_known_layout_inner(ast: &DeriveInput, _top_level: Trait) -> proc_macro2::TokenStream { let is_repr_c_struct = match &ast.data { Data::Struct(..) => { let reprs = try_or_print!(repr::reprs::(&ast.attrs)); @@ -325,7 +329,7 @@ fn derive_known_layout_inner(ast: &DeriveInput) -> proc_macro2::TokenStream { } } -fn derive_no_cell_inner(ast: &DeriveInput) -> proc_macro2::TokenStream { +fn derive_no_cell_inner(ast: &DeriveInput, _top_level: Trait) -> proc_macro2::TokenStream { match &ast.data { Data::Struct(strct) => impl_block( ast, @@ -357,16 +361,16 @@ fn derive_no_cell_inner(ast: &DeriveInput) -> proc_macro2::TokenStream { } } -fn derive_try_from_bytes_inner(ast: &DeriveInput) -> proc_macro2::TokenStream { +fn derive_try_from_bytes_inner(ast: &DeriveInput, top_level: Trait) -> proc_macro2::TokenStream { match &ast.data { - Data::Struct(strct) => derive_try_from_bytes_struct(ast, strct), - Data::Enum(enm) => derive_try_from_bytes_enum(ast, enm), - Data::Union(unn) => derive_try_from_bytes_union(ast, unn), + Data::Struct(strct) => derive_try_from_bytes_struct(ast, strct, top_level), + Data::Enum(enm) => derive_try_from_bytes_enum(ast, enm, top_level), + Data::Union(unn) => derive_try_from_bytes_union(ast, unn, top_level), } } -fn derive_from_zeros_inner(ast: &DeriveInput) -> proc_macro2::TokenStream { - let try_from_bytes = derive_try_from_bytes_inner(ast); +fn derive_from_zeros_inner(ast: &DeriveInput, top_level: Trait) -> proc_macro2::TokenStream { + let try_from_bytes = derive_try_from_bytes_inner(ast, top_level); let from_zeros = match &ast.data { Data::Struct(strct) => derive_from_zeros_struct(ast, strct), Data::Enum(enm) => derive_from_zeros_enum(ast, enm), @@ -375,8 +379,8 @@ fn derive_from_zeros_inner(ast: &DeriveInput) -> proc_macro2::TokenStream { IntoIterator::into_iter([try_from_bytes, from_zeros]).collect() } -fn derive_from_bytes_inner(ast: &DeriveInput) -> proc_macro2::TokenStream { - let from_zeros = derive_from_zeros_inner(ast); +fn derive_from_bytes_inner(ast: &DeriveInput, top_level: Trait) -> proc_macro2::TokenStream { + let from_zeros = derive_from_zeros_inner(ast, top_level); let from_bytes = match &ast.data { Data::Struct(strct) => derive_from_bytes_struct(ast, strct), Data::Enum(enm) => derive_from_bytes_enum(ast, enm), @@ -386,7 +390,7 @@ fn derive_from_bytes_inner(ast: &DeriveInput) -> proc_macro2::TokenStream { IntoIterator::into_iter([from_zeros, from_bytes]).collect() } -fn derive_into_bytes_inner(ast: &DeriveInput) -> proc_macro2::TokenStream { +fn derive_into_bytes_inner(ast: &DeriveInput, _top_level: Trait) -> proc_macro2::TokenStream { match &ast.data { Data::Struct(strct) => derive_into_bytes_struct(ast, strct), Data::Enum(enm) => derive_into_bytes_enum(ast, enm), @@ -394,7 +398,7 @@ fn derive_into_bytes_inner(ast: &DeriveInput) -> proc_macro2::TokenStream { } } -fn derive_unaligned_inner(ast: &DeriveInput) -> proc_macro2::TokenStream { +fn derive_unaligned_inner(ast: &DeriveInput, _top_level: Trait) -> proc_macro2::TokenStream { match &ast.data { Data::Struct(strct) => derive_unaligned_struct(ast, strct), Data::Enum(enm) => derive_unaligned_enum(ast, enm), @@ -405,40 +409,51 @@ fn derive_unaligned_inner(ast: &DeriveInput) -> proc_macro2::TokenStream { // A struct is `TryFromBytes` if: // - all fields are `TryFromBytes` -fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - let extras = Some({ - let fields = strct.fields(); - let field_names = fields.iter().map(|(name, _ty)| name); - let field_tys = fields.iter().map(|(_name, ty)| ty); - quote!( - // SAFETY: We use `is_bit_valid` to validate that each field is - // bit-valid, and only return `true` if all of them are. The bit - // validity of a struct is just the composition of the bit - // validities of its fields, so this is a sound implementation of - // `is_bit_valid`. - fn is_bit_valid>( - mut candidate: ::zerocopy::Maybe - ) -> bool { - true #(&& { - // SAFETY: - // - `project` is a field projection, and so it addresses a - // subset of the bytes addressed by `slf` - // - ..., and so it preserves provenance - // - ..., and `*slf` is a struct, so `UnsafeCell`s exist at - // the same byte ranges in the returned pointer's referent - // as they do in `*slf` - let field_candidate = unsafe { - let project = |slf: *mut Self| - ::zerocopy::macro_util::core_reexport::ptr::addr_of_mut!((*slf).#field_names); - - candidate.reborrow().project(project) - }; - - <#field_tys as ::zerocopy::TryFromBytes>::is_bit_valid(field_candidate) - })* - } - ) - }); +fn derive_try_from_bytes_struct( + ast: &DeriveInput, + strct: &DataStruct, + top_level: Trait, +) -> proc_macro2::TokenStream { + let extras = if top_level == Trait::FromBytes { + // Since the top-level trait is `FromBytes`, we know that the + // compilation will only succeed if `Self` is soundly `FromBytes`, and + // so it's sound to use a trivial `is_bit_valid` impl. + Some(gen_is_bit_valid_for_from_bytes_type()) + } else { + Some({ + let fields = strct.fields(); + let field_names = fields.iter().map(|(name, _ty)| name); + let field_tys = fields.iter().map(|(_name, ty)| ty); + quote!( + // SAFETY: We use `is_bit_valid` to validate that each field is + // bit-valid, and only return `true` if all of them are. The bit + // validity of a struct is just the composition of the bit + // validities of its fields, so this is a sound implementation of + // `is_bit_valid`. + fn is_bit_valid>( + mut candidate: ::zerocopy::Maybe + ) -> bool { + true #(&& { + // SAFETY: + // - `project` is a field projection, and so it addresses a + // subset of the bytes addressed by `slf` + // - ..., and so it preserves provenance + // - ..., and `*slf` is a struct, so `UnsafeCell`s exist at + // the same byte ranges in the returned pointer's referent + // as they do in `*slf` + let field_candidate = unsafe { + let project = |slf: *mut Self| + ::zerocopy::macro_util::core_reexport::ptr::addr_of_mut!((*slf).#field_names); + + candidate.reborrow().project(project) + }; + + <#field_tys as ::zerocopy::TryFromBytes>::is_bit_valid(field_candidate) + })* + } + ) + }) + }; impl_block( ast, strct, @@ -453,43 +468,54 @@ fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_m // A union is `TryFromBytes` if: // - all of its fields are `TryFromBytes` and `Immutable` -fn derive_try_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { +fn derive_try_from_bytes_union( + ast: &DeriveInput, + unn: &DataUnion, + top_level: Trait, +) -> proc_macro2::TokenStream { // TODO(#5): Remove the `Immutable` bound. let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf, TraitBound::Other(Trait::Immutable)]); - let extras = Some({ - let fields = unn.fields(); - let field_names = fields.iter().map(|(name, _ty)| name); - let field_tys = fields.iter().map(|(_name, ty)| ty); - quote!( - // SAFETY: We use `is_bit_valid` to validate that any field is - // bit-valid; we only return `true` if at least one of them is. The - // bit validity of a union is not yet well defined in Rust, but it - // is guaranteed to be no more strict than this definition. See #696 - // for a more in-depth discussion. - fn is_bit_valid>( - mut candidate: ::zerocopy::Maybe - ) -> bool { - false #(|| { - // SAFETY: - // - `project` is a field projection, and so it addresses a - // subset of the bytes addressed by `slf` - // - ..., and so it preserves provenance - // - Since `Self: Immutable` is enforced by - // `self_type_trait_bounds`, neither `*slf` nor the - // returned pointer's referent contain any `UnsafeCell`s - let field_candidate = unsafe { - let project = |slf: *mut Self| - ::zerocopy::macro_util::core_reexport::ptr::addr_of_mut!((*slf).#field_names); - - candidate.reborrow().project(project) - }; - - <#field_tys as ::zerocopy::TryFromBytes>::is_bit_valid(field_candidate) - })* - } - ) - }); + let extras = if top_level == Trait::FromBytes { + // Since the top-level trait is `FromBytes`, we know that the + // compilation will only succeed if `Self` is soundly `FromBytes`, and + // so it's sound to use a trivial `is_bit_valid` impl. + Some(gen_is_bit_valid_for_from_bytes_type()) + } else { + Some({ + let fields = unn.fields(); + let field_names = fields.iter().map(|(name, _ty)| name); + let field_tys = fields.iter().map(|(_name, ty)| ty); + quote!( + // SAFETY: We use `is_bit_valid` to validate that any field is + // bit-valid; we only return `true` if at least one of them is. The + // bit validity of a union is not yet well defined in Rust, but it + // is guaranteed to be no more strict than this definition. See #696 + // for a more in-depth discussion. + fn is_bit_valid>( + mut candidate: ::zerocopy::Maybe + ) -> bool { + false #(|| { + // SAFETY: + // - `project` is a field projection, and so it addresses a + // subset of the bytes addressed by `slf` + // - ..., and so it preserves provenance + // - Since `Self: Immutable` is enforced by + // `self_type_trait_bounds`, neither `*slf` nor the + // returned pointer's referent contain any `UnsafeCell`s + let field_candidate = unsafe { + let project = |slf: *mut Self| + ::zerocopy::macro_util::core_reexport::ptr::addr_of_mut!((*slf).#field_names); + + candidate.reborrow().project(project) + }; + + <#field_tys as ::zerocopy::TryFromBytes>::is_bit_valid(field_candidate) + })* + } + ) + }) + }; impl_block( ast, unn, @@ -508,7 +534,11 @@ const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[ &[StructRepr::C, StructRepr::Packed], ]; -fn derive_try_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream { +fn derive_try_from_bytes_enum( + ast: &DeriveInput, + enm: &DataEnum, + top_level: Trait, +) -> proc_macro2::TokenStream { if !enm.is_fieldless() { return Error::new_spanned(ast, "only field-less enums can implement TryFromBytes") .to_compile_error(); @@ -526,80 +556,99 @@ fn derive_try_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2: enm.is_fieldless() && enm.variants.len() == 1usize << size }) .unwrap_or(false); - - let variant_names = enm.variants.iter().map(|v| &v.ident); - let is_bit_valid_body = if from_bytes { - // If the enum could implement `FromBytes`, we can avoid emitting a - // match statement. This is faster to compile, and generates code which - // performs better. - quote!({ - // Prevent an "unused" warning. - let _ = candidate; - // SAFETY: If the enum could implement `FromBytes`, then all bit - // patterns are valid. Thus, this is a sound implementation. - true - }) + let extras = if from_bytes || top_level == Trait::FromBytes { + // At least one of the following conditions holds: + // - The top-level trait is `FromBytes`, so we know that the + // compilation will only succeed if `Self` is soundly `FromBytes` + // - It would be sound for the enum to implement `FromBytes` + // + // Thus, it's sound to use a trivial `is_bit_valid` impl. Using a + // trivial impl is faster to codegen, faster to compile, and is + // friendlier on the optimizer. + Some(gen_is_bit_valid_for_from_bytes_type()) } else { - quote!( - use ::zerocopy::macro_util::core_reexport; - // SAFETY: - // - The closure is a pointer cast, and `Self` and `[u8; - // size_of::()]` have the same size, so the returned pointer - // addresses the same bytes as `p` subset of the bytes addressed - // by `slf` - // - ..., and so it preserves provenance - // - Since we validate that this type is a field-less enum, it - // cannot contain any `UnsafeCell`s. Neither does `[u8; N]`. - let discriminant = unsafe { candidate.cast_unsized(|p: *mut Self| p as *mut [core_reexport::primitive::u8; core_reexport::mem::size_of::()]) }; - // SAFETY: Since `candidate` has the invariant `Initialized`, we - // know that `candidate`'s referent (and thus `discriminant`'s - // referent) are fully initialized. Since all of the allowed `repr`s - // are types for which all bytes are always initialized, we know - // that `discriminant`'s referent has all of its bytes initialized. - // Since `[u8; N]`'s validity invariant is just that all of its - // bytes are initialized, we know that `discriminant`'s referent is - // bit-valid. - let discriminant = unsafe { discriminant.assume_valid() }; - let discriminant = discriminant.read_unaligned(); - - false #(|| { - let v = Self::#variant_names{}; - // SAFETY: All of the allowed `repr`s for `Self` guarantee that - // `Self`'s discriminant bytes are all initialized. Since we - // validate that `Self` has no fields, it has no bytes other - // than the discriminant. Thus, it is sound to transmute any - // instance of `Self` to `[u8; size_of::()]`. - let d: [core_reexport::primitive::u8; core_reexport::mem::size_of::()] = unsafe { core_reexport::mem::transmute(v) }; - // SAFETY: Here we check that the bits of the argument - // `candidate` are equal to the bits of a `Self` constructed - // using safe code. If this condition passes, then we know that - // `candidate` refers to a bit-valid `Self`. - discriminant == d - })* - ) + let variant_names = enm.variants.iter().map(|v| &v.ident); + Some(quote!( + // SAFETY: We use `is_bit_valid` to validate that the bit pattern + // corresponds to one of the field-less enum's variant + // discriminants. Thus, this is a sound implementation of + // `is_bit_valid`. + fn is_bit_valid< + A: ::zerocopy::pointer::invariant::Aliasing + + ::zerocopy::pointer::invariant::AtLeast<::zerocopy::pointer::invariant::Shared>, + >( + candidate: ::zerocopy::Ptr< + '_, + Self, + ( + A, + ::zerocopy::pointer::invariant::Any, + ::zerocopy::pointer::invariant::Initialized, + ), + >, + ) -> ::zerocopy::macro_util::core_reexport::primitive::bool { + use ::zerocopy::macro_util::core_reexport; + // SAFETY: + // - The closure is a pointer cast, and `Self` and `[u8; + // size_of::()]` have the same size, so the returned + // pointer addresses the same bytes as `p` subset of the bytes + // addressed by `slf` + // - ..., and so it preserves provenance + // - Since we validate that this type is a field-less enum, it + // cannot contain any `UnsafeCell`s. Neither does `[u8; N]`. + let discriminant = unsafe { candidate.cast_unsized(|p: *mut Self| p as *mut [core_reexport::primitive::u8; core_reexport::mem::size_of::()]) }; + // SAFETY: Since `candidate` has the invariant `Initialized`, we + // know that `candidate`'s referent (and thus `discriminant`'s + // referent) are fully initialized. Since all of the allowed + // `repr`s are types for which all bytes are always initialized, + // we know that `discriminant`'s referent has all of its bytes + // initialized. Since `[u8; N]`'s validity invariant is just + // that all of its bytes are initialized, we know that + // `discriminant`'s referent is bit-valid. + let discriminant = unsafe { discriminant.assume_valid() }; + let discriminant = discriminant.read_unaligned(); + + false #(|| { + let v = Self::#variant_names{}; + // SAFETY: All of the allowed `repr`s for `Self` guarantee + // that `Self`'s discriminant bytes are all initialized. + // Since we validate that `Self` has no fields, it has no + // bytes other than the discriminant. Thus, it is sound to + // transmute any instance of `Self` to `[u8; + // size_of::()]`. + let d: [core_reexport::primitive::u8; core_reexport::mem::size_of::()] = unsafe { core_reexport::mem::transmute(v) }; + // SAFETY: Here we check that the bits of the argument + // `candidate` are equal to the bits of a `Self` constructed + // using safe code. If this condition passes, then we know + // that `candidate` refers to a bit-valid `Self`. + discriminant == d + })* + } + )) }; - let extras = Some(quote!( - // SAFETY: We use `is_bit_valid` to validate that the bit pattern - // corresponds to one of the field-less enum's variant discriminants. - // Thus, this is a sound implementation of `is_bit_valid`. - fn is_bit_valid>( - candidate: ::zerocopy::Ptr< - '_, - Self, - ( - A, - ::zerocopy::pointer::invariant::Any, - ::zerocopy::pointer::invariant::Initialized, - ), - >, - ) -> ::zerocopy::macro_util::core_reexport::primitive::bool { - #is_bit_valid_body - } - )); impl_block(ast, enm, Trait::TryFromBytes, FieldBounds::ALL_SELF, SelfBounds::None, None, extras) } +// Generates a `TryFromBytes::is_bit_valid` instance for a `FromBytes` type - +// ie, one that unconditionally returns `true`. +// +// This should be used where possible. Using this impl is faster to codegen, +// faster to compile, and is friendlier on the optimizer. +fn gen_is_bit_valid_for_from_bytes_type() -> proc_macro2::TokenStream { + quote!( + // SAFETY: `Self: FromBytes`, so all initialized byte sequences + // represent valid instances of `Self`. + fn is_bit_valid(candidate: ::zerocopy::Maybe) -> bool + where + A: ::zerocopy::pointer::invariant::Aliasing + + ::zerocopy::pointer::invariant::AtLeast<::zerocopy::pointer::invariant::Shared>, + { + true + } + ) +} + #[rustfmt::skip] const ENUM_TRY_FROM_BYTES_CFG: Config = { use EnumRepr::*;