From 262edeed033f14046da3abd09874d521b18b79cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 18 Oct 2020 03:14:49 +0200 Subject: [PATCH 1/8] added data monitor code --- pl_bolts/callbacks/data_monitor.py | 270 +++++++++++++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 pl_bolts/callbacks/data_monitor.py diff --git a/pl_bolts/callbacks/data_monitor.py b/pl_bolts/callbacks/data_monitor.py new file mode 100644 index 0000000000..d62776e7bd --- /dev/null +++ b/pl_bolts/callbacks/data_monitor.py @@ -0,0 +1,270 @@ +from typing import Any, Sequence, Dict +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import wandb +from torch import Tensor +from torch.utils.hooks import RemovableHandle + +from pytorch_lightning import Callback +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.apply_func import apply_to_collection + + +class DataMonitorBase(Callback): + + supported_loggers = ( + TensorBoardLogger, + WandbLogger, + ) + + def __init__(self, log_every_n_steps: int = None): + """ + Base class for monitoring data histograms in a LightningModule. + This requires a logger configured in the Trainer, otherwise no data is logged. + The specific class that inherits from this base defines what data gets collected. + + Args: + log_every_n_steps: The interval at which histograms should be logged. This defaults to the + interval defined in the Trainer. Use this to override the Trainer default. + """ + super().__init__() + self._log_every_n_steps = log_every_n_steps + self._log = False + self._trainer = None + self._train_batch_idx = None + + def on_train_start(self, trainer, pl_module): + self._log = self._is_logger_available(trainer.logger) + self._log_every_n_steps = self._log_every_n_steps or trainer.log_every_n_steps + self._trainer = trainer + + def on_train_batch_start( + self, trainer, pl_module, batch, batch_idx, dataloader_idx + ): + self._train_batch_idx = batch_idx + + def log_histograms(self, batch, group="") -> None: + """ + Logs the histograms at the interval defined by `row_log_interval`, given a logger is available. + + Args: + batch: torch or numpy arrays, or a collection of it (tuple, list, dict, ...), can be nested. + If the data appears in a dictionary, the keys are used as labels for the corresponding histogram. + Otherwise the histograms get labelled with an integer index. + Each label also has the tensors's shape as suffix. + group: Name under which the histograms will be grouped. + """ + if not self._log or (self._train_batch_idx + 1) % self._log_every_n_steps != 0: + return + + batch = apply_to_collection(batch, dtype=np.ndarray, function=torch.from_numpy) + named_tensors = dict() + collect_and_name_tensors(batch, output=named_tensors, parent_name=group) + + for name, tensor in named_tensors.items(): + self.log_histogram(tensor, name) + + def log_histogram(self, tensor: Tensor, name: str) -> None: + """ + Override this method to customize the logging of histograms. + Detaches the tensor from the graph and moves it to the CPU for logging. + + Args: + tensor: The tensor for which to log a histogram + name: The name of the tensor as determined by the callback. Example: ``ìnput/0/[64, 1, 28, 28]`` + """ + logger = self._trainer.logger + tensor = tensor.detach().cpu() + if isinstance(logger, TensorBoardLogger): + logger.experiment.add_histogram( + tag=name, values=tensor, global_step=self._trainer.global_step + ) + + if isinstance(logger, WandbLogger): + logger.experiment.log( + data={name: wandb.Histogram(tensor)}, commit=False, + ) + + def _is_logger_available(self, logger) -> bool: + available = True + if not logger: + rank_zero_warn("Cannot log histograms because Trainer has no logger.") + available = False + if not isinstance(logger, self.supported_loggers): + rank_zero_warn( + f"{self.__class__.__name__} does not support logging with {logger.__class__.__name__}." + f" Supported loggers are: {', '.join(map(lambda x: str(x.__name__), self.supported_loggers))}" + ) + available = False + return available + + +class ModuleDataMonitor(DataMonitorBase): + + GROUP_NAME_INPUT = "input" + GROUP_NAME_OUTPUT = "output" + + def __init__( + self, + submodules: Optional[Union[bool, List[str]]] = None, + log_every_n_steps: int = None, + ): + """ + Args: + submodules: If `True`, logs the in- and output histograms of every submodule in the + LightningModule, including the root module itself. + This parameter can also take a list of names of specifc submodules (see example below). + Default: `None`, logs only the in- and output of the root module. + log_every_n_steps: The interval at which histograms should be logged. This defaults to the + interval defined in the Trainer. Use this to override the Trainer default. + + Note: + A too low value for `row_log_interval` may have a significant performance impact + especially when many submodules are involved, since the logging occurs during the forward pass. + It should only be used for debugging purposes. + + Example: + + .. code-block:: python + + # log the in- and output histograms of the `forward` in LightningModule + trainer = Trainer(callbacks=[ModuleDataMonitor()]) + + # all submodules in LightningModule + trainer = Trainer(callbacks=[ModuleDataMonitor(submodules=True)]) + + # specific submodules + trainer = Trainer(callbacks=[ModuleDataMonitor(submodules=["generator", "generator.conv1"])]) + + """ + super().__init__(log_every_n_steps=log_every_n_steps) + self._submodule_names = submodules + self._hook_handles = [] + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule): + super().on_train_start(trainer, pl_module) + submodule_dict = dict(pl_module.named_modules()) + self._hook_handles = [] + for name in self._get_submodule_names(pl_module): + if name not in submodule_dict: + rank_zero_warn( + f"{name} is not a valid identifier for a submodule in {pl_module.__class__.__name__}," + " skipping this key." + ) + continue + handle = self._register_hook(name, submodule_dict[name]) + self._hook_handles.append(handle) + + def on_train_end(self, trainer, pl_module): + for handle in self._hook_handles: + handle.remove() + + def _get_submodule_names(self, root_module: nn.Module) -> List[str]: + # default is the root module only + names = [""] + + if isinstance(self._submodule_names, list): + names = self._submodule_names + + if self._submodule_names is True: + names = [name for name, _ in root_module.named_modules()] + + return names + + def _register_hook(self, module_name: str, module: nn.Module) -> RemovableHandle: + input_group_name = ( + f"{self.GROUP_NAME_INPUT}/{module_name}" + if module_name + else self.GROUP_NAME_INPUT + ) + output_group_name = ( + f"{self.GROUP_NAME_OUTPUT}/{module_name}" + if module_name + else self.GROUP_NAME_OUTPUT + ) + + def hook(_, inp, out): + inp = inp[0] if len(inp) == 1 else inp + self.log_histograms(inp, group=input_group_name) + self.log_histograms(out, group=output_group_name) + + handle = module.register_forward_hook(hook) + return handle + + +class TrainingDataMonitor(DataMonitorBase): + + GROUP_NAME = "training_step" + + def __init__(self, log_every_n_steps: int = None): + """ + Callback that logs the histogram of values in the batched data passed to `training_step`. + + Args: + log_every_n_steps: The interval at which histograms should be logged. This defaults to the + interval defined in the Trainer. Use this to override the Trainer default. + + Example: + + .. code-block:: python + + # log histogram of training data passed to `LightningModule.training_step` + trainer = Trainer(callbacks=[TrainingDataMonitor()]) + """ + super().__init__(log_every_n_steps=log_every_n_steps) + + def on_train_batch_start(self, trainer, pl_module, batch, *args, **kwargs): + super().on_train_batch_start(trainer, pl_module, batch, *args, **kwargs) + self.log_histograms(batch, group=self.GROUP_NAME) + + +def collect_and_name_tensors( + data: Any, output: Dict[str, Tensor], parent_name: str = "input" +) -> None: + """ + Recursively fetches all tensors in a (nested) collection of data (depth-first search) and names them. + Data in dictionaries get named by their corresponding keys and otherwise they get indexed by an + increasing integer. The shape of the tensor gets appended to the name as well. + + Args: + data: A collection of data (potentially nested). + output: A dictionary in which the outputs will be stored. + parent_name: Used when called recursively on a nested input data. + + Example: + >>> data = {"x": torch.zeros(2, 3), "y": {"z": torch.zeros(5)}, "w": 1} + >>> output = {} + >>> collect_and_name_tensors(data, output) + >>> output # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS + {'input/x/[2, 3]': ..., 'input/y/z/[5]': ...} + """ + assert isinstance(output, dict) + if isinstance(data, Tensor): + name = f"{parent_name}/{shape2str(data)}" + output[name] = data + + if isinstance(data, dict): + for k, v in data.items(): + collect_and_name_tensors(v, output, parent_name=f"{parent_name}/{k}") + + if isinstance(data, Sequence) and not isinstance(data, str): + for i, item in enumerate(data): + collect_and_name_tensors(item, output, parent_name=f"{parent_name}/{i:d}") + + +def shape2str(tensor: Tensor) -> str: + """ + Returns the shape of a tensor in bracket notation as a string. + + Example: + >>> shape2str(torch.rand(1, 2, 3)) + '[1, 2, 3]' + >>> shape2str(torch.rand(4)) + '[4]' + """ + return "[" + ", ".join(map(str, tensor.shape)) + "]" From 2bc2e98788999e329420cbc4c147e3550f52c53b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 18 Oct 2020 03:15:31 +0200 Subject: [PATCH 2/8] black format --- pl_bolts/callbacks/data_monitor.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pl_bolts/callbacks/data_monitor.py b/pl_bolts/callbacks/data_monitor.py index d62776e7bd..8c69651cc1 100644 --- a/pl_bolts/callbacks/data_monitor.py +++ b/pl_bolts/callbacks/data_monitor.py @@ -44,7 +44,7 @@ def on_train_start(self, trainer, pl_module): self._trainer = trainer def on_train_batch_start( - self, trainer, pl_module, batch, batch_idx, dataloader_idx + self, trainer, pl_module, batch, batch_idx, dataloader_idx ): self._train_batch_idx = batch_idx @@ -110,9 +110,9 @@ class ModuleDataMonitor(DataMonitorBase): GROUP_NAME_OUTPUT = "output" def __init__( - self, - submodules: Optional[Union[bool, List[str]]] = None, - log_every_n_steps: int = None, + self, + submodules: Optional[Union[bool, List[str]]] = None, + log_every_n_steps: int = None, ): """ Args: @@ -224,7 +224,7 @@ def on_train_batch_start(self, trainer, pl_module, batch, *args, **kwargs): def collect_and_name_tensors( - data: Any, output: Dict[str, Tensor], parent_name: str = "input" + data: Any, output: Dict[str, Tensor], parent_name: str = "input" ) -> None: """ Recursively fetches all tensors in a (nested) collection of data (depth-first search) and names them. From da626a983a81fd2d182c808fb43ff54d349ae9d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 18 Oct 2020 03:25:28 +0200 Subject: [PATCH 3/8] added top import --- pl_bolts/callbacks/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pl_bolts/callbacks/__init__.py b/pl_bolts/callbacks/__init__.py index 62d94bc8b6..a8b27c1e5d 100644 --- a/pl_bolts/callbacks/__init__.py +++ b/pl_bolts/callbacks/__init__.py @@ -1,6 +1,7 @@ """ Collection of PyTorchLightning callbacks """ +from pl_bolts.callbacks.data_monitor import ModuleDataMonitor, TrainingDataMonitor from pl_bolts.callbacks.printing import PrintTableMetricsCallback from pl_bolts.callbacks.variational import LatentDimInterpolator from pl_bolts.callbacks.vision import TensorboardGenerativeModelImageSampler From 86640dceb8e7ff3f8c093d55410ba9cc01b5b8a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 18 Oct 2020 03:25:39 +0200 Subject: [PATCH 4/8] added basic tests --- tests/callbacks/test_data_monitor.py | 76 ++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests/callbacks/test_data_monitor.py diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py new file mode 100644 index 0000000000..5493b333b3 --- /dev/null +++ b/tests/callbacks/test_data_monitor.py @@ -0,0 +1,76 @@ +from unittest import mock + +import pytest + +from pl_bolts.callbacks import TrainingDataMonitor +from pytorch_lightning import Trainer +from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger + +from pl_bolts.models import LitMNIST + + +@pytest.mark.parametrize( + ["log_every_n_steps", "max_steps", "expected_calls"], [pytest.param(3, 10, 3)] +) +@mock.patch("pl_bolts.callbacks.data_monitor.TrainingDataMonitor.log_histogram") +def test_log_interval_override( + log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls +): + """ Test logging interval set by log_every_n_steps argument. """ + monitor = TrainingDataMonitor(log_every_n_steps=log_every_n_steps) + model = LitMNIST(hidden_dim=10, batch_size=4) + trainer = Trainer( + default_root_dir=tmpdir, + log_every_n_steps=1, + max_steps=max_steps, + callbacks=[monitor], + ) + + trainer.fit(model) + assert log_histogram.call_count == (expected_calls * 2) # 2 tensors per log call + + +@pytest.mark.parametrize( + ["log_every_n_steps", "max_steps", "expected_calls"], + [ + pytest.param(1, 5, 5), + pytest.param(2, 5, 2), + pytest.param(5, 5, 1), + pytest.param(6, 5, 0), + ], +) +@mock.patch("pl_bolts.callbacks.data_monitor.TrainingDataMonitor.log_histogram") +def test_log_interval_fallback( + log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls +): + """ Test that if log_every_n_steps not set in the callback, fallback to what is defined in the Trainer. """ + monitor = TrainingDataMonitor() + model = LitMNIST(hidden_dim=10, batch_size=4) + trainer = Trainer( + default_root_dir=tmpdir, + log_every_n_steps=log_every_n_steps, + max_steps=max_steps, + callbacks=[monitor], + ) + trainer.fit(model) + assert log_histogram.call_count == (expected_calls * 2) # 2 tensors per log call + + +def test_no_logger_warning(): + monitor = TrainingDataMonitor() + trainer = Trainer(logger=False, callbacks=[monitor]) + with pytest.warns( + UserWarning, match="Cannot log histograms because Trainer has no logger" + ): + monitor.on_train_start(trainer, pl_module=None) + + +def test_unsupported_logger_warning(tmpdir): + monitor = TrainingDataMonitor() + trainer = Trainer( + logger=LoggerCollection([TensorBoardLogger(tmpdir)]), callbacks=[monitor] + ) + with pytest.warns( + UserWarning, match="does not support logging with LoggerCollection" + ): + monitor.on_train_start(trainer, pl_module=None) From 9f1f85e5a385679bd3b9ecc7d015e751ff8a9b73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 18 Oct 2020 18:46:45 +0200 Subject: [PATCH 5/8] added more testing --- tests/callbacks/test_data_monitor.py | 173 +++++++++++++++++++++++++-- 1 file changed, 162 insertions(+), 11 deletions(-) diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py index 5493b333b3..0a84149c4e 100644 --- a/tests/callbacks/test_data_monitor.py +++ b/tests/callbacks/test_data_monitor.py @@ -1,24 +1,26 @@ from unittest import mock +from unittest.mock import call, ANY import pytest +import torch +import torch.nn as nn -from pl_bolts.callbacks import TrainingDataMonitor +from pl_bolts.callbacks import TrainingDataMonitor, ModuleDataMonitor +from pl_bolts.models import LitMNIST from pytorch_lightning import Trainer from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger -from pl_bolts.models import LitMNIST - @pytest.mark.parametrize( ["log_every_n_steps", "max_steps", "expected_calls"], [pytest.param(3, 10, 3)] ) -@mock.patch("pl_bolts.callbacks.data_monitor.TrainingDataMonitor.log_histogram") -def test_log_interval_override( +@mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram") +def test_base_log_interval_override( log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls ): """ Test logging interval set by log_every_n_steps argument. """ monitor = TrainingDataMonitor(log_every_n_steps=log_every_n_steps) - model = LitMNIST(hidden_dim=10, batch_size=4) + model = LitMNIST(num_workers=0) trainer = Trainer( default_root_dir=tmpdir, log_every_n_steps=1, @@ -39,13 +41,13 @@ def test_log_interval_override( pytest.param(6, 5, 0), ], ) -@mock.patch("pl_bolts.callbacks.data_monitor.TrainingDataMonitor.log_histogram") -def test_log_interval_fallback( +@mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram") +def test_base_log_interval_fallback( log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls ): """ Test that if log_every_n_steps not set in the callback, fallback to what is defined in the Trainer. """ monitor = TrainingDataMonitor() - model = LitMNIST(hidden_dim=10, batch_size=4) + model = LitMNIST(num_workers=0) trainer = Trainer( default_root_dir=tmpdir, log_every_n_steps=log_every_n_steps, @@ -56,7 +58,8 @@ def test_log_interval_fallback( assert log_histogram.call_count == (expected_calls * 2) # 2 tensors per log call -def test_no_logger_warning(): +def test_base_no_logger_warning(): + """ Test a warning is displayed when Trainer has no logger. """ monitor = TrainingDataMonitor() trainer = Trainer(logger=False, callbacks=[monitor]) with pytest.warns( @@ -65,7 +68,8 @@ def test_no_logger_warning(): monitor.on_train_start(trainer, pl_module=None) -def test_unsupported_logger_warning(tmpdir): +def test_base_unsupported_logger_warning(tmpdir): + """ Test a warning is displayed when an unsupported logger is used. """ monitor = TrainingDataMonitor() trainer = Trainer( logger=LoggerCollection([TensorBoardLogger(tmpdir)]), callbacks=[monitor] @@ -74,3 +78,150 @@ def test_unsupported_logger_warning(tmpdir): UserWarning, match="does not support logging with LoggerCollection" ): monitor.on_train_start(trainer, pl_module=None) + + +@mock.patch("pl_bolts.callbacks.data_monitor.TrainingDataMonitor.log_histogram") +def test_training_data_monitor(log_histogram, tmpdir): + """ Test that the TrainingDataMonitor logs histograms of data points going into training_step. """ + monitor = TrainingDataMonitor() + model = LitMNIST() + trainer = Trainer( + default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], + ) + monitor.on_train_start(trainer, model) + + # single tensor + example_data = torch.rand(2, 3, 4) + monitor.on_train_batch_start( + trainer, model, batch=example_data, batch_idx=0, dataloader_idx=0 + ) + assert log_histogram.call_args_list == [ + call(example_data, "training_step/[2, 3, 4]"), + ] + + log_histogram.reset_mock() + + # tuple + example_data = (torch.rand(2, 3, 4), torch.rand(5), "non-tensor") + monitor.on_train_batch_start( + trainer, model, batch=example_data, batch_idx=0, dataloader_idx=0 + ) + assert log_histogram.call_args_list == [ + call(example_data[0], "training_step/0/[2, 3, 4]"), + call(example_data[1], "training_step/1/[5]"), + ] + + log_histogram.reset_mock() + + # dict + example_data = { + "x0": torch.rand(2, 3, 4), + "x1": torch.rand(5), + "non-tensor": "non-tensor", + } + monitor.on_train_batch_start( + trainer, model, batch=example_data, batch_idx=0, dataloader_idx=0 + ) + assert log_histogram.call_args_list == [ + call(example_data["x0"], "training_step/x0/[2, 3, 4]"), + call(example_data["x1"], "training_step/x1/[5]"), + ] + + +class SubModule(nn.Module): + def __init__(self, inp, out): + super().__init__() + self.sub_layer = nn.Linear(inp, out) + + def forward(self, *args, **kwargs): + return self.sub_layer(*args, **kwargs) + + +class ModuleDataMonitorModel(nn.Module): + def __init__(self): + super().__init__() + self.layer1 = nn.Linear(12, 5) + self.layer2 = SubModule(5, 2) + + def forward(self, x): + x = x.flatten(1) + self.layer1_input = x + x = self.layer1(x) + self.layer1_output = x + x = torch.relu(x + 1) + self.layer2_input = x + x = self.layer2(x) + self.layer2_output = x + x = torch.relu(x - 2) + return x + + +@mock.patch("pl_bolts.callbacks.data_monitor.ModuleDataMonitor.log_histogram") +def test_module_data_monitor_forward(log_histogram, tmpdir): + """ Test that the default ModuleDataMonitor logs inputs and outputs of model's forward. """ + monitor = ModuleDataMonitor(submodules=None) + model = ModuleDataMonitorModel() + trainer = Trainer( + default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], + ) + monitor.on_train_start(trainer, model) + monitor.on_train_batch_start( + trainer, model, batch=None, batch_idx=0, dataloader_idx=0 + ) + + example_input = torch.rand(2, 6, 2) + output = model(example_input) + assert log_histogram.call_args_list == [ + call(example_input, "input/[2, 6, 2]"), + call(output, "output/[2, 2]"), + ] + + +@mock.patch("pl_bolts.callbacks.data_monitor.ModuleDataMonitor.log_histogram") +def test_module_data_monitor_submodules_all(log_histogram, tmpdir): + """ Test that the ModuleDataMonitor logs the inputs and outputs of each submodule. """ + monitor = ModuleDataMonitor(submodules=True) + model = ModuleDataMonitorModel() + trainer = Trainer( + default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], + ) + monitor.on_train_start(trainer, model) + monitor.on_train_batch_start( + trainer, model, batch=None, batch_idx=0, dataloader_idx=0 + ) + + example_input = torch.rand(2, 6, 2) + output = model(example_input) + assert log_histogram.call_args_list == [ + call(model.layer1_input, "input/layer1/[2, 12]"), + call(model.layer1_output, "output/layer1/[2, 5]"), + call(model.layer2_input, "input/layer2.sub_layer/[2, 5]"), + call(model.layer2_output, "output/layer2.sub_layer/[2, 2]"), + call(model.layer2_input, "input/layer2/[2, 5]"), + call(model.layer2_output, "output/layer2/[2, 2]"), + call(example_input, "input/[2, 6, 2]"), + call(output, "output/[2, 2]"), + ] + + +@mock.patch("pl_bolts.callbacks.data_monitor.ModuleDataMonitor.log_histogram") +def test_module_data_monitor_submodules_specific(log_histogram, tmpdir): + """ Test that the ModuleDataMonitor logs the inputs and outputs of selected submodules. """ + monitor = ModuleDataMonitor(submodules=["layer1", "layer2.sub_layer"]) + model = ModuleDataMonitorModel() + trainer = Trainer( + default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], + ) + monitor.on_train_start(trainer, model) + monitor.on_train_batch_start( + trainer, model, batch=None, batch_idx=0, dataloader_idx=0 + ) + + example_input = torch.rand(2, 6, 2) + _ = model(example_input) + assert log_histogram.call_args_list == [ + call(model.layer1_input, "input/layer1/[2, 12]"), + call(model.layer1_output, "output/layer1/[2, 5]"), + call(model.layer2_input, "input/layer2.sub_layer/[2, 5]"), + call(model.layer2_output, "output/layer2.sub_layer/[2, 2]"), + ] From 5799e6633c8b7ce9ed4bdc808ca368762adea755 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 18 Oct 2020 20:05:29 +0200 Subject: [PATCH 6/8] row_log_interval -> log_evey_n_steps --- pl_bolts/callbacks/data_monitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/callbacks/data_monitor.py b/pl_bolts/callbacks/data_monitor.py index 8c69651cc1..43207a8275 100644 --- a/pl_bolts/callbacks/data_monitor.py +++ b/pl_bolts/callbacks/data_monitor.py @@ -124,7 +124,7 @@ def __init__( interval defined in the Trainer. Use this to override the Trainer default. Note: - A too low value for `row_log_interval` may have a significant performance impact + A too low value for `log_every_n_steps` may have a significant performance impact especially when many submodules are involved, since the logging occurs during the forward pass. It should only be used for debugging purposes. From a7dcc103420f38fe71b9c655580ae20af005e3df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 18 Oct 2020 20:05:39 +0200 Subject: [PATCH 7/8] new docs --- docs/source/info_callbacks.rst | 51 +++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/docs/source/info_callbacks.rst b/docs/source/info_callbacks.rst index 2f11fbc902..cfeed28817 100644 --- a/docs/source/info_callbacks.rst +++ b/docs/source/info_callbacks.rst @@ -10,8 +10,57 @@ These callbacks give all sorts of useful information during training. Print Table Metrics ------------------- -This callbacks prints training metrics to a table. +This callback prints training metrics to a table. It's very bare-bones for speed purposes. .. autoclass:: pl_bolts.callbacks.printing.PrintTableMetricsCallback :noindex: + + +--------------- + +Data Monitoring in LightningModule +---------------------------------- +The data monitoring callbacks allow you to log and inspect the distribution of data that passes through +the training step and layers of the model. When used in combination with a supported logger, the +:class:`~pl_bolts.callbacks.data_monitor.TrainingDataMonitor` creates a histogram for each `batch` input in +:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` and sends it to the logger: + +.. code-block:: python + + from pl_bolts.callbacks import TrainingDataMonitor + from pytorch_lightning import Trainer + + # log the histograms of input data sent to LightningModule.training_step + monitor = TrainingDataMonitor(log_every_n_steps=25) + + model = YourLightningModule() + trainer = Trainer(callbacks=[monitor]) + trainer.fit() + + +The second, more advanced :class:`~pl_bolts.callbacks.data_monitor.ModuleDataMonitor` +callback tracks histograms for the data that passes through +the model itself and its submodules, i.e., it tracks all `.forward()` calls and registers the in- and outputs. +You can track all or just a selection of submodules: + +.. code-block:: python + + from pl_bolts.callbacks import ModuleDataMonitor + from pytorch_lightning import Trainer + + # log the in- and output histograms of LightningModule's `forward` + monitor = ModuleDataMonitor() + + # all submodules in LightningModule + monitor = ModuleDataMonitor(submodules=True) + + # specific submodules + monitor = ModuleDataMonitor(submodules=["generator", "generator.conv1"]) + + model = YourLightningModule() + trainer = Trainer(callbacks=[monitor]) + trainer.fit() + +This is especially useful for debugging the data flow in complex models and to identify +numerical instabilities. From 40870f6723cfaed1553f7b9e37adfe585571ab29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 18 Oct 2020 21:13:30 +0200 Subject: [PATCH 8/8] fix wandb import error --- pl_bolts/callbacks/data_monitor.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pl_bolts/callbacks/data_monitor.py b/pl_bolts/callbacks/data_monitor.py index 43207a8275..a623951d2e 100644 --- a/pl_bolts/callbacks/data_monitor.py +++ b/pl_bolts/callbacks/data_monitor.py @@ -4,7 +4,6 @@ import numpy as np import torch import torch.nn as nn -import wandb from torch import Tensor from torch.utils.hooks import RemovableHandle @@ -15,6 +14,12 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection +try: + import wandb +except ModuleNotFoundError: + wandb = None + + class DataMonitorBase(Callback): supported_loggers = (