Skip to content

Commit

Permalink
Refactor away inferred_obligations from the trait selector
Browse files Browse the repository at this point in the history
  • Loading branch information
aravind-pg committed Mar 4, 2018
1 parent 3b8bd53 commit 81f0b96
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 85 deletions.
13 changes: 0 additions & 13 deletions src/librustc/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -854,19 +854,6 @@ impl<'tcx, N> Vtable<'tcx, N> {
}
}

fn nested_obligations_mut(&mut self) -> &mut Vec<N> {
match self {
&mut VtableImpl(ref mut i) => &mut i.nested,
&mut VtableParam(ref mut n) => n,
&mut VtableBuiltin(ref mut i) => &mut i.nested,
&mut VtableAutoImpl(ref mut d) => &mut d.nested,
&mut VtableGenerator(ref mut c) => &mut c.nested,
&mut VtableClosure(ref mut c) => &mut c.nested,
&mut VtableObject(ref mut d) => &mut d.nested,
&mut VtableFnPointer(ref mut d) => &mut d.nested,
}
}

pub fn map<M, F>(self, f: F) -> Vtable<'tcx, M> where F: FnMut(N) -> M {
match self {
VtableImpl(i) => VtableImpl(VtableImplData {
Expand Down
112 changes: 40 additions & 72 deletions src/librustc/traits/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,17 @@ use ty::relate::TypeRelation;
use middle::lang_items;

use rustc_data_structures::bitvec::BitVector;
use rustc_data_structures::snapshot_vec::{SnapshotVecDelegate, SnapshotVec};
use std::iter;
use std::cell::RefCell;
use std::cmp;
use std::fmt;
use std::marker::PhantomData;
use std::mem;
use std::rc::Rc;
use syntax::abi::Abi;
use hir;
use lint;
use util::nodemap::{FxHashMap, FxHashSet};

struct InferredObligationsSnapshotVecDelegate<'tcx> {
phantom: PhantomData<&'tcx i32>,
}
impl<'tcx> SnapshotVecDelegate for InferredObligationsSnapshotVecDelegate<'tcx> {
type Value = PredicateObligation<'tcx>;
type Undo = ();
fn reverse(_: &mut Vec<Self::Value>, _: Self::Undo) {}
}

pub struct SelectionContext<'cx, 'gcx: 'cx+'tcx, 'tcx: 'cx> {
infcx: &'cx InferCtxt<'cx, 'gcx, 'tcx>,
Expand Down Expand Up @@ -92,8 +82,6 @@ pub struct SelectionContext<'cx, 'gcx: 'cx+'tcx, 'tcx: 'cx> {
/// would satisfy it. This avoids crippling inference, basically.
intercrate: Option<IntercrateMode>,

inferred_obligations: SnapshotVec<InferredObligationsSnapshotVecDelegate<'tcx>>,

intercrate_ambiguity_causes: Option<Vec<IntercrateAmbiguityCause>>,

/// Controls whether or not to filter out negative impls when selecting.
Expand Down Expand Up @@ -429,7 +417,6 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
infcx,
freshener: infcx.freshener(),
intercrate: None,
inferred_obligations: SnapshotVec::new(),
intercrate_ambiguity_causes: None,
allow_negative_impls: false,
}
Expand All @@ -442,7 +429,6 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
infcx,
freshener: infcx.freshener(),
intercrate: Some(mode),
inferred_obligations: SnapshotVec::new(),
intercrate_ambiguity_causes: None,
allow_negative_impls: false,
}
Expand All @@ -455,7 +441,6 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
infcx,
freshener: infcx.freshener(),
intercrate: None,
inferred_obligations: SnapshotVec::new(),
intercrate_ambiguity_causes: None,
allow_negative_impls,
}
Expand Down Expand Up @@ -498,8 +483,6 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
fn in_snapshot<R, F>(&mut self, f: F) -> R
where F: FnOnce(&mut Self, &infer::CombinedSnapshot<'cx, 'tcx>) -> R
{
// The irrefutable nature of the operation means we don't need to snapshot the
// inferred_obligations vector.
self.infcx.in_snapshot(|snapshot| f(self, snapshot))
}

Expand All @@ -508,28 +491,15 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
fn probe<R, F>(&mut self, f: F) -> R
where F: FnOnce(&mut Self, &infer::CombinedSnapshot<'cx, 'tcx>) -> R
{
let inferred_obligations_snapshot = self.inferred_obligations.start_snapshot();
let result = self.infcx.probe(|snapshot| f(self, snapshot));
self.inferred_obligations.rollback_to(inferred_obligations_snapshot);
result
self.infcx.probe(|snapshot| f(self, snapshot))
}

/// Wraps a commit_if_ok s.t. obligations collected during it are not returned in selection if
/// the transaction fails and s.t. old obligations are retained.
fn commit_if_ok<T, E, F>(&mut self, f: F) -> Result<T, E> where
F: FnOnce(&mut Self, &infer::CombinedSnapshot) -> Result<T, E>
{
let inferred_obligations_snapshot = self.inferred_obligations.start_snapshot();
match self.infcx.commit_if_ok(|snapshot| f(self, snapshot)) {
Ok(ok) => {
self.inferred_obligations.commit(inferred_obligations_snapshot);
Ok(ok)
},
Err(err) => {
self.inferred_obligations.rollback_to(inferred_obligations_snapshot);
Err(err)
}
}
self.infcx.commit_if_ok(|snapshot| f(self, snapshot))
}


Expand Down Expand Up @@ -560,12 +530,7 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
let stack = self.push_stack(TraitObligationStackList::empty(), obligation);
let ret = match self.candidate_from_obligation(&stack)? {
None => None,
Some(candidate) => {
let mut candidate = self.confirm_candidate(obligation, candidate)?;
let inferred_obligations = (*self.inferred_obligations).into_iter().cloned();
candidate.nested_obligations_mut().extend(inferred_obligations);
Some(candidate)
},
Some(candidate) => Some(self.confirm_candidate(obligation, candidate)?)
};

// Test whether this is a `()` which was produced by defaulting a
Expand Down Expand Up @@ -658,7 +623,7 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
stack: TraitObligationStackList<'o, 'tcx>,
predicates: I)
-> EvaluationResult
where I : Iterator<Item=&'a PredicateObligation<'tcx>>, 'tcx:'a
where I : IntoIterator<Item=&'a PredicateObligation<'tcx>>, 'tcx:'a
{
let mut result = EvaluatedToOk;
for obligation in predicates {
Expand Down Expand Up @@ -695,7 +660,7 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
// does this code ever run?
match self.infcx.equality_predicate(&obligation.cause, obligation.param_env, p) {
Ok(InferOk { obligations, .. }) => {
self.inferred_obligations.extend(obligations);
self.evaluate_predicates_recursively(previous_stack, &obligations);
EvaluatedToOk
},
Err(_) => EvaluatedToErr
Expand All @@ -706,7 +671,7 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
// does this code ever run?
match self.infcx.subtype_predicate(&obligation.cause, obligation.param_env, p) {
Some(Ok(InferOk { obligations, .. })) => {
self.inferred_obligations.extend(obligations);
self.evaluate_predicates_recursively(previous_stack, &obligations);
EvaluatedToOk
},
Some(Err(_)) => EvaluatedToErr,
Expand Down Expand Up @@ -1553,12 +1518,9 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
-> bool
{
assert!(!skol_trait_ref.has_escaping_regions());
match self.infcx.at(&obligation.cause, obligation.param_env)
.sup(ty::Binder(skol_trait_ref), trait_bound) {
Ok(InferOk { obligations, .. }) => {
self.inferred_obligations.extend(obligations);
}
Err(_) => { return false; }
if let Err(_) = self.infcx.at(&obligation.cause, obligation.param_env)
.sup(ty::Binder(skol_trait_ref), trait_bound) {
return false;
}

self.infcx.leak_check(false, obligation.cause.span, skol_map, snapshot).is_ok()
Expand Down Expand Up @@ -2644,6 +2606,7 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
};

let mut upcast_trait_ref = None;
let mut nested = vec![];
let vtable_base;

{
Expand All @@ -2662,7 +2625,11 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
self.commit_if_ok(
|this, _| this.match_poly_trait_ref(obligation, t))
{
Ok(_) => { upcast_trait_ref = Some(t); false }
Ok(obligations) => {
upcast_trait_ref = Some(t);
nested.extend(obligations);
false
}
Err(_) => { true }
}
});
Expand All @@ -2680,7 +2647,7 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
VtableObjectData {
upcast_trait_ref: upcast_trait_ref.unwrap(),
vtable_base,
nested: vec![]
nested,
}
}

Expand Down Expand Up @@ -2737,7 +2704,7 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
self.generator_trait_ref_unnormalized(obligation, closure_def_id, substs);
let Normalized {
value: trait_ref,
obligations
mut obligations
} = normalize_with_depth(self,
obligation.param_env,
obligation.cause.clone(),
Expand All @@ -2749,10 +2716,11 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
trait_ref,
obligations);

self.confirm_poly_trait_refs(obligation.cause.clone(),
obligation.param_env,
obligation.predicate.to_poly_trait_ref(),
trait_ref)?;
obligations.extend(
self.confirm_poly_trait_refs(obligation.cause.clone(),
obligation.param_env,
obligation.predicate.to_poly_trait_ref(),
trait_ref)?);

Ok(VtableGeneratorData {
closure_def_id: closure_def_id,
Expand Down Expand Up @@ -2798,10 +2766,11 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
trait_ref,
obligations);

self.confirm_poly_trait_refs(obligation.cause.clone(),
obligation.param_env,
obligation.predicate.to_poly_trait_ref(),
trait_ref)?;
obligations.extend(
self.confirm_poly_trait_refs(obligation.cause.clone(),
obligation.param_env,
obligation.predicate.to_poly_trait_ref(),
trait_ref)?);

obligations.push(Obligation::new(
obligation.cause.clone(),
Expand Down Expand Up @@ -2845,13 +2814,13 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
obligation_param_env: ty::ParamEnv<'tcx>,
obligation_trait_ref: ty::PolyTraitRef<'tcx>,
expected_trait_ref: ty::PolyTraitRef<'tcx>)
-> Result<(), SelectionError<'tcx>>
-> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>>
{
let obligation_trait_ref = obligation_trait_ref.clone();
self.infcx
.at(&obligation_cause, obligation_param_env)
.sup(obligation_trait_ref, expected_trait_ref)
.map(|InferOk { obligations, .. }| self.inferred_obligations.extend(obligations))
.map(|InferOk { obligations, .. }| obligations)
.map_err(|e| OutputTypeParameterMismatch(expected_trait_ref, obligation_trait_ref, e))
}

Expand Down Expand Up @@ -2888,7 +2857,7 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
self.infcx.at(&obligation.cause, obligation.param_env)
.eq(target, new_trait)
.map_err(|_| Unimplemented)?;
self.inferred_obligations.extend(obligations);
nested.extend(obligations);

// Register one obligation for 'a: 'b.
let cause = ObligationCause::new(obligation.cause.span,
Expand Down Expand Up @@ -2950,7 +2919,7 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
self.infcx.at(&obligation.cause, obligation.param_env)
.eq(b, a)
.map_err(|_| Unimplemented)?;
self.inferred_obligations.extend(obligations);
nested.extend(obligations);
}

// Struct<T> -> Struct<U>.
Expand Down Expand Up @@ -3014,7 +2983,7 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
self.infcx.at(&obligation.cause, obligation.param_env)
.eq(target, new_struct)
.map_err(|_| Unimplemented)?;
self.inferred_obligations.extend(obligations);
nested.extend(obligations);

// Construct the nested Field<T>: Unsize<Field<U>> predicate.
nested.push(tcx.predicate_for_trait_def(
Expand Down Expand Up @@ -3045,7 +3014,7 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
self.infcx.at(&obligation.cause, obligation.param_env)
.eq(target, new_tuple)
.map_err(|_| Unimplemented)?;
self.inferred_obligations.extend(obligations);
nested.extend(obligations);

// Construct the nested T: Unsize<U> predicate.
nested.push(tcx.predicate_for_trait_def(
Expand Down Expand Up @@ -3118,7 +3087,7 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
let impl_trait_ref = impl_trait_ref.subst(self.tcx(),
impl_substs);

let impl_trait_ref =
let Normalized { value: impl_trait_ref, obligations: mut nested_obligations } =
project::normalize_with_depth(self,
obligation.param_env,
obligation.cause.clone(),
Expand All @@ -3134,12 +3103,12 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {

let InferOk { obligations, .. } =
self.infcx.at(&obligation.cause, obligation.param_env)
.eq(skol_obligation_trait_ref, impl_trait_ref.value)
.eq(skol_obligation_trait_ref, impl_trait_ref)
.map_err(|e| {
debug!("match_impl: failed eq_trait_refs due to `{}`", e);
()
})?;
self.inferred_obligations.extend(obligations);
nested_obligations.extend(obligations);

if let Err(e) = self.infcx.leak_check(false,
obligation.cause.span,
Expand All @@ -3152,7 +3121,7 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
debug!("match_impl: success impl_substs={:?}", impl_substs);
Ok((Normalized {
value: impl_substs,
obligations: impl_trait_ref.obligations
obligations: nested_obligations
}, skol_map))
}

Expand Down Expand Up @@ -3189,24 +3158,23 @@ impl<'cx, 'gcx, 'tcx> SelectionContext<'cx, 'gcx, 'tcx> {
where_clause_trait_ref: ty::PolyTraitRef<'tcx>)
-> Result<Vec<PredicateObligation<'tcx>>,()>
{
self.match_poly_trait_ref(obligation, where_clause_trait_ref)?;
Ok(Vec::new())
self.match_poly_trait_ref(obligation, where_clause_trait_ref)
}

/// Returns `Ok` if `poly_trait_ref` being true implies that the
/// obligation is satisfied.
fn match_poly_trait_ref(&mut self,
obligation: &TraitObligation<'tcx>,
poly_trait_ref: ty::PolyTraitRef<'tcx>)
-> Result<(),()>
-> Result<Vec<PredicateObligation<'tcx>>,()>
{
debug!("match_poly_trait_ref: obligation={:?} poly_trait_ref={:?}",
obligation,
poly_trait_ref);

self.infcx.at(&obligation.cause, obligation.param_env)
.sup(obligation.predicate.to_poly_trait_ref(), poly_trait_ref)
.map(|InferOk { obligations, .. }| self.inferred_obligations.extend(obligations))
.map(|InferOk { obligations, .. }| obligations)
.map_err(|_| ())
}

Expand Down

0 comments on commit 81f0b96

Please sign in to comment.