From 1bdd360e96cd41c8ea1509af8a3e8877c42a8edb Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 Jul 2024 14:52:00 +0200 Subject: [PATCH 1/7] chore(ci): bump prefix-dev/rattler-build-action from 0.2.11 to 0.2.12 (#57) Bumps [prefix-dev/rattler-build-action](https://github.com/prefix-dev/rattler-build-action) from 0.2.11 to 0.2.12. - [Release notes](https://github.com/prefix-dev/rattler-build-action/releases) - [Commits](https://github.com/prefix-dev/rattler-build-action/compare/v0.2.11...v0.2.12) --- updated-dependencies: - dependency-name: prefix-dev/rattler-build-action dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/cpp-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cpp-test.yml b/.github/workflows/cpp-test.yml index 23e1643..2ac41e8 100644 --- a/.github/workflows/cpp-test.yml +++ b/.github/workflows/cpp-test.yml @@ -68,7 +68,7 @@ jobs: - name: Checkout source code uses: actions/checkout@v4 - name: Build conda package - uses: prefix-dev/rattler-build-action@v0.2.11 + uses: prefix-dev/rattler-build-action@v0.2.12 with: recipe-path: recipe/recipe.yaml # needs to be unique for each matrix entry From 899040ad8c011618a31ae30042a8864a735c853b Mon Sep 17 00:00:00 2001 From: Arav Singhal Date: Mon, 29 Jul 2024 09:41:20 -0500 Subject: [PATCH 2/7] fix: Fix off-by-one error in `Mapping::serialize` (#58) * fix: Fix off-by-one error in `Mapping::serialize` * fix(test): Add serialization tests for `Mapping` --- src/internal/mapping.rs | 42 +++++++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/src/internal/mapping.rs b/src/internal/mapping.rs index 34cf4a4..9fc6d93 100644 --- a/src/internal/mapping.rs +++ b/src/internal/mapping.rs @@ -196,7 +196,7 @@ impl serde::Serialize for Mapping { self.chunks .iter() .flatten() - .take(self.max()) + .take(self.max() + 1) .collect::>() .serialize(serializer) } @@ -221,7 +221,7 @@ impl<'de, K: ArenaId, V: serde::Deserialize<'de>> serde::Deserialize<'de> for Ma #[cfg(test)] mod tests { - use crate::internal::arena::ArenaId; + use super::*; struct Id { id: usize, @@ -240,9 +240,9 @@ mod tests { #[test] pub fn test_mapping() { // New mapping should have 128 slots per default - let mut mapping = super::Mapping::::new(); + let mut mapping = Mapping::::new(); assert_eq!(mapping.len(), 0); - assert_eq!(mapping.slots(), super::VALUES_PER_CHUNK); + assert_eq!(mapping.slots(), VALUES_PER_CHUNK); // Inserting a value should increase the length // and the number of slots should stay the same @@ -251,20 +251,42 @@ mod tests { // Should be able to get it assert_eq!(*mapping.get(Id::from_usize(0)).unwrap(), 10usize); - assert_eq!(mapping.slots(), super::VALUES_PER_CHUNK); + assert_eq!(mapping.slots(), VALUES_PER_CHUNK); // Inserting higher than the slot size should trigger a resize - mapping.insert(Id::from_usize(super::VALUES_PER_CHUNK), 20usize); + mapping.insert(Id::from_usize(VALUES_PER_CHUNK), 20usize); assert_eq!( - *mapping - .get(Id::from_usize(super::VALUES_PER_CHUNK)) - .unwrap(), + *mapping.get(Id::from_usize(VALUES_PER_CHUNK)).unwrap(), 20usize ); // Now contains 2 elements assert_eq!(mapping.len(), 2); // And double number of slots due to resize - assert_eq!(mapping.slots(), super::VALUES_PER_CHUNK * 2); + assert_eq!(mapping.slots(), VALUES_PER_CHUNK * 2); + } + + #[cfg(feature = "serde")] + #[test] + pub fn test_serde() { + use serde_json::{from_value, to_value}; + + let values = [1, 3, 6, 9, 2, 4, 6, 1, 2, 3]; + let json = to_value(values).unwrap(); + let mapping = + values + .iter() + .copied() + .enumerate() + .fold(Mapping::new(), |mut mapping, (i, v)| { + mapping.insert(Id::from_usize(i), v); + mapping + }); + + assert_eq!(json, to_value(&mapping).unwrap()); + itertools::assert_equal( + mapping.iter().map(|(_, &v)| v), + from_value::>(json).unwrap(), + ); } } From d5a82f5d92d5b0c8c6b35cfd568dbd454c3b8aa5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 22:17:19 +0200 Subject: [PATCH 3/7] chore(ci): bump prefix-dev/rattler-build-action from 0.2.12 to 0.2.13 (#59) Bumps [prefix-dev/rattler-build-action](https://github.com/prefix-dev/rattler-build-action) from 0.2.12 to 0.2.13. - [Release notes](https://github.com/prefix-dev/rattler-build-action/releases) - [Commits](https://github.com/prefix-dev/rattler-build-action/compare/v0.2.12...v0.2.13) --- updated-dependencies: - dependency-name: prefix-dev/rattler-build-action dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/cpp-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cpp-test.yml b/.github/workflows/cpp-test.yml index 2ac41e8..f50a833 100644 --- a/.github/workflows/cpp-test.yml +++ b/.github/workflows/cpp-test.yml @@ -68,7 +68,7 @@ jobs: - name: Checkout source code uses: actions/checkout@v4 - name: Build conda package - uses: prefix-dev/rattler-build-action@v0.2.12 + uses: prefix-dev/rattler-build-action@v0.2.13 with: recipe-path: recipe/recipe.yaml # needs to be unique for each matrix entry From 31cf85143561e248631aa822ab90fd79c9044279 Mon Sep 17 00:00:00 2001 From: Arav Singhal Date: Mon, 5 Aug 2024 15:23:24 -0500 Subject: [PATCH 4/7] feat!: Version set unions as solvable requirements (#56) This PR introduces the concept of version set unions and allows them to be used to specify solvable requirements that can be fulfilled by more than one version set (belonging to more than one package). The change is facilitated by the introduction of a Requirement type, which is used instead of VersionSetIds to specify a required dependency of a solvable. A requirement can either be comprised of a single version set (a VersionSetId), or multiple version sets (a VersionSetUnionId). This allows existing code bases to continue using a single version set (VersionSetId) as a requirement specification without ever having to touch version set unions (VersionSetUnionId). --- cpp/include/resolvo.h | 21 ++- cpp/include/resolvo_dependency_provider.h | 15 ++ cpp/src/lib.rs | 95 +++++++++++- cpp/tests/solve.cpp | 107 +++++++++++-- src/internal/id.rs | 17 +++ src/lib.rs | 12 +- src/problem.rs | 40 ++--- src/requirement.rs | 91 +++++++++++ src/snapshot.rs | 64 ++++++-- src/solver/cache.rs | 117 +++++++++----- src/solver/clause.rs | 38 ++--- src/solver/mod.rs | 52 ++++--- src/solver/watch_map.rs | 5 +- src/utils/pool.rs | 32 +++- tests/snapshots/solver__snapshot.snap | 1 - .../solver__snapshot_union_requirements.snap | 6 + tests/solver.rs | 143 ++++++++++++++++-- tools/solve-snapshot/src/main.rs | 2 +- 18 files changed, 701 insertions(+), 157 deletions(-) create mode 100644 src/requirement.rs create mode 100644 tests/snapshots/solver__snapshot_union_requirements.snap diff --git a/cpp/include/resolvo.h b/cpp/include/resolvo.h index 6a41aa3..48844d6 100644 --- a/cpp/include/resolvo.h +++ b/cpp/include/resolvo.h @@ -4,6 +4,24 @@ #include "resolvo_internal.h" namespace resolvo { +using cbindgen_private::Requirement; + +/** + * Specifies a requirement (dependency) of a single version set. + */ +inline Requirement requirement_single(VersionSetId id) { + return cbindgen_private::resolvo_requirement_single(id); +} + +/** + * Specifies a requirement (dependency) of the union (logical OR) of multiple version sets. + * A solvable belonging to any of the version sets contained in the union satisfies the + * requirement. This variant is typically used for requirements that can be satisfied by two + * or more version sets belonging to different packages. + */ +inline Requirement requirement_union(VersionSetUnionId id) { + return cbindgen_private::resolvo_requirement_union(id); +} /** * Called to solve a package problem. @@ -12,7 +30,7 @@ namespace resolvo { * stored in `result`. If the solve was unsuccesfull an error describing the reason is returned and * the result vector will be empty. */ -inline String solve(DependencyProvider &provider, Slice requirements, +inline String solve(DependencyProvider &provider, Slice requirements, Slice constraints, Vector &result) { cbindgen_private::DependencyProvider bridge{ static_cast(&provider), @@ -24,6 +42,7 @@ inline String solve(DependencyProvider &provider, Slice requiremen private_api::bridge_display_string, private_api::bridge_version_set_name, private_api::bridge_solvable_name, + private_api::bridge_version_sets_in_union, private_api::bridge_get_candidates, private_api::bridge_sort_candidates, private_api::bridge_filter_candidates, diff --git a/cpp/include/resolvo_dependency_provider.h b/cpp/include/resolvo_dependency_provider.h index 8cd9258..9484079 100644 --- a/cpp/include/resolvo_dependency_provider.h +++ b/cpp/include/resolvo_dependency_provider.h @@ -13,6 +13,7 @@ using cbindgen_private::NameId; using cbindgen_private::SolvableId; using cbindgen_private::StringId; using cbindgen_private::VersionSetId; +using cbindgen_private::VersionSetUnionId; /** * An interface that implements ecosystem specific logic. @@ -75,6 +76,11 @@ struct DependencyProvider { */ virtual NameId solvable_name(SolvableId solvable_id) = 0; + /** + * Returns the version sets comprising the given union. + */ + virtual Slice version_sets_in_union(VersionSetUnionId version_set_union_id) = 0; + /** * Obtains a list of solvables that should be considered when a package * with the given name is requested. @@ -133,6 +139,15 @@ extern "C" inline NameId bridge_solvable_name(void *data, SolvableId solvable_id return reinterpret_cast(data)->solvable_name(solvable_id); } +// HACK(clang): For some reason, clang needs this to know that the return type is complete +static_assert(sizeof(Slice)); + +extern "C" inline Slice bridge_version_sets_in_union( + void *data, VersionSetUnionId version_set_union_id) { + return reinterpret_cast(data)->version_sets_in_union( + version_set_union_id); +} + extern "C" inline void bridge_get_candidates(void *data, NameId package, Candidates *result) { *result = reinterpret_cast(data)->get_candidates(package); } diff --git a/cpp/src/lib.rs b/cpp/src/lib.rs index f8bd5c6..4aa9998 100644 --- a/cpp/src/lib.rs +++ b/cpp/src/lib.rs @@ -31,6 +31,66 @@ impl From for resolvo::SolvableId { } } +/// Specifies the dependency of a solvable on a set of version sets. +/// cbindgen:derive-eq +/// cbindgen:derive-neq +#[repr(C)] +#[derive(Copy, Clone)] +pub enum Requirement { + /// Specifies a dependency on a single version set. + /// cbindgen:derive-eq + /// cbindgen:derive-neq + Single(VersionSetId), + /// Specifies a dependency on the union (logical OR) of multiple version sets. A solvable + /// belonging to ANY of the version sets contained in the union satisfies the requirement. + /// This variant is typically used for requirements that can be satisfied by two or more + /// version sets belonging to different packages. + /// cbindgen:derive-eq + /// cbindgen:derive-neq + Union(VersionSetUnionId), +} + +impl From for crate::Requirement { + fn from(value: resolvo::Requirement) -> Self { + match value { + resolvo::Requirement::Single(id) => Requirement::Single(id.into()), + resolvo::Requirement::Union(id) => Requirement::Union(id.into()), + } + } +} + +impl From for resolvo::Requirement { + fn from(value: crate::Requirement) -> Self { + match value { + Requirement::Single(id) => resolvo::Requirement::Single(id.into()), + Requirement::Union(id) => resolvo::Requirement::Union(id.into()), + } + } +} + +/// A unique identifier for a version set union. A version set union describes +/// the union (logical OR) of a non-empty set of version sets belonging to +/// more than one package. +/// cbindgen:derive-eq +/// cbindgen:derive-neq +#[repr(C)] +#[derive(Copy, Clone)] +pub struct VersionSetUnionId { + id: u32, +} + +impl From for crate::VersionSetUnionId { + fn from(id: resolvo::VersionSetUnionId) -> Self { + Self { id: id.0 } + } +} + +impl From for resolvo::VersionSetUnionId { + fn from(id: crate::VersionSetUnionId) -> Self { + Self(id.id) + } +} + /// A unique identifier for a single version set. A version set describes a /// set of versions. /// cbindgen:derive-eq @@ -102,7 +162,7 @@ pub struct Dependencies { /// A pointer to the first element of a list of requirements. Requirements /// defines which packages should be installed alongside the depending /// package and the constraints applied to the package. - pub requirements: Vector, + pub requirements: Vector, /// Defines additional constraints on packages that may or may not be part /// of the solution. Different from `requirements`, packages in this set @@ -230,6 +290,12 @@ pub struct DependencyProvider { /// Returns the name of the package for the given solvable. pub solvable_name: unsafe extern "C" fn(data: *mut c_void, solvable_id: SolvableId) -> NameId, + /// Returns the version sets comprising the given union. + pub version_sets_in_union: unsafe extern "C" fn( + data: *mut c_void, + version_set_union_id: VersionSetUnionId, + ) -> Slice<'static, VersionSetId>, + /// Obtains a list of solvables that should be considered when a package /// with the given name is requested. pub get_candidates: @@ -314,6 +380,17 @@ impl<'d> resolvo::Interner for &'d DependencyProvider { fn solvable_name(&self, solvable: resolvo::SolvableId) -> resolvo::NameId { unsafe { (self.solvable_name)(self.data, solvable.into()) }.into() } + + fn version_sets_in_union( + &self, + version_set_union: resolvo::VersionSetUnionId, + ) -> impl Iterator { + unsafe { (self.version_sets_in_union)(self.data, version_set_union.into()) } + .as_slice() + .into_iter() + .copied() + .map(Into::into) + } } impl<'d> resolvo::DependencyProvider for &'d DependencyProvider { @@ -400,7 +477,7 @@ impl<'d> resolvo::DependencyProvider for &'d DependencyProvider { #[allow(unused)] pub extern "C" fn resolvo_solve( provider: &DependencyProvider, - requirements: Slice, + requirements: Slice, constraints: Slice, error: &mut String, result: &mut Vector, @@ -433,6 +510,20 @@ pub extern "C" fn resolvo_solve( } } +#[no_mangle] +#[allow(unused)] +pub extern "C" fn resolvo_requirement_single(version_set_id: VersionSetId) -> Requirement { + Requirement::Single(version_set_id) +} + +#[no_mangle] +#[allow(unused)] +pub extern "C" fn resolvo_requirement_union( + version_set_union_id: VersionSetUnionId, +) -> Requirement { + Requirement::Union(version_set_union_id) +} + #[cfg(test)] mod tests { use super::*; diff --git a/cpp/tests/solve.cpp b/cpp/tests/solve.cpp index ffb8c66..13a05fa 100644 --- a/cpp/tests/solve.cpp +++ b/cpp/tests/solve.cpp @@ -16,9 +16,9 @@ struct Candidate { }; /** - * A requirement for a package. + * A version set for a package. */ -struct Requirement { +struct VersionSet { resolvo::NameId name; uint32_t version_start; uint32_t version_end; @@ -31,19 +31,47 @@ struct PackageDatabase : public resolvo::DependencyProvider { resolvo::Pool names; resolvo::Pool strings; std::vector candidates; - std::vector requirements; + std::vector version_sets; + std::vector> version_set_unions; /** - * Allocates a new requirement and return the id of the requirement. + * Allocates a new version set and return the id of the version set. */ - resolvo::VersionSetId alloc_requirement(std::string_view package, uint32_t version_start, + resolvo::VersionSetId alloc_version_set(std::string_view package, uint32_t version_start, uint32_t version_end) { auto name_id = names.alloc(std::move(package)); - auto id = resolvo::VersionSetId{static_cast(requirements.size())}; - requirements.push_back(Requirement{name_id, version_start, version_end}); + auto id = resolvo::VersionSetId{static_cast(version_sets.size())}; + version_sets.push_back(VersionSet{name_id, version_start, version_end}); return id; } + /** + * Allocates a new requirement for a single version set. + */ + resolvo::Requirement alloc_requirement(std::string_view package, uint32_t version_start, + uint32_t version_end) { + auto id = alloc_version_set(package, version_start, version_end); + return resolvo::requirement_single(id); + } + + /** + * Allocates a new requirement for a version set union. + */ + resolvo::Requirement alloc_requirement_union( + std::initializer_list> version_sets) { + std::vector version_set_union{version_sets.size()}; + + auto version_sets_it = version_sets.begin(); + for (size_t i = 0; i < version_sets.size(); ++i, ++version_sets_it) { + auto [package, version_start, version_end] = *version_sets_it; + version_set_union[i] = alloc_version_set(package, version_start, version_end); + } + + auto id = resolvo::VersionSetUnionId{static_cast(version_set_unions.size())}; + version_set_unions.push_back(std::move(version_set_union)); + return resolvo::requirement_union(id); + } + /** * Allocates a new candidate and return the id of the candidate. */ @@ -90,7 +118,7 @@ struct PackageDatabase : public resolvo::DependencyProvider { } resolvo::String display_version_set(resolvo::VersionSetId version_set) override { - const auto& req = requirements[version_set.id]; + const auto& req = version_sets[version_set.id]; std::stringstream ss; ss << req.version_start << ".." << req.version_end; return resolvo::String(ss.str()); @@ -101,13 +129,19 @@ struct PackageDatabase : public resolvo::DependencyProvider { } resolvo::NameId version_set_name(resolvo::VersionSetId version_set_id) override { - return requirements[version_set_id.id].name; + return version_sets[version_set_id.id].name; } resolvo::NameId solvable_name(resolvo::SolvableId solvable_id) override { return candidates[solvable_id.id].name; } + resolvo::Slice version_sets_in_union( + resolvo::VersionSetUnionId version_set_union_id) override { + const auto& version_set_ids = version_set_unions[version_set_union_id.id]; + return {version_set_ids.data(), version_set_ids.size()}; + } + resolvo::Candidates get_candidates(resolvo::NameId package) override { resolvo::Candidates result; @@ -137,11 +171,11 @@ struct PackageDatabase : public resolvo::DependencyProvider { resolvo::Slice solvables, resolvo::VersionSetId version_set_id, bool inverse) override { resolvo::Vector result; - const auto& requirement = requirements[version_set_id.id]; + const auto& version_set = version_sets[version_set_id.id]; for (auto solvable : solvables) { const auto& candidate = candidates[solvable.id]; - bool matches = candidate.version >= requirement.version_start && - candidate.version < requirement.version_end; + bool matches = candidate.version >= version_set.version_start && + candidate.version < version_set.version_end; if (matches != inverse) { result.push_back(solvable); } @@ -183,9 +217,9 @@ SCENARIO("Solve") { auto c_1 = db.alloc_candidate("c", 1, {}); // Construct a problem to be solved by the solver - resolvo::Vector requirements = {db.alloc_requirement("a", 1, 3)}; - resolvo::Vector constraints = {db.alloc_requirement("b", 1, 3), - db.alloc_requirement("c", 1, 3)}; + resolvo::Vector requirements = {db.alloc_requirement("a", 1, 3)}; + resolvo::Vector constraints = {db.alloc_version_set("b", 1, 3), + db.alloc_version_set("c", 1, 3)}; // Solve the problem resolvo::Vector result; @@ -196,3 +230,46 @@ SCENARIO("Solve") { REQUIRE(result[0] == a_2); REQUIRE(result[1] == b_2); } + +SCENARIO("Solve Union") { + /// Construct a database with packages a, b, and c. + PackageDatabase db; + + // Check that PackageDatabase correctly implements the DependencyProvider interface + static_assert(std::has_virtual_destructor_v); + static_assert(std::is_polymorphic_v); + static_assert(std::is_base_of_v); + + auto a_1 = db.alloc_candidate("a", 1, {}); + + auto b_1 = db.alloc_candidate("b", 1, {}); + + auto c_1 = db.alloc_candidate("c", 1, {{db.alloc_requirement("a", 1, 10)}, {}}); + + auto d_1 = db.alloc_candidate("d", 1, {{db.alloc_requirement("b", 1, 10)}, {}}); + + auto e_1 = db.alloc_candidate("e", 1, + {{db.alloc_requirement_union({{"a", 1, 10}, {"b", 1, 10}})}, {}}); + + auto f_1 = db.alloc_candidate( + "f", 1, {{db.alloc_requirement("b", 1, 10)}, {db.alloc_version_set("a", 10, 20)}}); + + // Construct a problem to be solved by the solver + resolvo::Vector requirements = { + db.alloc_requirement_union({{"c", 1, 10}, {"d", 1, 10}}), + db.alloc_requirement("e", 1, 10), + db.alloc_requirement("f", 1, 10), + }; + resolvo::Vector constraints = {}; + + // Solve the problem + resolvo::Vector result; + resolvo::solve(db, requirements, constraints, result); + + // Check the result + REQUIRE(result.size() == 4); + REQUIRE(result[0] == f_1); + REQUIRE(result[1] == e_1); + REQUIRE(result[2] == b_1); + REQUIRE(result[3] == d_1); +} diff --git a/src/internal/id.rs b/src/internal/id.rs index 1ebc041..7f5a44a 100644 --- a/src/internal/id.rs +++ b/src/internal/id.rs @@ -53,6 +53,23 @@ impl ArenaId for VersionSetId { } } +/// The id associated with a union (logical OR) of two or more version sets. +#[repr(transparent)] +#[derive(Clone, Default, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] +pub struct VersionSetUnionId(pub u32); + +impl ArenaId for VersionSetUnionId { + fn from_usize(x: usize) -> Self { + Self(x as u32) + } + + fn to_usize(self) -> usize { + self.0 as usize + } +} + /// The id associated to a solvable #[repr(transparent)] #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] diff --git a/src/lib.rs b/src/lib.rs index 40bac7d..5ca2ba1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,7 @@ pub(crate) mod internal; pub mod problem; +mod requirement; pub mod runtime; pub mod snapshot; mod solver; @@ -23,10 +24,11 @@ use std::{ }; pub use internal::{ - id::{NameId, SolvableId, StringId, VersionSetId}, + id::{NameId, SolvableId, StringId, VersionSetId, VersionSetUnionId}, mapping::Mapping, }; use itertools::Itertools; +pub use requirement::Requirement; pub use solver::{Solver, SolverCache, UnsolvableOrCancelled}; /// An object that is used by the solver to query certain properties of @@ -87,6 +89,12 @@ pub trait Interner { /// Returns the name of the package for the given solvable. fn solvable_name(&self, solvable: SolvableId) -> NameId; + + /// Returns the version sets comprising the given union. + fn version_sets_in_union( + &self, + version_set_union: VersionSetUnionId, + ) -> impl Iterator; } /// Defines implementation specific behavior for the solver and a way for the @@ -195,7 +203,7 @@ pub struct KnownDependencies { feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty") )] - pub requirements: Vec, + pub requirements: Vec, /// Defines additional constraints on packages that may or may not be part /// of the solution. Different from `requirements`, packages in this set diff --git a/src/problem.rs b/src/problem.rs index a084453..3033244 100644 --- a/src/problem.rs +++ b/src/problem.rs @@ -15,7 +15,7 @@ use crate::{ internal::id::{ClauseId, InternalSolvableId, SolvableId, StringId, VersionSetId}, runtime::AsyncRuntime, solver::{clause::Clause, Solver}, - DependencyProvider, Interner, + DependencyProvider, Interner, Requirement, }; /// Represents the cause of the solver being unable to find a solution @@ -198,21 +198,21 @@ impl ProblemNode { #[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] pub(crate) enum ProblemEdge { /// The target node is a candidate for the dependency specified by the - /// version set - Requires(VersionSetId), + /// [`Requirement`] + Requires(Requirement), /// The target node is involved in a conflict, caused by `ConflictCause` Conflict(ConflictCause), } impl ProblemEdge { - fn try_requires(self) -> Option { + fn try_requires(self) -> Option { match self { ProblemEdge::Requires(match_spec_id) => Some(match_spec_id), ProblemEdge::Conflict(_) => None, } } - fn requires(self) -> VersionSetId { + fn requires(self) -> Requirement { match self { ProblemEdge::Requires(match_spec_id) => match_spec_id, ProblemEdge::Conflict(_) => panic!("expected requires edge, found conflict"), @@ -301,8 +301,8 @@ impl ProblemGraph { }; let label = match edge.weight() { - ProblemEdge::Requires(version_set_id) - | ProblemEdge::Conflict(ConflictCause::Constrains(version_set_id)) => { + ProblemEdge::Requires(requirement) => requirement.display(interner).to_string(), + ProblemEdge::Conflict(ConflictCause::Constrains(version_set_id)) => { interner.display_version_set(*version_set_id).to_string() } ProblemEdge::Conflict(ConflictCause::ForbidMultipleInstances) @@ -651,7 +651,7 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { top_level_indent: bool, ) -> fmt::Result { pub enum DisplayOp { - Requirement(VersionSetId, Vec), + Requirement(Requirement, Vec), Candidate(NodeIndex), } @@ -693,7 +693,7 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { let indent = indenter.get_indent(); match node { - DisplayOp::Requirement(version_set_id, edges) => { + DisplayOp::Requirement(requirement, edges) => { debug_assert!(!edges.is_empty()); let installable = edges.iter().any(|&e| { @@ -701,12 +701,7 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { installable_nodes.contains(&target) }); - let req = self - .interner - .display_version_set(version_set_id) - .to_string(); - let name = self.interner.version_set_name(version_set_id); - let name = self.interner.display_name(name).to_string(); + let req = requirement.display(self.interner).to_string(); let target_nx = graph.edge_endpoints(edges[0]).unwrap().1; let missing = @@ -714,22 +709,19 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { if missing { // No candidates for requirement if top_level { - writeln!(f, "{indent}No candidates were found for {name} {req}.")?; + writeln!(f, "{indent}No candidates were found for {req}.")?; } else { - writeln!( - f, - "{indent}{name} {req}, for which no candidates were found.", - )?; + writeln!(f, "{indent}{req}, for which no candidates were found.",)?; } } else if installable { // Package can be installed (only mentioned for top-level requirements) if top_level { writeln!( f, - "{indent}{name} {req} can be installed with any of the following options:" + "{indent}{req} can be installed with any of the following options:" )?; } else { - writeln!(f, "{indent}{name} {req}, which can be installed with any of the following options:")?; + writeln!(f, "{indent}{req}, which can be installed with any of the following options:")?; } let children: Vec<_> = edges @@ -780,9 +772,9 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { // Package cannot be installed (the conflicting requirement is further down // the tree) if top_level { - writeln!(f, "{indent}{name} {req} cannot be installed because there are no viable options:")?; + writeln!(f, "{indent}{req} cannot be installed because there are no viable options:")?; } else { - writeln!(f, "{indent}{name} {req}, which cannot be installed because there are no viable options:")?; + writeln!(f, "{indent}{req}, which cannot be installed because there are no viable options:")?; } let children: Vec<_> = edges diff --git a/src/requirement.rs b/src/requirement.rs new file mode 100644 index 0000000..10b9e46 --- /dev/null +++ b/src/requirement.rs @@ -0,0 +1,91 @@ +use crate::{Interner, VersionSetId, VersionSetUnionId}; +use itertools::Itertools; +use std::fmt::Display; + +/// Specifies the dependency of a solvable on a set of version sets. +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum Requirement { + /// Specifies a dependency on a single version set. + Single(VersionSetId), + /// Specifies a dependency on the union (logical OR) of multiple version sets. A solvable + /// belonging to _any_ of the version sets contained in the union satisfies the requirement. + /// This variant is typically used for requirements that can be satisfied by two or more + /// version sets belonging to _different_ packages. + Union(VersionSetUnionId), +} + +impl Default for Requirement { + fn default() -> Self { + Self::Single(Default::default()) + } +} + +impl From for Requirement { + fn from(value: VersionSetId) -> Self { + Requirement::Single(value) + } +} + +impl From for Requirement { + fn from(value: VersionSetUnionId) -> Self { + Requirement::Union(value) + } +} + +impl Requirement { + pub(crate) fn display<'i>(&'i self, interner: &'i impl Interner) -> impl Display + '_ { + DisplayRequirement { + interner, + requirement: self, + } + } + + pub(crate) fn version_sets<'i>( + &'i self, + interner: &'i impl Interner, + ) -> impl Iterator + 'i { + match *self { + Requirement::Single(version_set) => { + itertools::Either::Left(std::iter::once(version_set)) + } + Requirement::Union(version_set_union) => { + itertools::Either::Right(interner.version_sets_in_union(version_set_union)) + } + } + } +} + +pub(crate) struct DisplayRequirement<'i, I: Interner> { + interner: &'i I, + requirement: &'i Requirement, +} + +impl<'i, I: Interner> Display for DisplayRequirement<'i, I> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match *self.requirement { + Requirement::Single(version_set) => write!( + f, + "{} {}", + self.interner + .display_name(self.interner.version_set_name(version_set)), + self.interner.display_version_set(version_set) + ), + Requirement::Union(version_set_union) => { + let formatted_version_sets = self + .interner + .version_sets_in_union(version_set_union) + .format_with(" | ", |version_set, f| { + f(&format_args!( + "{} {}", + self.interner + .display_name(self.interner.version_set_name(version_set)), + self.interner.display_version_set(version_set) + )) + }); + + write!(f, "{}", formatted_version_sets) + } + } + } +} diff --git a/src/snapshot.rs b/src/snapshot.rs index 279759d..08b3fab 100644 --- a/src/snapshot.rs +++ b/src/snapshot.rs @@ -16,7 +16,7 @@ use futures::FutureExt; use crate::{ internal::arena::ArenaId, Candidates, Dependencies, DependencyProvider, Interner, Mapping, - NameId, SolvableId, SolverCache, StringId, VersionSetId, + NameId, Requirement, SolvableId, SolverCache, StringId, VersionSetId, VersionSetUnionId, }; /// A single solvable in a [`DependencySnapshot`]. @@ -84,12 +84,19 @@ pub struct DependencySnapshot { )] pub solvables: Mapping, - /// All the requirements in the snapshot + /// All the version set unions in the snapshot #[cfg_attr( feature = "serde", serde(default, skip_serializing_if = "Mapping::is_empty") )] - pub requirements: Mapping, + pub version_set_unions: Mapping>, + + /// All the version sets in the snapshot + #[cfg_attr( + feature = "serde", + serde(default, skip_serializing_if = "Mapping::is_empty") + )] + pub version_sets: Mapping, /// All the packages in the snapshot #[cfg_attr( @@ -151,7 +158,8 @@ impl DependencySnapshot { let mut result = Self { solvables: Mapping::new(), - requirements: Mapping::new(), + version_set_unions: Mapping::new(), + version_sets: Mapping::new(), packages: Mapping::new(), strings: Mapping::new(), }; @@ -206,11 +214,37 @@ impl DependencySnapshot { } } Dependencies::Known(deps) => { - for &dep in deps.requirements.iter().chain(deps.constrains.iter()) { + for &dep in deps.constrains.iter() { if seen.insert(Element::VersionSet(dep)) { queue.push_back(Element::VersionSet(dep)); } } + + for &requirement in deps.requirements.iter() { + match requirement { + Requirement::Single(version_set) => { + if seen.insert(Element::VersionSet(version_set)) { + queue.push_back(Element::VersionSet(version_set)); + } + } + Requirement::Union(version_set_union_id) => { + let version_sets: HashSet<_> = cache + .provider() + .version_sets_in_union(version_set_union_id) + .collect(); + + for &version_set in version_sets.iter() { + if seen.insert(Element::VersionSet(version_set)) { + queue.push_back(Element::VersionSet(version_set)); + } + } + + result + .version_set_unions + .insert(version_set_union_id, version_sets); + } + } + } } } @@ -255,7 +289,7 @@ impl DependencySnapshot { matching_candidates: matching_candidates.iter().copied().collect(), }; - result.requirements.insert(version_set_id, version_set); + result.version_sets.insert(version_set_id, version_set); } } } @@ -323,7 +357,7 @@ impl<'s> SnapshotProvider<'s> { /// Adds another requirement that matches any version of a package pub fn add_package_requirement(&mut self, name: NameId) -> VersionSetId { - let id = self.snapshot.requirements.max() + self.additional_version_sets.len(); + let id = self.snapshot.version_sets.max() + self.additional_version_sets.len(); let package = self.package(name); @@ -360,12 +394,12 @@ impl<'s> SnapshotProvider<'s> { fn version_set(&self, version_set: VersionSetId) -> &VersionSet { let idx = version_set.to_usize(); - let max_idx = self.snapshot.requirements.max(); + let max_idx = self.snapshot.version_sets.max(); if idx >= max_idx { &self.additional_version_sets[idx - max_idx] } else { self.snapshot - .requirements + .version_sets .get(version_set) .expect("missing version set") } @@ -396,6 +430,18 @@ impl<'s> Interner for SnapshotProvider<'s> { fn solvable_name(&self, solvable: SolvableId) -> NameId { self.solvable(solvable).name } + + fn version_sets_in_union( + &self, + version_set_union_id: VersionSetUnionId, + ) -> impl Iterator { + self.snapshot + .version_set_unions + .get(version_set_union_id) + .expect("missing constraint") + .iter() + .copied() + } } impl<'s> DependencyProvider for SnapshotProvider<'s> { diff --git a/src/solver/cache.rs b/src/solver/cache.rs index cd4cc9b..e66a950 100644 --- a/src/solver/cache.rs +++ b/src/solver/cache.rs @@ -11,7 +11,7 @@ use crate::{ frozen_copy_map::FrozenCopyMap, id::{CandidatesId, DependenciesId}, }, - Candidates, Dependencies, DependencyProvider, NameId, SolvableId, VersionSetId, + Candidates, Dependencies, DependencyProvider, NameId, Requirement, SolvableId, VersionSetId, }; /// Keeps a cache of previously computed and/or requested information about @@ -32,10 +32,10 @@ pub struct SolverCache { /// included). version_set_inverse_candidates: FrozenMap, ahash::RandomState>, - /// A mapping of `VersionSetId` to a sorted list of candidates that match - /// that set. - pub(crate) version_set_to_sorted_candidates: - FrozenMap, ahash::RandomState>, + /// A mapping of [`Requirement`] to a sorted list of candidates that fulfill + /// that requirement. + pub(crate) requirement_to_sorted_candidates: + FrozenMap, ahash::RandomState>, /// A mapping from a solvable to a list of dependencies solvable_dependencies: Arena, @@ -59,7 +59,7 @@ impl SolverCache { package_name_to_candidates_in_flight: Default::default(), version_set_candidates: Default::default(), version_set_inverse_candidates: Default::default(), - version_set_to_sorted_candidates: Default::default(), + requirement_to_sorted_candidates: Default::default(), solvable_dependencies: Default::default(), solvable_to_dependencies: Default::default(), hint_dependencies_available: Default::default(), @@ -242,50 +242,89 @@ impl SolverCache { } } - /// Returns the candidates for the package with the given name similar to - /// [`Self::get_or_cache_candidates`] sorted from highest to lowest. + /// Returns the candidates fulfilling the [`Requirement`] sorted from highest to lowest + /// within each version set comprising the [`Requirement`]. /// /// If the provider has requested the solving process to be cancelled, the /// cancellation value will be returned as an `Err(...)`. pub async fn get_or_cache_sorted_candidates( &self, - version_set_id: VersionSetId, + requirement: Requirement, ) -> Result<&[SolvableId], Box> { - match self.version_set_to_sorted_candidates.get(&version_set_id) { - Some(candidates) => Ok(candidates), - None => { - let package_name_id = self.provider.version_set_name(version_set_id); - tracing::trace!( - "Getting sorted matching candidates for package: {:?}", - self.provider.display_name(package_name_id).to_string() - ); - - let matching_candidates = self - .get_or_cache_matching_candidates(version_set_id) - .await?; - let candidates = self.get_or_cache_candidates(package_name_id).await?; - - // Sort all the candidates in order in which they should be tried by the solver. - let mut sorted_candidates = Vec::new(); - sorted_candidates.extend_from_slice(matching_candidates); - self.provider - .sort_candidates(self, &mut sorted_candidates) - .await; - - // If we have a solvable that we favor, we sort that to the front. This ensures - // that the version that is favored is picked first. - if let Some(favored_id) = candidates.favored { - if let Some(pos) = sorted_candidates.iter().position(|&s| s == favored_id) { - // Move the element at `pos` to the front of the array - sorted_candidates[0..=pos].rotate_right(1); + match requirement { + Requirement::Single(version_set_id) => { + self.get_or_cache_sorted_candidates_for_version_set(version_set_id) + .await + } + Requirement::Union(version_set_union_id) => { + match self.requirement_to_sorted_candidates.get(&requirement) { + Some(candidates) => Ok(candidates), + None => { + let sorted_candidates = futures::future::try_join_all( + self.provider() + .version_sets_in_union(version_set_union_id) + .map(|version_set_id| { + self.get_or_cache_sorted_candidates_for_version_set( + version_set_id, + ) + }), + ) + .await? + .into_iter() + .flatten() + .copied() + .collect(); + + Ok(self + .requirement_to_sorted_candidates + .insert(requirement, sorted_candidates)) } } + } + } + } - Ok(self - .version_set_to_sorted_candidates - .insert(version_set_id, sorted_candidates)) + /// Returns the sorted candidates for a singular version set requirement + /// (akin to a [`Requirement::Single`]). + async fn get_or_cache_sorted_candidates_for_version_set( + &self, + version_set_id: VersionSetId, + ) -> Result<&[SolvableId], Box> { + let requirement = version_set_id.into(); + if let Some(candidates) = self.requirement_to_sorted_candidates.get(&requirement) { + return Ok(candidates); + } + + let package_name_id = self.provider.version_set_name(version_set_id); + tracing::trace!( + "Getting sorted matching candidates for package: {:?}", + self.provider.display_name(package_name_id).to_string() + ); + + let matching_candidates = self + .get_or_cache_matching_candidates(version_set_id) + .await?; + let candidates = self.get_or_cache_candidates(package_name_id).await?; + + // Sort all the candidates in order in which they should be tried by the solver. + let mut sorted_candidates = Vec::with_capacity(matching_candidates.len()); + sorted_candidates.extend_from_slice(matching_candidates); + self.provider + .sort_candidates(self, &mut sorted_candidates) + .await; + + // If we have a solvable that we favor, we sort that to the front. This ensures + // that the version that is favored is picked first. + if let Some(favored_id) = candidates.favored { + if let Some(pos) = sorted_candidates.iter().position(|&s| s == favored_id) { + // Move the element at `pos` to the front of the array + sorted_candidates[0..=pos].rotate_right(1); } } + + Ok(self + .requirement_to_sorted_candidates + .insert(requirement, sorted_candidates)) } /// Returns the dependencies of a solvable. Requests the solvables from the diff --git a/src/solver/clause.rs b/src/solver/clause.rs index 8ea3b99..1476c33 100644 --- a/src/solver/clause.rs +++ b/src/solver/clause.rs @@ -11,7 +11,7 @@ use crate::{ id::{ClauseId, InternalSolvableId, LearntClauseId, StringId, VersionSetId}, }, solver::{decision_map::DecisionMap, decision_tracker::DecisionTracker}, - Interner, NameId, SolvableId, + Interner, NameId, Requirement, SolvableId, }; /// Represents a single clause in the SAT problem @@ -47,11 +47,11 @@ pub(crate) enum Clause { /// /// In SAT terms: (root) InstallRoot, - /// The solvable requires the candidates associated with the version set + /// Makes the solvable require the candidates associated with the [`Requirement`]. /// /// In SAT terms: (¬A ∨ B1 ∨ B2 ∨ ... ∨ B99), where B1 to B99 represent the - /// possible candidates for the provided version set - Requires(InternalSolvableId, VersionSetId), + /// possible candidates for the provided [`Requirement`]. + Requires(InternalSolvableId, Requirement), /// Ensures only a single version of a package is installed /// /// Usage: generate one [`Clause::ForbidMultipleInstances`] clause for each @@ -108,7 +108,7 @@ impl Clause { /// added dynamically. fn requires( parent: InternalSolvableId, - requirement: VersionSetId, + requirement: Requirement, candidates: &[SolvableId], decision_tracker: &DecisionTracker, ) -> (Self, Option<[InternalSolvableId; 2]>, bool) { @@ -228,8 +228,8 @@ impl Clause { pub fn visit_literals( &self, learnt_clauses: &Arena>, - version_set_to_sorted_candidates: &FrozenMap< - VersionSetId, + requirements_to_sorted_candidates: &FrozenMap< + Requirement, Vec, ahash::RandomState, >, @@ -254,7 +254,7 @@ impl Clause { negate: true, }); - for &solvable_id in &version_set_to_sorted_candidates[&match_spec_id] { + for &solvable_id in &requirements_to_sorted_candidates[&match_spec_id] { visit(Literal { solvable_id: solvable_id.into(), negate: false, @@ -320,7 +320,7 @@ impl ClauseState { /// conflict. pub fn requires( candidate: InternalSolvableId, - requirement: VersionSetId, + requirement: Requirement, matching_candidates: &[SolvableId], decision_tracker: &DecisionTracker, ) -> (Self, bool) { @@ -519,8 +519,8 @@ impl ClauseState { pub fn next_unwatched_variable( &self, learnt_clauses: &Arena>, - version_set_to_sorted_candidates: &FrozenMap< - VersionSetId, + requirement_to_sorted_candidates: &FrozenMap< + Requirement, Vec, ahash::RandomState, >, @@ -543,7 +543,7 @@ impl ClauseState { .find(|&l| can_watch(l)) .map(|l| l.solvable_id), Clause::Constrains(..) | Clause::ForbidMultipleInstances(..) | Clause::Lock(..) => None, - Clause::Requires(solvable_id, version_set_id) => { + Clause::Requires(solvable_id, requirement) => { // The solvable that added this clause let solvable_lit = Literal { solvable_id, @@ -554,7 +554,7 @@ impl ClauseState { } // The available candidates - for &candidate in &version_set_to_sorted_candidates[&version_set_id] { + for &candidate in &requirement_to_sorted_candidates[&requirement] { let lit = Literal { solvable_id: candidate.into(), negate: false, @@ -622,13 +622,13 @@ impl<'i, I: Interner> Display for ClauseDisplay<'i, I> { ) } Clause::Learnt(learnt_id) => write!(f, "Learnt({learnt_id:?})"), - Clause::Requires(solvable_id, version_set_id) => { + Clause::Requires(solvable_id, requirement) => { write!( f, "Requires({}({:?}), {})", solvable_id.display(self.interner), solvable_id, - self.interner.display_version_set(version_set_id) + requirement.display(self.interner), ) } Clause::Constrains(s1, s2, version_set_id) => { @@ -850,7 +850,7 @@ mod test { // No conflict, all candidates available let (clause, conflict) = ClauseState::requires( parent, - VersionSetId::from_usize(0), + VersionSetId::from_usize(0).into(), &[candidate1, candidate2], &decisions, ); @@ -864,7 +864,7 @@ mod test { .unwrap(); let (clause, conflict) = ClauseState::requires( parent, - VersionSetId::from_usize(0), + VersionSetId::from_usize(0).into(), &[candidate1, candidate2], &decisions, ); @@ -878,7 +878,7 @@ mod test { .unwrap(); let (clause, conflict) = ClauseState::requires( parent, - VersionSetId::from_usize(0), + VersionSetId::from_usize(0).into(), &[candidate1, candidate2], &decisions, ); @@ -893,7 +893,7 @@ mod test { let panicked = std::panic::catch_unwind(|| { ClauseState::requires( parent, - VersionSetId::from_usize(0), + VersionSetId::from_usize(0).into(), &[candidate1, candidate2], &decisions, ) diff --git a/src/solver/mod.rs b/src/solver/mod.rs index 61fe85a..45205cc 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -3,7 +3,7 @@ use clause::{Clause, ClauseState, Literal}; use decision::Decision; use decision_tracker::DecisionTracker; use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; -use itertools::{chain, Itertools}; +use itertools::Itertools; use std::fmt::Display; use std::{any::Any, cell::RefCell, collections::HashSet, future::ready, ops::ControlFlow}; use watch_map::WatchMap; @@ -16,7 +16,7 @@ use crate::{ }, problem::Problem, runtime::{AsyncRuntime, NowOrNeverRuntime}, - Candidates, Dependencies, DependencyProvider, KnownDependencies, VersionSetId, + Candidates, Dependencies, DependencyProvider, KnownDependencies, Requirement, VersionSetId, }; mod cache; @@ -28,7 +28,7 @@ mod watch_map; #[derive(Default)] struct AddClauseOutput { - new_requires_clauses: Vec<(InternalSolvableId, VersionSetId, ClauseId)>, + new_requires_clauses: Vec<(InternalSolvableId, Requirement, ClauseId)>, conflicting_clauses: Vec, negative_assertions: Vec<(InternalSolvableId, ClauseId)>, clauses_to_watch: Vec, @@ -40,7 +40,7 @@ pub struct Solver { pub(crate) cache: SolverCache, pub(crate) clauses: RefCell>, - requires_clauses: Vec<(InternalSolvableId, VersionSetId, ClauseId)>, + requires_clauses: Vec<(InternalSolvableId, Requirement, ClauseId)>, watches: WatchMap, negative_assertions: Vec<(InternalSolvableId, ClauseId)>, @@ -54,8 +54,8 @@ pub struct Solver { decision_tracker: DecisionTracker, - /// The version sets that must be installed as part of the solution. - root_requirements: Vec, + /// The [`Requirement`]s that must be installed as part of the solution. + root_requirements: Vec, /// Additional constraints imposed by the root. root_constraints: Vec, @@ -164,7 +164,7 @@ impl Solver { /// inspect the causes and report them to the user. pub fn solve( &mut self, - root_requirements: Vec, + root_requirements: Vec, root_constraints: Vec, ) -> Result, UnsolvableOrCancelled> { // Clear state @@ -231,7 +231,7 @@ impl Solver { }, SortedCandidates { solvable_id: InternalSolvableId, - version_set_id: VersionSetId, + requirement: Requirement, candidates: &'i [SolvableId], }, NonMatchingCandidates { @@ -335,7 +335,11 @@ impl Solver { } }; - for version_set_id in chain(requirements.iter(), constrains.iter()).copied() { + for version_set_id in requirements + .iter() + .flat_map(|requirement| requirement.version_sets(self.provider())) + .chain(constrains.iter().copied()) + { let dependency_name = self.provider().version_set_name(version_set_id); if clauses_added_for_package.insert(dependency_name) { tracing::trace!( @@ -357,17 +361,17 @@ impl Solver { } } - for version_set_id in requirements { + for requirement in requirements { // Find all the solvable that match for the given version set pending_futures.push( async move { let candidates = self .cache - .get_or_cache_sorted_candidates(version_set_id) + .get_or_cache_sorted_candidates(requirement) .await?; Ok(TaskResult::SortedCandidates { solvable_id, - version_set_id, + requirement, candidates, }) } @@ -467,14 +471,12 @@ impl Solver { } TaskResult::SortedCandidates { solvable_id, - version_set_id, + requirement, candidates, } => { tracing::trace!( - "Sorted candidates available for {} {}", - self.provider() - .display_name(self.provider().version_set_name(version_set_id)), - self.provider().display_version_set(version_set_id), + "Sorted candidates available for {}", + requirement.display(self.provider()), ); // Queue requesting the dependencies of the candidates as well if they are @@ -492,7 +494,7 @@ impl Solver { let no_candidates = candidates.is_empty(); let (clause, conflict) = ClauseState::requires( solvable_id, - version_set_id, + requirement, candidates, &self.decision_tracker, ); @@ -500,7 +502,7 @@ impl Solver { let clause_id = self.clauses.borrow_mut().alloc(clause); let clause = &self.clauses.borrow()[clause_id]; - let &Clause::Requires(solvable_id, version_set_id) = &clause.kind else { + let &Clause::Requires(solvable_id, requirement) = &clause.kind else { unreachable!(); }; @@ -510,7 +512,7 @@ impl Solver { output .new_requires_clauses - .push((solvable_id, version_set_id, clause_id)); + .push((solvable_id, requirement, clause_id)); if conflict { output.conflicting_clauses.push(clause_id); @@ -787,7 +789,7 @@ impl Solver { } // Consider only clauses in which no candidates have been installed - let candidates = &self.cache.version_set_to_sorted_candidates[&deps]; + let candidates = &self.cache.requirement_to_sorted_candidates[&deps]; // Either find the first assignable candidate or determine that one of the // candidates is already assigned in which case the clause has @@ -1091,7 +1093,7 @@ impl Solver { // One of the watched literals is now false if let Some(variable) = clause.next_unwatched_variable( &self.learnt_clauses, - &self.cache.version_set_to_sorted_candidates, + &self.cache.requirement_to_sorted_candidates, self.decision_tracker.map(), ) { debug_assert!(!clause.watched_literals.contains(&variable)); @@ -1202,7 +1204,7 @@ impl Solver { let mut involved = HashSet::new(); self.clauses.borrow()[clause_id].kind.visit_literals( &self.learnt_clauses, - &self.cache.version_set_to_sorted_candidates, + &self.cache.requirement_to_sorted_candidates, |literal| { involved.insert(literal.solvable_id); }, @@ -1240,7 +1242,7 @@ impl Solver { self.clauses.borrow()[why].kind.visit_literals( &self.learnt_clauses, - &self.cache.version_set_to_sorted_candidates, + &self.cache.requirement_to_sorted_candidates, |literal| { if literal.eval(self.decision_tracker.map()) == Some(true) { assert_eq!(literal.solvable_id, decision.solvable_id); @@ -1286,7 +1288,7 @@ impl Solver { self.clauses.borrow()[clause_id].kind.visit_literals( &self.learnt_clauses, - &self.cache.version_set_to_sorted_candidates, + &self.cache.requirement_to_sorted_candidates, |literal| { if !first_iteration && literal.solvable_id == conflicting_solvable { // We are only interested in the causes of the conflict, so we ignore the diff --git a/src/solver/watch_map.rs b/src/solver/watch_map.rs index cf17e17..51ee685 100644 --- a/src/solver/watch_map.rs +++ b/src/solver/watch_map.rs @@ -52,10 +52,7 @@ impl WatchMap { clause.watched_literals[watch_index] = new_watch; clause.link_to_clause( watch_index, - *self - .map - .get(new_watch) - .expect("linking to unknown solvable"), + self.map.get(new_watch).copied().unwrap_or(ClauseId::null()), ); self.map.insert(new_watch, clause_id); } diff --git a/src/utils/pool.rs b/src/utils/pool.rs index 7358fcf..5fffddd 100644 --- a/src/utils/pool.rs +++ b/src/utils/pool.rs @@ -6,7 +6,8 @@ use std::{ use crate::internal::{ arena::Arena, frozen_copy_map::FrozenCopyMap, - id::{NameId, SolvableId, StringId, VersionSetId}, + id::{NameId, SolvableId, StringId, VersionSetId, VersionSetUnionId}, + small_vec::SmallVec, }; /// A solvable represents a single candidate of a package. @@ -47,6 +48,8 @@ pub struct Pool { /// Map from version set to the id of their interned counterpart version_set_to_id: FrozenCopyMap<(NameId, VS), VersionSetId, ahash::RandomState>, + + version_set_unions: Arena>, } impl Default for Pool { @@ -61,6 +64,7 @@ impl Default for Pool { string_to_ids: Default::default(), version_set_to_id: Default::default(), version_sets: Arena::new(), + version_set_unions: Arena::new(), } } } @@ -182,6 +186,32 @@ impl Pool { pub fn resolve_version_set_package_name(&self, id: VersionSetId) -> NameId { self.version_sets[id].0 } + + /// Interns a union of two or more version sets and returns its [`VersionSetUnionId`]. + /// + /// Version set unions are *not* deduplicated, and a unique id is returned on every + /// invocation. + pub fn intern_version_set_union( + &self, + first: VersionSetId, + others: impl Iterator, + ) -> VersionSetUnionId { + self.version_set_unions + .alloc(others.fold(SmallVec::one(first), |mut vec, version_set| { + vec.push(version_set); + vec + })) + } + + /// Returns the version sets in the version set union with the given id. + /// + /// Panics if there is no union with the given id. + pub fn resolve_version_set_union( + &self, + id: VersionSetUnionId, + ) -> impl Iterator + '_ { + self.version_set_unions[id].iter().copied() + } } /// A helper struct to visualize a name. diff --git a/tests/snapshots/solver__snapshot.snap b/tests/snapshots/solver__snapshot.snap index 224e308..2edcdb6 100644 --- a/tests/snapshots/solver__snapshot.snap +++ b/tests/snapshots/solver__snapshot.snap @@ -1,6 +1,5 @@ --- source: tests/solver.rs -assertion_line: 1121 expression: "solve_for_snapshot(snapshot_provider, &[menu_req])" --- dropdown=2 diff --git a/tests/snapshots/solver__snapshot_union_requirements.snap b/tests/snapshots/solver__snapshot_union_requirements.snap new file mode 100644 index 0000000..16b5780 --- /dev/null +++ b/tests/snapshots/solver__snapshot_union_requirements.snap @@ -0,0 +1,6 @@ +--- +source: tests/solver.rs +expression: "solve_for_snapshot(snapshot_provider, &[intl_req, union_req])" +--- +intl=5 +union=1 diff --git a/tests/solver.rs b/tests/solver.rs index 0064255..f5708d0 100644 --- a/tests/solver.rs +++ b/tests/solver.rs @@ -1,5 +1,6 @@ use std::{ any::Any, + borrow::Borrow, cell::{Cell, RefCell}, collections::HashSet, fmt::{Debug, Display, Formatter}, @@ -21,8 +22,9 @@ use itertools::Itertools; use resolvo::{ snapshot::{DependencySnapshot, SnapshotProvider}, utils::{Pool, Range}, - Candidates, Dependencies, DependencyProvider, Interner, KnownDependencies, NameId, SolvableId, - Solver, SolverCache, StringId, UnsolvableOrCancelled, VersionSetId, + Candidates, Dependencies, DependencyProvider, Interner, KnownDependencies, NameId, Requirement, + SolvableId, Solver, SolverCache, StringId, UnsolvableOrCancelled, VersionSetId, + VersionSetUnionId, }; use tracing_test::traced_test; @@ -116,6 +118,14 @@ impl Spec { pub fn new(name: String, versions: Range) -> Self { Self { name, versions } } + + pub fn parse_union( + spec: &str, + ) -> impl Iterator::Err>> + '_ { + spec.split('|') + .map(str::trim) + .map(|dep| Spec::from_str(dep)) + } } impl FromStr for Spec { @@ -173,7 +183,7 @@ struct BundleBoxProvider { #[derive(Debug, Clone)] struct BundleBoxPackageDependencies { - dependencies: Vec, + dependencies: Vec>, constrains: Vec, } @@ -188,18 +198,42 @@ impl BundleBoxProvider { .expect("package missing") } - pub fn requirements(&self, requirements: &[&str]) -> Vec { + pub fn requirements>(&self, requirements: &[&str]) -> Vec { requirements .iter() .map(|dep| Spec::from_str(dep).unwrap()) - .map(|spec| { - let dep_name = self.pool.intern_package_name(&spec.name); - self.pool - .intern_version_set(dep_name, spec.versions.clone()) + .map(|spec| self.intern_version_set(&spec)) + .map(From::from) + .collect() + } + + pub fn parse_requirements(&self, requirements: &[&str]) -> Vec { + requirements + .iter() + .map(|deps| { + let specs = Spec::parse_union(deps).map(Result::unwrap); + self.intern_version_set_union(specs).into() }) .collect() } + pub fn intern_version_set(&self, spec: &Spec) -> VersionSetId { + let dep_name = self.pool.intern_package_name(&spec.name); + self.pool + .intern_version_set(dep_name, spec.versions.clone()) + } + + pub fn intern_version_set_union( + &self, + specs: impl IntoIterator>, + ) -> VersionSetUnionId { + let mut specs = specs + .into_iter() + .map(|spec| self.intern_version_set(spec.borrow())); + self.pool + .intern_version_set_union(specs.next().unwrap(), specs) + } + pub fn from_packages(packages: &[(&str, u32, Vec<&str>)]) -> Self { let mut result = Self::new(); for (name, version, deps) in packages { @@ -236,7 +270,7 @@ impl BundleBoxProvider { let dependencies = dependencies .iter() - .map(|dep| Spec::from_str(dep)) + .map(|dep| Spec::parse_union(dep).collect()) .collect::, _>>() .unwrap(); @@ -319,6 +353,12 @@ impl Interner for BundleBoxProvider { fn solvable_name(&self, solvable: SolvableId) -> NameId { self.pool.resolve_solvable(solvable).name } + fn version_sets_in_union( + &self, + version_set_union: VersionSetUnionId, + ) -> impl Iterator { + self.pool.resolve_version_set_union(version_set_union) + } } impl DependencyProvider for BundleBoxProvider { @@ -436,9 +476,33 @@ impl DependencyProvider for BundleBoxProvider { constrains: Vec::with_capacity(deps.constrains.len()), }; for req in &deps.dependencies { - let dep_name = self.pool.intern_package_name(&req.name); - let dep_spec = self.pool.intern_version_set(dep_name, req.versions.clone()); - result.requirements.push(dep_spec); + let mut remaining_req_specs = req.iter(); + + let first = remaining_req_specs + .next() + .expect("Dependency spec must have at least one constraint"); + + let first_name = self.pool.intern_package_name(&first.name); + let first_version_set = self + .pool + .intern_version_set(first_name, first.versions.clone()); + + let requirement = if remaining_req_specs.len() == 0 { + first_version_set.into() + } else { + let other_version_sets = remaining_req_specs.map(|spec| { + self.pool.intern_version_set( + self.pool.intern_package_name(&spec.name), + spec.versions.clone(), + ) + }); + + self.pool + .intern_version_set_union(first_version_set, other_version_sets) + .into() + }; + + result.requirements.push(requirement); } for req in &deps.constrains { @@ -509,7 +573,7 @@ fn solve_snapshot(mut provider: BundleBoxProvider, specs: &[&str]) -> String { provider.sleep_before_return = true; - let requirements = provider.requirements(specs); + let requirements = provider.parse_requirements(specs); let mut solver = Solver::new(provider).with_runtime(runtime); match solver.solve(requirements, Vec::new()) { Ok(solvables) => transaction_to_string(solver.provider(), &solvables), @@ -833,6 +897,28 @@ fn test_resolve_cyclic() { "###); } +#[test] +fn test_resolve_union_requirements() { + let mut provider = BundleBoxProvider::from_packages(&[ + ("a", 1, vec![]), + ("b", 1, vec![]), + ("c", 1, vec!["a"]), + ("d", 1, vec!["b"]), + ("e", 1, vec!["a | b"]), + ]); + + // Make d conflict with a=1 + provider.add_package("f", 1.into(), &["b"], &["a 2"]); + + let result = solve_snapshot(provider, &["c | d", "e", "f"]); + assert_snapshot!(result, @r###" + b=1 + d=1 + e=1 + f=1 + "###); +} + #[test] fn test_unsat_locked_and_excluded() { let mut provider = BundleBoxProvider::from_packages(&[ @@ -1121,6 +1207,32 @@ fn test_snapshot() { assert_snapshot!(solve_for_snapshot(snapshot_provider, &[menu_req])); } +#[test] +fn test_snapshot_union_requirements() { + let provider = BundleBoxProvider::from_packages(&[ + ("icons", 2, vec![]), + ("icons", 1, vec![]), + ("intl", 5, vec![]), + ("intl", 3, vec![]), + ("union", 1, vec!["icons 2 | intl"]), + ]); + + let intl_name_id = provider.package_name("intl"); + let union_name_id = provider.package_name("union"); + + let snapshot = provider.into_snapshot(); + + let mut snapshot_provider = snapshot.provider(); + + let intl_req = snapshot_provider.add_package_requirement(intl_name_id); + let union_req = snapshot_provider.add_package_requirement(union_name_id); + + assert_snapshot!(solve_for_snapshot( + snapshot_provider, + &[intl_req, union_req] + )); +} + #[cfg(feature = "serde")] fn serialize_snapshot(snapshot: &DependencySnapshot, destination: impl AsRef) { let file = std::io::BufWriter::new(std::fs::File::create(destination.as_ref()).unwrap()); @@ -1129,7 +1241,10 @@ fn serialize_snapshot(snapshot: &DependencySnapshot, destination: impl AsRef String { let mut solver = Solver::new(provider); - match solver.solve(root_reqs.to_vec(), Vec::new()) { + match solver.solve( + root_reqs.iter().copied().map(Into::into).collect(), + Vec::new(), + ) { Ok(solvables) => transaction_to_string(solver.provider(), &solvables), Err(UnsolvableOrCancelled::Unsolvable(problem)) => { // Write the problem graphviz to stderr diff --git a/tools/solve-snapshot/src/main.rs b/tools/solve-snapshot/src/main.rs index 0ae50d2..a58e1e4 100644 --- a/tools/solve-snapshot/src/main.rs +++ b/tools/solve-snapshot/src/main.rs @@ -50,7 +50,7 @@ fn main() { let mut solver = Solver::new(provider); let mut records = None; let mut error = None; - match solver.solve(vec![package_requirement], vec![]) { + match solver.solve(vec![package_requirement.into()], vec![]) { Ok(solution) => { eprintln!("OK"); records = Some(solution.len()) From 0c77ba3d4c44b908f609b0d34e3a9a44c80907ad Mon Sep 17 00:00:00 2001 From: Arav Singhal Date: Mon, 5 Aug 2024 23:50:18 -0500 Subject: [PATCH 5/7] feat(solver)!: Solve for optional solvables in addition to the root solvable (#54) * fix(solver): Call `clear` if level = 0 in `DecisionTracker::undo_until` * fix(solver): Remove unneeded assert when processing package candidates The requirement of ensuring that no decisions are made for any candidate before the clauses for a package are added is unnecessary, since we now perform propagation every time after adding any new clauses. * fix(solver): Make `Solver::run_sat` solve for any solvable at any level * feat(solver): Solve for optional solvables in addition to the root * test(solver): Add tests for `Solver::solve_with_additional` * fix(solver): Remove unused `DecisionTracker::is_empty` * feat(solver)!: Rename `Problem`s to `Conflict`s * feat(solver)!: Add `Problem` struct to describe solver input --- cpp/include/resolvo.h | 7 +- cpp/src/lib.rs | 45 ++++-- cpp/tests/solve.cpp | 20 ++- src/{problem.rs => conflict.rs} | 141 +++++++++--------- src/lib.rs | 4 +- src/solver/decision_tracker.rs | 9 +- src/solver/mod.rs | 248 ++++++++++++++++++++------------ tests/solver.rs | 230 ++++++++++++++++++++++++----- 8 files changed, 486 insertions(+), 218 deletions(-) rename src/{problem.rs => conflict.rs} (88%) diff --git a/cpp/include/resolvo.h b/cpp/include/resolvo.h index 48844d6..97d00f5 100644 --- a/cpp/include/resolvo.h +++ b/cpp/include/resolvo.h @@ -4,6 +4,7 @@ #include "resolvo_internal.h" namespace resolvo { +using cbindgen_private::Problem; using cbindgen_private::Requirement; /** @@ -30,8 +31,8 @@ inline Requirement requirement_union(VersionSetUnionId id) { * stored in `result`. If the solve was unsuccesfull an error describing the reason is returned and * the result vector will be empty. */ -inline String solve(DependencyProvider &provider, Slice requirements, - Slice constraints, Vector &result) { +inline String solve(DependencyProvider &provider, const Problem &problem, + Vector &result) { cbindgen_private::DependencyProvider bridge{ static_cast(&provider), private_api::bridge_display_solvable, @@ -50,7 +51,7 @@ inline String solve(DependencyProvider &provider, Slice requirement }; String error; - cbindgen_private::resolvo_solve(&bridge, requirements, constraints, &error, &result); + cbindgen_private::resolvo_solve(&bridge, &problem, &error, &result); return error; } } // namespace resolvo diff --git a/cpp/src/lib.rs b/cpp/src/lib.rs index 4aa9998..a0ab2ac 100644 --- a/cpp/src/lib.rs +++ b/cpp/src/lib.rs @@ -473,28 +473,45 @@ impl<'d> resolvo::DependencyProvider for &'d DependencyProvider { } } +#[repr(C)] +pub struct Problem<'a> { + pub requirements: Slice<'a, Requirement>, + pub constraints: Slice<'a, VersionSetId>, + pub soft_requirements: Slice<'a, SolvableId>, +} + #[no_mangle] #[allow(unused)] pub extern "C" fn resolvo_solve( provider: &DependencyProvider, - requirements: Slice, - constraints: Slice, + problem: &Problem, error: &mut String, result: &mut Vector, ) -> bool { - let requirements = requirements - .into_iter() - .copied() - .map(Into::into) - .collect::>(); - let constraints = constraints - .into_iter() - .copied() - .map(Into::into) - .collect::>(); - let mut solver = resolvo::Solver::new(provider); - match solver.solve(requirements, constraints) { + + let problem = resolvo::Problem { + requirements: problem + .requirements + .into_iter() + .copied() + .map(Into::into) + .collect(), + constraints: problem + .constraints + .into_iter() + .copied() + .map(Into::into) + .collect(), + soft_requirements: problem + .soft_requirements + .into_iter() + .copied() + .map(Into::into) + .collect(), + }; + + match solver.solve(problem) { Ok(solution) => { *result = solution.into_iter().map(Into::into).collect(); true diff --git a/cpp/tests/solve.cpp b/cpp/tests/solve.cpp index 13a05fa..1bb02b7 100644 --- a/cpp/tests/solve.cpp +++ b/cpp/tests/solve.cpp @@ -216,19 +216,27 @@ SCENARIO("Solve") { auto c_1 = db.alloc_candidate("c", 1, {}); + const auto d_1 = db.alloc_candidate("d", 1, {}); + // Construct a problem to be solved by the solver resolvo::Vector requirements = {db.alloc_requirement("a", 1, 3)}; - resolvo::Vector constraints = {db.alloc_version_set("b", 1, 3), - db.alloc_version_set("c", 1, 3)}; + resolvo::Vector constraints = { + db.alloc_version_set("b", 1, 3), + db.alloc_version_set("c", 1, 3), + db.alloc_version_set("d", 2, 2), + }; + resolvo::Vector soft_requirements{c_1, d_1}; // Solve the problem resolvo::Vector result; - resolvo::solve(db, requirements, constraints, result); + resolvo::Problem problem = {requirements, constraints, soft_requirements}; + resolvo::solve(db, problem, result); // Check the result - REQUIRE(result.size() == 2); + REQUIRE(result.size() == 3); REQUIRE(result[0] == a_2); REQUIRE(result[1] == b_2); + REQUIRE(result[2] == c_1); } SCENARIO("Solve Union") { @@ -264,7 +272,9 @@ SCENARIO("Solve Union") { // Solve the problem resolvo::Vector result; - resolvo::solve(db, requirements, constraints, result); + resolvo::Problem problem = {requirements, constraints, {}}; + resolvo::solve(db, problem, result); + ; // Check the result REQUIRE(result.size() == 4); diff --git a/src/problem.rs b/src/conflict.rs similarity index 88% rename from src/problem.rs rename to src/conflict.rs index 3033244..5996dbe 100644 --- a/src/problem.rs +++ b/src/conflict.rs @@ -20,12 +20,12 @@ use crate::{ /// Represents the cause of the solver being unable to find a solution #[derive(Debug)] -pub struct Problem { +pub struct Conflict { /// The clauses involved in an unsatisfiable conflict clauses: Vec, } -impl Problem { +impl Conflict { pub(crate) fn default() -> Self { Self { clauses: Vec::new(), @@ -38,18 +38,18 @@ impl Problem { } } - /// Generates a graph representation of the problem (see [`ProblemGraph`] + /// Generates a graph representation of the conflict (see [`ConflictGraph`] /// for details) pub fn graph( &self, solver: &Solver, - ) -> ProblemGraph { - let mut graph = DiGraph::::default(); + ) -> ConflictGraph { + let mut graph = DiGraph::::default(); let mut nodes: HashMap = HashMap::default(); let mut excluded_nodes: HashMap = HashMap::default(); let root_node = Self::add_node(&mut graph, &mut nodes, InternalSolvableId::root()); - let unresolved_node = graph.add_node(ProblemNode::UnresolvedDependency); + let unresolved_node = graph.add_node(ConflictNode::UnresolvedDependency); for clause_id in &self.clauses { let clause = &solver.clauses.borrow()[*clause_id].kind; @@ -60,12 +60,12 @@ impl Problem { let package_node = Self::add_node(&mut graph, &mut nodes, *solvable); let excluded_node = excluded_nodes .entry(*reason) - .or_insert_with(|| graph.add_node(ProblemNode::Excluded(*reason))); + .or_insert_with(|| graph.add_node(ConflictNode::Excluded(*reason))); graph.add_edge( package_node, *excluded_node, - ProblemEdge::Conflict(ConflictCause::Excluded), + ConflictEdge::Conflict(ConflictCause::Excluded), ); } Clause::Learnt(..) => unreachable!(), @@ -82,7 +82,7 @@ impl Problem { graph.add_edge( package_node, unresolved_node, - ProblemEdge::Requires(version_set_id), + ConflictEdge::Requires(version_set_id), ); } else { for &candidate_id in candidates { @@ -93,7 +93,7 @@ impl Problem { graph.add_edge( package_node, candidate_node, - ProblemEdge::Requires(version_set_id), + ConflictEdge::Requires(version_set_id), ); } } @@ -101,14 +101,14 @@ impl Problem { &Clause::Lock(locked, forbidden) => { let node2_id = Self::add_node(&mut graph, &mut nodes, forbidden); let conflict = ConflictCause::Locked(locked); - graph.add_edge(root_node, node2_id, ProblemEdge::Conflict(conflict)); + graph.add_edge(root_node, node2_id, ConflictEdge::Conflict(conflict)); } &Clause::ForbidMultipleInstances(instance1_id, instance2_id, _) => { let node1_id = Self::add_node(&mut graph, &mut nodes, instance1_id); let node2_id = Self::add_node(&mut graph, &mut nodes, instance2_id); let conflict = ConflictCause::ForbidMultipleInstances; - graph.add_edge(node1_id, node2_id, ProblemEdge::Conflict(conflict)); + graph.add_edge(node1_id, node2_id, ConflictEdge::Conflict(conflict)); } &Clause::Constrains(package_id, dep_id, version_set_id) => { let package_node = Self::add_node(&mut graph, &mut nodes, package_id); @@ -117,7 +117,7 @@ impl Problem { graph.add_edge( package_node, dep_node, - ProblemEdge::Conflict(ConflictCause::Constrains(version_set_id)), + ConflictEdge::Conflict(ConflictCause::Constrains(version_set_id)), ); } } @@ -142,7 +142,7 @@ impl Problem { } assert_eq!(graph.node_count(), visited_nodes.len()); - ProblemGraph { + ConflictGraph { graph, root_node, unresolved_node, @@ -150,16 +150,16 @@ impl Problem { } fn add_node( - graph: &mut DiGraph, + graph: &mut DiGraph, nodes: &mut HashMap, solvable_id: InternalSolvableId, ) -> NodeIndex { *nodes .entry(solvable_id) - .or_insert_with(|| graph.add_node(ProblemNode::Solvable(solvable_id))) + .or_insert_with(|| graph.add_node(ConflictNode::Solvable(solvable_id))) } - /// Display a user-friendly error explaining the problem + /// Display a user-friendly error explaining the conflict pub fn display_user_friendly<'a, D: DependencyProvider, RT: AsyncRuntime>( &self, solver: &'a Solver, @@ -169,9 +169,9 @@ impl Problem { } } -/// A node in the graph representation of a [`Problem`] +/// A node in the graph representation of a [`Conflict`] #[derive(Copy, Clone, Eq, PartialEq)] -pub(crate) enum ProblemNode { +pub(crate) enum ConflictNode { /// Node corresponding to a solvable Solvable(InternalSolvableId), /// Node representing a dependency without candidates @@ -180,23 +180,23 @@ pub(crate) enum ProblemNode { Excluded(StringId), } -impl ProblemNode { +impl ConflictNode { fn solvable_id(self) -> InternalSolvableId { match self { - ProblemNode::Solvable(solvable_id) => solvable_id, - ProblemNode::UnresolvedDependency => { + ConflictNode::Solvable(solvable_id) => solvable_id, + ConflictNode::UnresolvedDependency => { panic!("expected solvable node, found unresolved dependency") } - ProblemNode::Excluded(_) => { + ConflictNode::Excluded(_) => { panic!("expected solvable node, found excluded node") } } } } -/// An edge in the graph representation of a [`Problem`] +/// An edge in the graph representation of a [`Conflict`] #[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] -pub(crate) enum ProblemEdge { +pub(crate) enum ConflictEdge { /// The target node is a candidate for the dependency specified by the /// [`Requirement`] Requires(Requirement), @@ -204,18 +204,18 @@ pub(crate) enum ProblemEdge { Conflict(ConflictCause), } -impl ProblemEdge { +impl ConflictEdge { fn try_requires(self) -> Option { match self { - ProblemEdge::Requires(match_spec_id) => Some(match_spec_id), - ProblemEdge::Conflict(_) => None, + ConflictEdge::Requires(match_spec_id) => Some(match_spec_id), + ConflictEdge::Conflict(_) => None, } } fn requires(self) -> Requirement { match self { - ProblemEdge::Requires(match_spec_id) => match_spec_id, - ProblemEdge::Conflict(_) => panic!("expected requires edge, found conflict"), + ConflictEdge::Requires(match_spec_id) => match_spec_id, + ConflictEdge::Conflict(_) => panic!("expected requires edge, found conflict"), } } } @@ -241,22 +241,22 @@ pub(crate) enum ConflictCause { /// - They all have the same name /// - They all have the same predecessor nodes /// - They all have the same successor nodes -pub(crate) struct MergedProblemNode { +pub(crate) struct MergedConflictNode { pub ids: Vec, } -/// Graph representation of [`Problem`] +/// Graph representation of [`Conflict`] /// /// The root of the graph is the "root solvable". Note that not all the /// solvable's requirements are included in the graph, only those that are /// directly or indirectly involved in the conflict. -pub struct ProblemGraph { - graph: DiGraph, +pub struct ConflictGraph { + graph: DiGraph, root_node: NodeIndex, unresolved_node: Option, } -impl ProblemGraph { +impl ConflictGraph { /// Writes a graphviz graph that represents this instance to the specified /// output. pub fn graphviz( @@ -276,7 +276,7 @@ impl ProblemGraph { write!(f, "digraph {{")?; for nx in graph.node_indices() { let id = match graph.node_weight(nx).as_ref().unwrap() { - ProblemNode::Solvable(id) => *id, + ConflictNode::Solvable(id) => *id, _ => continue, }; @@ -294,26 +294,28 @@ impl ProblemGraph { let target = *graph.node_weight(edge.target()).unwrap(); let color = match edge.weight() { - ProblemEdge::Requires(_) if target != ProblemNode::UnresolvedDependency => { + ConflictEdge::Requires(_) if target != ConflictNode::UnresolvedDependency => { "black" } _ => "red", }; let label = match edge.weight() { - ProblemEdge::Requires(requirement) => requirement.display(interner).to_string(), - ProblemEdge::Conflict(ConflictCause::Constrains(version_set_id)) => { + ConflictEdge::Requires(requirement) => { + requirement.display(interner).to_string() + } + ConflictEdge::Conflict(ConflictCause::Constrains(version_set_id)) => { interner.display_version_set(*version_set_id).to_string() } - ProblemEdge::Conflict(ConflictCause::ForbidMultipleInstances) - | ProblemEdge::Conflict(ConflictCause::Locked(_)) => { + ConflictEdge::Conflict(ConflictCause::ForbidMultipleInstances) + | ConflictEdge::Conflict(ConflictCause::Locked(_)) => { "already installed".to_string() } - ProblemEdge::Conflict(ConflictCause::Excluded) => "excluded".to_string(), + ConflictEdge::Conflict(ConflictCause::Excluded) => "excluded".to_string(), }; let target = match target { - ProblemNode::Solvable(mut solvable_2) => { + ConflictNode::Solvable(mut solvable_2) => { // If the target node has been merged, replace it by the first id in the // group if let Some(solvable_id) = solvable_2.as_solvable() { @@ -329,8 +331,8 @@ impl ProblemGraph { solvable_2.display(interner).to_string() } - ProblemNode::UnresolvedDependency => "unresolved".to_string(), - ProblemNode::Excluded(reason) => { + ConflictNode::UnresolvedDependency => "unresolved".to_string(), + ConflictNode::Excluded(reason) => { format!("reason: {}", interner.display_string(reason)) } }; @@ -348,15 +350,15 @@ impl ProblemGraph { /// Simplifies and collapses nodes so that these can be considered the same /// candidate - fn simplify(&self, interner: &impl Interner) -> HashMap> { + fn simplify(&self, interner: &impl Interner) -> HashMap> { let graph = &self.graph; // Gather information about nodes that can be merged let mut maybe_merge = HashMap::default(); for node_id in graph.node_indices() { let candidate = match graph[node_id] { - ProblemNode::UnresolvedDependency | ProblemNode::Excluded(_) => continue, - ProblemNode::Solvable(solvable_id) => { + ConflictNode::UnresolvedDependency | ConflictNode::Excluded(_) => continue, + ConflictNode::Solvable(solvable_id) => { if solvable_id.is_root() { continue; } else { @@ -395,7 +397,7 @@ impl ProblemGraph { let mut merged_candidates = HashMap::default(); for m in maybe_merge.into_values() { if m.len() > 1 { - let m = Rc::new(MergedProblemNode { + let m = Rc::new(MergedConflictNode { ids: m.into_iter().map(|(_, snd)| snd).collect(), }); for &id in &m.ids { @@ -427,7 +429,7 @@ impl ProblemGraph { let excluding_edges = self .graph .edges_directed(nx, Direction::Incoming) - .any(|e| matches!(e.weight(), ProblemEdge::Conflict(ConflictCause::Excluded))); + .any(|e| matches!(e.weight(), ConflictEdge::Conflict(ConflictCause::Excluded))); if excluding_edges { // Nodes with incoming disabling edges aren't installable continue; @@ -436,7 +438,7 @@ impl ProblemGraph { let outgoing_conflicts = self .graph .edges_directed(nx, Direction::Outgoing) - .any(|e| matches!(e.weight(), ProblemEdge::Conflict(_))); + .any(|e| matches!(e.weight(), ConflictEdge::Conflict(_))); if outgoing_conflicts { // Nodes with outgoing conflicts aren't installable continue; @@ -447,8 +449,8 @@ impl ProblemGraph { .graph .edges_directed(nx, Direction::Outgoing) .map(|e| match e.weight() { - ProblemEdge::Requires(version_set_id) => (version_set_id, e.target()), - ProblemEdge::Conflict(_) => unreachable!(), + ConflictEdge::Requires(version_set_id) => (version_set_id, e.target()), + ConflictEdge::Conflict(_) => unreachable!(), }) .chunk_by(|(&version_set_id, _)| version_set_id); @@ -482,7 +484,7 @@ impl ProblemGraph { let outgoing_conflicts = self .graph .edges_directed(nx, Direction::Outgoing) - .any(|e| matches!(e.weight(), ProblemEdge::Conflict(_))); + .any(|e| matches!(e.weight(), ConflictEdge::Conflict(_))); if outgoing_conflicts { // Nodes with outgoing conflicts aren't missing continue; @@ -493,8 +495,8 @@ impl ProblemGraph { .graph .edges_directed(nx, Direction::Outgoing) .map(|e| match e.weight() { - ProblemEdge::Requires(version_set_id) => (version_set_id, e.target()), - ProblemEdge::Conflict(_) => unreachable!(), + ConflictEdge::Requires(version_set_id) => (version_set_id, e.target()), + ConflictEdge::Conflict(_) => unreachable!(), }) .chunk_by(|(&version_set_id, _)| version_set_id); @@ -620,17 +622,17 @@ mod tests { } /// A struct implementing [`fmt::Display`] that generates a user-friendly -/// representation of a problem graph +/// representation of a conflict graph pub struct DisplayUnsat<'i, I: Interner> { - graph: ProblemGraph, - merged_candidates: HashMap>, + graph: ConflictGraph, + merged_candidates: HashMap>, installable_set: HashSet, missing_set: HashSet, interner: &'i I, } impl<'i, I: Interner> DisplayUnsat<'i, I> { - pub(crate) fn new(graph: ProblemGraph, interner: &'i I) -> Self { + pub(crate) fn new(graph: ConflictGraph, interner: &'i I) -> Self { let merged_candidates = graph.simplify(interner); let installable_set = graph.get_installable_set(); let missing_set = graph.get_missing_set(); @@ -647,7 +649,7 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { fn fmt_graph( &self, f: &mut Formatter<'_>, - top_level_edges: &[EdgeReference<'_, ProblemEdge>], + top_level_edges: &[EdgeReference<'_, ConflictEdge>], top_level_indent: bool, ) -> fmt::Result { pub enum DisplayOp { @@ -705,7 +707,7 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { let target_nx = graph.edge_endpoints(edges[0]).unwrap().1; let missing = - edges.len() == 1 && graph[target_nx] == ProblemNode::UnresolvedDependency; + edges.len() == 1 && graph[target_nx] == ConflictNode::UnresolvedDependency; if missing { // No candidates for requirement if top_level { @@ -845,8 +847,8 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { let excluded = graph .edges_directed(candidate, Direction::Outgoing) .find_map(|e| match e.weight() { - ProblemEdge::Conflict(ConflictCause::Excluded) => { - let ProblemNode::Excluded(reason) = graph[e.target()] else { + ConflictEdge::Conflict(ConflictCause::Excluded) => { + let ConflictNode::Excluded(reason) = graph[e.target()] else { unreachable!(); }; Some(reason) @@ -854,12 +856,13 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { _ => None, }); let already_installed = graph.edges(candidate).any(|e| { - e.weight() == &ProblemEdge::Conflict(ConflictCause::ForbidMultipleInstances) + e.weight() + == &ConflictEdge::Conflict(ConflictCause::ForbidMultipleInstances) }); let constrains_conflict = graph.edges(candidate).any(|e| { matches!( e.weight(), - ProblemEdge::Conflict(ConflictCause::Constrains(_)) + ConflictEdge::Conflict(ConflictCause::Constrains(_)) ) }); let is_leaf = graph.edges(candidate).next().is_none(); @@ -881,7 +884,7 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { let mut version_sets = graph .edges(candidate) .flat_map(|e| match e.weight() { - ProblemEdge::Conflict(ConflictCause::Constrains( + ConflictEdge::Conflict(ConflictCause::Constrains( version_set_id, )) => Some(version_set_id), _ => None, @@ -972,8 +975,8 @@ impl<'i, I: Interner> fmt::Display for DisplayUnsat<'i, I> { let indent = indenter.get_indent(); let conflict = match e.weight() { - ProblemEdge::Requires(_) => continue, - ProblemEdge::Conflict(conflict) => conflict, + ConflictEdge::Requires(_) => continue, + ConflictEdge::Conflict(conflict) => conflict, }; // The only possible conflict at the root level is a Locked conflict diff --git a/src/lib.rs b/src/lib.rs index 5ca2ba1..ece1a5e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,8 +10,8 @@ #![deny(missing_docs)] +pub mod conflict; pub(crate) mod internal; -pub mod problem; mod requirement; pub mod runtime; pub mod snapshot; @@ -29,7 +29,7 @@ pub use internal::{ }; use itertools::Itertools; pub use requirement::Requirement; -pub use solver::{Solver, SolverCache, UnsolvableOrCancelled}; +pub use solver::{Problem, Solver, SolverCache, UnsolvableOrCancelled}; /// An object that is used by the solver to query certain properties of /// different internalized objects. diff --git a/src/solver/decision_tracker.rs b/src/solver/decision_tracker.rs index 080061a..abbfd12 100644 --- a/src/solver/decision_tracker.rs +++ b/src/solver/decision_tracker.rs @@ -27,10 +27,6 @@ impl DecisionTracker { self.propagate_index = 0; } - pub(crate) fn is_empty(&self) -> bool { - self.stack.is_empty() - } - pub(crate) fn assigned_value(&self, solvable_id: InternalSolvableId) -> Option { self.map.value(solvable_id) } @@ -77,6 +73,11 @@ impl DecisionTracker { } pub(crate) fn undo_until(&mut self, level: u32) { + if level == 0 { + self.clear(); + return; + } + while let Some(decision) = self.stack.last() { if self.level(decision.solvable_id) <= level { break; diff --git a/src/solver/mod.rs b/src/solver/mod.rs index 45205cc..66bed80 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -9,12 +9,12 @@ use std::{any::Any, cell::RefCell, collections::HashSet, future::ready, ops::Con use watch_map::WatchMap; use crate::{ + conflict::Conflict, internal::{ arena::Arena, id::{ClauseId, InternalSolvableId, LearntClauseId, NameId, SolvableId}, mapping::Mapping, }, - problem::Problem, runtime::{AsyncRuntime, NowOrNeverRuntime}, Candidates, Dependencies, DependencyProvider, KnownDependencies, Requirement, VersionSetId, }; @@ -34,7 +34,30 @@ struct AddClauseOutput { clauses_to_watch: Vec, } -/// Drives the SAT solving process +/// Describes the problem that is to be solved by the solver. +#[derive(Default)] +pub struct Problem { + /// The requirements that _must_ have one candidate solvable be included in the + /// solution. + pub requirements: Vec, + + /// Additional constraints imposed on individual packages that the solvable (if any) + /// chosen for that package _must_ adhere to. + pub constraints: Vec, + + /// A set of additional requirements that the solver should _try_ and fulfill once it has + /// found a solution to the main problem. + /// + /// An unsatisfiable soft requirement does not cause a conflict; the solver will try + /// and fulfill as many soft requirements as possible and skip the unsatisfiable ones. + /// + /// Soft requirements are currently only specified as individual solvables to be + /// included in the solution, however in the future they will be able to be specified + /// as version sets. + pub soft_requirements: Vec, +} + +/// Drives the SAT solving process. pub struct Solver { pub(crate) async_runtime: RT, pub(crate) cache: SolverCache, @@ -88,13 +111,13 @@ impl Solver { #[derive(Debug)] pub enum UnsolvableOrCancelled { /// The problem was unsolvable. - Unsolvable(Problem), + Unsolvable(Conflict), /// The solving process was cancelled. Cancelled(Box), } -impl From for UnsolvableOrCancelled { - fn from(value: Problem) -> Self { +impl From for UnsolvableOrCancelled { + fn from(value: Conflict) -> Self { UnsolvableOrCancelled::Unsolvable(value) } } @@ -155,57 +178,70 @@ impl Solver { } } - /// Solves for the provided `root_requirements` and `root_constraints`. The - /// `root_requirements` are package that will be included in the - /// solution. `root_constraints` are additional constrains which do not - /// necesarily need to be included in the solution. + /// Solves the given [`Problem`]. /// - /// Returns a [`Problem`] if no solution was found, which provides ways to - /// inspect the causes and report them to the user. - pub fn solve( - &mut self, - root_requirements: Vec, - root_constraints: Vec, - ) -> Result, UnsolvableOrCancelled> { - // Clear state + /// The solver first solves for the root requirements and constraints, and then + /// tries to include in the solution as many of the soft requirements as it can. + /// Each soft requirement is subject to all the clauses and decisions introduced + /// for all the previously decided solvables in the solution. + /// + /// Unless the corresponding package has been requested by a version set in another + /// solvable's clauses, each soft requirement is _not_ subject to the + /// package-level clauses introduced in [`DependencyProvider::get_candidates`] since the + /// solvables have been requested specifically (not through a version set) in the solution. + /// + /// # Returns + /// + /// If a solution was found, returns a [`Vec`] of the solvables included in the solution. + /// + /// If no solution to the _root_ requirements and constraints was found, returns a + /// [`Conflict`] wrapped in a [`UnsolvableOrCancelled::Unsolvable`], which provides ways to + /// inspect the causes and report them to the user. If a soft requirement is unsolvable, + /// it is simply not included in the solution. + /// + /// If the solution process is cancelled (see [`DependencyProvider::should_cancel_with_value`]), + /// returns an [`UnsolvableOrCancelled::Cancelled`] containing the cancellation value. + pub fn solve(&mut self, problem: Problem) -> Result, UnsolvableOrCancelled> { self.decision_tracker.clear(); self.negative_assertions.clear(); self.learnt_clauses.clear(); self.learnt_why = Mapping::new(); self.clauses = Default::default(); - self.root_requirements = root_requirements; - self.root_constraints = root_constraints; + self.root_requirements = problem.requirements; + self.root_constraints = problem.constraints; // The first clause will always be the install root clause. Here we verify that // this is indeed the case. let root_clause = self.clauses.borrow_mut().alloc(ClauseState::root()); assert_eq!(root_clause, ClauseId::install_root()); - // Run SAT - self.run_sat()?; + assert!( + self.run_sat(InternalSolvableId::root())?, + "bug: Since root is the first requested solvable, \ + should have returned Err instead of Ok(false) if root is unsolvable" + ); - let steps: Vec = self - .decision_tracker - .stack() - .filter_map(|d| { - if d.value { - d.solvable_id.as_solvable() - } else { - // Ignore things that are set to false - None - } - }) - .collect(); + for additional in problem.soft_requirements { + let additional = additional.into(); - tracing::trace!("Solvables found:"); - for step in &steps { - tracing::trace!( - " - {}", - InternalSolvableId::from(*step).display(self.provider()) - ); + if self.decision_tracker.assigned_value(additional).is_none() { + self.run_sat(additional)?; + } } - Ok(steps) + Ok(self.chosen_solvables().collect()) + } + + /// Returns the solvables that the solver has chosen to include in the solution so far. + fn chosen_solvables(&self) -> impl Iterator + '_ { + self.decision_tracker.stack().filter_map(|d| { + if d.value { + d.solvable_id.as_solvable() + } else { + // Ignore things that are set to false + None + } + }) } /// Adds clauses for a solvable. These clauses include requirements and @@ -410,15 +446,6 @@ impl Solver { let locked_solvable_id = package_candidates.locked; let candidates = &package_candidates.candidates; - // Check the assumption that no decision has been made about any of the - // solvables. - for &candidate in candidates { - debug_assert!( - self.decision_tracker.assigned_value(candidate.into()).is_none(), - "a decision has been made about a candidate of a package that was not properly added yet." - ); - } - // Each candidate gets a clause to disallow other candidates. for (i, &candidate) in candidates.iter().enumerate() { for &other_candidate in &candidates[i + 1..] { @@ -581,32 +608,49 @@ impl Solver { /// implementation of this step. /// /// The solver loop can be found in [`Solver::resolve_dependencies`]. - fn run_sat(&mut self) -> Result<(), UnsolvableOrCancelled> { - assert!(self.decision_tracker.is_empty()); - let mut level = 0; + /// + /// Returns `Ok(true)` if a solution was found for `solvable`. If a solution was not + /// found, returns `Ok(false)` if some decisions have already been made by the solver + /// (i.e. the decision tracker stack is not empty). Otherwise, returns + /// [`UnsolvableOrCancelled::Unsolvable`] as an `Err` on no solution. + /// + /// If the solution process is cancelled (see [`DependencyProvider::should_cancel_with_value`]), + /// returns [`UnsolvableOrCancelled::Cancelled`] as an `Err`. + fn run_sat(&mut self, solvable: InternalSolvableId) -> Result { + let starting_level = self + .decision_tracker + .stack() + .next_back() + .map(|decision| self.decision_tracker.level(decision.solvable_id)) + .unwrap_or(0); + + let mut level = starting_level; loop { - if level == 0 { - tracing::trace!("Level 0: Resetting the decision loop"); + if level == starting_level { + tracing::trace!("Level {starting_level}: Resetting the decision loop"); } else { tracing::trace!("Level {}: Starting the decision loop", level); } - // A level of 0 means the decision loop has been completely reset because a - // partial solution was invalidated by newly added clauses. - if level == 0 { - // Level 1 is the initial decision level - level = 1; + // A level of starting_level means the decision loop has been completely reset + // because a partial solution was invalidated by newly added clauses. + if level == starting_level { + // Level starting_level + 1 is the initial decision level + level = starting_level + 1; // Assign `true` to the root solvable. This must be installed to satisfy the // solution. The root solvable contains the dependencies that // were injected when calling `Solver::solve`. If we can find a // solution were the root is installable we found a // solution that satisfies the user requirements. - tracing::trace!("╤══ Install at level {level}",); + tracing::trace!( + "╤══ Install {} at level {level}", + solvable.display(self.provider()) + ); self.decision_tracker .try_add_decision( - Decision::new(InternalSolvableId::root(), true, ClauseId::install_root()), + Decision::new(solvable, true, ClauseId::install_root()), level, ) .expect("already decided"); @@ -614,12 +658,9 @@ impl Solver { // Add the clauses for the root solvable. let output = self .async_runtime - .block_on(self.add_clauses_for_solvables(vec![InternalSolvableId::root()]))?; + .block_on(self.add_clauses_for_solvables([solvable]))?; if let Err(clause_id) = self.process_add_clause_output(output) { - tracing::trace!("Unsolvable: {:?}", clause_id); - return Err(UnsolvableOrCancelled::Unsolvable( - self.analyze_unsolvable(clause_id), - )); + return self.run_sat_process_unsolvable(solvable, starting_level, clause_id); } } @@ -634,17 +675,19 @@ impl Solver { match propagate_result { Ok(()) => {} Err(PropagationError::Conflict(_, _, clause_id)) => { - if level == 1 { - return Err(UnsolvableOrCancelled::Unsolvable( - self.analyze_unsolvable(clause_id), - )); + if level == starting_level + 1 { + return self.run_sat_process_unsolvable( + solvable, + starting_level, + clause_id, + ); } else { // The conflict was caused because new clauses have been added dynamically. // We need to start over. tracing::debug!("├─ added clause {clause} introduces a conflict which invalidates the partial solution", clause=self.clauses.borrow()[clause_id].display(self.provider())); - level = 0; - self.decision_tracker.clear(); + level = starting_level; + self.decision_tracker.undo_until(starting_level); continue; } } @@ -687,7 +730,7 @@ impl Solver { "Level {}: No new solvables selected, solution is complete", level ); - return Ok(()); + return Ok(true); } tracing::debug!("==== Found newly selected solvables"); @@ -716,12 +759,41 @@ impl Solver { } if let Err(_first_conflicting_clause_id) = self.process_add_clause_output(output) { - self.decision_tracker.clear(); - level = 0; + self.decision_tracker.undo_until(starting_level); + level = starting_level; } } } + /// Decides how to terminate the solver algorithm when the given `solvable` was + /// deemed unsolvable by [`Solver::run_sat`]. + /// + /// Returns an `Err` value of [`UnsolvableOrCancelled::Unsolvable`] only if `solvable` is + /// the very first solvable we are solving for. Otherwise, undoes all the decisions made + /// when trying to solve for `solvable`, sets it to `false` and returns `Ok(false)`. + fn run_sat_process_unsolvable( + &mut self, + solvable: InternalSolvableId, + starting_level: u32, + clause_id: ClauseId, + ) -> Result { + if starting_level == 0 { + tracing::trace!("Unsolvable: {:?}", clause_id); + Err(UnsolvableOrCancelled::Unsolvable( + self.analyze_unsolvable(clause_id), + )) + } else { + self.decision_tracker.undo_until(starting_level); + self.decision_tracker + .try_add_decision( + Decision::new(solvable, false, ClauseId::install_root()), + starting_level + 1, + ) + .expect("bug: already decided - decision should have been undone"); + Ok(false) + } + } + fn process_add_clause_output(&mut self, mut output: AddClauseOutput) -> Result<(), ClauseId> { let mut clauses = self.clauses.borrow_mut(); for clause_id in output.clauses_to_watch { @@ -855,7 +927,7 @@ impl Solver { /// CDCL algorithm. /// /// Returns the new level after this set-propagate-learn round, or a - /// [`Problem`] if we discovered that the requested jobs are + /// [`Conflict`] if we discovered that the requested jobs are /// unsatisfiable. fn set_propagate_learn( &mut self, @@ -915,7 +987,7 @@ impl Solver { conflicting_solvable: InternalSolvableId, attempted_value: bool, conflicting_clause: ClauseId, - ) -> Result { + ) -> Result { { tracing::info!( "├─ Propagation conflicted: could not set {solvable} to {attempted_value}", @@ -1161,15 +1233,15 @@ impl Solver { Ok(()) } - /// Adds the clause with `clause_id` to the current `Problem` + /// Adds the clause with `clause_id` to the current [`Conflict`] /// /// Because learnt clauses are not relevant for the user, they are not added - /// to the `Problem`. Instead, we report the clauses that caused them. + /// to the [`Conflict`]. Instead, we report the clauses that caused them. fn analyze_unsolvable_clause( clauses: &Arena, learnt_why: &Mapping>, clause_id: ClauseId, - problem: &mut Problem, + conflict: &mut Conflict, seen: &mut HashSet, ) { let clause = &clauses[clause_id]; @@ -1183,21 +1255,21 @@ impl Solver { .get(learnt_clause_id) .expect("no cause for learnt clause available") { - Self::analyze_unsolvable_clause(clauses, learnt_why, cause, problem, seen); + Self::analyze_unsolvable_clause(clauses, learnt_why, cause, conflict, seen); } } - _ => problem.add_clause(clause_id), + _ => conflict.add_clause(clause_id), } } - /// Create a [`Problem`] based on the id of the clause that triggered an + /// Create a [`Conflict`] based on the id of the clause that triggered an /// unrecoverable conflict - fn analyze_unsolvable(&mut self, clause_id: ClauseId) -> Problem { + fn analyze_unsolvable(&mut self, clause_id: ClauseId) -> Conflict { let last_decision = self.decision_tracker.stack().last().unwrap(); let highest_level = self.decision_tracker.level(last_decision.solvable_id); debug_assert_eq!(highest_level, 1); - let mut problem = Problem::default(); + let mut conflict = Conflict::default(); tracing::info!("=== ANALYZE UNSOLVABLE"); @@ -1215,7 +1287,7 @@ impl Solver { &self.clauses.borrow(), &self.learnt_why, clause_id, - &mut problem, + &mut conflict, &mut seen, ); @@ -1236,7 +1308,7 @@ impl Solver { &self.clauses.borrow(), &self.learnt_why, why, - &mut problem, + &mut conflict, &mut seen, ); @@ -1253,7 +1325,7 @@ impl Solver { ); } - problem + conflict } /// Analyze the causes of the conflict and learn from it diff --git a/tests/solver.rs b/tests/solver.rs index f5708d0..6aa1cea 100644 --- a/tests/solver.rs +++ b/tests/solver.rs @@ -22,8 +22,8 @@ use itertools::Itertools; use resolvo::{ snapshot::{DependencySnapshot, SnapshotProvider}, utils::{Pool, Range}, - Candidates, Dependencies, DependencyProvider, Interner, KnownDependencies, NameId, Requirement, - SolvableId, Solver, SolverCache, StringId, UnsolvableOrCancelled, VersionSetId, + Candidates, Dependencies, DependencyProvider, Interner, KnownDependencies, NameId, Problem, + Requirement, SolvableId, Solver, SolverCache, StringId, UnsolvableOrCancelled, VersionSetId, VersionSetUnionId, }; use tracing_test::traced_test; @@ -179,6 +179,7 @@ struct BundleBoxProvider { // duplicate requests. requested_candidates: RefCell>, requested_dependencies: RefCell>, + interned_solvables: RefCell>, } #[derive(Debug, Clone)] @@ -313,6 +314,18 @@ impl BundleBoxProvider { .collect::>(); DependencySnapshot::from_provider(self, name_ids, [], []).unwrap() } + + pub fn intern_solvable(&self, name_id: NameId, pack: Pack) -> SolvableId { + *self + .interned_solvables + .borrow_mut() + .entry((name_id, pack)) + .or_insert_with_key(|&(name_id, pack)| self.pool.intern_solvable(name_id, pack)) + } + + pub fn solvable_id(&self, name: impl Into, version: impl Into) -> SolvableId { + self.intern_solvable(self.pool.intern_package_name(name.into()), version.into()) + } } impl Interner for BundleBoxProvider { @@ -411,7 +424,7 @@ impl DependencyProvider for BundleBoxProvider { let locked = self.locked.get(package_name); let excluded = self.excluded.get(package_name); for pack in package.keys() { - let solvable = self.pool.intern_solvable(name, *pack); + let solvable = self.intern_solvable(name, *pack); candidates.candidates.push(solvable); if Some(pack) == favor { candidates.favored = Some(solvable); @@ -539,15 +552,19 @@ fn transaction_to_string(interner: &impl Interner, solvables: &Vec) buf } -/// Unsat so that we can view the problem +/// Unsat so that we can view the conflict fn solve_unsat(provider: BundleBoxProvider, specs: &[&str]) -> String { let requirements = provider.requirements(specs); let mut solver = Solver::new(provider); - match solver.solve(requirements, Vec::new()) { + let problem = Problem { + requirements, + ..Default::default() + }; + match solver.solve(problem) { Ok(_) => panic!("expected unsat, but a solution was found"), - Err(UnsolvableOrCancelled::Unsolvable(problem)) => { - // Write the problem graphviz to stderr - let graph = problem.graph(&solver); + Err(UnsolvableOrCancelled::Unsolvable(conflict)) => { + // Write the conflict graphviz to stderr + let graph = conflict.graph(&solver); let mut output = stderr(); writeln!(output, "UNSOLVABLE:").unwrap(); graph @@ -556,7 +573,7 @@ fn solve_unsat(provider: BundleBoxProvider, specs: &[&str]) -> String { writeln!(output, "\n").unwrap(); // Format a user friendly error message - problem.display_user_friendly(&solver).to_string() + conflict.display_user_friendly(&solver).to_string() } Err(UnsolvableOrCancelled::Cancelled(reason)) => *reason.downcast().unwrap(), } @@ -575,11 +592,15 @@ fn solve_snapshot(mut provider: BundleBoxProvider, specs: &[&str]) -> String { let requirements = provider.parse_requirements(specs); let mut solver = Solver::new(provider).with_runtime(runtime); - match solver.solve(requirements, Vec::new()) { + let problem = Problem { + requirements, + ..Default::default() + }; + match solver.solve(problem) { Ok(solvables) => transaction_to_string(solver.provider(), &solvables), - Err(UnsolvableOrCancelled::Unsolvable(problem)) => { - // Write the problem graphviz to stderr - let graph = problem.graph(&solver); + Err(UnsolvableOrCancelled::Unsolvable(conflict)) => { + // Write the conflict graphviz to stderr + let graph = conflict.graph(&solver); let mut output = stderr(); writeln!(output, "UNSOLVABLE:").unwrap(); graph @@ -588,7 +609,7 @@ fn solve_snapshot(mut provider: BundleBoxProvider, specs: &[&str]) -> String { writeln!(output, "\n").unwrap(); // Format a user friendly error message - problem.display_user_friendly(&solver).to_string() + conflict.display_user_friendly(&solver).to_string() } Err(UnsolvableOrCancelled::Cancelled(reason)) => *reason.downcast().unwrap(), } @@ -598,9 +619,13 @@ fn solve_snapshot(mut provider: BundleBoxProvider, specs: &[&str]) -> String { #[test] fn test_unit_propagation_1() { let provider = BundleBoxProvider::from_packages(&[("asdf", 1, vec![])]); - let root_requirements = provider.requirements(&["asdf"]); + let requirements = provider.requirements(&["asdf"]); let mut solver = Solver::new(provider); - let solved = solver.solve(root_requirements, Vec::new()).unwrap(); + let problem = Problem { + requirements, + ..Default::default() + }; + let solved = solver.solve(problem).unwrap(); let pool = &solver.provider().pool; assert_eq!(solved.len(), 1); @@ -620,7 +645,11 @@ fn test_unit_propagation_nested() { ]); let requirements = provider.requirements(&["asdf"]); let mut solver = Solver::new(provider); - let solved = solver.solve(requirements, Vec::new()).unwrap(); + let problem = Problem { + requirements, + ..Default::default() + }; + let solved = solver.solve(problem).unwrap(); let pool = &solver.provider().pool; assert_eq!(solved.len(), 2); @@ -647,7 +676,11 @@ fn test_resolve_multiple() { ]); let requirements = provider.requirements(&["asdf", "efgh"]); let mut solver = Solver::new(provider); - let solved = solver.solve(requirements, Vec::new()).unwrap(); + let problem = Problem { + requirements, + ..Default::default() + }; + let solved = solver.solve(problem).unwrap(); let pool = &solver.provider().pool; assert_eq!(solved.len(), 2); @@ -705,7 +738,11 @@ fn test_resolve_with_nonexisting() { ]); let requirements = provider.requirements(&["asdf"]); let mut solver = Solver::new(provider); - let solved = solver.solve(requirements, Vec::new()).unwrap(); + let problem = Problem { + requirements, + ..Default::default() + }; + let solved = solver.solve(problem).unwrap(); let pool = &solver.provider().pool; assert_eq!(solved.len(), 1); @@ -739,7 +776,11 @@ fn test_resolve_with_nested_deps() { ]); let requirements = provider.requirements(&["apache-airflow"]); let mut solver = Solver::new(provider); - let solved = solver.solve(requirements, Vec::new()).unwrap(); + let problem = Problem { + requirements, + ..Default::default() + }; + let solved = solver.solve(problem).unwrap(); let pool = &solver.provider().pool; assert_eq!(solved.len(), 1); @@ -763,7 +804,11 @@ fn test_resolve_with_unknown_deps() { provider.add_package("opentelemetry-api", Pack::new(2), &[], &[]); let requirements = provider.requirements(&["opentelemetry-api"]); let mut solver = Solver::new(provider); - let solved = solver.solve(requirements, Vec::new()).unwrap(); + let problem = Problem { + requirements, + ..Default::default() + }; + let solved = solver.solve(problem).unwrap(); let pool = &solver.provider().pool; assert_eq!(solved.len(), 1); @@ -808,7 +853,11 @@ fn test_resolve_locked_top_level() { let requirements = provider.requirements(&["asdf"]); let mut solver = Solver::new(provider); - let solved = solver.solve(requirements, Vec::new()).unwrap(); + let problem = Problem { + requirements, + ..Default::default() + }; + let solved = solver.solve(problem).unwrap(); let pool = &solver.provider().pool; assert_eq!(solved.len(), 1); @@ -830,7 +879,11 @@ fn test_resolve_ignored_locked_top_level() { let requirements = provider.requirements(&["asdf"]); let mut solver = Solver::new(provider); - let solved = solver.solve(requirements, Vec::new()).unwrap(); + let problem = Problem { + requirements, + ..Default::default() + }; + let solved = solver.solve(problem).unwrap(); let pool = &solver.provider().pool; assert_eq!(solved.len(), 1); @@ -888,7 +941,11 @@ fn test_resolve_cyclic() { BundleBoxProvider::from_packages(&[("a", 2, vec!["b 0..10"]), ("b", 5, vec!["a 2..4"])]); let requirements = provider.requirements(&["a 0..100"]); let mut solver = Solver::new(provider); - let solved = solver.solve(requirements, Vec::new()).unwrap(); + let problem = Problem { + requirements, + ..Default::default() + }; + let solved = solver.solve(problem).unwrap(); let result = transaction_to_string(solver.provider(), &solved); insta::assert_snapshot!(result, @r###" @@ -1171,7 +1228,12 @@ fn test_constraints() { let requirements = provider.requirements(&["a 0..10"]); let constraints = provider.requirements(&["b 1..2", "c"]); let mut solver = Solver::new(provider); - let solved = solver.solve(requirements, constraints).unwrap(); + let problem = Problem { + requirements, + constraints, + ..Default::default() + }; + let solved = solver.solve(problem).unwrap(); let result = transaction_to_string(solver.provider(), &solved); insta::assert_snapshot!(result, @r###" @@ -1180,6 +1242,107 @@ fn test_constraints() { "###); } +#[test] +fn test_solve_with_additional() { + let mut provider = BundleBoxProvider::from_packages(&[ + ("a", 1, vec!["b 0..10"]), + ("b", 1, vec![]), + ("b", 2, vec![]), + ("c", 1, vec![]), + ("d", 1, vec![]), + ("e", 1, vec!["d"]), + ("locked", 1, vec![]), + ("locked", 2, vec![]), + ]); + + provider.set_locked("locked", 2); + + let requirements = provider.requirements(&["a 0..10"]); + let constraints = provider.requirements(&["b 1..2", "c"]); + + let extra_solvables = [ + provider.solvable_id("b", 2), + provider.solvable_id("c", 1), + provider.solvable_id("e", 1), + // Does not obey the locked clause since it has not been requested + // in a version set by another solvable + provider.solvable_id("locked", 1), + provider.solvable_id("unknown-deps", Pack::new(1).with_unknown_deps()), + ]; + + let mut solver = Solver::new(provider); + + let problem = Problem { + requirements, + constraints, + soft_requirements: extra_solvables.to_vec(), + }; + let solved = solver.solve(problem).unwrap(); + + let result = transaction_to_string(solver.provider(), &solved); + assert_snapshot!(result, @r###" + a=1 + b=1 + c=1 + d=1 + e=1 + locked=1 + "###); +} + +#[test] +fn test_solve_with_additional_with_constrains() { + let mut provider = BundleBoxProvider::from_packages(&[ + ("a", 1, vec!["b 0..10"]), + ("b", 1, vec![]), + ("b", 2, vec![]), + ("b", 3, vec![]), + ("c", 1, vec![]), + ("d", 1, vec!["f"]), + ("e", 1, vec!["c"]), + ]); + + provider.add_package("f", 1.into(), &[], &["c 2..3"]); + provider.add_package("g", 1.into(), &[], &["b 2..3"]); + provider.add_package("h", 1.into(), &[], &["b 1..2"]); + provider.add_package("i", 1.into(), &[], &[]); + provider.add_package("j", 1.into(), &["i"], &[]); + provider.add_package("k", 1.into(), &["i"], &[]); + provider.add_package("l", 1.into(), &["j", "k"], &[]); + + let requirements = provider.requirements(&["a 0..10", "e"]); + let constraints = provider.requirements(&["b 1..2", "c", "k 2..3"]); + + let extra_solvables = [ + provider.solvable_id("d", 1), + provider.solvable_id("g", 1), + provider.solvable_id("h", 1), + provider.solvable_id("j", 1), + provider.solvable_id("l", 1), + provider.solvable_id("k", 1), + ]; + + let mut solver = Solver::new(provider); + + let problem = Problem { + requirements, + constraints, + soft_requirements: extra_solvables.to_vec(), + }; + let solved = solver.solve(problem).unwrap(); + + let result = transaction_to_string(solver.provider(), &solved); + assert_snapshot!(result, @r###" + a=1 + b=1 + c=1 + e=1 + h=1 + i=1 + j=1 + "###); +} + #[test] fn test_snapshot() { let provider = BundleBoxProvider::from_packages(&[ @@ -1241,14 +1404,15 @@ fn serialize_snapshot(snapshot: &DependencySnapshot, destination: impl AsRef String { let mut solver = Solver::new(provider); - match solver.solve( - root_reqs.iter().copied().map(Into::into).collect(), - Vec::new(), - ) { + let problem = Problem { + requirements: root_reqs.iter().copied().map(Into::into).collect(), + ..Default::default() + }; + match solver.solve(problem) { Ok(solvables) => transaction_to_string(solver.provider(), &solvables), - Err(UnsolvableOrCancelled::Unsolvable(problem)) => { - // Write the problem graphviz to stderr - let graph = problem.graph(&solver); + Err(UnsolvableOrCancelled::Unsolvable(conflict)) => { + // Write the conflict graphviz to stderr + let graph = conflict.graph(&solver); let mut output = stderr(); writeln!(output, "UNSOLVABLE:").unwrap(); graph @@ -1257,7 +1421,7 @@ fn solve_for_snapshot(provider: SnapshotProvider, root_reqs: &[VersionSetId]) -> writeln!(output, "\n").unwrap(); // Format a user friendly error message - problem.display_user_friendly(&solver).to_string() + conflict.display_user_friendly(&solver).to_string() } Err(UnsolvableOrCancelled::Cancelled(reason)) => *reason.downcast().unwrap(), } From 0681a0f38a92bfee73d9d33d6f2ee7aefb9ece4a Mon Sep 17 00:00:00 2001 From: Bas Zalmstra Date: Tue, 6 Aug 2024 16:27:17 +0200 Subject: [PATCH 6/7] chore: release v0.7.0 (#52) --- CHANGELOG.md | 16 ++++++++++++++++ Cargo.lock | 2 +- Cargo.toml | 2 +- cpp/Cargo.toml | 2 +- 4 files changed, 19 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96694b1..3fe1a53 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.7.0](https://github.com/mamba-org/resolvo/compare/resolvo-v0.6.2...resolvo-v0.7.0) - 2024-08-06 + +### Added +- *(solver)* [**breaking**] Solve for optional solvables in addition to the root solvable ([#54](https://github.com/mamba-org/resolvo/pull/54)) +- [**breaking**] Version set unions as solvable requirements ([#56](https://github.com/mamba-org/resolvo/pull/56)) + +### Fixed +- Fix off-by-one error in `Mapping::serialize` ([#58](https://github.com/mamba-org/resolvo/pull/58)) + +### Other +- *(ci)* bump prefix-dev/rattler-build-action from 0.2.12 to 0.2.13 ([#59](https://github.com/mamba-org/resolvo/pull/59)) +- *(ci)* bump prefix-dev/rattler-build-action from 0.2.11 to 0.2.12 ([#57](https://github.com/mamba-org/resolvo/pull/57)) +- Add more tracing ([#55](https://github.com/mamba-org/resolvo/pull/55)) +- *(ci)* bump prefix-dev/rattler-build-action from 0.2.10 to 0.2.11 ([#53](https://github.com/mamba-org/resolvo/pull/53)) +- *(ci)* bump prefix-dev/rattler-build-action from 0.2.9 to 0.2.10 ([#51](https://github.com/mamba-org/resolvo/pull/51)) + ## [0.6.2](https://github.com/mamba-org/resolvo/compare/resolvo-v0.6.1...resolvo-v0.6.2) - 2024-06-11 ### Added diff --git a/Cargo.lock b/Cargo.lock index a42f4cd..ad81fe7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1153,7 +1153,7 @@ checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" [[package]] name = "resolvo" -version = "0.6.2" +version = "0.7.0" dependencies = [ "ahash", "async-std", diff --git a/Cargo.toml b/Cargo.toml index 4375bea..eda7af9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = ["cpp", "tools/*"] resolver = "2" [workspace.package] -version = "0.6.2" +version = "0.7.0" authors = ["Adolfo Ochagavía ", "Bas Zalmstra ", "Tim de Jager "] homepage = "https://github.com/mamba-org/resolvo" repository = "https://github.com/mamba-org/resolvo" diff --git a/cpp/Cargo.toml b/cpp/Cargo.toml index 7357387..ed7e179 100644 --- a/cpp/Cargo.toml +++ b/cpp/Cargo.toml @@ -16,7 +16,7 @@ publish = false crate-type = ["lib", "cdylib", "staticlib"] [dependencies] -resolvo = { version = "0.6.2", path = "../" } +resolvo = { version = "0.7.0", path = "../" } [build-dependencies] anyhow = "1" From 2700ea08b7f2bb9a4e0e99b9594a0cd83316c11a Mon Sep 17 00:00:00 2001 From: Arav Singhal Date: Mon, 12 Aug 2024 03:58:36 -0500 Subject: [PATCH 7/7] feat(solver)!: Make `Problem` use builder pattern (#62) --- cpp/src/lib.rs | 44 +++++++++++++----------- src/solver/mod.rs | 77 ++++++++++++++++++++++++++++++++++++----- tests/solver.rs | 87 +++++++++++++---------------------------------- 3 files changed, 116 insertions(+), 92 deletions(-) diff --git a/cpp/src/lib.rs b/cpp/src/lib.rs index a0ab2ac..b732cf1 100644 --- a/cpp/src/lib.rs +++ b/cpp/src/lib.rs @@ -490,26 +490,30 @@ pub extern "C" fn resolvo_solve( ) -> bool { let mut solver = resolvo::Solver::new(provider); - let problem = resolvo::Problem { - requirements: problem - .requirements - .into_iter() - .copied() - .map(Into::into) - .collect(), - constraints: problem - .constraints - .into_iter() - .copied() - .map(Into::into) - .collect(), - soft_requirements: problem - .soft_requirements - .into_iter() - .copied() - .map(Into::into) - .collect(), - }; + let problem = resolvo::Problem::new() + .requirements( + problem + .requirements + .into_iter() + .copied() + .map(Into::into) + .collect(), + ) + .constraints( + problem + .constraints + .into_iter() + .copied() + .map(Into::into) + .collect(), + ) + .soft_requirements( + problem + .soft_requirements + .into_iter() + .copied() + .map(Into::into), + ); match solver.solve(problem) { Ok(solution) => { diff --git a/src/solver/mod.rs b/src/solver/mod.rs index 66bed80..0ed9ffc 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -35,17 +35,60 @@ struct AddClauseOutput { } /// Describes the problem that is to be solved by the solver. -#[derive(Default)] -pub struct Problem { - /// The requirements that _must_ have one candidate solvable be included in the +/// +/// This struct is generic over the type `S` of the collection of soft requirements passed +/// to the solver, typically expected to be a type implementing [`IntoIterator`]. +/// +/// This struct follows the builder pattern and can have its fields set by one of the available +/// setter methods. +pub struct Problem { + requirements: Vec, + constraints: Vec, + soft_requirements: S, +} + +impl Default for Problem> { + fn default() -> Self { + Self::new() + } +} + +impl Problem> { + /// Creates a new empty [`Problem`]. Use the setter methods to build the problem + /// before passing it to the solver to be solved. + pub fn new() -> Self { + Self { + requirements: Default::default(), + constraints: Default::default(), + soft_requirements: Default::default(), + } + } +} + +impl> Problem { + /// Sets the requirements that _must_ have one candidate solvable be included in the /// solution. - pub requirements: Vec, + /// + /// Returns the [`Problem`] for further mutation or to pass to [`Solver::solve`]. + pub fn requirements(self, requirements: Vec) -> Self { + Self { + requirements, + ..self + } + } - /// Additional constraints imposed on individual packages that the solvable (if any) + /// Sets the additional constraints imposed on individual packages that the solvable (if any) /// chosen for that package _must_ adhere to. - pub constraints: Vec, + /// + /// Returns the [`Problem`] for further mutation or to pass to [`Solver::solve`]. + pub fn constraints(self, constraints: Vec) -> Self { + Self { + constraints, + ..self + } + } - /// A set of additional requirements that the solver should _try_ and fulfill once it has + /// Sets the additional requirements that the solver should _try_ and fulfill once it has /// found a solution to the main problem. /// /// An unsatisfiable soft requirement does not cause a conflict; the solver will try @@ -54,7 +97,20 @@ pub struct Problem { /// Soft requirements are currently only specified as individual solvables to be /// included in the solution, however in the future they will be able to be specified /// as version sets. - pub soft_requirements: Vec, + /// + /// # Returns + /// + /// Returns the [`Problem`] for further mutation or to pass to [`Solver::solve`]. + pub fn soft_requirements>( + self, + soft_requirements: I, + ) -> Problem { + Problem { + requirements: self.requirements, + constraints: self.constraints, + soft_requirements, + } + } } /// Drives the SAT solving process. @@ -201,7 +257,10 @@ impl Solver { /// /// If the solution process is cancelled (see [`DependencyProvider::should_cancel_with_value`]), /// returns an [`UnsolvableOrCancelled::Cancelled`] containing the cancellation value. - pub fn solve(&mut self, problem: Problem) -> Result, UnsolvableOrCancelled> { + pub fn solve( + &mut self, + problem: Problem>, + ) -> Result, UnsolvableOrCancelled> { self.decision_tracker.clear(); self.negative_assertions.clear(); self.learnt_clauses.clear(); diff --git a/tests/solver.rs b/tests/solver.rs index 6aa1cea..fe37b83 100644 --- a/tests/solver.rs +++ b/tests/solver.rs @@ -556,10 +556,7 @@ fn transaction_to_string(interner: &impl Interner, solvables: &Vec) fn solve_unsat(provider: BundleBoxProvider, specs: &[&str]) -> String { let requirements = provider.requirements(specs); let mut solver = Solver::new(provider); - let problem = Problem { - requirements, - ..Default::default() - }; + let problem = Problem::new().requirements(requirements); match solver.solve(problem) { Ok(_) => panic!("expected unsat, but a solution was found"), Err(UnsolvableOrCancelled::Unsolvable(conflict)) => { @@ -592,10 +589,7 @@ fn solve_snapshot(mut provider: BundleBoxProvider, specs: &[&str]) -> String { let requirements = provider.parse_requirements(specs); let mut solver = Solver::new(provider).with_runtime(runtime); - let problem = Problem { - requirements, - ..Default::default() - }; + let problem = Problem::new().requirements(requirements); match solver.solve(problem) { Ok(solvables) => transaction_to_string(solver.provider(), &solvables), Err(UnsolvableOrCancelled::Unsolvable(conflict)) => { @@ -621,10 +615,7 @@ fn test_unit_propagation_1() { let provider = BundleBoxProvider::from_packages(&[("asdf", 1, vec![])]); let requirements = provider.requirements(&["asdf"]); let mut solver = Solver::new(provider); - let problem = Problem { - requirements, - ..Default::default() - }; + let problem = Problem::new().requirements(requirements); let solved = solver.solve(problem).unwrap(); let pool = &solver.provider().pool; @@ -645,10 +636,7 @@ fn test_unit_propagation_nested() { ]); let requirements = provider.requirements(&["asdf"]); let mut solver = Solver::new(provider); - let problem = Problem { - requirements, - ..Default::default() - }; + let problem = Problem::new().requirements(requirements); let solved = solver.solve(problem).unwrap(); let pool = &solver.provider().pool; @@ -676,10 +664,7 @@ fn test_resolve_multiple() { ]); let requirements = provider.requirements(&["asdf", "efgh"]); let mut solver = Solver::new(provider); - let problem = Problem { - requirements, - ..Default::default() - }; + let problem = Problem::new().requirements(requirements); let solved = solver.solve(problem).unwrap(); let pool = &solver.provider().pool; @@ -738,10 +723,7 @@ fn test_resolve_with_nonexisting() { ]); let requirements = provider.requirements(&["asdf"]); let mut solver = Solver::new(provider); - let problem = Problem { - requirements, - ..Default::default() - }; + let problem = Problem::new().requirements(requirements); let solved = solver.solve(problem).unwrap(); let pool = &solver.provider().pool; @@ -776,10 +758,7 @@ fn test_resolve_with_nested_deps() { ]); let requirements = provider.requirements(&["apache-airflow"]); let mut solver = Solver::new(provider); - let problem = Problem { - requirements, - ..Default::default() - }; + let problem = Problem::new().requirements(requirements); let solved = solver.solve(problem).unwrap(); let pool = &solver.provider().pool; @@ -804,10 +783,7 @@ fn test_resolve_with_unknown_deps() { provider.add_package("opentelemetry-api", Pack::new(2), &[], &[]); let requirements = provider.requirements(&["opentelemetry-api"]); let mut solver = Solver::new(provider); - let problem = Problem { - requirements, - ..Default::default() - }; + let problem = Problem::new().requirements(requirements); let solved = solver.solve(problem).unwrap(); let pool = &solver.provider().pool; @@ -853,10 +829,7 @@ fn test_resolve_locked_top_level() { let requirements = provider.requirements(&["asdf"]); let mut solver = Solver::new(provider); - let problem = Problem { - requirements, - ..Default::default() - }; + let problem = Problem::new().requirements(requirements); let solved = solver.solve(problem).unwrap(); let pool = &solver.provider().pool; @@ -879,10 +852,7 @@ fn test_resolve_ignored_locked_top_level() { let requirements = provider.requirements(&["asdf"]); let mut solver = Solver::new(provider); - let problem = Problem { - requirements, - ..Default::default() - }; + let problem = Problem::new().requirements(requirements); let solved = solver.solve(problem).unwrap(); let pool = &solver.provider().pool; @@ -941,10 +911,7 @@ fn test_resolve_cyclic() { BundleBoxProvider::from_packages(&[("a", 2, vec!["b 0..10"]), ("b", 5, vec!["a 2..4"])]); let requirements = provider.requirements(&["a 0..100"]); let mut solver = Solver::new(provider); - let problem = Problem { - requirements, - ..Default::default() - }; + let problem = Problem::new().requirements(requirements); let solved = solver.solve(problem).unwrap(); let result = transaction_to_string(solver.provider(), &solved); @@ -1228,11 +1195,9 @@ fn test_constraints() { let requirements = provider.requirements(&["a 0..10"]); let constraints = provider.requirements(&["b 1..2", "c"]); let mut solver = Solver::new(provider); - let problem = Problem { - requirements, - constraints, - ..Default::default() - }; + let problem = Problem::new() + .requirements(requirements) + .constraints(constraints); let solved = solver.solve(problem).unwrap(); let result = transaction_to_string(solver.provider(), &solved); @@ -1272,11 +1237,10 @@ fn test_solve_with_additional() { let mut solver = Solver::new(provider); - let problem = Problem { - requirements, - constraints, - soft_requirements: extra_solvables.to_vec(), - }; + let problem = Problem::new() + .requirements(requirements) + .constraints(constraints) + .soft_requirements(extra_solvables); let solved = solver.solve(problem).unwrap(); let result = transaction_to_string(solver.provider(), &solved); @@ -1324,11 +1288,11 @@ fn test_solve_with_additional_with_constrains() { let mut solver = Solver::new(provider); - let problem = Problem { - requirements, - constraints, - soft_requirements: extra_solvables.to_vec(), - }; + let problem = Problem::new() + .requirements(requirements) + .constraints(constraints) + .soft_requirements(extra_solvables); + let solved = solver.solve(problem).unwrap(); let result = transaction_to_string(solver.provider(), &solved); @@ -1404,10 +1368,7 @@ fn serialize_snapshot(snapshot: &DependencySnapshot, destination: impl AsRef String { let mut solver = Solver::new(provider); - let problem = Problem { - requirements: root_reqs.iter().copied().map(Into::into).collect(), - ..Default::default() - }; + let problem = Problem::new().requirements(root_reqs.iter().copied().map(Into::into).collect()); match solver.solve(problem) { Ok(solvables) => transaction_to_string(solver.provider(), &solvables), Err(UnsolvableOrCancelled::Unsolvable(conflict)) => {