Skip to content

Commit

Permalink
[CP-SAT] improve cores when at_most_one are present, cleanup internal…
Browse files Browse the repository at this point in the history
… lns helpers
  • Loading branch information
lperron committed Jul 20, 2023
1 parent 7b49228 commit fd4a173
Show file tree
Hide file tree
Showing 10 changed files with 303 additions and 187 deletions.
1 change: 1 addition & 0 deletions ortools/sat/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,7 @@ cc_library(
":sat_parameters_cc_proto",
":sat_solver",
"//ortools/base",
"//ortools/base:stl_util",
"//ortools/util:strong_integers",
],
)
Expand Down
4 changes: 3 additions & 1 deletion ortools/sat/cp_model_lns.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ class NeighborhoodGeneratorHelper : public SubSolver {
}

// Constraints <-> Variables graph.
// Note that only non-constant variable are listed here.
// Important:
// - The constraint index is NOT related to the one in the cp_model.
// - Only non-constant var are listed in ConstraintToVar().
const std::vector<std::vector<int>>& ConstraintToVar() const
ABSL_SHARED_LOCKS_REQUIRED(graph_mutex_) {
return constraint_to_var_;
Expand Down
2 changes: 1 addition & 1 deletion ortools/sat/cp_model_presolve.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8501,7 +8501,7 @@ void CpModelPresolver::DetectDominatedLinearConstraints() {
++num_inclusions;

// Store the coeff of the subset linear constraint in a map.
const ConstraintProto subset_ct =
const ConstraintProto& subset_ct =
context_->working_model->constraints(subset_c);
const LinearConstraintProto& subset_lin = subset_ct.linear();
coeff_map.clear();
Expand Down
197 changes: 107 additions & 90 deletions ortools/sat/cp_model_solver.cc

Large diffs are not rendered by default.

92 changes: 49 additions & 43 deletions ortools/sat/cp_model_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,108 +55,116 @@ void SetToNegatedLinearExpression(const LinearExpressionProto& input_expr,

IndexReferences GetReferencesUsedByConstraint(const ConstraintProto& ct) {
IndexReferences output;
GetReferencesUsedByConstraint(ct, &output.variables, &output.literals);
return output;
}

void GetReferencesUsedByConstraint(const ConstraintProto& ct,
std::vector<int>* variables,
std::vector<int>* literals) {
variables->clear();
literals->clear();
switch (ct.constraint_case()) {
case ConstraintProto::ConstraintCase::kBoolOr:
AddIndices(ct.bool_or().literals(), &output.literals);
AddIndices(ct.bool_or().literals(), literals);
break;
case ConstraintProto::ConstraintCase::kBoolAnd:
AddIndices(ct.bool_and().literals(), &output.literals);
AddIndices(ct.bool_and().literals(), literals);
break;
case ConstraintProto::ConstraintCase::kAtMostOne:
AddIndices(ct.at_most_one().literals(), &output.literals);
AddIndices(ct.at_most_one().literals(), literals);
break;
case ConstraintProto::ConstraintCase::kExactlyOne:
AddIndices(ct.exactly_one().literals(), &output.literals);
AddIndices(ct.exactly_one().literals(), literals);
break;
case ConstraintProto::ConstraintCase::kBoolXor:
AddIndices(ct.bool_xor().literals(), &output.literals);
AddIndices(ct.bool_xor().literals(), literals);
break;
case ConstraintProto::ConstraintCase::kIntDiv:
AddIndices(ct.int_div().target().vars(), &output.variables);
AddIndices(ct.int_div().target().vars(), variables);
for (const LinearExpressionProto& expr : ct.int_div().exprs()) {
AddIndices(expr.vars(), &output.variables);
AddIndices(expr.vars(), variables);
}
break;
case ConstraintProto::ConstraintCase::kIntMod:
AddIndices(ct.int_mod().target().vars(), &output.variables);
AddIndices(ct.int_mod().target().vars(), variables);
for (const LinearExpressionProto& expr : ct.int_mod().exprs()) {
AddIndices(expr.vars(), &output.variables);
AddIndices(expr.vars(), variables);
}
break;
case ConstraintProto::ConstraintCase::kLinMax: {
AddIndices(ct.lin_max().target().vars(), &output.variables);
AddIndices(ct.lin_max().target().vars(), variables);
for (const LinearExpressionProto& expr : ct.lin_max().exprs()) {
AddIndices(expr.vars(), &output.variables);
AddIndices(expr.vars(), variables);
}
break;
}
case ConstraintProto::ConstraintCase::kIntProd:
AddIndices(ct.int_prod().target().vars(), &output.variables);
AddIndices(ct.int_prod().target().vars(), variables);
for (const LinearExpressionProto& expr : ct.int_prod().exprs()) {
AddIndices(expr.vars(), &output.variables);
AddIndices(expr.vars(), variables);
}
break;
case ConstraintProto::ConstraintCase::kLinear:
AddIndices(ct.linear().vars(), &output.variables);
AddIndices(ct.linear().vars(), variables);
break;
case ConstraintProto::ConstraintCase::kAllDiff:
for (const LinearExpressionProto& expr : ct.all_diff().exprs()) {
AddIndices(expr.vars(), &output.variables);
AddIndices(expr.vars(), variables);
}
break;
case ConstraintProto::ConstraintCase::kDummyConstraint:
AddIndices(ct.dummy_constraint().vars(), &output.variables);
AddIndices(ct.dummy_constraint().vars(), variables);
break;
case ConstraintProto::ConstraintCase::kElement:
output.variables.push_back(ct.element().index());
output.variables.push_back(ct.element().target());
AddIndices(ct.element().vars(), &output.variables);
variables->push_back(ct.element().index());
variables->push_back(ct.element().target());
AddIndices(ct.element().vars(), variables);
break;
case ConstraintProto::ConstraintCase::kCircuit:
AddIndices(ct.circuit().literals(), &output.literals);
AddIndices(ct.circuit().literals(), literals);
break;
case ConstraintProto::ConstraintCase::kRoutes:
AddIndices(ct.routes().literals(), &output.literals);
AddIndices(ct.routes().literals(), literals);
break;
case ConstraintProto::ConstraintCase::kInverse:
AddIndices(ct.inverse().f_direct(), &output.variables);
AddIndices(ct.inverse().f_inverse(), &output.variables);
AddIndices(ct.inverse().f_direct(), variables);
AddIndices(ct.inverse().f_inverse(), variables);
break;
case ConstraintProto::ConstraintCase::kReservoir:
for (const LinearExpressionProto& time : ct.reservoir().time_exprs()) {
AddIndices(time.vars(), &output.variables);
AddIndices(time.vars(), variables);
}
for (const LinearExpressionProto& level :
ct.reservoir().level_changes()) {
AddIndices(level.vars(), &output.variables);
AddIndices(level.vars(), variables);
}
AddIndices(ct.reservoir().active_literals(), &output.literals);
AddIndices(ct.reservoir().active_literals(), literals);
break;
case ConstraintProto::ConstraintCase::kTable:
AddIndices(ct.table().vars(), &output.variables);
AddIndices(ct.table().vars(), variables);
break;
case ConstraintProto::ConstraintCase::kAutomaton:
AddIndices(ct.automaton().vars(), &output.variables);
AddIndices(ct.automaton().vars(), variables);
break;
case ConstraintProto::ConstraintCase::kInterval:
AddIndices(ct.interval().start().vars(), &output.variables);
AddIndices(ct.interval().size().vars(), &output.variables);
AddIndices(ct.interval().end().vars(), &output.variables);
AddIndices(ct.interval().start().vars(), variables);
AddIndices(ct.interval().size().vars(), variables);
AddIndices(ct.interval().end().vars(), variables);
break;
case ConstraintProto::ConstraintCase::kNoOverlap:
break;
case ConstraintProto::ConstraintCase::kNoOverlap2D:
break;
case ConstraintProto::ConstraintCase::kCumulative:
AddIndices(ct.cumulative().capacity().vars(), &output.variables);
AddIndices(ct.cumulative().capacity().vars(), variables);
for (const LinearExpressionProto& demand : ct.cumulative().demands()) {
AddIndices(demand.vars(), &output.variables);
AddIndices(demand.vars(), variables);
}
break;
case ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET:
break;
}
return output;
}

#define APPLY_TO_SINGULAR_FIELD(ct_name, field_name) \
Expand Down Expand Up @@ -393,7 +401,7 @@ void ApplyToAllIntervalIndices(const std::function<void(int*)>& f,
#undef APPLY_TO_SINGULAR_FIELD
#undef APPLY_TO_REPEATED_FIELD

std::string ConstraintCaseName(
absl::string_view ConstraintCaseName(
ConstraintProto::ConstraintCase constraint_case) {
switch (constraint_case) {
case ConstraintProto::ConstraintCase::kBoolOr:
Expand Down Expand Up @@ -448,18 +456,16 @@ std::string ConstraintCaseName(
}

std::vector<int> UsedVariables(const ConstraintProto& ct) {
IndexReferences references = GetReferencesUsedByConstraint(ct);
for (int& ref : references.variables) {
std::vector<int> result;
GetReferencesUsedByConstraint(ct, &result, &result);
for (int& ref : result) {
ref = PositiveRef(ref);
}
for (const int lit : references.literals) {
references.variables.push_back(PositiveRef(lit));
}
for (const int lit : ct.enforcement_literal()) {
references.variables.push_back(PositiveRef(lit));
result.push_back(PositiveRef(lit));
}
gtl::STLSortAndRemoveDuplicates(&references.variables);
return references.variables;
gtl::STLSortAndRemoveDuplicates(&result);
return result;
}

std::vector<int> UsedIntervals(const ConstraintProto& ct) {
Expand Down
6 changes: 5 additions & 1 deletion ortools/sat/cp_model_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ struct IndexReferences {
std::vector<int> literals;
};
IndexReferences GetReferencesUsedByConstraint(const ConstraintProto& ct);
void GetReferencesUsedByConstraint(const ConstraintProto& ct,
std::vector<int>* variables,
std::vector<int>* literals);

// Applies the given function to all variables/literals/intervals indices of the
// constraint. This function is used in a few places to have a "generic" code
Expand All @@ -78,7 +81,8 @@ void ApplyToAllIntervalIndices(const std::function<void(int*)>& function,

// Returns the name of the ConstraintProto::ConstraintCase oneof enum.
// Note(user): There is no such function in the proto API as of 16/01/2017.
std::string ConstraintCaseName(ConstraintProto::ConstraintCase constraint_case);
absl::string_view ConstraintCaseName(
ConstraintProto::ConstraintCase constraint_case);

// Returns the sorted list of variables used by a constraint.
// Note that this include variable used as a literal.
Expand Down
114 changes: 67 additions & 47 deletions ortools/sat/encoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include "absl/log/check.h"
#include "absl/strings/str_cat.h"
#include "ortools/base/stl_util.h"
#include "ortools/sat/boolean_problem.pb.h"
#include "ortools/sat/pb_constraint.h"
#include "ortools/sat/sat_base.h"
Expand Down Expand Up @@ -625,43 +626,6 @@ bool ProcessCore(const std::vector<Literal>& core, Coefficient min_weight,
return solver->AddUnitClause(core[0].Negated());
}

// Are the literal in amo relationship?
// - If min_weight is large enough, we can infer that.
// - If the size is small we can infer this manually.
//
// TODO(user): improve the HeuristicAmoPartition() for small case, i.e. use
// full propagation. Maybe we can merge with MinimizeCoreWithPropagation().
bool in_exactly_one = (2 * min_weight) > gap;
if (!in_exactly_one && core.size() == 2) {
const Literal cost0 = core[0].Negated();
const Literal cost1 = core[1].Negated();
const bool ok = solver->EnqueueDecisionIfNotConflicting(cost0);
if (!ok) {
// cost0 cannot be true! so cost1 must be true.
return solver->AddUnitClause(cost1);
}

if (solver->Assignment().LiteralIsFalse(cost1)) {
// If the literal are in AMO, we make sure the solver will not forget that
// by adding the exactly one below.
in_exactly_one = true;

// Tricky: we need the two nodes to be "Boolean nodes", otherwise while
// we have an "at most one" between the assumption we don't necessarily
// have an at most one between the nodes.
if (!solver->ResetToLevelZero()) return false;
for (const EncodingNode* node : *nodes) {
if (node->AssumptionIs(core[0]) || node->AssumptionIs(core[1])) {
if (node->current_ub() != node->ub() || node->size() != 1) {
// TODO(user): We could still improve the encoding in this case.
in_exactly_one = false;
}
}
}
}
if (!solver->ResetToLevelZero()) return false;
}

// Remove from nodes the EncodingNode in the core and put them in to_merge.
std::vector<EncodingNode*> to_merge;
{
Expand Down Expand Up @@ -697,32 +661,88 @@ bool ProcessCore(const std::vector<Literal>& core, Coefficient min_weight,
nodes->resize(new_size);
}

// Amongst the node to merge, if many are leaf nodes in an "at most one"
// relationship, it is super advantageous to exploit it during merging as
// we can regroup all nodes from an at most one in a single new node with a
// depth of 1.
// Are the literal in amo relationship?
// - If min_weight is large enough, we can infer that.
// - If the size is small we can infer this via propagation.
bool in_exactly_one = (2 * min_weight) > gap;

// Amongst the node to merge, if many are boolean nodes in an "at most one"
// relationship, it is super advantageous to exploit it during merging as we
// can regroup all nodes from an at most one in a single new node with a depth
// of 1.
if (!in_exactly_one && implications != nullptr) {
std::vector<Literal> leaves;
// Collect "boolean nodes".
std::vector<Literal> bool_nodes;
absl::flat_hash_map<LiteralIndex, int> node_indices;
for (int i = 0; i < to_merge.size(); ++i) {
const EncodingNode& node = *to_merge[i];

// TODO(user): Why is there issue if we consider higher level?
if (node.depth() != 0) continue;
if (node.size() != 1) continue;
if (node.ub() != node.lb() + 1) continue;
if (node_indices.contains(node.literal(0).Index())) continue;
node_indices[node.literal(0).Index()] = i;
leaves.push_back(node.literal(0));
}
bool_nodes.push_back(node.literal(0));
}

// For "small" core, with O(n) full propagation, we can discover possible
// at most ones. This is a bit costly but can significantly reduce the
// number of Booleans needed and has a good positive impact.
std::vector<int> buffer;
std::vector<absl::Span<const Literal>> decomposition;
if (bool_nodes.size() < 100 && bool_nodes.size() > 1) {
const auto& assignment = solver->Assignment();
const int size = bool_nodes.size();
std::vector<std::vector<int>> graph(size);
for (int i = 0; i < size; ++i) {
if (!solver->ResetToLevelZero()) return false;
if (!solver->EnqueueDecisionIfNotConflicting(bool_nodes[i])) {
// TODO(user): this node is closed and can be removed from the core.
continue;
}
for (int j = 0; j < size; ++j) {
if (i == j) continue;
if (assignment.LiteralIsFalse(bool_nodes[j])) {
graph[i].push_back(j);

int num_in_decompo = 0;
const std::vector<absl::Span<const Literal>> decomposition =
implications->HeuristicAmoPartition(&leaves);
// Unit propagation is not always symmetric.
graph[j].push_back(i);
}

// TODO(user): If assignment.LiteralIsTrue(bool_nodes[j]) We can
// minimize the core here by removing bool_nodes[i] from it. Note
// however that since we already minimized the core, this is
// unlikely to happen.
}
}
if (!solver->ResetToLevelZero()) return false;

for (std::vector<int>& adj : graph) {
gtl::STLSortAndRemoveDuplicates(&adj);
}
const std::vector<absl::Span<int>> index_decompo =
AtMostOneDecomposition(graph, &buffer);

// Convert.
std::vector<Literal> new_order;
for (const int i : buffer) new_order.push_back(bool_nodes[i]);
bool_nodes = new_order;
for (const auto span : index_decompo) {
if (span.size() == 1) continue;
decomposition.push_back(absl::MakeSpan(
bool_nodes.data() + (span.data() - buffer.data()), span.size()));
}
} else {
decomposition = implications->HeuristicAmoPartition(&bool_nodes);
}

// Same case as above, all the nodes in the core are in a exactly_one.
if (decomposition.size() == 1 && decomposition[0].size() == core.size()) {
in_exactly_one = true;
}

int num_in_decompo = 0;
if (!in_exactly_one) {
for (const auto amo : decomposition) {
num_in_decompo += amo.size();
Expand Down
Loading

0 comments on commit fd4a173

Please sign in to comment.