Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement should_continue in chalk-recursive #774

Merged
merged 5 commits into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions chalk-engine/src/slg/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub trait AggregateOps<I: Interner> {
&self,
root_goal: &UCanonical<InEnvironment<Goal<I>>>,
answers: impl context::AnswerStream<I>,
should_continue: impl std::ops::Fn() -> bool,
should_continue: impl std::ops::Fn() -> bool + Clone,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need this clone bound? You should be able to pass in &should_continue to any functions to avoid a move

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I tried that in commit 8da2e1d, and that runs into a rustc bug

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a comment so the workaround can be removed when that's fixed.

) -> Option<Solution<I>>;
}

Expand All @@ -28,7 +28,7 @@ impl<I: Interner> AggregateOps<I> for SlgContextOps<'_, I> {
&self,
root_goal: &UCanonical<InEnvironment<Goal<I>>>,
mut answers: impl context::AnswerStream<I>,
should_continue: impl std::ops::Fn() -> bool,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Option<Solution<I>> {
let interner = self.program.interner();
let CompleteAnswer { subst, ambiguous } = match answers.next_answer(&should_continue) {
Expand Down
20 changes: 15 additions & 5 deletions chalk-recursive/src/fixed_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ where
context: &mut RecursiveContext<K, V>,
goal: &K,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> V;
fn reached_fixed_point(self, old_value: &V, new_value: &V) -> bool;
fn error_value(self) -> V;
Expand Down Expand Up @@ -104,22 +105,24 @@ where
&mut self,
canonical_goal: &K,
solver_stuff: impl SolverStuff<K, V>,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> V {
debug!("solve_root_goal(canonical_goal={:?})", canonical_goal);
assert!(self.stack.is_empty());
let minimums = &mut Minimums::new();
self.solve_goal(canonical_goal, minimums, solver_stuff)
self.solve_goal(canonical_goal, minimums, solver_stuff, should_continue)
}

/// Attempt to solve a goal that has been fully broken down into leaf form
/// and canonicalized. This is where the action really happens, and is the
/// place where we would perform caching in rustc (and may eventually do in Chalk).
#[instrument(level = "info", skip(self, minimums, solver_stuff,))]
#[instrument(level = "info", skip(self, minimums, solver_stuff, should_continue))]
pub fn solve_goal(
&mut self,
goal: &K,
minimums: &mut Minimums,
solver_stuff: impl SolverStuff<K, V>,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> V {
// First check the cache.
if let Some(cache) = &self.cache {
Expand Down Expand Up @@ -159,7 +162,8 @@ where
let depth = self.stack.push(coinductive_goal);
let dfn = self.search_graph.insert(goal, depth, initial_solution);

let subgoal_minimums = self.solve_new_subgoal(goal, depth, dfn, solver_stuff);
let subgoal_minimums =
self.solve_new_subgoal(goal, depth, dfn, solver_stuff, should_continue);

self.search_graph[dfn].links = subgoal_minimums;
self.search_graph[dfn].stack_depth = None;
Expand Down Expand Up @@ -190,13 +194,14 @@ where
}
}

#[instrument(level = "debug", skip(self, solver_stuff))]
#[instrument(level = "debug", skip(self, solver_stuff, should_continue))]
fn solve_new_subgoal(
&mut self,
canonical_goal: &K,
depth: StackDepth,
dfn: DepthFirstNumber,
solver_stuff: impl SolverStuff<K, V>,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Minimums {
// We start with `answer = None` and try to solve the goal. At the end of the iteration,
// `answer` will be updated with the result of the solving process. If we detect a cycle
Expand All @@ -209,7 +214,12 @@ where
// so this function will eventually be constant and the loop terminates.
loop {
let minimums = &mut Minimums::new();
let current_answer = solver_stuff.solve_iteration(self, canonical_goal, minimums);
let current_answer = solver_stuff.solve_iteration(
self,
canonical_goal,
minimums,
should_continue.clone(),
);

debug!(
"solve_new_subgoal: loop iteration result = {:?} with minimums {:?}",
Expand Down
38 changes: 28 additions & 10 deletions chalk-recursive/src/fulfill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,24 +342,31 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
Ok(())
}

#[instrument(level = "debug", skip(self, minimums))]
#[instrument(level = "debug", skip(self, minimums, should_continue))]
fn prove(
&mut self,
wc: InEnvironment<Goal<I>>,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<PositiveSolution<I>> {
let interner = self.solver.interner();
let (quantified, free_vars) = canonicalize(&mut self.infer, interner, wc);
let (quantified, universes) = u_canonicalize(&mut self.infer, interner, &quantified);
let result = self.solver.solve_goal(quantified, minimums);
let result = self
.solver
.solve_goal(quantified, minimums, should_continue);
Ok(PositiveSolution {
free_vars,
universes,
solution: result?,
})
}

fn refute(&mut self, goal: InEnvironment<Goal<I>>) -> Fallible<NegativeSolution> {
fn refute(
&mut self,
goal: InEnvironment<Goal<I>>,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<NegativeSolution> {
let canonicalized = match self
.infer
.invert_then_canonicalize(self.solver.interner(), goal)
Expand All @@ -376,7 +383,10 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
let (quantified, _) =
u_canonicalize(&mut self.infer, self.solver.interner(), &canonicalized);
let mut minimums = Minimums::new(); // FIXME -- minimums here seems wrong
if let Ok(solution) = self.solver.solve_goal(quantified, &mut minimums) {
if let Ok(solution) = self
.solver
.solve_goal(quantified, &mut minimums, should_continue)
{
if solution.is_unique() {
Err(NoSolution)
} else {
Expand Down Expand Up @@ -431,7 +441,11 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
}
}

fn fulfill(&mut self, minimums: &mut Minimums) -> Fallible<Outcome> {
fn fulfill(
&mut self,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<Outcome> {
debug_span!("fulfill", obligations=?self.obligations);

// Try to solve all the obligations. We do this via a fixed-point
Expand Down Expand Up @@ -460,7 +474,7 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
free_vars,
universes,
solution,
} = self.prove(wc.clone(), minimums)?;
} = self.prove(wc.clone(), minimums, should_continue.clone())?;

if let Some(constrained_subst) = solution.definite_subst(self.interner()) {
// If the substitution is trivial, we won't actually make any progress by applying it!
Expand All @@ -484,7 +498,7 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
solution.is_ambig()
}
Obligation::Refute(goal) => {
let answer = self.refute(goal.clone())?;
let answer = self.refute(goal.clone(), should_continue.clone())?;
answer == NegativeSolution::Ambiguous
}
};
Expand Down Expand Up @@ -514,8 +528,12 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
/// Try to fulfill all pending obligations and build the resulting
/// solution. The returned solution will transform `subst` substitution with
/// the outcome of type inference by updating the replacements it provides.
pub(super) fn solve(mut self, minimums: &mut Minimums) -> Fallible<Solution<I>> {
let outcome = match self.fulfill(minimums) {
pub(super) fn solve(
mut self,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<Solution<I>> {
let outcome = match self.fulfill(minimums, should_continue.clone()) {
Ok(o) => o,
Err(e) => return Err(e),
};
Expand Down Expand Up @@ -567,7 +585,7 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
free_vars,
universes,
solution,
} = self.prove(goal, minimums).unwrap();
} = self.prove(goal, minimums, should_continue.clone()).unwrap();
if let Some(constrained_subst) =
solution.constrained_subst(self.solver.interner())
{
Expand Down
16 changes: 10 additions & 6 deletions chalk-recursive/src/recursive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ impl<I: Interner> SolverStuff<UCanonicalGoal<I>, Fallible<Solution<I>>> for &dyn
context: &mut RecursiveContext<UCanonicalGoal<I>, Fallible<Solution<I>>>,
goal: &UCanonicalGoal<I>,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<Solution<I>> {
Solver::new(context, self).solve_iteration(goal, minimums)
Solver::new(context, self).solve_iteration(goal, minimums, should_continue)
}

fn reached_fixed_point(
Expand Down Expand Up @@ -108,8 +109,10 @@ impl<'me, I: Interner> SolveDatabase<I> for Solver<'me, I> {
&mut self,
goal: UCanonicalGoal<I>,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<Solution<I>> {
self.context.solve_goal(&goal, minimums, self.program)
self.context
.solve_goal(&goal, minimums, self.program, should_continue)
}

fn interner(&self) -> I {
Expand All @@ -131,17 +134,18 @@ impl<I: Interner> chalk_solve::Solver<I> for RecursiveSolver<I> {
program: &dyn RustIrDatabase<I>,
goal: &UCanonical<InEnvironment<Goal<I>>>,
) -> Option<chalk_solve::Solution<I>> {
self.ctx.solve_root_goal(goal, program).ok()
self.ctx.solve_root_goal(goal, program, || true).ok()
}

fn solve_limited(
&mut self,
program: &dyn RustIrDatabase<I>,
goal: &UCanonical<InEnvironment<Goal<I>>>,
_should_continue: &dyn std::ops::Fn() -> bool,
should_continue: &dyn std::ops::Fn() -> bool,
) -> Option<chalk_solve::Solution<I>> {
// TODO support should_continue in recursive solver
self.ctx.solve_root_goal(goal, program).ok()
self.ctx
.solve_root_goal(goal, program, should_continue)
.ok()
}

fn solve_multiple(
Expand Down
23 changes: 17 additions & 6 deletions chalk-recursive/src/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub(super) trait SolveDatabase<I: Interner>: Sized {
&mut self,
goal: UCanonical<InEnvironment<Goal<I>>>,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<Solution<I>>;

fn max_size(&self) -> usize;
Expand All @@ -35,12 +36,17 @@ pub(super) trait SolveIteration<I: Interner>: SolveDatabase<I> {
/// Executes one iteration of the recursive solver, computing the current
/// solution to the given canonical goal. This is used as part of a loop in
/// the case of cyclic goals.
#[instrument(level = "debug", skip(self))]
#[instrument(level = "debug", skip(self, should_continue))]
fn solve_iteration(
&mut self,
canonical_goal: &UCanonicalGoal<I>,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<Solution<I>> {
if !should_continue() {
return Ok(Solution::Ambig(Guidance::Unknown));
}

let UCanonical {
universes,
canonical:
Expand Down Expand Up @@ -72,7 +78,7 @@ pub(super) trait SolveIteration<I: Interner>: SolveDatabase<I> {
let prog_solution = {
debug_span!("prog_clauses");

self.solve_from_clauses(&canonical_goal, minimums)
self.solve_from_clauses(&canonical_goal, minimums, should_continue)
};
debug!(?prog_solution);

Expand All @@ -88,7 +94,7 @@ pub(super) trait SolveIteration<I: Interner>: SolveDatabase<I> {
},
};

self.solve_via_simplification(&canonical_goal, minimums)
self.solve_via_simplification(&canonical_goal, minimums, should_continue)
}
}
}
Expand All @@ -103,15 +109,16 @@ where

/// Helper methods for `solve_iteration`, private to this module.
trait SolveIterationHelpers<I: Interner>: SolveDatabase<I> {
#[instrument(level = "debug", skip(self, minimums))]
#[instrument(level = "debug", skip(self, minimums, should_continue))]
fn solve_via_simplification(
&mut self,
canonical_goal: &UCanonicalGoal<I>,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<Solution<I>> {
let (infer, subst, goal) = self.new_inference_table(canonical_goal);
match Fulfill::new_with_simplification(self, infer, subst, goal) {
Ok(fulfill) => fulfill.solve(minimums),
Ok(fulfill) => fulfill.solve(minimums, should_continue),
Err(e) => Err(e),
}
}
Expand All @@ -123,6 +130,7 @@ trait SolveIterationHelpers<I: Interner>: SolveDatabase<I> {
&mut self,
canonical_goal: &UCanonical<InEnvironment<DomainGoal<I>>>,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<Solution<I>> {
let mut clauses = vec![];

Expand Down Expand Up @@ -159,7 +167,10 @@ trait SolveIterationHelpers<I: Interner>: SolveDatabase<I> {
let subst = subst.clone();
let goal = goal.clone();
let res = match Fulfill::new_with_clause(self, infer, subst, goal, implication) {
Ok(fulfill) => (fulfill.solve(minimums), implication.skip_binders().priority),
Ok(fulfill) => (
fulfill.solve(minimums, should_continue.clone()),
implication.skip_binders().priority,
),
Err(e) => (Err(e), ClausePriority::High),
};

Expand Down