diff --git a/composer/callbacks/__init__.py b/composer/callbacks/__init__.py index 1fdfed4767..d403ee0ee2 100644 --- a/composer/callbacks/__init__.py +++ b/composer/callbacks/__init__.py @@ -17,6 +17,7 @@ from composer.callbacks.lr_monitor import LRMonitor from composer.callbacks.memory_monitor import MemoryMonitor from composer.callbacks.mlperf import MLPerfCallback +from composer.callbacks.nan_monitor import NaNMonitor from composer.callbacks.optimizer_monitor import OptimizerMonitor from composer.callbacks.runtime_estimator import RuntimeEstimator from composer.callbacks.speed_monitor import SpeedMonitor @@ -28,6 +29,7 @@ 'OptimizerMonitor', 'LRMonitor', 'MemoryMonitor', + 'NaNMonitor', 'SpeedMonitor', 'CheckpointSaver', 'MLPerfCallback', diff --git a/composer/callbacks/nan_monitor.py b/composer/callbacks/nan_monitor.py new file mode 100644 index 0000000000..16d3c47bf6 --- /dev/null +++ b/composer/callbacks/nan_monitor.py @@ -0,0 +1,28 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Callback for catching loss NaNs.""" + +from typing import Sequence + +import torch + +from composer import Callback, Logger, State + +__all__ = ['NaNMonitor'] + + +class NaNMonitor(Callback): + """Catches NaNs in the loss and raises an error if one is found.""" + + def after_loss(self, state: State, logger: Logger): + """Check if loss is NaN and raise an error if so.""" + if isinstance(state.loss, torch.Tensor): + if torch.isnan(state.loss).any(): + raise RuntimeError('Train loss contains a NaN.') + elif isinstance(state.loss, Sequence): + for loss in state.loss: + if torch.isnan(loss).any(): + raise RuntimeError('Train loss contains a NaN.') + else: + raise TypeError(f'Loss is of type {type(state.loss)}, but should be a tensor or a sequence of tensors') diff --git a/docs/source/trainer/callbacks.rst b/docs/source/trainer/callbacks.rst index a12cdf5e52..a2c02c71cd 100644 --- a/docs/source/trainer/callbacks.rst +++ b/docs/source/trainer/callbacks.rst @@ -50,6 +50,7 @@ components of training. ~lr_monitor.LRMonitor ~optimizer_monitor.OptimizerMonitor ~memory_monitor.MemoryMonitor + ~nan_monitor.NaNMonitor ~image_visualizer.ImageVisualizer ~mlperf.MLPerfCallback ~threshold_stopper.ThresholdStopper diff --git a/tests/callbacks/test_nan_monitor.py b/tests/callbacks/test_nan_monitor.py new file mode 100644 index 0000000000..fe0c50018a --- /dev/null +++ b/tests/callbacks/test_nan_monitor.py @@ -0,0 +1,34 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from torch.utils.data import DataLoader + +from composer.callbacks import NaNMonitor +from composer.optim import DecoupledAdamW +from composer.trainer import Trainer +from tests.common import RandomClassificationDataset, SimpleModel + + +@pytest.mark.parametrize('should_nan', [True, False]) +def test_nan_monitor(should_nan): + # Make the callback + nan_monitor = NaNMonitor() + # Test model + model = SimpleModel() + # Construct the trainer and train. Make the LR huge to force a NaN, small if it shouldn't + lr = 1e20 if should_nan else 1e-10 + trainer = Trainer( + model=model, + callbacks=nan_monitor, + train_dataloader=DataLoader(RandomClassificationDataset()), + optimizers=DecoupledAdamW(model.parameters(), lr=lr), + max_duration='100ba', + ) + # If it should NaN, expect a RuntimeError + if should_nan: + with pytest.raises(RuntimeError) as excinfo: + trainer.fit() + assert 'Train loss contains a NaN.' in str(excinfo.value) + else: + trainer.fit()