diff --git a/compiler/rustc_hir_typeck/src/check.rs b/compiler/rustc_hir_typeck/src/check.rs index 871fde81cce8c..b40e495b1698c 100644 --- a/compiler/rustc_hir_typeck/src/check.rs +++ b/compiler/rustc_hir_typeck/src/check.rs @@ -94,7 +94,7 @@ pub(super) fn check_fn<'a, 'tcx>( // Resume type defaults to `()` if the coroutine has no argument. let resume_ty = fn_sig.inputs().get(0).copied().unwrap_or_else(|| Ty::new_unit(tcx)); - fcx.resume_yield_tys = Some((resume_ty, yield_ty)); + fcx.coroutine_types = Some(CoroutineTypes { resume_ty, yield_ty }); } GatherLocalsVisitor::new(fcx).visit_body(body); @@ -146,20 +146,6 @@ pub(super) fn check_fn<'a, 'tcx>( fcx.require_type_is_sized(declared_ret_ty, return_or_body_span, traits::SizedReturnType); fcx.check_return_expr(body.value, false); - // We insert the deferred_coroutine_interiors entry after visiting the body. - // This ensures that all nested coroutines appear before the entry of this coroutine. - // resolve_coroutine_interiors relies on this property. - let coroutine_ty = if let Some(hir::ClosureKind::Coroutine(_)) = closure_kind { - let interior = fcx - .next_ty_var(TypeVariableOrigin { kind: TypeVariableOriginKind::MiscVariable, span }); - fcx.deferred_coroutine_interiors.borrow_mut().push((fn_def_id, body.id(), interior)); - - let (resume_ty, yield_ty) = fcx.resume_yield_tys.unwrap(); - Some(CoroutineTypes { resume_ty, yield_ty, interior }) - } else { - None - }; - // Finalize the return check by taking the LUB of the return types // we saw and assigning it to the expected return type. This isn't // really expected to fail, since the coercions would have failed @@ -195,7 +181,7 @@ pub(super) fn check_fn<'a, 'tcx>( check_lang_start_fn(tcx, fn_sig, fn_def_id); } - coroutine_ty + fcx.coroutine_types } fn check_panic_info_fn(tcx: TyCtxt<'_>, fn_id: LocalDefId, fn_sig: ty::FnSig<'_>) { diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs index bf6fda20df83e..8ebda50d334c3 100644 --- a/compiler/rustc_hir_typeck/src/closure.rs +++ b/compiler/rustc_hir_typeck/src/closure.rs @@ -105,59 +105,76 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { span: self.tcx.def_span(expr_def_id), }); - if let Some(CoroutineTypes { resume_ty, yield_ty, interior }) = coroutine_types { - let coroutine_args = ty::CoroutineArgs::new( - self.tcx, - ty::CoroutineArgsParts { - parent_args, - resume_ty, - yield_ty, - return_ty: liberated_sig.output(), - witness: interior, - tupled_upvars_ty, - }, - ); - - return Ty::new_coroutine(self.tcx, expr_def_id.to_def_id(), coroutine_args.args); - } - - // Tuple up the arguments and insert the resulting function type into - // the `closures` table. - let sig = bound_sig.map_bound(|sig| { - self.tcx.mk_fn_sig( - [Ty::new_tup(self.tcx, sig.inputs())], - sig.output(), - sig.c_variadic, - sig.unsafety, - sig.abi, - ) - }); - - debug!(?sig, ?opt_kind); - - let closure_kind_ty = match opt_kind { - Some(kind) => Ty::from_closure_kind(self.tcx, kind), + match closure.kind { + hir::ClosureKind::Closure => { + assert_eq!(coroutine_types, None); + // Tuple up the arguments and insert the resulting function type into + // the `closures` table. + let sig = bound_sig.map_bound(|sig| { + self.tcx.mk_fn_sig( + [Ty::new_tup(self.tcx, sig.inputs())], + sig.output(), + sig.c_variadic, + sig.unsafety, + sig.abi, + ) + }); - // Create a type variable (for now) to represent the closure kind. - // It will be unified during the upvar inference phase (`upvar.rs`) - None => self.next_root_ty_var(TypeVariableOrigin { - // FIXME(eddyb) distinguish closure kind inference variables from the rest. - kind: TypeVariableOriginKind::ClosureSynthetic, - span: expr_span, - }), - }; + debug!(?sig, ?opt_kind); + + let closure_kind_ty = match opt_kind { + Some(kind) => Ty::from_closure_kind(self.tcx, kind), + + // Create a type variable (for now) to represent the closure kind. + // It will be unified during the upvar inference phase (`upvar.rs`) + None => self.next_root_ty_var(TypeVariableOrigin { + // FIXME(eddyb) distinguish closure kind inference variables from the rest. + kind: TypeVariableOriginKind::ClosureSynthetic, + span: expr_span, + }), + }; + + let closure_args = ty::ClosureArgs::new( + self.tcx, + ty::ClosureArgsParts { + parent_args, + closure_kind_ty, + closure_sig_as_fn_ptr_ty: Ty::new_fn_ptr(self.tcx, sig), + tupled_upvars_ty, + }, + ); - let closure_args = ty::ClosureArgs::new( - self.tcx, - ty::ClosureArgsParts { - parent_args, - closure_kind_ty, - closure_sig_as_fn_ptr_ty: Ty::new_fn_ptr(self.tcx, sig), - tupled_upvars_ty, - }, - ); + Ty::new_closure(self.tcx, expr_def_id.to_def_id(), closure_args.args) + } + hir::ClosureKind::Coroutine(_) => { + let Some(CoroutineTypes { resume_ty, yield_ty }) = coroutine_types else { + bug!("expected coroutine to have yield/resume types"); + }; + let interior = fcx.next_ty_var(TypeVariableOrigin { + kind: TypeVariableOriginKind::MiscVariable, + span: body.value.span, + }); + fcx.deferred_coroutine_interiors.borrow_mut().push(( + expr_def_id, + body.id(), + interior, + )); + + let coroutine_args = ty::CoroutineArgs::new( + self.tcx, + ty::CoroutineArgsParts { + parent_args, + resume_ty, + yield_ty, + return_ty: liberated_sig.output(), + witness: interior, + tupled_upvars_ty, + }, + ); - Ty::new_closure(self.tcx, expr_def_id.to_def_id(), closure_args.args) + Ty::new_coroutine(self.tcx, expr_def_id.to_def_id(), coroutine_args.args) + } + } } /// Given the expected type, figures out what it can about this closure we diff --git a/compiler/rustc_hir_typeck/src/expr.rs b/compiler/rustc_hir_typeck/src/expr.rs index 7d753216534b1..d1d0c070c7c9f 100644 --- a/compiler/rustc_hir_typeck/src/expr.rs +++ b/compiler/rustc_hir_typeck/src/expr.rs @@ -15,6 +15,7 @@ use crate::errors::{ use crate::fatally_break_rust; use crate::method::SelfSource; use crate::type_error_struct; +use crate::CoroutineTypes; use crate::Expectation::{self, ExpectCastableToType, ExpectHasType, NoExpectation}; use crate::{ report_unexpected_variant_res, BreakableCtxt, Diverges, FnCtxt, Needs, @@ -3164,8 +3165,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { expr: &'tcx hir::Expr<'tcx>, src: &'tcx hir::YieldSource, ) -> Ty<'tcx> { - match self.resume_yield_tys { - Some((resume_ty, yield_ty)) => { + match self.coroutine_types { + Some(CoroutineTypes { resume_ty, yield_ty }) => { self.check_expr_coercible_to_type(value, yield_ty, None); resume_ty diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs index 635284c5f7323..fde3d41faecf0 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs @@ -5,7 +5,7 @@ mod checks; mod suggestions; use crate::coercion::DynamicCoerceMany; -use crate::{Diverges, EnclosingBreakables, Inherited}; +use crate::{CoroutineTypes, Diverges, EnclosingBreakables, Inherited}; use rustc_errors::{DiagCtxt, ErrorGuaranteed}; use rustc_hir as hir; use rustc_hir::def_id::{DefId, LocalDefId}; @@ -68,7 +68,7 @@ pub struct FnCtxt<'a, 'tcx> { /// First span of a return site that we find. Used in error messages. pub(super) ret_coercion_span: Cell>, - pub(super) resume_yield_tys: Option<(Ty<'tcx>, Ty<'tcx>)>, + pub(super) coroutine_types: Option>, /// Whether the last checked node generates a divergence (e.g., /// `return` will set this to `Always`). In general, when entering @@ -122,7 +122,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { err_count_on_creation: inh.tcx.dcx().err_count(), ret_coercion: None, ret_coercion_span: Cell::new(None), - resume_yield_tys: None, + coroutine_types: None, diverges: Cell::new(Diverges::Maybe), enclosing_breakables: RefCell::new(EnclosingBreakables { stack: Vec::new(), diff --git a/compiler/rustc_hir_typeck/src/lib.rs b/compiler/rustc_hir_typeck/src/lib.rs index da9a2bde783b8..20ef0fab43b07 100644 --- a/compiler/rustc_hir_typeck/src/lib.rs +++ b/compiler/rustc_hir_typeck/src/lib.rs @@ -295,15 +295,13 @@ fn typeck_with_fallback<'tcx>( /// When `check_fn` is invoked on a coroutine (i.e., a body that /// includes yield), it returns back some information about the yield /// points. +#[derive(Debug, PartialEq, Copy, Clone)] struct CoroutineTypes<'tcx> { /// Type of coroutine argument / values returned by `yield`. resume_ty: Ty<'tcx>, /// Type of value that is yielded. yield_ty: Ty<'tcx>, - - /// Types that are captured (see `CoroutineInterior` for more). - interior: Ty<'tcx>, } #[derive(Copy, Clone, Debug, PartialEq, Eq)]