Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform async ResumeTy in generator transform #105977

Merged
merged 1 commit into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ language_item_table! {
IdentityFuture, sym::identity_future, identity_future_fn, Target::Fn, GenericRequirement::None;
GetContext, sym::get_context, get_context_fn, Target::Fn, GenericRequirement::None;

Context, sym::Context, context, Target::Struct, GenericRequirement::None;
FuturePoll, sym::poll, future_poll_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;

FromFrom, sym::from, from_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
Expand Down
9 changes: 9 additions & 0 deletions compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1952,6 +1952,15 @@ impl<'tcx> TyCtxt<'tcx> {
self.mk_ty(GeneratorWitness(types))
}

/// Creates a `&mut Context<'_>` [`Ty`] with erased lifetimes.
pub fn mk_task_context(self) -> Ty<'tcx> {
let context_did = self.require_lang_item(LangItem::Context, None);
let context_adt_ref = self.adt_def(context_did);
let context_substs = self.intern_substs(&[self.lifetimes.re_erased.into()]);
let context_ty = self.mk_adt(context_adt_ref, context_substs);
self.mk_mut_ref(self.lifetimes.re_erased, context_ty)
}

#[inline]
pub fn mk_ty_var(self, v: TyVid) -> Ty<'tcx> {
self.mk_ty_infer(TyVar(v))
Expand Down
118 changes: 111 additions & 7 deletions compiler/rustc_mir_transform/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,104 @@ fn replace_local<'tcx>(
new_local
}

/// Transforms the `body` of the generator applying the following transforms:
///
/// - Eliminates all the `get_context` calls that async lowering created.
/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
///
/// The `Local`s that have their types replaced are:
/// - The `resume` argument itself.
/// - The argument to `get_context`.
/// - The yielded value of a `yield`.
///
/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
///
/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
/// but rather directly use `&mut Context<'_>`, however that would currently
/// lead to higher-kinded lifetime errors.
/// See <https://github.com/rust-lang/rust/issues/105501>.
///
/// The async lowering step and the type / lifetime inference / checking are
/// still using the `ResumeTy` indirection for the time being, and that indirection
/// is removed here. After this transform, the generator body only knows about `&mut Context<'_>`.
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let context_mut_ref = tcx.mk_task_context();

// replace the type of the `resume` argument
replace_resume_ty_local(tcx, body, Local::new(2), context_mut_ref);

let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None);

for bb in BasicBlock::new(0)..body.basic_blocks.next_index() {
let bb_data = &body[bb];
if bb_data.is_cleanup {
continue;
}

match &bb_data.terminator().kind {
TerminatorKind::Call { func, .. } => {
let func_ty = func.ty(body, tcx);
if let ty::FnDef(def_id, _) = *func_ty.kind() {
if def_id == get_context_def_id {
let local = eliminate_get_context_call(&mut body[bb]);
replace_resume_ty_local(tcx, body, local, context_mut_ref);
}
} else {
continue;
}
}
TerminatorKind::Yield { resume_arg, .. } => {
replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
}
_ => {}
}
}
}

fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
let terminator = bb_data.terminator.take().unwrap();
if let TerminatorKind::Call { mut args, destination, target, .. } = terminator.kind {
let arg = args.pop().unwrap();
let local = arg.place().unwrap().local;

let arg = Rvalue::Use(arg);
let assign = Statement {
source_info: terminator.source_info,
kind: StatementKind::Assign(Box::new((destination, arg))),
};
bb_data.statements.push(assign);
bb_data.terminator = Some(Terminator {
source_info: terminator.source_info,
kind: TerminatorKind::Goto { target: target.unwrap() },
});
local
} else {
bug!();
}
}

#[cfg_attr(not(debug_assertions), allow(unused))]
fn replace_resume_ty_local<'tcx>(
tcx: TyCtxt<'tcx>,
body: &mut Body<'tcx>,
local: Local,
context_mut_ref: Ty<'tcx>,
) {
let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
// We have to replace the `ResumeTy` that is used for type and borrow checking
// with `&mut Context<'_>` in MIR.
#[cfg(debug_assertions)]
{
if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
assert_eq!(*resume_ty_adt, expected_adt);
} else {
panic!("expected `ResumeTy`, found `{:?}`", local_ty);
};
}
}

struct LivenessInfo {
/// Which locals are live across any suspension point.
saved_locals: GeneratorSavedLocals,
Expand Down Expand Up @@ -1283,13 +1381,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
}
};

let is_async_kind = body.generator_kind().unwrap() != GeneratorKind::Gen;
let is_async_kind = matches!(body.generator_kind(), Some(GeneratorKind::Async(_)));
let (state_adt_ref, state_substs) = if is_async_kind {
// Compute Poll<return_ty>
let state_did = tcx.require_lang_item(LangItem::Poll, None);
let state_adt_ref = tcx.adt_def(state_did);
let state_substs = tcx.intern_substs(&[body.return_ty().into()]);
(state_adt_ref, state_substs)
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
let poll_adt_ref = tcx.adt_def(poll_did);
let poll_substs = tcx.intern_substs(&[body.return_ty().into()]);
(poll_adt_ref, poll_substs)
} else {
// Compute GeneratorState<yield_ty, return_ty>
let state_did = tcx.require_lang_item(LangItem::GeneratorState, None);
Expand All @@ -1303,13 +1401,19 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
// RETURN_PLACE then is a fresh unused local with type ret_ty.
let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx);

// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
if is_async_kind {
transform_async_context(tcx, body);
}

// We also replace the resume argument and insert an `Assign`.
// This is needed because the resume argument `_2` might be live across a `yield`, in which
// case there is no `Assign` to it that the transform can turn into a store to the generator
// state. After the yield the slot in the generator state would then be uninitialized.
let resume_local = Local::new(2);
let new_resume_local =
replace_local(resume_local, body.local_decls[resume_local].ty, body, tcx);
let resume_ty =
if is_async_kind { tcx.mk_task_context() } else { body.local_decls[resume_local].ty };
let new_resume_local = replace_local(resume_local, resume_ty, body, tcx);

// When first entering the generator, move the resume argument into its new local.
let source_info = SourceInfo::outermost(body.span);
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ symbols! {
Capture,
Center,
Clone,
Context,
Continue,
Copy,
Count,
Expand Down
34 changes: 27 additions & 7 deletions compiler/rustc_ty_utils/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,41 @@ fn fn_sig_for_fn_abi<'tcx>(
// `Generator::resume(...) -> GeneratorState` function in case we
// have an ordinary generator, or the `Future::poll(...) -> Poll`
// function in case this is a special generator backing an async construct.
let ret_ty = if tcx.generator_is_async(did) {
let state_did = tcx.require_lang_item(LangItem::Poll, None);
let state_adt_ref = tcx.adt_def(state_did);
let state_substs = tcx.intern_substs(&[sig.return_ty.into()]);
tcx.mk_adt(state_adt_ref, state_substs)
let (resume_ty, ret_ty) = if tcx.generator_is_async(did) {
// The signature should be `Future::poll(_, &mut Context<'_>) -> Poll<Output>`
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
let poll_adt_ref = tcx.adt_def(poll_did);
let poll_substs = tcx.intern_substs(&[sig.return_ty.into()]);
let ret_ty = tcx.mk_adt(poll_adt_ref, poll_substs);

// We have to replace the `ResumeTy` that is used for type and borrow checking
// with `&mut Context<'_>` which is used in codegen.
#[cfg(debug_assertions)]
{
if let ty::Adt(resume_ty_adt, _) = sig.resume_ty.kind() {
let expected_adt =
tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
assert_eq!(*resume_ty_adt, expected_adt);
} else {
panic!("expected `ResumeTy`, found `{:?}`", sig.resume_ty);
};
}
let context_mut_ref = tcx.mk_task_context();

(context_mut_ref, ret_ty)
} else {
// The signature should be `Generator::resume(_, Resume) -> GeneratorState<Yield, Return>`
let state_did = tcx.require_lang_item(LangItem::GeneratorState, None);
let state_adt_ref = tcx.adt_def(state_did);
let state_substs = tcx.intern_substs(&[sig.yield_ty.into(), sig.return_ty.into()]);
tcx.mk_adt(state_adt_ref, state_substs)
let ret_ty = tcx.mk_adt(state_adt_ref, state_substs);

(sig.resume_ty, ret_ty)
};

ty::Binder::bind_with_vars(
tcx.mk_fn_sig(
[env_ty, sig.resume_ty].iter(),
[env_ty, resume_ty].iter(),
&ret_ty,
false,
hir::Unsafety::Normal,
Expand Down
4 changes: 4 additions & 0 deletions library/core/src/future/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ pub unsafe fn get_context<'a, 'b>(cx: ResumeTy) -> &'a mut Context<'b> {
unsafe { &mut *cx.0.as_ptr().cast() }
}

// FIXME(swatinem): This fn is currently needed to work around shortcomings
// in type and lifetime inference.
// See the comment at the bottom of `LoweringContext::make_async_expr` and
// <https://github.com/rust-lang/rust/issues/104826>.
#[doc(hidden)]
#[unstable(feature = "gen_future", issue = "50547")]
#[inline]
Expand Down
1 change: 1 addition & 0 deletions library/core/src/task/wake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ impl RawWakerVTable {
/// Currently, `Context` only serves to provide access to a [`&Waker`](Waker)
/// which can be used to wake the current task.
#[stable(feature = "futures_api", since = "1.36.0")]
#[cfg_attr(not(bootstrap), lang = "Context")]
pub struct Context<'a> {
waker: &'a Waker,
// Ensure we future-proof against variance changes by forcing
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// MIR for `a::{closure#0}` 0 generator_resume
/* generator_layout = GeneratorLayout {
field_tys: {},
variant_fields: {
Unresumed(0): [],
Returned (1): [],
Panicked (2): [],
},
storage_conflicts: BitMatrix(0x0) {},
} */

fn a::{closure#0}(_1: Pin<&mut [async fn body@$DIR/async_await.rs:11:14: 11:16]>, _2: &mut Context<'_>) -> Poll<()> {
debug _task_context => _4; // in scope 0 at $DIR/async_await.rs:+0:14: +0:16
let mut _0: std::task::Poll<()>; // return place in scope 0 at $DIR/async_await.rs:+0:14: +0:16
let mut _3: (); // in scope 0 at $DIR/async_await.rs:+0:14: +0:16
let mut _4: &mut std::task::Context<'_>; // in scope 0 at $DIR/async_await.rs:+0:14: +0:16
let mut _5: u32; // in scope 0 at $DIR/async_await.rs:+0:14: +0:16

bb0: {
_5 = discriminant((*(_1.0: &mut [async fn body@$DIR/async_await.rs:11:14: 11:16]))); // scope 0 at $DIR/async_await.rs:+0:14: +0:16
switchInt(move _5) -> [0: bb1, 1: bb2, otherwise: bb3]; // scope 0 at $DIR/async_await.rs:+0:14: +0:16
}

bb1: {
_4 = move _2; // scope 0 at $DIR/async_await.rs:+0:14: +0:16
_3 = const (); // scope 0 at $DIR/async_await.rs:+0:14: +0:16
Deinit(_0); // scope 0 at $DIR/async_await.rs:+0:16: +0:16
((_0 as Ready).0: ()) = move _3; // scope 0 at $DIR/async_await.rs:+0:16: +0:16
discriminant(_0) = 0; // scope 0 at $DIR/async_await.rs:+0:16: +0:16
discriminant((*(_1.0: &mut [async fn body@$DIR/async_await.rs:11:14: 11:16]))) = 1; // scope 0 at $DIR/async_await.rs:+0:16: +0:16
return; // scope 0 at $DIR/async_await.rs:+0:16: +0:16
}

bb2: {
assert(const false, "`async fn` resumed after completion") -> bb2; // scope 0 at $DIR/async_await.rs:+0:14: +0:16
}

bb3: {
unreachable; // scope 0 at $DIR/async_await.rs:+0:14: +0:16
}
}
Loading