diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 06b843b4442..a4f42fdb626 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -454,7 +454,7 @@ namespace sat { void display_lookahead_scores(std::ostream& out); - stats const& stats() const { return m_stats; } + stats const& get_stats() const { return m_stats; } protected: diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index ef020d3deac..eabb657ddae 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -40,6 +40,7 @@ Module Name: #include "model/model_evaluator.h" #include "sat/sat_solver.h" #include "sat/sat_params.hpp" +#include "sat/smt/euf_solver.h" #include "sat/tactic/goal2sat.h" #include "sat/tactic/sat_tactic.h" #include "sat/sat_simplifier_params.hpp" @@ -80,6 +81,8 @@ class inc_sat_solver : public solver { typedef obj_map dep2asm_t; + dep2asm_t m_dep2asm; + bool is_internalized() const { return m_fmls_head == m_fmls.size(); } public: inc_sat_solver(ast_manager& m, params_ref const& p, bool incremental_mode): @@ -120,7 +123,15 @@ class inc_sat_solver : public solver { ast_translation tr(m, dst_m); m_solver.pop_to_base_level(); inc_sat_solver* result = alloc(inc_sat_solver, dst_m, p, is_incremental()); - result->m_solver.copy(m_solver); + auto* ext = dynamic_cast(m_solver.get_extension()); + if (ext) { + auto& si = result->m_goal2sat.si(dst_m, m_params, result->m_solver, result->m_map, result->m_dep2asm, is_incremental()); + euf::solver::scoped_set_translate st(*ext, tr, result->m_map, si); + result->m_solver.copy(m_solver); + } + else { + result->m_solver.copy(m_solver); + } result->m_fmls_head = m_fmls_head; for (expr* f : m_fmls) result->m_fmls.push_back(tr(f)); for (expr* f : m_asmsf) result->m_asmsf.push_back(tr(f)); @@ -145,7 +156,7 @@ class inc_sat_solver : public solver { } init_preprocess(); m_solver.pop_to_base_level(); - dep2asm_t dep2asm; + m_dep2asm.reset(); expr_ref_vector asms(m); for (unsigned i = 0; i < sz; ++i) { expr_ref a(m.mk_fresh_const("s", m.mk_bool_sort()), m); @@ -154,7 +165,7 @@ class inc_sat_solver : public solver { asms.push_back(a); } VERIFY(l_true == internalize_formulas()); - VERIFY(l_true == internalize_assumptions(sz, asms.c_ptr(), dep2asm)); + VERIFY(l_true == internalize_assumptions(sz, asms.c_ptr())); svector nweights; for (unsigned i = 0; i < m_asms.size(); ++i) { nweights.push_back((unsigned) m_weights[i]); @@ -190,10 +201,10 @@ class inc_sat_solver : public solver { } TRACE("sat", tout << _assumptions << "\n";); - dep2asm_t dep2asm; + m_dep2asm.reset(); lbool r = internalize_formulas(); if (r != l_true) return r; - r = internalize_assumptions(sz, _assumptions.c_ptr(), dep2asm); + r = internalize_assumptions(sz, _assumptions.c_ptr()); if (r != l_true) return r; init_reason_unknown(); @@ -216,13 +227,13 @@ class inc_sat_solver : public solver { r = l_undef; } else if (sz > 0) { - check_assumptions(dep2asm); + check_assumptions(); } break; case l_false: // TBD: expr_dependency core is not accounted for. if (!m_asms.empty()) { - extract_core(dep2asm, asm2fml); + extract_core(asm2fml); } break; default: @@ -438,19 +449,19 @@ class inc_sat_solver : public solver { sat::literal_vector asms; sat::bool_var_vector bvars; vector lconseq; - dep2asm_t dep2asm; + m_dep2asm.reset(); obj_map asm2fml; m_solver.pop_to_base_level(); lbool r = internalize_formulas(); if (r != l_true) return r; r = internalize_vars(vars, bvars); if (r != l_true) return r; - r = internalize_assumptions(assumptions.size(), assumptions.c_ptr(), dep2asm); + r = internalize_assumptions(assumptions.size(), assumptions.c_ptr()); if (r != l_true) return r; r = m_solver.get_consequences(m_asms, bvars, lconseq); if (r == l_false) { if (!m_asms.empty()) { - extract_core(dep2asm, asm2fml); + extract_core(asm2fml); } return r; } @@ -465,10 +476,10 @@ class inc_sat_solver : public solver { // extract original fixed variables u_map asm2dep; - extract_asm2dep(dep2asm, asm2dep); + extract_asm2dep(asm2dep); for (auto v : vars) { expr_ref cons(m); - if (extract_fixed_variable(dep2asm, asm2dep, v, bool_var2conseq, lconseq, cons)) { + if (extract_fixed_variable(asm2dep, v, bool_var2conseq, lconseq, cons)) { conseq.push_back(cons); } } @@ -615,7 +626,7 @@ class inc_sat_solver : public solver { private: - lbool internalize_goal(goal_ref& g, dep2asm_t& dep2asm) { + lbool internalize_goal(goal_ref& g) { m_solver.pop_to_base_level(); if (m_solver.inconsistent()) return l_false; @@ -662,7 +673,7 @@ class inc_sat_solver : public solver { // ensure that if goal is already internalized, then import mc from m_solver. - m_goal2sat(*g, m_params, m_solver, m_map, dep2asm, is_incremental()); + m_goal2sat(*g, m_params, m_solver, m_map, m_dep2asm, is_incremental()); m_goal2sat.get_interpreted_atoms(atoms); if (!m_sat_mc) m_sat_mc = alloc(sat2goal::mc, m); m_sat_mc->flush_smc(m_solver, m_map); @@ -678,7 +689,7 @@ class inc_sat_solver : public solver { return l_true; } - lbool internalize_assumptions(unsigned sz, expr* const* asms, dep2asm_t& dep2asm) { + lbool internalize_assumptions(unsigned sz, expr* const* asms) { if (sz == 0 && get_num_assumptions() == 0) { m_asms.shrink(0); return l_true; @@ -690,9 +701,9 @@ class inc_sat_solver : public solver { for (unsigned i = 0; i < get_num_assumptions(); ++i) { g->assert_expr(get_assumption(i), m.mk_leaf(get_assumption(i))); } - lbool res = internalize_goal(g, dep2asm); + lbool res = internalize_goal(g); if (res == l_true) { - extract_assumptions(sz, asms, dep2asm); + extract_assumptions(sz, asms); } return res; } @@ -741,7 +752,7 @@ class inc_sat_solver : public solver { return internalized; } - bool extract_fixed_variable(dep2asm_t& dep2asm, u_map& asm2dep, expr* v, u_map const& bool_var2conseq, vector const& lconseq, expr_ref& conseq) { + bool extract_fixed_variable(u_map& asm2dep, expr* v, u_map const& bool_var2conseq, vector const& lconseq, expr_ref& conseq) { sat::bool_var_vector bvars; if (!internalize_var(v, bvars)) { @@ -831,7 +842,7 @@ class inc_sat_solver : public solver { expr* fml = m_fmls.get(i); g->assert_expr(fml); } - lbool res = internalize_goal(g, dep2asm); + lbool res = internalize_goal(g); if (res != l_undef) { m_fmls_head = m_fmls.size(); } @@ -839,13 +850,13 @@ class inc_sat_solver : public solver { return res; } - void extract_assumptions(unsigned sz, expr* const* asms, dep2asm_t& dep2asm) { + void extract_assumptions(unsigned sz, expr* const* asms) { m_asms.reset(); unsigned j = 0; sat::literal lit; sat::literal_set seen; for (unsigned i = 0; i < sz; ++i) { - if (dep2asm.find(asms[i], lit)) { + if (m_dep2asm.find(asms[i], lit)) { SASSERT(lit.var() <= m_solver.num_vars()); if (!seen.contains(lit)) { m_asms.push_back(lit); @@ -858,7 +869,7 @@ class inc_sat_solver : public solver { } } for (unsigned i = 0; i < get_num_assumptions(); ++i) { - if (dep2asm.find(get_assumption(i), lit)) { + if (m_dep2asm.find(get_assumption(i), lit)) { SASSERT(lit.var() <= m_solver.num_vars()); if (!seen.contains(lit)) { m_asms.push_back(lit); @@ -875,15 +886,15 @@ class inc_sat_solver : public solver { SASSERT(dep2asm.size() == m_asms.size()); } - void extract_asm2dep(dep2asm_t const& dep2asm, u_map& asm2dep) { - for (auto const& kv : dep2asm) { + void extract_asm2dep(u_map& asm2dep) { + for (auto const& kv : m_dep2asm) { asm2dep.insert(kv.m_value.index(), kv.m_key); } } - void extract_core(dep2asm_t& dep2asm, obj_map const& asm2fml) { + void extract_core(obj_map const& asm2fml) { u_map asm2dep; - extract_asm2dep(dep2asm, asm2dep); + extract_asm2dep(asm2dep); sat::literal_vector const& core = m_solver.get_core(); TRACE("sat", for (auto kv : dep2asm) { @@ -908,9 +919,9 @@ class inc_sat_solver : public solver { } } - void check_assumptions(dep2asm_t& dep2asm) { + void check_assumptions() { sat::model const & ll_m = m_solver.get_model(); - for (auto const& kv : dep2asm) { + for (auto const& kv : m_dep2asm) { sat::literal lit = kv.m_value; if (sat::value_at(lit, ll_m) != l_true) { IF_VERBOSE(0, verbose_stream() << mk_pp(kv.m_key, m) << " does not evaluate to true\n"; diff --git a/src/sat/smt/euf_ackerman.cpp b/src/sat/smt/euf_ackerman.cpp index bc6f36e6d86..0f9b2638ca3 100644 --- a/src/sat/smt/euf_ackerman.cpp +++ b/src/sat/smt/euf_ackerman.cpp @@ -169,7 +169,7 @@ namespace euf { SASSERT(s.s().at_base_lvl()); auto* n = m_queue; inference* k = nullptr; - unsigned num_prop = static_cast(s.s().stats().m_conflict * s.m_config.m_dack_factor); + unsigned num_prop = static_cast(s.s().get_stats().m_conflict * s.m_config.m_dack_factor); num_prop = std::min(num_prop, m_table.size()); for (unsigned i = 0; i < num_prop; ++i, n = k) { k = n->m_next; diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 7846a643678..1136e8d45df 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -264,7 +264,8 @@ namespace euf { solver* solver::copy_core() { ast_manager& to = m_translate ? m_translate->to() : m; atom2bool_var& a2b = m_translate_expr2var ? *m_translate_expr2var : m_expr2var; - auto* r = alloc(solver, to, a2b, si); + sat::sat_internalizer& to_si = m_translate_si ? *m_translate_si : si; + auto* r = alloc(solver, to, a2b, to_si); r->m_config = m_config; std::function copy_justification = [&](void* x) { return (void*)(r->base_ptr() + ((unsigned*)x - base_ptr())); }; r->m_egraph.copy_from(m_egraph, copy_justification); diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 537a0b66563..0ab7b5c7061 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -62,6 +62,7 @@ namespace euf { sat::lookahead* m_lookahead { nullptr }; ast_translation* m_translate { nullptr }; atom2bool_var* m_translate_expr2var { nullptr }; + sat::sat_internalizer* m_translate_si{ nullptr }; scoped_ptr m_ackerman; euf::enode* m_true { nullptr }; @@ -137,8 +138,13 @@ namespace euf { void set_lookahead(sat::lookahead* s) override { m_lookahead = s; } struct scoped_set_translate { solver& s; - scoped_set_translate(solver& s, ast_translation& t, atom2bool_var& a2b):s(s) { s.m_translate = &t; s.m_translate_expr2var = &a2b; } - ~scoped_set_translate() { s.m_translate = nullptr; s. m_translate_expr2var = nullptr; } + scoped_set_translate(solver& s, ast_translation& t, atom2bool_var& a2b, sat::sat_internalizer& si) : + s(s) { + s.m_translate = &t; + s.m_translate_expr2var = &a2b; + s.m_translate_si = &si; + } + ~scoped_set_translate() { s.m_translate = nullptr; s.m_translate_expr2var = nullptr; s.m_translate_si = nullptr; } }; double get_reward(literal l, ext_constraint_idx idx, sat::literal_occs_fun& occs) const override { return 0; } bool is_extended_binary(ext_justification_idx idx, literal_vector & r) override { return false; } diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index c1c3343b1ac..2d5d00f5dfc 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -770,6 +770,13 @@ bool goal2sat::has_interpreted_atoms() const { } +sat::sat_internalizer& goal2sat::si(ast_manager& m, params_ref const& p, sat::solver_core& t, atom2bool_var& a2b, dep2asm_map& dep2asm, bool default_external) { + if (!m_imp) + m_imp = alloc(imp, m, p, t, a2b, dep2asm, default_external); + return *m_imp; +} + + sat2goal::mc::mc(ast_manager& m): m(m), m_var2expr(m) {} diff --git a/src/sat/tactic/goal2sat.h b/src/sat/tactic/goal2sat.h index 40ccb2e9cab..df585bf3fab 100644 --- a/src/sat/tactic/goal2sat.h +++ b/src/sat/tactic/goal2sat.h @@ -66,6 +66,8 @@ class goal2sat { bool has_interpreted_atoms() const; + sat::sat_internalizer& si(ast_manager& m, params_ref const& p, sat::solver_core& t, atom2bool_var& a2b, dep2asm_map& dep2asm, bool default_external = false); + };