diff --git a/src/api/api_quant.cpp b/src/api/api_quant.cpp index 25d31229f3b..1e764301b75 100644 --- a/src/api/api_quant.cpp +++ b/src/api/api_quant.cpp @@ -72,14 +72,11 @@ extern "C" { expr * const* ps = reinterpret_cast(patterns); expr * const* no_ps = reinterpret_cast(no_patterns); symbol qid = to_symbol(quantifier_id); - bool is_rec = mk_c(c)->m().rec_fun_qid() == qid; - if (!is_rec) { - pattern_validator v(mk_c(c)->m()); - for (unsigned i = 0; i < num_patterns; i++) { - if (!v(num_decls, ps[i], 0, 0)) { - SET_ERROR_CODE(Z3_INVALID_PATTERN, nullptr); - return nullptr; - } + pattern_validator v(mk_c(c)->m()); + for (unsigned i = 0; i < num_patterns; i++) { + if (!v(num_decls, ps[i], 0, 0)) { + SET_ERROR_CODE(Z3_INVALID_PATTERN, nullptr); + return nullptr; } } sort* const* ts = reinterpret_cast(sorts); diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index 454f142613d..56b5d76a8f5 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -1379,7 +1379,6 @@ ast_manager::ast_manager(proof_gen_mode m, char const * trace_file, bool is_form m_proof_mode(m), m_trace_stream(nullptr), m_trace_stream_owner(false), - m_rec_fun(":rec-fun"), m_lambda_def(":lambda-def") { if (trace_file) { @@ -1403,7 +1402,6 @@ ast_manager::ast_manager(proof_gen_mode m, std::fstream * trace_stream, bool is_ m_proof_mode(m), m_trace_stream(trace_stream), m_trace_stream_owner(false), - m_rec_fun(":rec-fun"), m_lambda_def(":lambda-def") { if (!is_format_manager) @@ -1421,7 +1419,6 @@ ast_manager::ast_manager(ast_manager const & src, bool disable_proofs): m_proof_mode(disable_proofs ? PGM_DISABLED : src.m_proof_mode), m_trace_stream(src.m_trace_stream), m_trace_stream_owner(false), - m_rec_fun(":rec-fun"), m_lambda_def(":lambda-def") { SASSERT(!src.is_format_manager()); m_format_manager = alloc(ast_manager, PGM_DISABLED, m_trace_stream, true); @@ -1756,13 +1753,6 @@ quantifier* ast_manager::is_lambda_def(func_decl* f) { return nullptr; } - -func_decl* ast_manager::get_rec_fun_decl(quantifier* q) const { - SASSERT(is_rec_fun_def(q)); - return to_app(to_app(q->get_pattern(0))->get_arg(0))->get_decl(); -} - - void ast_manager::register_plugin(family_id id, decl_plugin * plugin) { SASSERT(m_plugins.get(id, 0) == 0); m_plugins.setx(id, plugin, 0); diff --git a/src/ast/ast.h b/src/ast/ast.h index ece6d0005c0..2b24d55584a 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -1553,7 +1553,6 @@ class ast_manager { bool slow_not_contains(ast const * n); #endif ast_manager * m_format_manager; // hack for isolating format objects in a different manager. - symbol m_rec_fun; symbol m_lambda_def; void init(); @@ -1666,13 +1665,10 @@ class ast_manager { bool contains(ast * a) const { return m_ast_table.contains(a); } - bool is_rec_fun_def(quantifier* q) const { return q->get_qid() == m_rec_fun; } bool is_lambda_def(quantifier* q) const { return q->get_qid() == m_lambda_def; } void add_lambda_def(func_decl* f, quantifier* q); quantifier* is_lambda_def(func_decl* f); - func_decl* get_rec_fun_decl(quantifier* q) const; - symbol const& rec_fun_qid() const { return m_rec_fun; } symbol const& lambda_def_qid() const { return m_lambda_def; } diff --git a/src/ast/recfun_decl_plugin.h b/src/ast/recfun_decl_plugin.h index 2b8b427444c..4536f5f99e8 100644 --- a/src/ast/recfun_decl_plugin.h +++ b/src/ast/recfun_decl_plugin.h @@ -220,6 +220,7 @@ namespace recfun { ~util(); ast_manager & m() { return m_manager; } + family_id get_family_id() const { return m_fid; } decl::plugin& get_plugin() { return *m_plugin; } bool is_case_pred(expr * e) const { return is_app_of(e, m_fid, OP_FUN_CASE_PRED); } diff --git a/src/ast/rewriter/CMakeLists.txt b/src/ast/rewriter/CMakeLists.txt index ba57f757d43..cb51085a16d 100644 --- a/src/ast/rewriter/CMakeLists.txt +++ b/src/ast/rewriter/CMakeLists.txt @@ -29,6 +29,7 @@ z3_add_component(rewriter pb2bv_rewriter.cpp push_app_ite.cpp quant_hoist.cpp + recfun_rewriter.cpp rewriter.cpp seq_rewriter.cpp th_rewriter.cpp diff --git a/src/ast/rewriter/recfun_rewriter.cpp b/src/ast/rewriter/recfun_rewriter.cpp new file mode 100644 index 00000000000..24318dc2cd2 --- /dev/null +++ b/src/ast/rewriter/recfun_rewriter.cpp @@ -0,0 +1,39 @@ +/*++ +Copyright (c) 2018 Microsoft Corporation + +Module Name: + + recfun_rewriter.cpp + +Abstract: + + Rewriter recursive function applications to values + +Author: + + Nikolaj Bjorner (nbjorner) 2020-04-26 + + +--*/ + + +#include "ast/rewriter/recfun_rewriter.h" +#include "ast/rewriter/var_subst.h" + +br_status recfun_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * const * args, expr_ref & result) { + if (m_rec.is_defined(f)) { + for (unsigned i = 0; i < num_args; ++i) { + if (!m.is_value(args[i])) + return BR_FAILED; + } + recfun::def const& d = m_rec.get_def(f); + var_subst sub(m); + result = sub(d.get_rhs(), num_args, args); + return BR_REWRITE_FULL; + } + else { + return BR_FAILED; + } +} + + diff --git a/src/ast/rewriter/recfun_rewriter.h b/src/ast/rewriter/recfun_rewriter.h new file mode 100644 index 00000000000..a08a75783b1 --- /dev/null +++ b/src/ast/rewriter/recfun_rewriter.h @@ -0,0 +1,36 @@ +/*++ +Copyright (c) 2018 Microsoft Corporation + +Module Name: + + recfun_rewriter.h + +Abstract: + + Rewriter recursive function applications to values + +Author: + + Nikolaj Bjorner (nbjorner) 2020-04-26 + + +--*/ + +#pragma once + +#include "ast/recfun_decl_plugin.h" +#include "ast/rewriter/rewriter.h" + +class recfun_rewriter { + ast_manager& m; + recfun::util m_rec; +public: + recfun_rewriter(ast_manager& m): m(m), m_rec(m) {} + ~recfun_rewriter() {} + + br_status mk_app_core(func_decl * f, unsigned num_args, expr * const * args, expr_ref & result); + + family_id get_fid() const { return m_rec.get_family_id(); } + +}; + diff --git a/src/ast/rewriter/th_rewriter.cpp b/src/ast/rewriter/th_rewriter.cpp index a19c3abf84d..94b661e46be 100644 --- a/src/ast/rewriter/th_rewriter.cpp +++ b/src/ast/rewriter/th_rewriter.cpp @@ -26,6 +26,7 @@ Module Name: #include "ast/rewriter/fpa_rewriter.h" #include "ast/rewriter/dl_rewriter.h" #include "ast/rewriter/pb_rewriter.h" +#include "ast/rewriter/recfun_rewriter.h" #include "ast/rewriter/seq_rewriter.h" #include "ast/rewriter/rewriter_def.h" #include "ast/rewriter/var_subst.h" @@ -46,6 +47,7 @@ struct th_rewriter_cfg : public default_rewriter_cfg { dl_rewriter m_dl_rw; pb_rewriter m_pb_rw; seq_rewriter m_seq_rw; + recfun_rewriter m_rec_rw; arith_util m_a_util; bv_util m_bv_util; unsigned long long m_max_memory; // in bytes @@ -219,6 +221,8 @@ struct th_rewriter_cfg : public default_rewriter_cfg { return m_pb_rw.mk_app_core(f, num, args, result); if (fid == m_seq_rw.get_fid()) return m_seq_rw.mk_app_core(f, num, args, result); + if (fid == m_rec_rw.get_fid()) + return m_rec_rw.mk_app_core(f, num, args, result); return BR_FAILED; } @@ -747,6 +751,7 @@ struct th_rewriter_cfg : public default_rewriter_cfg { m_dl_rw(m), m_pb_rw(m), m_seq_rw(m), + m_rec_rw(m), m_a_util(m), m_bv_util(m), m_used_dependencies(m), diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index 9a1900355df..9d8a81a8442 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -4507,32 +4507,6 @@ namespace smt { void context::add_rec_funs_to_model() { if (!m_model) return; - for (unsigned i = 0; !get_cancel_flag() && i < m_asserted_formulas.get_num_formulas(); ++i) { - expr* e = m_asserted_formulas.get_formula(i); - if (is_quantifier(e)) { - quantifier* q = to_quantifier(e); - if (!m.is_rec_fun_def(q)) continue; - TRACE("context", tout << mk_pp(e, m) << "\n";); - SASSERT(q->get_num_patterns() == 2); - expr* fn = to_app(q->get_pattern(0))->get_arg(0); - expr* body = to_app(q->get_pattern(1))->get_arg(0); - SASSERT(is_app(fn)); - // reverse argument order so that variable 0 starts at the beginning. - expr_ref_vector subst(m); - unsigned idx = 0; - for (expr* arg : *to_app(fn)) { - subst.push_back(m.mk_var(idx++, m.get_sort(arg))); - } - expr_ref bodyr(m); - var_subst sub(m, true); - TRACE("context", tout << expr_ref(q, m) << " " << subst << "\n";); - bodyr = sub(body, subst.size(), subst.c_ptr()); - func_decl* f = to_app(fn)->get_decl(); - func_interp* fi = alloc(func_interp, m, f->get_arity()); - fi->set_else(bodyr); - m_model->register_decl(f, fi); - } - } recfun::util u(m); func_decl_ref_vector recfuns = u.get_rec_funs(); for (func_decl* f : recfuns) { diff --git a/src/smt/smt_context_inv.cpp b/src/smt/smt_context_inv.cpp index 1caf0fbb326..3938ead08cc 100644 --- a/src/smt/smt_context_inv.cpp +++ b/src/smt/smt_context_inv.cpp @@ -367,9 +367,6 @@ namespace smt { if (!is_ground(n)) { continue; } - if (is_quantifier(n) && m.is_rec_fun_def(to_quantifier(n))) { - continue; - } switch (get_assignment(lit)) { case l_undef: break; diff --git a/src/smt/smt_model_checker.cpp b/src/smt/smt_model_checker.cpp index d1030efdbb0..efaea40e3ee 100644 --- a/src/smt/smt_model_checker.cpp +++ b/src/smt/smt_model_checker.cpp @@ -47,7 +47,6 @@ namespace smt { m_model_finder(mf), m_max_cexs(1), m_iteration_idx(0), - m_has_rec_fun(false), m_curr_model(nullptr), m_fresh_exprs(m), m_pinned_exprs(m) { @@ -380,36 +379,6 @@ namespace smt { return false; } - bool model_checker::check_rec_fun(quantifier* q, bool strict_rec_fun) { - TRACE("model_checker", tout << mk_pp(q, m) << "\n";); - SASSERT(q->get_num_patterns() == 2); // first pattern is the function, second is the body. - func_decl* f = m.get_rec_fun_decl(q); - - expr_ref_vector args(m); - unsigned num_decls = q->get_num_decls(); - args.resize(num_decls, nullptr); - var_subst sub(m); - expr_ref tmp(m), result(m); - for (enode* n : m_context->enodes_of(f)) { - if (m_context->is_relevant(n)) { - app* e = n->get_owner(); - SASSERT(e->get_num_args() == num_decls); - for (unsigned i = 0; i < num_decls; ++i) { - args[i] = e->get_arg(i); - } - tmp = sub(q->get_expr(), num_decls, args.c_ptr()); - TRACE("model_checker", tout << "curr_model:\n"; model_pp(tout, *m_curr_model);); - m_curr_model->eval(tmp, result, true); - if (strict_rec_fun ? !m.is_true(result) : m.is_false(result)) { - add_instance(q, args, 0, nullptr); - return false; - } - TRACE("model_checker", tout << tmp << "\nevaluates to:\n" << result << "\n";); - } - } - return true; - } - void model_checker::init_aux_context() { if (!m_fparams) { m_fparams = alloc(smt_params, m_context->get_fparams()); @@ -458,7 +427,7 @@ namespace smt { bool found_relevant = false; unsigned num_failures = 0; - check_quantifiers(false, found_relevant, num_failures); + check_quantifiers(found_relevant, num_failures); if (found_relevant) m_iteration_idx++; @@ -467,11 +436,11 @@ namespace smt { TRACE("model_checker", tout << "model checker result: " << (num_failures == 0) << "\n";); m_max_cexs += m_params.m_mbqi_max_cexs; - if (num_failures == 0 && (!m_context->validate_model() || has_rec_under_quantifiers())) { + if (num_failures == 0 && (!m_context->validate_model())) { num_failures = 1; // this time force expanding recursive function definitions // that are not forced true in the current model. - check_quantifiers(true, found_relevant, num_failures); + check_quantifiers(found_relevant, num_failures); } if (num_failures == 0) m_curr_model->cleanup(); @@ -484,43 +453,6 @@ namespace smt { return num_failures == 0; } - struct has_rec_fun_proc { - obj_hashtable& m_rec_funs; - bool m_has_rec_fun; - - bool has_rec_fun() const { return m_has_rec_fun; } - - has_rec_fun_proc(obj_hashtable& rec_funs): - m_rec_funs(rec_funs), - m_has_rec_fun(false) {} - - void operator()(app* fn) { - m_has_rec_fun |= m_rec_funs.contains(fn->get_decl()); - } - void operator()(expr*) {} - }; - - bool model_checker::has_rec_under_quantifiers() { - if (!m_has_rec_fun) { - return false; - } - obj_hashtable rec_funs; - for (quantifier * q : *m_qm) { - if (m.is_rec_fun_def(q)) { - rec_funs.insert(m.get_rec_fun_decl(q)); - } - } - expr_fast_mark1 visited; - has_rec_fun_proc proc(rec_funs); - for (quantifier * q : *m_qm) { - if (!m.is_rec_fun_def(q)) { - quick_for_each_expr(proc, visited, q); - if (proc.has_rec_fun()) return true; - } - } - return false; - } - // // (repeated from defined_names.cpp) // NB. The pattern for lambdas is incomplete. @@ -532,7 +464,7 @@ namespace smt { // using multi-patterns. // - void model_checker::check_quantifiers(bool strict_rec_fun, bool& found_relevant, unsigned& num_failures) { + void model_checker::check_quantifiers(bool& found_relevant, unsigned& num_failures) { for (quantifier * q : *m_qm) { if (!(m_qm->mbqi_enabled(q) && m_context->is_relevant(q) && @@ -549,14 +481,7 @@ namespace smt { verbose_stream() << "(smt.mbqi :checking " << q->get_qid() << ")\n"; } found_relevant = true; - if (m.is_rec_fun_def(q)) { - m_has_rec_fun = true; - if (!check_rec_fun(q, strict_rec_fun)) { - TRACE("model_checker", tout << "checking recursive function failed\n";); - num_failures++; - } - } - else if (!check(q)) { + if (!check(q)) { if (m_params.m_mbqi_trace || get_verbosity_level() >= 5) { IF_VERBOSE(0, verbose_stream() << "(smt.mbqi :failed " << q->get_qid() << ")\n"); } diff --git a/src/smt/smt_model_checker.h b/src/smt/smt_model_checker.h index e8676cbaffb..b437115351a 100644 --- a/src/smt/smt_model_checker.h +++ b/src/smt/smt_model_checker.h @@ -51,7 +51,6 @@ namespace smt { scoped_ptr m_aux_context; // Auxiliary context used for model checking quantifiers. unsigned m_max_cexs; unsigned m_iteration_idx; - bool m_has_rec_fun; proto_model * m_curr_model; obj_map m_value2expr; expr_ref_vector m_fresh_exprs; @@ -67,9 +66,7 @@ namespace smt { void assert_neg_q_m(quantifier * q, expr_ref_vector & sks); bool add_blocking_clause(model * cex, expr_ref_vector & sks); bool check(quantifier * q); - bool check_rec_fun(quantifier* q, bool strict_rec_fun); - bool has_rec_under_quantifiers(); - void check_quantifiers(bool strict_rec_fun, bool& found_relevant, unsigned& num_failures); + void check_quantifiers(bool& found_relevant, unsigned& num_failures); struct instance { quantifier * m_q; diff --git a/src/smt/smt_quantifier.cpp b/src/smt/smt_quantifier.cpp index 14792bf761f..97482850146 100644 --- a/src/smt/smt_quantifier.cpp +++ b/src/smt/smt_quantifier.cpp @@ -625,9 +625,6 @@ namespace smt { if (!m_fparams->m_ematching) { return; } - if (false && m.is_rec_fun_def(q) && mbqi_enabled(q)) { - return; - } bool has_unary_pattern = false; unsigned num_patterns = q->get_num_patterns(); for (unsigned i = 0; i < num_patterns; i++) { @@ -644,11 +641,7 @@ namespace smt { app * mp = to_app(q->get_pattern(i)); SASSERT(m.is_pattern(mp)); bool unary = (mp->get_num_args() == 1); - if (m.is_rec_fun_def(q) && i > 0) { - // add only the first pattern - TRACE("quantifier", tout << "skip recursive function body " << mk_ismt2_pp(mp, m) << "\n";); - } - else if (!unary && j >= num_eager_multi_patterns) { + if (!unary && j >= num_eager_multi_patterns) { TRACE("quantifier", tout << "delaying (too many multipatterns):\n" << mk_ismt2_pp(mp, m) << "\n" << "j: " << j << " unary: " << unary << " m_params.m_qi_max_eager_multipatterns: " << m_fparams->m_qi_max_eager_multipatterns << " num_eager_multi_patterns: " << num_eager_multi_patterns << "\n";);