From 847fd88df724280880c705848ba1a120ce15e020 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Sun, 24 Mar 2024 20:06:05 -0400
Subject: [PATCH 1/3] Always use tcx.coroutine_layout over calling
 optimized_mir directly

---
 .../src/debuginfo/metadata/enums/cpp_like.rs                   | 2 +-
 .../rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs  | 3 +--
 compiler/rustc_ty_utils/src/layout.rs                          | 2 +-
 3 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs
index 4792b0798dfb8..9a10e1bc1299c 100644
--- a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs
+++ b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs
@@ -683,7 +683,7 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
         _ => unreachable!(),
     };
 
-    let coroutine_layout = cx.tcx.optimized_mir(coroutine_def_id).coroutine_layout().unwrap();
+    let coroutine_layout = cx.tcx.coroutine_layout(coroutine_def_id).unwrap();
 
     let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
     let variant_range = coroutine_args.variant_range(coroutine_def_id, cx.tcx);
diff --git a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs
index 3dbe820b8ff9b..f0f55981760cb 100644
--- a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs
+++ b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs
@@ -158,8 +158,7 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
             DIFlags::FlagZero,
         ),
         |cx, coroutine_type_di_node| {
-            let coroutine_layout =
-                cx.tcx.optimized_mir(coroutine_def_id).coroutine_layout().unwrap();
+            let coroutine_layout = cx.tcx.coroutine_layout(coroutine_def_id).unwrap();
 
             let Variants::Multiple { tag_encoding: TagEncoding::Direct, ref variants, .. } =
                 coroutine_type_and_layout.variants
diff --git a/compiler/rustc_ty_utils/src/layout.rs b/compiler/rustc_ty_utils/src/layout.rs
index 9c3d39307b26f..85ac6071a08cc 100644
--- a/compiler/rustc_ty_utils/src/layout.rs
+++ b/compiler/rustc_ty_utils/src/layout.rs
@@ -1072,7 +1072,7 @@ fn variant_info_for_coroutine<'tcx>(
         return (vec![], None);
     };
 
-    let coroutine = cx.tcx.optimized_mir(def_id).coroutine_layout().unwrap();
+    let coroutine = cx.tcx.coroutine_layout(def_id).unwrap();
     let upvar_names = cx.tcx.closure_saved_names_of_captured_variables(def_id);
 
     let mut upvars_size = Size::ZERO;

From b7d67eace78d5e660df93b513326650fe8226a96 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Sun, 24 Mar 2024 21:12:49 -0400
Subject: [PATCH 2/3] Require coroutine kind type to be passed to
 TyCtxt::coroutine_layout

---
 .../src/debuginfo/metadata/enums/cpp_like.rs  |  3 +-
 .../src/debuginfo/metadata/enums/native.rs    |  7 +++-
 .../src/transform/validate.rs                 | 13 ++++---
 compiler/rustc_middle/src/mir/mod.rs          |  3 +-
 compiler/rustc_middle/src/mir/pretty.rs       |  2 +-
 compiler/rustc_middle/src/ty/mod.rs           | 37 ++++++++++++++++++-
 compiler/rustc_middle/src/ty/sty.rs           |  7 +++-
 compiler/rustc_ty_utils/src/layout.rs         |  4 +-
 8 files changed, 59 insertions(+), 17 deletions(-)

diff --git a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs
index 9a10e1bc1299c..4edef14422e5f 100644
--- a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs
+++ b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs
@@ -683,7 +683,8 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
         _ => unreachable!(),
     };
 
-    let coroutine_layout = cx.tcx.coroutine_layout(coroutine_def_id).unwrap();
+    let coroutine_layout =
+        cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.kind_ty()).unwrap();
 
     let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
     let variant_range = coroutine_args.variant_range(coroutine_def_id, cx.tcx);
diff --git a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs
index f0f55981760cb..115d5187eafa8 100644
--- a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs
+++ b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs
@@ -135,7 +135,7 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
     unique_type_id: UniqueTypeId<'tcx>,
 ) -> DINodeCreationResult<'ll> {
     let coroutine_type = unique_type_id.expect_ty();
-    let &ty::Coroutine(coroutine_def_id, _) = coroutine_type.kind() else {
+    let &ty::Coroutine(coroutine_def_id, coroutine_args) = coroutine_type.kind() else {
         bug!("build_coroutine_di_node() called with non-coroutine type: `{:?}`", coroutine_type)
     };
 
@@ -158,7 +158,10 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
             DIFlags::FlagZero,
         ),
         |cx, coroutine_type_di_node| {
-            let coroutine_layout = cx.tcx.coroutine_layout(coroutine_def_id).unwrap();
+            let coroutine_layout = cx
+                .tcx
+                .coroutine_layout(coroutine_def_id, coroutine_args.as_coroutine().kind_ty())
+                .unwrap();
 
             let Variants::Multiple { tag_encoding: TagEncoding::Direct, ref variants, .. } =
                 coroutine_type_and_layout.variants
diff --git a/compiler/rustc_const_eval/src/transform/validate.rs b/compiler/rustc_const_eval/src/transform/validate.rs
index 08e3e42a82e27..b085e4e76a11e 100644
--- a/compiler/rustc_const_eval/src/transform/validate.rs
+++ b/compiler/rustc_const_eval/src/transform/validate.rs
@@ -101,9 +101,9 @@ impl<'tcx> MirPass<'tcx> for Validator {
         }
 
         // Enforce that coroutine-closure layouts are identical.
-        if let Some(layout) = body.coroutine_layout()
+        if let Some(layout) = body.coroutine_layout_raw()
             && let Some(by_move_body) = body.coroutine_by_move_body()
-            && let Some(by_move_layout) = by_move_body.coroutine_layout()
+            && let Some(by_move_layout) = by_move_body.coroutine_layout_raw()
         {
             if layout != by_move_layout {
                 // If this turns out not to be true, please let compiler-errors know.
@@ -715,13 +715,14 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
                             // args of the coroutine. Otherwise, we prefer to use this body
                             // since we may be in the process of computing this MIR in the
                             // first place.
-                            let gen_body = if def_id == self.caller_body.source.def_id() {
-                                self.caller_body
+                            let layout = if def_id == self.caller_body.source.def_id() {
+                                // FIXME: This is not right for async closures.
+                                self.caller_body.coroutine_layout_raw()
                             } else {
-                                self.tcx.optimized_mir(def_id)
+                                self.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty())
                             };
 
-                            let Some(layout) = gen_body.coroutine_layout() else {
+                            let Some(layout) = layout else {
                                 self.fail(
                                     location,
                                     format!("No coroutine layout for {parent_ty:?}"),
diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs
index e4dce2bdc9e80..02af55fbf0e4f 100644
--- a/compiler/rustc_middle/src/mir/mod.rs
+++ b/compiler/rustc_middle/src/mir/mod.rs
@@ -652,8 +652,9 @@ impl<'tcx> Body<'tcx> {
         self.coroutine.as_ref().and_then(|coroutine| coroutine.resume_ty)
     }
 
+    /// Prefer going through [`TyCtxt::coroutine_layout`] rather than using this directly.
     #[inline]
-    pub fn coroutine_layout(&self) -> Option<&CoroutineLayout<'tcx>> {
+    pub fn coroutine_layout_raw(&self) -> Option<&CoroutineLayout<'tcx>> {
         self.coroutine.as_ref().and_then(|coroutine| coroutine.coroutine_layout.as_ref())
     }
 
diff --git a/compiler/rustc_middle/src/mir/pretty.rs b/compiler/rustc_middle/src/mir/pretty.rs
index f0499cf344fca..41df2e3b5875a 100644
--- a/compiler/rustc_middle/src/mir/pretty.rs
+++ b/compiler/rustc_middle/src/mir/pretty.rs
@@ -126,7 +126,7 @@ fn dump_matched_mir_node<'tcx, F>(
             Some(promoted) => write!(file, "::{promoted:?}`")?,
         }
         writeln!(file, " {disambiguator} {pass_name}")?;
-        if let Some(ref layout) = body.coroutine_layout() {
+        if let Some(ref layout) = body.coroutine_layout_raw() {
             writeln!(file, "/* coroutine_layout = {layout:#?} */")?;
         }
         writeln!(file)?;
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 6ce53ccc8cd7a..aad2f6a4cf8ac 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -60,6 +60,7 @@ pub use rustc_target::abi::{ReprFlags, ReprOptions};
 pub use rustc_type_ir::{DebugWithInfcx, InferCtxtLike, WithInfcx};
 pub use vtable::*;
 
+use std::assert_matches::assert_matches;
 use std::fmt::Debug;
 use std::hash::{Hash, Hasher};
 use std::marker::PhantomData;
@@ -1826,8 +1827,40 @@ impl<'tcx> TyCtxt<'tcx> {
 
     /// Returns layout of a coroutine. Layout might be unavailable if the
     /// coroutine is tainted by errors.
-    pub fn coroutine_layout(self, def_id: DefId) -> Option<&'tcx CoroutineLayout<'tcx>> {
-        self.optimized_mir(def_id).coroutine_layout()
+    ///
+    /// Takes `coroutine_kind` which can be acquired from the `CoroutineArgs::kind_ty`,
+    /// e.g. `args.as_coroutine().kind_ty()`.
+    pub fn coroutine_layout(
+        self,
+        def_id: DefId,
+        coroutine_kind_ty: Ty<'tcx>,
+    ) -> Option<&'tcx CoroutineLayout<'tcx>> {
+        let mir = self.optimized_mir(def_id);
+        // Regular coroutine
+        if coroutine_kind_ty.is_unit() {
+            mir.coroutine_layout_raw()
+        } else {
+            // If we have a `Coroutine` that comes from an coroutine-closure,
+            // then it may be a by-move or by-ref body.
+            let ty::Coroutine(_, identity_args) =
+                *self.type_of(def_id).instantiate_identity().kind()
+            else {
+                unreachable!();
+            };
+            let identity_kind_ty = identity_args.as_coroutine().kind_ty();
+            // If the types differ, then we must be getting the by-move body of
+            // a by-ref coroutine.
+            if identity_kind_ty == coroutine_kind_ty {
+                mir.coroutine_layout_raw()
+            } else {
+                assert_matches!(coroutine_kind_ty.to_opt_closure_kind(), Some(ClosureKind::FnOnce));
+                assert_matches!(
+                    identity_kind_ty.to_opt_closure_kind(),
+                    Some(ClosureKind::Fn | ClosureKind::FnMut)
+                );
+                mir.coroutine_by_move_body().unwrap().coroutine_layout_raw()
+            }
+        }
     }
 
     /// Given the `DefId` of an impl, returns the `DefId` of the trait it implements.
diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs
index c85ee140fa4ec..82a2423cbc6b0 100644
--- a/compiler/rustc_middle/src/ty/sty.rs
+++ b/compiler/rustc_middle/src/ty/sty.rs
@@ -694,7 +694,10 @@ impl<'tcx> CoroutineArgs<'tcx> {
     #[inline]
     pub fn variant_range(&self, def_id: DefId, tcx: TyCtxt<'tcx>) -> Range<VariantIdx> {
         // FIXME requires optimized MIR
-        FIRST_VARIANT..tcx.coroutine_layout(def_id).unwrap().variant_fields.next_index()
+        // FIXME(async_closures): We should assert all coroutine layouts have
+        // the same number of variants.
+        FIRST_VARIANT
+            ..tcx.coroutine_layout(def_id, tcx.types.unit).unwrap().variant_fields.next_index()
     }
 
     /// The discriminant for the given variant. Panics if the `variant_index` is
@@ -754,7 +757,7 @@ impl<'tcx> CoroutineArgs<'tcx> {
         def_id: DefId,
         tcx: TyCtxt<'tcx>,
     ) -> impl Iterator<Item: Iterator<Item = Ty<'tcx>> + Captures<'tcx>> {
-        let layout = tcx.coroutine_layout(def_id).unwrap();
+        let layout = tcx.coroutine_layout(def_id, self.kind_ty()).unwrap();
         layout.variant_fields.iter().map(move |variant| {
             variant.iter().map(move |field| {
                 ty::EarlyBinder::bind(layout.field_tys[*field].ty).instantiate(tcx, self.args)
diff --git a/compiler/rustc_ty_utils/src/layout.rs b/compiler/rustc_ty_utils/src/layout.rs
index 85ac6071a08cc..331970ac36233 100644
--- a/compiler/rustc_ty_utils/src/layout.rs
+++ b/compiler/rustc_ty_utils/src/layout.rs
@@ -745,7 +745,7 @@ fn coroutine_layout<'tcx>(
     let tcx = cx.tcx;
     let instantiate_field = |ty: Ty<'tcx>| EarlyBinder::bind(ty).instantiate(tcx, args);
 
-    let Some(info) = tcx.coroutine_layout(def_id) else {
+    let Some(info) = tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()) else {
         return Err(error(cx, LayoutError::Unknown(ty)));
     };
     let (ineligible_locals, assignments) = coroutine_saved_local_eligibility(info);
@@ -1072,7 +1072,7 @@ fn variant_info_for_coroutine<'tcx>(
         return (vec![], None);
     };
 
-    let coroutine = cx.tcx.coroutine_layout(def_id).unwrap();
+    let coroutine = cx.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()).unwrap();
     let upvar_names = cx.tcx.closure_saved_names_of_captured_variables(def_id);
 
     let mut upvars_size = Size::ZERO;

From 9bda9ac76e8d232c6cf0efde55dace718c1d428c Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Sun, 24 Mar 2024 21:14:49 -0400
Subject: [PATCH 3/3] Relax validation now

---
 compiler/rustc_const_eval/src/transform/validate.rs | 9 ++++-----
 compiler/rustc_middle/src/ty/sty.rs                 | 2 --
 2 files changed, 4 insertions(+), 7 deletions(-)

diff --git a/compiler/rustc_const_eval/src/transform/validate.rs b/compiler/rustc_const_eval/src/transform/validate.rs
index b085e4e76a11e..378b168a50c33 100644
--- a/compiler/rustc_const_eval/src/transform/validate.rs
+++ b/compiler/rustc_const_eval/src/transform/validate.rs
@@ -105,14 +105,13 @@ impl<'tcx> MirPass<'tcx> for Validator {
             && let Some(by_move_body) = body.coroutine_by_move_body()
             && let Some(by_move_layout) = by_move_body.coroutine_layout_raw()
         {
-            if layout != by_move_layout {
-                // If this turns out not to be true, please let compiler-errors know.
-                // It is possible to support, but requires some changes to the layout
-                // computation code.
+            // FIXME(async_closures): We could do other validation here?
+            if layout.variant_fields.len() != by_move_layout.variant_fields.len() {
                 cfg_checker.fail(
                     Location::START,
                     format!(
-                        "Coroutine layout differs from by-move coroutine layout:\n\
+                        "Coroutine layout has different number of variant fields from \
+                        by-move coroutine layout:\n\
                         layout: {layout:#?}\n\
                         by_move_layout: {by_move_layout:#?}",
                     ),
diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs
index 82a2423cbc6b0..510a4b59520c4 100644
--- a/compiler/rustc_middle/src/ty/sty.rs
+++ b/compiler/rustc_middle/src/ty/sty.rs
@@ -694,8 +694,6 @@ impl<'tcx> CoroutineArgs<'tcx> {
     #[inline]
     pub fn variant_range(&self, def_id: DefId, tcx: TyCtxt<'tcx>) -> Range<VariantIdx> {
         // FIXME requires optimized MIR
-        // FIXME(async_closures): We should assert all coroutine layouts have
-        // the same number of variants.
         FIRST_VARIANT
             ..tcx.coroutine_layout(def_id, tcx.types.unit).unwrap().variant_fields.next_index()
     }