diff --git a/CHANGELOG.md b/CHANGELOG.md index 784a1581ee97a..cb1ee9d5eb07d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added unified API for figure logging `log_figure` in Tensorboard, Comet, ML Flow, Neptune, Weights and Biases. + + - Add `dataclass` support for `pytorch_lightning.utilities.apply_to_collection` ([#7935](https://github.com/PyTorchLightning/pytorch-lightning/pull/7935)) diff --git a/docs/source/extensions/logging.rst b/docs/source/extensions/logging.rst index 12760f0ee6898..b091579845e8d 100644 --- a/docs/source/extensions/logging.rst +++ b/docs/source/extensions/logging.rst @@ -319,6 +319,20 @@ in the `hparams tab Callable: @@ -184,6 +192,21 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): """ pass + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: + """ + Logs a matplotlib figure. + + Args: + name: name of the figure + figure: plt figure handle + step: step number at which the figure should be recorded + close: close figure after logging + """ + # Default is silent and not NotImplementedError because we want to support LoggerCollection + # where some loggers might others might not have implemented this method. + if close: + plt.close(figure) + @staticmethod def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]: # in case converting from namespace @@ -392,6 +415,14 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> for logger in self._logger_iterable: logger.log_metrics(metrics, step) + def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: + for logger in self._logger_iterable: + # don't close in the individual loggers, but once at the end + logger.log_figure(name, figure, step=step, close=False) + + if close: + plt.close(figure) + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: for logger in self._logger_iterable: logger.log_hyperparams(params) diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 148e512f5e439..06cf0398e1f24 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -28,6 +28,8 @@ from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE + log = logging.getLogger(__name__) _COMET_AVAILABLE = _module_available("comet_ml") @@ -49,6 +51,12 @@ CometExperiment, CometExistingExperiment, CometOfflineExperiment = None, None, None API = None +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib + plt = _matplotlib.pyplot + class CometLogger(LightningLoggerBase): r""" @@ -252,6 +260,13 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti metrics_without_epoch = self._add_prefix(metrics_without_epoch) self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch) + @rank_zero_only + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: + + self.experiment.log_figure(figure_name=name, figure=figure, step=step) + if close: + plt.close(figure) + def reset_experiment(self): self._experiment = None diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 1426adbe1104a..6930c2967fe5c 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -18,12 +18,15 @@ import logging import os import re +import tempfile from argparse import Namespace +from pathlib import Path from time import time from typing import Any, Dict, Optional, Union from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE log = logging.getLogger(__name__) LOCAL_FILE_URI_PREFIX = "file:" @@ -38,6 +41,12 @@ mlflow, MlflowClient, context = None, None, None MLFLOW_RUN_NAME = "mlflow.runName" +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib + plt = _matplotlib.pyplot + # before v1.1.0 if hasattr(context, 'resolve_tags'): from mlflow.tracking.context import resolve_tags @@ -99,6 +108,7 @@ def any_lightning_module_function_or_hook(self): prefix: A string to put at the beginning of metric keys. artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate default. + figure_file_extension: File extension with which matplotlib saves figure Raises: ImportError: @@ -116,6 +126,7 @@ def __init__( save_dir: Optional[str] = './mlruns', prefix: str = '', artifact_location: Optional[str] = None, + figure_file_extension='.png', ): if mlflow is None: raise ImportError( @@ -136,6 +147,7 @@ def __init__( self._artifact_location = artifact_location self._mlflow_client = MlflowClient(tracking_uri) + self._figure_file_extension = figure_file_extension @property @rank_zero_experiment @@ -220,6 +232,28 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step) + @rank_zero_only + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: + + if step is not None: + figure_fname = f"figure_{name}_step_{step}{self._figure_file_extension}" + else: + figure_fname = f"figure_{name}{self._figure_file_extension}" + + # create tmp directory and semantically named filebecause + # apparently one should not write to artifact location directly + # ToDo: Once its stable, use ml_flow.log_figure + with tempfile.TemporaryDirectory() as tmp_dir: + figure_path = Path(tmp_dir) / figure_fname + figure.savefig(figure_path) + self.experiment.log_artifact( + self.run_id, + figure_path, + ) + + if close: + plt.close(figure) + @rank_zero_only def finalize(self, status: str = 'FINISHED') -> None: super().finalize(status) diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index aed09f11464f8..a1b6e1fc9a212 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -24,6 +24,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE log = logging.getLogger(__name__) _NEPTUNE_AVAILABLE = _module_available("neptune") @@ -35,6 +36,12 @@ # needed for test mocks, these tests shall be updated neptune, Experiment = None, None +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib + plt = _matplotlib.pyplot + class NeptuneLogger(LightningLoggerBase): r""" @@ -263,6 +270,15 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti # Lighting does not always guarantee. self.log_metric(key, val) + @rank_zero_only + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: + import matplotlib.pyplot as plt + + description = f"step_{step}" if step is not None else None + self.experiment.log_image(name, figure, description=description) + if close: + plt.close(figure) + @rank_zero_only def finalize(self, status: str) -> None: super().finalize(status) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index b69f31ae53b32..2d1a52f2152fb 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -30,12 +30,19 @@ from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE log = logging.getLogger(__name__) if _OMEGACONF_AVAILABLE: from omegaconf import Container, OmegaConf +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib + plt = _matplotlib.pyplot + class TensorBoardLogger(LightningLoggerBase): r""" @@ -222,6 +229,10 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> m = f'\n you tried to log {v} which is not currently supported. Try a dict or a scalar/tensor.' raise ValueError(m) from ex + @rank_zero_only + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: + self.experiment.add_figure(tag=name, figure=figure, global_step=step, close=close) + @rank_zero_only def log_graph(self, model: LightningModule, input_array=None): if self._log_graph: diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index c127fa037ed6b..1d991ec9bc34c 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -28,7 +28,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _compare_version +from pytorch_lightning.utilities.imports import _compare_version, _MATPLOTLIB_AVAILABLE from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() @@ -43,6 +43,12 @@ # needed for test mocks, these tests shall be updated wandb, Run = None, None +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib + plt = _matplotlib.pyplot + class WandbLogger(LightningLoggerBase): r""" @@ -219,6 +225,14 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> else: self.experiment.log(metrics) + @rank_zero_only + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: + import matplotlib.pyplot as plt + + self.experiment.log({name: wandb.Image(figure)}, step=step) + if close: + plt.close(figure) + @property def save_dir(self) -> Optional[str]: return self._save_dir diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 2a51b01404821..2313754321868 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -83,6 +83,7 @@ def _compare_version(package: str, op, version) -> bool: _HYDRA_AVAILABLE = _module_available("hydra") _HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental") _KINETO_AVAILABLE = _TORCH_GREATER_EQUAL_1_8_1 and torch.profiler.kineto_available() +_MATPLOTLIB_AVAILABLE = _module_available("matplotlib") _NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") _OMEGACONF_AVAILABLE = _module_available("omegaconf") _POPTORCH_AVAILABLE = _module_available('poptorch') diff --git a/pytorch_lightning/utilities/mock_types.py b/pytorch_lightning/utilities/mock_types.py new file mode 100644 index 0000000000000..7b6e326c353a2 --- /dev/null +++ b/pytorch_lightning/utilities/mock_types.py @@ -0,0 +1,20 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class matplotlib: + + class pyplot: + figure = None + close = None diff --git a/tests/helpers/plotting.py b/tests/helpers/plotting.py new file mode 100644 index 0000000000000..ef9e1b7559dca --- /dev/null +++ b/tests/helpers/plotting.py @@ -0,0 +1,20 @@ +import numpy as np + +from pytorch_lightning.utilities import _module_available + +_MATPLOTLIB_AVAILABLE = _module_available("matplotlib") +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + + class plt: + figure = None + + +def dummy_figure() -> plt.figure: + """Dummy figure to test logging capability of figures for loggers.""" + + f = plt.figure() + plt.plot(np.linspace(0., 1., 100), np.linspace(0., 10., 100)**2) + + return f diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 1bad12b1f9a3d..3819ca9c9cb77 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -32,7 +32,8 @@ WandbLogger, ) from pytorch_lightning.loggers.base import DummyExperiment -from tests.helpers import BoringModel +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE +from tests.helpers import BoringModel, plotting from tests.helpers.runif import RunIf from tests.loggers.test_comet import _patch_comet_atexit from tests.loggers.test_mlflow import mock_mlflow_run_creation @@ -407,3 +408,43 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): wandb.init().step = 0 logger.log_metrics({"test": 1.0}, step=0) logger.experiment.log.assert_called_once_with({'tmp-test': 1.0, 'trainer/global_step': 0}) + + +@pytest.mark.skipif( + not _MATPLOTLIB_AVAILABLE, reason="close figure test requires matplotlib to be installed." +) +@pytest.mark.parametrize("close", [True, False]) +@pytest.mark.parametrize( + "logger_class", + [ + CometLogger, + CSVLogger, + MLFlowLogger, + NeptuneLogger, + TensorBoardLogger, + # Wandb has its own close_figure test + ] +) +def test_logger_close_figure_all(tmpdir, monkeypatch, logger_class, close): + _patch_comet_atexit(monkeypatch) + try: + _test_logger_close_figure(tmpdir, monkeypatch, logger_class, close) + except (ImportError, ModuleNotFoundError): + pytest.xfail(f"pickle test requires {logger_class.__class__} dependencies to be installed.") + + +def _test_logger_close_figure(tmpdir, monkeypatch, logger_class, close): + _patch_comet_atexit(monkeypatch) + + logger_args = _get_logger_args(logger_class, tmpdir) + logger = logger_class(**logger_args) + + f = plotting.dummy_figure() + + with mock.patch('matplotlib.pyplot.close') as plt_close: + logger.log_figure('dummy', f, 0, close=close) + + if close: + plt_close.assert_called_once_with(f) + else: + plt_close.assert_not_called() diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index 9209083148265..f8a80930bd1f2 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -22,7 +22,7 @@ from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger from pytorch_lightning.loggers.base import DummyExperiment, DummyLogger from pytorch_lightning.utilities import rank_zero_only -from tests.helpers import BoringModel +from tests.helpers import BoringModel, plotting def test_logger_collection(): @@ -73,6 +73,10 @@ def log_hyperparams(self, params): def log_metrics(self, metrics, step): self.metrics_logged = metrics + @rank_zero_only + def log_figure(self, name, figure, step): + self.figure_logged = figure + @rank_zero_only def finalize(self, status): self.finalized_status = status @@ -170,6 +174,18 @@ def test_multiple_loggers_pickle(tmpdir): assert trainer2.logger[1].metrics_logged == {"acc": 1.0} +def test_multiple_loggers_figure(): + + logger1 = MagicMock() + logger2 = MagicMock() + + logger = LoggerCollection([logger1, logger2]) + logger.log_figure('dummy_figure', plotting.dummy_figure(), 0) + + logger1.log_figure.assert_called_once() + logger2.log_figure.assert_called_once() + + def test_adding_step_key(tmpdir): class CustomTensorBoardLogger(TensorBoardLogger): diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index 1d686c6ba8c15..1ee2fd81f0287 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -19,7 +19,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import CometLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE +from tests.helpers import BoringModel, plotting def _patch_comet_atexit(monkeypatch): @@ -222,3 +223,24 @@ def test_comet_epoch_logging(comet, comet_experiment, tmpdir, monkeypatch): logger = CometLogger(project_name="test", save_dir=tmpdir) logger.log_metrics({"test": 1, "epoch": 1}, step=123) logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123) + + +@pytest.mark.skipif( + not _MATPLOTLIB_AVAILABLE, reason="close figure test requires matplotlib to be installed." +) +@patch("pytorch_lightning.loggers.comet.CometExperiment") +@patch('pytorch_lightning.loggers.comet.comet_ml') +@pytest.mark.parametrize("step_idx", [10, None]) +def test_comet_log_figure(comet, comet_experiment, tmpdir, monkeypatch, step_idx): + + _patch_comet_atexit(monkeypatch) + + logger = CometLogger(project_name="test", save_dir=tmpdir) + logger.log_figure('dummy', plotting.dummy_figure(), step_idx, close=True) # functional test + + # test whether figure is closed etc. + with patch.object(logger.experiment, 'log_figure') as mock_log: + f = plotting.dummy_figure() + logger.log_figure('dummy', f, step_idx, close=True) + + mock_log.assert_called_once_with(figure_name='dummy', figure=f, step=step_idx) diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index d798cb9f16f7e..2fec404641834 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -20,7 +20,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import _MLFLOW_AVAILABLE, MLFlowLogger from pytorch_lightning.loggers.mlflow import MLFLOW_RUN_NAME, resolve_tags -from tests.helpers import BoringModel +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE +from tests.helpers import BoringModel, plotting def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None): @@ -255,3 +256,32 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): name='test', artifact_location='my_artifact_location', ) + + +@pytest.mark.skipif( + not _MATPLOTLIB_AVAILABLE, reason="close figure test requires matplotlib to be installed." +) +@pytest.mark.parametrize("step_idx", [10, None]) +@pytest.mark.parametrize("figure_format", ['.png', '.pdf']) +def test_mlflow_log_figure(step_idx, figure_format, tmpdir): + + logger = MLFlowLogger('test', save_dir=tmpdir, figure_file_extension=figure_format) + logger.log_figure('dummy', plotting.dummy_figure(), step_idx, close=True) + + if step_idx is not None: + fname_expct = f'figure_dummy_step_{step_idx}{figure_format}' + else: + fname_expct = f'figure_dummy{figure_format}' + path_expct = tmpdir / logger.experiment_id / logger.run_id / 'artifacts' / fname_expct + assert path_expct.check(file=True) + + # tests arguments to log_artifact + with mock.patch.object(logger.experiment, 'log_artifact') as mock_log: + f = plotting.dummy_figure() + logger.log_figure('dummy', f, step_idx, close=True) + + mock_log.assert_called_once() + if step_idx is not None: + assert mock_log.call_args_list[0][0][1].stem == 'figure_dummy_step_10' + else: + assert mock_log.call_args_list[0][0][1].stem == 'figure_dummy' diff --git a/tests/loggers/test_neptune.py b/tests/loggers/test_neptune.py index 25b417c8eae3f..ea83dd0eeab8c 100644 --- a/tests/loggers/test_neptune.py +++ b/tests/loggers/test_neptune.py @@ -13,11 +13,13 @@ # limitations under the License. from unittest.mock import MagicMock, patch +import pytest import torch from pytorch_lightning import Trainer from pytorch_lightning.loggers import NeptuneLogger -from tests.helpers import BoringModel +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE +from tests.helpers import BoringModel, plotting @patch('pytorch_lightning.loggers.neptune.neptune') @@ -124,3 +126,19 @@ def _run_training(logger): logger_open_after_fit = _run_training(NeptuneLogger(offline_mode=True, close_after_fit=False)) assert logger_open_after_fit._experiment.stop.call_count == 0 + + +@pytest.mark.skipif( + not _MATPLOTLIB_AVAILABLE, reason="close figure test requires matplotlib to be installed." +) +@patch('pytorch_lightning.loggers.neptune.neptune') +@pytest.mark.parametrize("step_idx", [10, None]) +def test_neptune_log_figure(neptune, step_idx): + logger = NeptuneLogger(api_key='test', project_name='project') + logger.log_figure('dummy', plotting.dummy_figure(), step=42, close=True) + + with patch.object(logger.experiment, 'log_image') as mock_log: + f = plotting.dummy_figure() + logger.log_figure('dummy', f, step_idx, close=True) + + mock_log.assert_called_once_with('dummy', f, description=f"step_{step_idx}" if step_idx is not None else None) diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index ffd89a0c14984..4bb6fc4b7b64c 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -25,7 +25,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger -from tests.helpers import BoringModel +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE +from tests.helpers import BoringModel, plotting from tests.helpers.runif import RunIf @@ -177,6 +178,22 @@ def test_tensorboard_log_metrics(tmpdir, step_idx): logger.log_metrics(metrics, step_idx) +@pytest.mark.skipif( + not _MATPLOTLIB_AVAILABLE, reason="close figure test requires matplotlib to be installed." +) +@pytest.mark.parametrize("step_idx", [10, None]) +def test_tensorboard_log_figure(tmpdir, step_idx): + logger = TensorBoardLogger(tmpdir) + logger.log_figure('dummy', plotting.dummy_figure(), step_idx, close=True) # functional test + + # test whether figure is closed etc. + with mock.patch.object(logger.experiment, 'add_figure') as mock_log: + f = plotting.dummy_figure() + logger.log_figure('dummy', f, step_idx, close=True) + + mock_log.assert_called_once_with(tag="dummy", figure=f, global_step=step_idx, close=True) + + def test_tensorboard_log_hyperparams(tmpdir): logger = TensorBoardLogger(tmpdir) hparams = { diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 27185b911b6d0..46ca8edfd343c 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -21,7 +21,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE +from tests.helpers import BoringModel, plotting @mock.patch('pytorch_lightning.loggers.wandb.wandb') @@ -254,3 +255,39 @@ def test_wandb_logger_offline_log_model(wandb, tmpdir): """ Test that log_model=True raises an error in offline mode """ with pytest.raises(MisconfigurationException, match='checkpoints cannot be uploaded in offline mode'): _ = WandbLogger(save_dir=str(tmpdir), offline=True, log_model=True) + + +@pytest.mark.skipif( + not _MATPLOTLIB_AVAILABLE, reason="close figure test requires matplotlib to be installed." +) +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +@pytest.mark.parametrize("step_idx", [10, None]) +def test_wandb_logger_log_figure(wandb, tmpdir, step_idx): + + logger = WandbLogger(save_dir=str(tmpdir), offline=True) + logger.log_figure('dummy', plotting.dummy_figure(), step_idx, close=True) # functional test + + with mock.patch.object(logger.experiment, 'log') as mock_log: + f = plotting.dummy_figure() + logger.log_figure('dummy', f, step_idx, close=True) + + mock_log.assert_called_once_with({'dummy': wandb.Image(f)}, step=step_idx) + + +@pytest.mark.skipif( + not _MATPLOTLIB_AVAILABLE, reason="close figure test requires matplotlib to be installed." +) +@pytest.mark.parametrize("close", [True, False]) +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_close_figure(wandb, tmpdir, close): + f = plotting.dummy_figure() + + logger = WandbLogger(save_dir=str(tmpdir), offline=True) + + with mock.patch('matplotlib.pyplot.close') as plt_close: + logger.log_figure('dummy', f, 0, close=close) + + if close: + plt_close.assert_called_once_with(f) + else: + plt_close.assert_not_called()