Skip to content

Commit

Permalink
Align with rust naming conventions (#146)
Browse files Browse the repository at this point in the history
See: https://rust-lang.github.io/api-guidelines/naming.html.

Areas of changes:
- removing `get_`, convert `idx` -> `index`
    - `BinarySdd::get_scratch()` -> `BinarySdd::scratch()`
    - `BottomUpBuilder::get_vtree()` -> `BottomUpBuilder::vtree()`
    - `BottomUpBuilder::get_vtree_idx()` -> `BottomUpBuilder::vtree_index()`
    - `BottomUpBuilder::get_vtree_manager()` -> `BottomUpBuilder::vtree_manager()`
    - `Cnf::get_hasher()` -> `Cnf::hasher()`
    - `ImportanceSampler::get_state_index()` -> `ImportanceSampler::state_index()`
    - `RobddBuilder::get_order()` -> `RobddBuilder::order()`
    - `SatSolver::get_cur_hash()` -> `SatSolver::cur_hash()`
    - `SatSolver::get_difference()` -> `SatSolver::difference_iter()`
    - `SddOr::get_scratch()` -> `SddOr::scratch()`
    - `VarLabel::get_label()` -> `VarLabel::label()`
    - `VarLabel::get_polarity()` -> `VarLabel::polarity()`
    - `VTree::get_all_vars()` -> `VTree::all_vars()`
    - `VTree::vtree_idx` -> `VTree::vtree_index`
    - `VTree::get_varlabel_idx()` -> `VTree::var_index()`
    - `WmcParams::get_weight()` -> `WmcParams::assignment_weight()`
- removing dead code
    - `Bdd::Assignment`
    - `VarOrder::get_var_to_pos_vec()`
  • Loading branch information
mattxwang authored Jul 20, 2023
1 parent fdb32ec commit 3399c56
Show file tree
Hide file tree
Showing 25 changed files with 230 additions and 268 deletions.
16 changes: 8 additions & 8 deletions src/builder/bdd/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ pub trait BddBuilder<'a>: BottomUpBuilder<'a, BddPtr<'a>> {
for clause in clauses.iter() {
let mut cur_ptr = BddPtr::false_ptr();
for lit in clause.iter() {
match assgn.get(lit.get_label()) {
match assgn.get(lit.label()) {
None => {
let new_v = self.var(lit.get_label(), lit.get_polarity());
let new_v = self.var(lit.label(), lit.polarity());
cur_ptr = self.or(new_v, cur_ptr);
}
Some(v) if v == lit.get_polarity() => {
Some(v) if v == lit.polarity() => {
cur_ptr = BddPtr::true_ptr();
break;
}
Expand Down Expand Up @@ -99,7 +99,7 @@ pub trait BddBuilder<'a>: BottomUpBuilder<'a, BddPtr<'a>> {
let fst1 = c1
.iter()
.max_by(|l1, l2| {
if self.less_than(l1.get_label(), l2.get_label()) {
if self.less_than(l1.label(), l2.label()) {
Ordering::Less
} else {
Ordering::Equal
Expand All @@ -109,25 +109,25 @@ pub trait BddBuilder<'a>: BottomUpBuilder<'a, BddPtr<'a>> {
let fst2 = c2
.iter()
.max_by(|l1, l2| {
if self.less_than(l1.get_label(), l2.get_label()) {
if self.less_than(l1.label(), l2.label()) {
Ordering::Less
} else {
Ordering::Equal
}
})
.unwrap();
if self.less_than(fst1.get_label(), fst2.get_label()) {
if self.less_than(fst1.label(), fst2.label()) {
Ordering::Less
} else {
Ordering::Equal
}
});

for lit_vec in cnf_sorted.iter() {
let (vlabel, val) = (lit_vec[0].get_label(), lit_vec[0].get_polarity());
let (vlabel, val) = (lit_vec[0].label(), lit_vec[0].polarity());
let mut bdd = self.var(vlabel, val);
for lit in lit_vec {
let (vlabel, val) = (lit.get_label(), lit.get_polarity());
let (vlabel, val) = (lit.label(), lit.polarity());
let var = self.var(vlabel, val);
bdd = self.or(bdd, var);
}
Expand Down
17 changes: 1 addition & 16 deletions src/builder/bdd/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::repr::{bdd::BddPtr, var_label::VarLabel};
use crate::repr::bdd::BddPtr;
use std::cmp::Ordering;

mod builder;
Expand Down Expand Up @@ -35,18 +35,3 @@ impl<'a> PartialOrd for CompiledCNF<'a> {
Some(self.cmp(other))
}
}

#[derive(Debug)]
pub struct Assignment {
assignments: Vec<bool>,
}

impl Assignment {
pub fn new(assignments: Vec<bool>) -> Assignment {
Assignment { assignments }
}

pub fn get_assignment(&self, var: VarLabel) -> bool {
self.assignments[var.value() as usize]
}
}
20 changes: 10 additions & 10 deletions src/builder/bdd/robdd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl<'a, T: LruTable<'a, BddPtr<'a>>> RobddBuilder<'a, T> {

/// Get the current variable order
#[inline]
pub fn get_order(&self) -> &VarOrder {
pub fn order(&self) -> &VarOrder {
// TODO fix this, it doesn't need to be unsafe
unsafe { &*self.order.as_ptr() }
}
Expand Down Expand Up @@ -198,7 +198,7 @@ impl<'a, T: LruTable<'a, BddPtr<'a>>> RobddBuilder<'a, T> {
}

// check cache
match bdd.get_scratch::<usize>() {
match bdd.scratch::<usize>() {
None => (),
Some(v) => {
return if bdd.is_neg() {
Expand Down Expand Up @@ -252,7 +252,7 @@ impl<'a, T: LruTable<'a, BddPtr<'a>>> RobddBuilder<'a, T> {
// TODO: optimize this
let mut bdd = bdd;
for m in m.assignment_iter() {
bdd = self.condition(bdd, m.get_label(), m.get_polarity());
bdd = self.condition(bdd, m.label(), m.polarity());
}
bdd
}
Expand Down Expand Up @@ -387,7 +387,7 @@ mod tests {
(VarLabel::new(1), (RealSemiring(0.1), RealSemiring(0.9))),
]);
let params = WmcParams::new(weights);
let wmc = r1.wmc(builder.get_order().borrow(), &params);
let wmc = r1.wmc(builder.order().borrow(), &params);
assert!((wmc.0 - (1.0 - 0.2 * 0.1)).abs() < 0.000001);
}

Expand Down Expand Up @@ -605,7 +605,7 @@ mod tests {
let and1 = builder.and(iff1, iff2);
let f = builder.and(and1, obs);
assert_eq!(
f.wmc(builder.get_order().borrow(), &wmc).0,
f.wmc(builder.order().borrow(), &wmc).0,
0.2 * 0.3 + 0.2 * 0.7 + 0.8 * 0.3
);
}
Expand Down Expand Up @@ -655,9 +655,9 @@ mod tests {
(VarLabel::new(2), (FiniteField::new(1), FiniteField::new(1))),
]));

let unsmoothed_model_count = bdd.wmc(builder.get_order(), &weights);
let unsmoothed_model_count = bdd.wmc(builder.order(), &weights);

let smoothed_model_count = smoothed.wmc(builder.get_order(), &weights);
let smoothed_model_count = smoothed.wmc(builder.order(), &weights);

assert_eq!(unsmoothed_model_count.value(), 3);
assert_eq!(smoothed_model_count.value(), 7);
Expand All @@ -681,7 +681,7 @@ mod tests {
let smoothed = builder.smooth(bdd, cnf.num_vars());

let weighted_model_count = smoothed.wmc(
builder.get_order(),
builder.order(),
&WmcParams::<RealSemiring>::new(HashMap::from_iter([
(VarLabel::new(0), (RealSemiring(0.4), RealSemiring(0.6))),
(VarLabel::new(1), (RealSemiring(0.3), RealSemiring(0.7))),
Expand Down Expand Up @@ -709,7 +709,7 @@ mod tests {
let smoothed = builder.smooth(bdd, cnf.num_vars());

let model_count = smoothed.wmc(
builder.get_order(),
builder.order(),
&WmcParams::<FiniteField<1000001>>::new(HashMap::from_iter([
(VarLabel::new(0), (FiniteField::new(1), FiniteField::new(1))),
(VarLabel::new(1), (FiniteField::new(1), FiniteField::new(1))),
Expand All @@ -722,7 +722,7 @@ mod tests {

// TODO: this WMC test is broken. not sure why :(
// let weighted_model_count = smoothed.wmc(
// builder.get_order(),
// builder.order(),
// &WmcParams::new(HashMap::from_iter([
// // (VarLabel::new(0), (RealSemiring(0.10), RealSemiring(0.05))),
// // (VarLabel::new(1), (RealSemiring(0.20), RealSemiring(0.15))),
Expand Down
26 changes: 13 additions & 13 deletions src/builder/decision_nnf/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ pub trait DecisionNNFBuilder<'a>: TopDownBuilder<'a, BddPtr<'a>> {
}
let mut sub = nnf;
for l in literals {
let node = if l.get_polarity() {
BddNode::new(l.get_label(), BddPtr::false_ptr(), sub)
let node = if l.polarity() {
BddNode::new(l.label(), BddPtr::false_ptr(), sub)
} else {
BddNode::new(l.get_label(), sub, BddPtr::false_ptr())
BddNode::new(l.label(), sub, BddPtr::false_ptr())
};
sub = self.get_or_insert(node);
}
Expand Down Expand Up @@ -73,7 +73,7 @@ pub trait DecisionNNFBuilder<'a>: TopDownBuilder<'a, BddPtr<'a>> {
}

// check cache
let hashed = sat.get_cur_hash();
let hashed = sat.cur_hash();
match cache.get(&hashed) {
None => (),
Some(v) => {
Expand All @@ -85,14 +85,14 @@ pub trait DecisionNNFBuilder<'a>: TopDownBuilder<'a, BddPtr<'a>> {
let high_bdd = match sat.decide(Literal::new(cur_v, true)) {
DecisionResult::UNSAT => BddPtr::false_ptr(),
DecisionResult::SAT => {
let new_assgn = sat.get_difference().filter(|x| x.get_label() != cur_v);
let new_assgn = sat.difference_iter().filter(|x| x.label() != cur_v);
let r = self.conjoin_implied(new_assgn, BddPtr::true_ptr());
sat.pop();
r
}
DecisionResult::Unknown => {
let sub = self.topdown_h(cnf, sat, level + 1, cache);
let new_assgn = sat.get_difference().filter(|x| x.get_label() != cur_v);
let new_assgn = sat.difference_iter().filter(|x| x.label() != cur_v);
let r = self.conjoin_implied(new_assgn, sub);
sat.pop();
r
Expand All @@ -101,14 +101,14 @@ pub trait DecisionNNFBuilder<'a>: TopDownBuilder<'a, BddPtr<'a>> {
let low_bdd = match sat.decide(Literal::new(cur_v, false)) {
DecisionResult::UNSAT => BddPtr::false_ptr(),
DecisionResult::SAT => {
let new_assgn = sat.get_difference().filter(|x| x.get_label() != cur_v);
let new_assgn = sat.difference_iter().filter(|x| x.label() != cur_v);
let r = self.conjoin_implied(new_assgn, BddPtr::true_ptr());
sat.pop();
r
}
DecisionResult::Unknown => {
let sub = self.topdown_h(cnf, sat, level + 1, cache);
let new_assgn = sat.get_difference().filter(|x| x.get_label() != cur_v);
let new_assgn = sat.difference_iter().filter(|x| x.label() != cur_v);
let r = self.conjoin_implied(new_assgn, sub);
sat.pop();
r
Expand All @@ -135,11 +135,11 @@ pub trait DecisionNNFBuilder<'a>: TopDownBuilder<'a, BddPtr<'a>> {
let mut r = self.topdown_h(cnf, &mut sat, 0, &mut FxHashMap::default());

// conjoin in any initially implied literals
for l in sat.get_difference() {
let node = if l.get_polarity() {
BddNode::new(l.get_label(), BddPtr::false_ptr(), r)
for l in sat.difference_iter() {
let node = if l.polarity() {
BddNode::new(l.label(), BddPtr::false_ptr(), r)
} else {
BddNode::new(l.get_label(), r, BddPtr::false_ptr())
BddNode::new(l.label(), r, BddPtr::false_ptr())
};
r = self.get_or_insert(node);
}
Expand All @@ -159,7 +159,7 @@ pub trait DecisionNNFBuilder<'a>: TopDownBuilder<'a, BddPtr<'a>> {
}
BddPtr::Reg(node) | BddPtr::Compl(node) => {
// check cache
if let Some(v) = bdd.get_scratch::<BddPtr>() {
if let Some(v) = bdd.scratch::<BddPtr>() {
return if bdd.is_neg() { v.neg() } else { v };
}

Expand Down
56 changes: 23 additions & 33 deletions src/builder/sdd/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub struct SddBuilderStats {

pub trait SddBuilder<'a>: BottomUpBuilder<'a, SddPtr<'a>> {
// internal data structures
fn get_vtree_manager(&self) -> &VTreeManager;
fn vtree_manager(&self) -> &VTreeManager;

fn app_cache_get(&self, and: &SddAnd<'a>) -> Option<SddPtr<'a>>;
fn app_cache_insert(&self, and: SddAnd<'a>, ptr: SddPtr<'a>);
Expand Down Expand Up @@ -117,7 +117,7 @@ pub trait SddBuilder<'a>: BottomUpBuilder<'a, SddPtr<'a>> {
/// a is prime to b
fn and_indep(&'a self, a: SddPtr<'a>, b: SddPtr<'a>, lca: VTreeIndex) -> SddPtr<'a> {
// check if this is a right-linear fragment and construct the relevant SDD type
if self.get_vtree_manager().get_idx(lca).is_right_linear() {
if self.vtree_manager().vtree(lca).is_right_linear() {
// a is a right-linear decision for b; construct a binary decision
let bdd = match a {
SddPtr::Var(label, true) => BinarySDD::new(label, SddPtr::false_ptr(), b, lca),
Expand Down Expand Up @@ -251,7 +251,7 @@ pub trait SddBuilder<'a>: BottomUpBuilder<'a, SddPtr<'a>> {
// check if a and b are both binary SDDs; if so, we apply BDD conjunction here

if let SddPtr::BDD(or) | SddPtr::ComplBDD(or) = a {
if self.get_vtree_manager().get_idx(lca).is_right_linear() {
if self.vtree_manager().vtree(lca).is_right_linear() {
let l = self.and(a.low(), b.low());
let h = self.and(a.high(), b.high());
return self.unique_bdd(BinarySDD::new(or.label(), l, h, lca));
Expand Down Expand Up @@ -318,28 +318,24 @@ pub trait SddBuilder<'a>: BottomUpBuilder<'a, SddPtr<'a>> {

// helpers

fn get_vtree_root(&self) -> &VTree {
self.get_vtree_manager().vtree_root()
}

fn num_vars(&self) -> usize {
self.get_vtree_manager().num_vars()
self.vtree_manager().num_vars()
}

fn get_vtree(&self, ptr: SddPtr) -> &VTree {
fn vtree(&self, ptr: SddPtr) -> &VTree {
match ptr {
SddPtr::Var(lbl, _) => {
let idx = self.get_vtree_manager().get_varlabel_idx(lbl);
self.get_vtree_manager().get_idx(idx)
let idx = self.vtree_manager().var_index(lbl);
self.vtree_manager().vtree(idx)
}
SddPtr::Compl(_) | SddPtr::Reg(_) => self.get_vtree_manager().get_idx(ptr.vtree()),
SddPtr::Compl(_) | SddPtr::Reg(_) => self.vtree_manager().vtree(ptr.vtree()),
_ => panic!("called vtree on constant"),
}
}

fn get_vtree_idx(&self, ptr: SddPtr) -> VTreeIndex {
fn vtree_index(&self, ptr: SddPtr) -> VTreeIndex {
match ptr {
SddPtr::Var(lbl, _) => self.get_vtree_manager().get_varlabel_idx(lbl),
SddPtr::Var(lbl, _) => self.vtree_manager().var_index(lbl),
SddPtr::BDD(_) | SddPtr::ComplBDD(_) | SddPtr::Compl(_) | SddPtr::Reg(_) => ptr.vtree(),
_ => panic!("called vtree on constant"),
}
Expand All @@ -363,10 +359,7 @@ pub trait SddBuilder<'a>: BottomUpBuilder<'a, SddPtr<'a>> {
let fst1 = c1
.iter()
.max_by(|l1, l2| {
if self
.get_vtree_manager()
.is_prime_var(l1.get_label(), l2.get_label())
{
if self.vtree_manager().is_prime_var(l1.label(), l2.label()) {
Ordering::Less
} else {
Ordering::Equal
Expand All @@ -376,19 +369,16 @@ pub trait SddBuilder<'a>: BottomUpBuilder<'a, SddPtr<'a>> {
let fst2 = c2
.iter()
.max_by(|l1, l2| {
if self
.get_vtree_manager()
.is_prime_var(l1.get_label(), l2.get_label())
{
if self.vtree_manager().is_prime_var(l1.label(), l2.label()) {
Ordering::Less
} else {
Ordering::Equal
}
})
.unwrap();
if self
.get_vtree_manager()
.is_prime_var(fst1.get_label(), fst2.get_label())
.vtree_manager()
.is_prime_var(fst1.label(), fst2.label())
{
Ordering::Less
} else {
Expand All @@ -397,10 +387,10 @@ pub trait SddBuilder<'a>: BottomUpBuilder<'a, SddPtr<'a>> {
});

for lit_vec in cnf_sorted.iter() {
let (vlabel, val) = (lit_vec[0].get_label(), lit_vec[0].get_polarity());
let (vlabel, val) = (lit_vec[0].label(), lit_vec[0].polarity());
let mut bdd = SddPtr::Var(vlabel, val);
for lit in lit_vec {
let (vlabel, val) = (lit.get_label(), lit.get_polarity());
let (vlabel, val) = (lit.label(), lit.polarity());
let var = SddPtr::Var(vlabel, val);
bdd = self.or(bdd, var);
}
Expand Down Expand Up @@ -568,10 +558,10 @@ where
};

// normalize so `a` is always prime if possible
let (a, b) = if self.get_vtree_idx(a) == self.get_vtree_idx(b)
let (a, b) = if self.vtree_index(a) == self.vtree_index(b)
|| self
.get_vtree_manager()
.is_prime_index(self.get_vtree_idx(a), self.get_vtree_idx(b))
.vtree_manager()
.is_prime_index(self.vtree_index(a), self.vtree_index(b))
{
(a, b)
} else {
Expand All @@ -583,9 +573,9 @@ where
return x;
}

let av = self.get_vtree_idx(a);
let bv = self.get_vtree_idx(b);
let lca = self.get_vtree_manager().lca(av, bv);
let av = self.vtree_index(a);
let bv = self.vtree_index(b);
let lca = self.vtree_manager().lca(av, bv);

// now we determine the current iterator for primes and subs
// consider the following example vtree:
Expand Down Expand Up @@ -670,7 +660,7 @@ where

/// Computes the SDD representing the logical function `if f then g else h`
fn ite(&'a self, f: SddPtr<'a>, g: SddPtr<'a>, h: SddPtr<'a>) -> SddPtr<'a> {
let ite = Ite::new(|a, b| self.get_vtree_manager().is_prime(a, b), f, g, h);
let ite = Ite::new(|a, b| self.vtree_manager().is_prime(a, b), f, g, h);
if let Ite::IteConst(f) = ite {
return f;
}
Expand Down
Loading

0 comments on commit 3399c56

Please sign in to comment.