Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract log_figure API that implements reasonable default for all lo… #6227

Closed
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7b2d2d8
Abstract log_figure API that implements reasonable default for all lo…
Haydnspass Feb 26, 2021
6f8bd9a
Fix missing line at end of file, rm default arg for unlink (compatibi…
Haydnspass Feb 26, 2021
95b4d4e
Update CHANGELOG.md
Haydnspass Mar 12, 2021
dcdfbee
Update pytorch_lightning/loggers/neptune.py
Haydnspass Mar 12, 2021
eca36b2
Update docs/source/extensions/logging.rst
Haydnspass Mar 12, 2021
b0b67d7
use pathlib for path
Haydnspass Mar 12, 2021
a77028a
Make matplotlib extra
Haydnspass Mar 12, 2021
4424626
Merge branch 'master' into feature/log_figure_abstract_clean
Haydnspass Mar 12, 2021
e11fc0c
Make matplotlib local import
Haydnspass Mar 12, 2021
32f67c6
Merge master
Haydnspass Mar 26, 2021
6af3f1e
Change from manual to pythonic tempfile
Haydnspass Mar 26, 2021
7f2407b
Revert "Merge master"
Haydnspass Mar 26, 2021
149d7a4
Merge branch 'master' into feature/log_figure_abstract_clean
Haydnspass Mar 26, 2021
f545b94
Rm old fname check that is obsolete due to tempfile
Haydnspass Mar 26, 2021
289bd21
Merge branch 'master' into feature/log_figure_abstract_clean
Haydnspass Apr 9, 2021
a1a3eed
Clear spec of supported loggers.
Haydnspass Apr 9, 2021
1559b78
Style
Haydnspass Apr 9, 2021
653bd36
Add matplotlib is available checks. Refactor close_figure all test.
Haydnspass Apr 23, 2021
855efd1
Add mocked types when matplotlib is not available. Style.
Haydnspass Apr 23, 2021
cf9218c
Merge branch 'master' into feature/log_figure_abstract_clean
Haydnspass Apr 23, 2021
76bd3b2
PEP
Haydnspass Apr 23, 2021
87150dc
Rm deprecated example section
Haydnspass May 7, 2021
602b42c
Merge branch 'master' into feature/log_figure_abstract_clean
Haydnspass May 7, 2021
a4175f2
Rm unnecessary imports
Haydnspass May 21, 2021
ca0dd51
Merge branch 'master' into feature/log_figure_abstract_clean
Haydnspass May 21, 2021
812d3c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2021
00b59a1
Bugfix in artifact path
Haydnspass May 21, 2021
e775e76
Merge branch 'feature/log_figure_abstract_clean' of https://github.co…
Haydnspass May 21, 2021
82c6da9
Change log_figure saving logic for mlflow
Haydnspass May 21, 2021
fc86940
Merge branch 'master' into feature/log_figure_abstract_clean
Haydnspass May 21, 2021
ec64bf8
Merge branch 'master' into feature/log_figure_abstract_clean
tchaton Jun 14, 2021
126c0ec
Change mlflow log_figure test to actual functional test instead of mo…
Jun 18, 2021
f972410
Move matplotlib available check to utilities.import
Jun 18, 2021
5684258
Rm unnecessary line
Jun 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added unified API for figure logging `log_figure` in Tensorboard, Comet, ML Flow, Neptune, Weights and Biases.

### Changed

Expand Down
14 changes: 14 additions & 0 deletions docs/source/extensions/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,20 @@ in the `hparams tab <https://pytorch.org/docs/stable/tensorboard.html#torch.util

----------

***************
Logging Figures
***************

When training a model, often it is very indicative to log figures, e.g. of the in- and output.
For standard ``matplotlib.pyplot`` figures, Lightning has a unified API that works with Tensorboard, Comet, ML Flow, Neptune, Weights and Biases.

.. code-block:: python

f = plt.figure()
logger.log_figure(name='dummy_figure', figure=f, step=0, close=True)

----------

*************
Snapshot code
*************
Expand Down
37 changes: 36 additions & 1 deletion pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@
import torch

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities import _module_available, rank_zero_only

_MATPLOTLIB_AVAILABLE = _module_available("matplotlib")

if _MATPLOTLIB_AVAILABLE:
import matplotlib.pyplot as plt
else:
from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib
plt = _matplotlib.pyplot


def rank_zero_experiment(fn: Callable) -> Callable:
Expand Down Expand Up @@ -173,6 +181,23 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""
pass

def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add additional **kwargs here.
When the user knows they are using a particular logger that supports additinal arguments, we should allow it to be passed down.

for example, in one of the loggers you have a description, and this could be customized then by the user.
This would then still be logger agnostic, where loggers simply ignore unknown **kwargs when it's not applicable.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Funnily I had it including **kwargs two times, and changed it back and forth ...

I do not like it because:

  • Silently ignoring arguments that are unexpected is not what python does by default when you parse **kwargs
  • It'll most certainly break for LoggerCollection unless one implements a more complex logic which then makes the whole implementation unnecessary

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think **kwargs is really necessary here. Each logger can have many different arguments

Silently ignoring arguments that are unexpected is not what python does by default when you parse **kwargs

I don't get this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you add **kwargs it will break LoggerCollection if you don't do anything. So what do you do if you have logger A that can take a certain argument that logger B in your collection does not? Ignore argument, raise?
Or add more complex logic that looks something like this

my_mwargs = {
    'loggerA': {'kwarga': 1},
    'loggerB': {'kwargb': 2},
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--> Needs design decision

"""
Logs a matplotlib figure.

Args:
name: name of the figure
figure: plt figure handle
step: step number at which the figure should be recorded
close: close figure after logging
"""
import matplotlib.pyplot as plt

Haydnspass marked this conversation as resolved.
Show resolved Hide resolved
# Default is silent and not NotImplementedError because we want to support LoggerCollection
# where some loggers might others might not have implemented this method.
if close:
plt.close(figure)

@staticmethod
def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
# in case converting from namespace
Expand Down Expand Up @@ -377,6 +402,16 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
for logger in self._logger_iterable:
logger.log_metrics(metrics, step)

def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None:
import matplotlib.pyplot as plt

Haydnspass marked this conversation as resolved.
Show resolved Hide resolved
Haydnspass marked this conversation as resolved.
Show resolved Hide resolved
for logger in self._logger_iterable:
# don't close in the individual loggers, but once at the end
logger.log_figure(name, figure, step=step, close=False)

if close:
plt.close(figure)

def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
for logger in self._logger_iterable:
logger.log_hyperparams(params)
Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@
CometExperiment, CometExistingExperiment, CometOfflineExperiment = None, None, None
API = None

_MATPLOTLIB_AVAILABLE = _module_available("matplotlib")

if _MATPLOTLIB_AVAILABLE:
import matplotlib.pyplot as plt
else:
from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib
plt = _matplotlib.pyplot

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth adding this just for the type annotation?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this because someone suggested it above. Does not hurt, does it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--> Needs review decision


class CometLogger(LightningLoggerBase):
r"""
Expand Down Expand Up @@ -252,6 +260,13 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti
metrics_without_epoch = self._add_prefix(metrics_without_epoch)
self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch)

@rank_zero_only
def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None:

self.experiment.log_figure(figure_name=name, figure=figure, step=step)
if close:
plt.close(figure)

def reset_experiment(self):
self._experiment = None

Expand Down
28 changes: 28 additions & 0 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
"""
import logging
import re
import tempfile
from argparse import Namespace
from pathlib import Path
from time import time
from typing import Any, Dict, Optional, Union

Expand All @@ -35,6 +37,15 @@
_MLFLOW_AVAILABLE = False
mlflow, MlflowClient, context = None, None, None

_MATPLOTLIB_AVAILABLE = _module_available("matplotlib")

if _MATPLOTLIB_AVAILABLE:
import matplotlib.pyplot as plt
else:
from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib
plt = _matplotlib.pyplot


# before v1.1.0
if hasattr(context, 'resolve_tags'):
from mlflow.tracking.context import resolve_tags
Expand Down Expand Up @@ -93,6 +104,7 @@ def any_lightning_module_function_or_hook(self):
prefix: A string to put at the beginning of metric keys.
artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
default.
figure_file_extension: File extension with which matplotlib saves figure

Raises:
ImportError:
Expand All @@ -109,6 +121,7 @@ def __init__(
save_dir: Optional[str] = './mlruns',
prefix: str = '',
artifact_location: Optional[str] = None,
figure_file_extension='.png',
):
if mlflow is None:
raise ImportError(
Expand All @@ -128,6 +141,7 @@ def __init__(
self._artifact_location = artifact_location

self._mlflow_client = MlflowClient(tracking_uri)
self._figure_file_extension = figure_file_extension

@property
@rank_zero_experiment
Expand Down Expand Up @@ -204,6 +218,20 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->

self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)

@rank_zero_only
def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None:

with tempfile.NamedTemporaryFile(suffix=self._figure_file_extension) as tmp_file:
figure.savefig(tmp_file)
self.experiment.log_artifact(
self.run_id,
tmp_file.name,
artifact_path=Path(self.save_dir) / ("figure_" + name)
)

if close:
plt.close(figure)

@rank_zero_only
def finalize(self, status: str = 'FINISHED') -> None:
super().finalize(status)
Expand Down
17 changes: 17 additions & 0 deletions pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
# needed for test mocks, these tests shall be updated
neptune, Experiment = None, None

_MATPLOTLIB_AVAILABLE = _module_available("matplotlib")
Haydnspass marked this conversation as resolved.
Show resolved Hide resolved

if _MATPLOTLIB_AVAILABLE:
import matplotlib.pyplot as plt
else:
from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib
plt = _matplotlib.pyplot


class NeptuneLogger(LightningLoggerBase):
r"""
Expand Down Expand Up @@ -263,6 +271,15 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti
# Lighting does not always guarantee.
self.log_metric(key, val)

@rank_zero_only
def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None:
import matplotlib.pyplot as plt

description = f"step_{step}" if step is not None else None
self.experiment.log_image(name, figure, description=description)
if close:
plt.close(figure)

@rank_zero_only
def finalize(self, status: str) -> None:
super().finalize(status)
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,23 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import \
_OMEGACONF_AVAILABLE, _module_available, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem

log = logging.getLogger(__name__)

if _OMEGACONF_AVAILABLE:
from omegaconf import Container, OmegaConf

_MATPLOTLIB_AVAILABLE = _module_available("matplotlib")

if _MATPLOTLIB_AVAILABLE:
import matplotlib.pyplot as plt
else:
from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib
plt = _matplotlib.pyplot


class TensorBoardLogger(LightningLoggerBase):
r"""
Expand Down Expand Up @@ -209,6 +218,10 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
m = f'\n you tried to log {v} which is not currently supported. Try a dict or a scalar/tensor.'
raise ValueError(m) from ex

@rank_zero_only
def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None:
self.experiment.add_figure(tag=name, figure=figure, global_step=step, close=close)

@rank_zero_only
def log_graph(self, model: LightningModule, input_array=None):
if self._log_graph:
Expand Down
16 changes: 16 additions & 0 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@
# needed for test mocks, these tests shall be updated
wandb, Run = None, None

_MATPLOTLIB_AVAILABLE = _module_available("matplotlib")

if _MATPLOTLIB_AVAILABLE:
import matplotlib.pyplot as plt
else:
from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib
plt = _matplotlib.pyplot


class WandbLogger(LightningLoggerBase):
r"""
Expand Down Expand Up @@ -199,6 +207,14 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
else:
self.experiment.log(metrics)

@rank_zero_only
def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None:
import matplotlib.pyplot as plt

self.experiment.log({name: wandb.Image(figure)}, step=step)
if close:
plt.close(figure)

@property
def save_dir(self) -> Optional[str]:
return self._save_dir
Expand Down
19 changes: 19 additions & 0 deletions pytorch_lightning/utilities/mock_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class matplotlib:
class pyplot:
figure = None
close = None
19 changes: 19 additions & 0 deletions tests/helpers/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np

from pytorch_lightning.utilities import _module_available

_MATPLOTLIB_AVAILABLE = _module_available("matplotlib")
if _MATPLOTLIB_AVAILABLE:
import matplotlib.pyplot as plt
else:
class plt:
figure = None


def dummy_figure() -> plt.figure:
"""Dummy figure to test logging capability of figures for loggers."""

f = plt.figure()
plt.plot(np.linspace(0., 1., 100), np.linspace(0., 10., 100) ** 2)

return f
41 changes: 40 additions & 1 deletion tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import (
CometLogger,
CSVLogger,
MLFlowLogger,
NeptuneLogger,
TensorBoardLogger,
TestTubeLogger,
WandbLogger,
)
from pytorch_lightning.loggers.base import DummyExperiment
from tests.helpers import BoringModel
from pytorch_lightning.utilities import _module_available
from tests.helpers import BoringModel, plotting
from tests.helpers.runif import RunIf
from tests.loggers.test_comet import _patch_comet_atexit
from tests.loggers.test_mlflow import mock_mlflow_run_creation
Expand Down Expand Up @@ -404,3 +406,40 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
wandb.init().step = 0
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.log.assert_called_once_with({'tmp-test': 1.0, 'trainer/global_step': 0})


@pytest.mark.skipif(
not _module_available("matplotlib"),
reason="close figure test requires matplotlib to be installed.")
@pytest.mark.parametrize("close", [True, False])
@pytest.mark.parametrize("logger_class", [
CometLogger,
CSVLogger,
MLFlowLogger,
NeptuneLogger,
TensorBoardLogger,
# Wandb has its own close_figure test
])
def test_logger_close_figure_all(tmpdir, monkeypatch, logger_class, close):
_patch_comet_atexit(monkeypatch)
try:
_test_logger_close_figure(tmpdir, monkeypatch, logger_class, close)
except (ImportError, ModuleNotFoundError):
pytest.xfail(f"pickle test requires {logger_class.__class__} dependencies to be installed.")


def _test_logger_close_figure(tmpdir, monkeypatch, logger_class, close):
_patch_comet_atexit(monkeypatch)

logger_args = _get_logger_args(logger_class, tmpdir)
logger = logger_class(**logger_args)

f = plotting.dummy_figure()

with mock.patch('matplotlib.pyplot.close') as plt_close:
logger.log_figure('dummy', f, 0, close=close)

if close:
plt_close.assert_called_once_with(f)
else:
plt_close.assert_not_called()
Loading