Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More resillient RequirementCache that checks for module importability #112

Merged
merged 9 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 26 additions & 45 deletions src/lightning_utilities/core/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,67 +84,48 @@ def compare_version(package: str, op: Callable, version: str, use_base_version:


class RequirementCache:
"""Boolean-like class for check of requirement with extras and version specifiers.
"""Boolean-like class to check for requirement and module availability.

Args:
requirement: The requirement to check, version specifiers are allowed.
module: The optional module to try to import if the requirement check fails.

>>> RequirementCache("torch>=0.1")
Requirement 'torch>=0.1' met
>>> bool(RequirementCache("torch>=0.1"))
True
>>> bool(RequirementCache("torch>100.0"))
False
"""

def __init__(self, requirement: str) -> None:
self.requirement = requirement

def _check_requirement(self) -> None:
if not hasattr(self, "available"):
try:
pkg_resources.require(self.requirement)
self.available = True
self.message = f"Requirement {self.requirement!r} met"
except Exception as ex:
self.available = False
self.message = f"{ex.__class__.__name__}: {ex}. HINT: Try running `pip install -U {self.requirement!r}`"

def __bool__(self) -> bool:
"""Format as bool."""
self._check_requirement()
return self.available

def __str__(self) -> str:
"""Format as string."""
self._check_requirement()
return self.message

def __repr__(self) -> str:
"""Format as string."""
return self.__str__()


class ModuleAvailableCache:
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""Boolean-like class for check of module availability.

>>> ModuleAvailableCache("torch")
Module 'torch' available
>>> bool(ModuleAvailableCache("torch"))
>>> RequirementCache("torch")
Requirement 'torch' met
>>> bool(RequirementCache("torch"))
True
>>> bool(ModuleAvailableCache("unknown_package"))
>>> bool(RequirementCache("unknown_package"))
False
"""

def __init__(self, module: str) -> None:
def __init__(self, requirement: str, module: Optional[str] = None) -> None:
self.requirement = requirement
self.module = module

def _check_requirement(self) -> None:
Borda marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(self, "available"):
return

self.available = module_available(self.module)
if self.available:
self.message = f"Module {self.module!r} available"
else:
self.message = f"Module not found: {self.module!r}. HINT: Try running `pip install -U {self.module}`"
try:
# first try the pkg_resources requirement
pkg_resources.require(self.requirement)
self.available = True
self.message = f"Requirement {self.requirement!r} met"
except Exception as ex:
self.available = False
self.message = f"{ex.__class__.__name__}: {ex}. HINT: Try running `pip install -U {self.requirement!r}`"
requirement_contains_version_specifier = any(c in self.requirement for c in "=<>")
if not requirement_contains_version_specifier or self.module is not None:
module = self.requirement if self.module is None else self.module
# sometimes `pkg_resources.require()` fails but the module is importable
self.available = module_available(module)
if self.available:
self.message = f"Module {module!r} available"

def __bool__(self) -> bool:
"""Format as bool."""
Expand Down
14 changes: 9 additions & 5 deletions tests/unittests/core/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
get_dependency_min_version_spec,
lazy_import,
module_available,
ModuleAvailableCache,
RequirementCache,
requires,
)
Expand Down Expand Up @@ -52,11 +51,16 @@ def test_requirement_cache():
assert not RequirementCache(f"pytest<{pytest.__version__}")
assert "pip install -U '-'" in str(RequirementCache("-"))

# invalid requirement is skipped by valid module
assert RequirementCache(f"pytest<{pytest.__version__}", "pytest")

def test_module_available_cache():
assert ModuleAvailableCache("pytest")
assert not ModuleAvailableCache("this_module_is_not_installed")
assert "pip install -U this_module_is_not_installed" in str(ModuleAvailableCache("this_module_is_not_installed"))
cache = RequirementCache("this_module_is_not_installed")
assert not cache
assert "pip install -U 'this_module_is_not_installed" in str(cache)

cache = RequirementCache("this_module_is_not_installed", "this_also_is_not")
assert not cache
assert "pip install -U 'this_module_is_not_installed" in str(cache)


def test_get_dependency_min_version_spec():
Expand Down