Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[derive] TryFromBytes on repr(C) enums #806

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 36 additions & 93 deletions zerocopy-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use {
quote::quote,
syn::{
parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, Expr, ExprLit,
GenericParam, Ident, Index, Lit,
GenericParam, Ident, Lit,
},
};

Expand Down Expand Up @@ -405,29 +405,13 @@ fn derive_try_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2:
.to_compile_error();
}

let reprs = try_or_print!(ENUM_TRY_FROM_BYTES_CFG.validate_reprs(ast));
let discriminant_type = match reprs.as_slice() {
[EnumRepr::U8] => quote!(u8),
[EnumRepr::U16] => quote!(u16),
[EnumRepr::U32] => quote!(u32),
[EnumRepr::U64] => quote!(u64),
[EnumRepr::Usize] => quote!(usize),
[EnumRepr::I8] => quote!(i8),
[EnumRepr::I16] => quote!(i16),
[EnumRepr::I32] => quote!(i32),
[EnumRepr::I64] => quote!(i64),
[EnumRepr::Isize] => quote!(isize),
// `validate_reprs` has already validated that it's one of the preceding
// patterns.
_ => unreachable!(),
};

let discriminant_exprs = enm.variants.iter().scan(Discriminant::default(), |disc, var| {
Some(disc.update_and_generate_expr(&var.discriminant))
});
// We don't actually care what the repr is; we just care that it's one of
// the allowed ones.
try_or_print!(ENUM_TRY_FROM_BYTES_CFG.validate_reprs(ast));
let variant_names = enm.variants.iter().map(|v| &v.ident);
let extras = Some(quote!(
// SAFETY: We use `is_bit_valid` to validate that the bit pattern
// corresponds to one of the C-like enum's variant discriminants.
// corresponds to one of the field-less enum's variant discriminants.
// Thus, this is a sound implementation of `is_bit_valid`.
fn is_bit_valid(
candidate: ::zerocopy::Ptr<
Expand All @@ -439,90 +423,51 @@ fn derive_try_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2:
::zerocopy::pointer::invariant::AsInitialized,
),
>,
) -> bool {
) -> ::zerocopy::macro_util::core_reexport::primitive::bool {
use ::zerocopy::macro_util::core_reexport;
// SAFETY:
// - `cast` is implemented as required.
// - Since we cast to the type specified by `Self`'s repr, `p`'s
// referent and the referent of the returned pointer have the
// same size.
let discriminant = unsafe { candidate.cast_unsized(|p: *mut Self| p as *mut ::zerocopy::macro_util::core_reexport::primitive::#discriminant_type) };
// SAFETY: Since `candidate` has the invariant `AsInitialized`,
// we know that `candidate`'s referent (and thus
// `discriminant`'s referent) is as-initialized as `Self`. Since
// `Self`'s repr is the same type as `discriminant`, we know
// that `discriminant`'s referent satisfies the as-initialized
// property.
// - By definition, `*mut Self` and `*mut [u8; size_of::<Self>()]`
// are types of the same size.
let discriminant = unsafe { candidate.cast_unsized(|p: *mut Self| p as *mut [core_reexport::primitive::u8; core_reexport::mem::size_of::<Self>()]) };
// SAFETY: Since `candidate` has the invariant `AsInitialized`, we
// know that `candidate`'s referent (and thus `discriminant`'s
// referent) is as-initialized as `Self`. Since all of the allowed
// `repr`s are types for which all bytes are always initialized, we
// know that `discriminant`'s referent has all of its bytes
// initialized. Since `[u8; N]`'s validity invariant is just that
// all of its bytes are initialized, we know that `discriminant`'s
// referent is bit-valid.
let discriminant = unsafe { discriminant.assume_valid() };
let discriminant = discriminant.read_unaligned();

false #(|| (discriminant == (#discriminant_exprs)))*
false #(|| {
let v = Self::#variant_names{};
// SAFETY: All of the allowed `repr`s for `Self` guarantee that
// `Self`'s discriminant bytes are all initialized. Since we
// validate that `Self` has no fields, it has no bytes other
// than the discriminant. Thus, it is sound to transmute any
// instance of `Self` to `[u8; size_of::<Self>()]`.
let d: [core_reexport::primitive::u8; core_reexport::mem::size_of::<Self>()] = unsafe { core_reexport::mem::transmute(v) };
// SAFETY: Here we check that the bits of the argument
// `candidate` are equal to the bits of a `Self` constructed
// using safe code. If this condition passes, then we know that
// `candidate` refers to a bit-valid `Self`.
discriminant == d
})*
}
));
impl_block(ast, enm, Trait::TryFromBytes, RequireBoundedFields::Yes, false, None, extras)
}

// Enum variant discriminants can be manually set not only as literal values,
// but as arbitrary const expressions. In order to handle this, we keep track of
// the most-recently-seen expression and a count of how many variants have been
// encountered since then.
//
// #[repr(u8)]
// enum Foo {
// A, // 0
// B = 5, // 5
// C, // 6
// D = 1 + 1, // 2
// E, // 3
// }
//
// Note: Default::default does the right thing (initializes to { None, 0 }).
#[derive(Default, Copy, Clone)]
struct Discriminant<'a> {
// The most-recently-set explicit discriminant.
previous: Option<&'a Expr>,
// When the next variant is encountered, what offset should be used compared
// to `previous` to determine the variant's discriminant?
next_offset: usize,
}

impl<'a> Discriminant<'a> {
/// Called when encountering a variant with discriminant set to `ast`.
/// Updates `self` in preparation for the next variant and generates an
/// expression which will evaluate to the numeric value this variant's
/// discriminant.
fn update_and_generate_expr(
&mut self,
ast: &'a Option<(syn::token::Eq, Expr)>,
) -> proc_macro2::TokenStream {
match ast.as_ref().map(|(_eq, expr)| expr) {
Some(expr) => {
self.previous = Some(expr);
self.next_offset = 1;
quote!(#expr)
}
None => {
let previous = self.previous.iter();
// Use `Index` instead of `usize` so that the number is
// formatted just as `0` rather than as `0usize`; the latter
// syntax is only valid if the repr is `usize`; otherwise,
// comparison will result in a type mismatch.
let offset = Index::from(self.next_offset);
let tokens = quote!(#(#previous +)* #offset);

self.next_offset += 1;
tokens
}
}
}
}

#[rustfmt::skip]
const ENUM_TRY_FROM_BYTES_CFG: Config<EnumRepr> = {
use EnumRepr::*;
Config {
allowed_combinations_message: r#"TryFromBytes requires repr of "u8", "u16", "u32", "u64", "usize", "i8", or "i16", "i32", "i64", or "isize""#,
allowed_combinations_message: r#"TryFromBytes requires repr of "C", "u8", "u16", "u32", "u64", "usize", "i8", or "i16", "i32", "i64", or "isize""#,
derive_unaligned: false,
allowed_combinations: &[
&[C],
&[U8],
&[U16],
&[U32],
Expand All @@ -534,9 +479,7 @@ const ENUM_TRY_FROM_BYTES_CFG: Config<EnumRepr> = {
&[I64],
&[Isize],
],
disallowed_but_legal_combinations: &[
&[C],
],
disallowed_but_legal_combinations: &[],
}
};

Expand Down
93 changes: 57 additions & 36 deletions zerocopy-derive/tests/enum_try_from_bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

use std::convert::TryFrom;

use syn::Field;

mod util;

use {
Expand All @@ -27,10 +29,10 @@ assert_impl_all!(Foo: TryFromBytes);

#[test]
fn test_foo() {
assert_eq!(Foo::try_from_ref(&[0]), Some(&Foo::A));
assert_eq!(Foo::try_from_ref(&[]), None);
assert_eq!(Foo::try_from_ref(&[1]), None);
assert_eq!(Foo::try_from_ref(&[0, 0]), None);
assert_eq!(Foo::try_read_from(&[0]), Some(Foo::A));
assert_eq!(Foo::try_read_from(&[]), None);
assert_eq!(Foo::try_read_from(&[1]), None);
assert_eq!(Foo::try_read_from(&[0, 0]), None);
}

#[derive(Eq, PartialEq, Debug, KnownLayout, TryFromBytes)]
Expand All @@ -43,11 +45,11 @@ assert_impl_all!(Bar: TryFromBytes);

#[test]
fn test_bar() {
assert_eq!(Bar::try_from_ref(&[0, 0]), Some(&Bar::A));
assert_eq!(Bar::try_from_ref(&[]), None);
assert_eq!(Bar::try_from_ref(&[0]), None);
assert_eq!(Bar::try_from_ref(&[0, 1]), None);
assert_eq!(Bar::try_from_ref(&[0, 0, 0]), None);
assert_eq!(Bar::try_read_from(&[0, 0]), Some(Bar::A));
assert_eq!(Bar::try_read_from(&[]), None);
assert_eq!(Bar::try_read_from(&[0]), None);
assert_eq!(Bar::try_read_from(&[0, 1]), None);
assert_eq!(Bar::try_read_from(&[0, 0, 0]), None);
}

#[derive(Eq, PartialEq, Debug, KnownLayout, TryFromBytes)]
Expand All @@ -61,13 +63,13 @@ assert_impl_all!(Baz: TryFromBytes);

#[test]
fn test_baz() {
assert_eq!(Baz::try_from_ref(1u32.as_bytes()), Some(&Baz::A));
assert_eq!(Baz::try_from_ref(0u32.as_bytes()), Some(&Baz::B));
assert_eq!(Baz::try_from_ref(&[]), None);
assert_eq!(Baz::try_from_ref(&[0]), None);
assert_eq!(Baz::try_from_ref(&[0, 0]), None);
assert_eq!(Baz::try_from_ref(&[0, 0, 0]), None);
assert_eq!(Baz::try_from_ref(&[0, 0, 0, 0, 0]), None);
assert_eq!(Baz::try_read_from(1u32.as_bytes()), Some(Baz::A));
assert_eq!(Baz::try_read_from(0u32.as_bytes()), Some(Baz::B));
assert_eq!(Baz::try_read_from(&[]), None);
assert_eq!(Baz::try_read_from(&[0]), None);
assert_eq!(Baz::try_read_from(&[0, 0]), None);
assert_eq!(Baz::try_read_from(&[0, 0, 0]), None);
assert_eq!(Baz::try_read_from(&[0, 0, 0, 0, 0]), None);
}

// Test hygiene - make sure that `i8` being shadowed doesn't cause problems for
Expand All @@ -89,17 +91,17 @@ assert_impl_all!(Blah: TryFromBytes);

#[test]
fn test_blah() {
assert_eq!(Blah::try_from_ref(1i8.as_bytes()), Some(&Blah::A));
assert_eq!(Blah::try_from_ref(0i8.as_bytes()), Some(&Blah::B));
assert_eq!(Blah::try_from_ref(3i8.as_bytes()), Some(&Blah::C));
assert_eq!(Blah::try_from_ref(6i8.as_bytes()), Some(&Blah::D));
assert_eq!(Blah::try_from_ref(&[]), None);
assert_eq!(Blah::try_from_ref(&[4]), None);
assert_eq!(Blah::try_from_ref(&[0, 0]), None);
assert_eq!(Blah::try_read_from(1i8.as_bytes()), Some(Blah::A));
assert_eq!(Blah::try_read_from(0i8.as_bytes()), Some(Blah::B));
assert_eq!(Blah::try_read_from(3i8.as_bytes()), Some(Blah::C));
assert_eq!(Blah::try_read_from(6i8.as_bytes()), Some(Blah::D));
assert_eq!(Blah::try_read_from(&[]), None);
assert_eq!(Blah::try_read_from(&[4]), None);
assert_eq!(Blah::try_read_from(&[0, 0]), None);
}

#[derive(Eq, PartialEq, Debug, KnownLayout, TryFromBytes)]
#[repr(u8)]
#[derive(Eq, PartialEq, Debug, KnownLayout, TryFromBytes, IntoBytes)]
#[repr(C)]
enum FieldlessButNotUnitOnly {
A,
B(),
Expand All @@ -108,19 +110,38 @@ enum FieldlessButNotUnitOnly {

#[test]
fn test_fieldless_but_not_unit_only() {
const SIZE: usize = core::mem::size_of::<FieldlessButNotUnitOnly>();
let disc: [u8; SIZE] = zerocopy::transmute!(FieldlessButNotUnitOnly::A);
assert_eq!(FieldlessButNotUnitOnly::try_read_from(&disc[..]), Some(FieldlessButNotUnitOnly::A));
let disc: [u8; SIZE] = zerocopy::transmute!(FieldlessButNotUnitOnly::B());
assert_eq!(
FieldlessButNotUnitOnly::try_from_ref(0u8.as_bytes()),
Some(&FieldlessButNotUnitOnly::A)
);
assert_eq!(
FieldlessButNotUnitOnly::try_from_ref(1u8.as_bytes()),
Some(&FieldlessButNotUnitOnly::B())
FieldlessButNotUnitOnly::try_read_from(&disc[..]),
Some(FieldlessButNotUnitOnly::B())
);
let disc: [u8; SIZE] = zerocopy::transmute!(FieldlessButNotUnitOnly::C {});
assert_eq!(
FieldlessButNotUnitOnly::try_from_ref(2u8.as_bytes()),
Some(&FieldlessButNotUnitOnly::C {})
FieldlessButNotUnitOnly::try_read_from(&disc[..]),
Some(FieldlessButNotUnitOnly::C {})
);
assert_eq!(FieldlessButNotUnitOnly::try_from_ref(&[]), None);
assert_eq!(FieldlessButNotUnitOnly::try_from_ref(&[3]), None);
assert_eq!(FieldlessButNotUnitOnly::try_from_ref(&[0, 0]), None);
assert_eq!(FieldlessButNotUnitOnly::try_read_from(&[0xFF; SIZE][..]), None);
}

#[derive(Eq, PartialEq, Debug, KnownLayout, TryFromBytes, IntoBytes)]
#[repr(C)]
enum WeirdDiscriminants {
A = -7,
B,
C = 33,
}

#[test]
fn test_weird_discriminants() {
const SIZE: usize = core::mem::size_of::<WeirdDiscriminants>();
let disc: [u8; SIZE] = zerocopy::transmute!(WeirdDiscriminants::A);
assert_eq!(WeirdDiscriminants::try_read_from(&disc[..]), Some(WeirdDiscriminants::A));
let disc: [u8; SIZE] = zerocopy::transmute!(WeirdDiscriminants::B);
assert_eq!(WeirdDiscriminants::try_read_from(&disc[..]), Some(WeirdDiscriminants::B));
let disc: [u8; SIZE] = zerocopy::transmute!(WeirdDiscriminants::C);
assert_eq!(WeirdDiscriminants::try_read_from(&disc[..]), Some(WeirdDiscriminants::C));
assert_eq!(WeirdDiscriminants::try_read_from(&[0xFF; SIZE][..]), None);
}
Loading