Skip to content

Commit

Permalink
derive(SmartPointer): rewrite bounds in where and generic bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
dingxiangfei2009 committed Jul 29, 2024
1 parent a5ee5cb commit 00413c5
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 11 deletions.
208 changes: 197 additions & 11 deletions compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
use std::mem::swap;

use ast::HasAttrs;
use rustc_ast::mut_visit::MutVisitor;
use rustc_ast::visit::BoundKind;
use rustc_ast::{
self as ast, GenericArg, GenericBound, GenericParamKind, ItemKind, MetaItem,
TraitBoundModifiers, VariantData,
};
use rustc_attr as attr;
use rustc_data_structures::flat_map_in_place::FlatMapInPlace;
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_span::symbol::{sym, Ident};
use rustc_span::Span;
use rustc_span::{Span, Symbol};
use smallvec::{smallvec, SmallVec};
use thin_vec::{thin_vec, ThinVec};

type AstTy = ast::ptr::P<ast::Ty>;

macro_rules! path {
($span:expr, $($part:ident)::*) => { vec![$(Ident::new(sym::$part, $span),)*] }
}

macro_rules! symbols {
($($part:ident)::*) => { [$(sym::$part),*] }
}

pub fn expand_deriving_smart_ptr(
cx: &ExtCtxt<'_>,
span: Span,
Expand Down Expand Up @@ -143,31 +152,208 @@ pub fn expand_deriving_smart_ptr(

// Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it.
let mut impl_generics = generics.clone();
let pointee_ty_ident = generics.params[pointee_param_idx].ident;
let mut self_bounds;
{
let p = &mut impl_generics.params[pointee_param_idx];
self_bounds = p.bounds.clone();
let arg = GenericArg::Type(s_ty.clone());
let unsize = cx.path_all(span, true, path!(span, core::marker::Unsize), vec![arg]);
p.bounds.push(cx.trait_bound(unsize, false));
let mut attrs = thin_vec![];
swap(&mut p.attrs, &mut attrs);
p.attrs = attrs.into_iter().filter(|attr| !attr.has_name(sym::pointee)).collect();
}
// We should not set default values to constant generic parameters
// and write out bounds that indirectly involves `#[pointee]`.
for (params, orig_params) in impl_generics.params[pointee_param_idx + 1..]
.iter_mut()
.zip(&generics.params[pointee_param_idx + 1..])
{
if let ast::GenericParamKind::Const { default, .. } = &mut params.kind {
*default = None;
}
for bound in &orig_params.bounds {
let mut bound = bound.clone();
let mut substitution = TypeSubstitution {
from_name: pointee_ty_ident.name,
to_ty: &s_ty,
rewritten: false,
};
substitution.visit_param_bound(&mut bound, BoundKind::Bound);
if substitution.rewritten {
params.bounds.push(bound);
}
}
}

// Add the `__S: ?Sized` extra parameter to the impl block.
// We should also write the bounds from `#[pointee]` to `__S` as required by `Unsize<__S>`.
let sized = cx.path_global(span, path!(span, core::marker::Sized));
let bound = GenericBound::Trait(
cx.poly_trait_ref(span, sized),
TraitBoundModifiers {
polarity: ast::BoundPolarity::Maybe(span),
constness: ast::BoundConstness::Never,
asyncness: ast::BoundAsyncness::Normal,
},
);
let extra_param = cx.typaram(span, Ident::new(sym::__S, span), vec![bound], None);
impl_generics.params.push(extra_param);
if self_bounds.iter().all(|bound| {
if let GenericBound::Trait(
trait_ref,
TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. },
) = bound
{
!is_sized_marker(&trait_ref.trait_ref.path)
} else {
false
}
}) {
self_bounds.push(GenericBound::Trait(
cx.poly_trait_ref(span, sized),
TraitBoundModifiers {
polarity: ast::BoundPolarity::Maybe(span),
constness: ast::BoundConstness::Never,
asyncness: ast::BoundAsyncness::Normal,
},
));
}
{
let mut substitution =
TypeSubstitution { from_name: pointee_ty_ident.name, to_ty: &s_ty, rewritten: false };
for bound in &mut self_bounds {
substitution.visit_param_bound(bound, BoundKind::Bound);
}
}

// We should also commute the where bounds from `#[pointee]` to `__S`
// as well as any bound that indirectly involves the `#[pointee]` type.
for bound in &generics.where_clause.predicates {
if let ast::WherePredicate::BoundPredicate(bound) = bound {
let bound_on_pointee = bound
.bounded_ty
.kind
.is_simple_path()
.map_or(false, |name| name == pointee_ty_ident.name);

let bounds: Vec<_> = bound
.bounds
.iter()
.filter(|bound| {
if let GenericBound::Trait(
trait_ref,
TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. },
) = bound
{
!bound_on_pointee || !is_sized_marker(&trait_ref.trait_ref.path)
} else {
true
}
})
.cloned()
.collect();
let mut substitution = TypeSubstitution {
from_name: pointee_ty_ident.name,
to_ty: &s_ty,
rewritten: bounds.len() != bound.bounds.len(),
};
let mut predicate = ast::WherePredicate::BoundPredicate(ast::WhereBoundPredicate {
span: bound.span,
bound_generic_params: bound.bound_generic_params.clone(),
bounded_ty: bound.bounded_ty.clone(),
bounds,
});
substitution.visit_where_predicate(&mut predicate);
if substitution.rewritten {
impl_generics.where_clause.predicates.push(predicate);
}
}
}

let extra_param = cx.typaram(span, Ident::new(sym::__S, span), self_bounds, None);
impl_generics.params.insert(pointee_param_idx + 1, extra_param);

// Add the impl blocks for `DispatchFromDyn` and `CoerceUnsized`.
let gen_args = vec![GenericArg::Type(alt_self_type.clone())];
add_impl_block(impl_generics.clone(), sym::DispatchFromDyn, gen_args.clone());
add_impl_block(impl_generics.clone(), sym::CoerceUnsized, gen_args.clone());
}

fn is_sized_marker(path: &ast::Path) -> bool {
const CORE_UNSIZE: [Symbol; 3] = symbols!(core::marker::Sized);
const STD_UNSIZE: [Symbol; 3] = symbols!(std::marker::Sized);
if path.segments.len() == 3 {
path.segments.iter().zip(CORE_UNSIZE).all(|(segment, symbol)| segment.ident.name == symbol)
|| path
.segments
.iter()
.zip(STD_UNSIZE)
.all(|(segment, symbol)| segment.ident.name == symbol)
} else {
*path == sym::Sized
}
}

struct TypeSubstitution<'a> {
from_name: Symbol,
to_ty: &'a AstTy,
rewritten: bool,
}

impl<'a> ast::mut_visit::MutVisitor for TypeSubstitution<'a> {
fn visit_ty(&mut self, ty: &mut AstTy) {
if let Some(name) = ty.kind.is_simple_path()
&& name == self.from_name
{
*ty = self.to_ty.clone();
self.rewritten = true;
return;
}
match &mut ty.kind {
ast::TyKind::Slice(_)
| ast::TyKind::Array(_, _)
| ast::TyKind::Ptr(_)
| ast::TyKind::Ref(_, _)
| ast::TyKind::BareFn(_)
| ast::TyKind::Never
| ast::TyKind::Tup(_)
| ast::TyKind::AnonStruct(_, _)
| ast::TyKind::AnonUnion(_, _)
| ast::TyKind::Path(_, _)
| ast::TyKind::TraitObject(_, _)
| ast::TyKind::ImplTrait(_, _)
| ast::TyKind::Paren(_)
| ast::TyKind::Typeof(_)
| ast::TyKind::Infer
| ast::TyKind::MacCall(_)
| ast::TyKind::Pat(_, _) => ast::mut_visit::walk_ty(self, ty),
ast::TyKind::ImplicitSelf
| ast::TyKind::CVarArgs
| ast::TyKind::Dummy
| ast::TyKind::Err(_) => {}
}
}

fn visit_param_bound(&mut self, bound: &mut GenericBound, _ctxt: BoundKind) {
match bound {
GenericBound::Trait(trait_ref, _) => {
self.visit_poly_trait_ref(trait_ref);
}

GenericBound::Use(args, _span) => {
for arg in args {
self.visit_precise_capturing_arg(arg);
}
}
GenericBound::Outlives(_) => {}
}
}

fn visit_where_predicate(&mut self, where_predicate: &mut ast::WherePredicate) {
match where_predicate {
rustc_ast::WherePredicate::BoundPredicate(bound) => {
bound
.bound_generic_params
.flat_map_in_place(|param| self.flat_map_generic_param(param));
self.visit_ty(&mut bound.bounded_ty);
for bound in &mut bound.bounds {
self.visit_param_bound(bound, BoundKind::Bound)
}
}
rustc_ast::WherePredicate::RegionPredicate(_)
| rustc_ast::WherePredicate::EqPredicate(_) => {}
}
}
}
19 changes: 19 additions & 0 deletions tests/ui/deriving/deriving-smart-pointer-expanded.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//@ check-pass
//@ compile-flags: -Zunpretty=expanded
#![feature(derive_smart_pointer)]
use std::marker::SmartPointer;

pub trait MyTrait<T: ?Sized> {}

#[derive(SmartPointer)]
#[repr(transparent)]
struct MyPointer<'a, #[pointee] T: ?Sized> {
ptr: &'a T,
}

#[derive(core::marker::SmartPointer)]
#[repr(transparent)]
pub struct MyPointer2<'a, #[pointee] T: ?Sized, X: MyTrait<T>> {
data: &'a mut T,
x: core::marker::PhantomData<X>,
}
41 changes: 41 additions & 0 deletions tests/ui/deriving/deriving-smart-pointer-expanded.stdout
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#![feature(prelude_import)]
#![no_std]
//@ check-pass
//@ compile-flags: -Zunpretty=expanded
#![feature(derive_smart_pointer)]
#[prelude_import]
use ::std::prelude::rust_2015::*;
#[macro_use]
extern crate std;
use std::marker::SmartPointer;

pub trait MyTrait<T: ?Sized> {}

#[repr(transparent)]
struct MyPointer<'a, #[pointee] T: ?Sized> {
ptr: &'a T,
}
#[automatically_derived]
impl<'a, T: ?Sized + ::core::marker::Unsize<__S>, __S: ?Sized>
::core::ops::DispatchFromDyn<MyPointer<'a, __S>> for MyPointer<'a, T> {
}
#[automatically_derived]
impl<'a, T: ?Sized + ::core::marker::Unsize<__S>, __S: ?Sized>
::core::ops::CoerceUnsized<MyPointer<'a, __S>> for MyPointer<'a, T> {
}

#[repr(transparent)]
pub struct MyPointer2<'a, #[pointee] T: ?Sized, X: MyTrait<T>> {
data: &'a mut T,
x: core::marker::PhantomData<X>,
}
#[automatically_derived]
impl<'a, T: ?Sized + ::core::marker::Unsize<__S>, __S: ?Sized, X: MyTrait<T> +
MyTrait<__S>> ::core::ops::DispatchFromDyn<MyPointer2<'a, __S, X>> for
MyPointer2<'a, T, X> {
}
#[automatically_derived]
impl<'a, T: ?Sized + ::core::marker::Unsize<__S>, __S: ?Sized, X: MyTrait<T> +
MyTrait<__S>> ::core::ops::CoerceUnsized<MyPointer2<'a, __S, X>> for
MyPointer2<'a, T, X> {
}
78 changes: 78 additions & 0 deletions tests/ui/deriving/smart-pointer-bounds-issue-127647.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
//@ check-pass

#![feature(derive_smart_pointer)]

#[derive(core::marker::SmartPointer)]
#[repr(transparent)]
pub struct Ptr<'a, #[pointee] T: OnDrop + ?Sized, X> {
data: &'a mut T,
x: core::marker::PhantomData<X>,
}

pub trait OnDrop {
fn on_drop(&mut self);
}

#[derive(core::marker::SmartPointer)]
#[repr(transparent)]
pub struct Ptr2<'a, #[pointee] T: ?Sized, X>
where
T: OnDrop,
{
data: &'a mut T,
x: core::marker::PhantomData<X>,
}

pub trait MyTrait<T: ?Sized> {}

#[derive(core::marker::SmartPointer)]
#[repr(transparent)]
pub struct Ptr3<'a, #[pointee] T: ?Sized, X>
where
T: MyTrait<T>,
{
data: &'a mut T,
x: core::marker::PhantomData<X>,
}

#[derive(core::marker::SmartPointer)]
#[repr(transparent)]
pub struct Ptr4<'a, #[pointee] T: MyTrait<T> + ?Sized, X> {
data: &'a mut T,
x: core::marker::PhantomData<X>,
}

#[derive(core::marker::SmartPointer)]
#[repr(transparent)]
pub struct Ptr5<'a, #[pointee] T: ?Sized, X>
where
Ptr5Companion<T>: MyTrait<T>,
Ptr5Companion2: MyTrait<T>,
{
data: &'a mut T,
x: core::marker::PhantomData<X>,
}

pub struct Ptr5Companion<T: ?Sized>(core::marker::PhantomData<T>);
pub struct Ptr5Companion2;

#[derive(core::marker::SmartPointer)]
#[repr(transparent)]
pub struct Ptr6<'a, #[pointee] T: ?Sized, X: MyTrait<T>> {
data: &'a mut T,
x: core::marker::PhantomData<X>,
}

// a reduced example from https://lore.kernel.org/all/20240402-linked-list-v1-1-b1c59ba7ae3b@google.com/
#[repr(transparent)]
#[derive(core::marker::SmartPointer)]
pub struct ListArc<#[pointee] T, const ID: u64 = 0>
where
T: ListArcSafe<ID> + ?Sized,
{
arc: *const T,
}

pub trait ListArcSafe<const ID: u64> {}

fn main() {}

0 comments on commit 00413c5

Please sign in to comment.