From 49c4ebcc409b1537fb0cf99134f5166481096c5f Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Fri, 5 Apr 2024 16:48:12 -0400 Subject: [PATCH] Check the base of the place too! --- .../src/coroutine/by_move_body.rs | 26 +++++++++++++----- .../async-closures/overlapping-projs.rs | 27 +++++++++++++++++++ .../overlapping-projs.run.stdout | 1 + 3 files changed, 48 insertions(+), 6 deletions(-) create mode 100644 tests/ui/async-await/async-closures/overlapping-projs.rs create mode 100644 tests/ui/async-await/async-closures/overlapping-projs.run.stdout diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs index a62fe4af81045..d94441b1413c7 100644 --- a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs +++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs @@ -71,7 +71,7 @@ use rustc_data_structures::unord::UnordMap; use rustc_hir as hir; -use rustc_middle::hir::place::{Projection, ProjectionKind}; +use rustc_middle::hir::place::{PlaceBase, Projection, ProjectionKind}; use rustc_middle::mir::visit::MutVisitor; use rustc_middle::mir::{self, dump_mir, MirPass}; use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt}; @@ -149,17 +149,25 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody { bug!("we ran out of parent captures!") }; + let PlaceBase::Upvar(parent_base) = parent_capture.place.base else { + bug!("expected capture to be an upvar"); + }; + let PlaceBase::Upvar(child_base) = child_capture.place.base else { + bug!("expected capture to be an upvar"); + }; + assert!( child_capture.place.projections.len() >= parent_capture.place.projections.len() ); // A parent matches a child they share the same prefix of projections. // The child may have more, if it is capturing sub-fields out of // something that is captured by-move in the parent closure. - if !std::iter::zip( - &child_capture.place.projections, - &parent_capture.place.projections, - ) - .all(|(child, parent)| child.kind == parent.kind) + if parent_base.var_path.hir_id != child_base.var_path.hir_id + || !std::iter::zip( + &child_capture.place.projections, + &parent_capture.place.projections, + ) + .all(|(child, parent)| child.kind == parent.kind) { // Make sure the field was used at least once. assert!( @@ -217,6 +225,12 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody { } } + // Pop the last parent capture + if field_used_at_least_once { + let _ = parent_captures.next().unwrap(); + } + assert_eq!(parent_captures.next(), None, "leftover parent captures?"); + if coroutine_kind == ty::ClosureKind::FnOnce { assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len()); return; diff --git a/tests/ui/async-await/async-closures/overlapping-projs.rs b/tests/ui/async-await/async-closures/overlapping-projs.rs new file mode 100644 index 0000000000000..6dd00b16103f7 --- /dev/null +++ b/tests/ui/async-await/async-closures/overlapping-projs.rs @@ -0,0 +1,27 @@ +//@ aux-build:block-on.rs +//@ edition:2021 +//@ run-pass +//@ check-run-results + +#![feature(async_closure)] + +extern crate block_on; + +async fn call_once(f: impl async FnOnce()) { + f().await; +} + +async fn async_main() { + let x = &mut 0; + let y = &mut 0; + let c = async || { + *x = 1; + *y = 2; + }; + call_once(c).await; + println!("{x} {y}"); +} + +fn main() { + block_on::block_on(async_main()); +} diff --git a/tests/ui/async-await/async-closures/overlapping-projs.run.stdout b/tests/ui/async-await/async-closures/overlapping-projs.run.stdout new file mode 100644 index 0000000000000..8d04f961a0371 --- /dev/null +++ b/tests/ui/async-await/async-closures/overlapping-projs.run.stdout @@ -0,0 +1 @@ +1 2