From 5484a88e64a3c8475c625ddcb8961b4171ac6484 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Josef=20Kemetm=C3=BCller?= Date: Fri, 10 Nov 2023 10:43:27 +0100 Subject: [PATCH] Attempt supporting multi-file version constraints --- conda_lock/models/lock_spec.py | 109 ++++++++++++++++++++++++++- conda_lock/src_parser/aggregation.py | 2 +- tests/test_conda_lock.py | 20 +++-- 3 files changed, 122 insertions(+), 9 deletions(-) diff --git a/conda_lock/models/lock_spec.py b/conda_lock/models/lock_spec.py index 6448800b0..0a16d8364 100644 --- a/conda_lock/models/lock_spec.py +++ b/conda_lock/models/lock_spec.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import hashlib import json import pathlib import typing -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Type, Union from pydantic import BaseModel, Field, validator from typing_extensions import Literal @@ -24,23 +26,128 @@ class _BaseDependency(StrictModel): def sorted_extras(cls, v: List[str]) -> List[str]: return sorted(v) + def merge(self, other): + if other is None: + return self + if ( + self.name != other.name + or self.manager != other.manager + or self.category != other.category + ): + raise ValueError( + "Cannot merge incompatible dependencies: {self} != {other}" + ) + return _BaseDependency( + name=self.name, + manager=self.manager, + category=self.category, + extras=list(set(self.extras + other.extras)), + ) + class VersionedDependency(_BaseDependency): version: str build: Optional[str] = None conda_channel: Optional[str] = None + @staticmethod + def _merge_versions(version1, version2): + if version1 is None or version1 == "": + return version2 + if version2 is None or version2 == "": + return version1 + return f"{version1},{version2}" + + def merge(self, other): + if other is None: + return self + if ( + self.build is not None + and other.build is not None + and self.build != other.build + ): + raise ValueError( + f"VersionedDependency has two different builds:\n{self}\n{other}" + ) + + if ( + self.conda_channel is not None + and other.conda_channel is not None + and self.conda_channel != other.conda_channel + ): + raise ValueError( + f"VersionedDependency has two different conda_channels:\n{self}\n{other}" + ) + merged_base = super().merge(other) + return VersionedDependency( + name=merged_base.name, + manager=merged_base.manager, + category=merged_base.category, + extras=merged_base.extras, + version=self._merge_versions(self.version, other.version), + build=self.build or other.build, + conda_channel=self.conda_channel or other.conda_channel, + ) + class URLDependency(_BaseDependency): url: str hashes: List[str] + def merge(self, other): + if other is None: + return self + if self.url != other.url: + raise ValueError(f"URLDependency has two different urls:\n{self}\n{other}") + + if self.hashes != other.hashes: + raise ValueError( + f"URLDependency has two different hashess:\n{self}\n{other}" + ) + merged_base = super().merge(other) + + return URLDependency( + name=merged_base.name, + manager=merged_base.manager, + category=merged_base.category, + extras=merged_base.extras, + url=self.url, + hashes=self.hashes, + ) + class VCSDependency(_BaseDependency): source: str vcs: str rev: Optional[str] = None + def merge(self, other): + if other is None: + return self + if self.source != other.source: + raise ValueError( + f"VCSDependency has two different sources:\n{self}\n{other}" + ) + + if self.vcs != other.vcs: + raise ValueError()( + f"VCSDependency has two different vcss:\n{self}\n{other}" + ) + + if self.rev is not None and other.rev is not None and self.rev != other.rev: + raise ValueError(f"VCSDependency has two different revs:\n{self}\n{other}") + merged_base = super().merge(other) + + return VCSDependency( + name=merged_base.name, + manager=merged_base.manager, + category=merged_base.category, + extras=merged_base.extras, + source=self.source, + vcs=self.vcs, + rev=self.rev or other.rev, + ) + Dependency = Union[VersionedDependency, URLDependency, VCSDependency] diff --git a/conda_lock/src_parser/aggregation.py b/conda_lock/src_parser/aggregation.py index d2b5349b3..401ac3004 100644 --- a/conda_lock/src_parser/aggregation.py +++ b/conda_lock/src_parser/aggregation.py @@ -34,7 +34,7 @@ def aggregate_lock_specs( lock_spec.dependencies.get(platform, []) for lock_spec in lock_specs ): key = (dep.manager, dep.name) - unique_deps[key] = dep + unique_deps[key] = dep.merge(unique_deps.get(key)) dependencies[platform] = list(unique_deps.values()) diff --git a/tests/test_conda_lock.py b/tests/test_conda_lock.py index afbc5f402..a63cf4639 100644 --- a/tests/test_conda_lock.py +++ b/tests/test_conda_lock.py @@ -1622,22 +1622,28 @@ def test_aggregate_lock_specs(): assert actual.content_hash() == expected.content_hash() -def test_aggregate_lock_specs_override_version(): - base_spec = LockSpecification( - dependencies={"linux-64": [_make_spec("package", "=1.0")]}, +def test_aggregate_lock_specs_combine_version(): + first_spec = LockSpecification( + dependencies={"linux-64": [_make_spec("package", ">1.0")]}, channels=[Channel.from_string("conda-forge")], sources=[Path("base.yml")], ) - override_spec = LockSpecification( - dependencies={"linux-64": [_make_spec("package", "=2.0")]}, + second_spec = LockSpecification( + dependencies={"linux-64": [_make_spec("package", "<2.0")]}, + channels=[Channel.from_string("internal"), Channel.from_string("conda-forge")], + sources=[Path("override.yml")], + ) + + result_spec = LockSpecification( + dependencies={"linux-64": [_make_spec("package", ">1.0,<2.0")]}, channels=[Channel.from_string("internal"), Channel.from_string("conda-forge")], sources=[Path("override.yml")], ) - agg_spec = aggregate_lock_specs([base_spec, override_spec], platforms=["linux-64"]) + agg_spec = aggregate_lock_specs([first_spec, second_spec], platforms=["linux-64"]) - assert agg_spec.dependencies == override_spec.dependencies + assert agg_spec.dependencies == result_spec.dependencies def test_aggregate_lock_specs_invalid_channels():