diff --git a/cpp/include/resolvo.h b/cpp/include/resolvo.h index 6a41aa3..48844d6 100644 --- a/cpp/include/resolvo.h +++ b/cpp/include/resolvo.h @@ -4,6 +4,24 @@ #include "resolvo_internal.h" namespace resolvo { +using cbindgen_private::Requirement; + +/** + * Specifies a requirement (dependency) of a single version set. + */ +inline Requirement requirement_single(VersionSetId id) { + return cbindgen_private::resolvo_requirement_single(id); +} + +/** + * Specifies a requirement (dependency) of the union (logical OR) of multiple version sets. + * A solvable belonging to any of the version sets contained in the union satisfies the + * requirement. This variant is typically used for requirements that can be satisfied by two + * or more version sets belonging to different packages. + */ +inline Requirement requirement_union(VersionSetUnionId id) { + return cbindgen_private::resolvo_requirement_union(id); +} /** * Called to solve a package problem. @@ -12,7 +30,7 @@ namespace resolvo { * stored in `result`. If the solve was unsuccesfull an error describing the reason is returned and * the result vector will be empty. */ -inline String solve(DependencyProvider &provider, Slice requirements, +inline String solve(DependencyProvider &provider, Slice requirements, Slice constraints, Vector &result) { cbindgen_private::DependencyProvider bridge{ static_cast(&provider), @@ -24,6 +42,7 @@ inline String solve(DependencyProvider &provider, Slice requiremen private_api::bridge_display_string, private_api::bridge_version_set_name, private_api::bridge_solvable_name, + private_api::bridge_version_sets_in_union, private_api::bridge_get_candidates, private_api::bridge_sort_candidates, private_api::bridge_filter_candidates, diff --git a/cpp/include/resolvo_dependency_provider.h b/cpp/include/resolvo_dependency_provider.h index 8cd9258..9484079 100644 --- a/cpp/include/resolvo_dependency_provider.h +++ b/cpp/include/resolvo_dependency_provider.h @@ -13,6 +13,7 @@ using cbindgen_private::NameId; using cbindgen_private::SolvableId; using cbindgen_private::StringId; using cbindgen_private::VersionSetId; +using cbindgen_private::VersionSetUnionId; /** * An interface that implements ecosystem specific logic. @@ -75,6 +76,11 @@ struct DependencyProvider { */ virtual NameId solvable_name(SolvableId solvable_id) = 0; + /** + * Returns the version sets comprising the given union. + */ + virtual Slice version_sets_in_union(VersionSetUnionId version_set_union_id) = 0; + /** * Obtains a list of solvables that should be considered when a package * with the given name is requested. @@ -133,6 +139,15 @@ extern "C" inline NameId bridge_solvable_name(void *data, SolvableId solvable_id return reinterpret_cast(data)->solvable_name(solvable_id); } +// HACK(clang): For some reason, clang needs this to know that the return type is complete +static_assert(sizeof(Slice)); + +extern "C" inline Slice bridge_version_sets_in_union( + void *data, VersionSetUnionId version_set_union_id) { + return reinterpret_cast(data)->version_sets_in_union( + version_set_union_id); +} + extern "C" inline void bridge_get_candidates(void *data, NameId package, Candidates *result) { *result = reinterpret_cast(data)->get_candidates(package); } diff --git a/cpp/src/lib.rs b/cpp/src/lib.rs index f8bd5c6..4aa9998 100644 --- a/cpp/src/lib.rs +++ b/cpp/src/lib.rs @@ -31,6 +31,66 @@ impl From for resolvo::SolvableId { } } +/// Specifies the dependency of a solvable on a set of version sets. +/// cbindgen:derive-eq +/// cbindgen:derive-neq +#[repr(C)] +#[derive(Copy, Clone)] +pub enum Requirement { + /// Specifies a dependency on a single version set. + /// cbindgen:derive-eq + /// cbindgen:derive-neq + Single(VersionSetId), + /// Specifies a dependency on the union (logical OR) of multiple version sets. A solvable + /// belonging to ANY of the version sets contained in the union satisfies the requirement. + /// This variant is typically used for requirements that can be satisfied by two or more + /// version sets belonging to different packages. + /// cbindgen:derive-eq + /// cbindgen:derive-neq + Union(VersionSetUnionId), +} + +impl From for crate::Requirement { + fn from(value: resolvo::Requirement) -> Self { + match value { + resolvo::Requirement::Single(id) => Requirement::Single(id.into()), + resolvo::Requirement::Union(id) => Requirement::Union(id.into()), + } + } +} + +impl From for resolvo::Requirement { + fn from(value: crate::Requirement) -> Self { + match value { + Requirement::Single(id) => resolvo::Requirement::Single(id.into()), + Requirement::Union(id) => resolvo::Requirement::Union(id.into()), + } + } +} + +/// A unique identifier for a version set union. A version set union describes +/// the union (logical OR) of a non-empty set of version sets belonging to +/// more than one package. +/// cbindgen:derive-eq +/// cbindgen:derive-neq +#[repr(C)] +#[derive(Copy, Clone)] +pub struct VersionSetUnionId { + id: u32, +} + +impl From for crate::VersionSetUnionId { + fn from(id: resolvo::VersionSetUnionId) -> Self { + Self { id: id.0 } + } +} + +impl From for resolvo::VersionSetUnionId { + fn from(id: crate::VersionSetUnionId) -> Self { + Self(id.id) + } +} + /// A unique identifier for a single version set. A version set describes a /// set of versions. /// cbindgen:derive-eq @@ -102,7 +162,7 @@ pub struct Dependencies { /// A pointer to the first element of a list of requirements. Requirements /// defines which packages should be installed alongside the depending /// package and the constraints applied to the package. - pub requirements: Vector, + pub requirements: Vector, /// Defines additional constraints on packages that may or may not be part /// of the solution. Different from `requirements`, packages in this set @@ -230,6 +290,12 @@ pub struct DependencyProvider { /// Returns the name of the package for the given solvable. pub solvable_name: unsafe extern "C" fn(data: *mut c_void, solvable_id: SolvableId) -> NameId, + /// Returns the version sets comprising the given union. + pub version_sets_in_union: unsafe extern "C" fn( + data: *mut c_void, + version_set_union_id: VersionSetUnionId, + ) -> Slice<'static, VersionSetId>, + /// Obtains a list of solvables that should be considered when a package /// with the given name is requested. pub get_candidates: @@ -314,6 +380,17 @@ impl<'d> resolvo::Interner for &'d DependencyProvider { fn solvable_name(&self, solvable: resolvo::SolvableId) -> resolvo::NameId { unsafe { (self.solvable_name)(self.data, solvable.into()) }.into() } + + fn version_sets_in_union( + &self, + version_set_union: resolvo::VersionSetUnionId, + ) -> impl Iterator { + unsafe { (self.version_sets_in_union)(self.data, version_set_union.into()) } + .as_slice() + .into_iter() + .copied() + .map(Into::into) + } } impl<'d> resolvo::DependencyProvider for &'d DependencyProvider { @@ -400,7 +477,7 @@ impl<'d> resolvo::DependencyProvider for &'d DependencyProvider { #[allow(unused)] pub extern "C" fn resolvo_solve( provider: &DependencyProvider, - requirements: Slice, + requirements: Slice, constraints: Slice, error: &mut String, result: &mut Vector, @@ -433,6 +510,20 @@ pub extern "C" fn resolvo_solve( } } +#[no_mangle] +#[allow(unused)] +pub extern "C" fn resolvo_requirement_single(version_set_id: VersionSetId) -> Requirement { + Requirement::Single(version_set_id) +} + +#[no_mangle] +#[allow(unused)] +pub extern "C" fn resolvo_requirement_union( + version_set_union_id: VersionSetUnionId, +) -> Requirement { + Requirement::Union(version_set_union_id) +} + #[cfg(test)] mod tests { use super::*; diff --git a/cpp/tests/solve.cpp b/cpp/tests/solve.cpp index ffb8c66..13a05fa 100644 --- a/cpp/tests/solve.cpp +++ b/cpp/tests/solve.cpp @@ -16,9 +16,9 @@ struct Candidate { }; /** - * A requirement for a package. + * A version set for a package. */ -struct Requirement { +struct VersionSet { resolvo::NameId name; uint32_t version_start; uint32_t version_end; @@ -31,19 +31,47 @@ struct PackageDatabase : public resolvo::DependencyProvider { resolvo::Pool names; resolvo::Pool strings; std::vector candidates; - std::vector requirements; + std::vector version_sets; + std::vector> version_set_unions; /** - * Allocates a new requirement and return the id of the requirement. + * Allocates a new version set and return the id of the version set. */ - resolvo::VersionSetId alloc_requirement(std::string_view package, uint32_t version_start, + resolvo::VersionSetId alloc_version_set(std::string_view package, uint32_t version_start, uint32_t version_end) { auto name_id = names.alloc(std::move(package)); - auto id = resolvo::VersionSetId{static_cast(requirements.size())}; - requirements.push_back(Requirement{name_id, version_start, version_end}); + auto id = resolvo::VersionSetId{static_cast(version_sets.size())}; + version_sets.push_back(VersionSet{name_id, version_start, version_end}); return id; } + /** + * Allocates a new requirement for a single version set. + */ + resolvo::Requirement alloc_requirement(std::string_view package, uint32_t version_start, + uint32_t version_end) { + auto id = alloc_version_set(package, version_start, version_end); + return resolvo::requirement_single(id); + } + + /** + * Allocates a new requirement for a version set union. + */ + resolvo::Requirement alloc_requirement_union( + std::initializer_list> version_sets) { + std::vector version_set_union{version_sets.size()}; + + auto version_sets_it = version_sets.begin(); + for (size_t i = 0; i < version_sets.size(); ++i, ++version_sets_it) { + auto [package, version_start, version_end] = *version_sets_it; + version_set_union[i] = alloc_version_set(package, version_start, version_end); + } + + auto id = resolvo::VersionSetUnionId{static_cast(version_set_unions.size())}; + version_set_unions.push_back(std::move(version_set_union)); + return resolvo::requirement_union(id); + } + /** * Allocates a new candidate and return the id of the candidate. */ @@ -90,7 +118,7 @@ struct PackageDatabase : public resolvo::DependencyProvider { } resolvo::String display_version_set(resolvo::VersionSetId version_set) override { - const auto& req = requirements[version_set.id]; + const auto& req = version_sets[version_set.id]; std::stringstream ss; ss << req.version_start << ".." << req.version_end; return resolvo::String(ss.str()); @@ -101,13 +129,19 @@ struct PackageDatabase : public resolvo::DependencyProvider { } resolvo::NameId version_set_name(resolvo::VersionSetId version_set_id) override { - return requirements[version_set_id.id].name; + return version_sets[version_set_id.id].name; } resolvo::NameId solvable_name(resolvo::SolvableId solvable_id) override { return candidates[solvable_id.id].name; } + resolvo::Slice version_sets_in_union( + resolvo::VersionSetUnionId version_set_union_id) override { + const auto& version_set_ids = version_set_unions[version_set_union_id.id]; + return {version_set_ids.data(), version_set_ids.size()}; + } + resolvo::Candidates get_candidates(resolvo::NameId package) override { resolvo::Candidates result; @@ -137,11 +171,11 @@ struct PackageDatabase : public resolvo::DependencyProvider { resolvo::Slice solvables, resolvo::VersionSetId version_set_id, bool inverse) override { resolvo::Vector result; - const auto& requirement = requirements[version_set_id.id]; + const auto& version_set = version_sets[version_set_id.id]; for (auto solvable : solvables) { const auto& candidate = candidates[solvable.id]; - bool matches = candidate.version >= requirement.version_start && - candidate.version < requirement.version_end; + bool matches = candidate.version >= version_set.version_start && + candidate.version < version_set.version_end; if (matches != inverse) { result.push_back(solvable); } @@ -183,9 +217,9 @@ SCENARIO("Solve") { auto c_1 = db.alloc_candidate("c", 1, {}); // Construct a problem to be solved by the solver - resolvo::Vector requirements = {db.alloc_requirement("a", 1, 3)}; - resolvo::Vector constraints = {db.alloc_requirement("b", 1, 3), - db.alloc_requirement("c", 1, 3)}; + resolvo::Vector requirements = {db.alloc_requirement("a", 1, 3)}; + resolvo::Vector constraints = {db.alloc_version_set("b", 1, 3), + db.alloc_version_set("c", 1, 3)}; // Solve the problem resolvo::Vector result; @@ -196,3 +230,46 @@ SCENARIO("Solve") { REQUIRE(result[0] == a_2); REQUIRE(result[1] == b_2); } + +SCENARIO("Solve Union") { + /// Construct a database with packages a, b, and c. + PackageDatabase db; + + // Check that PackageDatabase correctly implements the DependencyProvider interface + static_assert(std::has_virtual_destructor_v); + static_assert(std::is_polymorphic_v); + static_assert(std::is_base_of_v); + + auto a_1 = db.alloc_candidate("a", 1, {}); + + auto b_1 = db.alloc_candidate("b", 1, {}); + + auto c_1 = db.alloc_candidate("c", 1, {{db.alloc_requirement("a", 1, 10)}, {}}); + + auto d_1 = db.alloc_candidate("d", 1, {{db.alloc_requirement("b", 1, 10)}, {}}); + + auto e_1 = db.alloc_candidate("e", 1, + {{db.alloc_requirement_union({{"a", 1, 10}, {"b", 1, 10}})}, {}}); + + auto f_1 = db.alloc_candidate( + "f", 1, {{db.alloc_requirement("b", 1, 10)}, {db.alloc_version_set("a", 10, 20)}}); + + // Construct a problem to be solved by the solver + resolvo::Vector requirements = { + db.alloc_requirement_union({{"c", 1, 10}, {"d", 1, 10}}), + db.alloc_requirement("e", 1, 10), + db.alloc_requirement("f", 1, 10), + }; + resolvo::Vector constraints = {}; + + // Solve the problem + resolvo::Vector result; + resolvo::solve(db, requirements, constraints, result); + + // Check the result + REQUIRE(result.size() == 4); + REQUIRE(result[0] == f_1); + REQUIRE(result[1] == e_1); + REQUIRE(result[2] == b_1); + REQUIRE(result[3] == d_1); +} diff --git a/src/internal/id.rs b/src/internal/id.rs index 1ebc041..7f5a44a 100644 --- a/src/internal/id.rs +++ b/src/internal/id.rs @@ -53,6 +53,23 @@ impl ArenaId for VersionSetId { } } +/// The id associated with a union (logical OR) of two or more version sets. +#[repr(transparent)] +#[derive(Clone, Default, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] +pub struct VersionSetUnionId(pub u32); + +impl ArenaId for VersionSetUnionId { + fn from_usize(x: usize) -> Self { + Self(x as u32) + } + + fn to_usize(self) -> usize { + self.0 as usize + } +} + /// The id associated to a solvable #[repr(transparent)] #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] diff --git a/src/lib.rs b/src/lib.rs index 40bac7d..5ca2ba1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,7 @@ pub(crate) mod internal; pub mod problem; +mod requirement; pub mod runtime; pub mod snapshot; mod solver; @@ -23,10 +24,11 @@ use std::{ }; pub use internal::{ - id::{NameId, SolvableId, StringId, VersionSetId}, + id::{NameId, SolvableId, StringId, VersionSetId, VersionSetUnionId}, mapping::Mapping, }; use itertools::Itertools; +pub use requirement::Requirement; pub use solver::{Solver, SolverCache, UnsolvableOrCancelled}; /// An object that is used by the solver to query certain properties of @@ -87,6 +89,12 @@ pub trait Interner { /// Returns the name of the package for the given solvable. fn solvable_name(&self, solvable: SolvableId) -> NameId; + + /// Returns the version sets comprising the given union. + fn version_sets_in_union( + &self, + version_set_union: VersionSetUnionId, + ) -> impl Iterator; } /// Defines implementation specific behavior for the solver and a way for the @@ -195,7 +203,7 @@ pub struct KnownDependencies { feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty") )] - pub requirements: Vec, + pub requirements: Vec, /// Defines additional constraints on packages that may or may not be part /// of the solution. Different from `requirements`, packages in this set diff --git a/src/problem.rs b/src/problem.rs index a084453..3033244 100644 --- a/src/problem.rs +++ b/src/problem.rs @@ -15,7 +15,7 @@ use crate::{ internal::id::{ClauseId, InternalSolvableId, SolvableId, StringId, VersionSetId}, runtime::AsyncRuntime, solver::{clause::Clause, Solver}, - DependencyProvider, Interner, + DependencyProvider, Interner, Requirement, }; /// Represents the cause of the solver being unable to find a solution @@ -198,21 +198,21 @@ impl ProblemNode { #[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] pub(crate) enum ProblemEdge { /// The target node is a candidate for the dependency specified by the - /// version set - Requires(VersionSetId), + /// [`Requirement`] + Requires(Requirement), /// The target node is involved in a conflict, caused by `ConflictCause` Conflict(ConflictCause), } impl ProblemEdge { - fn try_requires(self) -> Option { + fn try_requires(self) -> Option { match self { ProblemEdge::Requires(match_spec_id) => Some(match_spec_id), ProblemEdge::Conflict(_) => None, } } - fn requires(self) -> VersionSetId { + fn requires(self) -> Requirement { match self { ProblemEdge::Requires(match_spec_id) => match_spec_id, ProblemEdge::Conflict(_) => panic!("expected requires edge, found conflict"), @@ -301,8 +301,8 @@ impl ProblemGraph { }; let label = match edge.weight() { - ProblemEdge::Requires(version_set_id) - | ProblemEdge::Conflict(ConflictCause::Constrains(version_set_id)) => { + ProblemEdge::Requires(requirement) => requirement.display(interner).to_string(), + ProblemEdge::Conflict(ConflictCause::Constrains(version_set_id)) => { interner.display_version_set(*version_set_id).to_string() } ProblemEdge::Conflict(ConflictCause::ForbidMultipleInstances) @@ -651,7 +651,7 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { top_level_indent: bool, ) -> fmt::Result { pub enum DisplayOp { - Requirement(VersionSetId, Vec), + Requirement(Requirement, Vec), Candidate(NodeIndex), } @@ -693,7 +693,7 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { let indent = indenter.get_indent(); match node { - DisplayOp::Requirement(version_set_id, edges) => { + DisplayOp::Requirement(requirement, edges) => { debug_assert!(!edges.is_empty()); let installable = edges.iter().any(|&e| { @@ -701,12 +701,7 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { installable_nodes.contains(&target) }); - let req = self - .interner - .display_version_set(version_set_id) - .to_string(); - let name = self.interner.version_set_name(version_set_id); - let name = self.interner.display_name(name).to_string(); + let req = requirement.display(self.interner).to_string(); let target_nx = graph.edge_endpoints(edges[0]).unwrap().1; let missing = @@ -714,22 +709,19 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { if missing { // No candidates for requirement if top_level { - writeln!(f, "{indent}No candidates were found for {name} {req}.")?; + writeln!(f, "{indent}No candidates were found for {req}.")?; } else { - writeln!( - f, - "{indent}{name} {req}, for which no candidates were found.", - )?; + writeln!(f, "{indent}{req}, for which no candidates were found.",)?; } } else if installable { // Package can be installed (only mentioned for top-level requirements) if top_level { writeln!( f, - "{indent}{name} {req} can be installed with any of the following options:" + "{indent}{req} can be installed with any of the following options:" )?; } else { - writeln!(f, "{indent}{name} {req}, which can be installed with any of the following options:")?; + writeln!(f, "{indent}{req}, which can be installed with any of the following options:")?; } let children: Vec<_> = edges @@ -780,9 +772,9 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { // Package cannot be installed (the conflicting requirement is further down // the tree) if top_level { - writeln!(f, "{indent}{name} {req} cannot be installed because there are no viable options:")?; + writeln!(f, "{indent}{req} cannot be installed because there are no viable options:")?; } else { - writeln!(f, "{indent}{name} {req}, which cannot be installed because there are no viable options:")?; + writeln!(f, "{indent}{req}, which cannot be installed because there are no viable options:")?; } let children: Vec<_> = edges diff --git a/src/requirement.rs b/src/requirement.rs new file mode 100644 index 0000000..10b9e46 --- /dev/null +++ b/src/requirement.rs @@ -0,0 +1,91 @@ +use crate::{Interner, VersionSetId, VersionSetUnionId}; +use itertools::Itertools; +use std::fmt::Display; + +/// Specifies the dependency of a solvable on a set of version sets. +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum Requirement { + /// Specifies a dependency on a single version set. + Single(VersionSetId), + /// Specifies a dependency on the union (logical OR) of multiple version sets. A solvable + /// belonging to _any_ of the version sets contained in the union satisfies the requirement. + /// This variant is typically used for requirements that can be satisfied by two or more + /// version sets belonging to _different_ packages. + Union(VersionSetUnionId), +} + +impl Default for Requirement { + fn default() -> Self { + Self::Single(Default::default()) + } +} + +impl From for Requirement { + fn from(value: VersionSetId) -> Self { + Requirement::Single(value) + } +} + +impl From for Requirement { + fn from(value: VersionSetUnionId) -> Self { + Requirement::Union(value) + } +} + +impl Requirement { + pub(crate) fn display<'i>(&'i self, interner: &'i impl Interner) -> impl Display + '_ { + DisplayRequirement { + interner, + requirement: self, + } + } + + pub(crate) fn version_sets<'i>( + &'i self, + interner: &'i impl Interner, + ) -> impl Iterator + 'i { + match *self { + Requirement::Single(version_set) => { + itertools::Either::Left(std::iter::once(version_set)) + } + Requirement::Union(version_set_union) => { + itertools::Either::Right(interner.version_sets_in_union(version_set_union)) + } + } + } +} + +pub(crate) struct DisplayRequirement<'i, I: Interner> { + interner: &'i I, + requirement: &'i Requirement, +} + +impl<'i, I: Interner> Display for DisplayRequirement<'i, I> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match *self.requirement { + Requirement::Single(version_set) => write!( + f, + "{} {}", + self.interner + .display_name(self.interner.version_set_name(version_set)), + self.interner.display_version_set(version_set) + ), + Requirement::Union(version_set_union) => { + let formatted_version_sets = self + .interner + .version_sets_in_union(version_set_union) + .format_with(" | ", |version_set, f| { + f(&format_args!( + "{} {}", + self.interner + .display_name(self.interner.version_set_name(version_set)), + self.interner.display_version_set(version_set) + )) + }); + + write!(f, "{}", formatted_version_sets) + } + } + } +} diff --git a/src/snapshot.rs b/src/snapshot.rs index 279759d..08b3fab 100644 --- a/src/snapshot.rs +++ b/src/snapshot.rs @@ -16,7 +16,7 @@ use futures::FutureExt; use crate::{ internal::arena::ArenaId, Candidates, Dependencies, DependencyProvider, Interner, Mapping, - NameId, SolvableId, SolverCache, StringId, VersionSetId, + NameId, Requirement, SolvableId, SolverCache, StringId, VersionSetId, VersionSetUnionId, }; /// A single solvable in a [`DependencySnapshot`]. @@ -84,12 +84,19 @@ pub struct DependencySnapshot { )] pub solvables: Mapping, - /// All the requirements in the snapshot + /// All the version set unions in the snapshot #[cfg_attr( feature = "serde", serde(default, skip_serializing_if = "Mapping::is_empty") )] - pub requirements: Mapping, + pub version_set_unions: Mapping>, + + /// All the version sets in the snapshot + #[cfg_attr( + feature = "serde", + serde(default, skip_serializing_if = "Mapping::is_empty") + )] + pub version_sets: Mapping, /// All the packages in the snapshot #[cfg_attr( @@ -151,7 +158,8 @@ impl DependencySnapshot { let mut result = Self { solvables: Mapping::new(), - requirements: Mapping::new(), + version_set_unions: Mapping::new(), + version_sets: Mapping::new(), packages: Mapping::new(), strings: Mapping::new(), }; @@ -206,11 +214,37 @@ impl DependencySnapshot { } } Dependencies::Known(deps) => { - for &dep in deps.requirements.iter().chain(deps.constrains.iter()) { + for &dep in deps.constrains.iter() { if seen.insert(Element::VersionSet(dep)) { queue.push_back(Element::VersionSet(dep)); } } + + for &requirement in deps.requirements.iter() { + match requirement { + Requirement::Single(version_set) => { + if seen.insert(Element::VersionSet(version_set)) { + queue.push_back(Element::VersionSet(version_set)); + } + } + Requirement::Union(version_set_union_id) => { + let version_sets: HashSet<_> = cache + .provider() + .version_sets_in_union(version_set_union_id) + .collect(); + + for &version_set in version_sets.iter() { + if seen.insert(Element::VersionSet(version_set)) { + queue.push_back(Element::VersionSet(version_set)); + } + } + + result + .version_set_unions + .insert(version_set_union_id, version_sets); + } + } + } } } @@ -255,7 +289,7 @@ impl DependencySnapshot { matching_candidates: matching_candidates.iter().copied().collect(), }; - result.requirements.insert(version_set_id, version_set); + result.version_sets.insert(version_set_id, version_set); } } } @@ -323,7 +357,7 @@ impl<'s> SnapshotProvider<'s> { /// Adds another requirement that matches any version of a package pub fn add_package_requirement(&mut self, name: NameId) -> VersionSetId { - let id = self.snapshot.requirements.max() + self.additional_version_sets.len(); + let id = self.snapshot.version_sets.max() + self.additional_version_sets.len(); let package = self.package(name); @@ -360,12 +394,12 @@ impl<'s> SnapshotProvider<'s> { fn version_set(&self, version_set: VersionSetId) -> &VersionSet { let idx = version_set.to_usize(); - let max_idx = self.snapshot.requirements.max(); + let max_idx = self.snapshot.version_sets.max(); if idx >= max_idx { &self.additional_version_sets[idx - max_idx] } else { self.snapshot - .requirements + .version_sets .get(version_set) .expect("missing version set") } @@ -396,6 +430,18 @@ impl<'s> Interner for SnapshotProvider<'s> { fn solvable_name(&self, solvable: SolvableId) -> NameId { self.solvable(solvable).name } + + fn version_sets_in_union( + &self, + version_set_union_id: VersionSetUnionId, + ) -> impl Iterator { + self.snapshot + .version_set_unions + .get(version_set_union_id) + .expect("missing constraint") + .iter() + .copied() + } } impl<'s> DependencyProvider for SnapshotProvider<'s> { diff --git a/src/solver/cache.rs b/src/solver/cache.rs index cd4cc9b..e66a950 100644 --- a/src/solver/cache.rs +++ b/src/solver/cache.rs @@ -11,7 +11,7 @@ use crate::{ frozen_copy_map::FrozenCopyMap, id::{CandidatesId, DependenciesId}, }, - Candidates, Dependencies, DependencyProvider, NameId, SolvableId, VersionSetId, + Candidates, Dependencies, DependencyProvider, NameId, Requirement, SolvableId, VersionSetId, }; /// Keeps a cache of previously computed and/or requested information about @@ -32,10 +32,10 @@ pub struct SolverCache { /// included). version_set_inverse_candidates: FrozenMap, ahash::RandomState>, - /// A mapping of `VersionSetId` to a sorted list of candidates that match - /// that set. - pub(crate) version_set_to_sorted_candidates: - FrozenMap, ahash::RandomState>, + /// A mapping of [`Requirement`] to a sorted list of candidates that fulfill + /// that requirement. + pub(crate) requirement_to_sorted_candidates: + FrozenMap, ahash::RandomState>, /// A mapping from a solvable to a list of dependencies solvable_dependencies: Arena, @@ -59,7 +59,7 @@ impl SolverCache { package_name_to_candidates_in_flight: Default::default(), version_set_candidates: Default::default(), version_set_inverse_candidates: Default::default(), - version_set_to_sorted_candidates: Default::default(), + requirement_to_sorted_candidates: Default::default(), solvable_dependencies: Default::default(), solvable_to_dependencies: Default::default(), hint_dependencies_available: Default::default(), @@ -242,50 +242,89 @@ impl SolverCache { } } - /// Returns the candidates for the package with the given name similar to - /// [`Self::get_or_cache_candidates`] sorted from highest to lowest. + /// Returns the candidates fulfilling the [`Requirement`] sorted from highest to lowest + /// within each version set comprising the [`Requirement`]. /// /// If the provider has requested the solving process to be cancelled, the /// cancellation value will be returned as an `Err(...)`. pub async fn get_or_cache_sorted_candidates( &self, - version_set_id: VersionSetId, + requirement: Requirement, ) -> Result<&[SolvableId], Box> { - match self.version_set_to_sorted_candidates.get(&version_set_id) { - Some(candidates) => Ok(candidates), - None => { - let package_name_id = self.provider.version_set_name(version_set_id); - tracing::trace!( - "Getting sorted matching candidates for package: {:?}", - self.provider.display_name(package_name_id).to_string() - ); - - let matching_candidates = self - .get_or_cache_matching_candidates(version_set_id) - .await?; - let candidates = self.get_or_cache_candidates(package_name_id).await?; - - // Sort all the candidates in order in which they should be tried by the solver. - let mut sorted_candidates = Vec::new(); - sorted_candidates.extend_from_slice(matching_candidates); - self.provider - .sort_candidates(self, &mut sorted_candidates) - .await; - - // If we have a solvable that we favor, we sort that to the front. This ensures - // that the version that is favored is picked first. - if let Some(favored_id) = candidates.favored { - if let Some(pos) = sorted_candidates.iter().position(|&s| s == favored_id) { - // Move the element at `pos` to the front of the array - sorted_candidates[0..=pos].rotate_right(1); + match requirement { + Requirement::Single(version_set_id) => { + self.get_or_cache_sorted_candidates_for_version_set(version_set_id) + .await + } + Requirement::Union(version_set_union_id) => { + match self.requirement_to_sorted_candidates.get(&requirement) { + Some(candidates) => Ok(candidates), + None => { + let sorted_candidates = futures::future::try_join_all( + self.provider() + .version_sets_in_union(version_set_union_id) + .map(|version_set_id| { + self.get_or_cache_sorted_candidates_for_version_set( + version_set_id, + ) + }), + ) + .await? + .into_iter() + .flatten() + .copied() + .collect(); + + Ok(self + .requirement_to_sorted_candidates + .insert(requirement, sorted_candidates)) } } + } + } + } - Ok(self - .version_set_to_sorted_candidates - .insert(version_set_id, sorted_candidates)) + /// Returns the sorted candidates for a singular version set requirement + /// (akin to a [`Requirement::Single`]). + async fn get_or_cache_sorted_candidates_for_version_set( + &self, + version_set_id: VersionSetId, + ) -> Result<&[SolvableId], Box> { + let requirement = version_set_id.into(); + if let Some(candidates) = self.requirement_to_sorted_candidates.get(&requirement) { + return Ok(candidates); + } + + let package_name_id = self.provider.version_set_name(version_set_id); + tracing::trace!( + "Getting sorted matching candidates for package: {:?}", + self.provider.display_name(package_name_id).to_string() + ); + + let matching_candidates = self + .get_or_cache_matching_candidates(version_set_id) + .await?; + let candidates = self.get_or_cache_candidates(package_name_id).await?; + + // Sort all the candidates in order in which they should be tried by the solver. + let mut sorted_candidates = Vec::with_capacity(matching_candidates.len()); + sorted_candidates.extend_from_slice(matching_candidates); + self.provider + .sort_candidates(self, &mut sorted_candidates) + .await; + + // If we have a solvable that we favor, we sort that to the front. This ensures + // that the version that is favored is picked first. + if let Some(favored_id) = candidates.favored { + if let Some(pos) = sorted_candidates.iter().position(|&s| s == favored_id) { + // Move the element at `pos` to the front of the array + sorted_candidates[0..=pos].rotate_right(1); } } + + Ok(self + .requirement_to_sorted_candidates + .insert(requirement, sorted_candidates)) } /// Returns the dependencies of a solvable. Requests the solvables from the diff --git a/src/solver/clause.rs b/src/solver/clause.rs index 8ea3b99..1476c33 100644 --- a/src/solver/clause.rs +++ b/src/solver/clause.rs @@ -11,7 +11,7 @@ use crate::{ id::{ClauseId, InternalSolvableId, LearntClauseId, StringId, VersionSetId}, }, solver::{decision_map::DecisionMap, decision_tracker::DecisionTracker}, - Interner, NameId, SolvableId, + Interner, NameId, Requirement, SolvableId, }; /// Represents a single clause in the SAT problem @@ -47,11 +47,11 @@ pub(crate) enum Clause { /// /// In SAT terms: (root) InstallRoot, - /// The solvable requires the candidates associated with the version set + /// Makes the solvable require the candidates associated with the [`Requirement`]. /// /// In SAT terms: (¬A ∨ B1 ∨ B2 ∨ ... ∨ B99), where B1 to B99 represent the - /// possible candidates for the provided version set - Requires(InternalSolvableId, VersionSetId), + /// possible candidates for the provided [`Requirement`]. + Requires(InternalSolvableId, Requirement), /// Ensures only a single version of a package is installed /// /// Usage: generate one [`Clause::ForbidMultipleInstances`] clause for each @@ -108,7 +108,7 @@ impl Clause { /// added dynamically. fn requires( parent: InternalSolvableId, - requirement: VersionSetId, + requirement: Requirement, candidates: &[SolvableId], decision_tracker: &DecisionTracker, ) -> (Self, Option<[InternalSolvableId; 2]>, bool) { @@ -228,8 +228,8 @@ impl Clause { pub fn visit_literals( &self, learnt_clauses: &Arena>, - version_set_to_sorted_candidates: &FrozenMap< - VersionSetId, + requirements_to_sorted_candidates: &FrozenMap< + Requirement, Vec, ahash::RandomState, >, @@ -254,7 +254,7 @@ impl Clause { negate: true, }); - for &solvable_id in &version_set_to_sorted_candidates[&match_spec_id] { + for &solvable_id in &requirements_to_sorted_candidates[&match_spec_id] { visit(Literal { solvable_id: solvable_id.into(), negate: false, @@ -320,7 +320,7 @@ impl ClauseState { /// conflict. pub fn requires( candidate: InternalSolvableId, - requirement: VersionSetId, + requirement: Requirement, matching_candidates: &[SolvableId], decision_tracker: &DecisionTracker, ) -> (Self, bool) { @@ -519,8 +519,8 @@ impl ClauseState { pub fn next_unwatched_variable( &self, learnt_clauses: &Arena>, - version_set_to_sorted_candidates: &FrozenMap< - VersionSetId, + requirement_to_sorted_candidates: &FrozenMap< + Requirement, Vec, ahash::RandomState, >, @@ -543,7 +543,7 @@ impl ClauseState { .find(|&l| can_watch(l)) .map(|l| l.solvable_id), Clause::Constrains(..) | Clause::ForbidMultipleInstances(..) | Clause::Lock(..) => None, - Clause::Requires(solvable_id, version_set_id) => { + Clause::Requires(solvable_id, requirement) => { // The solvable that added this clause let solvable_lit = Literal { solvable_id, @@ -554,7 +554,7 @@ impl ClauseState { } // The available candidates - for &candidate in &version_set_to_sorted_candidates[&version_set_id] { + for &candidate in &requirement_to_sorted_candidates[&requirement] { let lit = Literal { solvable_id: candidate.into(), negate: false, @@ -622,13 +622,13 @@ impl<'i, I: Interner> Display for ClauseDisplay<'i, I> { ) } Clause::Learnt(learnt_id) => write!(f, "Learnt({learnt_id:?})"), - Clause::Requires(solvable_id, version_set_id) => { + Clause::Requires(solvable_id, requirement) => { write!( f, "Requires({}({:?}), {})", solvable_id.display(self.interner), solvable_id, - self.interner.display_version_set(version_set_id) + requirement.display(self.interner), ) } Clause::Constrains(s1, s2, version_set_id) => { @@ -850,7 +850,7 @@ mod test { // No conflict, all candidates available let (clause, conflict) = ClauseState::requires( parent, - VersionSetId::from_usize(0), + VersionSetId::from_usize(0).into(), &[candidate1, candidate2], &decisions, ); @@ -864,7 +864,7 @@ mod test { .unwrap(); let (clause, conflict) = ClauseState::requires( parent, - VersionSetId::from_usize(0), + VersionSetId::from_usize(0).into(), &[candidate1, candidate2], &decisions, ); @@ -878,7 +878,7 @@ mod test { .unwrap(); let (clause, conflict) = ClauseState::requires( parent, - VersionSetId::from_usize(0), + VersionSetId::from_usize(0).into(), &[candidate1, candidate2], &decisions, ); @@ -893,7 +893,7 @@ mod test { let panicked = std::panic::catch_unwind(|| { ClauseState::requires( parent, - VersionSetId::from_usize(0), + VersionSetId::from_usize(0).into(), &[candidate1, candidate2], &decisions, ) diff --git a/src/solver/mod.rs b/src/solver/mod.rs index 61fe85a..45205cc 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -3,7 +3,7 @@ use clause::{Clause, ClauseState, Literal}; use decision::Decision; use decision_tracker::DecisionTracker; use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; -use itertools::{chain, Itertools}; +use itertools::Itertools; use std::fmt::Display; use std::{any::Any, cell::RefCell, collections::HashSet, future::ready, ops::ControlFlow}; use watch_map::WatchMap; @@ -16,7 +16,7 @@ use crate::{ }, problem::Problem, runtime::{AsyncRuntime, NowOrNeverRuntime}, - Candidates, Dependencies, DependencyProvider, KnownDependencies, VersionSetId, + Candidates, Dependencies, DependencyProvider, KnownDependencies, Requirement, VersionSetId, }; mod cache; @@ -28,7 +28,7 @@ mod watch_map; #[derive(Default)] struct AddClauseOutput { - new_requires_clauses: Vec<(InternalSolvableId, VersionSetId, ClauseId)>, + new_requires_clauses: Vec<(InternalSolvableId, Requirement, ClauseId)>, conflicting_clauses: Vec, negative_assertions: Vec<(InternalSolvableId, ClauseId)>, clauses_to_watch: Vec, @@ -40,7 +40,7 @@ pub struct Solver { pub(crate) cache: SolverCache, pub(crate) clauses: RefCell>, - requires_clauses: Vec<(InternalSolvableId, VersionSetId, ClauseId)>, + requires_clauses: Vec<(InternalSolvableId, Requirement, ClauseId)>, watches: WatchMap, negative_assertions: Vec<(InternalSolvableId, ClauseId)>, @@ -54,8 +54,8 @@ pub struct Solver { decision_tracker: DecisionTracker, - /// The version sets that must be installed as part of the solution. - root_requirements: Vec, + /// The [`Requirement`]s that must be installed as part of the solution. + root_requirements: Vec, /// Additional constraints imposed by the root. root_constraints: Vec, @@ -164,7 +164,7 @@ impl Solver { /// inspect the causes and report them to the user. pub fn solve( &mut self, - root_requirements: Vec, + root_requirements: Vec, root_constraints: Vec, ) -> Result, UnsolvableOrCancelled> { // Clear state @@ -231,7 +231,7 @@ impl Solver { }, SortedCandidates { solvable_id: InternalSolvableId, - version_set_id: VersionSetId, + requirement: Requirement, candidates: &'i [SolvableId], }, NonMatchingCandidates { @@ -335,7 +335,11 @@ impl Solver { } }; - for version_set_id in chain(requirements.iter(), constrains.iter()).copied() { + for version_set_id in requirements + .iter() + .flat_map(|requirement| requirement.version_sets(self.provider())) + .chain(constrains.iter().copied()) + { let dependency_name = self.provider().version_set_name(version_set_id); if clauses_added_for_package.insert(dependency_name) { tracing::trace!( @@ -357,17 +361,17 @@ impl Solver { } } - for version_set_id in requirements { + for requirement in requirements { // Find all the solvable that match for the given version set pending_futures.push( async move { let candidates = self .cache - .get_or_cache_sorted_candidates(version_set_id) + .get_or_cache_sorted_candidates(requirement) .await?; Ok(TaskResult::SortedCandidates { solvable_id, - version_set_id, + requirement, candidates, }) } @@ -467,14 +471,12 @@ impl Solver { } TaskResult::SortedCandidates { solvable_id, - version_set_id, + requirement, candidates, } => { tracing::trace!( - "Sorted candidates available for {} {}", - self.provider() - .display_name(self.provider().version_set_name(version_set_id)), - self.provider().display_version_set(version_set_id), + "Sorted candidates available for {}", + requirement.display(self.provider()), ); // Queue requesting the dependencies of the candidates as well if they are @@ -492,7 +494,7 @@ impl Solver { let no_candidates = candidates.is_empty(); let (clause, conflict) = ClauseState::requires( solvable_id, - version_set_id, + requirement, candidates, &self.decision_tracker, ); @@ -500,7 +502,7 @@ impl Solver { let clause_id = self.clauses.borrow_mut().alloc(clause); let clause = &self.clauses.borrow()[clause_id]; - let &Clause::Requires(solvable_id, version_set_id) = &clause.kind else { + let &Clause::Requires(solvable_id, requirement) = &clause.kind else { unreachable!(); }; @@ -510,7 +512,7 @@ impl Solver { output .new_requires_clauses - .push((solvable_id, version_set_id, clause_id)); + .push((solvable_id, requirement, clause_id)); if conflict { output.conflicting_clauses.push(clause_id); @@ -787,7 +789,7 @@ impl Solver { } // Consider only clauses in which no candidates have been installed - let candidates = &self.cache.version_set_to_sorted_candidates[&deps]; + let candidates = &self.cache.requirement_to_sorted_candidates[&deps]; // Either find the first assignable candidate or determine that one of the // candidates is already assigned in which case the clause has @@ -1091,7 +1093,7 @@ impl Solver { // One of the watched literals is now false if let Some(variable) = clause.next_unwatched_variable( &self.learnt_clauses, - &self.cache.version_set_to_sorted_candidates, + &self.cache.requirement_to_sorted_candidates, self.decision_tracker.map(), ) { debug_assert!(!clause.watched_literals.contains(&variable)); @@ -1202,7 +1204,7 @@ impl Solver { let mut involved = HashSet::new(); self.clauses.borrow()[clause_id].kind.visit_literals( &self.learnt_clauses, - &self.cache.version_set_to_sorted_candidates, + &self.cache.requirement_to_sorted_candidates, |literal| { involved.insert(literal.solvable_id); }, @@ -1240,7 +1242,7 @@ impl Solver { self.clauses.borrow()[why].kind.visit_literals( &self.learnt_clauses, - &self.cache.version_set_to_sorted_candidates, + &self.cache.requirement_to_sorted_candidates, |literal| { if literal.eval(self.decision_tracker.map()) == Some(true) { assert_eq!(literal.solvable_id, decision.solvable_id); @@ -1286,7 +1288,7 @@ impl Solver { self.clauses.borrow()[clause_id].kind.visit_literals( &self.learnt_clauses, - &self.cache.version_set_to_sorted_candidates, + &self.cache.requirement_to_sorted_candidates, |literal| { if !first_iteration && literal.solvable_id == conflicting_solvable { // We are only interested in the causes of the conflict, so we ignore the diff --git a/src/solver/watch_map.rs b/src/solver/watch_map.rs index cf17e17..51ee685 100644 --- a/src/solver/watch_map.rs +++ b/src/solver/watch_map.rs @@ -52,10 +52,7 @@ impl WatchMap { clause.watched_literals[watch_index] = new_watch; clause.link_to_clause( watch_index, - *self - .map - .get(new_watch) - .expect("linking to unknown solvable"), + self.map.get(new_watch).copied().unwrap_or(ClauseId::null()), ); self.map.insert(new_watch, clause_id); } diff --git a/src/utils/pool.rs b/src/utils/pool.rs index 7358fcf..5fffddd 100644 --- a/src/utils/pool.rs +++ b/src/utils/pool.rs @@ -6,7 +6,8 @@ use std::{ use crate::internal::{ arena::Arena, frozen_copy_map::FrozenCopyMap, - id::{NameId, SolvableId, StringId, VersionSetId}, + id::{NameId, SolvableId, StringId, VersionSetId, VersionSetUnionId}, + small_vec::SmallVec, }; /// A solvable represents a single candidate of a package. @@ -47,6 +48,8 @@ pub struct Pool { /// Map from version set to the id of their interned counterpart version_set_to_id: FrozenCopyMap<(NameId, VS), VersionSetId, ahash::RandomState>, + + version_set_unions: Arena>, } impl Default for Pool { @@ -61,6 +64,7 @@ impl Default for Pool { string_to_ids: Default::default(), version_set_to_id: Default::default(), version_sets: Arena::new(), + version_set_unions: Arena::new(), } } } @@ -182,6 +186,32 @@ impl Pool { pub fn resolve_version_set_package_name(&self, id: VersionSetId) -> NameId { self.version_sets[id].0 } + + /// Interns a union of two or more version sets and returns its [`VersionSetUnionId`]. + /// + /// Version set unions are *not* deduplicated, and a unique id is returned on every + /// invocation. + pub fn intern_version_set_union( + &self, + first: VersionSetId, + others: impl Iterator, + ) -> VersionSetUnionId { + self.version_set_unions + .alloc(others.fold(SmallVec::one(first), |mut vec, version_set| { + vec.push(version_set); + vec + })) + } + + /// Returns the version sets in the version set union with the given id. + /// + /// Panics if there is no union with the given id. + pub fn resolve_version_set_union( + &self, + id: VersionSetUnionId, + ) -> impl Iterator + '_ { + self.version_set_unions[id].iter().copied() + } } /// A helper struct to visualize a name. diff --git a/tests/snapshots/solver__snapshot.snap b/tests/snapshots/solver__snapshot.snap index 224e308..2edcdb6 100644 --- a/tests/snapshots/solver__snapshot.snap +++ b/tests/snapshots/solver__snapshot.snap @@ -1,6 +1,5 @@ --- source: tests/solver.rs -assertion_line: 1121 expression: "solve_for_snapshot(snapshot_provider, &[menu_req])" --- dropdown=2 diff --git a/tests/snapshots/solver__snapshot_union_requirements.snap b/tests/snapshots/solver__snapshot_union_requirements.snap new file mode 100644 index 0000000..16b5780 --- /dev/null +++ b/tests/snapshots/solver__snapshot_union_requirements.snap @@ -0,0 +1,6 @@ +--- +source: tests/solver.rs +expression: "solve_for_snapshot(snapshot_provider, &[intl_req, union_req])" +--- +intl=5 +union=1 diff --git a/tests/solver.rs b/tests/solver.rs index 0064255..f5708d0 100644 --- a/tests/solver.rs +++ b/tests/solver.rs @@ -1,5 +1,6 @@ use std::{ any::Any, + borrow::Borrow, cell::{Cell, RefCell}, collections::HashSet, fmt::{Debug, Display, Formatter}, @@ -21,8 +22,9 @@ use itertools::Itertools; use resolvo::{ snapshot::{DependencySnapshot, SnapshotProvider}, utils::{Pool, Range}, - Candidates, Dependencies, DependencyProvider, Interner, KnownDependencies, NameId, SolvableId, - Solver, SolverCache, StringId, UnsolvableOrCancelled, VersionSetId, + Candidates, Dependencies, DependencyProvider, Interner, KnownDependencies, NameId, Requirement, + SolvableId, Solver, SolverCache, StringId, UnsolvableOrCancelled, VersionSetId, + VersionSetUnionId, }; use tracing_test::traced_test; @@ -116,6 +118,14 @@ impl Spec { pub fn new(name: String, versions: Range) -> Self { Self { name, versions } } + + pub fn parse_union( + spec: &str, + ) -> impl Iterator::Err>> + '_ { + spec.split('|') + .map(str::trim) + .map(|dep| Spec::from_str(dep)) + } } impl FromStr for Spec { @@ -173,7 +183,7 @@ struct BundleBoxProvider { #[derive(Debug, Clone)] struct BundleBoxPackageDependencies { - dependencies: Vec, + dependencies: Vec>, constrains: Vec, } @@ -188,18 +198,42 @@ impl BundleBoxProvider { .expect("package missing") } - pub fn requirements(&self, requirements: &[&str]) -> Vec { + pub fn requirements>(&self, requirements: &[&str]) -> Vec { requirements .iter() .map(|dep| Spec::from_str(dep).unwrap()) - .map(|spec| { - let dep_name = self.pool.intern_package_name(&spec.name); - self.pool - .intern_version_set(dep_name, spec.versions.clone()) + .map(|spec| self.intern_version_set(&spec)) + .map(From::from) + .collect() + } + + pub fn parse_requirements(&self, requirements: &[&str]) -> Vec { + requirements + .iter() + .map(|deps| { + let specs = Spec::parse_union(deps).map(Result::unwrap); + self.intern_version_set_union(specs).into() }) .collect() } + pub fn intern_version_set(&self, spec: &Spec) -> VersionSetId { + let dep_name = self.pool.intern_package_name(&spec.name); + self.pool + .intern_version_set(dep_name, spec.versions.clone()) + } + + pub fn intern_version_set_union( + &self, + specs: impl IntoIterator>, + ) -> VersionSetUnionId { + let mut specs = specs + .into_iter() + .map(|spec| self.intern_version_set(spec.borrow())); + self.pool + .intern_version_set_union(specs.next().unwrap(), specs) + } + pub fn from_packages(packages: &[(&str, u32, Vec<&str>)]) -> Self { let mut result = Self::new(); for (name, version, deps) in packages { @@ -236,7 +270,7 @@ impl BundleBoxProvider { let dependencies = dependencies .iter() - .map(|dep| Spec::from_str(dep)) + .map(|dep| Spec::parse_union(dep).collect()) .collect::, _>>() .unwrap(); @@ -319,6 +353,12 @@ impl Interner for BundleBoxProvider { fn solvable_name(&self, solvable: SolvableId) -> NameId { self.pool.resolve_solvable(solvable).name } + fn version_sets_in_union( + &self, + version_set_union: VersionSetUnionId, + ) -> impl Iterator { + self.pool.resolve_version_set_union(version_set_union) + } } impl DependencyProvider for BundleBoxProvider { @@ -436,9 +476,33 @@ impl DependencyProvider for BundleBoxProvider { constrains: Vec::with_capacity(deps.constrains.len()), }; for req in &deps.dependencies { - let dep_name = self.pool.intern_package_name(&req.name); - let dep_spec = self.pool.intern_version_set(dep_name, req.versions.clone()); - result.requirements.push(dep_spec); + let mut remaining_req_specs = req.iter(); + + let first = remaining_req_specs + .next() + .expect("Dependency spec must have at least one constraint"); + + let first_name = self.pool.intern_package_name(&first.name); + let first_version_set = self + .pool + .intern_version_set(first_name, first.versions.clone()); + + let requirement = if remaining_req_specs.len() == 0 { + first_version_set.into() + } else { + let other_version_sets = remaining_req_specs.map(|spec| { + self.pool.intern_version_set( + self.pool.intern_package_name(&spec.name), + spec.versions.clone(), + ) + }); + + self.pool + .intern_version_set_union(first_version_set, other_version_sets) + .into() + }; + + result.requirements.push(requirement); } for req in &deps.constrains { @@ -509,7 +573,7 @@ fn solve_snapshot(mut provider: BundleBoxProvider, specs: &[&str]) -> String { provider.sleep_before_return = true; - let requirements = provider.requirements(specs); + let requirements = provider.parse_requirements(specs); let mut solver = Solver::new(provider).with_runtime(runtime); match solver.solve(requirements, Vec::new()) { Ok(solvables) => transaction_to_string(solver.provider(), &solvables), @@ -833,6 +897,28 @@ fn test_resolve_cyclic() { "###); } +#[test] +fn test_resolve_union_requirements() { + let mut provider = BundleBoxProvider::from_packages(&[ + ("a", 1, vec![]), + ("b", 1, vec![]), + ("c", 1, vec!["a"]), + ("d", 1, vec!["b"]), + ("e", 1, vec!["a | b"]), + ]); + + // Make d conflict with a=1 + provider.add_package("f", 1.into(), &["b"], &["a 2"]); + + let result = solve_snapshot(provider, &["c | d", "e", "f"]); + assert_snapshot!(result, @r###" + b=1 + d=1 + e=1 + f=1 + "###); +} + #[test] fn test_unsat_locked_and_excluded() { let mut provider = BundleBoxProvider::from_packages(&[ @@ -1121,6 +1207,32 @@ fn test_snapshot() { assert_snapshot!(solve_for_snapshot(snapshot_provider, &[menu_req])); } +#[test] +fn test_snapshot_union_requirements() { + let provider = BundleBoxProvider::from_packages(&[ + ("icons", 2, vec![]), + ("icons", 1, vec![]), + ("intl", 5, vec![]), + ("intl", 3, vec![]), + ("union", 1, vec!["icons 2 | intl"]), + ]); + + let intl_name_id = provider.package_name("intl"); + let union_name_id = provider.package_name("union"); + + let snapshot = provider.into_snapshot(); + + let mut snapshot_provider = snapshot.provider(); + + let intl_req = snapshot_provider.add_package_requirement(intl_name_id); + let union_req = snapshot_provider.add_package_requirement(union_name_id); + + assert_snapshot!(solve_for_snapshot( + snapshot_provider, + &[intl_req, union_req] + )); +} + #[cfg(feature = "serde")] fn serialize_snapshot(snapshot: &DependencySnapshot, destination: impl AsRef) { let file = std::io::BufWriter::new(std::fs::File::create(destination.as_ref()).unwrap()); @@ -1129,7 +1241,10 @@ fn serialize_snapshot(snapshot: &DependencySnapshot, destination: impl AsRef String { let mut solver = Solver::new(provider); - match solver.solve(root_reqs.to_vec(), Vec::new()) { + match solver.solve( + root_reqs.iter().copied().map(Into::into).collect(), + Vec::new(), + ) { Ok(solvables) => transaction_to_string(solver.provider(), &solvables), Err(UnsolvableOrCancelled::Unsolvable(problem)) => { // Write the problem graphviz to stderr diff --git a/tools/solve-snapshot/src/main.rs b/tools/solve-snapshot/src/main.rs index 0ae50d2..a58e1e4 100644 --- a/tools/solve-snapshot/src/main.rs +++ b/tools/solve-snapshot/src/main.rs @@ -50,7 +50,7 @@ fn main() { let mut solver = Solver::new(provider); let mut records = None; let mut error = None; - match solver.solve(vec![package_requirement], vec![]) { + match solver.solve(vec![package_requirement.into()], vec![]) { Ok(solution) => { eprintln!("OK"); records = Some(solution.len())