diff --git a/scripts/mk_project.py b/scripts/mk_project.py index 745d7bf5493..59faf0f0641 100644 --- a/scripts/mk_project.py +++ b/scripts/mk_project.py @@ -39,6 +39,7 @@ def init_project_def(): add_lib('solver', ['model', 'tactic', 'proofs']) add_lib('cmd_context', ['solver', 'rewriter']) add_lib('sat_tactic', ['tactic', 'sat', 'solver'], 'sat/tactic') + add_lib('sat_euf', ['sat_tactic', 'sat', 'euf'], 'sat/euf') add_lib('smt2parser', ['cmd_context', 'parser_util'], 'parsers/smt2') add_lib('pattern', ['normal_forms', 'smt2parser', 'rewriter'], 'ast/pattern') add_lib('core_tactics', ['tactic', 'macros', 'normal_forms', 'rewriter', 'pattern'], 'tactic/core') @@ -80,7 +81,7 @@ def init_project_def(): includes2install=['z3.h', 'z3_v1.h', 'z3_macros.h'] + API_files) add_lib('extra_cmds', ['cmd_context', 'subpaving_tactic', 'qe', 'arith_tactics'], 'cmd_context/extra_cmds') add_exe('shell', ['api', 'sat', 'extra_cmds','opt'], exe_name='z3') - add_exe('test', ['api', 'fuzzing', 'simplex', 'euf'], exe_name='test-z3', install=False) + add_exe('test', ['api', 'fuzzing', 'simplex', 'sat_euf'], exe_name='test-z3', install=False) _libz3Component = add_dll('api_dll', ['api', 'sat', 'extra_cmds'], 'api/dll', reexports=['api'], dll_name='libz3', diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5e1ea8f5178..e18e38f6939 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -60,6 +60,7 @@ add_subdirectory(math/subpaving/tactic) add_subdirectory(tactic/aig) add_subdirectory(solver) add_subdirectory(sat/tactic) +add_subdirectory(sat/euf) add_subdirectory(tactic/arith) add_subdirectory(nlsat/tactic) add_subdirectory(ackermannization) diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index b05df0cff7d..c68f743a198 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -35,8 +35,7 @@ namespace euf { unmerge_justification(n1); } - enode* egraph::mk_enode(expr* f, enode * const* args) { - unsigned num_args = is_app(f) ? to_app(f)->get_num_args() : 0; + enode* egraph::mk_enode(expr* f, unsigned num_args, enode * const* args) { enode* n = enode::mk(m_region, f, num_args, args); m_nodes.push_back(n); m_exprs.push_back(f); @@ -98,15 +97,15 @@ namespace euf { n->set_update_children(); } - enode* egraph::mk(expr* f, enode *const* args) { + enode* egraph::mk(expr* f, unsigned num_args, enode *const* args) { SASSERT(!find(f)); force_push(); - enode *n = mk_enode(f, args); + enode *n = mk_enode(f, num_args, args); SASSERT(n->class_size() == 1); m_expr2enode.setx(f->get_id(), n, nullptr); - if (n->num_args() == 0 && m.is_unique_value(f)) + if (num_args == 0 && m.is_unique_value(f)) n->mark_interpreted(); - if (n->num_args() == 0) + if (num_args == 0) return n; if (is_equality(n)) { update_children(n); @@ -171,6 +170,8 @@ namespace euf { std::swap(r1, r2); std::swap(n1, n2); } + if ((m.is_true(r2->get_owner()) || m.is_false(r2->get_owner())) && j.is_congruence()) + m_new_lits.push_back(n1); for (enode* p : enode_parents(n1)) m_table.erase(p); for (enode* p : enode_parents(n2)) @@ -187,6 +188,7 @@ namespace euf { void egraph::propagate() { m_new_eqs.reset(); + m_new_lits.reset(); SASSERT(m_num_scopes == 0 || m_worklist.empty()); unsigned head = 0, tail = m_worklist.size(); while (head < tail && m.limit().inc() && !inconsistent()) { @@ -239,6 +241,88 @@ namespace euf { SASSERT(n1->get_root()->m_target == nullptr); } + /** + \brief generate an explanation for a congruence. + Each pair of children under a congruence have the same roots + and therefore have a least common ancestor. We only need + explanations up to the least common ancestors. + */ + void egraph::push_congruence(enode* n1, enode* n2, bool comm) { + SASSERT(n1->get_decl() == n2->get_decl()); + if (comm && + n1->get_arg(0)->get_root() == n2->get_arg(1)->get_root() && + n1->get_arg(1)->get_root() == n2->get_arg(0)->get_root()) { + push_lca(n1->get_arg(0), n2->get_arg(1)); + push_lca(n1->get_arg(1), n2->get_arg(0)); + return; + } + + for (unsigned i = 0; i < n1->num_args(); ++i) + push_lca(n1->get_arg(i), n2->get_arg(i)); + } + + void egraph::push_lca(enode* a, enode* b) { + SASSERT(a->get_root() == b->get_root()); + enode* n = a; + while (n) { + n->mark2(); + n = n->m_target; + } + n = b; + while (n) { + if (n->is_marked2()) + n->unmark2(); + else if (!n->is_marked1()) + m_todo.push_back(n); + n = n->m_target; + } + n = a; + while (n->is_marked2()) { + n->unmark2(); + if (!n->is_marked1()) + m_todo.push_back(n); + n = n->m_target; + } + } + + void egraph::push_todo(enode* n) { + while (n) { + m_todo.push_back(n); + n = n->m_target; + } + } + + template + void egraph::explain(ptr_vector& justifications) { + SASSERT(m_inconsistent); + SASSERT(m_todo.empty()); + push_todo(m_n1); + push_todo(m_n2); + explain_eq(justifications, m_n1, m_n2, m_justification); + explain_todo(justifications); + } + + template + void egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b, bool comm) { + SASSERT(m_todo.empty()); + push_congruence(a, b, comm); + explain_todo(justifications); + } + + template + void egraph::explain_todo(ptr_vector& justifications) { + for (unsigned i = 0; i < m_todo.size(); ++i) { + enode* n = m_todo[i]; + if (n->m_target && !n->is_marked1()) { + n->mark1(); + explain_eq(justifications, n, n->m_target, n->m_justification); + } + } + for (enode* n : m_todo) + n->unmark1(); + m_todo.reset(); + } + void egraph::invariant() { for (enode* n : m_nodes) n->invariant(); @@ -267,3 +351,12 @@ namespace euf { return out; } } + +template void euf::egraph::explain(ptr_vector& justifications); +template void euf::egraph::explain_todo(ptr_vector& justifications); +template void euf::egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b, bool comm); + +template void euf::egraph::explain(ptr_vector& justifications); +template void euf::egraph::explain_todo(ptr_vector& justifications); +template void euf::egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b, bool comm); + diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 2f4bfed898b..6844a2dc507 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -58,13 +58,14 @@ namespace euf { enode *m_n2 { nullptr }; justification m_justification; enode_vector m_new_eqs; + enode_vector m_new_lits; enode_vector m_todo; void push_eq(enode* r1, enode* n1, unsigned r2_num_parents) { m_eqs.push_back(add_eq_record(r1, n1, r2_num_parents)); } void undo_eq(enode* r1, enode* n1, unsigned r2_num_parents); - enode* mk_enode(expr* f, enode * const* args); + enode* mk_enode(expr* f, unsigned num_args, enode * const* args); void reinsert(enode* n); void force_push(); void set_conflict(enode* n1, enode* n2, justification j); @@ -72,15 +73,30 @@ namespace euf { void merge_justification(enode* n1, enode* n2, justification j); void unmerge_justification(enode* n1); void dedup_equalities(); - bool is_equality(enode* n) const; void reinsert_equality(enode* p); void update_children(enode* n); + void push_lca(enode* a, enode* b); + void push_congruence(enode* n1, enode* n2, bool commutative); + void push_todo(enode* n); + template + void explain_eq(ptr_vector& justifications, enode* a, enode* b, justification const& j) { + if (j.is_external()) + justifications.push_back(j.ext()); + else if (j.is_congruence()) + push_congruence(a, b, j.is_commutative()); + } + template + void explain_todo(ptr_vector& justifications); + public: egraph(ast_manager& m): m(m), m_table(m), m_exprs(m) {} enode* find(expr* f) { return m_expr2enode.get(f->get_id(), nullptr); } - enode* mk(expr* f, enode *const* args); + enode* mk(expr* f, unsigned n, enode *const* args); void push() { ++m_num_scopes; } void pop(unsigned num_scopes); + + bool is_equality(enode* n) const; + /** \brief merge nodes, all effects are deferred to the propagation step. */ @@ -98,45 +114,11 @@ namespace euf { void propagate(); bool inconsistent() const { return m_inconsistent; } enode_vector const& new_eqs() const { return m_new_eqs; } + enode_vector const& new_lits() const { return m_new_lits; } template - void explain(ptr_vector& justifications) { - SASSERT(m_inconsistent); - SASSERT(m_todo.empty()); - auto push_congruence = [&](enode* p, enode* q) { - SASSERT(p->get_decl() == q->get_decl()); - for (enode* arg : enode_args(p)) - m_todo.push_back(arg); - for (enode* arg : enode_args(q)) - m_todo.push_back(arg); - }; - auto explain_node = [&](enode* n) { - if (!n->m_target) - return; - if (n->is_marked1()) - return; - n->mark1(); - if (n->m_justification.is_external()) - justifications.push_back(n->m_justification.ext()); - else if (n->m_justification.is_congruence()) - push_congruence(n, n->m_target); - n = n->m_target; - if (!n->is_marked1()) - m_todo.push_back(n); - }; - m_todo.push_back(m_n1); - m_todo.push_back(m_n2); - if (m_justification.is_external()) - justifications.push_back(m_justification.ext()); - else if (m_justification.is_congruence()) - push_congruence(m_n1, m_n2); - for (unsigned i = 0; i < m_todo.size(); ++i) - explain_node(m_todo[i]); - for (enode* n : m_todo) - n->unmark1(); - m_todo.reset(); - } - - + void explain(ptr_vector& justifications); + template + void explain_eq(ptr_vector& justifications, enode* a, enode* b, bool comm); void invariant(); std::ostream& display(std::ostream& out) const; }; diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index 39034bd0edd..776747db270 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -43,6 +43,7 @@ namespace euf { enode* m_root; enode* m_target { nullptr }; justification m_justification; + unsigned m_num_args; enode* m_args[0]; friend class enode_args; @@ -56,12 +57,14 @@ namespace euf { } static enode* mk(region& r, expr* f, unsigned num_args, enode* const* args) { + SASSERT(num_args <= (is_app(f) ? to_app(f)->get_num_args() : 0)); void* mem = r.allocate(get_enode_size(num_args)); enode* n = new (mem) enode(); n->m_owner = f; n->m_next = n; n->m_root = n; n->m_commutative = num_args == 2 && is_app(f) && to_app(f)->get_decl()->is_commutative(); + n->m_num_args = num_args; for (unsigned i = 0; i < num_args; ++i) { SASSERT(to_app(f)->get_arg(i) == args[i]->get_owner()); n->m_args[i] = args[i]; @@ -83,9 +86,10 @@ namespace euf { } enode* const* args() const { return m_args; } - unsigned num_args() const { return is_app(m_owner) ? to_app(m_owner)->get_num_args() : 0; } + unsigned num_args() const { return m_num_args; } unsigned num_parents() const { return m_parents.size(); } bool interpreted() const { return m_interpreted; } + bool commutative() const { return m_commutative; } void mark_interpreted() { SASSERT(num_args() == 0); m_interpreted = true; } enode* get_arg(unsigned i) const { SASSERT(i < num_args()); return m_args[i]; } @@ -97,6 +101,9 @@ namespace euf { void mark1() { m_mark1 = true; } void unmark1() { m_mark1 = false; } bool is_marked1() { return m_mark1; } + void mark2() { m_mark2 = true; } + void unmark2() { m_mark2 = false; } + bool is_marked2() { return m_mark2; } void add_parent(enode* p) { m_parents.push_back(p); } unsigned class_size() const { return m_class_size; } enode* get_root() const { return m_root; } diff --git a/src/sat/euf/CMakeLists.txt b/src/sat/euf/CMakeLists.txt new file mode 100644 index 00000000000..0be16aa8b75 --- /dev/null +++ b/src/sat/euf/CMakeLists.txt @@ -0,0 +1,8 @@ +z3_add_component(sat_euf + SOURCES + euf_solver.cpp + COMPONENT_DEPENDENCIES + sat + sat_tactic + euf +) diff --git a/src/sat/euf/euf_solver.cpp b/src/sat/euf/euf_solver.cpp new file mode 100644 index 00000000000..0b5a257dc27 --- /dev/null +++ b/src/sat/euf/euf_solver.cpp @@ -0,0 +1,186 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + euf_solver.cpp + +Abstract: + + Solver plugin for EUF + +Author: + + Nikolaj Bjorner (nbjorner) 2020-08-25 + +--*/ +#include "sat/euf/euf_solver.h" +#include "sat/sat_solver.h" +#include "tactic/tactic_exception.h" + +namespace euf_sat { + + bool solver::propagate(literal l, ext_constraint_idx idx) { + UNREACHABLE(); + return true; + } + + void solver::get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) { + m_explain.reset(); + euf::enode* n = nullptr; + bool sign = false; + if (idx != 0) { + auto p = m_var2node[l.var()]; + n = p.first; + sign = l.sign() != p.second; + } + + switch (idx) { + case 0: + SASSERT(m_egraph.inconsistent()); + m_egraph.explain(m_explain); + break; + case 1: + SASSERT(m_egraph.is_equality(n)); + m_egraph.explain_eq(m_explain, n->get_arg(0), n->get_arg(1), n->commutative()); + break; + case 2: + SASSERT(m.is_bool(n->get_owner())); + m_egraph.explain_eq(m_explain, n, (sign ? m_false : m_true), false); + break; + default: + UNREACHABLE(); + } + for (unsigned* idx : m_explain) + r.push_back(sat::to_literal((unsigned)(idx - base_ptr()))); + } + + void solver::asserted(literal l) { + auto p = m_var2node[l.var()]; + bool sign = p.second != l.sign(); + euf::enode* n = p.first; + expr* e = n->get_owner(); + if (m.is_eq(e) && !sign) { + euf::enode* na = n->get_arg(0); + euf::enode* nb = n->get_arg(1); + m_egraph.merge(na, nb, base_ptr() + l.index()); + } + else { + euf::enode* nb = sign ? m_false : m_true; + m_egraph.merge(n, nb, base_ptr() + l.index()); + } + + // TBD: delay propagation? + m_egraph.propagate(); + if (m_egraph.inconsistent()) { + s().set_conflict(sat::justification::mk_ext_justification(s().scope_lvl(), 0)); + return; + } + for (euf::enode* eq : m_egraph.new_eqs()) { + bool_var v = m_expr2var.to_bool_var(eq->get_owner()); + s().assign(literal(v, false), sat::justification::mk_ext_justification(s().scope_lvl(), 1)); + } + for (euf::enode* p : m_egraph.new_lits()) { + expr* e = p->get_owner(); + bool sign = m.is_false(p->get_root()->get_owner()); + SASSERT(m.is_bool(e)); + SASSERT(m.is_true(p->get_root()->get_owner()) || sign); + bool_var v = m_expr2var.to_bool_var(e); + s().assign(literal(v, sign), sat::justification::mk_ext_justification(s().scope_lvl(), 2)); + } + } + + sat::check_result solver::check() { + return sat::CR_DONE; + } + void solver::push() { + m_egraph.push(); + } + void solver::pop(unsigned n) { + m_egraph.pop(n); + } + void solver::pre_simplify() {} + void solver::simplify() {} + // have a way to replace l by r in all constraints + void solver::clauses_modifed() {} + lbool solver::get_phase(bool_var v) { return l_undef; } + std::ostream& solver::display(std::ostream& out) const { + m_egraph.display(out); + return out; + } + std::ostream& solver::display_justification(std::ostream& out, ext_justification_idx idx) const { return out; } + std::ostream& solver::display_constraint(std::ostream& out, ext_constraint_idx idx) const { return out; } + void solver::collect_statistics(statistics& st) const {} + sat::extension* solver::copy(sat::solver* s) { return nullptr; } + sat::extension* solver::copy(sat::lookahead* s, bool learned) { return nullptr; } + void solver::find_mutexes(literal_vector& lits, vector & mutexes) {} + void solver::gc() {} + void solver::pop_reinit() {} + bool solver::validate() { return true; } + void solver::init_use_list(sat::ext_use_list& ul) {} + bool solver::is_blocked(literal l, ext_constraint_idx) { return false; } + bool solver::check_model(sat::model const& m) const { return true;} + unsigned solver::max_var(unsigned w) const { return w; } + + void solver::internalize(sat_internalizer& si, expr* e) { + SASSERT(!si.is_bool_op(e)); + unsigned sz = m_stack.size(); + euf::enode* n = visit(si, e); + while (m_stack.size() > sz) { + loop: + if (!m.inc()) + throw tactic_exception(m.limit().get_cancel_msg()); + frame & fr = m_stack.back(); + expr* e = fr.m_e; + if (m_egraph.find(e)) { + m_stack.pop_back(); + continue; + } + unsigned num = is_app(e) ? to_app(e)->get_num_args() : 0; + m_args.reset(); + while (fr.m_idx < num) { + expr* arg = to_app(e)->get_arg(fr.m_idx); + fr.m_idx++; + n = visit(si, arg); + if (!n) + goto loop; + m_args.push_back(n); + } + n = m_egraph.mk(e, num, m_args.c_ptr()); + attach_bool_var(si, n); + } + SASSERT(m_egraph.find(e)); + } + + euf::enode* solver::visit(sat_internalizer& si, expr* e) { + euf::enode* n = m_egraph.find(e); + if (n) + return n; + if (si.is_bool_op(e)) { + sat::literal lit = si.internalize(e); + n = m_egraph.mk(e, 0, nullptr); + attach_bool_var(lit.var(), lit.sign(), n); + s().set_external(lit.var()); + return n; + } + if (is_app(e) && to_app(e)->get_num_args() > 0) + return nullptr; + n = m_egraph.mk(e, 0, nullptr); + attach_bool_var(si, n); + return n; + } + + void solver::attach_bool_var(sat_internalizer& si, euf::enode* n) { + expr* e = n->get_owner(); + if (m.is_bool(e)) { + sat::bool_var v = si.add_bool_var(e); + attach_bool_var(v, false, n); + } + } + + void solver::attach_bool_var(sat::bool_var v, bool sign, euf::enode* n) { + m_var2node.reserve(v + 1); + m_var2node[v] = euf::enode_bool_pair(n, sign); + } + +} diff --git a/src/sat/euf/euf_solver.h b/src/sat/euf/euf_solver.h new file mode 100644 index 00000000000..6a98ea93e85 --- /dev/null +++ b/src/sat/euf/euf_solver.h @@ -0,0 +1,98 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + euf_solver.h + +Abstract: + + Solver plugin for EUF + +Author: + + Nikolaj Bjorner (nbjorner) 2020-08-25 + +--*/ +#pragma once + +#include "sat/sat_extension.h" +#include "ast/euf/euf_egraph.h" +#include "sat/tactic/atom2bool_var.h" +#include "sat/tactic/goal2sat.h" + +namespace euf_sat { + typedef sat::literal literal; + typedef sat::ext_constraint_idx ext_constraint_idx; + typedef sat::ext_justification_idx ext_justification_idx; + typedef sat::literal_vector literal_vector; + typedef sat::bool_var bool_var; + + struct frame { + expr* m_e; + unsigned m_idx; + frame(expr* e) : m_e(e), m_idx(0) {} + }; + + class solver : public sat::extension { + ast_manager& m; + atom2bool_var& m_expr2var; + euf::egraph m_egraph; + sat::solver* m_solver; + + euf::enode* m_true; + euf::enode* m_false; + svector m_var2node; + ptr_vector m_explain; + euf::enode_vector m_args; + svector m_stack; + + sat::solver& s() { return *m_solver; } + unsigned * base_ptr() { return reinterpret_cast(this); } + euf::enode* visit(sat_internalizer& si, expr* e); + void attach_bool_var(sat_internalizer& si, euf::enode* n); + void attach_bool_var(sat::bool_var v, bool sign, euf::enode* n); + + public: + solver(ast_manager& m, atom2bool_var& expr2var): + m(m), + m_expr2var(expr2var), + m_egraph(m), + m_solver(nullptr) + {} + + void set_solver(sat::solver* s) override { m_solver = s; } + void set_lookahead(sat::lookahead* s) override { } + double get_reward(literal l, ext_constraint_idx idx, sat::literal_occs_fun& occs) const override { return 0; } + bool is_extended_binary(ext_justification_idx idx, literal_vector & r) override { return false; } + + bool propagate(literal l, ext_constraint_idx idx) override; + void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) override; + void asserted(literal l) override; + sat::check_result check() override; + void push() override; + void pop(unsigned n) override; + void pre_simplify() override; + void simplify() override; + // have a way to replace l by r in all constraints + void clauses_modifed() override; + lbool get_phase(bool_var v) override; + std::ostream& display(std::ostream& out) const override; + std::ostream& display_justification(std::ostream& out, ext_justification_idx idx) const override; + std::ostream& display_constraint(std::ostream& out, ext_constraint_idx idx) const override; + void collect_statistics(statistics& st) const override; + extension* copy(sat::solver* s) override; + extension* copy(sat::lookahead* s, bool learned) override; + void find_mutexes(literal_vector& lits, vector & mutexes) override; + void gc() override; + void pop_reinit() override; + bool validate() override; + void init_use_list(sat::ext_use_list& ul) override; + bool is_blocked(literal l, ext_constraint_idx) override; + bool check_model(sat::model const& m) const override; + unsigned max_var(unsigned w) const override; + + void internalize(sat_internalizer& si, expr* e); + + }; +}; diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 664563cc210..e0ede3523a6 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -42,7 +42,7 @@ Module Name: #include "tactic/generic_model_converter.h" #include -struct goal2sat::imp { +struct goal2sat::imp : public sat_internalizer { struct frame { app * m_t; unsigned m_root:1; @@ -131,6 +131,18 @@ struct goal2sat::imp { return m_true; } + sat::bool_var add_bool_var(expr* t) override { + sat::bool_var v = m_map.to_bool_var(t); + if (v == sat::null_bool_var) { + v = m_solver.add_var(true); + m_map.insert(t, v); + } + else { + m_solver.set_external(v); + } + return v; + } + void convert_atom(expr * t, bool root, bool sign) { SASSERT(m.is_bool(t)); sat::literal l; @@ -149,7 +161,7 @@ struct goal2sat::imp { } else { bool ext = m_default_external || !is_uninterp_const(t) || m_interface_vars.contains(t); - sat::bool_var v = m_solver.add_var(ext); + v = m_solver.add_var(ext); m_map.insert(t, v); l = sat::literal(v, sign); TRACE("sat", tout << "new_var: " << v << ": " << mk_bounded_pp(t, m, 2) << " " << is_uninterp_const(t) << "\n";); @@ -812,7 +824,7 @@ struct goal2sat::imp { } } - sat::literal internalize(expr* n) { + sat::literal internalize(expr* n) override { SASSERT(m_result_stack.empty()); process(n, false); SASSERT(m_result_stack.size() == 1); @@ -820,6 +832,30 @@ struct goal2sat::imp { m_result_stack.reset(); return result; } + + bool is_bool_op(expr* t) const override { + if (!is_app(t)) + return false; + if (to_app(t)->get_family_id() == m.get_basic_family_id()) { + switch (to_app(t)->get_decl_kind()) { + case OP_OR: + case OP_AND: + case OP_TRUE: + case OP_FALSE: + case OP_NOT: + return true; + case OP_ITE: + case OP_EQ: + return m.is_bool(to_app(t)->get_arg(1)); + default: + return false; + } + } + else if (to_app(t)->get_family_id() == pb.get_family_id()) + return true; + else + return false; + } void process(expr * n) { m_result_stack.reset(); diff --git a/src/sat/tactic/goal2sat.h b/src/sat/tactic/goal2sat.h index a1ea6a78e34..47a12e93fcc 100644 --- a/src/sat/tactic/goal2sat.h +++ b/src/sat/tactic/goal2sat.h @@ -34,6 +34,14 @@ Module Name: #include "tactic/generic_model_converter.h" #include "sat/tactic/atom2bool_var.h" + +class sat_internalizer { +public: + virtual bool is_bool_op(expr* e) const = 0; + virtual sat::literal internalize(expr* e) = 0; + virtual sat::bool_var add_bool_var(expr* e) = 0; +}; + class goal2sat { struct imp; imp * m_imp; diff --git a/src/test/egraph.cpp b/src/test/egraph.cpp index f311b503836..692829ff3e6 100644 --- a/src/test/egraph.cpp +++ b/src/test/egraph.cpp @@ -83,6 +83,8 @@ static void test2() { } } + + static void test3() { ast_manager m; reg_decl_plugins(m);