Skip to content

Commit

Permalink
force-push on new_eq, new_diseq in user propagator, other fixes to Py…
Browse files Browse the repository at this point in the history
…thon bindings for user propagator

This update allows the python bindings for user-propagator to handle functions that are declared to be registered with the user propagator plugin. It fixes a bug in UserPropagateBase.add to allow registering terms dynamically during search.
It also fixes a bug in theory_user_propagate as scopes were not fully pushed when the solver gets the callbacks for new equalities and new disequalities.
It also adds equality and disequality interfaces to the sat/smt solver version (which isn't being exercised in earnest yet)
  • Loading branch information
NikolajBjorner committed Jul 25, 2022
1 parent 3e38bbb commit 5c2c0ae
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 25 deletions.
52 changes: 46 additions & 6 deletions src/api/python/z3/z3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11364,12 +11364,12 @@ def to_ContextObj(ptr,):
return ctx


def user_prop_fresh(ctx, new_ctx):
def user_prop_fresh(ctx, _new_ctx):
_prop_closures.set_threaded()
prop = _prop_closures.get(ctx)
nctx = Context()
Z3_del_context(nctx.ctx)
new_ctx = to_ContextObj(new_ctx)
new_ctx = to_ContextObj(_new_ctx)
nctx.ctx = new_ctx
nctx.eh = Z3_set_error_handler(new_ctx, z3_error_handler)
nctx.owner = False
Expand All @@ -11390,6 +11390,13 @@ def user_prop_fixed(ctx, cb, id, value):
prop.fixed(id, value)
prop.cb = None

def user_prop_created(ctx, cb, id):
prop = _prop_closures.get(ctx)
prop.cb = cb
id = _to_expr_ref(to_Ast(id), prop.ctx())
prop.created(id)
prop.cb = None

def user_prop_final(ctx, cb):
prop = _prop_closures.get(ctx)
prop.cb = cb
Expand Down Expand Up @@ -11417,10 +11424,32 @@ def user_prop_diseq(ctx, cb, x, y):
_user_prop_pop = Z3_pop_eh(user_prop_pop)
_user_prop_fresh = Z3_fresh_eh(user_prop_fresh)
_user_prop_fixed = Z3_fixed_eh(user_prop_fixed)
_user_prop_created = Z3_created_eh(user_prop_created)
_user_prop_final = Z3_final_eh(user_prop_final)
_user_prop_eq = Z3_eq_eh(user_prop_eq)
_user_prop_diseq = Z3_eq_eh(user_prop_diseq)

def PropagateFunction(name, *sig):
"""Create a function that gets tracked by user propagator.
Every term headed by this function symbol is tracked.
If a term is fixed and the fixed callback is registered a
callback is invoked that the term headed by this function is fixed.
"""
sig = _get_args(sig)
if z3_debug():
_z3_assert(len(sig) > 0, "At least two arguments expected")
arity = len(sig) - 1
rng = sig[arity]
if z3_debug():
_z3_assert(is_sort(rng), "Z3 sort expected")
dom = (Sort * arity)()
for i in range(arity):
if z3_debug():
_z3_assert(is_sort(sig[i]), "Z3 sort expected")
dom[i] = sig[i].ast
ctx = rng.ctx
return FuncDeclRef(Z3_solver_propagate_declare(ctx.ref(), to_symbol(name, ctx), arity, dom, rng.ast), ctx)


class UserPropagateBase:

Expand All @@ -11443,6 +11472,7 @@ def __init__(self, s, ctx=None):
self.final = None
self.eq = None
self.diseq = None
self.created = None
if ctx:
self.fresh_ctx = ctx
if s:
Expand Down Expand Up @@ -11473,6 +11503,13 @@ def add_fixed(self, fixed):
Z3_solver_propagate_fixed(self.ctx_ref(), self.solver.solver, _user_prop_fixed)
self.fixed = fixed

def add_created(self, created):
assert not self.created
assert not self._ctx
if self.solver:
Z3_solver_propagate_created(self.ctx_ref(), self.solver.solver, _user_prop_created)
self.created = created

def add_final(self, final):
assert not self.final
assert not self._ctx
Expand Down Expand Up @@ -11504,9 +11541,12 @@ def fresh(self, new_ctx):
raise Z3Exception("fresh needs to be overwritten")

def add(self, e):
assert self.solver
assert not self._ctx
Z3_solver_propagate_register(self.ctx_ref(), self.solver.solver, e.ast)
if self.solver:
Z3_solver_propagate_register(self.ctx_ref(), self.solver.solver, e.ast)
else:
Z3_solver_propagate_register_cb(self.ctx_ref(), ctypes.c_void_p(self.cb), e.ast)


#
# Propagation can only be invoked as during a fixed or final callback.
Expand All @@ -11519,5 +11559,5 @@ def propagate(self, e, ids, eqs=[]):
Z3_solver_propagate_consequence(e.ctx.ref(), ctypes.c_void_p(
self.cb), num_fixed, _ids, num_eqs, _lhs, _rhs, e.ast)

def conflict(self, deps):
self.propagate(BoolVal(False, self.ctx()), deps, eqs=[])
def conflict(self, deps = [], eqs = []):
self.propagate(BoolVal(False, self.ctx()), deps, eqs)
23 changes: 11 additions & 12 deletions src/ast/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1429,7 +1429,7 @@ ast_manager::~ast_manager() {
}
m_plugins.reset();
while (!m_ast_table.empty()) {
DEBUG_CODE(IF_VERBOSE(0, verbose_stream() << "ast_manager LEAKED: " << m_ast_table.size() << std::endl););
DEBUG_CODE(IF_VERBOSE(1, verbose_stream() << "ast_manager LEAKED: " << m_ast_table.size() << std::endl););
ptr_vector<ast> roots;
ast_mark mark;
for (ast * n : m_ast_table) {
Expand Down Expand Up @@ -1465,22 +1465,21 @@ ast_manager::~ast_manager() {
break;
}
}
for (ast * n : m_ast_table) {
if (!mark.is_marked(n)) {
for (ast * n : m_ast_table)
if (!mark.is_marked(n))
roots.push_back(n);
}
}

SASSERT(!roots.empty());
for (unsigned i = 0; i < roots.size(); ++i) {
ast* a = roots[i];
DEBUG_CODE(
std::cout << "Leaked: ";
if (is_sort(a)) {
std::cout << to_sort(a)->get_name() << "\n";
}
else {
std::cout << mk_ll_pp(a, *this, false) << "id: " << a->get_id() << "\n";
});
IF_VERBOSE(1,
verbose_stream() << "Leaked: ";
if (is_sort(a))
verbose_stream() << to_sort(a)->get_name() << "\n";
else
verbose_stream() << mk_ll_pp(a, *this, false) << "id: " << a->get_id() << "\n";
););
a->m_ref_count = 0;
delete_node(a);
}
Expand Down
15 changes: 15 additions & 0 deletions src/sat/smt/user_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,21 @@ namespace user_solver {
m_id2justification.setx(v, lits, sat::literal_vector());
m_fixed_eh(m_user_context, this, var2expr(v), lit.sign() ? m.mk_false() : m.mk_true());
}

void solver::new_eq_eh(euf::th_eq const& eq) {
if (!m_eq_eh)
return;
force_push();
m_eq_eh(m_user_context, this, var2expr(eq.v1()), var2expr(eq.v2()));
}

void solver::new_diseq_eh(euf::th_eq const& de) {
if (!m_diseq_eh)
return;
force_push();
m_diseq_eh(m_user_context, this, var2expr(eq.v1()), var2expr(eq.v2()));
}


void solver::push_core() {
th_euf_solver::push_core();
Expand Down
4 changes: 4 additions & 0 deletions src/sat/smt/user_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ namespace user_solver {
bool get_case_split(sat::bool_var& var, lbool &phase) override;

void asserted(sat::literal lit) override;
bool use_diseqs() const override { return (bool)m_diseq_eh; }
void new_eq_eh(euf::th_eq const& eq) override;
void new_diseq_eh(euf::th_eq const& de) override;

sat::check_result check() override;
void push_core() override;
void pop_core(unsigned n) override;
Expand Down
3 changes: 3 additions & 0 deletions src/smt/smt_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,8 @@ namespace smt {
SASSERT(t2 != null_theory_id);
theory_var v1 = m_fparams.m_new_core2th_eq ? get_closest_var(n1, t2) : r1->get_th_var(t2);

TRACE("merge_theory_vars", tout << get_theory(t2)->get_name() << ": " << v2 << " == " << v1 << "\n");

if (v1 != null_theory_var) {
// only send the equality to the theory, if the equality was not propagated by it.
if (t2 != from_th)
Expand All @@ -839,6 +841,7 @@ namespace smt {
SASSERT(v1 != null_theory_var);
SASSERT(t1 != null_theory_id);
theory_var v2 = r2->get_th_var(t1);
TRACE("merge_theory_vars", tout << get_theory(t1)->get_name() << ": " << v2 << " == " << v1 << "\n");
if (v2 == null_theory_var) {
r2->add_th_var(v1, t1, m_region);
push_new_th_diseqs(r2, v1, get_theory(t1));
Expand Down
16 changes: 11 additions & 5 deletions src/smt/theory_user_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ final_check_status theory_user_propagator::final_check_eh() {
catch (...) {
throw default_exception("Exception thrown in \"final\"-callback");
}
CTRACE("user_propagate", can_propagate(), tout << "can propagate\n");
propagate();
CTRACE("user_propagate", ctx.inconsistent(), tout << "inconsistent\n");
// check if it became inconsistent or something new was propagated/registered
bool done = (sz1 == m_prop.size()) && (sz2 == m_expr2var.size()) && !ctx.inconsistent();
return done ? FC_DONE : FC_CONTINUE;
Expand Down Expand Up @@ -298,13 +300,17 @@ void theory_user_propagator::propagate_consequence(prop_info const& prop) {
m_eqs.reset();
for (expr* id : prop.m_ids)
m_lits.append(m_id2justification[expr2var(id)]);
for (auto const& p : prop.m_eqs)
m_eqs.push_back(enode_pair(get_enode(expr2var(p.first)), get_enode(expr2var(p.second))));
DEBUG_CODE(for (auto const& p : m_eqs) VERIFY(p.first->get_root() == p.second->get_root()););
for (auto const& [a,b] : prop.m_eqs)
if (a != b)
m_eqs.push_back(enode_pair(get_enode(expr2var(a)), get_enode(expr2var(b))));
DEBUG_CODE(for (auto const& [a, b] : m_eqs) VERIFY(a->get_root() == b->get_root()););
DEBUG_CODE(for (expr* e : prop.m_ids) VERIFY(m_fixed.contains(expr2var(e))););
DEBUG_CODE(for (literal lit : m_lits) VERIFY(ctx.get_assignment(lit) == l_true););

TRACE("user_propagate", tout << "propagating #" << prop.m_conseq->get_id() << ": " << prop.m_conseq << "\n");
TRACE("user_propagate", tout << "propagating #" << prop.m_conseq->get_id() << ": " << prop.m_conseq << "\n";
for (auto const& [a,b] : m_eqs) tout << enode_pp(a, ctx) << " == " << enode_pp(b, ctx) << "\n";
for (expr* e : prop.m_ids) tout << mk_pp(e, m) << "\n";
for (literal lit : m_lits) tout << lit << "\n");

if (m.is_false(prop.m_conseq)) {
js = ctx.mk_justification(
Expand Down Expand Up @@ -341,9 +347,9 @@ void theory_user_propagator::propagate_new_fixed(prop_info const& prop) {


void theory_user_propagator::propagate() {
TRACE("user_propagate", tout << "propagating queue head: " << m_qhead << " prop queue: " << m_prop.size() << "\n");
if (m_qhead == m_prop.size() && m_to_add_qhead == m_to_add.size())
return;
TRACE("user_propagate", tout << "propagating queue head: " << m_qhead << " prop queue: " << m_prop.size() << "\n");
force_push();

unsigned qhead = m_to_add_qhead;
Expand Down
5 changes: 3 additions & 2 deletions src/smt/theory_user_propagator.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,11 @@ namespace smt {
bool get_case_split(bool_var& var, bool& is_pos);

theory * mk_fresh(context * new_ctx) override;
char const* get_name() const override { return "user_propagate"; }
bool internalize_atom(app* atom, bool gate_ctx) override;
bool internalize_term(app* term) override;
void new_eq_eh(theory_var v1, theory_var v2) override { if (m_eq_eh) m_eq_eh(m_user_context, this, var2expr(v1), var2expr(v2)); }
void new_diseq_eh(theory_var v1, theory_var v2) override { if (m_diseq_eh) m_diseq_eh(m_user_context, this, var2expr(v1), var2expr(v2)); }
void new_eq_eh(theory_var v1, theory_var v2) override { force_push(); if (m_eq_eh) m_eq_eh(m_user_context, this, var2expr(v1), var2expr(v2)); }
void new_diseq_eh(theory_var v1, theory_var v2) override { force_push(); if (m_diseq_eh) m_diseq_eh(m_user_context, this, var2expr(v1), var2expr(v2)); }
bool use_diseqs() const override { return ((bool)m_diseq_eh); }
bool build_models() const override { return false; }
final_check_status final_check_eh() override;
Expand Down

0 comments on commit 5c2c0ae

Please sign in to comment.