Skip to content

Commit

Permalink
derive(SmartPointer): rewrite bounds in where and generic bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
dingxiangfei2009 committed Jul 29, 2024
1 parent a5ee5cb commit b06c007
Show file tree
Hide file tree
Showing 7 changed files with 573 additions and 11 deletions.
171 changes: 160 additions & 11 deletions compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
use std::mem::swap;

use ast::HasAttrs;
use rustc_ast::mut_visit::MutVisitor;
use rustc_ast::visit::BoundKind;
use rustc_ast::{
self as ast, GenericArg, GenericBound, GenericParamKind, ItemKind, MetaItem,
TraitBoundModifiers, VariantData,
};
use rustc_attr as attr;
use rustc_data_structures::flat_map_in_place::FlatMapInPlace;
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_span::symbol::{sym, Ident};
use rustc_span::Span;
use rustc_span::{Span, Symbol};
use smallvec::{smallvec, SmallVec};
use thin_vec::{thin_vec, ThinVec};

type AstTy = ast::ptr::P<ast::Ty>;

macro_rules! path {
($span:expr, $($part:ident)::*) => { vec![$(Ident::new(sym::$part, $span),)*] }
}

macro_rules! symbols {
($($part:ident)::*) => { [$(sym::$part),*] }
}

pub fn expand_deriving_smart_ptr(
cx: &ExtCtxt<'_>,
span: Span,
Expand Down Expand Up @@ -143,31 +152,171 @@ pub fn expand_deriving_smart_ptr(

// Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it.
let mut impl_generics = generics.clone();
let pointee_ty_ident = generics.params[pointee_param_idx].ident;
let mut self_bounds;
{
let p = &mut impl_generics.params[pointee_param_idx];
self_bounds = p.bounds.clone();
let arg = GenericArg::Type(s_ty.clone());
let unsize = cx.path_all(span, true, path!(span, core::marker::Unsize), vec![arg]);
p.bounds.push(cx.trait_bound(unsize, false));
let mut attrs = thin_vec![];
swap(&mut p.attrs, &mut attrs);
p.attrs = attrs.into_iter().filter(|attr| !attr.has_name(sym::pointee)).collect();
}
// We should not set default values to constant generic parameters
// and write out bounds that indirectly involves `#[pointee]`.
for (params, orig_params) in impl_generics.params[pointee_param_idx + 1..]
.iter_mut()
.zip(&generics.params[pointee_param_idx + 1..])
{
if let ast::GenericParamKind::Const { default, .. } = &mut params.kind {
*default = None;
}
for bound in &orig_params.bounds {
let mut bound = bound.clone();
let mut substitution = TypeSubstitution {
from_name: pointee_ty_ident.name,
to_ty: &s_ty,
rewritten: false,
};
substitution.visit_param_bound(&mut bound, BoundKind::Bound);
if substitution.rewritten {
params.bounds.push(bound);
}
}
}

// Add the `__S: ?Sized` extra parameter to the impl block.
// We should also write the bounds from `#[pointee]` to `__S` as required by `Unsize<__S>`.
let sized = cx.path_global(span, path!(span, core::marker::Sized));
let bound = GenericBound::Trait(
cx.poly_trait_ref(span, sized),
TraitBoundModifiers {
polarity: ast::BoundPolarity::Maybe(span),
constness: ast::BoundConstness::Never,
asyncness: ast::BoundAsyncness::Normal,
},
);
let extra_param = cx.typaram(span, Ident::new(sym::__S, span), vec![bound], None);
impl_generics.params.push(extra_param);
if self_bounds.iter().all(|bound| {
if let GenericBound::Trait(
trait_ref,
TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. },
) = bound
{
!is_sized_marker(&trait_ref.trait_ref.path)
} else {
false
}
}) {
self_bounds.push(GenericBound::Trait(
cx.poly_trait_ref(span, sized),
TraitBoundModifiers {
polarity: ast::BoundPolarity::Maybe(span),
constness: ast::BoundConstness::Never,
asyncness: ast::BoundAsyncness::Normal,
},
));
}
{
let mut substitution =
TypeSubstitution { from_name: pointee_ty_ident.name, to_ty: &s_ty, rewritten: false };
for bound in &mut self_bounds {
substitution.visit_param_bound(bound, BoundKind::Bound);
}
}

// We should also commute the where bounds from `#[pointee]` to `__S`
// as well as any bound that indirectly involves the `#[pointee]` type.
for bound in &generics.where_clause.predicates {
if let ast::WherePredicate::BoundPredicate(bound) = bound {
let bound_on_pointee = bound
.bounded_ty
.kind
.is_simple_path()
.map_or(false, |name| name == pointee_ty_ident.name);

let bounds: Vec<_> = bound
.bounds
.iter()
.filter(|bound| {
if let GenericBound::Trait(
trait_ref,
TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. },
) = bound
{
!bound_on_pointee || !is_sized_marker(&trait_ref.trait_ref.path)
} else {
true
}
})
.cloned()
.collect();
let mut substitution = TypeSubstitution {
from_name: pointee_ty_ident.name,
to_ty: &s_ty,
rewritten: bounds.len() != bound.bounds.len(),
};
let mut predicate = ast::WherePredicate::BoundPredicate(ast::WhereBoundPredicate {
span: bound.span,
bound_generic_params: bound.bound_generic_params.clone(),
bounded_ty: bound.bounded_ty.clone(),
bounds,
});
substitution.visit_where_predicate(&mut predicate);
if substitution.rewritten {
impl_generics.where_clause.predicates.push(predicate);
}
}
}

let extra_param = cx.typaram(span, Ident::new(sym::__S, span), self_bounds, None);
impl_generics.params.insert(pointee_param_idx + 1, extra_param);

// Add the impl blocks for `DispatchFromDyn` and `CoerceUnsized`.
let gen_args = vec![GenericArg::Type(alt_self_type.clone())];
add_impl_block(impl_generics.clone(), sym::DispatchFromDyn, gen_args.clone());
add_impl_block(impl_generics.clone(), sym::CoerceUnsized, gen_args.clone());
}

fn is_sized_marker(path: &ast::Path) -> bool {
const CORE_UNSIZE: [Symbol; 3] = symbols!(core::marker::Sized);
const STD_UNSIZE: [Symbol; 3] = symbols!(std::marker::Sized);
if path.segments.len() == 3 {
path.segments.iter().zip(CORE_UNSIZE).all(|(segment, symbol)| segment.ident.name == symbol)
|| path
.segments
.iter()
.zip(STD_UNSIZE)
.all(|(segment, symbol)| segment.ident.name == symbol)
} else {
*path == sym::Sized
}
}

struct TypeSubstitution<'a> {
from_name: Symbol,
to_ty: &'a AstTy,
rewritten: bool,
}

impl<'a> ast::mut_visit::MutVisitor for TypeSubstitution<'a> {
fn visit_ty(&mut self, ty: &mut AstTy) {
if let Some(name) = ty.kind.is_simple_path()
&& name == self.from_name
{
*ty = self.to_ty.clone();
self.rewritten = true;
} else {
ast::mut_visit::walk_ty(self, ty);
}
}

fn visit_where_predicate(&mut self, where_predicate: &mut ast::WherePredicate) {
match where_predicate {
rustc_ast::WherePredicate::BoundPredicate(bound) => {
bound
.bound_generic_params
.flat_map_in_place(|param| self.flat_map_generic_param(param));
self.visit_ty(&mut bound.bounded_ty);
for bound in &mut bound.bounds {
self.visit_param_bound(bound, BoundKind::Bound)
}
}
rustc_ast::WherePredicate::RegionPredicate(_)
| rustc_ast::WherePredicate::EqPredicate(_) => {}
}
}
}
Loading

0 comments on commit b06c007

Please sign in to comment.