Skip to content

Commit

Permalink
Simplify _get_rank() utility function (#19220)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jan 2, 2024
1 parent 564be3b commit f75f3bc
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 20 deletions.
7 changes: 1 addition & 6 deletions src/lightning/fabric/utilities/rank_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,12 @@
)
from typing_extensions import ParamSpec

import lightning.fabric
from lightning.fabric.utilities.imports import _UTILITIES_GREATER_EQUAL_0_10

rank_zero_module.log = logging.getLogger(__name__)


def _get_rank(
strategy: Optional["lightning.fabric.strategies.Strategy"] = None,
) -> Optional[int]:
if strategy is not None:
return strategy.global_rank
def _get_rank() -> Optional[int]:
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
# therefore LOCAL_RANK needs to be checked first
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
Expand Down
9 changes: 2 additions & 7 deletions src/lightning/pytorch/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from typing_extensions import override

import lightning.pytorch as pl
from lightning.fabric.utilities.rank_zero import _get_rank
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_warn
Expand Down Expand Up @@ -265,12 +264,8 @@ def _improvement_message(self, current: Tensor) -> str:
return msg

@staticmethod
def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None:
rank = _get_rank(
strategy=(trainer.strategy if trainer is not None else None), # type: ignore[arg-type]
)
if trainer is not None and trainer.world_size <= 1:
rank = None
def _log_info(trainer: "pl.Trainer", message: str, log_rank_zero_only: bool) -> None:
rank = trainer.global_rank if trainer.world_size > 1 else None
message = rank_prefixed_message(message, rank)
if rank is None or not log_rank_zero_only or rank == 0:
log.info(message)
9 changes: 2 additions & 7 deletions tests/tests_pytorch/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,6 @@ def test_early_stopping_squeezes():
es_mock.assert_called_once_with(torch.tensor(0))


@pytest.mark.parametrize("trainer", [Trainer(), None])
@pytest.mark.parametrize(
("log_rank_zero_only", "world_size", "global_rank", "expected_log"),
[
Expand All @@ -492,15 +491,11 @@ def test_early_stopping_squeezes():
(True, 2, 1, None),
],
)
def test_early_stopping_log_info(trainer, log_rank_zero_only, world_size, global_rank, expected_log):
def test_early_stopping_log_info(log_rank_zero_only, world_size, global_rank, expected_log):
"""Checks if log.info() gets called with expected message when used within EarlyStopping."""
# set the global_rank and world_size if trainer is not None
# or else always expect the simple logging message
if trainer:
trainer.strategy.global_rank = global_rank
trainer.strategy.world_size = world_size
else:
expected_log = "bar"
trainer = Mock(global_rank=global_rank, world_size=world_size)

with mock.patch("lightning.pytorch.callbacks.early_stopping.log.info") as log_mock:
EarlyStopping._log_info(trainer, "bar", log_rank_zero_only)
Expand Down

0 comments on commit f75f3bc

Please sign in to comment.