Skip to content

Commit

Permalink
Add command to set initial value hints for solver in various components
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolajBjorner committed Sep 18, 2024
1 parent 1c163db commit a3f35b6
Show file tree
Hide file tree
Showing 17 changed files with 82 additions and 8 deletions.
20 changes: 20 additions & 0 deletions src/cmd_context/basic_cmds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,25 @@ UNARY_CMD(echo_cmd, "echo", "<string>", "display the given string", CPK_STRING,
else
ctx.regular_stream() << arg << std::endl;);

class set_initial_value_cmd : public cmd {
expr* m_var = nullptr, *m_value = nullptr;
public:
set_initial_value_cmd(): cmd("set-initial-value") {}
char const* get_usage() const override { return "<var> <value>"; }
char const* get_descr(cmd_context& ctx) const { return "set an initial value for search as a hint to the solver"; }
unsigned get_arity() const { return 2; }
void prepare(cmd_context& ctx) { m_var = m_value = nullptr; }
cmd_arg_kind next_arg_kind(cmd_context& ctx) const { return CPK_EXPR; }
void set_next_arg(cmd_context& ctx, expr* e) { if (m_var) m_value = e; else m_var = e; }
void execute(cmd_context& ctx) {
SASSERT(m_var && m_value);
if (ctx.get_opt())
ctx.get_opt()->initialize_value(m_var, m_value);
else if (ctx.get_solver())
ctx.get_solver()->user_propagate_initialize_value(m_var, m_value);
}
};

class set_get_option_cmd : public cmd {
protected:
symbol m_true;
Expand Down Expand Up @@ -893,6 +912,7 @@ void install_basic_cmds(cmd_context & ctx) {
ctx.insert(alloc(get_option_cmd));
ctx.insert(alloc(get_info_cmd));
ctx.insert(alloc(set_info_cmd));
ctx.insert(alloc(set_initial_value_cmd));
ctx.insert(alloc(get_consequences_cmd));
ctx.insert(alloc(builtin_cmd, "assert", "<term>", "assert term."));
ctx.insert(alloc(builtin_cmd, "check-sat", "<boolean-constants>*", "check if the current context is satisfiable. If a list of boolean constants B is provided, then check if the current context is consistent with assigning every constant in B to true."));
Expand Down
2 changes: 2 additions & 0 deletions src/cmd_context/cmd_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ class opt_wrapper : public check_sat_result {
virtual void set_logic(symbol const& s) = 0;
virtual void get_box_model(model_ref& mdl, unsigned index) = 0;
virtual void updt_params(params_ref const& p) = 0;
virtual void initialize_value(expr* var, expr* value) = 0;

};

class ast_context_params : public context_params {
Expand Down
3 changes: 3 additions & 0 deletions src/nlsat/tactic/nlsat_tactic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ class nlsat_tactic : public tactic {
void reset_statistics() override {
m_stats.reset();
}

void user_propagate_initialize_value(expr* var, expr* value) override { }

};

tactic * mk_nlsat_tactic(ast_manager & m, params_ref const & p) {
Expand Down
4 changes: 2 additions & 2 deletions src/opt/opt_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ namespace opt {

void add_offset(unsigned id, rational const& o) override;

void initialize_value(expr* var, expr* value);

void initialize_value(expr* var, expr* value) override;
void register_on_model(on_model_t& ctx, std::function<void(on_model_t&, model_ref&)>& on_model) {
m_on_model_ctx = ctx;
m_on_model_eh = on_model;
Expand Down
2 changes: 2 additions & 0 deletions src/qe/nlqsat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,8 @@ namespace qe {
void collect_param_descrs(param_descrs & r) override {
}

void user_propagate_initialize_value(expr* var, expr* value) override { }


void operator()(/* in */ goal_ref const & in,
/* out */ goal_ref_buffer & result) override {
Expand Down
4 changes: 3 additions & 1 deletion src/qe/qsat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,9 @@ namespace qe {

tactic * translate(ast_manager & m) override {
return alloc(qsat, m, m_params, m_mode);
}
}

void user_propagate_initialize_value(expr* var, expr* value) override { }

lbool maximize(expr_ref_vector const& fmls, app* t, model_ref& mdl, opt::inf_eps& value) {
expr_ref_vector defs(m);
Expand Down
5 changes: 5 additions & 0 deletions src/sat/tactic/sat_tactic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,11 @@ class sat_tactic : public tactic {
m_stats.reset();
}

void user_propagate_initialize_value(expr* var, expr* value) override {

}


protected:

};
Expand Down
20 changes: 20 additions & 0 deletions src/smt/theory_bv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1789,6 +1789,26 @@ namespace smt {
return false;
}

void theory_bv::initialize_value(expr* var, expr* value) {
rational val;
unsigned sz;
if (!m_util.is_numeral(value, val, sz)) {
IF_VERBOSE(5, verbose_stream() << "value should be a bit-vector " << mk_pp(value, m) << "\n");
return;
}
if (!is_app(var))
return;
enode* n = mk_enode(to_app(var));
auto v = get_var(n);
unsigned idx = 0;
for (auto lit : m_bits[v]) {
auto & b = ctx.get_bdata(lit.var());
b.m_phase_available = true;
b.m_phase = val.get_bit(idx);
++idx;
}
}

void theory_bv::init_model(model_generator & mg) {
m_factory = alloc(bv_factory, m);
mg.register_factory(m_factory);
Expand Down
1 change: 1 addition & 0 deletions src/smt/theory_bv.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ namespace smt {
bool merge_zero_one_bits(theory_var r1, theory_var r2);
bool can_propagate() override { return m_prop_diseqs_qhead < m_prop_diseqs.size(); }
void propagate() override;
void initialize_value(expr* var, expr* value) override;

// -----------------------------------
//
Expand Down
6 changes: 1 addition & 5 deletions src/smt/theory_lra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ class theory_lra::imp {
svector<delayed_atom> m_asserted_atoms;
ptr_vector<expr> m_not_handled;
ptr_vector<app> m_underspecified;
vector<std::pair<lpvar, rational>> m_values;
vector<ptr_vector<api_bound> > m_use_list; // bounds where variables are used.

// attributes for incremental version:
Expand Down Expand Up @@ -998,8 +997,7 @@ class theory_lra::imp {
IF_VERBOSE(5, verbose_stream() << "numeric constant expected in initialization " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n");
return;
}
ctx().push_trail(push_back_vector(m_values));
m_values.push_back({get_lpvar(var), r});
lp().move_lpvar_to_value(get_lpvar(var), r);
}

void new_eq_eh(theory_var v1, theory_var v2) {
Expand Down Expand Up @@ -1420,8 +1418,6 @@ class theory_lra::imp {
void init_search_eh() {
m_arith_eq_adapter.init_search_eh();
m_num_conflicts = 0;
for (auto const& [v, r] : m_values)
lp().move_lpvar_to_value(v, r);
}

bool can_get_value(theory_var v) const {
Expand Down
1 change: 1 addition & 0 deletions src/solver/combined_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ class combined_solver : public solver {
}

void user_propagate_initialize_value(expr* var, expr* value) override {
m_solver1->user_propagate_initialize_value(var, value);
m_solver2->user_propagate_initialize_value(var, value);
}

Expand Down
5 changes: 5 additions & 0 deletions src/solver/solver2tactic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ class solver2tactic : public tactic {
}

char const* name() const override { return "solver2tactic"; }


void user_propagate_initialize_value(expr* var, expr* value) override {
m_solver->user_propagate_initialize_value(var, value);
}
};

tactic* mk_solver2tactic(solver* s) { return alloc(solver2tactic, s); }
1 change: 1 addition & 0 deletions src/tactic/.#tactic.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
nbjorner@LAPTOP-04AEAFKH.32880:1726092166
3 changes: 3 additions & 0 deletions src/tactic/fd_solver/bounded_int2bv_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ class bounded_int2bv_solver : public solver_na2as {
return m_assertions.get(idx);
}
}

void user_propagate_initialize_value(expr* var, expr* value) override {
}
};

solver * mk_bounded_int2bv_solver(ast_manager & m, params_ref const & p, solver* s) {
Expand Down
1 change: 1 addition & 0 deletions src/tactic/tactic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class fail_if_undecided_tactic : public skip_tactic {
throw tactic_exception("undecided");
skip_tactic::operator()(in, result);
}
void user_propagate_initialize_value(expr* var, expr* value) override { }
};

tactic * mk_fail_if_undecided_tactic() {
Expand Down
1 change: 1 addition & 0 deletions src/tactic/tactic.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class skip_tactic : public tactic {
tactic * translate(ast_manager & m) override { return this; }
char const* name() const override { return "skip"; }
void collect_statistics(statistics& st) const override {}
void user_propagate_initialize_value(expr* var, expr* value) override { }
};

tactic * mk_skip_tactic();
Expand Down
11 changes: 11 additions & 0 deletions src/tactic/tactical.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,12 @@ class or_else_tactical : public nary_tactical {
}

tactic * translate(ast_manager & m) override { return translate_core<or_else_tactical>(m); }

void user_propagate_initialize_value(expr* var, expr* value) override {
for (auto t : m_ts)
t->user_propagate_initialize_value(var, value);
}

};

tactic * or_else(unsigned num, tactic * const * ts) {
Expand Down Expand Up @@ -1163,6 +1169,11 @@ class cond_tactical : public binary_tactical {
tactic * new_t2 = m_t2->translate(m);
return alloc(cond_tactical, m_p.get(), new_t1, new_t2);
}

void user_propagate_initialize_value(expr* var, expr* value) override {
m_t1->user_propagate_initialize_value(var, value);
m_t2->user_propagate_initialize_value(var, value);
}
};

tactic * cond(probe * p, tactic * t1, tactic * t2) {
Expand Down

0 comments on commit a3f35b6

Please sign in to comment.