Skip to content

Commit

Permalink
Auto merge of #123259 - scottmcm:tweak-if-const, r=<try>
Browse files Browse the repository at this point in the history
Fixup `if T::CONST` in MIR

r? ghost
  • Loading branch information
bors committed Mar 31, 2024
2 parents 5f358a8 + 2fb9d27 commit 3d3fbc3
Show file tree
Hide file tree
Showing 13 changed files with 467 additions and 350 deletions.
19 changes: 14 additions & 5 deletions compiler/rustc_codegen_ssa/src/mir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,15 +361,24 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
discr: &mir::Operand<'tcx>,
targets: &SwitchTargets,
) {
let discr = self.codegen_operand(bx, discr);
let discr_value = discr.immediate();
let switch_ty = discr.layout.ty;
// If our discriminant is a constant we can branch directly
if let Some(const_discr) = bx.const_to_opt_u128(discr_value, false) {
if let Some(const_op) = discr.constant() {
let const_value = self.eval_mir_constant(const_op);
let Some(const_discr) = const_value.try_to_bits_for_ty(
self.cx.tcx(),
ty::ParamEnv::reveal_all(),
const_op.ty(),
) else {
bug!("Failed to evaluate constant {discr:?} for SwitchInt terminator")
};
let target = targets.target_for_value(const_discr);
bx.br(helper.llbb_with_cleanup(self, target));
return;
};
}

let discr = self.codegen_operand(bx, discr);
let discr_value = discr.immediate();
let switch_ty = discr.layout.ty;

let mut target_iter = targets.iter();
if target_iter.len() == 1 {
Expand Down
21 changes: 6 additions & 15 deletions compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ impl<'tcx> Body<'tcx> {
};

// If this is a SwitchInt(const _), then we can just evaluate the constant and return.
// (The `SwitchConst` transform pass tries to ensure this.)
let discr = match discr {
Operand::Constant(constant) => {
let bits = eval_mono_const(constant);
Expand All @@ -773,24 +774,18 @@ impl<'tcx> Body<'tcx> {
Operand::Move(place) | Operand::Copy(place) => place,
};

// MIR for `if false` actually looks like this:
// _1 = const _
// SwitchInt(_1)
//
// And MIR for if intrinsics::debug_assertions() looks like this:
// _1 = cfg!(debug_assertions)
// SwitchInt(_1)
//
// So we're going to try to recognize this pattern.
//
// If we have a SwitchInt on a non-const place, we find the most recent statement that
// isn't a storage marker. If that statement is an assignment of a const to our
// discriminant place, we evaluate and return the const, as if we've const-propagated it
// into the SwitchInt.
// If we have a SwitchInt on a non-const place, we look at the last statement
// in the block. If that statement is an assignment of UbChecks to our
// discriminant place, we evaluate its value, as if we've
// const-propagated it into the SwitchInt.

let last_stmt = block.statements.iter().rev().find(|stmt| {
!matches!(stmt.kind, StatementKind::StorageDead(_) | StatementKind::StorageLive(_))
})?;
let last_stmt = block.statements.last()?;

let (place, rvalue) = last_stmt.kind.as_assign()?;

Expand All @@ -802,10 +797,6 @@ impl<'tcx> Body<'tcx> {
Rvalue::NullaryOp(NullOp::UbChecks, _) => {
Some((tcx.sess.opts.debug_assertions as u128, targets))
}
Rvalue::Use(Operand::Constant(constant)) => {
let bits = eval_mono_const(constant);
Some((bits, targets))
}
_ => None,
}
}
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ pub mod simplify;
mod simplify_branches;
mod simplify_comparison_integral;
mod sroa;
mod switch_const;
mod uninhabited_enum_branching;
mod unreachable_prop;

Expand Down Expand Up @@ -600,6 +601,8 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&simplify::SimplifyLocals::AfterGVN,
&dataflow_const_prop::DataflowConstProp,
&const_debuginfo::ConstDebugInfo,
// GVN & ConstProp often don't fixup unevaluatable constants
&switch_const::SwitchConst,
&o1(simplify_branches::SimplifyConstCondition::AfterConstProp),
&jump_threading::JumpThreading,
&early_otherwise_branch::EarlyOtherwiseBranch,
Expand Down
55 changes: 55 additions & 0 deletions compiler/rustc_mir_transform/src/switch_const.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//! A pass that makes `SwitchInt`-on-`const` more obvious to later code.
use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;

/// A `MirPass` for simplifying `if T::CONST`.
///
/// Today, MIR building for things like `if T::IS_ZST` introduce a constant
/// for the copy of the bool, so it ends up in MIR as
/// `_1 = CONST; switchInt (move _1)` or `_2 = CONST; switchInt (_2)`.
///
/// This pass is very specifically targeted at *exactly* those patterns.
/// It can absolutely be replaced with a more general pass should we get one that
/// we can run in low optimization levels, but at the time of writing even in
/// optimized builds this wasn't simplified.
#[derive(Default)]
pub struct SwitchConst;

impl<'tcx> MirPass<'tcx> for SwitchConst {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
for block in body.basic_blocks.as_mut_preserves_cfg() {
let switch_local = if let TerminatorKind::SwitchInt { discr, .. } =
&block.terminator().kind
&& let Some(place) = discr.place()
&& let Some(local) = place.as_local()
{
local
} else {
continue;
};

let new_operand = if let Some(statement) = block.statements.last()
&& let StatementKind::Assign(place_and_rvalue) = &statement.kind
&& let Some(local) = place_and_rvalue.0.as_local()
&& local == switch_local
&& let Rvalue::Use(operand) = &place_and_rvalue.1
&& let Operand::Constant(_) = operand
{
operand.clone()
} else {
continue;
};

if !tcx.consider_optimizing(|| format!("SwitchConst: switchInt(move {switch_local:?}"))
{
break;
}

let TerminatorKind::SwitchInt { discr, .. } = &mut block.terminator_mut().kind else {
bug!("Somehow wasn't a switchInt any more?")
};
*discr = new_operand;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// MIR for `check_bool` after PreCodegen

fn check_bool() -> u32 {
let mut _0: u32;

bb0: {
switchInt(const <T as TraitWithBool>::FLAG) -> [0: bb1, otherwise: bb2];
}

bb1: {
_0 = const 456_u32;
goto -> bb3;
}

bb2: {
_0 = const 123_u32;
goto -> bb3;
}

bb3: {
return;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// MIR for `check_int` after PreCodegen

fn check_int() -> u32 {
let mut _0: u32;

bb0: {
switchInt(const <T as TraitWithInt>::VALUE) -> [1: bb1, 2: bb2, 3: bb3, otherwise: bb4];
}

bb1: {
_0 = const 123_u32;
goto -> bb5;
}

bb2: {
_0 = const 456_u32;
goto -> bb5;
}

bb3: {
_0 = const 789_u32;
goto -> bb5;
}

bb4: {
_0 = const 0_u32;
goto -> bb5;
}

bb5: {
return;
}
}
27 changes: 27 additions & 0 deletions tests/mir-opt/pre-codegen/if_associated_const.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// skip-filecheck
//@ compile-flags: -O -Zmir-opt-level=2 -Cdebuginfo=2

#![crate_type = "lib"]

pub trait TraitWithBool {
const FLAG: bool;
}

// EMIT_MIR if_associated_const.check_bool.PreCodegen.after.mir
pub fn check_bool<T: TraitWithBool>() -> u32 {
if T::FLAG { 123 } else { 456 }
}

pub trait TraitWithInt {
const VALUE: i32;
}

// EMIT_MIR if_associated_const.check_int.PreCodegen.after.mir
pub fn check_int<T: TraitWithInt>() -> u32 {
match T::VALUE {
1 => 123,
2 => 456,
3 => 789,
_ => 0,
}
}
Loading

0 comments on commit 3d3fbc3

Please sign in to comment.