Skip to content

Commit

Permalink
Replicate in New Branch
Browse files Browse the repository at this point in the history
  • Loading branch information
srilman committed Jan 22, 2023
1 parent 97ae704 commit a438337
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 15 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ repos:
hooks:
- id: mypy
additional_dependencies: [types-filelock, types-requests, types-toml, types-PyYAML, types-freezegun, types-setuptools, pydantic]
exclude: ^tests/test-local-pip/setup.py$
exclude: ^tests/test-local-pip/setup.py$
# args: [--enable-incomplete-feature=Unpack, --ignore-missing-imports]
7 changes: 7 additions & 0 deletions conda_lock/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,10 @@ class ChannelAggregationError(CondaLockError):
"""
Error thrown when lists of channels cannot be combined
"""


class DependencyAggregationError(CondaLockError):
"""
Error thrown when merging dependency specifications together that
are incompatible
"""
59 changes: 56 additions & 3 deletions conda_lock/src_parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing_extensions import Literal

from conda_lock.common import suffix_union
from conda_lock.errors import ChannelAggregationError
from conda_lock.errors import ChannelAggregationError, DependencyAggregationError
from conda_lock.models import StrictModel
from conda_lock.models.channel import Channel
from conda_lock.virtual_package import FakeRepoData
Expand Down Expand Up @@ -153,13 +153,66 @@ def validate_channels(cls, v: List[Union[Channel, str]]) -> List[Channel]:
return typing.cast(List[Channel], v)


def aggregate_deps(grouped_deps: List[List[Dependency]]) -> List[Dependency]:
def merge_dependencies(
dep_a: Dependency,
dep_b: Dependency,
) -> Dependency:
"""
Merge 2 Dependency Specifications Together if Valid
Either by comparing URL locations or by combining the Versioning Code
"""
assert dep_a.manager == dep_b.manager and dep_a.name == dep_b.name

if isinstance(dep_a, URLDependency) and isinstance(dep_b, URLDependency):
if dep_a != dep_b:
raise DependencyAggregationError(
f"Found conflicting URL dependency specifications for {dep_a.name} on {dep_a.manager}:\n"
f" URL 1: {dep_a.url}\n URL 2: {dep_b.url}"
)
return dep_a

# If bold old and new are VersionedDependency, combine version strings together
# If there are conflicting versions, they will be handled by the solver
if isinstance(dep_a, VersionedDependency) and isinstance(
dep_b, VersionedDependency
):
if dep_a.manager == "pip":
return VersionedDependency(
name=dep_a.name,
version=f"{dep_a.version},{dep_b.version}",
manager="pip",
optional=dep_a.optional,
category=dep_a.category,
extras=dep_a.extras,
)

from conda_lock.src_parser.conda_common import merge_version_specs

return VersionedDependency(
name=dep_a.name,
version=merge_version_specs(dep_a.version, dep_b.version),
manager="conda",
optional=dep_a.optional,
category=dep_a.category,
extras=dep_a.extras,
)

# Case when one dependency specifies a version and another a URL
raise DependencyAggregationError(
f"Found both a URL and Version Dependency Specification for {dep_a.name} on {dep_a.manager}."
"They can not be combined or solved together."
)


def aggregate_deps(grouped_deps: List[List[Dependency]]) -> List[Dependency]:
# List unique dependencies
unique_deps: Dict[Tuple[str, str], Dependency] = {}
for dep in chain.from_iterable(grouped_deps):
key = (dep.manager, dep.name)
unique_deps[key] = dep
if key in unique_deps:
unique_deps[key] = merge_dependencies(unique_deps[key], dep)
else:
unique_deps[key] = dep

return list(unique_deps.values())

Expand Down
20 changes: 20 additions & 0 deletions conda_lock/src_parser/conda_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .._vendor.conda.models.channel import Channel
from .._vendor.conda.models.match_spec import MatchSpec
from .._vendor.conda.models.version import treeify, untreeify
from ..src_parser import SourceDependency, VersionedDependency


Expand Down Expand Up @@ -31,3 +32,22 @@ def conda_spec_to_versioned_dep(spec: str, category: str) -> SourceDependency:
build=ms.get("build"),
conda_channel=channel_str,
).to_source()


def merge_version_specs(ver_a: str, ver_b: str) -> str:
"""Merge / And 2 Conda VersionSpec Strings Together"""
if ver_a == ver_b:
return ver_a

# Conda has tools for parsing VersionSpec into a tree format
ver_a_tree = treeify(ver_a)
ver_b_tree = treeify(ver_b)

if (
isinstance(ver_a_tree, tuple)
and isinstance(ver_b_tree, tuple)
and ver_a_tree[0] == ver_b_tree[0] == ","
):
return untreeify((",", *ver_a_tree, *ver_b_tree))
else:
return untreeify((",", ver_a_tree, ver_b_tree))
34 changes: 23 additions & 11 deletions tests/test_conda_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,14 @@ def _make_source_dep(
)


def _make_spec(name: str, constraint: str = "*", **kwargs):
return VersionedDependency(
name=name,
version=constraint,
**kwargs,
)


def test_parse_environment_file(gdal_environment: Path):
res = parse_environment_file(gdal_environment, pip_support=True)
assert all(
Expand Down Expand Up @@ -1089,13 +1097,6 @@ def test_run_with_channel_inversion(
raise ValueError("cuda-python not found!")


def _make_spec(name: str, constraint: str = "*"):
return VersionedDependency(
name=name,
version=constraint,
)


def test_lock_spec(lock_spec_env: List[Path]):
"""Ensure that the way two files combine when both specify channels is correct"""
actual = make_lock_spec(src_file_paths=lock_spec_env, virtual_package_repo=None) # type: ignore
Expand Down Expand Up @@ -1135,11 +1136,22 @@ def test_aggregate_lock_specs_multiple_platforms(lock_spec_env_platforms: List[P
assert actual.content_hash() == expected.content_hash()


def test_aggregate_lock_specs_override_version():
base_deps: List[Dependency] = [_make_spec("package", "=1.0")]
override_deps: List[Dependency] = [_make_spec("package", "=2.0")]
def test_aggregate_lock_specs_combine_version():
for manager in ("conda", "pip"):
base_deps: List[Dependency] = [_make_spec("package", "=1.0", manager=manager)]
override_deps: List[Dependency] = [
_make_spec("package", "=2.0", manager=manager)
]
agg_deps = aggregate_deps([base_deps, override_deps])
assert agg_deps == [_make_spec("package", "=1.0,=2.0", manager=manager)]


def test_aggregate_lock_specs_with_union_version():
base_deps: List[Dependency] = [_make_spec("package", "=1.0|>2")]
override_deps: List[Dependency] = [_make_spec("package", "=2.0,<3.0.0")]

agg_deps = aggregate_deps([base_deps, override_deps])
assert agg_deps == override_deps
assert agg_deps == [_make_spec("package", "(=1.0|>2),(=2.0,<3.0.0)")]


def test_aggregate_channels():
Expand Down

0 comments on commit a438337

Please sign in to comment.