diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index 691c42981e0..4e8e7327f43 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -129,7 +129,7 @@ namespace euf { return n; } - egraph::egraph(ast_manager& m) : m(m), m_table(m), m_tmp_app(2), m_exprs(m) { + egraph::egraph(ast_manager& m) : m(m), m_table(m), m_tmp_app(2), m_exprs(m), m_eq_decls(m) { m_tmp_eq = enode::mk_tmp(m_region, 2); } @@ -592,7 +592,7 @@ namespace euf { SASSERT(!n1->get_root()->m_target); } - bool egraph::are_diseq(enode* a, enode* b) const { + bool egraph::are_diseq(enode* a, enode* b) { enode* ra = a->get_root(), * rb = b->get_root(); if (ra == rb) return false; @@ -600,12 +600,16 @@ namespace euf { return true; if (ra->get_sort() != rb->get_sort()) return true; - expr_ref eq(m.mk_eq(a->get_expr(), b->get_expr()), m); - m_tmp_eq->m_args[0] = a; - m_tmp_eq->m_args[1] = b; - m_tmp_eq->m_expr = eq; - SASSERT(m_tmp_eq->num_args() == 2); - enode* r = m_table.find(m_tmp_eq); + if (ra->num_parents() > rb->num_parents()) + std::swap(ra, rb); + if (ra->num_parents() <= 3) { + for (enode* p : enode_parents(ra)) + if (p->is_equality() && p->get_root()->value() == l_false && + (rb == p->get_arg(0)->get_root() || rb == p->get_arg(1)->get_root())) + return true; + return false; + } + enode* r = tmp_eq(ra, rb); if (r && r->get_root()->value() == l_false) return true; return false; @@ -617,6 +621,24 @@ namespace euf { return find(m_tmp_app.get_app(), num_args, args); } + enode* egraph::tmp_eq(enode* a, enode* b) { + func_decl* f = nullptr; + for (unsigned i = m_eq_decls.size(); i-- > 0; ) { + auto e = m_eq_decls.get(i); + if (e->get_domain(0) == a->get_sort()) { + f = e; + break; + } + } + if (!f) { + app_ref eq(m.mk_eq(a->get_expr(), b->get_expr()), m); + m_eq_decls.push_back(eq->get_decl()); + f = eq->get_decl(); + } + enode* args[2] = { a, b }; + return get_enode_eq_to(f, 2, args); + } + /** \brief generate an explanation for a congruence. Each pair of children under a congruence have the same roots @@ -714,12 +736,7 @@ namespace euf { explain_eq(justifications, b, rb); return sat::null_bool_var; } - expr_ref eq(m.mk_eq(a->get_expr(), b->get_expr()), m); - m_tmp_eq->m_args[0] = a; - m_tmp_eq->m_args[1] = b; - m_tmp_eq->m_expr = eq; - SASSERT(m_tmp_eq->num_args() == 2); - enode* r = m_table.find(m_tmp_eq); + enode* r = tmp_eq(a, b); SASSERT(r && r->get_root()->value() == l_false); explain_eq(justifications, r, r->get_root()); return r->get_root()->bool_var(); diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 5c828678e8a..7c1f9e5666b 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -165,6 +165,7 @@ namespace euf { tmp_app m_tmp_app; enode_vector m_nodes; expr_ref_vector m_exprs; + func_decl_ref_vector m_eq_decls; vector m_decl2enodes; enode_vector m_empty_enodes; unsigned m_num_scopes = 0; @@ -263,10 +264,12 @@ namespace euf { /** * \brief check if two nodes are known to be disequal. */ - bool are_diseq(enode* a, enode* b) const; + bool are_diseq(enode* a, enode* b); enode* get_enode_eq_to(func_decl* f, unsigned num_args, enode* const* args); + enode* tmp_eq(enode* a, enode* b); + /** \brief Maintain and update cursor into propagated consequences. The result of get_literal() is a pair (n, is_eq) diff --git a/src/sat/smt/q_eval.cpp b/src/sat/smt/q_eval.cpp index 200b05ec944..c506a52f7e2 100644 --- a/src/sat/smt/q_eval.cpp +++ b/src/sat/smt/q_eval.cpp @@ -25,7 +25,10 @@ namespace q { struct eval::scoped_mark_reset { eval& e; scoped_mark_reset(eval& e): e(e) {} - ~scoped_mark_reset() { e.m_mark.reset(); } + ~scoped_mark_reset() { + e.m_mark.reset(); + e.m_diseq_undef = euf::enode_pair(); + } }; eval::eval(euf::solver& ctx): @@ -97,12 +100,18 @@ namespace q { if (sn && sn == tn) return l_true; + if (sn && sn == m_diseq_undef.first && tn == m_diseq_undef.second) + return l_undef; + if (sn && tn && ctx.get_egraph().are_diseq(sn, tn)) { evidence.push_back(euf::enode_pair(sn, tn)); return l_false; } - if (sn && tn) + if (sn && tn) { + m_diseq_undef = euf::enode_pair(sn, tn); return l_undef; + } + if (!sn && !tn) return compare_rec(n, binding, s, t, evidence); @@ -115,7 +124,11 @@ namespace q { std::swap(t, s); } unsigned sz = evidence.size(); - for (euf::enode* t1 : euf::enode_class(tn)) { + unsigned count = 0; + for (euf::enode* t1 : euf::enode_class(tn)) { + if (!t1->is_cgr()) + continue; + ++count; expr* t2 = t1->get_expr(); if ((c = compare_rec(n, binding, s, t2, evidence), c != l_undef)) { evidence.push_back(euf::enode_pair(t1, tn)); diff --git a/src/sat/smt/q_eval.h b/src/sat/smt/q_eval.h index 0ead01061fa..76c21934305 100644 --- a/src/sat/smt/q_eval.h +++ b/src/sat/smt/q_eval.h @@ -31,6 +31,7 @@ namespace q { euf::enode_vector m_eval; euf::enode_vector m_indirect_nodes; bool m_freeze_swap = false; + euf::enode_pair m_diseq_undef; struct scoped_mark_reset;