Skip to content

Commit

Permalink
Batch one of removing get_, standardizing idx -> index
Browse files Browse the repository at this point in the history
  • Loading branch information
mattxwang committed Jul 20, 2023
1 parent fdb32ec commit e2186c3
Show file tree
Hide file tree
Showing 17 changed files with 103 additions and 122 deletions.
17 changes: 1 addition & 16 deletions src/builder/bdd/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::repr::{bdd::BddPtr, var_label::VarLabel};
use crate::repr::bdd::BddPtr;
use std::cmp::Ordering;

mod builder;
Expand Down Expand Up @@ -35,18 +35,3 @@ impl<'a> PartialOrd for CompiledCNF<'a> {
Some(self.cmp(other))
}
}

#[derive(Debug)]
pub struct Assignment {
assignments: Vec<bool>,
}

impl Assignment {
pub fn new(assignments: Vec<bool>) -> Assignment {
Assignment { assignments }
}

pub fn get_assignment(&self, var: VarLabel) -> bool {
self.assignments[var.value() as usize]
}
}
16 changes: 8 additions & 8 deletions src/builder/bdd/robdd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl<'a, T: LruTable<'a, BddPtr<'a>>> RobddBuilder<'a, T> {

/// Get the current variable order
#[inline]
pub fn get_order(&self) -> &VarOrder {
pub fn order(&self) -> &VarOrder {
// TODO fix this, it doesn't need to be unsafe
unsafe { &*self.order.as_ptr() }
}
Expand Down Expand Up @@ -387,7 +387,7 @@ mod tests {
(VarLabel::new(1), (RealSemiring(0.1), RealSemiring(0.9))),
]);
let params = WmcParams::new(weights);
let wmc = r1.wmc(builder.get_order().borrow(), &params);
let wmc = r1.wmc(builder.order().borrow(), &params);
assert!((wmc.0 - (1.0 - 0.2 * 0.1)).abs() < 0.000001);
}

Expand Down Expand Up @@ -605,7 +605,7 @@ mod tests {
let and1 = builder.and(iff1, iff2);
let f = builder.and(and1, obs);
assert_eq!(
f.wmc(builder.get_order().borrow(), &wmc).0,
f.wmc(builder.order().borrow(), &wmc).0,
0.2 * 0.3 + 0.2 * 0.7 + 0.8 * 0.3
);
}
Expand Down Expand Up @@ -655,9 +655,9 @@ mod tests {
(VarLabel::new(2), (FiniteField::new(1), FiniteField::new(1))),
]));

let unsmoothed_model_count = bdd.wmc(builder.get_order(), &weights);
let unsmoothed_model_count = bdd.wmc(builder.order(), &weights);

let smoothed_model_count = smoothed.wmc(builder.get_order(), &weights);
let smoothed_model_count = smoothed.wmc(builder.order(), &weights);

assert_eq!(unsmoothed_model_count.value(), 3);
assert_eq!(smoothed_model_count.value(), 7);
Expand All @@ -681,7 +681,7 @@ mod tests {
let smoothed = builder.smooth(bdd, cnf.num_vars());

let weighted_model_count = smoothed.wmc(
builder.get_order(),
builder.order(),
&WmcParams::<RealSemiring>::new(HashMap::from_iter([
(VarLabel::new(0), (RealSemiring(0.4), RealSemiring(0.6))),
(VarLabel::new(1), (RealSemiring(0.3), RealSemiring(0.7))),
Expand Down Expand Up @@ -709,7 +709,7 @@ mod tests {
let smoothed = builder.smooth(bdd, cnf.num_vars());

let model_count = smoothed.wmc(
builder.get_order(),
builder.order(),
&WmcParams::<FiniteField<1000001>>::new(HashMap::from_iter([
(VarLabel::new(0), (FiniteField::new(1), FiniteField::new(1))),
(VarLabel::new(1), (FiniteField::new(1), FiniteField::new(1))),
Expand All @@ -722,7 +722,7 @@ mod tests {

// TODO: this WMC test is broken. not sure why :(
// let weighted_model_count = smoothed.wmc(
// builder.get_order(),
// builder.order(),
// &WmcParams::new(HashMap::from_iter([
// // (VarLabel::new(0), (RealSemiring(0.10), RealSemiring(0.05))),
// // (VarLabel::new(1), (RealSemiring(0.20), RealSemiring(0.15))),
Expand Down
10 changes: 5 additions & 5 deletions src/builder/decision_nnf/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ pub trait DecisionNNFBuilder<'a>: TopDownBuilder<'a, BddPtr<'a>> {
let high_bdd = match sat.decide(Literal::new(cur_v, true)) {
DecisionResult::UNSAT => BddPtr::false_ptr(),
DecisionResult::SAT => {
let new_assgn = sat.get_difference().filter(|x| x.get_label() != cur_v);
let new_assgn = sat.difference_iter().filter(|x| x.get_label() != cur_v);
let r = self.conjoin_implied(new_assgn, BddPtr::true_ptr());
sat.pop();
r
}
DecisionResult::Unknown => {
let sub = self.topdown_h(cnf, sat, level + 1, cache);
let new_assgn = sat.get_difference().filter(|x| x.get_label() != cur_v);
let new_assgn = sat.difference_iter().filter(|x| x.get_label() != cur_v);
let r = self.conjoin_implied(new_assgn, sub);
sat.pop();
r
Expand All @@ -101,14 +101,14 @@ pub trait DecisionNNFBuilder<'a>: TopDownBuilder<'a, BddPtr<'a>> {
let low_bdd = match sat.decide(Literal::new(cur_v, false)) {
DecisionResult::UNSAT => BddPtr::false_ptr(),
DecisionResult::SAT => {
let new_assgn = sat.get_difference().filter(|x| x.get_label() != cur_v);
let new_assgn = sat.difference_iter().filter(|x| x.get_label() != cur_v);
let r = self.conjoin_implied(new_assgn, BddPtr::true_ptr());
sat.pop();
r
}
DecisionResult::Unknown => {
let sub = self.topdown_h(cnf, sat, level + 1, cache);
let new_assgn = sat.get_difference().filter(|x| x.get_label() != cur_v);
let new_assgn = sat.difference_iter().filter(|x| x.get_label() != cur_v);
let r = self.conjoin_implied(new_assgn, sub);
sat.pop();
r
Expand All @@ -135,7 +135,7 @@ pub trait DecisionNNFBuilder<'a>: TopDownBuilder<'a, BddPtr<'a>> {
let mut r = self.topdown_h(cnf, &mut sat, 0, &mut FxHashMap::default());

// conjoin in any initially implied literals
for l in sat.get_difference() {
for l in sat.difference_iter() {
let node = if l.get_polarity() {
BddNode::new(l.get_label(), BddPtr::false_ptr(), r)
} else {
Expand Down
44 changes: 20 additions & 24 deletions src/builder/sdd/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub struct SddBuilderStats {

pub trait SddBuilder<'a>: BottomUpBuilder<'a, SddPtr<'a>> {
// internal data structures
fn get_vtree_manager(&self) -> &VTreeManager;
fn vtree_manager(&self) -> &VTreeManager;

fn app_cache_get(&self, and: &SddAnd<'a>) -> Option<SddPtr<'a>>;
fn app_cache_insert(&self, and: SddAnd<'a>, ptr: SddPtr<'a>);
Expand Down Expand Up @@ -117,7 +117,7 @@ pub trait SddBuilder<'a>: BottomUpBuilder<'a, SddPtr<'a>> {
/// a is prime to b
fn and_indep(&'a self, a: SddPtr<'a>, b: SddPtr<'a>, lca: VTreeIndex) -> SddPtr<'a> {
// check if this is a right-linear fragment and construct the relevant SDD type
if self.get_vtree_manager().get_idx(lca).is_right_linear() {
if self.vtree_manager().get_idx(lca).is_right_linear() {
// a is a right-linear decision for b; construct a binary decision
let bdd = match a {
SddPtr::Var(label, true) => BinarySDD::new(label, SddPtr::false_ptr(), b, lca),
Expand Down Expand Up @@ -251,7 +251,7 @@ pub trait SddBuilder<'a>: BottomUpBuilder<'a, SddPtr<'a>> {
// check if a and b are both binary SDDs; if so, we apply BDD conjunction here

if let SddPtr::BDD(or) | SddPtr::ComplBDD(or) = a {
if self.get_vtree_manager().get_idx(lca).is_right_linear() {
if self.vtree_manager().get_idx(lca).is_right_linear() {
let l = self.and(a.low(), b.low());
let h = self.and(a.high(), b.high());
return self.unique_bdd(BinarySDD::new(or.label(), l, h, lca));
Expand Down Expand Up @@ -318,28 +318,24 @@ pub trait SddBuilder<'a>: BottomUpBuilder<'a, SddPtr<'a>> {

// helpers

fn get_vtree_root(&self) -> &VTree {
self.get_vtree_manager().vtree_root()
}

fn num_vars(&self) -> usize {
self.get_vtree_manager().num_vars()
self.vtree_manager().num_vars()
}

fn get_vtree(&self, ptr: SddPtr) -> &VTree {
fn vtree(&self, ptr: SddPtr) -> &VTree {
match ptr {
SddPtr::Var(lbl, _) => {
let idx = self.get_vtree_manager().get_varlabel_idx(lbl);
self.get_vtree_manager().get_idx(idx)
let idx = self.vtree_manager().var_index(lbl);
self.vtree_manager().get_idx(idx)
}
SddPtr::Compl(_) | SddPtr::Reg(_) => self.get_vtree_manager().get_idx(ptr.vtree()),
SddPtr::Compl(_) | SddPtr::Reg(_) => self.vtree_manager().get_idx(ptr.vtree()),
_ => panic!("called vtree on constant"),
}
}

fn get_vtree_idx(&self, ptr: SddPtr) -> VTreeIndex {
fn vtree_index(&self, ptr: SddPtr) -> VTreeIndex {
match ptr {
SddPtr::Var(lbl, _) => self.get_vtree_manager().get_varlabel_idx(lbl),
SddPtr::Var(lbl, _) => self.vtree_manager().var_index(lbl),
SddPtr::BDD(_) | SddPtr::ComplBDD(_) | SddPtr::Compl(_) | SddPtr::Reg(_) => ptr.vtree(),
_ => panic!("called vtree on constant"),
}
Expand All @@ -364,7 +360,7 @@ pub trait SddBuilder<'a>: BottomUpBuilder<'a, SddPtr<'a>> {
.iter()
.max_by(|l1, l2| {
if self
.get_vtree_manager()
.vtree_manager()
.is_prime_var(l1.get_label(), l2.get_label())
{
Ordering::Less
Expand All @@ -377,7 +373,7 @@ pub trait SddBuilder<'a>: BottomUpBuilder<'a, SddPtr<'a>> {
.iter()
.max_by(|l1, l2| {
if self
.get_vtree_manager()
.vtree_manager()
.is_prime_var(l1.get_label(), l2.get_label())
{
Ordering::Less
Expand All @@ -387,7 +383,7 @@ pub trait SddBuilder<'a>: BottomUpBuilder<'a, SddPtr<'a>> {
})
.unwrap();
if self
.get_vtree_manager()
.vtree_manager()
.is_prime_var(fst1.get_label(), fst2.get_label())
{
Ordering::Less
Expand Down Expand Up @@ -568,10 +564,10 @@ where
};

// normalize so `a` is always prime if possible
let (a, b) = if self.get_vtree_idx(a) == self.get_vtree_idx(b)
let (a, b) = if self.vtree_index(a) == self.vtree_index(b)
|| self
.get_vtree_manager()
.is_prime_index(self.get_vtree_idx(a), self.get_vtree_idx(b))
.vtree_manager()
.is_prime_index(self.vtree_index(a), self.vtree_index(b))
{
(a, b)
} else {
Expand All @@ -583,9 +579,9 @@ where
return x;
}

let av = self.get_vtree_idx(a);
let bv = self.get_vtree_idx(b);
let lca = self.get_vtree_manager().lca(av, bv);
let av = self.vtree_index(a);
let bv = self.vtree_index(b);
let lca = self.vtree_manager().lca(av, bv);

// now we determine the current iterator for primes and subs
// consider the following example vtree:
Expand Down Expand Up @@ -670,7 +666,7 @@ where

/// Computes the SDD representing the logical function `if f then g else h`
fn ite(&'a self, f: SddPtr<'a>, g: SddPtr<'a>, h: SddPtr<'a>) -> SddPtr<'a> {
let ite = Ite::new(|a, b| self.get_vtree_manager().is_prime(a, b), f, g, h);
let ite = Ite::new(|a, b| self.vtree_manager().is_prime(a, b), f, g, h);
if let Ite::IteConst(f) = ite {
return f;
}
Expand Down
4 changes: 2 additions & 2 deletions src/builder/sdd/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub struct CompressionSddBuilder<'a> {

impl<'a> SddBuilder<'a> for CompressionSddBuilder<'a> {
#[inline]
fn get_vtree_manager(&self) -> &VTreeManager {
fn vtree_manager(&self) -> &VTreeManager {
&self.vtree
}

Expand Down Expand Up @@ -467,7 +467,7 @@ fn sdd_wmc1() {
let x_fx = builder.iff(x, fx);
let y_fy = builder.iff(y, fy);
let ptr = builder.and(x_fx, y_fy);
let wmc_res: RealSemiring = ptr.wmc(builder.get_vtree_manager(), &wmc_map);
let wmc_res: RealSemiring = ptr.wmc(builder.vtree_manager(), &wmc_map);
let expected = RealSemiring(1.0);
let diff = (wmc_res - expected).0.abs();
println!("sdd: {}", builder.print_sdd(ptr));
Expand Down
6 changes: 3 additions & 3 deletions src/builder/sdd/semantic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub struct SemanticSddBuilder<'a, const P: u128> {

impl<'a, const P: u128> SddBuilder<'a> for SemanticSddBuilder<'a, P> {
#[inline]
fn get_vtree_manager(&self) -> &VTreeManager {
fn vtree_manager(&self) -> &VTreeManager {
&self.vtree
}

Expand Down Expand Up @@ -260,8 +260,8 @@ fn prob_equiv_sdd_demorgan() {
let map: WmcParams<FiniteField<{ primes::U32_SMALL }>> =
create_semantic_hash_map(builder.num_vars());

let sh1 = res.cached_semantic_hash(builder.get_vtree_manager(), &map);
let sh2 = expected.cached_semantic_hash(builder.get_vtree_manager(), &map);
let sh1 = res.cached_semantic_hash(builder.vtree_manager(), &map);
let sh2 = expected.cached_semantic_hash(builder.vtree_manager(), &map);

assert!(sh1 == sh2, "Not eq:\nGot: {:?}\nExpected: {:?}", sh1, sh2);
}
12 changes: 6 additions & 6 deletions src/repr/bdd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ impl<'a> BddPtr<'a> {
) -> RealSemiring {
let mut v = self.bdd_fold(
&|varlabel, low, high| {
let (low_w, high_w) = wmc.get_var_weight(varlabel);
let (low_w, high_w) = wmc.var_weight(varlabel);
match partial_map_assgn.get(varlabel) {
None => {
if map_vars.contains(varlabel.value_usize()) {
Expand All @@ -517,7 +517,7 @@ impl<'a> BddPtr<'a> {
);
// multiply in weights of all variables in the partial assignment
for lit in partial_map_assgn.assignment_iter() {
let (l, h) = wmc.get_var_weight(lit.get_label());
let (l, h) = wmc.var_weight(lit.get_label());
if lit.get_polarity() {
v = v * (*h);
} else {
Expand Down Expand Up @@ -617,7 +617,7 @@ impl<'a> BddPtr<'a> {
self.bdd_fold(
&|varlabel, low: ExpectedUtility, high: ExpectedUtility| {
// get True and False weights for VarLabel
let (false_w, true_w) = wmc.get_var_weight(varlabel);
let (false_w, true_w) = wmc.var_weight(varlabel);
// Check if our partial model has already assigned my variable.
match partial_decisions.get(varlabel) {
// If not...
Expand Down Expand Up @@ -735,7 +735,7 @@ impl<'a> BddPtr<'a> {
{
let mut partial_join_acc = T::one();
for lit in partial_join_assgn.assignment_iter() {
let (l, h) = wmc.get_var_weight(lit.get_label());
let (l, h) = wmc.var_weight(lit.get_label());
if lit.get_polarity() {
partial_join_acc = partial_join_acc * (*h);
} else {
Expand All @@ -746,7 +746,7 @@ impl<'a> BddPtr<'a> {
let v = self.bdd_fold(
&|varlabel, low: T, high: T| {
// get True and False weights for node VarLabel
let (w_l, w_h) = wmc.get_var_weight(varlabel);
let (w_l, w_h) = wmc.var_weight(varlabel);
// Check if our partial model has already assigned the node.
match partial_join_assgn.get(varlabel) {
// If not...
Expand Down Expand Up @@ -1049,7 +1049,7 @@ impl<'a> BddNode<'a> {
order: &VarOrder,
map: &WmcParams<FiniteField<P>>,
) -> FiniteField<P> {
let (low_w, high_w) = map.get_var_weight(self.var);
let (low_w, high_w) = map.var_weight(self.var);
self.low.cached_semantic_hash(order, map) * (*low_w)
+ self.high.cached_semantic_hash(order, map) * (*high_w)
}
Expand Down
2 changes: 1 addition & 1 deletion src/repr/cnf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ impl Cnf {
let mut total: T = T::zero();
let mut weight_vec = Vec::new();
for i in 0..self.num_vars() {
weight_vec.push(weights.get_var_weight(VarLabel::new(i as u64)));
weight_vec.push(weights.var_weight(VarLabel::new(i as u64)));
}
for assgn in AssignmentIter::new(self.num_vars()) {
if assgn.is_empty() {
Expand Down
2 changes: 1 addition & 1 deletion src/repr/ddnnf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ pub trait DDNNFPtr<'a>: Clone + Debug + PartialEq + Eq + Hash + Copy {
True => params.one,
False => params.zero,
Lit(lbl, polarity) => {
let (low_w, high_w) = params.get_var_weight(lbl);
let (low_w, high_w) = params.var_weight(lbl);
if polarity {
*high_w
} else {
Expand Down
6 changes: 3 additions & 3 deletions src/repr/sdd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl<'a> SddPtr<'a> {
PtrTrue => FiniteField::new(1),
PtrFalse => FiniteField::new(0),
Var(label, polarity) => {
let (l_w, h_w) = map.get_var_weight(*label);
let (l_w, h_w) = map.var_weight(*label);
if *polarity {
*h_w
} else {
Expand Down Expand Up @@ -433,7 +433,7 @@ fn is_compressed_simple_bdd() {
VarLabel::new(2),
a,
b,
vtree_manager.get_varlabel_idx(VarLabel::new(2)),
vtree_manager.var_index(VarLabel::new(2)),
);
let binary_sdd_ptr = &mut binary_sdd;
let bdd_ptr = SddPtr::BDD(binary_sdd_ptr);
Expand All @@ -453,7 +453,7 @@ fn is_compressed_simple_bdd_duplicate() {
VarLabel::new(2),
a,
a, // duplicate with low - not compressed!
vtree_manager.get_varlabel_idx(VarLabel::new(2)),
vtree_manager.var_index(VarLabel::new(2)),
);
let binary_sdd_ptr = &mut binary_sdd;
let bdd_ptr = SddPtr::BDD(binary_sdd_ptr);
Expand Down
Loading

0 comments on commit e2186c3

Please sign in to comment.