From 33f4e65fa919349501e7511669131f6742fc6b1b Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 5 Oct 2021 10:15:56 -0700 Subject: [PATCH] redo bindings/fingerprints Signed-off-by: Nikolaj Bjorner --- src/sat/smt/q_clause.h | 81 +++++++++++++++++++++------ src/sat/smt/q_ematch.cpp | 107 ++++++++++++++++++------------------ src/sat/smt/q_ematch.h | 14 ++--- src/sat/smt/q_fingerprint.h | 77 -------------------------- src/sat/smt/q_queue.cpp | 22 ++++---- src/sat/smt/q_queue.h | 18 +++--- 6 files changed, 142 insertions(+), 177 deletions(-) delete mode 100644 src/sat/smt/q_fingerprint.h diff --git a/src/sat/smt/q_clause.h b/src/sat/smt/q_clause.h index 66daf07ea8b..08a6f615a1a 100644 --- a/src/sat/smt/q_clause.h +++ b/src/sat/smt/q_clause.h @@ -22,6 +22,7 @@ Module Name: #include "ast/euf/euf_enode.h" #include "sat/smt/euf_solver.h" + namespace q { struct lit { @@ -35,14 +36,40 @@ namespace q { std::ostream& display(std::ostream& out) const; }; + struct binding; + + struct clause { + unsigned m_index; + vector m_lits; + quantifier_ref m_q; + unsigned m_watch = 0; + sat::literal m_literal = sat::null_literal; + q::quantifier_stat* m_stat = nullptr; + binding* m_bindings = nullptr; + + + clause(ast_manager& m, unsigned idx) : m_index(idx), m_q(m) {} + + std::ostream& display(euf::solver& ctx, std::ostream& out) const; + lit const& operator[](unsigned i) const { return m_lits[i]; } + lit& operator[](unsigned i) { return m_lits[i]; } + unsigned size() const { return m_lits.size(); } + unsigned num_decls() const { return m_q->get_num_decls(); } + unsigned index() const { return m_index; } + quantifier* q() const { return m_q; } + }; + + struct binding : public dll_base { + clause* c; app* m_pattern; unsigned m_max_generation; unsigned m_min_top_generation; unsigned m_max_top_generation; euf::enode* m_nodes[0]; - binding(app* pat, unsigned max_generation, unsigned min_top, unsigned max_top): + binding(clause& c, app* pat, unsigned max_generation, unsigned min_top, unsigned max_top): + c(&c), m_pattern(pat), m_max_generation(max_generation), m_min_top_generation(min_top), @@ -53,29 +80,49 @@ namespace q { euf::enode* operator[](unsigned i) const { return m_nodes[i]; } std::ostream& display(euf::solver& ctx, unsigned num_nodes, std::ostream& out) const; + + unsigned size() const { return c->num_decls(); } + + quantifier* q() const { return c->m_q; } + + bool eq(binding const& other) const { + if (q() != other.q()) + return false; + for (unsigned i = size(); i-- > 0; ) + if ((*this)[i] != other[i]) + return false; + return true; + } }; - struct clause { - unsigned m_index; - vector m_lits; - quantifier_ref m_q; - unsigned m_watch = 0; - sat::literal m_literal = sat::null_literal; - q::quantifier_stat* m_stat = nullptr; - binding* m_bindings = nullptr; + struct binding_khasher { + unsigned operator()(binding const* f) const { return f->q()->get_id(); } + }; + struct binding_chasher { + unsigned operator()(binding const* f, unsigned idx) const { return f->m_nodes[idx]->hash(); } + }; - clause(ast_manager& m, unsigned idx): m_index(idx), m_q(m) {} + struct binding_hash_proc { + unsigned operator()(binding const* f) const { + return get_composite_hash(const_cast(f), f->size()); + } + }; - std::ostream& display(euf::solver& ctx, std::ostream& out) const; - lit const& operator[](unsigned i) const { return m_lits[i]; } - lit& operator[](unsigned i) { return m_lits[i]; } - unsigned size() const { return m_lits.size(); } - unsigned num_decls() const { return m_q->get_num_decls(); } - unsigned index() const { return m_index; } - quantifier* q() const { return m_q; } + struct binding_eq_proc { + bool operator()(binding const* a, binding const* b) const { return a->eq(*b); } }; + typedef ptr_hashtable bindings; + + inline std::ostream& operator<<(std::ostream& out, binding const& f) { + out << "[fp " << f.q()->get_id() << ":"; + for (unsigned i = 0; i < f.size(); ++i) + out << " " << f[i]->get_expr_id(); + return out << "]"; + } + + struct justification { expr* m_lhs, *m_rhs; bool m_sign; diff --git a/src/sat/smt/q_ematch.cpp b/src/sat/smt/q_ematch.cpp index 56ab260932f..1cfc1f67852 100644 --- a/src/sat/smt/q_ematch.cpp +++ b/src/sat/smt/q_ematch.cpp @@ -219,18 +219,48 @@ namespace q { }; + + binding* ematch::tmp_binding(clause& c, app* pat, euf::enode* const* b) { + if (c.num_decls() > m_tmp_binding_capacity) { + void* mem = memory::allocate(sizeof(binding) + c.num_decls() * sizeof(euf::enode*)); + m_tmp_binding = new (mem) binding(c, pat, 0, 0, 0); + m_tmp_binding_capacity = c.num_decls(); + } + + for (unsigned i = c.num_decls(); i-- > 0; ) + m_tmp_binding->m_nodes[i] = b[i]; + m_tmp_binding->m_pattern = pat; + m_tmp_binding->c = &c; + + return m_tmp_binding.get(); + } + binding* ematch::alloc_binding(clause& c, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top) { + binding* b = tmp_binding(c, pat, _binding); + + if (m_bindings.contains(b)) + return nullptr; + + for (unsigned i = c.num_decls(); i-- > 0; ) + b->m_nodes[i] = b->m_nodes[i]->get_root(); + + if (m_bindings.contains(b)) + return nullptr; + unsigned n = c.num_decls(); unsigned sz = sizeof(binding) + sizeof(euf::enode* const*) * n; void* mem = ctx.get_region().allocate(sz); - binding* b = new (mem) binding(pat, max_generation, min_top, max_top); + b = new (mem) binding(c, pat, max_generation, min_top, max_top); b->init(b); for (unsigned i = 0; i < n; ++i) b->m_nodes[i] = _binding[i]; + + m_bindings.insert(b); + ctx.push(insert_map(m_bindings, b)); return b; } - euf::enode* const* ematch::alloc_binding(clause& c, euf::enode* const* _binding) { + euf::enode* const* ematch::alloc_nodes(clause& c, euf::enode* const* _binding) { unsigned sz = sizeof(euf::enode* const*) * c.num_decls(); euf::enode** binding = (euf::enode**)ctx.get_region().allocate(sz); for (unsigned i = 0; i < c.num_decls(); ++i) @@ -244,8 +274,7 @@ namespace q { clause& c = *m_clauses[idx]; bool new_propagation = false; binding* b = alloc_binding(c, pat, _binding, max_generation, min_gen, max_gen); - fingerprint* f = add_fingerprint(c, *b, max_generation); - if (!f) + if (!b) return; if (propagate(false, _binding, max_generation, c, new_propagation)) @@ -276,7 +305,7 @@ namespace q { if (ev == l_undef && max_generation > m_generation_propagation_threshold) return false; if (!is_owned) - binding = alloc_binding(c, binding); + binding = alloc_nodes(c, binding); auto j_idx = mk_justification(idx, c, binding); @@ -312,17 +341,14 @@ namespace q { return true; } - void ematch::instantiate(binding& b, clause& c) { + void ematch::instantiate(binding& b) { if (m_stats.m_num_instantiations > ctx.get_config().m_qi_max_instances) return; unsigned max_generation = b.m_max_generation; - max_generation = std::max(max_generation, c.m_stat->get_generation()); - c.m_stat->update_max_generation(max_generation); - fingerprint * f = add_fingerprint(c, b, max_generation); - if (!f) - return; - m_inst_queue.insert(f); - m_stats.m_num_instantiations++; + max_generation = std::max(max_generation, b.c->m_stat->get_generation()); + b.c->m_stat->update_max_generation(max_generation); + m_stats.m_num_instantiations++; + m_inst_queue.insert(&b); } void ematch::add_instantiation(clause& c, binding& b, sat::literal lit) { @@ -330,35 +356,6 @@ namespace q { ctx.propagate(lit, mk_justification(UINT_MAX, c, b.nodes())); } - void ematch::set_tmp_binding(fingerprint& fp) { - binding& b = *fp.b; - clause& c = *fp.c; - if (c.num_decls() > m_tmp_binding_capacity) { - void* mem = memory::allocate(sizeof(binding) + c.num_decls()*sizeof(euf::enode*)); - m_tmp_binding = new (mem) binding(b.m_pattern, 0, 0, 0); - m_tmp_binding_capacity = c.num_decls(); - } - - fp.b = m_tmp_binding.get(); - for (unsigned i = c.num_decls(); i-- > 0; ) - fp.b->m_nodes[i] = b[i]; - } - - fingerprint* ematch::add_fingerprint(clause& c, binding& b, unsigned max_generation) { - fingerprint fp(c, b, max_generation); - if (m_fingerprints.contains(&fp)) - return nullptr; - set_tmp_binding(fp); - for (unsigned i = c.num_decls(); i-- > 0; ) - fp.b->m_nodes[i] = fp.b->m_nodes[i]->get_root(); - if (m_fingerprints.contains(&fp)) - return nullptr; - fingerprint* f = new (ctx.get_region()) fingerprint(c, b, max_generation); - m_fingerprints.insert(f); - ctx.push(insert_map(m_fingerprints, f)); - return f; - } - sat::literal ematch::instantiate(clause& c, euf::enode* const* binding, lit const& l) { expr_ref_vector _binding(m); for (unsigned i = 0; i < c.num_decls(); ++i) @@ -552,6 +549,7 @@ namespace q { bool ematch::unit_propagate() { + return false; return ctx.get_config().m_ematching && propagate(false); } @@ -569,12 +567,13 @@ namespace q { if (!b) continue; - do { + do { if (propagate(true, b->m_nodes, b->m_max_generation, c, propagated)) to_remove.push_back(b); else if (flush) { - instantiate(*b, c); + instantiate(*b); to_remove.push_back(b); + propagated = true; } b = b->next(); } @@ -600,21 +599,21 @@ namespace q { TRACE("q", m_mam->display(tout);); if (propagate(false)) return true; - if (m_lazy_mam) { + if (m_lazy_mam) m_lazy_mam->propagate(); - if (propagate(false)) - return true; - } - unsigned idx = 0; - for (clause* c : m_clauses) { - if (c->m_bindings) - insert_clause_in_queue(idx); - idx++; - } + if (propagate(false)) + return true; + for (unsigned i = 0; i < m_clauses.size(); ++i) + if (m_clauses[i]->m_bindings) + insert_clause_in_queue(i); if (propagate(true)) return true; if (m_inst_queue.lazy_propagate()) return true; + for (unsigned i = 0; i < m_clauses.size(); ++i) + if (m_clauses[i]->m_bindings) + std::cout << "missed propagation " << i << "\n"; + TRACE("q", tout << "no more propagation\n";); return false; } diff --git a/src/sat/smt/q_ematch.h b/src/sat/smt/q_ematch.h index fbedbd65a53..bd79511a840 100644 --- a/src/sat/smt/q_ematch.h +++ b/src/sat/smt/q_ematch.h @@ -23,7 +23,6 @@ Module Name: #include "sat/smt/sat_th.h" #include "sat/smt/q_mam.h" #include "sat/smt/q_clause.h" -#include "sat/smt/q_fingerprint.h" #include "sat/smt/q_queue.h" #include "sat/smt/q_eval.h" @@ -69,7 +68,7 @@ namespace q { ast_manager& m; eval m_eval; quantifier_stat_gen m_qstat_gen; - fingerprints m_fingerprints; + bindings m_bindings; scoped_ptr m_tmp_binding; unsigned m_tmp_binding_capacity = 0; queue m_inst_queue; @@ -90,16 +89,16 @@ namespace q { unsigned_vector m_clause_queue; euf::enode_pair_vector m_evidence; - euf::enode* const* alloc_binding(clause& c, euf::enode* const* _binding); - binding* alloc_binding(clause& c, app* pat, euf::enode* const* _bidning, unsigned max_generation, unsigned min_top, unsigned max_top); - void add_binding(clause& c, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top); + euf::enode* const* alloc_nodes(clause& c, euf::enode* const* _binding); + binding* tmp_binding(clause& c, app* pat, euf::enode* const* _binding); + binding* alloc_binding(clause& c, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top); sat::ext_justification_idx mk_justification(unsigned idx, clause& c, euf::enode* const* b); void ensure_ground_enodes(expr* e); void ensure_ground_enodes(clause const& c); - void instantiate(binding& b, clause& c); + void instantiate(binding& b); sat::literal instantiate(clause& c, euf::enode* const* binding, lit const& l); // register as callback into egraph. @@ -115,9 +114,6 @@ namespace q { clause* clausify(quantifier* q); lit clausify_literal(expr* arg); - fingerprint* add_fingerprint(clause& c, binding& b, unsigned max_generation); - void set_tmp_binding(fingerprint& fp); - bool flush_prop_queue(); void propagate(bool is_conflict, unsigned idx, sat::ext_justification_idx j_idx); diff --git a/src/sat/smt/q_fingerprint.h b/src/sat/smt/q_fingerprint.h deleted file mode 100644 index 99ad602b9ce..00000000000 --- a/src/sat/smt/q_fingerprint.h +++ /dev/null @@ -1,77 +0,0 @@ -/*++ -Copyright (c) 2020 Microsoft Corporation - -Module Name: - - q_fingerprint.h - -Abstract: - - Fingerprint summary of a quantifier instantiation - -Author: - - Nikolaj Bjorner (nbjorner) 2021-01-24 - ---*/ -#pragma once - -#include "util/hashtable.h" -#include "ast/ast.h" -#include "ast/quantifier_stat.h" -#include "ast/euf/euf_enode.h" -#include "sat/smt/q_clause.h" - - -namespace q { - - struct fingerprint { - clause* c; - binding* b; - unsigned m_max_generation; - - unsigned size() const { return c->num_decls(); } - euf::enode* const* nodes() const { return b->nodes(); } - quantifier* q() const { return c->m_q; } - - fingerprint(clause& _c, binding& _b, unsigned mg) : - c(&_c), b(&_b), m_max_generation(mg) {} - - bool eq(fingerprint const& other) const { - if (c->m_q != other.c->m_q) - return false; - for (unsigned i = size(); i--> 0; ) - if ((*b)[i] != (*other.b)[i]) - return false; - return true; - } - }; - - struct fingerprint_khasher { - unsigned operator()(fingerprint const * f) const { return f->c->m_q->get_id(); } - }; - - struct fingerprint_chasher { - unsigned operator()(fingerprint const * f, unsigned idx) const { return f->b->m_nodes[idx]->hash(); } - }; - - struct fingerprint_hash_proc { - unsigned operator()(fingerprint const * f) const { - return get_composite_hash(const_cast(f), f->size()); - } - }; - - struct fingerprint_eq_proc { - bool operator()(fingerprint const* a, fingerprint const* b) const { return a->eq(*b); } - }; - - typedef ptr_hashtable fingerprints; - - inline std::ostream& operator<<(std::ostream& out, fingerprint const& f) { - out << "[fp " << f.q()->get_id() << ":"; - for (unsigned i = 0; i < f.size(); ++i) - out << " " << (*f.b)[i]->get_expr_id(); - return out << "]"; - } - -} diff --git a/src/sat/smt/q_queue.cpp b/src/sat/smt/q_queue.cpp index 247451fb4d7..2e8db482f90 100644 --- a/src/sat/smt/q_queue.cpp +++ b/src/sat/smt/q_queue.cpp @@ -86,13 +86,13 @@ namespace q { m_parser.add_var("cs_factor"); } - void queue::set_values(fingerprint& f, float cost) { + void queue::set_values(binding& f, float cost) { quantifier_stat * stat = f.c->m_stat; quantifier* q = f.q(); - app* pat = f.b->m_pattern; + app* pat = f.m_pattern; m_vals[COST] = cost; - m_vals[MIN_TOP_GENERATION] = static_cast(f.b->m_min_top_generation); - m_vals[MAX_TOP_GENERATION] = static_cast(f.b->m_max_top_generation); + m_vals[MIN_TOP_GENERATION] = static_cast(f.m_min_top_generation); + m_vals[MAX_TOP_GENERATION] = static_cast(f.m_max_top_generation); m_vals[INSTANCES] = static_cast(stat->get_num_instances_curr_branch()); m_vals[SIZE] = static_cast(stat->get_size()); m_vals[DEPTH] = static_cast(stat->get_depth()); @@ -108,14 +108,14 @@ namespace q { TRACE("q_detail", for (unsigned i = 0; i < m_vals.size(); i++) { tout << m_vals[i] << " "; } tout << "\n";); } - float queue::get_cost(fingerprint& f) { + float queue::get_cost(binding& f) { set_values(f, 0); float r = m_evaluator(m_cost_function, m_vals.size(), m_vals.data()); f.c->m_stat->update_max_cost(r); return r; } - unsigned queue::get_new_gen(fingerprint& f, float cost) { + unsigned queue::get_new_gen(binding& f, float cost) { set_values(f, cost); float r = m_evaluator(m_new_gen_function, m_vals.size(), m_vals.data()); return std::max(f.m_max_generation + 1, static_cast(r)); @@ -129,7 +129,7 @@ namespace q { } }; - void queue::insert(fingerprint* f) { + void queue::insert(binding* f) { float cost = get_cost(*f); if (m_new_entries.empty()) ctx.push(reset_new_entries(m_new_entries)); @@ -137,7 +137,7 @@ namespace q { } void queue::instantiate(entry& ent) { - fingerprint & f = *ent.m_qb; + binding& f = *ent.m_qb; quantifier * q = f.q(); unsigned num_bindings = f.size(); quantifier_stat * stat = f.c->m_stat; @@ -151,7 +151,7 @@ namespace q { auto* ebindings = m_subst(q, num_bindings); for (unsigned i = 0; i < num_bindings; ++i) - ebindings[i] = f.nodes()[i]->get_expr(); + ebindings[i] = f[i]->get_expr(); expr_ref instance = m_subst(); ctx.get_rewriter()(instance); if (m.is_true(instance)) { @@ -164,7 +164,7 @@ namespace q { euf::solver::scoped_generation _sg(ctx, gen); sat::literal result_l = ctx.mk_literal(instance); - em.add_instantiation(*f.c, *f.b, result_l); + em.add_instantiation(*f.c, f, result_l); } bool queue::propagate() { @@ -178,7 +178,7 @@ namespace q { if (0 == since_last_check && ctx.resource_limits_exceeded()) break; - fingerprint& f = *curr.m_qb; + binding& f = *curr.m_qb; if (curr.m_cost <= m_eager_cost_threshold) instantiate(curr); diff --git a/src/sat/smt/q_queue.h b/src/sat/smt/q_queue.h index c23cb0377f5..3750ee31baf 100644 --- a/src/sat/smt/q_queue.h +++ b/src/sat/smt/q_queue.h @@ -20,7 +20,7 @@ Module Name: #include "ast/cost_evaluator.h" #include "ast/rewriter/cached_var_subst.h" #include "parsers/util/cost_parser.h" -#include "sat/smt/q_fingerprint.h" +#include "sat/smt/q_clause.h" @@ -51,12 +51,12 @@ namespace q { cost_evaluator m_evaluator; cached_var_subst m_subst; svector m_vals; - double m_eager_cost_threshold { 0 }; + double m_eager_cost_threshold = 0; struct entry { - fingerprint * m_qb; + binding * m_qb; float m_cost; - bool m_instantiated{ false }; - entry(fingerprint * f, float c):m_qb(f), m_cost(c) {} + bool m_instantiated = false; + entry(binding * f, float c):m_qb(f), m_cost(c) {} }; struct reset_new_entries; struct reset_instantiated; @@ -64,18 +64,18 @@ namespace q { svector m_new_entries; svector m_delayed_entries; - float get_cost(fingerprint& f); - void set_values(fingerprint& f, float cost); + float get_cost(binding& f); + void set_values(binding& f, float cost); void init_parser_vars(); void setup(); - unsigned get_new_gen(fingerprint& f, float cost); + unsigned get_new_gen(binding& f, float cost); void instantiate(entry& e); public: queue(ematch& em, euf::solver& ctx); - void insert(fingerprint* f); + void insert(binding* f); bool propagate();