Skip to content

Commit

Permalink
Avoid lowering code under dead SwitchInt targets
Browse files Browse the repository at this point in the history
  • Loading branch information
saethlin committed Feb 22, 2024
1 parent d7bd9cd commit 6b3d152
Show file tree
Hide file tree
Showing 13 changed files with 512 additions and 334 deletions.
10 changes: 10 additions & 0 deletions compiler/rustc_codegen_ssa/src/mir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,16 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
}
}

pub fn codegen_block_as_unreachable(&mut self, bb: mir::BasicBlock) {
let llbb = match self.try_llbb(bb) {
Some(llbb) => llbb,
None => return,
};
let bx = &mut Bx::build(self.cx, llbb);
debug!("codegen_block_as_unreachable({:?})", bb);
bx.unreachable();
}

fn codegen_terminator(
&mut self,
bx: &mut Bx,
Expand Down
11 changes: 10 additions & 1 deletion compiler/rustc_codegen_ssa/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,22 @@ pub fn codegen_mir<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
// Apply debuginfo to the newly allocated locals.
fx.debug_introduce_locals(&mut start_bx);

let reachable_blocks = mir.reachable_blocks_in_mono(cx.tcx(), instance);

// The builders will be created separately for each basic block at `codegen_block`.
// So drop the builder of `start_llbb` to avoid having two at the same time.
drop(start_bx);

// Codegen the body of each block using reverse postorder
for (bb, _) in traversal::reverse_postorder(mir) {
fx.codegen_block(bb);
if reachable_blocks.contains(bb) {
fx.codegen_block(bb);
} else {
// This may have references to things we didn't monomorphize, so we
// don't actually codegen the body. We still create the block so
// terminators in other blocks can reference it without worry.
fx.codegen_block_as_unreachable(bb);
}
}
}

Expand Down
70 changes: 69 additions & 1 deletion compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::ty::print::{pretty_print_const, with_no_trimmed_paths};
use crate::ty::print::{FmtPrinter, Printer};
use crate::ty::visit::TypeVisitableExt;
use crate::ty::{self, List, Ty, TyCtxt};
use crate::ty::{AdtDef, InstanceDef, UserTypeAnnotationIndex};
use crate::ty::{AdtDef, Instance, InstanceDef, UserTypeAnnotationIndex};
use crate::ty::{GenericArg, GenericArgsRef};

use rustc_data_structures::captures::Captures;
Expand All @@ -29,6 +29,7 @@ pub use rustc_ast::Mutability;
use rustc_data_structures::fx::FxHashMap;
use rustc_data_structures::fx::FxHashSet;
use rustc_data_structures::graph::dominators::Dominators;
use rustc_index::bit_set::BitSet;
use rustc_index::{Idx, IndexSlice, IndexVec};
use rustc_serialize::{Decodable, Encodable};
use rustc_span::symbol::Symbol;
Expand Down Expand Up @@ -642,6 +643,73 @@ impl<'tcx> Body<'tcx> {
self.injection_phase.is_some()
}

/// Finds which basic blocks are actually reachable for a specific
/// monomorphization of this body.
///
/// This is allowed to have false positives; just because this says a block
/// is reachable doesn't mean that's necessarily true. It's thus always
/// legal for this to return a filled set.
///
/// Regardless, the [`BitSet::domain_size`] of the returned set will always
/// exactly match the number of blocks in the body so that `contains`
/// checks can be done without worrying about panicking.
///
/// The main case this supports is filtering out `if <T as Trait>::CONST`
/// bodies that can't be removed in generic MIR, but *can* be removed once
/// the specific `T` is known.
///
/// This is used in the monomorphization collector as well as in codegen.
pub fn reachable_blocks_in_mono(
&self,
tcx: TyCtxt<'tcx>,
instance: Instance<'tcx>,
) -> BitSet<BasicBlock> {
if instance.args.non_erasable_generics(tcx, instance.def_id()).next().is_none() {
// If it's non-generic, then mir-opt const prop has already run, meaning it's
// probably not worth doing any further filtering. So call everything reachable.
return BitSet::new_filled(self.basic_blocks.len());
}

let mut set = BitSet::new_empty(self.basic_blocks.len());
self.reachable_blocks_in_mono_from(tcx, instance, &mut set, START_BLOCK);
set
}

fn reachable_blocks_in_mono_from(
&self,
tcx: TyCtxt<'tcx>,
instance: Instance<'tcx>,
set: &mut BitSet<BasicBlock>,
bb: BasicBlock,
) {
if !set.insert(bb) {
return;
}

let data = &self.basic_blocks[bb];

if let TerminatorKind::SwitchInt { discr: Operand::Constant(constant), targets } =
&data.terminator().kind
{
let env = ty::ParamEnv::reveal_all();
let mono_literal = instance.instantiate_mir_and_normalize_erasing_regions(
tcx,
env,
crate::ty::EarlyBinder::bind(constant.const_),
);
if let Some(bits) = mono_literal.try_eval_bits(tcx, env) {
let target = targets.target_for_value(bits);
return self.reachable_blocks_in_mono_from(tcx, instance, set, target);
} else {
bug!("Couldn't evaluate constant {:?} in mono {:?}", constant, instance);
}
}

for target in data.terminator().successors() {
self.reachable_blocks_in_mono_from(tcx, instance, set, target);
}
}

/// For a `Location` in this scope, determine what the "caller location" at that point is. This
/// is interesting because of inlining: the `#[track_caller]` attribute of inlined functions
/// must be honored. Falls back to the `tracked_caller` value for `#[track_caller]` functions,
Expand Down
2 changes: 0 additions & 2 deletions compiler/rustc_middle/src/mir/traversal.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use rustc_index::bit_set::BitSet;

use super::*;

/// Preorder traversal of a graph.
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ mod check_alignment;
pub mod simplify;
mod simplify_branches;
mod simplify_comparison_integral;
mod simplify_if_const;
mod sroa;
mod uninhabited_enum_branching;
mod unreachable_prop;
Expand Down Expand Up @@ -616,6 +617,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&large_enums::EnumSizeOpt { discrepancy: 128 },
// Some cleanup necessary at least for LLVM and potentially other codegen backends.
&add_call_guards::CriticalCallEdges,
&simplify_if_const::SimplifyIfConst,
// Cleanup for human readability, off by default.
&prettify::ReorderBasicBlocks,
&prettify::ReorderLocals,
Expand Down
76 changes: 76 additions & 0 deletions compiler/rustc_mir_transform/src/simplify_if_const.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
//! A pass that simplifies branches when their condition is known.

use crate::MirPass;
use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;

/// The lowering for `if CONST` produces
/// ```
/// _1 = Const(...);
/// switchInt (move _1)
/// ```
/// so this pass replaces that with
/// ```
/// switchInt (Const(...))
/// ```
/// so that further MIR consumers can special-case it more easily.
///
/// Unlike ConstProp, this supports generic constants too, not just concrete ones.
pub struct SimplifyIfConst;

impl<'tcx> MirPass<'tcx> for SimplifyIfConst {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
for block in body.basic_blocks_mut() {
simplify_assign_move_switch(tcx, block);
}
}
}

fn simplify_assign_move_switch(tcx: TyCtxt<'_>, block: &mut BasicBlockData<'_>) {
let Some(Terminator { kind: TerminatorKind::SwitchInt { discr: switch_desc, .. }, .. }) =
&mut block.terminator
else {
return;
};

let &mut Operand::Move(switch_place) = &mut *switch_desc else { return };

let Some(switch_local) = switch_place.as_local() else { return };

let Some(last_statement) = block.statements.last_mut() else { return };

let StatementKind::Assign(boxed_place_rvalue) = &last_statement.kind else { return };

let Some(assigned_local) = boxed_place_rvalue.0.as_local() else { return };

if switch_local != assigned_local {
return;
}

if !matches!(boxed_place_rvalue.1, Rvalue::Use(Operand::Constant(_))) {
return;
}

let should_optimize = tcx.consider_optimizing(|| {
format!(
"SimplifyBranches - Assignment: {:?} SourceInfo: {:?}",
boxed_place_rvalue, last_statement.source_info
)
});

if should_optimize {
let Some(last_statement) = block.statements.pop() else {
bug!("Somehow the statement disappeared?");
};

let StatementKind::Assign(boxed_place_rvalue) = last_statement.kind else {
bug!("Somehow it's not an assignment any more?");
};

let Rvalue::Use(assigned_constant @ Operand::Constant(_)) = boxed_place_rvalue.1 else {
bug!("Somehow it's not a use of a constant any more?");
};

*switch_desc = assigned_constant;
}
}
39 changes: 39 additions & 0 deletions tests/codegen/skip-mono-inside-if-false.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// compile-flags: -O -C no-prepopulate-passes

#![crate_type = "lib"]

#[no_mangle]
pub fn demo_for_i32() {
generic_impl::<i32>();
}

// Two important things here:
// - We replace the "then" block with `unreachable` to avoid linking problems
// - We neither declare nor define the `big_impl` that said block "calls".

// CHECK-LABEL: ; skip_mono_inside_if_false::generic_impl
// CHECK: start:
// CHECK-NEXT: br i1 false, label %[[THEN_BRANCH:bb[0-9]+]], label %[[ELSE_BRANCH:bb[0-9]+]]
// CHECK: [[ELSE_BRANCH]]:
// CHECK-NEXT: call skip_mono_inside_if_false::small_impl
// CHECK: [[THEN_BRANCH]]:
// CHECK-NEXT: unreachable

fn generic_impl<T>() {
trait MagicTrait {
const IS_BIG: bool;
}
impl<T> MagicTrait for T {
const IS_BIG: bool = std::mem::size_of::<T>() > 10;
}
if T::IS_BIG {
big_impl::<T>();
} else {
small_impl::<T>();
}
}

#[inline(never)]
fn small_impl<T>() {}
#[inline(never)]
fn big_impl<T>() {}
Loading

0 comments on commit 6b3d152

Please sign in to comment.