diff --git a/src/lightning_utilities/core/imports.py b/src/lightning_utilities/core/imports.py index 38456c4d..4455bb52 100644 --- a/src/lightning_utilities/core/imports.py +++ b/src/lightning_utilities/core/imports.py @@ -8,11 +8,15 @@ from functools import lru_cache from importlib.util import find_spec from types import ModuleType -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, TypeVar import pkg_resources from packaging.requirements import Requirement from packaging.version import Version +from typing_extensions import ParamSpec + +T = TypeVar("T") +P = ParamSpec("P") try: from importlib import metadata @@ -260,7 +264,7 @@ def lazy_import(module_name: str, callback: Optional[Callable] = None) -> LazyMo return LazyModule(module_name, callback=callback) -def requires(*module_path: str, raise_exception: bool = True) -> Callable: +def requires(*module_path: str, raise_exception: bool = True) -> Callable[[Callable[P, T]], Callable[P, T]]: """Wrap early import failure with some nice exception message. Example: @@ -277,15 +281,15 @@ def requires(*module_path: str, raise_exception: bool = True) -> Callable: ... self._rnd = pow(randint(1, 9), 2) """ - def decorator(func: Callable) -> Callable: + def decorator(func: Callable[P, T]) -> Callable[P, T]: @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: unavailable_modules = [module for module in module_path if not module_available(module)] if any(unavailable_modules): msg = f"Required dependencies not available. Please run `pip install {' '.join(unavailable_modules)}`" if raise_exception: raise ModuleNotFoundError(msg) - warnings.warn(msg) + warnings.warn(msg, stacklevel=2) return func(*args, **kwargs) return wrapper diff --git a/src/lightning_utilities/core/rank_zero.py b/src/lightning_utilities/core/rank_zero.py index 59b1d2a2..74fde563 100644 --- a/src/lightning_utilities/core/rank_zero.py +++ b/src/lightning_utilities/core/rank_zero.py @@ -7,19 +7,24 @@ import warnings from functools import wraps from platform import python_version -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TypeVar, Union + +from typing_extensions import ParamSpec log = logging.getLogger(__name__) +T = TypeVar("T") +P = ParamSpec("P") + -def rank_zero_only(fn: Callable) -> Callable: +def rank_zero_only(fn: Callable[P, T]) -> Callable[P, Optional[T]]: """Wrap a function to call internal function only in rank zero. Function that can be used as a decorator to enable a function/method being called only on global rank 0. """ @wraps(fn) - def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: + def wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> Optional[T]: rank = getattr(rank_zero_only, "rank", None) if rank is None: raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")