Skip to content

Commit

Permalink
[derive] TryFromBytes on repr(C) enums
Browse files Browse the repository at this point in the history
TODO:
- More tests?
  • Loading branch information
joshlf committed Jan 23, 2024
1 parent c50c835 commit 6bf5c90
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 376 deletions.
129 changes: 36 additions & 93 deletions zerocopy-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use {
quote::quote,
syn::{
parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, Expr, ExprLit,
GenericParam, Ident, Index, Lit,
GenericParam, Ident, Lit,
},
};

Expand Down Expand Up @@ -405,29 +405,13 @@ fn derive_try_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2:
.to_compile_error();
}

let reprs = try_or_print!(ENUM_TRY_FROM_BYTES_CFG.validate_reprs(ast));
let discriminant_type = match reprs.as_slice() {
[EnumRepr::U8] => quote!(u8),
[EnumRepr::U16] => quote!(u16),
[EnumRepr::U32] => quote!(u32),
[EnumRepr::U64] => quote!(u64),
[EnumRepr::Usize] => quote!(usize),
[EnumRepr::I8] => quote!(i8),
[EnumRepr::I16] => quote!(i16),
[EnumRepr::I32] => quote!(i32),
[EnumRepr::I64] => quote!(i64),
[EnumRepr::Isize] => quote!(isize),
// `validate_reprs` has already validated that it's one of the preceding
// patterns.
_ => unreachable!(),
};

let discriminant_exprs = enm.variants.iter().scan(Discriminant::default(), |disc, var| {
Some(disc.update_and_generate_expr(&var.discriminant))
});
// We don't actually care what the repr is; we just care that it's one of
// the allowed ones.
try_or_print!(ENUM_TRY_FROM_BYTES_CFG.validate_reprs(ast));
let variant_names = enm.variants.iter().map(|v| &v.ident);
let extras = Some(quote!(
// SAFETY: We use `is_bit_valid` to validate that the bit pattern
// corresponds to one of the C-like enum's variant discriminants.
// corresponds to one of the field-less enum's variant discriminants.
// Thus, this is a sound implementation of `is_bit_valid`.
fn is_bit_valid(
candidate: ::zerocopy::Ptr<
Expand All @@ -439,90 +423,51 @@ fn derive_try_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2:
::zerocopy::pointer::invariant::AsInitialized,
),
>,
) -> bool {
) -> ::zerocopy::macro_util::core_reexport::primitive::bool {
use ::zerocopy::macro_util::core_reexport;
// SAFETY:
// - `cast` is implemented as required.
// - Since we cast to the type specified by `Self`'s repr, `p`'s
// referent and the referent of the returned pointer have the
// same size.
let discriminant = unsafe { candidate.cast_unsized(|p: *mut Self| p as *mut ::zerocopy::macro_util::core_reexport::primitive::#discriminant_type) };
// SAFETY: Since `candidate` has the invariant `AsInitialized`,
// we know that `candidate`'s referent (and thus
// `discriminant`'s referent) is as-initialized as `Self`. Since
// `Self`'s repr is the same type as `discriminant`, we know
// that `discriminant`'s referent satisfies the as-initialized
// property.
// - By definition, `*mut Self` and `*mut [u8; size_of::<Self>()]`
// are types of the same size.
let discriminant = unsafe { candidate.cast_unsized(|p: *mut Self| p as *mut [core_reexport::primitive::u8; core_reexport::mem::size_of::<Self>()]) };
// SAFETY: Since `candidate` has the invariant `AsInitialized`, we
// know that `candidate`'s referent (and thus `discriminant`'s
// referent) is as-initialized as `Self`. Since all of the allowed
// `repr`s are types for which all bytes are always initialized, we
// know that `discriminant`'s referent has all of its bytes
// initialized. Since `[u8; N]`'s validity invariant is just that
// all of its bytes are initialized, we know that `discriminant`'s
// referent is bit-valid.
let discriminant = unsafe { discriminant.assume_valid() };
let discriminant = discriminant.read_unaligned();

false #(|| (discriminant == (#discriminant_exprs)))*
false #(|| {
let v = Self::#variant_names{};
// SAFETY: All of the allowed `repr`s for `Self` guarantee that
// `Self`'s discriminant bytes are all initialized. Since we
// validate that `Self` has no fields, it has no bytes other
// than the discriminant. Thus, it is sound to transmute any
// instance of `Self` to `[u8; size_of::<Self>()]`.
let d: [core_reexport::primitive::u8; core_reexport::mem::size_of::<Self>()] = unsafe { core_reexport::mem::transmute(v) };
// SAFETY: Here we check that the bits of the argument
// `candidate` are equal to the bits of a `Self` constructed
// using safe code. If this condition passes, then we know that
// `candidate` refers to a bit-valid `Self`.
discriminant == d
})*
}
));
impl_block(ast, enm, Trait::TryFromBytes, RequireBoundedFields::Yes, false, None, extras)
}

// Enum variant discriminants can be manually set not only as literal values,
// but as arbitrary const expressions. In order to handle this, we keep track of
// the most-recently-seen expression and a count of how many variants have been
// encountered since then.
//
// #[repr(u8)]
// enum Foo {
// A, // 0
// B = 5, // 5
// C, // 6
// D = 1 + 1, // 2
// E, // 3
// }
//
// Note: Default::default does the right thing (initializes to { None, 0 }).
#[derive(Default, Copy, Clone)]
struct Discriminant<'a> {
// The most-recently-set explicit discriminant.
previous: Option<&'a Expr>,
// When the next variant is encountered, what offset should be used compared
// to `previous` to determine the variant's discriminant?
next_offset: usize,
}

impl<'a> Discriminant<'a> {
/// Called when encountering a variant with discriminant set to `ast`.
/// Updates `self` in preparation for the next variant and generates an
/// expression which will evaluate to the numeric value this variant's
/// discriminant.
fn update_and_generate_expr(
&mut self,
ast: &'a Option<(syn::token::Eq, Expr)>,
) -> proc_macro2::TokenStream {
match ast.as_ref().map(|(_eq, expr)| expr) {
Some(expr) => {
self.previous = Some(expr);
self.next_offset = 1;
quote!(#expr)
}
None => {
let previous = self.previous.iter();
// Use `Index` instead of `usize` so that the number is
// formatted just as `0` rather than as `0usize`; the latter
// syntax is only valid if the repr is `usize`; otherwise,
// comparison will result in a type mismatch.
let offset = Index::from(self.next_offset);
let tokens = quote!(#(#previous +)* #offset);

self.next_offset += 1;
tokens
}
}
}
}

#[rustfmt::skip]
const ENUM_TRY_FROM_BYTES_CFG: Config<EnumRepr> = {
use EnumRepr::*;
Config {
allowed_combinations_message: r#"TryFromBytes requires repr of "u8", "u16", "u32", "u64", "usize", "i8", or "i16", "i32", "i64", or "isize""#,
allowed_combinations_message: r#"TryFromBytes requires repr of "C", "u8", "u16", "u32", "u64", "usize", "i8", or "i16", "i32", "i64", or "isize""#,
derive_unaligned: false,
allowed_combinations: &[
&[C],
&[U8],
&[U16],
&[U32],
Expand All @@ -534,9 +479,7 @@ const ENUM_TRY_FROM_BYTES_CFG: Config<EnumRepr> = {
&[I64],
&[Isize],
],
disallowed_but_legal_combinations: &[
&[C],
],
disallowed_but_legal_combinations: &[],
}
};

Expand Down
73 changes: 37 additions & 36 deletions zerocopy-derive/tests/enum_try_from_bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

use std::convert::TryFrom;

use syn::Field;

mod util;

use {
Expand All @@ -27,10 +29,10 @@ assert_impl_all!(Foo: TryFromBytes);

#[test]
fn test_foo() {
assert_eq!(Foo::try_from_ref(&[0]), Some(&Foo::A));
assert_eq!(Foo::try_from_ref(&[]), None);
assert_eq!(Foo::try_from_ref(&[1]), None);
assert_eq!(Foo::try_from_ref(&[0, 0]), None);
assert_eq!(Foo::try_read_from(&[0]), Some(Foo::A));
assert_eq!(Foo::try_read_from(&[]), None);
assert_eq!(Foo::try_read_from(&[1]), None);
assert_eq!(Foo::try_read_from(&[0, 0]), None);
}

#[derive(Eq, PartialEq, Debug, KnownLayout, TryFromBytes)]
Expand All @@ -43,11 +45,11 @@ assert_impl_all!(Bar: TryFromBytes);

#[test]
fn test_bar() {
assert_eq!(Bar::try_from_ref(&[0, 0]), Some(&Bar::A));
assert_eq!(Bar::try_from_ref(&[]), None);
assert_eq!(Bar::try_from_ref(&[0]), None);
assert_eq!(Bar::try_from_ref(&[0, 1]), None);
assert_eq!(Bar::try_from_ref(&[0, 0, 0]), None);
assert_eq!(Bar::try_read_from(&[0, 0]), Some(Bar::A));
assert_eq!(Bar::try_read_from(&[]), None);
assert_eq!(Bar::try_read_from(&[0]), None);
assert_eq!(Bar::try_read_from(&[0, 1]), None);
assert_eq!(Bar::try_read_from(&[0, 0, 0]), None);
}

#[derive(Eq, PartialEq, Debug, KnownLayout, TryFromBytes)]
Expand All @@ -61,13 +63,13 @@ assert_impl_all!(Baz: TryFromBytes);

#[test]
fn test_baz() {
assert_eq!(Baz::try_from_ref(1u32.as_bytes()), Some(&Baz::A));
assert_eq!(Baz::try_from_ref(0u32.as_bytes()), Some(&Baz::B));
assert_eq!(Baz::try_from_ref(&[]), None);
assert_eq!(Baz::try_from_ref(&[0]), None);
assert_eq!(Baz::try_from_ref(&[0, 0]), None);
assert_eq!(Baz::try_from_ref(&[0, 0, 0]), None);
assert_eq!(Baz::try_from_ref(&[0, 0, 0, 0, 0]), None);
assert_eq!(Baz::try_read_from(1u32.as_bytes()), Some(Baz::A));
assert_eq!(Baz::try_read_from(0u32.as_bytes()), Some(Baz::B));
assert_eq!(Baz::try_read_from(&[]), None);
assert_eq!(Baz::try_read_from(&[0]), None);
assert_eq!(Baz::try_read_from(&[0, 0]), None);
assert_eq!(Baz::try_read_from(&[0, 0, 0]), None);
assert_eq!(Baz::try_read_from(&[0, 0, 0, 0, 0]), None);
}

// Test hygiene - make sure that `i8` being shadowed doesn't cause problems for
Expand All @@ -89,17 +91,17 @@ assert_impl_all!(Blah: TryFromBytes);

#[test]
fn test_blah() {
assert_eq!(Blah::try_from_ref(1i8.as_bytes()), Some(&Blah::A));
assert_eq!(Blah::try_from_ref(0i8.as_bytes()), Some(&Blah::B));
assert_eq!(Blah::try_from_ref(3i8.as_bytes()), Some(&Blah::C));
assert_eq!(Blah::try_from_ref(6i8.as_bytes()), Some(&Blah::D));
assert_eq!(Blah::try_from_ref(&[]), None);
assert_eq!(Blah::try_from_ref(&[4]), None);
assert_eq!(Blah::try_from_ref(&[0, 0]), None);
assert_eq!(Blah::try_read_from(1i8.as_bytes()), Some(Blah::A));
assert_eq!(Blah::try_read_from(0i8.as_bytes()), Some(Blah::B));
assert_eq!(Blah::try_read_from(3i8.as_bytes()), Some(Blah::C));
assert_eq!(Blah::try_read_from(6i8.as_bytes()), Some(Blah::D));
assert_eq!(Blah::try_read_from(&[]), None);
assert_eq!(Blah::try_read_from(&[4]), None);
assert_eq!(Blah::try_read_from(&[0, 0]), None);
}

#[derive(Eq, PartialEq, Debug, KnownLayout, TryFromBytes)]
#[repr(u8)]
#[derive(Eq, PartialEq, Debug, KnownLayout, TryFromBytes, IntoBytes)]
#[repr(C)]
enum FieldlessButNotUnitOnly {
A,
B(),
Expand All @@ -108,19 +110,18 @@ enum FieldlessButNotUnitOnly {

#[test]
fn test_fieldless_but_not_unit_only() {
const SIZE: usize = core::mem::size_of::<FieldlessButNotUnitOnly>();
let disc: [u8; SIZE] = zerocopy::transmute!(FieldlessButNotUnitOnly::A);
assert_eq!(FieldlessButNotUnitOnly::try_read_from(&disc[..]), Some(FieldlessButNotUnitOnly::A));
let disc: [u8; SIZE] = zerocopy::transmute!(FieldlessButNotUnitOnly::B());
assert_eq!(
FieldlessButNotUnitOnly::try_from_ref(0u8.as_bytes()),
Some(&FieldlessButNotUnitOnly::A)
);
assert_eq!(
FieldlessButNotUnitOnly::try_from_ref(1u8.as_bytes()),
Some(&FieldlessButNotUnitOnly::B())
FieldlessButNotUnitOnly::try_read_from(&disc[..]),
Some(FieldlessButNotUnitOnly::B())
);
let disc: [u8; SIZE] = zerocopy::transmute!(FieldlessButNotUnitOnly::C {});
assert_eq!(
FieldlessButNotUnitOnly::try_from_ref(2u8.as_bytes()),
Some(&FieldlessButNotUnitOnly::C {})
FieldlessButNotUnitOnly::try_read_from(&disc[..]),
Some(FieldlessButNotUnitOnly::C {})
);
assert_eq!(FieldlessButNotUnitOnly::try_from_ref(&[]), None);
assert_eq!(FieldlessButNotUnitOnly::try_from_ref(&[3]), None);
assert_eq!(FieldlessButNotUnitOnly::try_from_ref(&[0, 0]), None);
assert_eq!(FieldlessButNotUnitOnly::try_read_from(&[0xFF; SIZE][..]), None);
}
Loading

0 comments on commit 6bf5c90

Please sign in to comment.