Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make TyCtxt::coroutine_layout take coroutine's kind parameter #123021

Merged
merged 3 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,8 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
_ => unreachable!(),
};

let coroutine_layout = cx.tcx.optimized_mir(coroutine_def_id).coroutine_layout().unwrap();
let coroutine_layout =
cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.kind_ty()).unwrap();

let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
let variant_range = coroutine_args.variant_range(coroutine_def_id, cx.tcx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
unique_type_id: UniqueTypeId<'tcx>,
) -> DINodeCreationResult<'ll> {
let coroutine_type = unique_type_id.expect_ty();
let &ty::Coroutine(coroutine_def_id, _) = coroutine_type.kind() else {
let &ty::Coroutine(coroutine_def_id, coroutine_args) = coroutine_type.kind() else {
bug!("build_coroutine_di_node() called with non-coroutine type: `{:?}`", coroutine_type)
};

Expand All @@ -158,8 +158,10 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
DIFlags::FlagZero,
),
|cx, coroutine_type_di_node| {
let coroutine_layout =
cx.tcx.optimized_mir(coroutine_def_id).coroutine_layout().unwrap();
let coroutine_layout = cx
.tcx
.coroutine_layout(coroutine_def_id, coroutine_args.as_coroutine().kind_ty())
.unwrap();

let Variants::Multiple { tag_encoding: TagEncoding::Direct, ref variants, .. } =
coroutine_type_and_layout.variants
Expand Down
22 changes: 11 additions & 11 deletions compiler/rustc_const_eval/src/transform/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,17 @@ impl<'tcx> MirPass<'tcx> for Validator {
}

// Enforce that coroutine-closure layouts are identical.
if let Some(layout) = body.coroutine_layout()
if let Some(layout) = body.coroutine_layout_raw()
&& let Some(by_move_body) = body.coroutine_by_move_body()
&& let Some(by_move_layout) = by_move_body.coroutine_layout()
&& let Some(by_move_layout) = by_move_body.coroutine_layout_raw()
{
if layout != by_move_layout {
// If this turns out not to be true, please let compiler-errors know.
// It is possible to support, but requires some changes to the layout
// computation code.
// FIXME(async_closures): We could do other validation here?
if layout.variant_fields.len() != by_move_layout.variant_fields.len() {
cfg_checker.fail(
Location::START,
format!(
"Coroutine layout differs from by-move coroutine layout:\n\
"Coroutine layout has different number of variant fields from \
by-move coroutine layout:\n\
layout: {layout:#?}\n\
by_move_layout: {by_move_layout:#?}",
),
Expand Down Expand Up @@ -715,13 +714,14 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
// args of the coroutine. Otherwise, we prefer to use this body
// since we may be in the process of computing this MIR in the
// first place.
let gen_body = if def_id == self.caller_body.source.def_id() {
self.caller_body
let layout = if def_id == self.caller_body.source.def_id() {
// FIXME: This is not right for async closures.
self.caller_body.coroutine_layout_raw()
} else {
self.tcx.optimized_mir(def_id)
self.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty())
};

let Some(layout) = gen_body.coroutine_layout() else {
let Some(layout) = layout else {
self.fail(
location,
format!("No coroutine layout for {parent_ty:?}"),
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -652,8 +652,9 @@ impl<'tcx> Body<'tcx> {
self.coroutine.as_ref().and_then(|coroutine| coroutine.resume_ty)
}

/// Prefer going through [`TyCtxt::coroutine_layout`] rather than using this directly.
#[inline]
pub fn coroutine_layout(&self) -> Option<&CoroutineLayout<'tcx>> {
pub fn coroutine_layout_raw(&self) -> Option<&CoroutineLayout<'tcx>> {
self.coroutine.as_ref().and_then(|coroutine| coroutine.coroutine_layout.as_ref())
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/mir/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ fn dump_matched_mir_node<'tcx, F>(
Some(promoted) => write!(file, "::{promoted:?}`")?,
}
writeln!(file, " {disambiguator} {pass_name}")?;
if let Some(ref layout) = body.coroutine_layout() {
if let Some(ref layout) = body.coroutine_layout_raw() {
writeln!(file, "/* coroutine_layout = {layout:#?} */")?;
}
writeln!(file)?;
Expand Down
37 changes: 35 additions & 2 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ pub use rustc_target::abi::{ReprFlags, ReprOptions};
pub use rustc_type_ir::{DebugWithInfcx, InferCtxtLike, WithInfcx};
pub use vtable::*;

use std::assert_matches::assert_matches;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use std::marker::PhantomData;
Expand Down Expand Up @@ -1826,8 +1827,40 @@ impl<'tcx> TyCtxt<'tcx> {

/// Returns layout of a coroutine. Layout might be unavailable if the
/// coroutine is tainted by errors.
pub fn coroutine_layout(self, def_id: DefId) -> Option<&'tcx CoroutineLayout<'tcx>> {
self.optimized_mir(def_id).coroutine_layout()
///
/// Takes `coroutine_kind` which can be acquired from the `CoroutineArgs::kind_ty`,
/// e.g. `args.as_coroutine().kind_ty()`.
pub fn coroutine_layout(
self,
def_id: DefId,
coroutine_kind_ty: Ty<'tcx>,
) -> Option<&'tcx CoroutineLayout<'tcx>> {
let mir = self.optimized_mir(def_id);
// Regular coroutine
if coroutine_kind_ty.is_unit() {
mir.coroutine_layout_raw()
} else {
// If we have a `Coroutine` that comes from an coroutine-closure,
// then it may be a by-move or by-ref body.
let ty::Coroutine(_, identity_args) =
*self.type_of(def_id).instantiate_identity().kind()
else {
unreachable!();
};
let identity_kind_ty = identity_args.as_coroutine().kind_ty();
// If the types differ, then we must be getting the by-move body of
// a by-ref coroutine.
if identity_kind_ty == coroutine_kind_ty {
mir.coroutine_layout_raw()
} else {
assert_matches!(coroutine_kind_ty.to_opt_closure_kind(), Some(ClosureKind::FnOnce));
assert_matches!(
identity_kind_ty.to_opt_closure_kind(),
Some(ClosureKind::Fn | ClosureKind::FnMut)
);
mir.coroutine_by_move_body().unwrap().coroutine_layout_raw()
}
}
}

/// Given the `DefId` of an impl, returns the `DefId` of the trait it implements.
Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,8 @@ impl<'tcx> CoroutineArgs<'tcx> {
#[inline]
pub fn variant_range(&self, def_id: DefId, tcx: TyCtxt<'tcx>) -> Range<VariantIdx> {
// FIXME requires optimized MIR
FIRST_VARIANT..tcx.coroutine_layout(def_id).unwrap().variant_fields.next_index()
FIRST_VARIANT
..tcx.coroutine_layout(def_id, tcx.types.unit).unwrap().variant_fields.next_index()
}

/// The discriminant for the given variant. Panics if the `variant_index` is
Expand Down Expand Up @@ -754,7 +755,7 @@ impl<'tcx> CoroutineArgs<'tcx> {
def_id: DefId,
tcx: TyCtxt<'tcx>,
) -> impl Iterator<Item: Iterator<Item = Ty<'tcx>> + Captures<'tcx>> {
let layout = tcx.coroutine_layout(def_id).unwrap();
let layout = tcx.coroutine_layout(def_id, self.kind_ty()).unwrap();
layout.variant_fields.iter().map(move |variant| {
variant.iter().map(move |field| {
ty::EarlyBinder::bind(layout.field_tys[*field].ty).instantiate(tcx, self.args)
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_ty_utils/src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ fn coroutine_layout<'tcx>(
let tcx = cx.tcx;
let instantiate_field = |ty: Ty<'tcx>| EarlyBinder::bind(ty).instantiate(tcx, args);

let Some(info) = tcx.coroutine_layout(def_id) else {
let Some(info) = tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()) else {
return Err(error(cx, LayoutError::Unknown(ty)));
};
let (ineligible_locals, assignments) = coroutine_saved_local_eligibility(info);
Expand Down Expand Up @@ -1072,7 +1072,7 @@ fn variant_info_for_coroutine<'tcx>(
return (vec![], None);
};

let coroutine = cx.tcx.optimized_mir(def_id).coroutine_layout().unwrap();
let coroutine = cx.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()).unwrap();
let upvar_names = cx.tcx.closure_saved_names_of_captured_variables(def_id);

let mut upvars_size = Size::ZERO;
Expand Down
Loading