Skip to content

Commit

Permalink
wmc functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
stites committed Oct 24, 2023
1 parent 12a7a0e commit 858dbfd
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 3 deletions.
75 changes: 72 additions & 3 deletions src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::os::raw::c_char;
use std::{collections::HashMap, ffi::CStr};

use crate::repr::DDNNFPtr;
use crate::util::semirings::Semiring;
use crate::util::semirings::{RealSemiring, Semiring};
use crate::{
builder::{bdd::RobddBuilder, cache::AllIteTable, BottomUpBuilder},
constants::primes,
Expand Down Expand Up @@ -103,12 +103,20 @@ pub unsafe extern "C" fn mk_bdd_manager_default_order(num_vars: u64) -> *mut Rsd

#[no_mangle]
#[allow(clippy::missing_safety_doc)]
pub unsafe extern "C" fn bdd_new_var(
pub unsafe extern "C" fn bdd_new_label(builder: *mut RsddBddBuilder) -> u64 {
let builder = robdd_builder_from_ptr(builder);
builder.new_label().value()
}

#[no_mangle]
#[allow(clippy::missing_safety_doc)]
pub unsafe extern "C" fn bdd_var(
builder: *mut RsddBddBuilder,
label: u64,
polarity: bool,
) -> *mut BddPtr<'static> {
let builder = robdd_builder_from_ptr(builder);
let (_, ptr) = builder.new_var(polarity);
let ptr = builder.var(VarLabel::new(label), polarity);
Box::into_raw(Box::new(ptr))
}

Expand Down Expand Up @@ -225,3 +233,64 @@ pub unsafe extern "C" fn bdd_low(bdd: *mut BddPtr<'static>) -> *mut BddPtr<'stat
pub unsafe extern "C" fn bdd_high(bdd: *mut BddPtr<'static>) -> *mut BddPtr<'static> {
Box::into_raw(Box::new((*bdd).high()))
}

#[no_mangle]
#[allow(clippy::missing_safety_doc)]
pub unsafe extern "C" fn print_bdd(bdd: *mut BddPtr<'static>) -> *const c_char {
let s = std::ffi::CString::new((*bdd).print_bdd()).unwrap();
let p = s.as_ptr();
std::mem::forget(s);
p
}

#[no_mangle]
#[allow(clippy::missing_safety_doc)]
pub unsafe extern "C" fn bdd_wmc(
bdd: *mut BddPtr<'static>,
wmc: *mut WmcParams<RealSemiring>,
) -> f64 {
DDNNFPtr::unsmoothed_wmc(&(*bdd), &(*wmc)).0
}

#[no_mangle]
#[allow(clippy::missing_safety_doc)]
pub unsafe extern "C" fn new_wmc_params_f64() -> *mut WmcParams<RealSemiring> {
Box::into_raw(Box::new(WmcParams::new(HashMap::from([]))))
}

#[no_mangle]
#[allow(clippy::missing_safety_doc)]
pub unsafe extern "C" fn wmc_param_f64_set_weight(
weights: *mut WmcParams<RealSemiring>,
var: u64,
low: f64,
high: f64,
) {
(*weights).set_weight(VarLabel::new(var), RealSemiring(low), RealSemiring(high))
}

#[repr(C)]
#[derive(Clone, Copy)]
pub struct WeightF64(pub f64, pub f64);

#[no_mangle]
#[allow(clippy::missing_safety_doc)]
pub unsafe extern "C" fn wmc_param_f64_var_weight(
weights: *mut WmcParams<RealSemiring>,
var: u64,
) -> WeightF64 {
let (l, h) = (*weights).var_weight(VarLabel::new(var));
WeightF64(l.0, h.0)
}

#[no_mangle]
#[allow(clippy::missing_safety_doc)]
pub unsafe extern "C" fn weight_f64_lo(w: WeightF64) -> f64 {
w.0
}

#[no_mangle]
#[allow(clippy::missing_safety_doc)]
pub unsafe extern "C" fn weight_f64_hi(w: WeightF64) -> f64 {
w.1
}
1 change: 1 addition & 0 deletions src/repr/wmc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use core::fmt::Debug;
use std::collections::HashMap;
/// Weighted model counting parameters for a BDD. It primarily is a storage for
/// the weight on each variable.
#[repr(C)]
#[derive(Clone)]
pub struct WmcParams<T: Semiring> {
pub zero: T,
Expand Down

0 comments on commit 858dbfd

Please sign in to comment.