Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[derive] Derive TryFromBytes on unions #800

Merged
merged 1 commit into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 39 additions & 29 deletions src/pointer/ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,37 +285,37 @@ pub mod invariant {
/// The referent is not necessarily initialized.
AnyValidity,

/// The byte ranges initialized in `T` are also initialized in the
/// referent.
/// The byte ranges initialized in `T` are also initialized in
/// the referent.
///
/// Formally: uninitialized bytes may only be present in `Ptr<T>`'s
/// referent where it is possible for them to be present in `T`.
/// This is a dynamic property: if, at a particular byte offset, a
/// valid enum discriminant is set, the subsequent bytes may only
/// have uninitialized bytes as specificed by the corresponding
/// enum.
/// Formally: uninitialized bytes may only be present in
/// `Ptr<T>`'s referent where they are guaranteed to be present
/// in `T`. This is a dynamic property: if, at a particular byte
/// offset, a valid enum discriminant is set, the subsequent
/// bytes may only have uninitialized bytes as specificed by the
/// corresponding enum.
///
/// Formally, given `len = size_of_val_raw(ptr)`, at every byte
/// offset, `b`, in the range `[0, len)`:
/// - If, in all instances `t: T` of length `len`, the byte at
/// offset `b` in `t` is initialized, then the byte at offset `b`
/// within `*ptr` must be initialized.
/// - Let `c` be the contents of the byte range `[0, b)` in `*ptr`.
/// Let `S` be the subset of valid instances of `T` of length
/// `len` which contain `c` in the offset range `[0, b)`. If, for
/// all instances of `t: T` in `S`, the byte at offset `b` in `t`
/// is initialized, then the byte at offset `b` in `*ptr` must be
/// initialized.
/// - If, in any instance `t: T` of length `len`, the byte at
/// offset `b` in `t` is initialized, then the byte at offset
/// `b` within `*ptr` must be initialized.
/// - Let `c` be the contents of the byte range `[0, b)` in
/// `*ptr`. Let `S` be the subset of valid instances of `T` of
/// length `len` which contain `c` in the offset range `[0,
/// b)`. If, in any instance of `t: T` in `S`, the byte at
/// offset `b` in `t` is initialized, then the byte at offset
/// `b` in `*ptr` must be initialized.
///
/// Pragmatically, this means that if `*ptr` is guaranteed to
/// contain an enum type at a particular offset, and the enum
/// discriminant stored in `*ptr` corresponds to a valid variant
/// of that enum type, then it is guaranteed that the appropriate
/// bytes of `*ptr` are initialized as defined by that variant's
/// bit validity (although note that the variant may contain
/// another enum type, in which case the same rules apply
/// depending on the state of its discriminant, and so on
/// recursively).
/// discriminant stored in `*ptr` corresponds to a valid
/// variant of that enum type, then it is guaranteed that the
/// appropriate bytes of `*ptr` are initialized as defined by
/// that variant's bit validity (although note that the
/// variant may contain another enum type, in which case the
/// same rules apply depending on the state of its
/// discriminant, and so on recursively).
AsInitialized,

/// The referent is bit-valid for `T`.
Expand Down Expand Up @@ -785,6 +785,7 @@ mod _project {
where
T: 'a + ?Sized,
I: Invariants,
I::Validity: invariant::at_least::AsInitialized,
{
/// Projects a field from `self`.
///
Expand All @@ -796,6 +797,8 @@ mod _project {
/// argument. Its argument will be `self` casted to a raw pointer. The
/// pointer it returns must reference only a subset of `self`'s bytes.
///
/// The caller also promises that `T` is a struct or union type.
///
/// ## Postconditions
///
/// If the preconditions of this function are met, this function will
Expand All @@ -805,7 +808,7 @@ mod _project {
pub unsafe fn project<U: 'a + ?Sized>(
self,
projector: impl FnOnce(*mut T) -> *mut U,
) -> Ptr<'a, U, (I::Aliasing, invariant::AnyAlignment, I::Validity)> {
) -> Ptr<'a, U, (I::Aliasing, invariant::AnyAlignment, invariant::AsInitialized)> {
// SAFETY: `projector` is provided with `self` casted to a raw
// pointer.
let field = projector(self.as_non_null().as_ptr());
Expand Down Expand Up @@ -849,10 +852,17 @@ mod _project {
// `ALIASING_INVARIANT` because projection does not impact the
// aliasing invariant.
// 7. `field`, trivially, conforms to the alignment invariant of
// `AnyAlignment`.
// 8. `field`, conditionally, conforms to the validity invariant of
// `VALIDITY_INVARIANT`. If `field` is projected from data valid
// for `T`, `field` will be valid for `U`.
// `AnyAlignment`.
// 8. By type bound on `I::Validity`, `self` satisfies the
// "as-initialized" property relative to `T`. The returned `Ptr`
// has the validity `AsInitialized`. The caller promises that `T`
// is either a struct type or a union type. Returning a `Ptr`
// with the validity `AsInitialized` is valid in both cases. The
// struct case is self-explanatory, but the union case bears
// explanation. The "as-initialized" property says that a byte
// must be initialized if it is initialized in *any* instance of
// the type. Thus, if `self`'s referent is as-initialized as `T`,
// then it is at least as-initialized as each of its fields.
unsafe { Ptr::new(field) }
}
}
Expand Down
62 changes: 36 additions & 26 deletions zerocopy-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,7 @@ pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenSt
Data::Enum(_) => {
Error::new_spanned(&ast, "TryFromBytes not supported on enum types").to_compile_error()
}
Data::Union(_) => {
Error::new_spanned(&ast, "TryFromBytes not supported on union types").to_compile_error()
}
Data::Union(unn) => derive_try_from_bytes_union(&ast, unn),
}
.into()
}
Expand Down Expand Up @@ -346,36 +344,15 @@ fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_m
// `is_bit_valid`.
fn is_bit_valid(candidate: zerocopy::Maybe<Self>) -> bool {
true #(&& {
// SAFETY: `project` is a field projection of `candidate`.
// The projected field will be well-aligned because this
// derive rejects packed types.
// SAFETY: `project` is a field projection of `candidate`,
// and `Self` is a struct type.
let field_candidate = unsafe {
let project = |slf: *mut Self|
::core::ptr::addr_of_mut!((*slf).#field_names);

candidate.project(project)
};

// SAFETY: The below invocation of `is_bit_valid` satisfies
// the safety preconditions of `is_bit_valid`:
// - The memory referenced by `field_candidate` is only
// accessed via reads for the duration of this method
// call. This is ensured by contract on the caller of the
// surrounding `is_bit_valid`.
// - `field_candidate` may not refer to a valid instance of
// its corresponding field type, but it will only have
// `UnsafeCell`s at the offsets at which they may occur in
// that field type. This is ensured both by contract on
// the caller of the surrounding `is_bit_valid`, and by
// the construction of `field_candidiate`, i.e., via
// projection through `candidate`.
//
// Note that it's possible that this call will panic -
// `is_bit_valid` does not promise that it doesn't panic,
// and in practice, we support user-defined validators,
// which could panic. This is sound because we haven't
// violated any safety invariants which we would need to fix
// before returning.
<#field_tys as zerocopy::TryFromBytes>::is_bit_valid(field_candidate)
})*
}
Expand All @@ -384,6 +361,39 @@ fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_m
impl_block(ast, strct, Trait::TryFromBytes, RequireBoundedFields::Yes, false, None, extras)
}

// A union is `TryFromBytes` if:
// - any of its fields are `TryFromBytes`

fn derive_try_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream {
let extras = Some({
let fields = unn.fields();
let field_names = fields.iter().map(|(name, _ty)| name);
let field_tys = fields.iter().map(|(_name, ty)| ty);
quote!(
// SAFETY: We use `is_bit_valid` to validate that any field is
// bit-valid; we only return `true` if at least one of them is. The
// bit validity of a union is not yet well defined in Rust, but it
// is guaranteed to be no more strict than this definition. See #696
// for a more in-depth discussion.
fn is_bit_valid(candidate: zerocopy::Maybe<Self>) -> bool {
false #(|| {
// SAFETY: `project` is a field projection of `candidate`,
// and `Self` is a union type.
let field_candidate = unsafe {
let project = |slf: *mut Self|
::core::ptr::addr_of_mut!((*slf).#field_names);

candidate.project(project)
};

<#field_tys as zerocopy::TryFromBytes>::is_bit_valid(field_candidate)
})*
}
)
});
impl_block(ast, unn, Trait::TryFromBytes, RequireBoundedFields::Yes, false, None, extras)
}

const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[
&[StructRepr::C],
&[StructRepr::Transparent],
Expand Down
18 changes: 8 additions & 10 deletions zerocopy-derive/tests/struct_try_from_bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,17 @@ use crate::util::AU16;
// A struct is `TryFromBytes` if:
// - all fields are `TryFromBytes`

#[derive(TryFromBytes, FromZeros, FromBytes)]
struct Zst;

assert_impl_all!(Zst: TryFromBytes);

#[test]
fn zst() {
// TODO(#5): Use `try_transmute` in this test once it's available.
let candidate = zerocopy::Ptr::from_ref(&Zst);
let candidate = zerocopy::Ptr::from_ref(&());
let candidate = candidate.forget_aligned().forget_valid();
let is_bit_valid = Zst::is_bit_valid(candidate);
let is_bit_valid = <()>::is_bit_valid(candidate);
assert!(is_bit_valid);
}

#[derive(TryFromBytes, FromZeros, FromBytes)]
#[repr(C)]
struct One {
a: u8,
}
Expand All @@ -53,17 +49,18 @@ fn one() {
}

#[derive(TryFromBytes, FromZeros)]
#[repr(C)]
struct Two {
a: bool,
b: Zst,
b: (),
}

assert_impl_all!(Two: TryFromBytes);

#[test]
fn two() {
// TODO(#5): Use `try_transmute` in this test once it's available.
let candidate = zerocopy::Ptr::from_ref(&Two { a: false, b: Zst });
let candidate = zerocopy::Ptr::from_ref(&Two { a: false, b: () });
let candidate = candidate.forget_aligned().forget_valid();
let is_bit_valid = Two::is_bit_valid(candidate);
assert!(is_bit_valid);
Expand All @@ -80,7 +77,6 @@ fn two_bad() {
// *mut U`.
// - The size of the object referenced by the resulting pointer is equal to
// the size of the object referenced by `self`.
// - The alignment of `Unsized` is equal to the alignment of `[u8]`.
let candidate = unsafe { candidate.cast_unsized(|p| p as *mut Two) };

// SAFETY: `candidate`'s referent is as-initialized as `Two`.
Expand All @@ -91,6 +87,7 @@ fn two_bad() {
}

#[derive(TryFromBytes, FromZeros, FromBytes)]
#[repr(C)]
struct Unsized {
a: [u8],
}
Expand Down Expand Up @@ -118,6 +115,7 @@ fn un_sized() {
}

#[derive(TryFromBytes, FromZeros, FromBytes)]
#[repr(C)]
struct TypeParams<'a, T: ?Sized, I: Iterator> {
a: I::Item,
b: u8,
Expand Down
Loading