Skip to content

Commit

Permalink
More comments, final tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Jan 26, 2024
1 parent ca6c57a commit a7e6d88
Show file tree
Hide file tree
Showing 20 changed files with 169 additions and 42 deletions.
3 changes: 3 additions & 0 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,9 @@ impl<'hir> LoweringContext<'_, 'hir> {
body,
fn_decl_span: self.lower_span(fn_decl_span),
fn_arg_span: Some(self.lower_span(fn_arg_span)),
// Lower this as a `CoroutineClosure`. That will ensure that HIR typeck
// knows that a `FnDecl` output type like `-> &str` actually means
// "coroutine that returns &str", rather than directly returning a `&str`.
kind: hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async),
constness: hir::Constness::NotConst,
});
Expand Down
8 changes: 6 additions & 2 deletions compiler/rustc_borrowck/src/type_check/input_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
)),
"this needs to be modified if we're lowering non-async closures"
);
// Make sure to use the args from `DefiningTy` so the right NLL region vids are prepopulated
// into the type.
let args = args.as_coroutine_closure();
let tupled_upvars_ty = ty::CoroutineClosureSignature::tupled_upvars_by_closure_kind(
self.tcx(),
Expand All @@ -87,11 +89,13 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
ty::CoroutineArgsParts {
parent_args: args.parent_args(),
kind_ty: Ty::from_closure_kind(self.tcx(), args.kind()),
return_ty: user_provided_sig.output(),
tupled_upvars_ty,
// For async closures, none of these can be annotated, so just fill
// them with fresh ty vars.
resume_ty: next_ty_var(),
yield_ty: next_ty_var(),
witness: next_ty_var(),
return_ty: user_provided_sig.output(),
tupled_upvars_ty: tupled_upvars_ty,
},
)
.args,
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_borrowck/src/universal_regions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,14 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
ty::Binder::dummy(inputs_and_output)
}

// Construct the signature of the CoroutineClosure for the purposes of borrowck.
// This is pretty straightforward -- we:
// 1. first grab the `coroutine_closure_sig`,
// 2. compute the self type (`&`/`&mut`/no borrow),
// 3. flatten the tupled_input_tys,
// 4. construct the correct generator type to return with
// `CoroutineClosureSignature::to_coroutine_given_kind_and_upvars`.
// Then we wrap it all up into a list of inputs and output.
DefiningTy::CoroutineClosure(def_id, args) => {
assert_eq!(self.mir_def.to_def_id(), def_id);
let closure_sig = args.as_coroutine_closure().coroutine_closure_sig();
Expand Down
11 changes: 2 additions & 9 deletions compiler/rustc_hir_analysis/src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1533,15 +1533,8 @@ fn coroutine_kind(tcx: TyCtxt<'_>, def_id: LocalDefId) -> Option<hir::CoroutineK
}

fn coroutine_for_closure(tcx: TyCtxt<'_>, def_id: LocalDefId) -> DefId {
let Node::Expr(&hir::Expr {
kind:
hir::ExprKind::Closure(&rustc_hir::Closure {
kind: hir::ClosureKind::CoroutineClosure(_),
body,
..
}),
..
}) = tcx.hir_node_by_def_id(def_id)
let &rustc_hir::Closure { kind: hir::ClosureKind::CoroutineClosure(_), body, .. } =
tcx.hir_node_by_def_id(def_id).expect_closure()
else {
bug!()
};
Expand Down
9 changes: 9 additions & 0 deletions compiler/rustc_hir_typeck/src/callee.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
return Some(CallStep::DeferredClosure(def_id, closure_sig));
}

// When calling a `CoroutineClosure` that is local to the body, we will
// not know what its `closure_kind` is yet. Instead, just fill in the
// signature with an infer var for the `tupled_upvars_ty` of the coroutine,
// and record a deferred call resolution which will constrain that var
// as part of `AsyncFn*` trait confirmation.
ty::CoroutineClosure(def_id, args) if self.closure_kind(adjusted_ty).is_none() => {
let def_id = def_id.expect_local();
let closure_args = args.as_coroutine_closure();
Expand All @@ -182,6 +187,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
coroutine_closure_sig.to_coroutine(
self.tcx,
closure_args.parent_args(),
// Inherit the kind ty of the closure, since we're calling this
// coroutine with the most relaxed `AsyncFn*` trait that we can.
// We don't necessarily need to do this here, but it saves us
// computing one more infer var that will get constrained later.
closure_args.kind_ty(),
self.tcx.coroutine_for_closure(def_id),
tupled_upvars_ty,
Expand Down
39 changes: 27 additions & 12 deletions compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
interior,
));

// Coroutines that come from coroutine closures have not yet determined
// their kind ty, so make a fresh infer var which will be constrained
// later during upvar analysis. Regular coroutines always have the kind
// ty of `().`
let kind_ty = match kind {
hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure) => self
.next_ty_var(TypeVariableOrigin {
Expand Down Expand Up @@ -203,6 +207,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
)
}
hir::ClosureKind::CoroutineClosure(kind) => {
// async closures always return the type ascribed after the `->` (if present),
// and yield `()`.
let (bound_return_ty, bound_yield_ty) = match kind {
hir::CoroutineDesugaring::Async => {
(bound_sig.skip_binder().output(), tcx.types.unit)
Expand All @@ -211,6 +217,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
todo!("`gen` and `async gen` closures not supported yet")
}
};
// Compute all of the variables that will be used to populate the coroutine.
let resume_ty = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::ClosureSynthetic,
span: expr_span,
Expand Down Expand Up @@ -258,20 +265,28 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
kind: TypeVariableOriginKind::ClosureSynthetic,
span: expr_span,
});

// We need to turn the liberated signature that we got from HIR, which
// looks something like `|Args...| -> T`, into a signature that is suitable
// for type checking the inner body of the closure, which always returns a
// coroutine. To do so, we use the `CoroutineClosureSignature` to compute
// the coroutine type, filling in the tupled_upvars_ty and kind_ty with infer
// vars which will get constrained during upvar analysis.
let coroutine_output_ty = tcx.liberate_late_bound_regions(
expr_def_id.to_def_id(),
closure_args.coroutine_closure_sig().map_bound(|sig| {
sig.to_coroutine(
tcx,
parent_args,
closure_kind_ty,
tcx.coroutine_for_closure(expr_def_id),
coroutine_upvars_ty,
)
}),
);
liberated_sig = tcx.mk_fn_sig(
liberated_sig.inputs().iter().copied(),
tcx.liberate_late_bound_regions(
expr_def_id.to_def_id(),
closure_args.coroutine_closure_sig().map_bound(|sig| {
sig.to_coroutine(
tcx,
parent_args,
closure_kind_ty,
tcx.coroutine_for_closure(expr_def_id),
coroutine_upvars_ty,
)
}),
),
coroutine_output_ty,
liberated_sig.c_variadic,
liberated_sig.unsafety,
liberated_sig.abi,
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
continue;
}

// For this check, we do *not* want to treat async coroutine-closures (async blocks)
// For this check, we do *not* want to treat async coroutine closures (async blocks)
// as proper closures. Doing so would regress type inference when feeding
// the return value of an argument-position async block to an argument-position
// closure wrapped in a block.
Expand Down
10 changes: 9 additions & 1 deletion compiler/rustc_hir_typeck/src/upvar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
}
}

// For coroutine-closures, we additionally must compute the
// `coroutine_captures_by_ref_ty` type, which is used to generate the by-ref
// version of the coroutine-closure's output coroutine.
if let UpvarArgs::CoroutineClosure(args) = args {
let closure_env_region: ty::Region<'_> = ty::Region::new_bound(
self.tcx,
Expand All @@ -353,7 +356,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
self.tcx.coroutine_for_closure(closure_def_id).expect_local(),
)
// Skip the captures that are just moving the closure's args
// into the coroutine. These are always by move.
// into the coroutine. These are always by move, and we append
// those later in the `CoroutineClosureSignature` helper functions.
.skip(
args.as_coroutine_closure()
.coroutine_closure_sig()
Expand All @@ -365,6 +369,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
.map(|captured_place| {
let upvar_ty = captured_place.place.ty();
let capture = captured_place.info.capture_kind;
// Not all upvars are captured by ref, so use
// `apply_capture_kind_on_capture_ty` to ensure that we
// compute the right captured type.
apply_capture_kind_on_capture_ty(
self.tcx,
upvar_ty,
Expand Down Expand Up @@ -394,6 +401,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
coroutine_captures_by_ref_ty,
);

// Additionally, we can now constrain the coroutine's kind type.
let ty::Coroutine(_, coroutine_args) =
*self.typeck_results.borrow().expr_ty(body.value).kind()
else {
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_middle/src/ty/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ impl FlagComputation {
self.flags -= TypeFlags::STILL_FURTHER_SPECIALIZABLE;
}

self.add_ty(args.signature_parts_ty());
self.add_ty(args.coroutine_witness_ty());
self.add_ty(args.coroutine_captures_by_ref_ty());
self.add_ty(args.kind_ty());
self.add_ty(args.signature_parts_ty());
self.add_ty(args.tupled_upvars_ty());
self.add_ty(args.coroutine_captures_by_ref_ty());
self.add_ty(args.coroutine_witness_ty());
}

&ty::Bound(debruijn, _) => {
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_middle/src/ty/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -646,9 +646,16 @@ impl<'tcx> Instance<'tcx> {
bug!()
};

// If the closure's kind ty disagrees with the identity closure's kind ty,
// then this must be a coroutine generated by one of the `ConstructCoroutineInClosureShim`s.
if args.as_coroutine().kind_ty() == id_args.as_coroutine().kind_ty() {
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
} else {
assert_eq!(
args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(),
ty::ClosureKind::FnOnce,
"FIXME(async_closures): Generate a by-mut body here."
);
Some(Instance {
def: ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id },
args,
Expand Down
44 changes: 38 additions & 6 deletions compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,10 @@ pub struct CoroutineClosureArgs<'tcx> {
pub args: GenericArgsRef<'tcx>,
}

/// See docs for explanation of how each argument is used.
///
/// See [`CoroutineClosureSignature`] for how these arguments are put together
/// to make a callable [`FnSig`] suitable for typeck and borrowck.
pub struct CoroutineClosureArgsParts<'tcx> {
/// This is the args of the typeck root.
pub parent_args: &'tcx [GenericArg<'tcx>],
Expand Down Expand Up @@ -485,25 +489,40 @@ pub struct CoroutineClosureSignature<'tcx> {
pub resume_ty: Ty<'tcx>,
pub yield_ty: Ty<'tcx>,
pub return_ty: Ty<'tcx>,

// Like the `fn_sig_as_fn_ptr_ty` of a regular closure, these types
// never actually differ. But we save them rather than recreating them
// from scratch just for good measure.
/// Always false
pub c_variadic: bool,
/// Always [`hir::Unsafety::Normal`]
pub unsafety: hir::Unsafety,
/// Always [`abi::Abi::RustCall`]
pub abi: abi::Abi,
}

impl<'tcx> CoroutineClosureSignature<'tcx> {
/// Construct a coroutine from the closure signature. Since a coroutine signature
/// is agnostic to the type of generator that is returned (by-ref/by-move),
/// the caller must specify what "flavor" of generator that they'd like to
/// create. Additionally, they must manually compute the upvars of the closure.
///
/// This helper is not really meant to be used directly except for early on
/// during typeck, when we want to put inference vars into the kind and upvars tys.
/// When the kind and upvars are known, use the other helper functions.
pub fn to_coroutine(
self,
tcx: TyCtxt<'tcx>,
parent_args: &'tcx [GenericArg<'tcx>],
kind_ty: Ty<'tcx>,
coroutine_kind_ty: Ty<'tcx>,
coroutine_def_id: DefId,
tupled_upvars_ty: Ty<'tcx>,
) -> Ty<'tcx> {
let coroutine_args = ty::CoroutineArgs::new(
tcx,
ty::CoroutineArgsParts {
parent_args,
kind_ty,
kind_ty: coroutine_kind_ty,
resume_ty: self.resume_ty,
yield_ty: self.yield_ty,
return_ty: self.return_ty,
Expand All @@ -515,19 +534,24 @@ impl<'tcx> CoroutineClosureSignature<'tcx> {
Ty::new_coroutine(tcx, coroutine_def_id, coroutine_args.args)
}

/// Given known upvars and a [`ClosureKind`](ty::ClosureKind), compute the coroutine
/// returned by that corresponding async fn trait.
///
/// This function expects the upvars to have been computed already, and doesn't check
/// that the `ClosureKind` is actually supported by the coroutine-closure.
pub fn to_coroutine_given_kind_and_upvars(
self,
tcx: TyCtxt<'tcx>,
parent_args: &'tcx [GenericArg<'tcx>],
coroutine_def_id: DefId,
closure_kind: ty::ClosureKind,
goal_kind: ty::ClosureKind,
env_region: ty::Region<'tcx>,
closure_tupled_upvars_ty: Ty<'tcx>,
coroutine_captures_by_ref_ty: Ty<'tcx>,
) -> Ty<'tcx> {
let tupled_upvars_ty = Self::tupled_upvars_by_closure_kind(
tcx,
closure_kind,
goal_kind,
self.tupled_inputs_ty,
closure_tupled_upvars_ty,
coroutine_captures_by_ref_ty,
Expand All @@ -537,13 +561,21 @@ impl<'tcx> CoroutineClosureSignature<'tcx> {
self.to_coroutine(
tcx,
parent_args,
Ty::from_closure_kind(tcx, closure_kind),
Ty::from_closure_kind(tcx, goal_kind),
coroutine_def_id,
tupled_upvars_ty,
)
}

/// Given a closure kind, compute the tupled upvars that the given coroutine would return.
/// Compute the tupled upvars that a coroutine-closure's output coroutine
/// would return for the given `ClosureKind`.
///
/// When `ClosureKind` is `FnMut`/`Fn`, then this will use the "captures by ref"
/// to return a set of upvars which are borrowed with the given `env_region`.
///
/// This ensures that the `AsyncFn::call` will return a coroutine whose upvars'
/// lifetimes are related to the lifetime of the borrow on the closure made for
/// the call. This allows borrowck to enforce the self-borrows correctly.
pub fn tupled_upvars_by_closure_kind(
tcx: TyCtxt<'tcx>,
kind: ty::ClosureKind,
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_mir_dataflow/src/value_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,9 @@ pub fn iter_fields<'tcx>(
ty::Closure(_, args) => {
iter_fields(args.as_closure().tupled_upvars_ty(), tcx, param_env, f);
}
ty::Coroutine(_, args) => {
iter_fields(args.as_coroutine().tupled_upvars_ty(), tcx, param_env, f);
}
ty::CoroutineClosure(_, args) => {
iter_fields(args.as_coroutine_closure().tupled_upvars_ty(), tcx, param_env, f);
}
Expand Down
13 changes: 12 additions & 1 deletion compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
//! A MIR pass which duplicates a coroutine's body and removes any derefs which
//! would be present for upvars that are taken by-ref. The result of which will
//! be a coroutine body that takes all of its upvars by-move, and which we stash
//! into the `CoroutineInfo` for all coroutines returned by coroutine-closures.
use rustc_data_structures::fx::FxIndexSet;
use rustc_hir as hir;
use rustc_middle::mir::visit::MutVisitor;
Expand Down Expand Up @@ -87,11 +92,16 @@ impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
&& self.by_ref_fields.contains(&idx)
{
let (begin, end) = place.projection[1..].split_first().unwrap();
// FIXME(async_closures): I'm actually a bit surprised to see that we always
// initially deref the by-ref upvars. If this is not actually true, then we
// will at least get an ICE that explains why this isn't true :^)
assert_eq!(*begin, mir::ProjectionElem::Deref);
// Peel one ref off of the ty.
let peeled_ty = ty.builtin_deref(true).unwrap().ty;
*place = mir::Place {
local: place.local,
projection: self.tcx.mk_place_elems_from_iter(
[mir::ProjectionElem::Field(idx, ty.builtin_deref(true).unwrap().ty)]
[mir::ProjectionElem::Field(idx, peeled_ty)]
.into_iter()
.chain(end.iter().copied()),
),
Expand All @@ -101,6 +111,7 @@ impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
}

fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) {
// Replace the type of the self arg.
if local == ty::CAPTURE_STRUCT_LOCAL {
local_decl.ty = self.by_move_coroutine_ty;
}
Expand Down
Loading

0 comments on commit a7e6d88

Please sign in to comment.