Skip to content

Commit

Permalink
Auto merge of #24209 - nikomatsakis:refactor-unification, r=nrc
Browse files Browse the repository at this point in the history
I'm on a quest to slowly refactor a lot of the inference code. A first step for that is moving the "pure data structures" out so as to simplify what's left. This PR moves `snapshot_vec`, `graph`, and `unify` into their own crate (`librustc_data_structures`). They can then be unit-tested, benchmarked, etc more easily. As a benefit, I improved the performance of unification slightly on the benchmark I added vs the original code.

r? @nrc
  • Loading branch information
bors committed Apr 18, 2015
2 parents efa6a46 + e47fb48 commit 77213d1
Show file tree
Hide file tree
Showing 18 changed files with 782 additions and 382 deletions.
6 changes: 4 additions & 2 deletions mk/crates.mk
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ TARGET_CRATES := libc std flate arena term \
log graphviz core rbml alloc \
unicode rustc_bitflags
RUSTC_CRATES := rustc rustc_typeck rustc_borrowck rustc_resolve rustc_driver \
rustc_trans rustc_back rustc_llvm rustc_privacy rustc_lint
rustc_trans rustc_back rustc_llvm rustc_privacy rustc_lint \
rustc_data_structures
HOST_CRATES := syntax $(RUSTC_CRATES) rustdoc fmt_macros
CRATES := $(TARGET_CRATES) $(HOST_CRATES)
TOOLS := compiletest rustdoc rustc rustbook
Expand All @@ -80,9 +81,10 @@ DEPS_rustc_resolve := rustc log syntax
DEPS_rustc_privacy := rustc log syntax
DEPS_rustc_lint := rustc log syntax
DEPS_rustc := syntax flate arena serialize getopts rbml \
log graphviz rustc_llvm rustc_back
log graphviz rustc_llvm rustc_back rustc_data_structures
DEPS_rustc_llvm := native:rustllvm libc std
DEPS_rustc_back := std syntax rustc_llvm flate log libc
DEPS_rustc_data_structures := std log serialize
DEPS_rustdoc := rustc rustc_driver native:hoedown serialize getopts \
test rustc_lint
DEPS_rustc_bitflags := core
Expand Down
3 changes: 1 addition & 2 deletions src/librustc/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ extern crate graphviz;
extern crate libc;
extern crate rustc_llvm;
extern crate rustc_back;
extern crate rustc_data_structures;
extern crate serialize;
extern crate rbml;
extern crate collections;
Expand Down Expand Up @@ -103,7 +104,6 @@ pub mod middle {
pub mod entry;
pub mod expr_use_visitor;
pub mod fast_reject;
pub mod graph;
pub mod intrinsicck;
pub mod infer;
pub mod lang_items;
Expand Down Expand Up @@ -141,7 +141,6 @@ pub mod util {
pub mod common;
pub mod ppaux;
pub mod nodemap;
pub mod snapshot_vec;
pub mod lev_distance;
}

Expand Down
2 changes: 1 addition & 1 deletion src/librustc/middle/cfg/construct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use rustc_data_structures::graph;
use middle::cfg::*;
use middle::def;
use middle::graph;
use middle::pat_util;
use middle::region::CodeExtent;
use middle::ty;
Expand Down
5 changes: 3 additions & 2 deletions src/librustc/middle/cfg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
//! Module that constructs a control-flow graph representing an item.
//! Uses `Graph` as the underlying representation.
use middle::graph;
use rustc_data_structures::graph;
use middle::ty;
use syntax::ast;

Expand All @@ -24,7 +24,7 @@ pub struct CFG {
pub exit: CFGIndex,
}

#[derive(Copy, Clone, PartialEq)]
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum CFGNodeData {
AST(ast::NodeId),
Entry,
Expand All @@ -43,6 +43,7 @@ impl CFGNodeData {
}
}

#[derive(Debug)]
pub struct CFGEdgeData {
pub exiting_scopes: Vec<ast::NodeId>
}
Expand Down
5 changes: 2 additions & 3 deletions src/librustc/middle/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -576,10 +576,9 @@ impl<'a, 'b, 'tcx, O:DataFlowOperator> PropagationContext<'a, 'b, 'tcx, O> {
pred_bits: &[usize],
cfg: &cfg::CFG,
cfgidx: CFGIndex) {
cfg.graph.each_outgoing_edge(cfgidx, |_e_idx, edge| {
for (_, edge) in cfg.graph.outgoing_edges(cfgidx) {
self.propagate_bits_into_entry_set_for(pred_bits, edge);
true
});
}
}

fn propagate_bits_into_entry_set_for(&mut self,
Expand Down
2 changes: 1 addition & 1 deletion src/librustc/middle/infer/freshen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use middle::ty_fold::TypeFolder;
use std::collections::hash_map::{self, Entry};

use super::InferCtxt;
use super::unify::ToType;
use super::unify_key::ToType;

pub struct TypeFreshener<'a, 'tcx:'a> {
infcx: &'a InferCtxt<'a, 'tcx>,
Expand Down
5 changes: 3 additions & 2 deletions src/librustc/middle/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use middle::ty::replace_late_bound_regions;
use middle::ty::{self, Ty};
use middle::ty_fold::{TypeFolder, TypeFoldable};
use middle::ty_relate::{Relate, RelateResult, TypeRelation};
use rustc_data_structures::unify::{self, UnificationTable};
use std::cell::{RefCell};
use std::fmt;
use std::rc::Rc;
Expand All @@ -41,8 +42,8 @@ use util::ppaux::{Repr, UserString};

use self::combine::CombineFields;
use self::region_inference::{RegionVarBindings, RegionSnapshot};
use self::unify::{ToType, UnificationTable};
use self::error_reporting::ErrorReporting;
use self::unify_key::ToType;

pub mod bivariate;
pub mod combine;
Expand All @@ -57,7 +58,7 @@ pub mod resolve;
mod freshen;
pub mod sub;
pub mod type_variable;
pub mod unify;
pub mod unify_key;

pub type Bound<T> = Option<T>;
pub type UnitResult<'tcx> = RelateResult<'tcx, ()>; // "unify result"
Expand Down
30 changes: 13 additions & 17 deletions src/librustc/middle/infer/region_inference/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@ use self::Classification::*;

use super::{RegionVariableOrigin, SubregionOrigin, TypeTrace, MiscVariable};

use rustc_data_structures::graph::{self, Direction, NodeIndex};
use middle::region;
use middle::ty::{self, Ty};
use middle::ty::{BoundRegion, FreeRegion, Region, RegionVid};
use middle::ty::{ReEmpty, ReStatic, ReInfer, ReFree, ReEarlyBound};
use middle::ty::{ReLateBound, ReScope, ReVar, ReSkolemized, BrFresh};
use middle::ty_relate::RelateResult;
use middle::graph;
use middle::graph::{Direction, NodeIndex};
use util::common::indenter;
use util::nodemap::{FnvHashMap, FnvHashSet};
use util::ppaux::{Repr, UserString};
Expand Down Expand Up @@ -1325,10 +1324,8 @@ impl<'a, 'tcx> RegionVarBindings<'a, 'tcx> {
let num_vars = self.num_vars();

let constraints = self.constraints.borrow();
let num_edges = constraints.len();

let mut graph = graph::Graph::with_capacity(num_vars as usize + 1,
num_edges);
let mut graph = graph::Graph::new();

for _ in 0..num_vars {
graph.add_node(());
Expand Down Expand Up @@ -1370,10 +1367,10 @@ impl<'a, 'tcx> RegionVarBindings<'a, 'tcx> {
// not contained by an upper-bound.
let (mut lower_bounds, lower_dup) =
self.collect_concrete_regions(graph, var_data, node_idx,
graph::Incoming, dup_vec);
graph::INCOMING, dup_vec);
let (mut upper_bounds, upper_dup) =
self.collect_concrete_regions(graph, var_data, node_idx,
graph::Outgoing, dup_vec);
graph::OUTGOING, dup_vec);

if lower_dup || upper_dup {
return;
Expand Down Expand Up @@ -1433,7 +1430,7 @@ impl<'a, 'tcx> RegionVarBindings<'a, 'tcx> {
// that have no intersection.
let (upper_bounds, dup_found) =
self.collect_concrete_regions(graph, var_data, node_idx,
graph::Outgoing, dup_vec);
graph::OUTGOING, dup_vec);

if dup_found {
return;
Expand Down Expand Up @@ -1508,8 +1505,8 @@ impl<'a, 'tcx> RegionVarBindings<'a, 'tcx> {
// figure out the direction from which this node takes its
// values, and search for concrete regions etc in that direction
let dir = match classification {
Expanding => graph::Incoming,
Contracting => graph::Outgoing,
Expanding => graph::INCOMING,
Contracting => graph::OUTGOING,
};

process_edges(self, &mut state, graph, node_idx, dir);
Expand All @@ -1519,14 +1516,14 @@ impl<'a, 'tcx> RegionVarBindings<'a, 'tcx> {
return (result, dup_found);

fn process_edges<'a, 'tcx>(this: &RegionVarBindings<'a, 'tcx>,
state: &mut WalkState<'tcx>,
graph: &RegionGraph,
source_vid: RegionVid,
dir: Direction) {
state: &mut WalkState<'tcx>,
graph: &RegionGraph,
source_vid: RegionVid,
dir: Direction) {
debug!("process_edges(source_vid={:?}, dir={:?})", source_vid, dir);

let source_node_index = NodeIndex(source_vid.index as usize);
graph.each_adjacent_edge(source_node_index, dir, |_, edge| {
for (_, edge) in graph.adjacent_edges(source_node_index, dir) {
match edge.data {
ConstrainVarSubVar(from_vid, to_vid) => {
let opp_vid =
Expand All @@ -1544,8 +1541,7 @@ impl<'a, 'tcx> RegionVarBindings<'a, 'tcx> {
});
}
}
true
});
}
}
}

Expand Down
8 changes: 3 additions & 5 deletions src/librustc/middle/infer/type_variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::cmp::min;
use std::marker::PhantomData;
use std::mem;
use std::u32;
use util::snapshot_vec as sv;
use rustc_data_structures::snapshot_vec as sv;

pub struct TypeVariableTable<'tcx> {
values: sv::SnapshotVec<Delegate<'tcx>>,
Expand Down Expand Up @@ -65,7 +65,7 @@ impl RelationDir {

impl<'tcx> TypeVariableTable<'tcx> {
pub fn new() -> TypeVariableTable<'tcx> {
TypeVariableTable { values: sv::SnapshotVec::new(Delegate(PhantomData)) }
TypeVariableTable { values: sv::SnapshotVec::new() }
}

fn relations<'a>(&'a mut self, a: ty::TyVid) -> &'a mut Vec<Relation> {
Expand Down Expand Up @@ -201,9 +201,7 @@ impl<'tcx> sv::SnapshotVecDelegate for Delegate<'tcx> {
type Value = TypeVariableData<'tcx>;
type Undo = UndoEntry;

fn reverse(&mut self,
values: &mut Vec<TypeVariableData<'tcx>>,
action: UndoEntry) {
fn reverse(values: &mut Vec<TypeVariableData<'tcx>>, action: UndoEntry) {
match action {
SpecifyVar(vid, relations) => {
values[vid.index as usize].value = Bounded(relations);
Expand Down
48 changes: 48 additions & 0 deletions src/librustc/middle/infer/unify_key.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2012-2014 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use middle::ty::{self, IntVarValue, Ty};
use rustc_data_structures::unify::UnifyKey;
use syntax::ast;

pub trait ToType<'tcx> {
fn to_type(&self, tcx: &ty::ctxt<'tcx>) -> Ty<'tcx>;
}

impl UnifyKey for ty::IntVid {
type Value = Option<IntVarValue>;
fn index(&self) -> u32 { self.index }
fn from_index(i: u32) -> ty::IntVid { ty::IntVid { index: i } }
fn tag(_: Option<ty::IntVid>) -> &'static str { "IntVid" }
}

impl<'tcx> ToType<'tcx> for IntVarValue {
fn to_type(&self, tcx: &ty::ctxt<'tcx>) -> Ty<'tcx> {
match *self {
ty::IntType(i) => ty::mk_mach_int(tcx, i),
ty::UintType(i) => ty::mk_mach_uint(tcx, i),
}
}
}

// Floating point type keys

impl UnifyKey for ty::FloatVid {
type Value = Option<ast::FloatTy>;
fn index(&self) -> u32 { self.index }
fn from_index(i: u32) -> ty::FloatVid { ty::FloatVid { index: i } }
fn tag(_: Option<ty::FloatVid>) -> &'static str { "FloatVid" }
}

impl<'tcx> ToType<'tcx> for ast::FloatTy {
fn to_type(&self, tcx: &ty::ctxt<'tcx>) -> Ty<'tcx> {
ty::mk_mach_float(tcx, *self)
}
}
42 changes: 42 additions & 0 deletions src/librustc_data_structures/bitvec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2015 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use std::iter;

/// A very simple BitVector type.
pub struct BitVector {
data: Vec<u64>
}

impl BitVector {
pub fn new(num_bits: usize) -> BitVector {
let num_words = (num_bits + 63) / 64;
BitVector { data: iter::repeat(0).take(num_words).collect() }
}

fn word_mask(&self, bit: usize) -> (usize, u64) {
let word = bit / 64;
let mask = 1 << (bit % 64);
(word, mask)
}

pub fn contains(&self, bit: usize) -> bool {
let (word, mask) = self.word_mask(bit);
(self.data[word] & mask) != 0
}

pub fn insert(&mut self, bit: usize) -> bool {
let (word, mask) = self.word_mask(bit);
let data = &mut self.data[word];
let value = *data;
*data = value | mask;
(value | mask) != value
}
}
Loading

0 comments on commit 77213d1

Please sign in to comment.