Skip to content

Commit

Permalink
integrate lookahead v1 into repair loop
Browse files Browse the repository at this point in the history
this ports some functionality from lookahead solver for qfbv-sls into sls-smt.
  • Loading branch information
NikolajBjorner committed Dec 27, 2024
1 parent c581714 commit 5eb71c3
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 103 deletions.
3 changes: 0 additions & 3 deletions src/ast/sls/sls_bv_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ Module Name:
#include "util/stopwatch.h"
#include "util/lbool.h"
#include "ast/converters/model_converter.h"

#include "ast/sls/sls_stats.h"
#include "ast/sls/sls_bv_tracker.h"
#include "ast/sls/sls_bv_evaluator.h"
Expand Down Expand Up @@ -79,8 +78,6 @@ class sls_engine {
void mk_inv(unsigned bv_sz, const mpz & old_value, mpz & inverted);
void mk_flip(sort * s, const mpz & old_value, unsigned bit, mpz & flipped);



lbool search();

lbool search_loop();
Expand Down
41 changes: 7 additions & 34 deletions src/ast/sls/sls_bv_eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ namespace sls {
return r;
}


void bv_eval::init_eval_bv(app* e) {
if (bv.is_bv(e))
eval(e).commit_eval();
Expand All @@ -99,7 +98,7 @@ namespace sls {
if (m.is_eq(e, x, y))
return bv.is_bv(x);
if (m.is_ite(e))
return bv.is_bv(e->get_arg(0));
return bv.is_bv(e->get_arg(1));
if (e->get_family_id() == bv.get_fid()) {
switch (e->get_decl_kind()) {
case OP_BNEG_OVFL:
Expand Down Expand Up @@ -680,6 +679,8 @@ namespace sls {
expr* arg = e->get_arg(i);
if (m.is_value(arg))
return false;
if (m.is_bool(e) && false && m_rand(10) == 0 && m_lookahead.try_repair_down(e))
return true;
if (e->get_family_id() == bv.get_family_id() && try_repair_bv(e, i)) {
commit_eval(e, to_app(arg));
IF_VERBOSE(11, verbose_stream() << "repair " << mk_bounded_pp(e, m) << " : " << mk_bounded_pp(arg, m) << " := " << wval(arg) << "\n";);
Expand All @@ -692,9 +693,9 @@ namespace sls {
ctx.new_value_eh(arg);
return true;
}
if (m.is_eq(e) && bv.is_bv(arg)) {
return try_repair_eq_lookahead(e);
}
if (m.is_bool(e) && m_lookahead.try_repair_down(e))
return true;

return false;
}

Expand Down Expand Up @@ -882,37 +883,9 @@ namespace sls {
return false;
}

bool bv_eval::try_repair_eq_lookahead(app* e) {
return m_lookahead.try_repair_down(e);

}

bool bv_eval::try_repair_eq(bool is_true, bvval& a, bvval const& b) {
if (is_true) {
#if 0
if (bv.is_bv_add(t)) {
bvval tmp(b);
unsigned start = m_rand();
unsigned sz = to_app(t)->get_num_args();
for (unsigned i = 0; i < sz; ++i) {
unsigned j = (start + i) % sz;
for (unsigned k = 0; k < sz; ++k) {
if (k == j)
continue;
auto& c = wval(to_app(t)->get_arg(k));
set_sub(tmp, tmp, c.bits());
}

auto& c = wval(to_app(t)->get_arg(j));
verbose_stream() << "TRY " << c << " := " << tmp << "\n";


}
}
#endif
if (m_rand(20) != 0 && a.try_set(b.bits()))
return true;
return a.set_random(m_rand);
return (m_rand(20) != 0 && a.try_set(b.bits())) || a.set_random(m_rand);
}
else {
bool try_above = m_rand(2) == 0;
Expand Down
1 change: 0 additions & 1 deletion src/ast/sls/sls_bv_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ namespace sls {
bool try_repair_comp(bvect const& e, bvval& a, bvval& b, unsigned i);
bool try_repair_eq(bool is_true, bvval& a, bvval const& b);
bool try_repair_eq(app* e, unsigned i);
bool try_repair_eq_lookahead(app* e);
bool try_repair_int2bv(bvect const& e, expr* arg);
void add_p2_1(bvval const& a, bool use_current, bvect& t) const;

Expand Down
35 changes: 22 additions & 13 deletions src/ast/sls/sls_bv_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,25 +451,34 @@ class sls_evaluator {
case OP_BLSHR: {
SASSERT(n_args == 2);
m_mpz_manager.set(result, m_tracker.get_value(args[0]));
mpz shift; m_mpz_manager.set(shift, m_tracker.get_value(args[1]));
while (!m_mpz_manager.is_zero(shift)) {
m_mpz_manager.machine_div(result, m_two, result);
m_mpz_manager.dec(shift);
auto const& shift = m_tracker.get_value(args[1]);
if (m_mpz_manager.is_small(shift)) {
int s = m_mpz_manager.get_int(shift);
SASSERT(s >= 0);
m_mpz_manager.machine_div2k(result, s);
}
m_mpz_manager.del(shift);
else
m_mpz_manager.set(result, m_zero);
break;
}
case OP_BSHL: {
SASSERT(n_args == 2);
m_mpz_manager.set(result, m_tracker.get_value(args[0]));
mpz shift; m_mpz_manager.set(shift, m_tracker.get_value(args[1]));
while (!m_mpz_manager.is_zero(shift)) {
m_mpz_manager.mul(result, m_two, result);
m_mpz_manager.dec(shift);
m_mpz_manager.set(result, m_tracker.get_value(args[0]));
auto const& shift = m_tracker.get_value(args[1]);
if (m_mpz_manager.is_small(shift)) {
int s = m_mpz_manager.get_int(shift);
SASSERT(s >= 0);
int sz = m_bv_util.get_bv_size(n);
if (s >= sz)
m_mpz_manager.set(result, m_zero);
else {
m_mpz_manager.mul2k(result, s);
const mpz& p = m_powers(sz);
m_mpz_manager.rem(result, p, result);
}
}
const mpz & p = m_powers(m_bv_util.get_bv_size(n));
m_mpz_manager.rem(result, p, result);
m_mpz_manager.del(shift);
else
m_mpz_manager.set(result, m_zero);
break;
}
case OP_SIGN_EXT: {
Expand Down
166 changes: 138 additions & 28 deletions src/ast/sls/sls_bv_lookahead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,43 @@ Module Name:
#include "ast/sls/sls_bv_lookahead.h"
#include "ast/sls/sls_bv_eval.h"
#include "ast/sls/sls_bv_terms.h"
#include "ast/ast_pp.h"

namespace sls {

bv_lookahead::bv_lookahead(bv_eval& ev) :
bv(ev.bv),
bv(ev.m),
m_ev(ev),
ctx(ev.ctx),
m(ev.m) {}

bool bv_lookahead::try_repair_down(expr* e) {
return false;
auto is_true = m_ev.bval0(e);
if (!is_true)
bool bv_lookahead::try_repair_down(app* e) {
if (!m.is_bool(e))
return false;
if (m_ev.bval1(e) == m_ev.bval0(e))
return true;
auto const& uninterp = m_ev.terms.uninterp_occurs(e);
if (uninterp.empty())
return false;
// for (auto e : uninterp)
// verbose_stream() << mk_bounded_pp(e, m) << " ";
// verbose_stream() << "\n";

expr* t = uninterp[m_ev.m_rand() % uninterp.size()];
reset_updates();

auto& v = wval(t);
if (v.set_random(m_ev.m_rand)) {
//verbose_stream() << "set random " << mk_bounded_pp(t, m) << "\n";
ctx.new_value_eh(t);
return true;
}
return false;
IF_VERBOSE(4,
verbose_stream() << mk_bounded_pp(e, m) << "\n";
for (auto e : uninterp)
verbose_stream() << mk_bounded_pp(e, m) << " ";
verbose_stream() << "\n");

for (auto e : uninterp)
add_updates(e);

for (auto e : uninterp) {
auto& v = wval(e);
v.get_variant(m_ev.m_tmp, m_ev.m_rand);
auto d = lookahead(e, m_ev.m_tmp);
//verbose_stream() << mk_bounded_pp(e, m) << " " << d << "\n";
#if 0
for (unsigned i = 0; i < m_num_updates; ++i) {
auto const& [e, score, new_value] = m_updates[i];
verbose_stream() << mk_bounded_pp(e, m) << " " << new_value << " score: " << score << "\n";
}
return false;
#endif

return apply_update();
}

double bv_lookahead::lookahead(expr* e, bvect const& new_value) {
Expand All @@ -76,7 +74,7 @@ namespace sls {
unsigned max_depth = get_depth(e);
for (unsigned depth = max_depth; depth <= max_depth; ++depth) {
for (unsigned i = 0; !has_tabu && i < m_update_stack[depth].size(); ++i) {
e = m_update_stack[depth][i];
auto e = m_update_stack[depth][i];
if (bv.is_bv(e)) {
auto& v = m_ev.eval(to_app(e));
if (insert_update(e)) {
Expand All @@ -89,24 +87,133 @@ namespace sls {
has_tabu = true;
}
else if (m.is_bool(e) && m_ev.can_eval1(to_app(e))) {
if (!ctx.is_relevant(e))
continue;
bool is_true = ctx.is_true(e);
bool is_true_new = m_ev.bval1(to_app(e));
bool is_true_old = m_ev.bval1_tmp(to_app(e));
// verbose_stream() << "parent " << mk_bounded_pp(e, m) << " " << is_true << " " << is_true_new << " " << is_true_old << "\n";
if (is_true == is_true_new && is_true_new != is_true_old)
if (is_true_new == is_true_old)
continue;
if (is_true == is_true_new)
++make_count;
if (is_true == is_true_old && is_true_new != is_true_old)
if (is_true == is_true_old)
++break_count;
}
else {
IF_VERBOSE(1, verbose_stream() << "skipping " << mk_bounded_pp(e, m) << "\n");
has_tabu = true;
}
}
m_update_stack[depth].reset();
}
restore_lookahead();
// verbose_stream() << has_tabu << " " << new_value << " " << make_count << " " << break_count << "\n";

if (has_tabu)
return -10000;
return make_count - break_count;
}

void bv_lookahead::try_set(expr* e, bvect const& new_value) {
if (!wval(e).can_set(new_value))
return;
auto d = lookahead(e, new_value);
if (d > 0)
add_update(d, e, new_value);
}

void bv_lookahead::add_updates(expr* e) {
SASSERT(bv.is_bv(e));
auto& v = wval(e);
double d = 0;
while (m_v_saved.size() < v.bits().size()) {
m_v_saved.push_back(0);
m_v_updated.push_back(0);
}
m_v_saved.set_bw(v.bw);
m_v_updated.set_bw(v.bw);
v.bits().copy_to(v.nw, m_v_saved);
m_v_saved.copy_to(v.nw, m_v_updated);

// flip a single bit
for (unsigned i = 0; i < v.bw; ++i) {
m_v_updated.set(i, !m_v_updated.get(i));
try_set(e, m_v_updated);
//verbose_stream() << "flip " << d << " " << m_v_updated << "\n";
m_v_updated.set(i, !m_v_updated.get(i));
}
if (v.bw <= 1)
return;

// invert
for (unsigned i = 0; i < v.nw; ++i)
m_v_updated[i] = ~m_v_updated[i];
v.clear_overflow_bits(m_v_updated);
try_set(e, m_v_updated);

// increment
m_v_saved.copy_to(v.nw, m_v_updated);
v.add1(m_v_updated);
try_set(e, m_v_updated);

// decrement
m_v_saved.copy_to(v.nw, m_v_updated);
v.sub1(m_v_updated);
try_set(e, m_v_updated);

// random
v.get_variant(m_v_updated, m_ev.m_rand);
try_set(e, m_v_updated);
}

bool bv_lookahead::apply_update() {
double sum_score = 0;
for (unsigned i = 0; i < m_num_updates; ++i)
sum_score += m_updates[i].score;
double pos = (sum_score * m_ev.m_rand()) / (double)m_ev.m_rand.max_value();
for (unsigned i = 0; i < m_num_updates; ++i) {
auto const& [e, score, new_value] = m_updates[i];
pos -= score;
if (pos <= 0) {
//verbose_stream() << "apply " << mk_bounded_pp(e, m) << " new value " << new_value << " " << score << "\n";
apply_update(e, new_value);
return true;
}
}
return false;
}

void bv_lookahead::apply_update(expr* e, bvect const& new_value) {
SASSERT(bv.is_bv(e));
SASSERT(is_uninterp(e));
SASSERT(m_restore.empty());
wval(e).eval = new_value;
VERIFY(wval(e).commit_eval());
insert_update_stack(e);
unsigned max_depth = get_depth(e);
for (unsigned depth = max_depth; depth <= max_depth; ++depth) {
for (unsigned i = 0; i < m_update_stack[depth].size(); ++i) {
auto e = m_update_stack[depth][i];
if (bv.is_bv(e)) {
m_ev.eval(to_app(e)); // updates wval(e).eval
VERIFY(wval(e).commit_eval());
for (auto p : ctx.parents(e)) {
insert_update_stack(p);
max_depth = std::max(max_depth, get_depth(p));
}
}
else if (m.is_bool(e) && m_ev.can_eval1(to_app(e))) {
VERIFY(m_ev.repair_up(e));
}
else {
UNREACHABLE();
}
}
m_update_stack[depth].reset();
}
m_in_update_stack.reset();
}

bool bv_lookahead::insert_update(expr* e) {
m_restore.push_back(e);
m_on_restore.mark(e);
Expand All @@ -118,15 +225,18 @@ namespace sls {
void bv_lookahead::insert_update_stack(expr* e) {
unsigned depth = get_depth(e);
m_update_stack.reserve(depth + 1);
if (!m_update_stack[depth].contains(e))
if (!m_in_update_stack.is_marked(e)) {
m_in_update_stack.mark(e);
m_update_stack[depth].push_back(e);
}
}

void bv_lookahead::restore_lookahead() {
for (auto e : m_restore)
wval(e).restore_value();
m_restore.reset();
m_on_restore.reset();
m_in_update_stack.reset();
}

sls::bv_valuation& bv_lookahead::wval(expr* e) const {
Expand Down
Loading

0 comments on commit 5eb71c3

Please sign in to comment.