Skip to content

Commit

Permalink
Use an interpreter in jump threading.
Browse files Browse the repository at this point in the history
  • Loading branch information
cjgillot committed Dec 23, 2023
1 parent 396af32 commit c3b8306
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 127 deletions.
98 changes: 74 additions & 24 deletions compiler/rustc_mir_transform/src/jump_threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,19 @@
//! cost by `MAX_COST`.

use rustc_arena::DroplessArena;
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
use rustc_data_structures::fx::FxHashSet;
use rustc_index::bit_set::BitSet;
use rustc_index::IndexVec;
use rustc_middle::mir::interpret::Scalar;
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
use rustc_middle::ty::{self, ScalarInt, TyCtxt};
use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
use rustc_span::DUMMY_SP;

use crate::cost_checker::CostChecker;
use crate::dataflow_const_prop::DummyMachine;

pub struct JumpThreading;

Expand All @@ -70,6 +74,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
let mut finder = TOFinder {
tcx,
param_env,
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
body,
arena: &arena,
map: &map,
Expand Down Expand Up @@ -141,6 +146,7 @@ struct ThreadingOpportunity {
struct TOFinder<'tcx, 'a> {
tcx: TyCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
body: &'a Body<'tcx>,
map: &'a Map,
loop_headers: &'a BitSet<BasicBlock>,
Expand Down Expand Up @@ -328,25 +334,75 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
}

#[instrument(level = "trace", skip(self))]
fn process_operand(
fn process_immediate(
&mut self,
bb: BasicBlock,
lhs: PlaceIndex,
rhs: &Operand<'tcx>,
rhs: ImmTy<'tcx>,
state: &mut State<ConditionSet<'a>>,
) -> Option<!> {
let register_opportunity = |c: Condition| {
debug!(?bb, ?c.target, "register");
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
};

let conditions = state.try_get_idx(lhs, self.map)?;
if let Immediate::Scalar(Scalar::Int(int)) = *rhs {
conditions.iter_matches(int).for_each(register_opportunity);
}

None
}

#[instrument(level = "trace", skip(self))]
fn process_operand(
&mut self,
bb: BasicBlock,
lhs: PlaceIndex,
rhs: &Operand<'tcx>,
state: &mut State<ConditionSet<'a>>,
) -> Option<!> {
match rhs {
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Operand::Constant(constant) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let constant =
constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
conditions.iter_matches(constant).for_each(register_opportunity);
let constant = self.ecx.eval_mir_constant(&constant.const_, None, None).ok()?;
self.map.for_each_projection_value(
lhs,
constant,
&mut |elem, op| match elem {
TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(),
TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(),
TrackElem::Discriminant => {
let variant = self.ecx.read_discriminant(op).ok()?;
let discr_value =
self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?;
Some(discr_value.into())
}
TrackElem::DerefLen => {
let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into();
let len_usize = op.len(&self.ecx).ok()?;
let layout = self
.tcx
.layout_of(self.param_env.and(self.tcx.types.usize))
.unwrap();
Some(ImmTy::from_uint(len_usize, layout).into())
}
},
&mut |place, op| {
if let Some(conditions) = state.try_get_idx(place, self.map)
&& let Ok(imm) = self.ecx.read_immediate_raw(op)
&& let Some(imm) = imm.right()
&& let Immediate::Scalar(Scalar::Int(int)) = *imm
{
conditions.iter_matches(int).for_each(|c: Condition| {
self.opportunities.push(ThreadingOpportunity {
chain: vec![bb],
target: c.target,
})
})
}
},
);
}
// Transfer the conditions on the copied rhs.
Operand::Move(rhs) | Operand::Copy(rhs) => {
Expand All @@ -373,26 +429,14 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
// Below, `lhs` is the return value of `mutated_statement`,
// the place to which `conditions` apply.

let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| {
let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?;
Some(Operand::const_from_scalar(
self.tcx,
discr.ty,
scalar.into(),
rustc_span::DUMMY_SP,
))
};

match &stmt.kind {
// If we expect `discriminant(place) ?= A`,
// we have an opportunity if `variant_index ?= A`.
StatementKind::SetDiscriminant { box place, variant_index } => {
let discr_target = self.map.find_discr(place.as_ref())?;
let enum_ty = place.ty(self.body, self.tcx).ty;
let discr = discriminant_for_variant(enum_ty, *variant_index)?;
self.process_operand(bb, discr_target, &discr, state)?;
let discr = self.ecx.discriminant_for_variant(enum_ty, *variant_index).ok()?;
self.process_immediate(bb, discr_target, discr, state)?;
}
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
Expand Down Expand Up @@ -422,10 +466,16 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
if let Some(discr_target) =
self.map.apply(lhs, TrackElem::Discriminant)
&& let Some(discr_value) =
discriminant_for_variant(agg_ty, *variant_index)
&& let Ok(discr_value) = self
.ecx
.discriminant_for_variant(agg_ty, *variant_index)
{
self.process_operand(bb, discr_target, &discr_value, state);
self.process_immediate(
bb,
discr_target,
discr_value,
state,
);
}
self.map.apply(lhs, TrackElem::Variant(*variant_index))?
}
Expand Down
8 changes: 4 additions & 4 deletions tests/coverage/partial_eq.cov-map
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ Number of file 0 mappings: 2
- Code(Zero) at (prev + 0, 32) to (start + 0, 33)

Function name: <partial_eq::Version as core::cmp::PartialOrd>::partial_cmp
Raw bytes (22): 0x[01, 01, 04, 07, 0b, 00, 09, 0f, 15, 00, 11, 02, 01, 04, 27, 00, 28, 03, 00, 30, 00, 31]
Raw bytes (22): 0x[01, 01, 04, 07, 0b, 05, 09, 0f, 15, 0d, 11, 02, 01, 04, 27, 00, 28, 03, 00, 30, 00, 31]
Number of files: 1
- file 0 => global file 1
Number of expressions: 4
- expression 0 operands: lhs = Expression(1, Add), rhs = Expression(2, Add)
- expression 1 operands: lhs = Zero, rhs = Counter(2)
- expression 1 operands: lhs = Counter(1), rhs = Counter(2)
- expression 2 operands: lhs = Expression(3, Add), rhs = Counter(5)
- expression 3 operands: lhs = Zero, rhs = Counter(4)
- expression 3 operands: lhs = Counter(3), rhs = Counter(4)
Number of file 0 mappings: 2
- Code(Counter(0)) at (prev + 4, 39) to (start + 0, 40)
- Code(Expression(0, Add)) at (prev + 0, 48) to (start + 0, 49)
= ((Zero + c2) + ((Zero + c4) + c5))
= ((c1 + c2) + ((c3 + c4) + c5))

Function name: <partial_eq::Version as core::fmt::Debug>::fmt
Raw bytes (9): 0x[01, 01, 00, 01, 01, 04, 11, 00, 16]
Expand Down
Loading

0 comments on commit c3b8306

Please sign in to comment.