From 81d630453b08253dc1d6bcc4139dde16530887d9 Mon Sep 17 00:00:00 2001 From: Ben Kimock Date: Wed, 21 Feb 2024 18:58:12 -0500 Subject: [PATCH] Avoid lowering code under dead SwitchInt targets --- compiler/rustc_codegen_ssa/src/mir/block.rs | 10 ++ compiler/rustc_codegen_ssa/src/mir/mod.rs | 11 +- compiler/rustc_middle/src/mir/mod.rs | 127 +++++++++++++++++++- compiler/rustc_middle/src/mir/traversal.rs | 2 - tests/codegen/precondition-checks.rs | 27 +++++ tests/codegen/skip-mono-inside-if-false.rs | 41 +++++++ 6 files changed, 214 insertions(+), 4 deletions(-) create mode 100644 tests/codegen/precondition-checks.rs create mode 100644 tests/codegen/skip-mono-inside-if-false.rs diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index d3f5de25d9a23..9bb2a52826585 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -1237,6 +1237,16 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { } } + pub fn codegen_block_as_unreachable(&mut self, bb: mir::BasicBlock) { + let llbb = match self.try_llbb(bb) { + Some(llbb) => llbb, + None => return, + }; + let bx = &mut Bx::build(self.cx, llbb); + debug!("codegen_block_as_unreachable({:?})", bb); + bx.unreachable(); + } + fn codegen_terminator( &mut self, bx: &mut Bx, diff --git a/compiler/rustc_codegen_ssa/src/mir/mod.rs b/compiler/rustc_codegen_ssa/src/mir/mod.rs index a6fcf1fd38c1f..bac10f313366a 100644 --- a/compiler/rustc_codegen_ssa/src/mir/mod.rs +++ b/compiler/rustc_codegen_ssa/src/mir/mod.rs @@ -256,13 +256,22 @@ pub fn codegen_mir<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( // Apply debuginfo to the newly allocated locals. fx.debug_introduce_locals(&mut start_bx); + let reachable_blocks = mir.reachable_blocks_in_mono(cx.tcx(), instance); + // The builders will be created separately for each basic block at `codegen_block`. // So drop the builder of `start_llbb` to avoid having two at the same time. drop(start_bx); // Codegen the body of each block using reverse postorder for (bb, _) in traversal::reverse_postorder(mir) { - fx.codegen_block(bb); + if reachable_blocks.contains(bb) { + fx.codegen_block(bb); + } else { + // This may have references to things we didn't monomorphize, so we + // don't actually codegen the body. We still create the block so + // terminators in other blocks can reference it without worry. + fx.codegen_block_as_unreachable(bb); + } } } diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs index 4b5a08d6af3ce..b71c614dc4fe9 100644 --- a/compiler/rustc_middle/src/mir/mod.rs +++ b/compiler/rustc_middle/src/mir/mod.rs @@ -10,7 +10,7 @@ use crate::ty::print::{pretty_print_const, with_no_trimmed_paths}; use crate::ty::print::{FmtPrinter, Printer}; use crate::ty::visit::TypeVisitableExt; use crate::ty::{self, List, Ty, TyCtxt}; -use crate::ty::{AdtDef, InstanceDef, UserTypeAnnotationIndex}; +use crate::ty::{AdtDef, Instance, InstanceDef, UserTypeAnnotationIndex}; use crate::ty::{GenericArg, GenericArgsRef}; use rustc_data_structures::captures::Captures; @@ -27,6 +27,8 @@ pub use rustc_ast::Mutability; use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::fx::FxHashSet; use rustc_data_structures::graph::dominators::Dominators; +use rustc_data_structures::stack::ensure_sufficient_stack; +use rustc_index::bit_set::BitSet; use rustc_index::{Idx, IndexSlice, IndexVec}; use rustc_serialize::{Decodable, Encodable}; use rustc_span::symbol::Symbol; @@ -640,6 +642,129 @@ impl<'tcx> Body<'tcx> { self.injection_phase.is_some() } + /// Finds which basic blocks are actually reachable for a specific + /// monomorphization of this body. + /// + /// This is allowed to have false positives; just because this says a block + /// is reachable doesn't mean that's necessarily true. It's thus always + /// legal for this to return a filled set. + /// + /// Regardless, the [`BitSet::domain_size`] of the returned set will always + /// exactly match the number of blocks in the body so that `contains` + /// checks can be done without worrying about panicking. + /// + /// This is mostly useful because it lets us skip lowering the `false` side + /// of `if ::CONST`, as well as `intrinsics::debug_assertions`. + pub fn reachable_blocks_in_mono( + &self, + tcx: TyCtxt<'tcx>, + instance: Instance<'tcx>, + ) -> BitSet { + let mut set = BitSet::new_empty(self.basic_blocks.len()); + self.reachable_blocks_in_mono_from(tcx, instance, &mut set, START_BLOCK); + set + } + + fn reachable_blocks_in_mono_from( + &self, + tcx: TyCtxt<'tcx>, + instance: Instance<'tcx>, + set: &mut BitSet, + bb: BasicBlock, + ) { + if !set.insert(bb) { + return; + } + + let data = &self.basic_blocks[bb]; + + if let Some((bits, targets)) = Self::try_const_mono_switchint(tcx, instance, data) { + let target = targets.target_for_value(bits); + ensure_sufficient_stack(|| { + self.reachable_blocks_in_mono_from(tcx, instance, set, target) + }); + return; + } + + for target in data.terminator().successors() { + ensure_sufficient_stack(|| { + self.reachable_blocks_in_mono_from(tcx, instance, set, target) + }); + } + } + + /// If this basic block ends with a [`TerminatorKind::SwitchInt`] for which we can evaluate the + /// dimscriminant in monomorphization, we return the discriminant bits and the + /// [`SwitchTargets`], just so the caller doesn't also have to match on the terminator. + fn try_const_mono_switchint<'a>( + tcx: TyCtxt<'tcx>, + instance: Instance<'tcx>, + block: &'a BasicBlockData<'tcx>, + ) -> Option<(u128, &'a SwitchTargets)> { + // There are two places here we need to evaluate a constant. + let eval_mono_const = |constant: &ConstOperand<'tcx>| { + let env = ty::ParamEnv::reveal_all(); + let mono_literal = instance.instantiate_mir_and_normalize_erasing_regions( + tcx, + env, + crate::ty::EarlyBinder::bind(constant.const_), + ); + let Some(bits) = mono_literal.try_eval_bits(tcx, env) else { + bug!("Couldn't evaluate constant {:?} in mono {:?}", constant, instance); + }; + bits + }; + + let TerminatorKind::SwitchInt { discr, targets } = &block.terminator().kind else { + return None; + }; + + // If this is a SwitchInt(const _), then we can just evaluate the constant and return. + let discr = match discr { + Operand::Constant(constant) => { + let bits = eval_mono_const(constant); + return Some((bits, targets)); + } + Operand::Move(place) | Operand::Copy(place) => place, + }; + + // MIR for `if false` actually looks like this: + // _1 = const _ + // SwitchInt(_1) + // + // And MIR for if intrinsics::debug_assertions() looks like this: + // _1 = cfg!(debug_assertions) + // SwitchInt(_1) + // + // So we're going to try to recognize this pattern. + // + // If we have a SwitchInt on a non-const place, we find the most recent statement that + // isn't a storage marker. If that statement is an assignment of a const to our + // discriminant place, we evaluate and return the const, as if we've const-propagated it + // into the SwitchInt. + + let last_stmt = block.statements.iter().rev().find(|stmt| { + !matches!(stmt.kind, StatementKind::StorageDead(_) | StatementKind::StorageLive(_)) + })?; + + let (place, rvalue) = last_stmt.kind.as_assign()?; + + if discr != place { + return None; + } + + match rvalue { + Rvalue::NullaryOp(NullOp::UbCheck(_), _) => { + Some((tcx.sess.opts.debug_assertions as u128, targets)) + } + Rvalue::Use(Operand::Constant(constant)) => { + let bits = eval_mono_const(constant); + Some((bits, targets)) + } + _ => None, + } + } + /// For a `Location` in this scope, determine what the "caller location" at that point is. This /// is interesting because of inlining: the `#[track_caller]` attribute of inlined functions /// must be honored. Falls back to the `tracked_caller` value for `#[track_caller]` functions, diff --git a/compiler/rustc_middle/src/mir/traversal.rs b/compiler/rustc_middle/src/mir/traversal.rs index a1ff8410eac4a..0a938bcd31562 100644 --- a/compiler/rustc_middle/src/mir/traversal.rs +++ b/compiler/rustc_middle/src/mir/traversal.rs @@ -1,5 +1,3 @@ -use rustc_index::bit_set::BitSet; - use super::*; /// Preorder traversal of a graph. diff --git a/tests/codegen/precondition-checks.rs b/tests/codegen/precondition-checks.rs new file mode 100644 index 0000000000000..1914944500374 --- /dev/null +++ b/tests/codegen/precondition-checks.rs @@ -0,0 +1,27 @@ +//@ compile-flags: -Cno-prepopulate-passes -Copt-level=0 -Cdebug-assertions=no + +// This test ensures that in a debug build which turns off debug assertions, we do not monomorphize +// any of the standard library's unsafe precondition checks. +// The naive codegen of those checks contains the actual check underneath an `if false`, which +// could be optimized out if optimizations are enabled. But if we rely on optimizations to remove +// panic branches, then we can't link compiler_builtins without optimizing it, which means that +// -Zbuild-std doesn't work with -Copt-level=0. +// +// In other words, this tests for a mandatory optimization. + +#![crate_type = "lib"] + +use std::ptr::NonNull; + +// CHECK-LABEL: ; core::ptr::non_null::NonNull::new_unchecked +// CHECK-NOT: call +// CHECK: } + +// CHECK-LABEL: @nonnull_new +#[no_mangle] +pub unsafe fn nonnull_new(ptr: *mut u8) -> NonNull { + // CHECK: ; call core::ptr::non_null::NonNull::new_unchecked + unsafe { + NonNull::new_unchecked(ptr) + } +} diff --git a/tests/codegen/skip-mono-inside-if-false.rs b/tests/codegen/skip-mono-inside-if-false.rs new file mode 100644 index 0000000000000..8b95de99dd3bd --- /dev/null +++ b/tests/codegen/skip-mono-inside-if-false.rs @@ -0,0 +1,41 @@ +//@ compile-flags: -Cno-prepopulate-passes -Copt-level=0 + +#![crate_type = "lib"] + +#[no_mangle] +pub fn demo_for_i32() { + generic_impl::(); +} + +// Two important things here: +// - We replace the "then" block with `unreachable` to avoid linking problems +// - We neither declare nor define the `big_impl` that said block "calls". + +// CHECK-LABEL: ; skip_mono_inside_if_false::generic_impl +// CHECK: start: +// CHECK-NEXT: br label %[[ELSE_BRANCH:bb[0-9]+]] +// CHECK: [[ELSE_BRANCH]]: +// CHECK-NEXT: call skip_mono_inside_if_false::small_impl +// CHECK: bb{{[0-9]+}}: +// CHECK-NEXT: ret void +// CHECK: bb{{[0-9+]}}: +// CHECK-NEXT: unreachable + +fn generic_impl() { + trait MagicTrait { + const IS_BIG: bool; + } + impl MagicTrait for T { + const IS_BIG: bool = std::mem::size_of::() > 10; + } + if T::IS_BIG { + big_impl::(); + } else { + small_impl::(); + } +} + +#[inline(never)] +fn small_impl() {} +#[inline(never)] +fn big_impl() {}