Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add requires wrapper #70

Merged
merged 14 commits into from
Dec 13, 2022
Merged
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added method to lazily import modules ([#71](https://github.com/Lightning-AI/utilities/pull/71))


- Added `requires` wrapper ([#70](https://github.com/Lightning-AI/utilities/pull/70))


### Changed

### Removed
Expand Down
39 changes: 38 additions & 1 deletion src/lightning_utilities/core/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0

import functools
import importlib
import operator
import os
import warnings
from functools import lru_cache
from importlib.util import find_spec
from types import ModuleType
Expand Down Expand Up @@ -121,7 +124,7 @@ def __repr__(self) -> str:
def get_dependency_min_version_spec(package_name: str, dependency_name: str) -> str:
"""Returns the minimum version specifier of a dependency of a package.

>>> get_dependency_min_version_spec("pytorch-lightning", "jsonargparse")
>>> get_dependency_min_version_spec("pytorch-lightning==1.8.0", "jsonargparse")
'>=4.12.0'
"""
dependencies = metadata.requires(package_name) or []
Expand Down Expand Up @@ -189,3 +192,37 @@ def lazy_import(module_name: str, callback: Optional[Callable] = None) -> LazyMo
a proxy module object that will be lazily imported when first used
"""
return LazyModule(module_name, callback=callback)


def requires(*module_path: str, raise_exception: bool = True) -> Callable:
"""Wrapper for early import failure with some nice exception message.

Example:

>>> @requires("libpath", raise_exception=bool(int(os.getenv("LIGHTING_TESTING", "0"))))
... def my_cwd():
... from pathlib import Path
... return Path(__file__).parent

>>> class MyRndPower:
... @requires("math", "random")
... def __init__(self):
... from math import pow
... from random import randint
... self._rnd = pow(randint(1, 9), 2)
"""

def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
unavailable_modules = [module for module in module_path if not module_available(module)]
if any(unavailable_modules):
msg = f"Required dependencies not available. Please run `pip install {' '.join(unavailable_modules)}`"
if raise_exception:
raise ModuleNotFoundError(msg)
warnings.warn(msg)
return func(*args, **kwargs)

return wrapper

return decorator
60 changes: 56 additions & 4 deletions tests/unittests/core/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
lazy_import,
module_available,
RequirementCache,
requires,
)

try:
Expand All @@ -27,8 +28,6 @@ def test_module_exists():


def testcompare_version(monkeypatch):
import pytest

monkeypatch.setattr(pytest, "__version__", "1.8.9")
assert not compare_version("pytest", operator.ge, "1.10.0")
assert compare_version("pytest", operator.lt, "1.10.0")
Expand All @@ -48,8 +47,6 @@ def testcompare_version(monkeypatch):


def test_requirement_cache():
import pytest

assert RequirementCache(f"pytest>={pytest.__version__}")
assert not RequirementCache(f"pytest<{pytest.__version__}")
assert "pip install -U '-'" in str(RequirementCache("-"))
Expand Down Expand Up @@ -78,3 +75,58 @@ def callback_fcn():
print(module)
os = lazy_import("os")
assert os.getcwd()


@requires("torch")
def my_torch_func(i: int) -> int:
import torch # noqa

return i


def test_torch_func_raised():
with pytest.raises(
ModuleNotFoundError, match="Required dependencies not available. Please run `pip install torch`"
):
my_torch_func(42)


@requires("random")
def my_random_func(nb: int) -> int:
from random import randint

return randint(0, nb)


def test_rand_func_passed():
assert 0 <= my_random_func(42) <= 42


class MyTorchClass:
@requires("torch", "random")
def __init__(self):
from random import randint

import torch # noqa

self._rnd = randint(1, 9)


def test_torch_class_raised():
with pytest.raises(
ModuleNotFoundError, match="Required dependencies not available. Please run `pip install torch`"
):
MyTorchClass()


class MyRandClass:
@requires("random")
def __init__(self, nb: int):
from random import randint

self._rnd = randint(1, nb)


def test_rand_class_passed():
cls = MyRandClass(42)
assert 0 <= cls._rnd <= 42