Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various fixes and improvements to the MIR Dataflow framework #1

Open
wants to merge 14 commits into
base: mir-dflow
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 47 additions & 86 deletions src/librustc/mir/transform/dataflow.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
use mir::repr as mir;
use mir::cfg::CFG;
use mir::repr::BasicBlock;
use mir::repr::{BasicBlock, START_BLOCK};
use rustc_data_structures::bitvec::BitVector;

use mir::transform::lattice::Lattice;


pub trait DataflowPass<'tcx> {
type Lattice: Lattice;
type Rewrite: Rewrite<'tcx, Self::Lattice>;
type Transfer: Transfer<'tcx, Self::Lattice>;
}

pub trait Rewrite<'tcx, L: Lattice> {
/// The rewrite function which given a statement optionally produces an alternative graph to be
/// placed in place of the original statement.
Expand All @@ -23,7 +16,7 @@ pub trait Rewrite<'tcx, L: Lattice> {
/// that is, given some fact `fact` true before both the statement and relacement graph, and
/// a fact `fact2` which is true after the statement, the same `fact2` must be true after the
/// replacement graph too.
fn stmt(&mir::Statement<'tcx>, &L, &mut CFG<'tcx>) -> StatementChange<'tcx>;
fn stmt(&self, &mir::Statement<'tcx>, &L, &mut CFG<'tcx>) -> StatementChange<'tcx>;

/// The rewrite function which given a terminator optionally produces an alternative graph to
/// be placed in place of the original statement.
Expand All @@ -35,49 +28,39 @@ pub trait Rewrite<'tcx, L: Lattice> {
/// that is, given some fact `fact` true before both the terminator and relacement graph, and
/// a fact `fact2` which is true after the statement, the same `fact2` must be true after the
/// replacement graph too.
fn term(&mir::Terminator<'tcx>, &L, &mut CFG<'tcx>) -> TerminatorChange<'tcx>;
fn term(&self, &mir::Terminator<'tcx>, &L, &mut CFG<'tcx>) -> TerminatorChange<'tcx>;

fn and_then<R2>(self, other: R2) -> RewriteAndThen<Self, R2> where Self: Sized {
RewriteAndThen(self, other)
}
}

/// This combinator has the following behaviour:
///
/// * Rewrite the node with the first rewriter.
/// * if the first rewriter replaced the node, 2nd rewriter is used to rewrite the replacement.
/// * otherwise 2nd rewriter is used to rewrite the original node.
pub struct RewriteAndThen<'tcx, R1, R2>(::std::marker::PhantomData<(&'tcx (), R1, R2)>);
impl<'tcx, L, R1, R2> Rewrite<'tcx, L> for RewriteAndThen<'tcx, R1, R2>
pub struct RewriteAndThen<R1, R2>(R1, R2);
impl<'tcx, L, R1, R2> Rewrite<'tcx, L> for RewriteAndThen<R1, R2>
where L: Lattice, R1: Rewrite<'tcx, L>, R2: Rewrite<'tcx, L> {
fn stmt(s: &mir::Statement<'tcx>, l: &L, c: &mut CFG<'tcx>) -> StatementChange<'tcx> {
let rs = <R1 as Rewrite<L>>::stmt(s, l, c);
fn stmt(&self, s: &mir::Statement<'tcx>, l: &L, c: &mut CFG<'tcx>) -> StatementChange<'tcx> {
let rs = self.0.stmt(s, l, c);
match rs {
StatementChange::None => <R2 as Rewrite<L>>::stmt(s, l, c),
StatementChange::None => self.1.stmt(s, l, c),
StatementChange::Remove => StatementChange::Remove,
StatementChange::Statement(ns) =>
match <R2 as Rewrite<L>>::stmt(&ns, l, c) {
match self.1.stmt(&ns, l, c) {
StatementChange::None => StatementChange::Statement(ns),
x => x
},
StatementChange::Statements(nss) => {
// We expect the common case of all statements in this vector being replaced/not
// replaced by other statements 1:1
let mut new_new_stmts = Vec::with_capacity(nss.len());
for s in nss {
match <R2 as Rewrite<L>>::stmt(&s, l, c) {
StatementChange::None => new_new_stmts.push(s),
StatementChange::Remove => {},
StatementChange::Statement(ns) => new_new_stmts.push(ns),
StatementChange::Statements(nss) => new_new_stmts.extend(nss)
}
}
StatementChange::Statements(new_new_stmts)
}
}
}

fn term(t: &mir::Terminator<'tcx>, l: &L, c: &mut CFG<'tcx>) -> TerminatorChange<'tcx> {
let rt = <R1 as Rewrite<L>>::term(t, l, c);
fn term(&self, t: &mir::Terminator<'tcx>, l: &L, c: &mut CFG<'tcx>) -> TerminatorChange<'tcx> {
let rt = self.0.term(t, l, c);
match rt {
TerminatorChange::None => <R2 as Rewrite<L>>::term(t, l, c),
TerminatorChange::Terminator(nt) => match <R2 as Rewrite<L>>::term(&nt, l, c) {
TerminatorChange::None => self.1.term(t, l, c),
TerminatorChange::Terminator(nt) => match self.1.term(&nt, l, c) {
TerminatorChange::None => TerminatorChange::Terminator(nt),
x => x
}
Expand All @@ -99,39 +82,22 @@ pub enum StatementChange<'tcx> {
Remove,
/// Replace with another single statement
Statement(mir::Statement<'tcx>),
/// Replace with a list of statements
Statements(Vec<mir::Statement<'tcx>>),
}

impl<'tcx> StatementChange<'tcx> {
fn normalise(&mut self) {
let old = ::std::mem::replace(self, StatementChange::None);
*self = match old {
StatementChange::Statements(mut stmts) => {
match stmts.len() {
0 => StatementChange::Remove,
1 => StatementChange::Statement(stmts.pop().unwrap()),
_ => StatementChange::Statements(stmts)
}
}
o => o
}
}
}
pub trait Transfer<'tcx> {
type Lattice: Lattice;

pub trait Transfer<'tcx, L: Lattice> {
type TerminatorOut;
/// The transfer function which given a statement and a fact produces a fact which is true
/// after the statement.
fn stmt(&mir::Statement<'tcx>, L) -> L;
fn stmt(&self, &mir::Statement<'tcx>, Self::Lattice) -> Self::Lattice;

/// The transfer function which given a terminator and a fact produces a fact for each
/// successor of the terminator.
///
/// Corectness precondtition:
/// * The list of facts produced should only contain the facts for blocks which are successors
/// of the terminator being transfered.
fn term(&mir::Terminator<'tcx>, L) -> Self::TerminatorOut;
fn term(&self, &mir::Terminator<'tcx>, Self::Lattice) -> Vec<Self::Lattice>;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change makes the trait not very reusable for backward analysis, where terminators only have a single edge regardless of how many edges the terminator has in the forward direction. That is the primary motivation for TerminatorOut associated type.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but I'm not sure that the same trait should be reused for the forward and backwards cases. Regardless, even if the two directions share a trait, I'd prefer a clearer way of denoting which direction a pass goes than using Vec<Lattice> for forward and Lattice for backwards. Transfer could have another argument (Direction) which itself has an associated type denoting what output term should have.

I got rid of the associated TerminatorOut because it felt clumsy, it cause recursive trait bounds when I moved in the associated lattice type, and because it wasn't necessary yet. When backwards transformation gets implemented, I'd be happy for this to be generalized.

}


Expand Down Expand Up @@ -160,20 +126,22 @@ impl<F: Lattice> ::std::ops::Index<BasicBlock> for Facts<F> {

impl<F: Lattice> ::std::ops::IndexMut<BasicBlock> for Facts<F> {
fn index_mut(&mut self, index: BasicBlock) -> &mut F {
if let None = self.0.get_mut(index.index()) {
if self.0.get(index.index()).is_none() {
self.put(index, <F as Lattice>::bottom());
}
self.0.get_mut(index.index()).unwrap()
}
}

/// Analyse and rewrite using dataflow in the forward direction
pub fn ar_forward<'tcx, T, P>(cfg: &CFG<'tcx>, fs: Facts<P::Lattice>, mut queue: BitVector)
-> (CFG<'tcx>, Facts<P::Lattice>)
// FIXME: shouldn’t need that T generic.
where T: Transfer<'tcx, P::Lattice, TerminatorOut=Vec<P::Lattice>>,
P: DataflowPass<'tcx, Transfer=T>
pub fn ar_forward<'tcx, T, R>(cfg: &CFG<'tcx>, fs: Facts<T::Lattice>, transfer: T, rewrite: R)
-> (CFG<'tcx>, Facts<T::Lattice>)
where T: Transfer<'tcx>,
R: Rewrite<'tcx, T::Lattice>
{
let mut queue = BitVector::new(cfg.len());
queue.insert(START_BLOCK.index());
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here, I do not believe that beginning at the start block is always correct/desired/etc.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give an example of when another location would be desired? For a forwards pass, the start block seems like a natural place to start, and allows for a number of blocks to be transformed only once - if the CFG is acyclic, fixpoint iteration is unnecessary starting at the start block.


fixpoint(cfg, Direction::Forward, |bb, fact, cfg| {
let new_graph = cfg.start_new_block();
let mut fact = fact.clone();
Expand All @@ -183,33 +151,25 @@ where T: Transfer<'tcx, P::Lattice, TerminatorOut=Vec<P::Lattice>>,
for stmt in &old_statements {
// Given a fact and statement produce a new fact and optionally a replacement
// graph.
let mut new_repl = P::Rewrite::stmt(&stmt, &fact, cfg);
new_repl.normalise();
match new_repl {
match rewrite.stmt(&stmt, &fact, cfg) {
StatementChange::None => {
fact = P::Transfer::stmt(stmt, fact);
fact = transfer.stmt(stmt, fact);
cfg.push(new_graph, stmt.clone());
}
StatementChange::Remove => changed = true,
StatementChange::Statement(stmt) => {
changed = true;
fact = P::Transfer::stmt(&stmt, fact);
fact = transfer.stmt(&stmt, fact);
cfg.push(new_graph, stmt);
}
StatementChange::Statements(stmts) => {
changed = true;
for stmt in &stmts { fact = P::Transfer::stmt(stmt, fact); }
cfg[new_graph].statements.extend(stmts);
}

}
}
// Swap the statements back in.
::std::mem::replace(&mut cfg[bb].statements, old_statements);
cfg[bb].statements = old_statements;

// Handle the terminator replacement and transfer.
let terminator = ::std::mem::replace(&mut cfg[bb].terminator, None).unwrap();
let repl = P::Rewrite::term(&terminator, &fact, cfg);
let terminator = cfg[bb].terminator.take().unwrap();
let repl = rewrite.term(&terminator, &fact, cfg);
match repl {
TerminatorChange::None => {
cfg[new_graph].terminator = Some(terminator.clone());
Expand All @@ -219,8 +179,8 @@ where T: Transfer<'tcx, P::Lattice, TerminatorOut=Vec<P::Lattice>>,
cfg[new_graph].terminator = Some(t);
}
}
let new_facts = P::Transfer::term(cfg[new_graph].terminator(), fact);
::std::mem::replace(&mut cfg[bb].terminator, Some(terminator));
let new_facts = transfer.term(cfg[new_graph].terminator(), fact);
cfg[bb].terminator = Some(terminator);

(if changed { Some(new_graph) } else { None }, new_facts)
}, &mut queue, fs)
Expand Down Expand Up @@ -350,7 +310,7 @@ where T: Transfer<'tcx, P::Lattice, TerminatorOut=Vec<P::Lattice>>,

enum Direction {
Forward,
Backward
// Backward
}

/// The fixpoint function is the engine of this whole thing. Important part of it is the `f: BF`
Expand Down Expand Up @@ -389,16 +349,17 @@ where BF: Fn(BasicBlock, &F, &mut CFG<'tcx>) -> (Option<BasicBlock>, Vec<F>),
}

// Then we record the facts in the correct direction.
if let Direction::Forward = direction {
for (f, &target) in new_facts.into_iter()
.zip(cfg[block].terminator().successors().iter()) {
let facts_changed = Lattice::join(&mut init_facts[target], &f);
if facts_changed {
to_visit.insert(target.index());
match direction {
Direction::Forward => {
for (f, &target) in new_facts.into_iter()
.zip(cfg[block].terminator().successors().iter()) {
let facts_changed = Lattice::join(&mut init_facts[target], &f);
if facts_changed {
to_visit.insert(target.index());
}
}
}
} else {
unimplemented!()
// Direction::Backward => unimplemented!()
// let mut new_facts = new_facts;
// let fact = new_facts.pop().unwrap().1;
// for pred in cfg.predecessors(block) {
Expand Down
2 changes: 1 addition & 1 deletion src/librustc/mir/transform/lattice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl<T: Lattice> Lattice for WTop<T> {
/// ⊤ + V = ⊤ (no change)
/// V + ⊤ = ⊤
/// ⊤ + ⊤ = ⊤ (no change)
default fn join(&mut self, other: &Self) -> bool {
fn join(&mut self, other: &Self) -> bool {
match (self, other) {
(&mut WTop::Value(ref mut this), &WTop::Value(ref o)) => <T as Lattice>::join(this, o),
(&mut WTop::Top, _) => false,
Expand Down
4 changes: 2 additions & 2 deletions src/librustc_driver/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -939,8 +939,8 @@ pub fn phase_3_run_analysis_passes<'tcx, F, R>(sess: &'tcx Session,
passes.push_pass(box mir::transform::remove_dead_blocks::RemoveDeadBlocks);
passes.push_pass(box mir::transform::qualify_consts::QualifyAndPromoteConstants);
passes.push_pass(box mir::transform::type_check::TypeckMir);
passes.push_pass(box mir::transform::acs_propagate::ACSPropagate);
// passes.push_pass(box mir::transform::simplify_cfg::SimplifyCfg);
passes.push_pass(box mir::transform::acs_propagate::AcsPropagate);
passes.push_pass(box mir::transform::simplify_cfg::SimplifyCfg);
passes.push_pass(box mir::transform::remove_dead_blocks::RemoveDeadBlocks);
// And run everything.
passes.run_passes(tcx, &mut mir_map);
Expand Down
Loading