diff --git a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs index f65f5f9c07023..af533d8db7149 100644 --- a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs +++ b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs @@ -305,15 +305,14 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_callable<'tcx>( return Err(NoSolution); } - sig.to_coroutine_given_kind_and_upvars( + coroutine_closure_to_certain_coroutine( tcx, - args.parent_args(), - tcx.coroutine_for_closure(def_id), goal_kind, // No captures by ref, so this doesn't matter. tcx.lifetimes.re_static, - args.tupled_upvars_ty(), - args.coroutine_captures_by_ref_ty(), + def_id, + args, + sig, ) } else { // Closure kind is not yet determined, so we return ambiguity unless @@ -322,33 +321,13 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_callable<'tcx>( return Ok(None); } - let async_fn_kind_trait_def_id = - tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); - let upvars_projection_def_id = tcx - .associated_items(async_fn_kind_trait_def_id) - .filter_by_name_unhygienic(sym::Upvars) - .next() - .unwrap() - .def_id; - let tupled_upvars_ty = Ty::new_projection( - tcx, - upvars_projection_def_id, - [ - ty::GenericArg::from(kind_ty), - Ty::from_closure_kind(tcx, goal_kind).into(), - // No captures by ref, so this doesn't matter. - tcx.lifetimes.re_static.into(), - sig.tupled_inputs_ty.into(), - args.tupled_upvars_ty().into(), - args.coroutine_captures_by_ref_ty().into(), - ], - ); - sig.to_coroutine( + coroutine_closure_to_ambiguous_coroutine( tcx, - args.parent_args(), - Ty::from_closure_kind(tcx, goal_kind), - tcx.coroutine_for_closure(def_id), - tupled_upvars_ty, + goal_kind, // No captures by ref, so this doesn't matter. + tcx.lifetimes.re_static, + def_id, + args, + sig, ) }; @@ -385,6 +364,19 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_callable<'tcx>( } } +/// Relevant types for an async callable, including its inputs, output, +/// and the return type you get from awaiting the output. +#[derive(Copy, Clone, Debug, TypeVisitable, TypeFoldable)] +pub(in crate::solve) struct AsyncCallableRelevantTypes<'tcx> { + pub tupled_inputs_ty: Ty<'tcx>, + /// Type returned by calling the closure + /// i.e. `f()`. + pub output_coroutine_ty: Ty<'tcx>, + /// Type returned by `await`ing the output + /// i.e. `f().await`. + pub coroutine_return_ty: Ty<'tcx>, +} + // Returns a binder of the tupled inputs types, output type, and coroutine type // from a builtin coroutine-closure type. If we don't yet know the closure kind of // the coroutine-closure, emit an additional trait predicate for `AsyncFnKindHelper` @@ -395,8 +387,10 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc self_ty: Ty<'tcx>, goal_kind: ty::ClosureKind, env_region: ty::Region<'tcx>, -) -> Result<(ty::Binder<'tcx, (Ty<'tcx>, Ty<'tcx>, Ty<'tcx>)>, Vec>), NoSolution> -{ +) -> Result< + (ty::Binder<'tcx, AsyncCallableRelevantTypes<'tcx>>, Vec>), + NoSolution, +> { match *self_ty.kind() { ty::CoroutineClosure(def_id, args) => { let args = args.as_coroutine_closure(); @@ -407,24 +401,11 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc if !closure_kind.extends(goal_kind) { return Err(NoSolution); } - sig.to_coroutine_given_kind_and_upvars( - tcx, - args.parent_args(), - tcx.coroutine_for_closure(def_id), - goal_kind, - env_region, - args.tupled_upvars_ty(), - args.coroutine_captures_by_ref_ty(), + + coroutine_closure_to_certain_coroutine( + tcx, goal_kind, env_region, def_id, args, sig, ) } else { - let async_fn_kind_trait_def_id = - tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); - let upvars_projection_def_id = tcx - .associated_items(async_fn_kind_trait_def_id) - .filter_by_name_unhygienic(sym::Upvars) - .next() - .unwrap() - .def_id; // When we don't know the closure kind (and therefore also the closure's upvars, // which are computed at the same time), we must delay the computation of the // generator's upvars. We do this using the `AsyncFnKindHelper`, which as a trait @@ -435,38 +416,23 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc nested.push( ty::TraitRef::new( tcx, - async_fn_kind_trait_def_id, + tcx.require_lang_item(LangItem::AsyncFnKindHelper, None), [kind_ty, Ty::from_closure_kind(tcx, goal_kind)], ) .to_predicate(tcx), ); - let tupled_upvars_ty = Ty::new_projection( - tcx, - upvars_projection_def_id, - [ - ty::GenericArg::from(kind_ty), - Ty::from_closure_kind(tcx, goal_kind).into(), - env_region.into(), - sig.tupled_inputs_ty.into(), - args.tupled_upvars_ty().into(), - args.coroutine_captures_by_ref_ty().into(), - ], - ); - sig.to_coroutine( - tcx, - args.parent_args(), - Ty::from_closure_kind(tcx, goal_kind), - tcx.coroutine_for_closure(def_id), - tupled_upvars_ty, + + coroutine_closure_to_ambiguous_coroutine( + tcx, goal_kind, env_region, def_id, args, sig, ) }; Ok(( - args.coroutine_closure_sig().rebind(( - sig.tupled_inputs_ty, - sig.return_ty, - coroutine_ty, - )), + args.coroutine_closure_sig().rebind(AsyncCallableRelevantTypes { + tupled_inputs_ty: sig.tupled_inputs_ty, + output_coroutine_ty: coroutine_ty, + coroutine_return_ty: sig.return_ty, + }), nested, )) } @@ -490,7 +456,11 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc .def_id; let future_output_ty = Ty::new_projection(tcx, future_output_def_id, [sig.output()]); Ok(( - bound_sig.rebind((Ty::new_tup(tcx, sig.inputs()), sig.output(), future_output_ty)), + bound_sig.rebind(AsyncCallableRelevantTypes { + tupled_inputs_ty: Ty::new_tup(tcx, sig.inputs()), + output_coroutine_ty: sig.output(), + coroutine_return_ty: future_output_ty, + }), nested, )) } @@ -541,7 +511,14 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc .unwrap() .def_id; let future_output_ty = Ty::new_projection(tcx, future_output_def_id, [sig.output()]); - Ok((bound_sig.rebind((sig.inputs()[0], sig.output(), future_output_ty)), nested)) + Ok(( + bound_sig.rebind(AsyncCallableRelevantTypes { + tupled_inputs_ty: sig.inputs()[0], + output_coroutine_ty: sig.output(), + coroutine_return_ty: future_output_ty, + }), + nested, + )) } ty::Bool @@ -574,6 +551,68 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc } } +/// Given a coroutine-closure, project to its returned coroutine when we are *certain* +/// that the closure's kind is compatible with the goal. +fn coroutine_closure_to_certain_coroutine<'tcx>( + tcx: TyCtxt<'tcx>, + goal_kind: ty::ClosureKind, + goal_region: ty::Region<'tcx>, + def_id: DefId, + args: ty::CoroutineClosureArgs<'tcx>, + sig: ty::CoroutineClosureSignature<'tcx>, +) -> Ty<'tcx> { + sig.to_coroutine_given_kind_and_upvars( + tcx, + args.parent_args(), + tcx.coroutine_for_closure(def_id), + goal_kind, + goal_region, + args.tupled_upvars_ty(), + args.coroutine_captures_by_ref_ty(), + ) +} + +/// Given a coroutine-closure, project to its returned coroutine when we are *not certain* +/// that the closure's kind is compatible with the goal, and therefore also don't know +/// yet what the closure's upvars are. +/// +/// Note that we do not also push a `AsyncFnKindHelper` goal here. +fn coroutine_closure_to_ambiguous_coroutine<'tcx>( + tcx: TyCtxt<'tcx>, + goal_kind: ty::ClosureKind, + goal_region: ty::Region<'tcx>, + def_id: DefId, + args: ty::CoroutineClosureArgs<'tcx>, + sig: ty::CoroutineClosureSignature<'tcx>, +) -> Ty<'tcx> { + let async_fn_kind_trait_def_id = tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); + let upvars_projection_def_id = tcx + .associated_items(async_fn_kind_trait_def_id) + .filter_by_name_unhygienic(sym::Upvars) + .next() + .unwrap() + .def_id; + let tupled_upvars_ty = Ty::new_projection( + tcx, + upvars_projection_def_id, + [ + ty::GenericArg::from(args.kind_ty()), + Ty::from_closure_kind(tcx, goal_kind).into(), + goal_region.into(), + sig.tupled_inputs_ty.into(), + args.tupled_upvars_ty().into(), + args.coroutine_captures_by_ref_ty().into(), + ], + ); + sig.to_coroutine( + tcx, + args.parent_args(), + Ty::from_closure_kind(tcx, goal_kind), + tcx.coroutine_for_closure(def_id), + tupled_upvars_ty, + ) +} + /// Assemble a list of predicates that would be present on a theoretical /// user impl for an object type. These predicates must be checked any time /// we assemble a built-in object candidate for an object type, since they diff --git a/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs b/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs index aa8cc3667cd7b..3aba5c85abc3a 100644 --- a/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs +++ b/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs @@ -1,5 +1,6 @@ use crate::traits::{check_args_compatible, specialization_graph}; +use super::assembly::structural_traits::AsyncCallableRelevantTypes; use super::assembly::{self, structural_traits, Candidate}; use super::{EvalCtxt, GoalSource}; use rustc_hir::def::DefKind; @@ -392,46 +393,56 @@ impl<'tcx> assembly::GoalKind<'tcx> for NormalizesTo<'tcx> { goal_kind, env_region, )?; - let output_is_sized_pred = - tupled_inputs_and_output_and_coroutine.map_bound(|(_, output, _)| { - ty::TraitRef::from_lang_item(tcx, LangItem::Sized, DUMMY_SP, [output]) - }); + let output_is_sized_pred = tupled_inputs_and_output_and_coroutine.map_bound( + |AsyncCallableRelevantTypes { output_coroutine_ty: output_ty, .. }| { + ty::TraitRef::from_lang_item(tcx, LangItem::Sized, DUMMY_SP, [output_ty]) + }, + ); let pred = tupled_inputs_and_output_and_coroutine - .map_bound(|(inputs, output, coroutine)| { - let (projection_ty, term) = match tcx.item_name(goal.predicate.def_id()) { - sym::CallOnceFuture => ( - ty::AliasTy::new( - tcx, - goal.predicate.def_id(), - [goal.predicate.self_ty(), inputs], + .map_bound( + |AsyncCallableRelevantTypes { + tupled_inputs_ty, + output_coroutine_ty, + coroutine_return_ty, + }| { + let (projection_ty, term) = match tcx.item_name(goal.predicate.def_id()) { + sym::CallOnceFuture => ( + ty::AliasTy::new( + tcx, + goal.predicate.def_id(), + [goal.predicate.self_ty(), tupled_inputs_ty], + ), + output_coroutine_ty.into(), ), - coroutine.into(), - ), - sym::CallMutFuture | sym::CallFuture => ( - ty::AliasTy::new( - tcx, - goal.predicate.def_id(), - [ - ty::GenericArg::from(goal.predicate.self_ty()), - inputs.into(), - env_region.into(), - ], + sym::CallMutFuture | sym::CallFuture => ( + ty::AliasTy::new( + tcx, + goal.predicate.def_id(), + [ + ty::GenericArg::from(goal.predicate.self_ty()), + tupled_inputs_ty.into(), + env_region.into(), + ], + ), + output_coroutine_ty.into(), ), - coroutine.into(), - ), - sym::Output => ( - ty::AliasTy::new( - tcx, - goal.predicate.def_id(), - [ty::GenericArg::from(goal.predicate.self_ty()), inputs.into()], + sym::Output => ( + ty::AliasTy::new( + tcx, + goal.predicate.def_id(), + [ + ty::GenericArg::from(goal.predicate.self_ty()), + tupled_inputs_ty.into(), + ], + ), + coroutine_return_ty.into(), ), - output.into(), - ), - name => bug!("no such associated type: {name}"), - }; - ty::ProjectionPredicate { projection_ty, term } - }) + name => bug!("no such associated type: {name}"), + }; + ty::ProjectionPredicate { projection_ty, term } + }, + ) .to_predicate(tcx); // A built-in `AsyncFn` impl only holds if the output is sized. diff --git a/compiler/rustc_trait_selection/src/solve/trait_goals.rs b/compiler/rustc_trait_selection/src/solve/trait_goals.rs index 73bf66f66890f..eba6ba3f7b062 100644 --- a/compiler/rustc_trait_selection/src/solve/trait_goals.rs +++ b/compiler/rustc_trait_selection/src/solve/trait_goals.rs @@ -2,6 +2,7 @@ use crate::traits::supertrait_def_ids; +use super::assembly::structural_traits::AsyncCallableRelevantTypes; use super::assembly::{self, structural_traits, Candidate}; use super::{EvalCtxt, GoalSource, SolverMode}; use rustc_data_structures::fx::FxIndexSet; @@ -327,14 +328,19 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> { // This region doesn't matter because we're throwing away the coroutine type tcx.lifetimes.re_static, )?; - let output_is_sized_pred = - tupled_inputs_and_output_and_coroutine.map_bound(|(_, output, _)| { - ty::TraitRef::from_lang_item(tcx, LangItem::Sized, DUMMY_SP, [output]) - }); + let output_is_sized_pred = tupled_inputs_and_output_and_coroutine.map_bound( + |AsyncCallableRelevantTypes { output_coroutine_ty, .. }| { + ty::TraitRef::from_lang_item(tcx, LangItem::Sized, DUMMY_SP, [output_coroutine_ty]) + }, + ); let pred = tupled_inputs_and_output_and_coroutine - .map_bound(|(inputs, _, _)| { - ty::TraitRef::new(tcx, goal.predicate.def_id(), [goal.predicate.self_ty(), inputs]) + .map_bound(|AsyncCallableRelevantTypes { tupled_inputs_ty, .. }| { + ty::TraitRef::new( + tcx, + goal.predicate.def_id(), + [goal.predicate.self_ty(), tupled_inputs_ty], + ) }) .to_predicate(tcx); // A built-in `AsyncFn` impl only holds if the output is sized.