Skip to content

Commit

Permalink
redo bindings/fingerprints
Browse files Browse the repository at this point in the history
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
  • Loading branch information
NikolajBjorner committed Oct 5, 2021
1 parent 281fb67 commit 33f4e65
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 177 deletions.
81 changes: 64 additions & 17 deletions src/sat/smt/q_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Module Name:
#include "ast/euf/euf_enode.h"
#include "sat/smt/euf_solver.h"


namespace q {

struct lit {
Expand All @@ -35,14 +36,40 @@ namespace q {
std::ostream& display(std::ostream& out) const;
};

struct binding;

struct clause {
unsigned m_index;
vector<lit> m_lits;
quantifier_ref m_q;
unsigned m_watch = 0;
sat::literal m_literal = sat::null_literal;
q::quantifier_stat* m_stat = nullptr;
binding* m_bindings = nullptr;


clause(ast_manager& m, unsigned idx) : m_index(idx), m_q(m) {}

std::ostream& display(euf::solver& ctx, std::ostream& out) const;
lit const& operator[](unsigned i) const { return m_lits[i]; }
lit& operator[](unsigned i) { return m_lits[i]; }
unsigned size() const { return m_lits.size(); }
unsigned num_decls() const { return m_q->get_num_decls(); }
unsigned index() const { return m_index; }
quantifier* q() const { return m_q; }
};


struct binding : public dll_base<binding> {
clause* c;
app* m_pattern;
unsigned m_max_generation;
unsigned m_min_top_generation;
unsigned m_max_top_generation;
euf::enode* m_nodes[0];

binding(app* pat, unsigned max_generation, unsigned min_top, unsigned max_top):
binding(clause& c, app* pat, unsigned max_generation, unsigned min_top, unsigned max_top):
c(&c),
m_pattern(pat),
m_max_generation(max_generation),
m_min_top_generation(min_top),
Expand All @@ -53,29 +80,49 @@ namespace q {
euf::enode* operator[](unsigned i) const { return m_nodes[i]; }

std::ostream& display(euf::solver& ctx, unsigned num_nodes, std::ostream& out) const;

unsigned size() const { return c->num_decls(); }

quantifier* q() const { return c->m_q; }

bool eq(binding const& other) const {
if (q() != other.q())
return false;
for (unsigned i = size(); i-- > 0; )
if ((*this)[i] != other[i])
return false;
return true;
}
};

struct clause {
unsigned m_index;
vector<lit> m_lits;
quantifier_ref m_q;
unsigned m_watch = 0;
sat::literal m_literal = sat::null_literal;
q::quantifier_stat* m_stat = nullptr;
binding* m_bindings = nullptr;
struct binding_khasher {
unsigned operator()(binding const* f) const { return f->q()->get_id(); }
};

struct binding_chasher {
unsigned operator()(binding const* f, unsigned idx) const { return f->m_nodes[idx]->hash(); }
};

clause(ast_manager& m, unsigned idx): m_index(idx), m_q(m) {}
struct binding_hash_proc {
unsigned operator()(binding const* f) const {
return get_composite_hash<binding*, binding_khasher, binding_chasher>(const_cast<binding*>(f), f->size());
}
};

std::ostream& display(euf::solver& ctx, std::ostream& out) const;
lit const& operator[](unsigned i) const { return m_lits[i]; }
lit& operator[](unsigned i) { return m_lits[i]; }
unsigned size() const { return m_lits.size(); }
unsigned num_decls() const { return m_q->get_num_decls(); }
unsigned index() const { return m_index; }
quantifier* q() const { return m_q; }
struct binding_eq_proc {
bool operator()(binding const* a, binding const* b) const { return a->eq(*b); }
};

typedef ptr_hashtable<binding, binding_hash_proc, binding_eq_proc> bindings;

inline std::ostream& operator<<(std::ostream& out, binding const& f) {
out << "[fp " << f.q()->get_id() << ":";
for (unsigned i = 0; i < f.size(); ++i)
out << " " << f[i]->get_expr_id();
return out << "]";
}


struct justification {
expr* m_lhs, *m_rhs;
bool m_sign;
Expand Down
107 changes: 53 additions & 54 deletions src/sat/smt/q_ematch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,18 +219,48 @@ namespace q {
};



binding* ematch::tmp_binding(clause& c, app* pat, euf::enode* const* b) {
if (c.num_decls() > m_tmp_binding_capacity) {
void* mem = memory::allocate(sizeof(binding) + c.num_decls() * sizeof(euf::enode*));
m_tmp_binding = new (mem) binding(c, pat, 0, 0, 0);
m_tmp_binding_capacity = c.num_decls();
}

for (unsigned i = c.num_decls(); i-- > 0; )
m_tmp_binding->m_nodes[i] = b[i];
m_tmp_binding->m_pattern = pat;
m_tmp_binding->c = &c;

return m_tmp_binding.get();
}

binding* ematch::alloc_binding(clause& c, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top) {
binding* b = tmp_binding(c, pat, _binding);

if (m_bindings.contains(b))
return nullptr;

for (unsigned i = c.num_decls(); i-- > 0; )
b->m_nodes[i] = b->m_nodes[i]->get_root();

if (m_bindings.contains(b))
return nullptr;

unsigned n = c.num_decls();
unsigned sz = sizeof(binding) + sizeof(euf::enode* const*) * n;
void* mem = ctx.get_region().allocate(sz);
binding* b = new (mem) binding(pat, max_generation, min_top, max_top);
b = new (mem) binding(c, pat, max_generation, min_top, max_top);
b->init(b);
for (unsigned i = 0; i < n; ++i)
b->m_nodes[i] = _binding[i];

m_bindings.insert(b);
ctx.push(insert_map<bindings, binding*>(m_bindings, b));
return b;
}

euf::enode* const* ematch::alloc_binding(clause& c, euf::enode* const* _binding) {
euf::enode* const* ematch::alloc_nodes(clause& c, euf::enode* const* _binding) {
unsigned sz = sizeof(euf::enode* const*) * c.num_decls();
euf::enode** binding = (euf::enode**)ctx.get_region().allocate(sz);
for (unsigned i = 0; i < c.num_decls(); ++i)
Expand All @@ -244,8 +274,7 @@ namespace q {
clause& c = *m_clauses[idx];
bool new_propagation = false;
binding* b = alloc_binding(c, pat, _binding, max_generation, min_gen, max_gen);
fingerprint* f = add_fingerprint(c, *b, max_generation);
if (!f)
if (!b)
return;

if (propagate(false, _binding, max_generation, c, new_propagation))
Expand Down Expand Up @@ -276,7 +305,7 @@ namespace q {
if (ev == l_undef && max_generation > m_generation_propagation_threshold)
return false;
if (!is_owned)
binding = alloc_binding(c, binding);
binding = alloc_nodes(c, binding);

auto j_idx = mk_justification(idx, c, binding);

Expand Down Expand Up @@ -312,53 +341,21 @@ namespace q {
return true;
}

void ematch::instantiate(binding& b, clause& c) {
void ematch::instantiate(binding& b) {
if (m_stats.m_num_instantiations > ctx.get_config().m_qi_max_instances)
return;
unsigned max_generation = b.m_max_generation;
max_generation = std::max(max_generation, c.m_stat->get_generation());
c.m_stat->update_max_generation(max_generation);
fingerprint * f = add_fingerprint(c, b, max_generation);
if (!f)
return;
m_inst_queue.insert(f);
m_stats.m_num_instantiations++;
max_generation = std::max(max_generation, b.c->m_stat->get_generation());
b.c->m_stat->update_max_generation(max_generation);
m_stats.m_num_instantiations++;
m_inst_queue.insert(&b);
}

void ematch::add_instantiation(clause& c, binding& b, sat::literal lit) {
m_evidence.reset();
ctx.propagate(lit, mk_justification(UINT_MAX, c, b.nodes()));
}

void ematch::set_tmp_binding(fingerprint& fp) {
binding& b = *fp.b;
clause& c = *fp.c;
if (c.num_decls() > m_tmp_binding_capacity) {
void* mem = memory::allocate(sizeof(binding) + c.num_decls()*sizeof(euf::enode*));
m_tmp_binding = new (mem) binding(b.m_pattern, 0, 0, 0);
m_tmp_binding_capacity = c.num_decls();
}

fp.b = m_tmp_binding.get();
for (unsigned i = c.num_decls(); i-- > 0; )
fp.b->m_nodes[i] = b[i];
}

fingerprint* ematch::add_fingerprint(clause& c, binding& b, unsigned max_generation) {
fingerprint fp(c, b, max_generation);
if (m_fingerprints.contains(&fp))
return nullptr;
set_tmp_binding(fp);
for (unsigned i = c.num_decls(); i-- > 0; )
fp.b->m_nodes[i] = fp.b->m_nodes[i]->get_root();
if (m_fingerprints.contains(&fp))
return nullptr;
fingerprint* f = new (ctx.get_region()) fingerprint(c, b, max_generation);
m_fingerprints.insert(f);
ctx.push(insert_map<fingerprints, fingerprint*>(m_fingerprints, f));
return f;
}

sat::literal ematch::instantiate(clause& c, euf::enode* const* binding, lit const& l) {
expr_ref_vector _binding(m);
for (unsigned i = 0; i < c.num_decls(); ++i)
Expand Down Expand Up @@ -552,6 +549,7 @@ namespace q {


bool ematch::unit_propagate() {
return false;
return ctx.get_config().m_ematching && propagate(false);
}

Expand All @@ -569,12 +567,13 @@ namespace q {
if (!b)
continue;

do {
do {
if (propagate(true, b->m_nodes, b->m_max_generation, c, propagated))
to_remove.push_back(b);
else if (flush) {
instantiate(*b, c);
instantiate(*b);
to_remove.push_back(b);
propagated = true;
}
b = b->next();
}
Expand All @@ -600,21 +599,21 @@ namespace q {
TRACE("q", m_mam->display(tout););
if (propagate(false))
return true;
if (m_lazy_mam) {
if (m_lazy_mam)
m_lazy_mam->propagate();
if (propagate(false))
return true;
}
unsigned idx = 0;
for (clause* c : m_clauses) {
if (c->m_bindings)
insert_clause_in_queue(idx);
idx++;
}
if (propagate(false))
return true;
for (unsigned i = 0; i < m_clauses.size(); ++i)
if (m_clauses[i]->m_bindings)
insert_clause_in_queue(i);
if (propagate(true))
return true;
if (m_inst_queue.lazy_propagate())
return true;
for (unsigned i = 0; i < m_clauses.size(); ++i)
if (m_clauses[i]->m_bindings)
std::cout << "missed propagation " << i << "\n";
TRACE("q", tout << "no more propagation\n";);
return false;
}

Expand Down
14 changes: 5 additions & 9 deletions src/sat/smt/q_ematch.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ Module Name:
#include "sat/smt/sat_th.h"
#include "sat/smt/q_mam.h"
#include "sat/smt/q_clause.h"
#include "sat/smt/q_fingerprint.h"
#include "sat/smt/q_queue.h"
#include "sat/smt/q_eval.h"

Expand Down Expand Up @@ -69,7 +68,7 @@ namespace q {
ast_manager& m;
eval m_eval;
quantifier_stat_gen m_qstat_gen;
fingerprints m_fingerprints;
bindings m_bindings;
scoped_ptr<binding> m_tmp_binding;
unsigned m_tmp_binding_capacity = 0;
queue m_inst_queue;
Expand All @@ -90,16 +89,16 @@ namespace q {
unsigned_vector m_clause_queue;
euf::enode_pair_vector m_evidence;

euf::enode* const* alloc_binding(clause& c, euf::enode* const* _binding);
binding* alloc_binding(clause& c, app* pat, euf::enode* const* _bidning, unsigned max_generation, unsigned min_top, unsigned max_top);
void add_binding(clause& c, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top);
euf::enode* const* alloc_nodes(clause& c, euf::enode* const* _binding);
binding* tmp_binding(clause& c, app* pat, euf::enode* const* _binding);
binding* alloc_binding(clause& c, app* pat, euf::enode* const* _binding, unsigned max_generation, unsigned min_top, unsigned max_top);

sat::ext_justification_idx mk_justification(unsigned idx, clause& c, euf::enode* const* b);

void ensure_ground_enodes(expr* e);
void ensure_ground_enodes(clause const& c);

void instantiate(binding& b, clause& c);
void instantiate(binding& b);
sat::literal instantiate(clause& c, euf::enode* const* binding, lit const& l);

// register as callback into egraph.
Expand All @@ -115,9 +114,6 @@ namespace q {
clause* clausify(quantifier* q);
lit clausify_literal(expr* arg);

fingerprint* add_fingerprint(clause& c, binding& b, unsigned max_generation);
void set_tmp_binding(fingerprint& fp);

bool flush_prop_queue();
void propagate(bool is_conflict, unsigned idx, sat::ext_justification_idx j_idx);

Expand Down
Loading

0 comments on commit 33f4e65

Please sign in to comment.