Skip to content

Commit

Permalink
Auto merge of rust-lang#35348 - scottcarr:discriminant2, r=nikomatsakis
Browse files Browse the repository at this point in the history
[MIR] Add explicit SetDiscriminant StatementKind for deaggregating enums

cc rust-lang#35186

To deaggregate enums, we need to be able to explicitly set the discriminant.  This PR implements a new StatementKind that does that.

I think some of the places that have `panics!` now could maybe do something smarter.
  • Loading branch information
bors authored Aug 13, 2016
2 parents d3c3de8 + d77a136 commit e64f688
Show file tree
Hide file tree
Showing 12 changed files with 156 additions and 18 deletions.
6 changes: 5 additions & 1 deletion src/librustc/mir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -689,13 +689,17 @@ pub struct Statement<'tcx> {
#[derive(Clone, Debug, RustcEncodable, RustcDecodable)]
pub enum StatementKind<'tcx> {
Assign(Lvalue<'tcx>, Rvalue<'tcx>),
SetDiscriminant{ lvalue: Lvalue<'tcx>, variant_index: usize },
}

impl<'tcx> Debug for Statement<'tcx> {
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
use self::StatementKind::*;
match self.kind {
Assign(ref lv, ref rv) => write!(fmt, "{:?} = {:?}", lv, rv)
Assign(ref lv, ref rv) => write!(fmt, "{:?} = {:?}", lv, rv),
SetDiscriminant{lvalue: ref lv, variant_index: index} => {
write!(fmt, "discriminant({:?}) = {:?}", lv, index)
}
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/librustc/mir/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ macro_rules! make_mir_visitor {
ref $($mutability)* rvalue) => {
self.visit_assign(block, lvalue, rvalue);
}
StatementKind::SetDiscriminant{ ref $($mutability)* lvalue, .. } => {
self.visit_lvalue(lvalue, LvalueContext::Store);
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/librustc_borrowck/borrowck/mir/dataflow/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ impl<'a, 'tcx> BitDenotation for MovingOutStatements<'a, 'tcx> {
}
let bits_per_block = self.bits_per_block(ctxt);
match stmt.kind {
repr::StatementKind::SetDiscriminant { .. } => {
span_bug!(stmt.source_info.span, "SetDiscriminant should not exist in borrowck");
}
repr::StatementKind::Assign(ref lvalue, _) => {
// assigning into this `lvalue` kills all
// MoveOuts from it, and *also* all MoveOuts
Expand Down
3 changes: 3 additions & 0 deletions src/librustc_borrowck/borrowck/mir/dataflow/sanity_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ fn each_block<'a, 'tcx, O>(tcx: TyCtxt<'a, 'tcx, 'tcx>,
repr::StatementKind::Assign(ref lvalue, ref rvalue) => {
(lvalue, rvalue)
}
repr::StatementKind::SetDiscriminant{ .. } =>
span_bug!(stmt.source_info.span,
"sanity_check should run before Deaggregator inserts SetDiscriminant"),
};

if lvalue == peek_arg_lval {
Expand Down
4 changes: 4 additions & 0 deletions src/librustc_borrowck/borrowck/mir/gather_moves.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,10 @@ fn gather_moves<'a, 'tcx>(mir: &Mir<'tcx>, tcx: TyCtxt<'a, 'tcx, 'tcx>) -> MoveD
Rvalue::InlineAsm { .. } => {}
}
}
StatementKind::SetDiscriminant{ .. } => {
span_bug!(stmt.source_info.span,
"SetDiscriminant should not exist during borrowck");
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/librustc_borrowck/borrowck/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,9 @@ fn drop_flag_effects_for_location<'a, 'tcx, F>(
let block = &mir[loc.block];
match block.statements.get(loc.index) {
Some(stmt) => match stmt.kind {
repr::StatementKind::SetDiscriminant{ .. } => {
span_bug!(stmt.source_info.span, "SetDiscrimant should not exist during borrowck");
}
repr::StatementKind::Assign(ref lvalue, _) => {
debug!("drop_flag_effects: assignment {:?}", stmt);
on_all_children_bits(tcx, mir, move_data,
Expand Down
45 changes: 34 additions & 11 deletions src/librustc_mir/transform/deaggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {

let mut curr: usize = 0;
for bb in mir.basic_blocks_mut() {
let idx = match get_aggregate_statement(curr, &bb.statements) {
let idx = match get_aggregate_statement_index(curr, &bb.statements) {
Some(idx) => idx,
None => continue,
};
Expand All @@ -48,7 +48,11 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
let src_info = bb.statements[idx].source_info;
let suffix_stmts = bb.statements.split_off(idx+1);
let orig_stmt = bb.statements.pop().unwrap();
let StatementKind::Assign(ref lhs, ref rhs) = orig_stmt.kind;
let (lhs, rhs) = match orig_stmt.kind {
StatementKind::Assign(ref lhs, ref rhs) => (lhs, rhs),
StatementKind::SetDiscriminant{ .. } =>
span_bug!(src_info.span, "expected aggregate, not {:?}", orig_stmt.kind),
};
let (agg_kind, operands) = match rhs {
&Rvalue::Aggregate(ref agg_kind, ref operands) => (agg_kind, operands),
_ => span_bug!(src_info.span, "expected aggregate, not {:?}", rhs),
Expand All @@ -64,10 +68,14 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
let ty = variant_def.fields[i].ty(tcx, substs);
let rhs = Rvalue::Use(op.clone());

// since we don't handle enums, we don't need a cast
let lhs_cast = lhs.clone();

// FIXME we cannot deaggregate enums issue: #35186
let lhs_cast = if adt_def.variants.len() > 1 {
Lvalue::Projection(Box::new(LvalueProjection {
base: lhs.clone(),
elem: ProjectionElem::Downcast(adt_def, variant),
}))
} else {
lhs.clone()
};

let lhs_proj = Lvalue::Projection(Box::new(LvalueProjection {
base: lhs_cast,
Expand All @@ -80,18 +88,34 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
debug!("inserting: {:?} @ {:?}", new_statement, idx + i);
bb.statements.push(new_statement);
}

// if the aggregate was an enum, we need to set the discriminant
if adt_def.variants.len() > 1 {
let set_discriminant = Statement {
kind: StatementKind::SetDiscriminant {
lvalue: lhs.clone(),
variant_index: variant,
},
source_info: src_info,
};
bb.statements.push(set_discriminant);
};

curr = bb.statements.len();
bb.statements.extend(suffix_stmts);
}
}
}

fn get_aggregate_statement<'a, 'tcx, 'b>(curr: usize,
fn get_aggregate_statement_index<'a, 'tcx, 'b>(start: usize,
statements: &Vec<Statement<'tcx>>)
-> Option<usize> {
for i in curr..statements.len() {
for i in start..statements.len() {
let ref statement = statements[i];
let StatementKind::Assign(_, ref rhs) = statement.kind;
let rhs = match statement.kind {
StatementKind::Assign(_, ref rhs) => rhs,
StatementKind::SetDiscriminant{ .. } => continue,
};
let (kind, operands) = match rhs {
&Rvalue::Aggregate(ref kind, ref operands) => (kind, operands),
_ => continue,
Expand All @@ -100,9 +124,8 @@ fn get_aggregate_statement<'a, 'tcx, 'b>(curr: usize,
&AggregateKind::Adt(adt_def, variant, _) => (adt_def, variant),
_ => continue,
};
if operands.len() == 0 || adt_def.variants.len() > 1 {
if operands.len() == 0 {
// don't deaggregate ()
// don't deaggregate enums ... for now
continue;
}
debug!("getting variant {:?}", variant);
Expand Down
22 changes: 19 additions & 3 deletions src/librustc_mir/transform/promote_consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,13 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
let (mut rvalue, mut call) = (None, None);
let source_info = if stmt_idx < no_stmts {
let statement = &mut self.source[bb].statements[stmt_idx];
let StatementKind::Assign(_, ref mut rhs) = statement.kind;
let mut rhs = match statement.kind {
StatementKind::Assign(_, ref mut rhs) => rhs,
StatementKind::SetDiscriminant{ .. } =>
span_bug!(statement.source_info.span,
"cannot promote SetDiscriminant {:?}",
statement),
};
if self.keep_original {
rvalue = Some(rhs.clone());
} else {
Expand Down Expand Up @@ -300,10 +306,16 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
});
let mut rvalue = match candidate {
Candidate::Ref(Location { block: bb, statement_index: stmt_idx }) => {
match self.source[bb].statements[stmt_idx].kind {
let ref mut statement = self.source[bb].statements[stmt_idx];
match statement.kind {
StatementKind::Assign(_, ref mut rvalue) => {
mem::replace(rvalue, Rvalue::Use(new_operand))
}
StatementKind::SetDiscriminant{ .. } => {
span_bug!(statement.source_info.span,
"cannot promote SetDiscriminant {:?}",
statement);
}
}
}
Candidate::ShuffleIndices(bb) => {
Expand Down Expand Up @@ -340,7 +352,11 @@ pub fn promote_candidates<'a, 'tcx>(mir: &mut Mir<'tcx>,
let (span, ty) = match candidate {
Candidate::Ref(Location { block: bb, statement_index: stmt_idx }) => {
let statement = &mir[bb].statements[stmt_idx];
let StatementKind::Assign(ref dest, _) = statement.kind;
let dest = match statement.kind {
StatementKind::Assign(ref dest, _) => dest,
StatementKind::SetDiscriminant{ .. } =>
panic!("cannot promote SetDiscriminant"),
};
if let Lvalue::Temp(index) = *dest {
if temps[index] == TempState::PromotedOut {
// Already promoted.
Expand Down
23 changes: 20 additions & 3 deletions src/librustc_mir/transform/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
use rustc::infer::{self, InferCtxt, InferOk};
use rustc::traits::{self, Reveal};
use rustc::ty::fold::TypeFoldable;
use rustc::ty::{self, Ty, TyCtxt};
use rustc::ty::{self, Ty, TyCtxt, TypeVariants};
use rustc::mir::repr::*;
use rustc::mir::tcx::LvalueTy;
use rustc::mir::transform::{MirPass, MirSource, Pass};
Expand Down Expand Up @@ -360,10 +360,27 @@ impl<'a, 'gcx, 'tcx> TypeChecker<'a, 'gcx, 'tcx> {
span_mirbug!(self, stmt, "bad assignment ({:?} = {:?}): {:?}",
lv_ty, rv_ty, terr);
}
}

// FIXME: rvalue with undeterminable type - e.g. inline
// asm.
}
}
StatementKind::SetDiscriminant{ ref lvalue, variant_index } => {
let lvalue_type = lvalue.ty(mir, tcx).to_ty(tcx);
let adt = match lvalue_type.sty {
TypeVariants::TyEnum(adt, _) => adt,
_ => {
span_bug!(stmt.source_info.span,
"bad set discriminant ({:?} = {:?}): lhs is not an enum",
lvalue,
variant_index);
}
};
if variant_index >= adt.variants.len() {
span_bug!(stmt.source_info.span,
"bad set discriminant ({:?} = {:?}): value of of range",
lvalue,
variant_index);
};
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/librustc_trans/mir/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ impl<'a, 'tcx> MirConstContext<'a, 'tcx> {
Err(err) => if failure.is_ok() { failure = Err(err); }
}
}
mir::StatementKind::SetDiscriminant{ .. } => {
span_bug!(span, "SetDiscriminant should not appear in constants?");
}
}
}

Expand Down
14 changes: 14 additions & 0 deletions src/librustc_trans/mir/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use common::{self, BlockAndBuilder};

use super::MirContext;
use super::LocalRef;
use super::super::adt;
use super::super::disr::Disr;

impl<'bcx, 'tcx> MirContext<'bcx, 'tcx> {
pub fn trans_statement(&mut self,
Expand Down Expand Up @@ -57,6 +59,18 @@ impl<'bcx, 'tcx> MirContext<'bcx, 'tcx> {
self.trans_rvalue(bcx, tr_dest, rvalue, debug_loc)
}
}
mir::StatementKind::SetDiscriminant{ref lvalue, variant_index} => {
let ty = self.monomorphized_lvalue_ty(lvalue);
let repr = adt::represent_type(bcx.ccx(), ty);
let lvalue_transed = self.trans_lvalue(&bcx, lvalue);
bcx.with_block(|bcx|
adt::trans_set_discr(bcx,
&repr,
lvalue_transed.llval,
Disr::from(variant_index))
);
bcx
}
}
}
}
45 changes: 45 additions & 0 deletions src/test/mir-opt/deaggregator_test_enum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright 2016 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

enum Baz {
Empty,
Foo { x: usize },
}

fn bar(a: usize) -> Baz {
Baz::Foo { x: a }
}

fn main() {
let x = bar(10);
match x {
Baz::Empty => println!("empty"),
Baz::Foo { x } => println!("{}", x),
};
}

// END RUST SOURCE
// START rustc.node10.Deaggregator.before.mir
// bb0: {
// var0 = arg0; // scope 0 at main.rs:7:8: 7:9
// tmp0 = var0; // scope 1 at main.rs:8:19: 8:20
// return = Baz::Foo { x: tmp0 }; // scope 1 at main.rs:8:5: 8:21
// goto -> bb1; // scope 1 at main.rs:7:1: 9:2
// }
// END rustc.node10.Deaggregator.before.mir
// START rustc.node10.Deaggregator.after.mir
// bb0: {
// var0 = arg0; // scope 0 at main.rs:7:8: 7:9
// tmp0 = var0; // scope 1 at main.rs:8:19: 8:20
// ((return as Foo).0: usize) = tmp0; // scope 1 at main.rs:8:5: 8:21
// discriminant(return) = 1; // scope 1 at main.rs:8:5: 8:21
// goto -> bb1; // scope 1 at main.rs:7:1: 9:2
// }
// END rustc.node10.Deaggregator.after.mir

0 comments on commit e64f688

Please sign in to comment.