diff --git a/src/lightning_utilities/__about__.py b/src/lightning_utilities/__about__.py index e6dfd24c..89f394ac 100644 --- a/src/lightning_utilities/__about__.py +++ b/src/lightning_utilities/__about__.py @@ -1,6 +1,6 @@ import time -__version__ = "0.2.0" +__version__ = "0.3.0dev" __author__ = "Lightning AI et al." __author_email__ = "pytorch@lightning.ai" __license__ = "Apache-2.0" diff --git a/src/lightning_utilities/core/rank_zero.py b/src/lightning_utilities/core/rank_zero.py new file mode 100644 index 00000000..48d9ca67 --- /dev/null +++ b/src/lightning_utilities/core/rank_zero.py @@ -0,0 +1,89 @@ +"""Utilities that can be used for calling functions on a particular rank.""" +import logging +import warnings +from functools import wraps +from platform import python_version +from typing import Any, Callable, Optional, Union + +log = logging.getLogger(__name__) + + +def rank_zero_only(fn: Callable) -> Callable: + """Function that can be used as a decorator to enable a function/method being called only on global rank 0.""" + + @wraps(fn) + def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: + rank = getattr(rank_zero_only, "rank", None) + if rank is None: + raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") + if rank == 0: + return fn(*args, **kwargs) + return None + + return wrapped_fn + + +def _debug(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None: + if python_version() >= "3.8.0": + kwargs["stacklevel"] = stacklevel + log.debug(*args, **kwargs) + + +@rank_zero_only +def rank_zero_debug(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None: + """Function used to log debug-level messages only on global rank 0.""" + _debug(*args, stacklevel=stacklevel, **kwargs) + + +def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None: + if python_version() >= "3.8.0": + kwargs["stacklevel"] = stacklevel + log.info(*args, **kwargs) + + +@rank_zero_only +def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None: + """Function used to log info-level messages only on global rank 0.""" + _info(*args, stacklevel=stacklevel, **kwargs) + + +def _warn(message: Union[str, Warning], stacklevel: int = 2, **kwargs: Any) -> None: + warnings.warn(message, stacklevel=stacklevel, **kwargs) + + +@rank_zero_only +def rank_zero_warn(message: Union[str, Warning], stacklevel: int = 4, **kwargs: Any) -> None: + """Function used to log warn-level messages only on global rank 0.""" + _warn(message, stacklevel=stacklevel, **kwargs) + + +rank_zero_deprecation_category = DeprecationWarning + + +def rank_zero_deprecation(message: Union[str, Warning], stacklevel: int = 5, **kwargs: Any) -> None: + category = kwargs.pop("category", rank_zero_deprecation_category) + rank_zero_warn(message, stacklevel=stacklevel, category=category, **kwargs) + + +def rank_prefixed_message(message: str, rank: Optional[int]) -> str: + if rank is not None: + # specify the rank of the process being logged + return f"[rank: {rank}] {message}" + return message + + +class WarningCache(set): + def warn(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None: + if message not in self: + self.add(message) + rank_zero_warn(message, stacklevel=stacklevel, **kwargs) + + def deprecation(self, message: str, stacklevel: int = 6, **kwargs: Any) -> None: + if message not in self: + self.add(message) + rank_zero_deprecation(message, stacklevel=stacklevel, **kwargs) + + def info(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None: + if message not in self: + self.add(message) + rank_zero_info(message, stacklevel=stacklevel, **kwargs) diff --git a/tests/unittests/core/test_rank_zero.py b/tests/unittests/core/test_rank_zero.py new file mode 100644 index 00000000..854365a7 --- /dev/null +++ b/tests/unittests/core/test_rank_zero.py @@ -0,0 +1,18 @@ +import pytest + +from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only + + +def test_rank_zero_only_raises(): + foo = rank_zero_only(lambda x: x + 1) + with pytest.raises(RuntimeError, match="rank_zero_only.rank` needs to be set "): + foo(1) + + +@pytest.mark.parametrize("rank", [0, 1, 4]) +def test_rank_prefixed_message(rank): + rank_zero_only.rank = rank + message = rank_prefixed_message("bar", rank) + assert message == f"[rank: {rank}] bar" + # reset + del rank_zero_only.rank