Skip to content

Commit

Permalink
test: add snapshot tests to verify solver sorting order (#895)
Browse files Browse the repository at this point in the history
  • Loading branch information
baszalmstra authored Oct 7, 2024
1 parent 7d692da commit 2e2cf94
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 10 deletions.
10 changes: 5 additions & 5 deletions crates/rattler_conda_types/src/version_spec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,20 +500,20 @@ mod tests {
#[test]
fn issue_bracket_printing() {
let v = VersionSpec::from_str("(>=1,<2)|>3", ParseStrictness::Lenient).unwrap();
assert_eq!(format!("{}", v), ">=1,<2|>3");
assert_eq!(format!("{v}"), ">=1,<2|>3");

let v = VersionSpec::from_str("(>=1|<2),>3", ParseStrictness::Lenient).unwrap();
assert_eq!(format!("{}", v), "(>=1|<2),>3");
assert_eq!(format!("{v}"), "(>=1|<2),>3");

let v = VersionSpec::from_str("(>=1|<2)|>3", ParseStrictness::Lenient).unwrap();
assert_eq!(format!("{}", v), ">=1|<2|>3");
assert_eq!(format!("{v}"), ">=1|<2|>3");

let v = VersionSpec::from_str("(>=1,<2),>3", ParseStrictness::Lenient).unwrap();
assert_eq!(format!("{}", v), ">=1,<2,>3");
assert_eq!(format!("{v}"), ">=1,<2,>3");

let v =
VersionSpec::from_str("((>=1|>2),(>3|>4))|(>5,<6)", ParseStrictness::Lenient).unwrap();
assert_eq!(format!("{}", v), "(>=1|>2),(>3|>4)|>5,<6");
assert_eq!(format!("{v}"), "(>=1|>2),(>3|>4)|>5,<6");
}

#[test]
Expand Down
11 changes: 8 additions & 3 deletions crates/rattler_solve/src/resolvo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl<'a> SolverRepoData<'a> for RepoData<'a> {}
/// Wrapper around `MatchSpec` so that we can use it in the `resolvo` pool
#[repr(transparent)]
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
struct SolverMatchSpec<'a> {
pub struct SolverMatchSpec<'a> {
inner: NamelessMatchSpec,
_marker: PhantomData<&'a PackageRecord>,
}
Expand Down Expand Up @@ -84,8 +84,11 @@ impl<'a> VersionSet for SolverMatchSpec<'a> {

/// Wrapper around [`PackageRecord`] so that we can use it in resolvo pool
#[derive(Eq, PartialEq)]
enum SolverPackageRecord<'a> {
pub enum SolverPackageRecord<'a> {
/// Represents a record from the repodata
Record(&'a RepoDataRecord),

/// Represents a virtual package.
VirtualPackage(&'a GenericVirtualPackage),
}

Expand Down Expand Up @@ -161,7 +164,8 @@ impl<'a> Display for SolverPackageRecord<'a> {
/// packages.
#[derive(Default)]
pub struct CondaDependencyProvider<'a> {
pool: Pool<SolverMatchSpec<'a>, String>,
/// The pool that deduplicates data used by the provider.
pub pool: Pool<SolverMatchSpec<'a>, String>,

records: HashMap<NameId, Candidates>,

Expand Down Expand Up @@ -512,6 +516,7 @@ impl<'a> DependencyProvider for CondaDependencyProvider<'a> {
return Dependencies::Unknown(reason);
}
};

dependencies.requirements.push(version_set_id.into());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
---
source: crates/rattler_solve/tests/sorting.rs
expression: create_sorting_snapshot(package_name)
---
abess=0.4.5=py39hac2352c_1
abess=0.4.5=py37h48bf904_1
abess=0.4.5=py38h514daf8_1
abess=0.4.5=py310hc4a4660_1
abess=0.4.5=py37h48bf904_0
abess=0.4.5=py38h514daf8_0
abess=0.4.5=py39hac2352c_0
abess=0.4.4=py38h514daf8_1
abess=0.4.4=py37h48bf904_1
abess=0.4.4=py39hac2352c_1
abess=0.4.4=py38h514daf8_0
abess=0.4.4=py39hac2352c_0
abess=0.4.4=py37h48bf904_0
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
source: crates/rattler_solve/tests/sorting.rs
expression: create_sorting_snapshot(package_name)
---
libuuid=2.32.1=h7f98852_1000
libuuid=2.32.1=h14c3975_1000
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
---
source: crates/rattler_solve/tests/sorting.rs
expression: create_sorting_snapshot(spec)
---
python=3.10.5=ha86cf86_0_cpython
python=3.10.5=h582c2e5_0_cpython
python=3.10.4=h2660328_0_cpython
python=3.10.4=h9a8a25e_0_cpython
python=3.10.2=hc74c709_4_cpython
python=3.10.2=h85951f9_4_cpython
python=3.10.2=hc74c709_3_cpython
python=3.10.2=h85951f9_3_cpython
python=3.10.2=h85951f9_2_cpython
python=3.10.2=hc74c709_2_cpython
python=3.10.2=h85951f9_1_cpython
python=3.10.2=hc74c709_1_cpython
python=3.10.2=h543edf9_0_cpython
python=3.10.2=h62f1059_0_cpython
python=3.10.1=h543edf9_2_cpython
python=3.10.1=h62f1059_2_cpython
python=3.10.1=h62f1059_1_cpython
python=3.10.1=h543edf9_1_cpython
python=3.10.1=h543edf9_0_cpython
python=3.10.1=h62f1059_0_cpython
python=3.10.0=h62f1059_3_cpython
python=3.10.0=h543edf9_3_cpython
python=3.10.0=h543edf9_2_cpython
python=3.10.0=h62f1059_2_cpython
python=3.10.0=h543edf9_1_cpython
python=3.10.0=h62f1059_1_cpython
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
---
source: crates/rattler_solve/tests/sorting.rs
expression: create_sorting_snapshot(spec)
---
pytorch=1.12.1=cuda110py39h0a9da28_200
pytorch=1.12.1=cuda110py37h0def887_200
pytorch=1.12.1=cuda110py38h386aa8f_200
pytorch=1.12.1=cuda102py37haad9b4f_200
pytorch=1.12.1=cuda110py310hfdf97d1_200
pytorch=1.12.1=cuda111py310h385535d_200
pytorch=1.12.1=cuda111py37hdb2541a_200
pytorch=1.12.1=cuda112py310h51fe464_200
pytorch=1.12.1=cuda112py38habe9d5a_200
pytorch=1.12.1=cuda102py38hfdb21e3_200
pytorch=1.12.1=cuda111py38h2d04dd0_200
pytorch=1.12.1=cuda111py39h9f128c5_200
pytorch=1.12.1=cuda112py39ha0cca9b_200
pytorch=1.12.1=cuda112py37hfcfbd4c_200
pytorch=1.12.1=cuda102py39hbbcd3cb_200
pytorch=1.12.1=cuda102py310hdf4a2db_200
pytorch=1.12.1=cpu_py310h75c9ab6_0
pytorch=1.12.1=cpu_py39h5d22d69_0
pytorch=1.12.1=cpu_py38h39c826d_0
pytorch=1.12.1=cpu_py37h14e09b7_0
pytorch=1.12.0=cuda111py310h385535d_202
pytorch=1.12.0=cuda112py310h51fe464_202
pytorch=1.12.0=cuda110py39h0a9da28_202
pytorch=1.12.0=cuda111py38h2d04dd0_202
pytorch=1.12.0=cuda112py39ha0cca9b_202
pytorch=1.12.0=cuda102py39hbbcd3cb_202
pytorch=1.12.0=cuda111py39h9f128c5_202
pytorch=1.12.0=cuda110py310hfdf97d1_202
pytorch=1.12.0=cuda102py37haad9b4f_202
pytorch=1.12.0=cuda112py37hfcfbd4c_202
pytorch=1.12.0=cuda111py37hdb2541a_202
pytorch=1.12.0=cuda110py37h0def887_202
pytorch=1.12.0=cuda110py38h386aa8f_202
pytorch=1.12.0=cuda102py310hdf4a2db_202
pytorch=1.12.0=cuda112py38habe9d5a_202
pytorch=1.12.0=cuda102py38hfdb21e3_202
pytorch=1.12.0=cuda111py310h385535d_200
pytorch=1.12.0=cuda112py310h51fe464_200
pytorch=1.12.0=cuda110py39h0a9da28_200
pytorch=1.12.0=cuda111py38h2d04dd0_200
pytorch=1.12.0=cuda112py39ha0cca9b_200
pytorch=1.12.0=cuda102py39hbbcd3cb_200
pytorch=1.12.0=cuda111py39h9f128c5_200
pytorch=1.12.0=cuda110py310hfdf97d1_200
pytorch=1.12.0=cuda102py37haad9b4f_200
pytorch=1.12.0=cuda112py37hfcfbd4c_200
pytorch=1.12.0=cuda111py37hdb2541a_200
pytorch=1.12.0=cuda110py37h0def887_200
pytorch=1.12.0=cuda110py38h386aa8f_200
pytorch=1.12.0=cuda102py310hdf4a2db_200
pytorch=1.12.0=cuda112py38habe9d5a_200
pytorch=1.12.0=cuda102py38hfdb21e3_200
pytorch=1.12.0=cpu_py39h5d22d69_2
pytorch=1.12.0=cpu_py38h39c826d_2
pytorch=1.12.0=cpu_py310h75c9ab6_2
pytorch=1.12.0=cpu_py37h14e09b7_2
pytorch=1.12.0=cpu_py38h39c826d_1
pytorch=1.12.0=cpu_py310h75c9ab6_1
pytorch=1.12.0=cpu_py37h14e09b7_1
pytorch=1.12.0=cpu_py39h5d22d69_1
pytorch=1.12.0=cpu_py37h14e09b7_0
pytorch=1.12.0=cpu_py38h39c826d_0
pytorch=1.12.0=cpu_py310h75c9ab6_0
pytorch=1.12.0=cpu_py39h5d22d69_0
91 changes: 91 additions & 0 deletions crates/rattler_solve/tests/sorting.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//! Tests that the sorting of candidates remains the same.
use std::path::Path;

use futures::FutureExt;
use itertools::Itertools;
use rattler_conda_types::{
Channel, MatchSpec, Matches, PackageName, ParseStrictness::Lenient, RepoDataRecord,
};
use rattler_repodata_gateway::sparse::SparseRepoData;
use rattler_solve::{resolvo::CondaDependencyProvider, ChannelPriority, SolveStrategy};
use resolvo::{Interner, SolverCache};
use rstest::*;

fn load_repodata(package_name: &PackageName) -> Vec<RepoDataRecord> {
let channel_path = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("../..")
.join("test-data")
.join("channels")
.join("conda-forge");
let repodata_json_path = channel_path.join("linux-64").join("repodata.json");
let channel = Channel::from_directory(&channel_path);

let sparse_repo_data = SparseRepoData::new(channel, "linux-64", repodata_json_path, None)
.expect("failed to load sparse repodata");
sparse_repo_data
.load_records(package_name)
.expect("failed to load records")
}

fn create_sorting_snapshot(package_name: &str) -> String {
let match_spec = MatchSpec::from_str(package_name, Lenient).unwrap();
let package_name = match_spec.name.clone().unwrap();

// Load repodata
let mut repodata = load_repodata(&package_name);
repodata.retain(|record| match_spec.matches(record));

// Construct dependency provider
let dependency_provider = CondaDependencyProvider::new(
[repodata.iter().collect()],
&[],
&[],
&[],
&[],
None,
ChannelPriority::default(),
None,
SolveStrategy::default(),
)
.expect("failed to create dependency provider");

let name = dependency_provider
.pool
.intern_package_name(package_name.as_normalized());
let version_set = dependency_provider
.pool
.intern_version_set(name, match_spec.into_nameless().1.into());

// Construct a cache
let cache = SolverCache::new(dependency_provider);

// Get the candidates for the package
let sorted_candidates = cache
.get_or_cache_sorted_candidates(version_set.into())
.now_or_never()
.expect("failed to get candidates")
.expect("solver requested cancellation");

sorted_candidates
.iter()
.map(|&candidate| cache.provider().display_solvable(candidate))
.format("\n")
.to_string()
}

#[rstest]
#[case::pytorch("pytorch >=1.12.0")]
#[case::python("python ~=3.10.*")]
#[case::libuuid("libuuid")]
#[case::abess("abess")]
// #[case::certifi("certifi >=2016.9.26")]
fn test_ordering(#[case] spec: &str) {
insta::assert_snapshot!(
format!(
"test_ordering_{}",
spec.split_whitespace().next().unwrap_or(spec)
),
create_sorting_snapshot(spec)
);
}
4 changes: 2 additions & 2 deletions crates/rattler_virtual_packages/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -609,11 +609,11 @@ mod test {

use rattler_conda_types::Version;

use crate::{Cuda, EnvOverride, LibC, Osx, Override, VirtualPackage};
use crate::{Cuda, EnvOverride, LibC, Osx, Override, VirtualPackage, VirtualPackageOverrides};

#[test]
fn doesnt_crash() {
let virtual_packages = VirtualPackage::detect(&Default::default()).unwrap();
let virtual_packages = VirtualPackage::detect(&VirtualPackageOverrides::default()).unwrap();
println!("{virtual_packages:?}");
}
#[test]
Expand Down

0 comments on commit 2e2cf94

Please sign in to comment.