Skip to content

Commit

Permalink
More resillient RequirementCache that checks for module importabili…
Browse files Browse the repository at this point in the history
…ty (#112)

Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
carmocca and Borda authored Feb 22, 2023
1 parent 340e1c8 commit 14d3e21
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 11 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- More resilient RequirementCache that checks for module import-ability ([#112](https://github.com/Lightning-AI/utilities/pull/112))


## [0.7.0] - 2023-02-20
Expand Down
40 changes: 30 additions & 10 deletions src/lightning_utilities/core/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,28 +84,48 @@ def compare_version(package: str, op: Callable, version: str, use_base_version:


class RequirementCache:
"""Boolean-like class for check of requirement with extras and version specifiers.
"""Boolean-like class to check for requirement and module availability.
Args:
requirement: The requirement to check, version specifiers are allowed.
module: The optional module to try to import if the requirement check fails.
>>> RequirementCache("torch>=0.1")
Requirement 'torch>=0.1' met
>>> bool(RequirementCache("torch>=0.1"))
True
>>> bool(RequirementCache("torch>100.0"))
False
>>> RequirementCache("torch")
Requirement 'torch' met
>>> bool(RequirementCache("torch"))
True
>>> bool(RequirementCache("unknown_package"))
False
"""

def __init__(self, requirement: str) -> None:
def __init__(self, requirement: str, module: Optional[str] = None) -> None:
self.requirement = requirement
self.module = module

def _check_requirement(self) -> None:
if not hasattr(self, "available"):
try:
pkg_resources.require(self.requirement)
self.available = True
self.message = f"Requirement {self.requirement!r} met"
except Exception as ex:
self.available = False
self.message = f"{ex.__class__.__name__}: {ex}. HINT: Try running `pip install -U {self.requirement!r}`"
if hasattr(self, "available"):
return
try:
# first try the pkg_resources requirement
pkg_resources.require(self.requirement)
self.available = True
self.message = f"Requirement {self.requirement!r} met"
except Exception as ex:
self.available = False
self.message = f"{ex.__class__.__name__}: {ex}. HINT: Try running `pip install -U {self.requirement!r}`"
requirement_contains_version_specifier = any(c in self.requirement for c in "=<>")
if not requirement_contains_version_specifier or self.module is not None:
module = self.requirement if self.module is None else self.module
# sometimes `pkg_resources.require()` fails but the module is importable
self.available = module_available(module)
if self.available:
self.message = f"Module {module!r} available"

def __bool__(self) -> bool:
"""Format as bool."""
Expand Down
11 changes: 11 additions & 0 deletions tests/unittests/core/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ def test_requirement_cache():
assert not RequirementCache(f"pytest<{pytest.__version__}")
assert "pip install -U '-'" in str(RequirementCache("-"))

# invalid requirement is skipped by valid module
assert RequirementCache(f"pytest<{pytest.__version__}", "pytest")

cache = RequirementCache("this_module_is_not_installed")
assert not cache
assert "pip install -U 'this_module_is_not_installed" in str(cache)

cache = RequirementCache("this_module_is_not_installed", "this_also_is_not")
assert not cache
assert "pip install -U 'this_module_is_not_installed" in str(cache)


def test_module_available_cache():
assert ModuleAvailableCache("pytest")
Expand Down

0 comments on commit 14d3e21

Please sign in to comment.