Skip to content

Commit

Permalink
Fix extras check in RequirementCache (#283)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
awaelchli and Borda authored Jul 15, 2024
1 parent d42f7c0 commit 2e4e185
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/lightning_utilities/__about__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time

__version__ = "0.11.4"
__version__ = "0.11.5"
__author__ = "Lightning AI et al."
__author_email__ = "pytorch@lightning.ai"
__license__ = "Apache-2.0"
Expand Down
37 changes: 35 additions & 2 deletions src/lightning_utilities/core/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import warnings
from functools import lru_cache
from importlib.metadata import PackageNotFoundError
from importlib.metadata import PackageNotFoundError, distribution
from importlib.metadata import version as _version
from importlib.util import find_spec
from types import ModuleType
Expand Down Expand Up @@ -128,7 +128,9 @@ def _check_requirement(self) -> None:
try:
req = Requirement(self.requirement)
pkg_version = Version(_version(req.name))
self.available = req.specifier.contains(pkg_version)
self.available = req.specifier.contains(pkg_version) and (
not req.extras or self._check_extras_available(req)
)
except (PackageNotFoundError, InvalidVersion) as ex:
self.available = False
self.message = f"{ex.__class__.__name__}: {ex}. HINT: Try running `pip install -U {self.requirement!r}`"
Expand All @@ -143,6 +145,9 @@ def _check_requirement(self) -> None:
self.available = module_available(module)
if self.available:
self.message = f"Module {module!r} available"
self.message = (
f"Requirement {self.requirement!r} not met. HINT: Try running `pip install -U {self.requirement!r}`"
)

def _check_module(self) -> None:
assert self.module # noqa: S101; needed for typing
Expand All @@ -160,6 +165,34 @@ def _check_available(self) -> None:
if getattr(self, "available", True) and self.module:
self._check_module()

def _check_extras_available(self, requirement: Requirement) -> bool:
if not requirement.extras:
return True

extra_requirements = self._get_extra_requirements(requirement)

if not extra_requirements:
# The specified extra is not found in the package metadata
return False

# Verify each extra requirement is installed
for extra_req in extra_requirements:
try:
extra_dist = distribution(extra_req.name)
extra_installed_version = Version(extra_dist.version)
if extra_req.specifier and not extra_req.specifier.contains(extra_installed_version):
return False
except importlib.metadata.PackageNotFoundError:
return False

return True

def _get_extra_requirements(self, requirement: Requirement) -> List[Requirement]:
dist = distribution(requirement.name)
# Get the required dependencies for the specified extras
extra_requirements = dist.metadata.get_all("Requires-Dist") or []
return [Requirement(r) for r in extra_requirements if any(extra in r for extra in requirement.extras)]

def __bool__(self) -> bool:
"""Format as bool."""
self._check_available()
Expand Down
37 changes: 37 additions & 0 deletions tests/unittests/core/test_imports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import operator
import re
from unittest import mock
from unittest.mock import Mock

import pytest
from lightning_utilities.core.imports import (
Expand Down Expand Up @@ -61,6 +63,41 @@ def test_requirement_cache():
assert not cache
assert "pip install -U 'this_module_is_not_installed" in str(cache)

cache = RequirementCache("pytest[not-valid-extra]")
assert not cache
assert "pip install -U 'pytest[not-valid-extra]" in str(cache)


@mock.patch("lightning_utilities.core.imports.Requirement")
@mock.patch("lightning_utilities.core.imports._version")
@mock.patch("lightning_utilities.core.imports.distribution")
def test_requirement_cache_with_extras(distribution_mock, version_mock, requirement_mock):
requirement_mock().specifier.contains.return_value = True
requirement_mock().name = "jsonargparse"
requirement_mock().extras = []
version_mock.return_value = "1.0.0"
assert RequirementCache("jsonargparse>=1.0.0")

with mock.patch("lightning_utilities.core.imports.RequirementCache._get_extra_requirements") as get_extra_req_mock:
get_extra_req_mock.return_value = [
# Extra packages, all versions satisfied
Mock(name="extra_package1", specifier=Mock(contains=Mock(return_value=True))),
Mock(name="extra_package2", specifier=Mock(contains=Mock(return_value=True))),
]
distribution_mock.return_value = Mock(version="0.10.0")
requirement_mock().extras = ["signatures"]
assert RequirementCache("jsonargparse[signatures]>=1.0.0")

with mock.patch("lightning_utilities.core.imports.RequirementCache._get_extra_requirements") as get_extra_req_mock:
get_extra_req_mock.return_value = [
# Extra packages, but not all versions are satisfied
Mock(name="extra_package1", specifier=Mock(contains=Mock(return_value=True))),
Mock(name="extra_package2", specifier=Mock(contains=Mock(return_value=False))),
]
distribution_mock.return_value = Mock(version="0.10.0")
requirement_mock().extras = ["signatures"]
assert not RequirementCache("jsonargparse[signatures]>=1.0.0")


def test_module_available_cache():
assert RequirementCache(module="pytest")
Expand Down

0 comments on commit 2e4e185

Please sign in to comment.