Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add requires wrapper #70

Merged
merged 14 commits into from
Dec 13, 2022
Merged
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added method to lazily import modules ([#71](https://github.com/Lightning-AI/utilities/pull/71))


- Added `requires` wrapper ([#70](https://github.com/Lightning-AI/utilities/pull/70))


### Changed

### Removed
Expand Down
43 changes: 42 additions & 1 deletion src/lightning_utilities/core/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0

import functools
import importlib
import operator
import os
import warnings
from functools import lru_cache
from importlib.util import find_spec
from types import ModuleType
Expand Down Expand Up @@ -121,7 +124,7 @@ def __repr__(self) -> str:
def get_dependency_min_version_spec(package_name: str, dependency_name: str) -> str:
"""Returns the minimum version specifier of a dependency of a package.

>>> get_dependency_min_version_spec("pytorch-lightning", "jsonargparse")
>>> get_dependency_min_version_spec("pytorch-lightning==1.8.0", "jsonargparse")
'>=4.12.0'
"""
dependencies = metadata.requires(package_name) or []
Expand Down Expand Up @@ -189,3 +192,41 @@ def lazy_import(module_name: str, callback: Optional[Callable] = None) -> LazyMo
a proxy module object that will be lazily imported when first used
"""
return LazyModule(module_name, callback=callback)


def requires(*module_path: str, raise_exception: bool = True) -> Callable:
"""Wrapper for early import failure with some nice exception message.

Example:

>>> @requires("libpath", raise_exception=bool(int(os.getenv("LIGHTING_TESTING", "0"))))
... def my_cwd():
... from pathlib import Path
... return Path(__file__).parent

>>> class MyRndPower:
... @requires("math", "random")
... def __init__(self):
... from math import pow
... from random import randint
... self._rnd = pow(randint(1, 9), 2)

.. note::
For downgrading exception to warning you export `LIGHTING_TESTING=1` which is handu for testing
Borda marked this conversation as resolved.
Show resolved Hide resolved
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""

def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
unavailable_modules = [module for module in module_path if not module_available(module)]
if any(unavailable_modules):
msg = f"Required dependencies not available. Please run `pip install {' '.join(unavailable_modules)}`"
if raise_exception:
raise ModuleNotFoundError(msg)
else:
warnings.warn(msg)
Borda marked this conversation as resolved.
Show resolved Hide resolved
return func(*args, **kwargs)

return wrapper

return decorator
66 changes: 64 additions & 2 deletions tests/unittests/core/test_imports.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import operator
import re

import pytest

from lightning_utilities.core.imports import (
compare_version,
get_dependency_min_version_spec,
lazy_import,
module_available,
RequirementCache,
requires,
)

try:
Expand Down Expand Up @@ -56,6 +55,8 @@ def test_requirement_cache():


def test_get_dependency_min_version_spec():
import pytest
Borda marked this conversation as resolved.
Show resolved Hide resolved

attrs_min_version_spec = get_dependency_min_version_spec("pytest", "attrs")
assert re.match(r"^>=[\d.]+$", attrs_min_version_spec)

Expand All @@ -67,6 +68,8 @@ def test_get_dependency_min_version_spec():


def test_lazy_import():
import pytest

def callback_fcn():
raise ValueError

Expand All @@ -78,3 +81,62 @@ def callback_fcn():
print(module)
os = lazy_import("os")
assert os.getcwd()


@requires("torch")
def my_torch_func(i: int) -> int:
import torch # noqa

return i


def test_torch_func_raised():
import pytest

with pytest.raises(
ModuleNotFoundError, match="Required dependencies not available. Please run `pip install torch`"
):
my_torch_func(42)


@requires("random")
def my_random_func(nb: int) -> int:
from random import randint

return randint(0, nb)


def test_rand_func_passed():
assert 0 <= my_random_func(42) <= 42


class MyTorchClass:
@requires("torch", "random")
def __init__(self):
from random import randint

import torch # noqa

self._rnd = randint(1, 9)


def test_torch_class_raised():
import pytest

with pytest.raises(
ModuleNotFoundError, match="Required dependencies not available. Please run `pip install torch`"
):
MyTorchClass()


class MyRandClass:
@requires("random")
def __init__(self, nb: int):
from random import randint

self._rnd = randint(1, nb)


def test_rand_class_passed():
cls = MyRandClass(42)
assert 0 <= cls._rnd <= 42