diff --git a/src/librustc_mir/transform/mod.rs b/src/librustc_mir/transform/mod.rs index 57e002bf3f3d6..a793124542c5f 100644 --- a/src/librustc_mir/transform/mod.rs +++ b/src/librustc_mir/transform/mod.rs @@ -311,12 +311,12 @@ fn run_optimization_passes<'tcx>( &const_prop::ConstProp, &simplify_branches::SimplifyBranches::new("after-const-prop"), &deaggregator::Deaggregator, + &simplify_try::SimplifyArmIdentity, + &simplify_try::SimplifyBranchSame, ©_prop::CopyPropagation, &simplify_branches::SimplifyBranches::new("after-copy-prop"), &remove_noop_landing_pads::RemoveNoopLandingPads, &simplify::SimplifyCfg::new("after-remove-noop-landing-pads"), - &simplify_try::SimplifyArmIdentity, - &simplify_try::SimplifyBranchSame, &simplify::SimplifyCfg::new("final"), &simplify::SimplifyLocals, &add_call_guards::CriticalCallEdges, diff --git a/src/librustc_mir/transform/simplify_try.rs b/src/librustc_mir/transform/simplify_try.rs index 3f28f033047a2..32cb2ffb542df 100644 --- a/src/librustc_mir/transform/simplify_try.rs +++ b/src/librustc_mir/transform/simplify_try.rs @@ -13,6 +13,7 @@ use crate::transform::{simplify, MirPass, MirSource}; use itertools::Itertools as _; use rustc::mir::*; use rustc::ty::{Ty, TyCtxt}; +use rustc_index::vec::IndexVec; use rustc_target::abi::VariantIdx; /// Simplifies arms of form `Variant(x) => Variant(x)` to just a move. @@ -21,7 +22,8 @@ use rustc_target::abi::VariantIdx; /// /// ```rust /// _LOCAL_TMP = ((_LOCAL_1 as Variant ).FIELD: TY ); -/// ((_LOCAL_0 as Variant).FIELD: TY) = move _LOCAL_TMP; +/// _TMP_2 = _LOCAL_TMP; +/// ((_LOCAL_0 as Variant).FIELD: TY) = move _TMP_2; /// discriminant(_LOCAL_0) = VAR_IDX; /// ``` /// @@ -32,50 +34,213 @@ use rustc_target::abi::VariantIdx; /// ``` pub struct SimplifyArmIdentity; +#[derive(Debug)] +struct ArmIdentityInfo<'tcx> { + /// Storage location for the variant's field + local_temp_0: Local, + /// Storage location holding the varient being read from + local_1: Local, + /// The varient field being read from + vf_s0: VarField<'tcx>, + + /// Tracks each assignment to a temporary of the varient's field + field_tmp_assignments: Vec<(Local, Local)>, + + /// Storage location holding the variant's field that was read from + local_tmp_s1: Local, + /// Storage location holding the enum that we are writing to + local_0: Local, + /// The varient field being written to + vf_s1: VarField<'tcx>, + + /// Storage location that the discrimentant is being set to + set_discr_local: Local, + /// The variant being written + set_discr_var_idx: VariantIdx, + + /// Index of the statement that should be overwritten as a move + stmt_to_overwrite: usize, + /// SourceInfo for the new move + source_info: SourceInfo, + + /// Indexes of matching Storage{Live,Dead} statements encountered. + /// (StorageLive index,, StorageDead index, Local) + storage_stmts: Vec<(usize, usize, Local)>, + + /// The statements that should be removed (turned into nops) + stmts_to_remove: Vec, +} + +fn get_arm_identity_info(stmts: &[Statement<'tcx>]) -> Option> { + let (mut local_tmp_s0, mut local_1, mut vf_s0) = (None, None, None); + let mut tmp_assigns = Vec::new(); + let (mut local_tmp_s1, mut local_0, mut vf_s1) = (None, None, None); + let (mut set_discr_local, mut set_discr_var_idx) = (None, None); + let mut starting_stmt = None; + let mut discr_stmt = None; + let mut nop_stmts = Vec::new(); + let mut storage_stmts = Vec::new(); + let mut storage_live_stmts = Vec::new(); + let mut storage_dead_stmts = Vec::new(); + + for (stmt_idx, stmt) in stmts.iter().enumerate() { + if let StatementKind::StorageLive(l) = stmt.kind { + storage_live_stmts.push((stmt_idx, l)); + continue; + } else if let StatementKind::StorageDead(l) = stmt.kind { + storage_dead_stmts.push((stmt_idx, l)); + continue; + } + + if local_tmp_s0 == None && local_1 == None && vf_s0 == None { + let result = match_get_variant_field(stmt)?; + local_tmp_s0 = Some(result.0); + local_1 = Some(result.1); + vf_s0 = Some(result.2); + starting_stmt = Some(stmt_idx); + } else if let StatementKind::Assign(box (place, Rvalue::Use(op))) = &stmt.kind { + if let Some(local) = place.as_local() { + if let Operand::Copy(p) | Operand::Move(p) = op { + tmp_assigns.push((local, p.as_local()?)); + nop_stmts.push(stmt_idx); + } else { + return None; + } + } else if local_tmp_s1 == None && local_0 == None && vf_s1 == None { + let result = match_set_variant_field(stmt)?; + local_tmp_s1 = Some(result.0); + local_0 = Some(result.1); + vf_s1 = Some(result.2); + nop_stmts.push(stmt_idx); + } + } else if set_discr_local == None && set_discr_var_idx == None { + let result = match_set_discr(stmt)?; + set_discr_local = Some(result.0); + set_discr_var_idx = Some(result.1); + discr_stmt = Some(stmt); + nop_stmts.push(stmt_idx); + } + } + + for (live_idx, live_local) in storage_live_stmts { + if let Some(i) = storage_dead_stmts.iter().rposition(|(_, l)| *l == live_local) { + let (dead_idx, _) = storage_dead_stmts.swap_remove(i); + storage_stmts.push((live_idx, dead_idx, live_local)); + } + } + + Some(ArmIdentityInfo { + local_temp_0: local_tmp_s0?, + local_1: local_1?, + vf_s0: vf_s0?, + field_tmp_assignments: tmp_assigns, + local_tmp_s1: local_tmp_s1?, + local_0: local_0?, + vf_s1: vf_s1?, + set_discr_local: set_discr_local?, + set_discr_var_idx: set_discr_var_idx?, + stmt_to_overwrite: starting_stmt?, + source_info: discr_stmt?.source_info, + storage_stmts: storage_stmts, + stmts_to_remove: nop_stmts, + }) +} + +fn optimization_applies<'tcx>(opt_info: &ArmIdentityInfo<'tcx>, local_decls: &IndexVec>) -> bool { + trace!("testing if optimization applies..."); + + if opt_info.local_0 == opt_info.local_1 { + trace!("NO: moving into ourselves"); + return false; + } else if opt_info.vf_s0 != opt_info.vf_s1 { + trace!("NO: the field-and-variant information do not match"); + return false; + } else if local_decls[opt_info.local_0].ty != local_decls[opt_info.local_1].ty { + // FIXME(Centril,oli-obk): possibly relax ot same layout? + trace!("NO: source and target locals have different types"); + return false; + } else if (opt_info.local_0, opt_info.vf_s0.var_idx) != (opt_info.set_discr_local, opt_info.set_discr_var_idx) { + trace!("NO: the discriminants do not match"); + return false; + } + + // Verify the assigment chain consists of the form b = a; c = b; d = c; etc... + if opt_info.field_tmp_assignments.len() == 0 { + trace!("NO: no assignments found"); + } + let mut last_assigned_to = opt_info.field_tmp_assignments[0].1; + let source_local = last_assigned_to; + for (l, r) in &opt_info.field_tmp_assignments { + if *r != last_assigned_to { + trace!("NO: found unexpected assignment {:?} = {:?}", l, r); + return false; + } + + last_assigned_to = *l; + } + + if source_local != opt_info.local_temp_0 { + trace!("NO: start of assignment chain does not match enum variant temp: {:?} != {:?}", source_local, opt_info.local_temp_0); + return false; + } else if last_assigned_to != opt_info.local_tmp_s1 { + trace!("NO: end of assignemnt chain does not match written enum temp: {:?} != {:?}", last_assigned_to, opt_info.local_tmp_s1); + return false; + } + + trace!("SUCCESS: optimization applies!"); + return true; +} + impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity { - fn run_pass(&self, _: TyCtxt<'tcx>, _: MirSource<'tcx>, body: &mut BodyAndCache<'tcx>) { + fn run_pass(&self, _: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut BodyAndCache<'tcx>) { + trace!("running SimplifyArmIdentity on {:?}", source); let (basic_blocks, local_decls) = body.basic_blocks_and_local_decls_mut(); for bb in basic_blocks { - // Need 3 statements: - let (s0, s1, s2) = match &mut *bb.statements { - [s0, s1, s2] => (s0, s1, s2), - _ => continue, - }; + trace!("bb.len() = {:?}", bb.statements.len()); - // Pattern match on the form we want: - let (local_tmp_s0, local_1, vf_s0) = match match_get_variant_field(s0) { - None => continue, - Some(x) => x, - }; - let (local_tmp_s1, local_0, vf_s1) = match match_set_variant_field(s1) { - None => continue, - Some(x) => x, - }; - if local_tmp_s0 != local_tmp_s1 - // Avoid moving into ourselves. - || local_0 == local_1 - // The field-and-variant information match up. - || vf_s0 != vf_s1 - // Source and target locals have the same type. - // FIXME(Centril | oli-obk): possibly relax to same layout? - || local_decls[local_0].ty != local_decls[local_1].ty - // We're setting the discriminant of `local_0` to this variant. - || Some((local_0, vf_s0.var_idx)) != match_set_discr(s2) - { - continue; - } + if let Some(mut opt_info) = get_arm_identity_info(&bb.statements) { + trace!("got opt_info = {:#?}", opt_info); + if !optimization_applies(&opt_info, local_decls) { + debug!("skipping simplification!!!!!!!!!!!"); + continue; + } + + trace!("proceeding..."); + + //if tcx.sess.opts.debugging_opts.mir_opt_level <= 1 { + // continue; + //} - // Right shape; transform! - s0.source_info = s2.source_info; - match &mut s0.kind { - StatementKind::Assign(box (place, rvalue)) => { - *place = local_0.into(); - *rvalue = Rvalue::Use(Operand::Move(local_1.into())); + // Also remove unused Storage{Live,Dead} statements which correspond + // to temps used previously. + for (left, right) in opt_info.field_tmp_assignments { + for (live_idx, dead_idx, local) in &opt_info.storage_stmts { + if *local == left || *local == right { + opt_info.stmts_to_remove.push(*live_idx); + opt_info.stmts_to_remove.push(*dead_idx); + } + } } - _ => unreachable!(), + + // Right shape; transform! + let stmt = &mut bb.statements[opt_info.stmt_to_overwrite]; + stmt.source_info = opt_info.source_info; + match &mut stmt.kind { + StatementKind::Assign(box (place, rvalue)) => { + *place = opt_info.local_0.into(); + *rvalue = Rvalue::Use(Operand::Move(opt_info.local_1.into())); + } + _ => unreachable!(), + } + + for stmt_idx in opt_info.stmts_to_remove { + bb.statements[stmt_idx].make_nop(); + } + + bb.statements.retain(|stmt| stmt.kind != StatementKind::Nop); + + trace!("block is now {:?}", bb.statements); } - s1.make_nop(); - s2.make_nop(); } } } @@ -129,7 +294,7 @@ fn match_set_discr<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, VariantIdx)> } } -#[derive(PartialEq)] +#[derive(PartialEq, Debug)] struct VarField<'tcx> { field: Field, field_ty: Ty<'tcx>, diff --git a/src/test/mir-opt/simplify-arm-identity.rs b/src/test/mir-opt/simplify-arm-identity.rs index a8fa64255fb9a..814c1b7a546a7 100644 --- a/src/test/mir-opt/simplify-arm-identity.rs +++ b/src/test/mir-opt/simplify-arm-identity.rs @@ -39,9 +39,14 @@ fn main() { // } // ... // bb3: { +// StorageLive(_4); // _4 = ((_1 as Foo).0: u8); -// ((_2 as Foo).0: u8) = move _4; +// StorageLive(_5); +// _5 = _4; +// ((_2 as Foo).0: u8) = move _5; // discriminant(_2) = 0; +// StorageDead(_5); +// StorageDead(_4); // goto -> bb4; // } // ... @@ -65,9 +70,14 @@ fn main() { // } // ... // bb3: { +// StorageLive(_4); // _4 = ((_1 as Foo).0: u8); -// ((_2 as Foo).0: u8) = move _4; +// StorageLive(_5); +// _5 = _4; +// ((_2 as Foo).0: u8) = move _5; // discriminant(_2) = 0; +// StorageDead(_5); +// StorageDead(_4); // goto -> bb4; // } // ... diff --git a/src/test/mir-opt/simplify_try.rs b/src/test/mir-opt/simplify_try.rs index abac66d95c548..efc9af0d48c2c 100644 --- a/src/test/mir-opt/simplify_try.rs +++ b/src/test/mir-opt/simplify_try.rs @@ -23,16 +23,16 @@ fn main() { // let _10: u32; // let mut _11: u32; // scope 1 { -// debug y => _10; +// debug y => _2; // } // scope 2 { // debug err => _6; // scope 3 { // scope 7 { -// debug t => _6; +// debug t => _9; // } // scope 8 { -// debug v => _6; +// debug v => _8; // let mut _12: i32; // } // } @@ -43,22 +43,49 @@ fn main() { // } // } // scope 6 { -// debug self => _1; +// debug self => _4; // } // bb0: { -// _5 = discriminant(_1); +// StorageLive(_2); +// StorageLive(_3); +// StorageLive(_4); +// _4 = _1; +// _3 = move _4; +// StorageDead(_4); +// _5 = discriminant(_3); // switchInt(move _5) -> [0isize: bb1, otherwise: bb2]; // } // bb1: { -// _10 = ((_1 as Ok).0: u32); -// ((_0 as Ok).0: u32) = move _10; +// StorageLive(_10); +// _10 = ((_3 as Ok).0: u32); +// _2 = _10; +// StorageDead(_10); +// StorageDead(_3); +// StorageLive(_11); +// _11 = _2; +// ((_0 as Ok).0: u32) = move _11; // discriminant(_0) = 0; +// StorageDead(_11); +// StorageDead(_2); // goto -> bb3; // } // bb2: { -// _6 = ((_1 as Err).0: i32); -// ((_0 as Err).0: i32) = move _6; +// StorageLive(_6); +// _6 = ((_3 as Err).0: i32); +// StorageLive(_8); +// StorageLive(_9); +// _9 = _6; +// _8 = move _9; +// StorageDead(_9); +// StorageLive(_12); +// _12 = move _8; +// ((_0 as Err).0: i32) = move _12; // discriminant(_0) = 1; +// StorageDead(_12); +// StorageDead(_8); +// StorageDead(_6); +// StorageDead(_3); +// StorageDead(_2); // goto -> bb3; // } // bb3: { @@ -82,16 +109,16 @@ fn main() { // let _10: u32; // let mut _11: u32; // scope 1 { -// debug y => _10; +// debug y => _2; // } // scope 2 { // debug err => _6; // scope 3 { // scope 7 { -// debug t => _6; +// debug t => _9; // } // scope 8 { -// debug v => _6; +// debug v => _8; // let mut _12: i32; // } // } @@ -102,22 +129,28 @@ fn main() { // } // } // scope 6 { -// debug self => _1; +// debug self => _4; // } // bb0: { -// _5 = discriminant(_1); +// StorageLive(_2); +// StorageLive(_3); +// StorageLive(_4); +// _4 = _1; +// _3 = move _4; +// StorageDead(_4); +// _5 = discriminant(_3); // switchInt(move _5) -> [0isize: bb1, otherwise: bb2]; // } // bb1: { -// _0 = move _1; -// nop; -// nop; +// _0 = move _3; +// StorageDead(_3); +// StorageDead(_2); // goto -> bb3; // } // bb2: { -// _0 = move _1; -// nop; -// nop; +// _0 = move _3; +// StorageDead(_3); +// StorageDead(_2); // goto -> bb3; // } // bb3: { @@ -141,16 +174,16 @@ fn main() { // let _10: u32; // let mut _11: u32; // scope 1 { -// debug y => _10; +// debug y => _2; // } // scope 2 { // debug err => _6; // scope 3 { // scope 7 { -// debug t => _6; +// debug t => _9; // } // scope 8 { -// debug v => _6; +// debug v => _8; // let mut _12: i32; // } // } @@ -161,16 +194,22 @@ fn main() { // } // } // scope 6 { -// debug self => _1; +// debug self => _4; // } // bb0: { -// _5 = discriminant(_1); +// StorageLive(_2); +// StorageLive(_3); +// StorageLive(_4); +// _4 = _1; +// _3 = move _4; +// StorageDead(_4); +// _5 = discriminant(_3); // goto -> bb1; // } // bb1: { -// _0 = move _1; -// nop; -// nop; +// _0 = move _3; +// StorageDead(_3); +// StorageDead(_2); // goto -> bb2; // } // bb2: { @@ -183,25 +222,28 @@ fn main() { // fn try_identity(_1: std::result::Result) -> std::result::Result { // debug x => _1; // let mut _0: std::result::Result; -// let mut _2: isize; -// let _3: i32; -// let _4: u32; +// let _2: u32; +// let mut _3: isize; +// let _4: i32; +// let mut _5: i32; +// let mut _6: i32; +// let _7: u32; // scope 1 { -// debug y => _4; +// debug y => _2; // } // scope 2 { -// debug err => _3; +// debug err => _4; // scope 3 { // scope 7 { -// debug t => _3; +// debug t => _6; // } // scope 8 { -// debug v => _3; +// debug v => _5; // } // } // } // scope 4 { -// debug val => _4; +// debug val => _7; // scope 5 { // } // } @@ -209,8 +251,10 @@ fn main() { // debug self => _1; // } // bb0: { -// _2 = discriminant(_1); +// StorageLive(_2); +// _3 = discriminant(_1); // _0 = move _1; +// StorageDead(_2); // return; // } // }