From 7b2d2d8028a8a50f3b5634857f5fba750617faee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 26 Feb 2021 16:57:51 +0100 Subject: [PATCH 01/25] Abstract log_figure API that implements reasonable default for all loggers Takes a plt.figure (along with some metadata) and logs it to the respective logger - CSV logger and testtube are in silent no-op - matplotlib becomes base requirement --- CHANGELOG.md | 3 +++ docs/source/extensions/logging.rst | 13 ++++++++++++ pytorch_lightning/loggers/base.py | 24 ++++++++++++++++++++++ pytorch_lightning/loggers/comet.py | 8 ++++++++ pytorch_lightning/loggers/mlflow.py | 18 ++++++++++++++++ pytorch_lightning/loggers/neptune.py | 13 ++++++++++++ pytorch_lightning/loggers/tensorboard.py | 5 +++++ pytorch_lightning/loggers/wandb.py | 8 ++++++++ requirements.txt | 1 + requirements/extra.txt | 1 - tests/helpers/plotting.py | 10 +++++++++ tests/loggers/test_all.py | 26 +++++++++++++++++++++++- tests/loggers/test_base.py | 18 +++++++++++++++- tests/loggers/test_comet.py | 19 ++++++++++++++++- tests/loggers/test_mlflow.py | 22 +++++++++++++++++++- tests/loggers/test_neptune.py | 17 +++++++++++++++- tests/loggers/test_tensorboard.py | 15 +++++++++++++- tests/loggers/test_wandb.py | 15 +++++++++++++- 18 files changed, 228 insertions(+), 8 deletions(-) create mode 100644 tests/helpers/plotting.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e1d4b1d6c983..ce5848006ddbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Logger have unified API for figure logging `log_figure`. CSV and Testtube excluded. + + - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) diff --git a/docs/source/extensions/logging.rst b/docs/source/extensions/logging.rst index 026040f03a330..5d747ca27a6d7 100644 --- a/docs/source/extensions/logging.rst +++ b/docs/source/extensions/logging.rst @@ -290,6 +290,19 @@ in the `hparams tab None: + """ + Logs a matplotlib figure. + Args: + name: name of the figure + figure: plt figure handle + step: step number at which the figure should be recorded + close: close figure after logging + """ + """Default is silent. + Not raising NotImplemented because one could have multiple logger where only some support log_figure.""" + if close: + plt.close(figure) + @staticmethod def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]: # in case converting from namespace @@ -375,6 +390,15 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> for logger in self._logger_iterable: logger.log_metrics(metrics, step) + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: + + for logger in self._logger_iterable: + # don't close in the individual loggers, but once at the end + logger.log_figure(name, figure, step=step, close=False) + + if close: + plt.close(figure) + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: for logger in self._logger_iterable: logger.log_hyperparams(params) diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 31c768fa5f37b..057456fc2c21b 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -20,6 +20,7 @@ from argparse import Namespace from typing import Any, Dict, Optional, Union +import matplotlib.pyplot as plt import torch from torch import is_tensor @@ -251,6 +252,13 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti metrics_without_epoch = self._add_prefix(metrics_without_epoch) self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch) + @rank_zero_only + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: + self.experiment.log_figure(figure_name=name, figure=figure, step=step) + + if close: + plt.close(figure) + def reset_experiment(self): self._experiment = None diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 4aa4c67b576ec..901636fc70c7b 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -17,9 +17,12 @@ """ import re from argparse import Namespace +from pathlib import Path from time import time from typing import Any, Dict, Optional, Union +import matplotlib.pyplot as plt + from pytorch_lightning import _logger as log from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only, rank_zero_warn @@ -78,6 +81,7 @@ def any_lightning_module_function_or_hook(self): Defaults to `./mlflow` if `tracking_uri` is not provided. Has no effect if `tracking_uri` is provided. prefix: A string to put at the beginning of metric keys. + figure_file_extension: File extension with which matplotlib saves figure Raises: ImportError: @@ -93,6 +97,7 @@ def __init__( tags: Optional[Dict[str, Any]] = None, save_dir: Optional[str] = './mlruns', prefix: str = '', + figure_file_extension='.png', ): if mlflow is None: raise ImportError( @@ -110,6 +115,7 @@ def __init__( self.tags = tags self._prefix = prefix self._mlflow_client = MlflowClient(tracking_uri) + self._figure_file_extension = figure_file_extension @property @rank_zero_experiment @@ -183,6 +189,18 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step) + @rank_zero_only + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: + + filename = self.save_dir + '/' + name + f"_step_{step}" + self._figure_file_extension + figure.savefig(filename) + self.experiment.log_artifact(self.run_id, filename, artifact_path="figure_" + name) + + Path(filename).unlink(missing_ok=False) # delete temporary file + + if close: + plt.close(figure) + @rank_zero_only def finalize(self, status: str = 'FINISHED') -> None: super().finalize(status) diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index d4f24567cab6a..1d32744dab592 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -18,6 +18,7 @@ from argparse import Namespace from typing import Any, Dict, Iterable, Optional, Union +import matplotlib.pyplot as plt import torch from torch import is_tensor @@ -262,6 +263,18 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti # Lighting does not always guarantee. self.log_metric(key, val) + @rank_zero_only + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: + if step is not None: + description = f"step_{step}" + else: + description = None + + self.experiment.log_image(name, figure, description=description) + + if close: + plt.close(figure) + @rank_zero_only def finalize(self, status: str) -> None: super().finalize(status) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 0485868fa2ef1..9583e897ef046 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -20,6 +20,7 @@ from argparse import Namespace from typing import Any, Dict, Optional, Union +import matplotlib.pyplot as plt import torch from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard.summary import hparams @@ -204,6 +205,10 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> m = f'\n you tried to log {v} which is not currently supported. Try a dict or a scalar/tensor.' type(ex)(ex.message + m) + @rank_zero_only + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: + self.experiment.add_figure(tag=name, figure=figure, global_step=step, close=close) + @rank_zero_only def log_graph(self, model: LightningModule, input_array=None): if self._log_graph: diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 68c7e228cc14a..3007b8c9fdffb 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -19,6 +19,7 @@ from argparse import Namespace from typing import Any, Dict, Optional, Union +import matplotlib.pyplot as plt import torch.nn as nn from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment @@ -200,6 +201,13 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> else: self.experiment.log(metrics) + @rank_zero_only + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: + self.experiment.log({name: wandb.Image(figure)}, step=step) + + if close: + plt.close(figure) + @property def save_dir(self) -> Optional[str]: return self._save_dir diff --git a/requirements.txt b/requirements.txt index bdfd6601ba4c2..e0871772f5086 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ # the default package dependencies numpy>=1.16.6 +matplotlib>3.1 torch>=1.4 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 diff --git a/requirements/extra.txt b/requirements/extra.txt index 0e7dffbcb39b0..76b84ab7c0151 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -1,6 +1,5 @@ # extended list of package dependencies to reach full functionality -matplotlib>3.1 horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already installed omegaconf>=2.0.1 torchtext>=0.5, <0.7 # TODO: temporary fix fix for compatibility diff --git a/tests/helpers/plotting.py b/tests/helpers/plotting.py new file mode 100644 index 0000000000000..f3910240657c5 --- /dev/null +++ b/tests/helpers/plotting.py @@ -0,0 +1,10 @@ +import numpy as np +from matplotlib import pyplot as plt + + +def dummy_figure(): + """Dummy figure to test logging capability of figures for loggers.""" + f = plt.figure() + plt.plot(np.linspace(0., 1., 100), np.linspace(0., 10., 100) ** 2) + + return f diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 02721ba436743..f217fb27540b6 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -25,6 +25,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers import ( CometLogger, + CSVLogger, MLFlowLogger, NeptuneLogger, TensorBoardLogger, @@ -33,7 +34,7 @@ ) from pytorch_lightning.loggers.base import DummyExperiment from pytorch_lightning.trainer.states import TrainerState -from tests.helpers import BoringModel +from tests.helpers import BoringModel, plotting from tests.loggers.test_comet import _patch_comet_atexit from tests.loggers.test_mlflow import mock_mlflow_run_creation @@ -405,3 +406,26 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): wandb.init().step = 0 logger.log_metrics({"test": 1.0}, step=0) logger.experiment.log.assert_called_once_with({'tmp-test': 1.0}, step=0) + + +@pytest.mark.parametrize("close", [True, False]) +@pytest.mark.parametrize("logger_class", [ + CometLogger, + CSVLogger, + MLFlowLogger, + NeptuneLogger, + TensorBoardLogger, + WandbLogger, +]) +def test_logger_close_figure_all(logger_class, close, tmpdir): + f = plotting.dummy_figure() + + logger = _instantiate_logger(logger_class, save_idr=tmpdir) + + with mock.patch('matplotlib.pyplot.close') as plt_close: + logger.log_figure('dummy', f, 0, close=close) + + if close: + plt_close.assert_called_once_with(f) + else: + plt_close.assert_not_called() diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index c48fef5e04b49..2cdd94ded3272 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -23,7 +23,7 @@ from pytorch_lightning.loggers.base import DummyExperiment, DummyLogger from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import rank_zero_only -from tests.helpers import BoringModel +from tests.helpers import BoringModel, plotting def test_logger_collection(): @@ -73,6 +73,10 @@ def log_hyperparams(self, params): def log_metrics(self, metrics, step): self.metrics_logged = metrics + @rank_zero_only + def log_figure(self, name, figure, step): + self.figure_logged = figure + @rank_zero_only def finalize(self, status): self.finalized_status = status @@ -166,6 +170,18 @@ def test_multiple_loggers_pickle(tmpdir): assert trainer2.logger[1].metrics_logged == {"acc": 1.0} +def test_multiple_loggers_figure(): + + logger1 = MagicMock() + logger2 = MagicMock() + + logger = LoggerCollection([logger1, logger2]) + logger.log_figure('dummy_figure', plotting.dummy_figure(), 0) + + logger1.log_figure.assert_called_once() + logger2.log_figure.assert_called_once() + + def test_adding_step_key(tmpdir): class CustomTensorBoardLogger(TensorBoardLogger): diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index 1d686c6ba8c15..a0d7357da4108 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -19,7 +19,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import CometLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel +from tests.helpers import BoringModel, plotting def _patch_comet_atexit(monkeypatch): @@ -222,3 +222,20 @@ def test_comet_epoch_logging(comet, comet_experiment, tmpdir, monkeypatch): logger = CometLogger(project_name="test", save_dir=tmpdir) logger.log_metrics({"test": 1, "epoch": 1}, step=123) logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123) + + +@patch("pytorch_lightning.loggers.comet.CometExperiment") +@patch('pytorch_lightning.loggers.comet.comet_ml') +@pytest.mark.parametrize("step_idx", [10, None]) +def test_comet_log_figure(comet, comet_experiment, tmpdir, monkeypatch, step_idx): + + _patch_comet_atexit(monkeypatch) + logger = CometLogger(project_name="test", save_dir=tmpdir) + logger.log_figure('dummy', plotting.dummy_figure(), step_idx, close=True) # functional test + + # test whether figure is closed etc. + with patch.object(logger.experiment, 'log_figure') as mock_log: + f = plotting.dummy_figure() + logger.log_figure('dummy', f, step_idx, close=True) + + mock_log.assert_called_once_with(figure_name='dummy', figure=f, step=step_idx) \ No newline at end of file diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index d2673f48b871b..c21e2976ef47b 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -19,7 +19,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import _MLFLOW_AVAILABLE, MLFlowLogger -from tests.helpers import BoringModel +from tests.helpers import BoringModel, plotting def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None): @@ -199,3 +199,23 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir): with pytest.warns(RuntimeWarning, match=f'Discard {key}={value}'): logger.log_hyperparams(params) + + +@mock.patch('pytorch_lightning.loggers.mlflow.mlflow') +@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') +@pytest.mark.parametrize("step_idx", [10, None]) +@pytest.mark.parametrize("figure_format", ['.png', '.pdf']) +def test_mlflow_log_figure(client, mlflow, step_idx, figure_format, tmpdir): + + logger = MLFlowLogger('test', save_dir=tmpdir, figure_file_extension=figure_format) + logger.log_figure('dummy', plotting.dummy_figure(), step_idx, close=True) # functional test + + # test whether figure is closed etc. + with mock.patch.object(logger.experiment, 'log_artifact') as mock_log: + f = plotting.dummy_figure() + logger.log_figure('dummy', f, step_idx, close=True) + + fname_expect = logger.save_dir + f'/dummy_step_{step_idx}{figure_format}' + artifact_expect = 'figure_dummy' + + mock_log.assert_called_once_with(logger.run_id, fname_expect, artifact_path=artifact_expect) diff --git a/tests/loggers/test_neptune.py b/tests/loggers/test_neptune.py index 3ac763cc87b4f..4de0d1c93d34e 100644 --- a/tests/loggers/test_neptune.py +++ b/tests/loggers/test_neptune.py @@ -13,11 +13,12 @@ # limitations under the License. from unittest.mock import MagicMock, patch +import pytest import torch from pytorch_lightning import Trainer from pytorch_lightning.loggers import NeptuneLogger -from tests.helpers import BoringModel +from tests.helpers import BoringModel, plotting @patch('pytorch_lightning.loggers.neptune.neptune') @@ -124,3 +125,17 @@ def _run_training(logger): logger_open_after_fit = _run_training(NeptuneLogger(offline_mode=True, close_after_fit=False)) assert logger_open_after_fit._experiment.stop.call_count == 0 + + +@patch('pytorch_lightning.loggers.neptune.neptune') +@pytest.mark.parametrize("step_idx", [10, None]) +def test_neptune_log_figure(neptune, step_idx): + logger = NeptuneLogger(api_key='test', project_name='project') + logger.log_figure('dummy', plotting.dummy_figure(), step=42, close=True) + + with patch.object(logger.experiment, 'log_image') as mock_log: + f = plotting.dummy_figure() + logger.log_figure('dummy', f, step_idx, close=True) + + mock_log.assert_called_once_with('dummy', f, + description=f"step_{step_idx}" if step_idx is not None else None) diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index e5e3f231d3ac7..cbf8c0cf915e4 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -24,7 +24,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger -from tests.helpers import BoringModel +from tests.helpers import BoringModel, plotting @pytest.mark.skipif( @@ -133,6 +133,19 @@ def test_tensorboard_log_metrics(tmpdir, step_idx): logger.log_metrics(metrics, step_idx) +@pytest.mark.parametrize("step_idx", [10, None]) +def test_tensorboard_log_figure(tmpdir, step_idx): + logger = TensorBoardLogger(tmpdir) + logger.log_figure('dummy', plotting.dummy_figure(), step_idx, close=True) # functional test + + # test whether figure is closed etc. + with mock.patch.object(logger.experiment, 'add_figure') as mock_log: + f = plotting.dummy_figure() + logger.log_figure('dummy', f, step_idx, close=True) + + mock_log.assert_called_once_with(tag="dummy", figure=f, global_step=step_idx, close=True) + + def test_tensorboard_log_hyperparams(tmpdir): logger = TensorBoardLogger(tmpdir) hparams = { diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index e5b9b891b88c1..205db93f3bace 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -22,7 +22,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel +from tests.helpers import BoringModel, plotting def get_warnings(recwarn): @@ -214,3 +214,16 @@ def test_wandb_logger_offline_log_model(wandb, tmpdir): """ Test that log_model=True raises an error in offline mode """ with pytest.raises(MisconfigurationException, match='checkpoints cannot be uploaded in offline mode'): _ = WandbLogger(save_dir=str(tmpdir), offline=True, log_model=True) + + +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +@pytest.mark.parametrize("step_idx", [10, None]) +def test_wandb_logger_log_figure(wandb, step_idx): + logger = WandbLogger(anonymous=True, offline=True) + logger.log_figure('dummy', plotting.dummy_figure(), step_idx, close=True) # functional test + + with mock.patch.object(logger.experiment, 'log') as mock_log: + f = plotting.dummy_figure() + logger.log_figure('dummy', f, step_idx, close=True) + + mock_log.assert_called_once_with({'dummy': wandb.Image(f)}, step=step_idx) From 6f8bd9abb366027cd487701a4f9f40eb51ca34ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 26 Feb 2021 18:58:53 +0100 Subject: [PATCH 02/25] Fix missing line at end of file, rm default arg for unlink (compatibility) --- pytorch_lightning/loggers/mlflow.py | 2 +- tests/loggers/test_comet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 901636fc70c7b..f6453d3c6d92f 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -196,7 +196,7 @@ def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, figure.savefig(filename) self.experiment.log_artifact(self.run_id, filename, artifact_path="figure_" + name) - Path(filename).unlink(missing_ok=False) # delete temporary file + Path(filename).unlink() # delete temporary file if close: plt.close(figure) diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index a0d7357da4108..eee7eddea7ab7 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -238,4 +238,4 @@ def test_comet_log_figure(comet, comet_experiment, tmpdir, monkeypatch, step_idx f = plotting.dummy_figure() logger.log_figure('dummy', f, step_idx, close=True) - mock_log.assert_called_once_with(figure_name='dummy', figure=f, step=step_idx) \ No newline at end of file + mock_log.assert_called_once_with(figure_name='dummy', figure=f, step=step_idx) From 95b4d4e359e326cfb594eb24b58e0ef43b1f13da Mon Sep 17 00:00:00 2001 From: Lucas Date: Fri, 12 Mar 2021 10:32:21 +0100 Subject: [PATCH 03/25] Update CHANGELOG.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ce5848006ddbb..78d0170eb0cd8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Logger have unified API for figure logging `log_figure`. CSV and Testtube excluded. +- Added unified API for figure logging `log_figure` in {insert all supported logger names here}. - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) From dcdfbee56eabc158081f1fba51dbceeb99193199 Mon Sep 17 00:00:00 2001 From: Lucas Date: Fri, 12 Mar 2021 10:32:45 +0100 Subject: [PATCH 04/25] Update pytorch_lightning/loggers/neptune.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/loggers/neptune.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index 1d32744dab592..f62a3f04d619c 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -265,10 +265,7 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti @rank_zero_only def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: - if step is not None: - description = f"step_{step}" - else: - description = None + description = f"step_{step}" if step is not None else None self.experiment.log_image(name, figure, description=description) From eca36b2eaefe0102ace89980a3cf5b16f08fb15d Mon Sep 17 00:00:00 2001 From: Lucas Date: Fri, 12 Mar 2021 10:33:31 +0100 Subject: [PATCH 05/25] Update docs/source/extensions/logging.rst Co-authored-by: Nicki Skafte --- docs/source/extensions/logging.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/extensions/logging.rst b/docs/source/extensions/logging.rst index 5d747ca27a6d7..e76cd9c6ed3ec 100644 --- a/docs/source/extensions/logging.rst +++ b/docs/source/extensions/logging.rst @@ -295,7 +295,7 @@ Logging Figures *************** When training a model, often it is very indicative to log figures, e.g. of the in- and output. -For standard matplotlib figures, Lightning has a unified API that works with the implemented loggers. +For standard ``matplotlib.pyplot`` figures, Lightning has a unified API that works with most of the implemented loggers. .. code-block:: python f = plt.figure() From b0b67d70eeac80d6a315506265a0fe21097af2e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 12 Mar 2021 11:15:42 +0100 Subject: [PATCH 06/25] use pathlib for path --- pytorch_lightning/loggers/base.py | 1 - pytorch_lightning/loggers/mlflow.py | 6 +++--- tests/loggers/test_mlflow.py | 3 ++- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index f460808ff1632..2236318e8b991 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -391,7 +391,6 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> logger.log_metrics(metrics, step) def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: - for logger in self._logger_iterable: # don't close in the individual loggers, but once at the end logger.log_figure(name, figure, step=step, close=False) diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index f6453d3c6d92f..f2591efb68014 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -191,12 +191,12 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> @rank_zero_only def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: - - filename = self.save_dir + '/' + name + f"_step_{step}" + self._figure_file_extension + # save temporary file until logged + filename = Path(self.save_dir) / (name + f"_step_{step}" + self._figure_file_extension) figure.savefig(filename) self.experiment.log_artifact(self.run_id, filename, artifact_path="figure_" + name) - Path(filename).unlink() # delete temporary file + filename.unlink() # delete temporary file if close: plt.close(figure) diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index c21e2976ef47b..40b7c741b702e 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from pathlib import Path from unittest import mock from unittest.mock import MagicMock @@ -218,4 +219,4 @@ def test_mlflow_log_figure(client, mlflow, step_idx, figure_format, tmpdir): fname_expect = logger.save_dir + f'/dummy_step_{step_idx}{figure_format}' artifact_expect = 'figure_dummy' - mock_log.assert_called_once_with(logger.run_id, fname_expect, artifact_path=artifact_expect) + mock_log.assert_called_once_with(logger.run_id, Path(fname_expect), artifact_path=artifact_expect) From a77028a550ffbea3ceece92163d1f2d5b36ef17f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 12 Mar 2021 11:54:01 +0100 Subject: [PATCH 07/25] Make matplotlib extra --- requirements.txt | 1 - requirements/extra.txt | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e0871772f5086..bdfd6601ba4c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ # the default package dependencies numpy>=1.16.6 -matplotlib>3.1 torch>=1.4 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 diff --git a/requirements/extra.txt b/requirements/extra.txt index 76b84ab7c0151..0e7dffbcb39b0 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -1,5 +1,6 @@ # extended list of package dependencies to reach full functionality +matplotlib>3.1 horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already installed omegaconf>=2.0.1 torchtext>=0.5, <0.7 # TODO: temporary fix fix for compatibility From e11fc0c6ffe8c61a814462a65ae8b8e45ac7fe0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 12 Mar 2021 12:24:18 +0100 Subject: [PATCH 08/25] Make matplotlib local import --- pytorch_lightning/loggers/comet.py | 6 +++--- pytorch_lightning/loggers/mlflow.py | 2 +- pytorch_lightning/loggers/neptune.py | 7 +++---- pytorch_lightning/loggers/tensorboard.py | 3 +-- pytorch_lightning/loggers/wandb.py | 6 +++--- 5 files changed, 11 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index cc6e84998e75b..1290d100b2241 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -21,7 +21,6 @@ from argparse import Namespace from typing import Any, Dict, Optional, Union -import matplotlib.pyplot as plt import torch from torch import is_tensor @@ -254,9 +253,10 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch) @rank_zero_only - def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: - self.experiment.log_figure(figure_name=name, figure=figure, step=step) + def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: + import matplotlib.pyplot as plt + self.experiment.log_figure(figure_name=name, figure=figure, step=step) if close: plt.close(figure) diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 4ec522c4f5c89..4903febc1f633 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -190,13 +190,13 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> @rank_zero_only def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: import matplotlib.pyplot as plt + # save temporary file until logged filename = Path(self.save_dir) / (name + f"_step_{step}" + self._figure_file_extension) figure.savefig(filename) self.experiment.log_artifact(self.run_id, filename, artifact_path="figure_" + name) filename.unlink() # delete temporary file - if close: plt.close(figure) diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index 315b3b4f4d3ce..2169dd175e621 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -19,7 +19,6 @@ from argparse import Namespace from typing import Any, Dict, Iterable, Optional, Union -import matplotlib.pyplot as plt import torch from torch import is_tensor @@ -265,11 +264,11 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti self.log_metric(key, val) @rank_zero_only - def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: - description = f"step_{step}" if step is not None else None + def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: + import matplotlib.pyplot as plt + description = f"step_{step}" if step is not None else None self.experiment.log_image(name, figure, description=description) - if close: plt.close(figure) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index aac944c5a8e02..c5a0ba71a103e 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -21,7 +21,6 @@ from argparse import Namespace from typing import Any, Dict, Optional, Union -import matplotlib.pyplot as plt import torch from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard.summary import hparams @@ -211,7 +210,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> type(ex)(ex.message + m) @rank_zero_only - def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: + def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: self.experiment.add_figure(tag=name, figure=figure, global_step=step, close=close) @rank_zero_only diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index a22999241151d..25273bc92c736 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -19,7 +19,6 @@ from argparse import Namespace from typing import Any, Dict, Optional, Union -import matplotlib.pyplot as plt import torch.nn as nn from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment @@ -201,9 +200,10 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> self.experiment.log(metrics) @rank_zero_only - def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: - self.experiment.log({name: wandb.Image(figure)}, step=step) + def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: + import matplotlib.pyplot as plt + self.experiment.log({name: wandb.Image(figure)}, step=step) if close: plt.close(figure) From 32f67c6dce617067938c4d9c64eef05b344c7ee6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 26 Mar 2021 12:13:30 +0100 Subject: [PATCH 09/25] Merge master --- .github/CODEOWNERS | 6 +- .github/workflows/ci_dockers.yml | 18 +- .github/workflows/ci_test-base.yml | 5 +- .github/workflows/ci_test-conda.yml | 2 +- .github/workflows/ci_test-full.yml | 6 +- .github/workflows/docs-checks.yml | 2 +- .github/workflows/events-nightly.yml | 23 + .github/workflows/release-docker.yml | 26 +- .gitignore | 1 + .pre-commit-config.yaml | 5 - CHANGELOG.md | 113 ++- Makefile | 2 +- azure-pipelines.yml | 34 +- dockers/nvidia/Dockerfile | 19 +- dockers/release/Dockerfile | 1 - docs/source/advanced/multi_gpu.rst | 56 +- docs/source/advanced/multiple_loaders.rst | 50 +- docs/source/benchmarking/performance.rst | 16 + docs/source/common/lightning_module.rst | 101 +- docs/source/common/trainer.rst | 5 +- docs/source/conf.py | 25 +- docs/source/extensions/callbacks.rst | 12 + docs/source/extensions/datamodules.rst | 13 +- docs/source/extensions/logging.rst | 2 +- docs/source/extensions/metrics.rst | 884 +----------------- docs/source/starter/introduction_guide.rst | 6 +- docs/source/starter/new-project.rst | 2 +- pl_examples/__init__.py | 4 +- pl_examples/basic_examples/autoencoder.py | 14 +- .../basic_examples/conv_sequential_example.py | 2 +- .../basic_examples/profiler_example.py | 102 ++ pl_examples/basic_examples/submit_ddp2_job.sh | 2 +- pl_examples/basic_examples/submit_ddp_job.sh | 2 +- .../computer_vision_fine_tuning.py | 5 +- pytorch_lightning/__init__.py | 81 +- pytorch_lightning/accelerators/accelerator.py | 86 +- pytorch_lightning/accelerators/cpu.py | 13 + pytorch_lightning/accelerators/gpu.py | 15 +- pytorch_lightning/accelerators/tpu.py | 39 +- pytorch_lightning/callbacks/base.py | 12 +- pytorch_lightning/callbacks/early_stopping.py | 2 +- .../gradient_accumulation_scheduler.py | 2 +- .../callbacks/model_checkpoint.py | 40 +- pytorch_lightning/callbacks/progress.py | 9 +- pytorch_lightning/core/datamodule.py | 67 +- pytorch_lightning/core/hooks.py | 106 ++- pytorch_lightning/core/lightning.py | 23 +- pytorch_lightning/core/memory.py | 12 +- pytorch_lightning/core/step_result.py | 2 +- pytorch_lightning/distributed/dist.py | 51 +- pytorch_lightning/info.py | 36 + pytorch_lightning/loggers/mlflow.py | 14 +- pytorch_lightning/metrics/__init__.py | 7 + .../metrics/classification/accuracy.py | 130 +-- .../metrics/classification/auc.py | 67 +- .../metrics/classification/auroc.py | 158 +--- .../classification/average_precision.py | 103 +- .../classification/confusion_matrix.py | 90 +- .../metrics/classification/f_beta.py | 180 +--- .../classification/hamming_distance.py | 88 +- .../metrics/classification/helpers.py | 539 ----------- .../metrics/classification/iou.py | 83 +- .../classification/precision_recall.py | 283 +----- .../classification/precision_recall_curve.py | 131 +-- .../metrics/classification/roc.py | 126 +-- .../metrics/classification/stat_scores.py | 245 +---- pytorch_lightning/metrics/compositional.py | 97 +- .../metrics/functional/accuracy.py | 101 +- pytorch_lightning/metrics/functional/auc.py | 59 +- pytorch_lightning/metrics/functional/auroc.py | 168 +--- .../metrics/functional/average_precision.py | 78 +- .../metrics/functional/classification.py | 331 +------ .../metrics/functional/confusion_matrix.py | 74 +- .../metrics/functional/explained_variance.py | 69 +- .../metrics/functional/f_beta.py | 120 +-- .../metrics/functional/hamming_distance.py | 61 +- .../metrics/functional/image_gradients.py | 58 +- pytorch_lightning/metrics/functional/iou.py | 87 +- .../metrics/functional/mean_absolute_error.py | 35 +- .../metrics/functional/mean_relative_error.py | 38 +- .../metrics/functional/mean_squared_error.py | 35 +- .../functional/mean_squared_log_error.py | 35 +- pytorch_lightning/metrics/functional/nlp.py | 89 +- .../metrics/functional/precision_recall.py | 453 +-------- .../functional/precision_recall_curve.py | 196 +--- pytorch_lightning/metrics/functional/psnr.py | 84 +- .../metrics/functional/r2score.py | 112 +-- pytorch_lightning/metrics/functional/roc.py | 132 +-- .../metrics/functional/self_supervised.py | 41 +- pytorch_lightning/metrics/functional/ssim.py | 130 +-- .../metrics/functional/stat_scores.py | 260 +----- pytorch_lightning/metrics/metric.py | 603 +----------- .../metrics/regression/explained_variance.py | 106 +-- .../metrics/regression/mean_absolute_error.py | 64 +- .../metrics/regression/mean_squared_error.py | 65 +- .../regression/mean_squared_log_error.py | 67 +- pytorch_lightning/metrics/regression/psnr.py | 123 +-- .../metrics/regression/r2score.py | 122 +-- pytorch_lightning/metrics/regression/ssim.py | 78 +- pytorch_lightning/metrics/utils.py | 287 +----- pytorch_lightning/overrides/base.py | 2 +- .../overrides/torch_distributed.py | 94 ++ pytorch_lightning/plugins/__init__.py | 2 + .../plugins/precision/__init__.py | 1 + .../plugins/precision/apex_amp.py | 1 - pytorch_lightning/plugins/precision/double.py | 95 ++ .../plugins/precision/native_amp.py | 18 + .../plugins/precision/precision_plugin.py | 1 - .../plugins/training_type/ddp.py | 71 +- .../plugins/training_type/ddp_spawn.py | 8 +- .../plugins/training_type/deepspeed.py | 10 - pytorch_lightning/plugins/training_type/dp.py | 6 +- .../plugins/training_type/horovod.py | 17 +- .../plugins/training_type/parallel.py | 38 +- .../plugins/training_type/rpc.py | 2 +- .../plugins/training_type/rpc_sequential.py | 2 +- .../plugins/training_type/single_device.py | 12 +- .../plugins/training_type/single_tpu.py | 7 +- .../plugins/training_type/tpu_spawn.py | 31 +- .../training_type/training_type_plugin.py | 45 +- .../plugins/training_type/utils.py | 13 + pytorch_lightning/profiler/__init__.py | 25 +- pytorch_lightning/profiler/profilers.py | 277 ++++-- pytorch_lightning/profiler/pytorch.py | 601 ++++++++---- pytorch_lightning/setup_tools.py | 6 +- pytorch_lightning/trainer/callback_hook.py | 60 +- .../trainer/configuration_validator.py | 9 +- .../connectors/accelerator_connector.py | 15 +- .../trainer/connectors/data_connector.py | 4 + .../trainer/connectors/env_vars_connector.py | 14 +- .../logger_connector/epoch_result_store.py | 9 +- .../logger_connector/logger_connector.py | 11 +- .../logger_connector/metrics_holder.py | 57 +- .../trainer/connectors/profiler_connector.py | 8 +- pytorch_lightning/trainer/data_loading.py | 23 +- pytorch_lightning/trainer/deprecated_api.py | 56 +- pytorch_lightning/trainer/evaluation_loop.py | 47 +- pytorch_lightning/trainer/predict_loop.py | 25 +- pytorch_lightning/trainer/properties.py | 10 + pytorch_lightning/trainer/trainer.py | 92 +- pytorch_lightning/trainer/training_loop.py | 39 +- pytorch_lightning/tuner/tuning.py | 13 +- pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/argparse.py | 20 +- pytorch_lightning/utilities/argparse_utils.py | 6 +- pytorch_lightning/utilities/distributed.py | 6 +- pytorch_lightning/utilities/imports.py | 24 +- pytorch_lightning/utilities/model_utils.py | 7 +- .../utilities/signature_utils.py | 22 + pytorch_lightning/utilities/warning_utils.py | 6 +- .../utilities/xla_device_utils.py | 7 +- requirements.txt | 2 + requirements/adjust_versions.py | 1 + requirements/extra.txt | 3 +- requirements/test.txt | 7 +- setup.cfg | 17 +- setup.py | 46 +- tests/__init__.py | 4 +- .../test_accelerator_connector.py | 30 +- tests/accelerators/test_common.py | 12 +- tests/accelerators/test_cpu.py | 32 + tests/accelerators/test_ddp.py | 30 +- tests/base/model_template.py | 3 +- tests/callbacks/test_callback_hook_outputs.py | 63 ++ tests/callbacks/test_callbacks.py | 18 +- .../test_checkpoint_callback_frequency.py | 40 + .../checkpointing/test_legacy_checkpoints.py | 5 +- tests/checkpointing/test_model_checkpoint.py | 5 +- tests/checkpointing/test_torch_saving.py | 1 + tests/core/test_datamodules.py | 61 +- tests/core/test_hooks.py | 56 ++ tests/core/test_memory.py | 24 +- tests/core/test_metric_result_integration.py | 2 +- tests/deprecated_api/__init__.py | 18 + tests/deprecated_api/test_remove_1-4.py | 18 - tests/deprecated_api/test_remove_1-5.py | 107 +++ tests/helpers/advanced_models.py | 4 +- tests/helpers/datasets.py | 15 +- tests/helpers/runif.py | 7 + tests/helpers/test_datasets.py | 13 +- tests/loggers/test_mlflow.py | 28 + tests/metrics/classification/__init__.py | 0 tests/metrics/classification/inputs.py | 66 -- tests/metrics/classification/test_accuracy.py | 175 ---- tests/metrics/classification/test_auc.py | 64 -- tests/metrics/classification/test_auroc.py | 142 --- .../classification/test_average_precision.py | 97 -- .../classification/test_confusion_matrix.py | 128 --- tests/metrics/classification/test_f_beta.py | 153 --- .../classification/test_hamming_distance.py | 80 -- tests/metrics/classification/test_inputs.py | 311 ------ tests/metrics/classification/test_iou.py | 216 ----- .../classification/test_precision_recall.py | 347 ------- .../test_precision_recall_curve.py | 97 -- tests/metrics/classification/test_roc.py | 99 -- .../classification/test_stat_scores.py | 255 ----- tests/metrics/functional/__init__.py | 0 .../metrics/functional/test_classification.py | 89 -- .../functional/test_image_gradients.py | 109 --- tests/metrics/functional/test_nlp.py | 68 -- tests/metrics/functional/test_reduction.py | 28 - .../functional/test_self_supervised.py | 32 - tests/metrics/regression/__init__.py | 0 .../regression/test_explained_variance.py | 77 -- tests/metrics/regression/test_mean_error.py | 87 -- tests/metrics/regression/test_psnr.py | 133 --- tests/metrics/regression/test_r2score.py | 114 --- tests/metrics/regression/test_ssim.py | 104 --- tests/metrics/test_composition.py | 510 ---------- tests/metrics/test_ddp.py | 71 -- tests/metrics/test_metric.py | 395 -------- tests/metrics/test_metric_lightning.py | 8 +- tests/metrics/test_remove_1-5_metrics.py | 348 +++++++ tests/metrics/utils.py | 3 +- .../data/horovod/train_default_model.py | 9 +- tests/models/test_amp.py | 31 +- tests/models/test_hooks.py | 261 ++++-- tests/models/test_horovod.py | 118 ++- tests/models/test_tpu.py | 41 + tests/overrides/test_data_parallel.py | 2 +- tests/plugins/test_custom_plugin.py | 41 + tests/plugins/test_deepspeed_plugin.py | 12 +- tests/plugins/test_double_plugin.py | 129 +++ tests/plugins/test_sharded_plugin.py | 10 +- tests/special_tests.sh | 13 +- tests/test_profiler.py | 360 ++++++- tests/trainer/flags/test_env_vars.py | 31 +- .../logging_/test_eval_loop_logging_1_0.py | 16 +- .../trainer/logging_/test_logger_connector.py | 32 +- tests/trainer/optimization/test_optimizers.py | 2 + tests/trainer/properties/test_get_model.py | 23 - tests/trainer/test_config_validator.py | 50 +- tests/trainer/test_dataloaders.py | 69 ++ tests/trainer/test_evaluation_loop.py | 42 + tests/trainer/test_lr_finder.py | 24 + tests/trainer/test_trainer.py | 177 +--- tests/tuner/test_scale_batch_size.py | 65 ++ tests/utilities/test_all_gather_grad.py | 23 + ...est_argparse_utils.py => test_argparse.py} | 47 +- 239 files changed, 4902 insertions(+), 12644 deletions(-) create mode 100644 pl_examples/basic_examples/profiler_example.py create mode 100644 pytorch_lightning/info.py delete mode 100644 pytorch_lightning/metrics/classification/helpers.py create mode 100644 pytorch_lightning/overrides/torch_distributed.py create mode 100644 pytorch_lightning/plugins/precision/double.py create mode 100644 pytorch_lightning/utilities/signature_utils.py create mode 100644 tests/core/test_hooks.py delete mode 100644 tests/metrics/classification/__init__.py delete mode 100644 tests/metrics/classification/inputs.py delete mode 100644 tests/metrics/classification/test_accuracy.py delete mode 100644 tests/metrics/classification/test_auc.py delete mode 100644 tests/metrics/classification/test_auroc.py delete mode 100644 tests/metrics/classification/test_average_precision.py delete mode 100644 tests/metrics/classification/test_confusion_matrix.py delete mode 100644 tests/metrics/classification/test_f_beta.py delete mode 100644 tests/metrics/classification/test_hamming_distance.py delete mode 100644 tests/metrics/classification/test_inputs.py delete mode 100644 tests/metrics/classification/test_iou.py delete mode 100644 tests/metrics/classification/test_precision_recall.py delete mode 100644 tests/metrics/classification/test_precision_recall_curve.py delete mode 100644 tests/metrics/classification/test_roc.py delete mode 100644 tests/metrics/classification/test_stat_scores.py delete mode 100644 tests/metrics/functional/__init__.py delete mode 100644 tests/metrics/functional/test_classification.py delete mode 100644 tests/metrics/functional/test_image_gradients.py delete mode 100644 tests/metrics/functional/test_nlp.py delete mode 100644 tests/metrics/functional/test_reduction.py delete mode 100644 tests/metrics/functional/test_self_supervised.py delete mode 100644 tests/metrics/regression/__init__.py delete mode 100644 tests/metrics/regression/test_explained_variance.py delete mode 100644 tests/metrics/regression/test_mean_error.py delete mode 100644 tests/metrics/regression/test_psnr.py delete mode 100644 tests/metrics/regression/test_r2score.py delete mode 100644 tests/metrics/regression/test_ssim.py delete mode 100644 tests/metrics/test_composition.py delete mode 100644 tests/metrics/test_ddp.py delete mode 100644 tests/metrics/test_metric.py create mode 100644 tests/metrics/test_remove_1-5_metrics.py create mode 100644 tests/plugins/test_custom_plugin.py create mode 100644 tests/plugins/test_double_plugin.py create mode 100644 tests/trainer/test_evaluation_loop.py create mode 100644 tests/tuner/test_scale_batch_size.py rename tests/utilities/{test_argparse_utils.py => test_argparse.py} (80%) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 6afdcc4cbe29f..4ac6944c7a31a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -34,9 +34,9 @@ /pytorch_lightning/utilities @borda @tchaton @SeanNaren @carmocca # Metrics -/pytorch_lightning/metrics/ @teddykoker @ananyahjha93 @justusschock -/tests/metrics/ @teddykoker @ananyahjha93 @justusschock -/docs/source/metrics.rst @teddykoker @ananyahjha93 @justusschock +/pytorch_lightning/metrics/ @SkafteNicki @ananyahjha93 @justusschock +/tests/metrics/ @SkafteNicki @ananyahjha93 @justusschock +/docs/source/metrics.rst @SkafteNicki @ananyahjha93 @justusschock # API /pytorch_lightning/callbacks/base.py @williamfalcon diff --git a/.github/workflows/ci_dockers.yml b/.github/workflows/ci_dockers.yml index 9f77fb76aa593..897e16a12d44f 100644 --- a/.github/workflows/ci_dockers.yml +++ b/.github/workflows/ci_dockers.yml @@ -29,9 +29,6 @@ jobs: - name: Checkout uses: actions/checkout@v2 - # https://github.com/docker/setup-buildx-action - # Set up Docker Buildx - to use cache-from and cache-to argument of buildx command - - uses: docker/setup-buildx-action@v1 - name: Build PL Docker # publish master/release uses: docker/build-push-action@v2 @@ -54,9 +51,6 @@ jobs: - name: Checkout uses: actions/checkout@v2 - # https://github.com/docker/setup-buildx-action - # Set up Docker Buildx - to use cache-from and cache-to argument of buildx command - - uses: docker/setup-buildx-action@v1 - name: Build XLA Docker # publish master/release uses: docker/build-push-action@v2 @@ -93,9 +87,6 @@ jobs: echo "::set-output name=CUDA::$cuda" id: extend - # https://github.com/docker/setup-buildx-action - # Set up Docker Buildx - to use cache-from and cache-to argument of buildx command - - uses: docker/setup-buildx-action@v1 - name: Build CUDA Docker # publish master/release uses: docker/build-push-action@v2 @@ -130,9 +121,6 @@ jobs: echo "::set-output name=CUDA::$cuda" id: extend - # https://github.com/docker/setup-buildx-action - # Set up Docker Buildx - to use cache-from and cache-to argument of buildx command - - uses: docker/setup-buildx-action@v1 - name: Build CUDA Docker # publish master/release uses: docker/build-push-action@v2 @@ -150,10 +138,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v2 - # https://github.com/docker/setup-buildx-action - # Set up Docker Buildx - to use cache-from and cache-to argument of buildx command - - uses: docker/setup-buildx-action@v1 - - name: Build CUDA Docker + + - name: Build NVIDIA Docker uses: docker/build-push-action@v2 with: file: dockers/nvidia/Dockerfile diff --git a/.github/workflows/ci_test-base.yml b/.github/workflows/ci_test-base.yml index ed8a2e30949b7..77363992718af 100644 --- a/.github/workflows/ci_test-base.yml +++ b/.github/workflows/ci_test-base.yml @@ -51,9 +51,8 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade --user pip - pip install --requirement ./requirements.txt --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade - pip install --requirement ./requirements/test.txt --quiet --upgrade-strategy only-if-needed - # pip install tox coverage + pip install --requirement ./requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade + pip install "pytest>6.0" "pytest-cov>2.10" --upgrade-strategy only-if-needed python --version pip --version pip list diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index 812d06f310812..da853bf623d1b 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -44,7 +44,7 @@ jobs: - name: Tests run: | # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 - python -m pytest pytorch_lightning tests --cov=pytorch_lightning -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-torch${{ matrix.pytorch-version }}.xml + coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-torch${{ matrix.pytorch-version }}.xml shell: bash -l {0} - name: Upload pytest results diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 3d3f7d11570a4..5a3e23a37fd0b 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -17,10 +17,6 @@ jobs: os: [ubuntu-18.04, windows-2019, macOS-10.15] python-version: [3.6, 3.7, 3.8] requires: ['minimal', 'latest'] - exclude: - # # todo: segmentation fault for minimal and hanging for latest - - python-version: 3.8 - os: ubuntu-18.04 # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 35 # TODO: the macOS is taking too long, probably caching did not work... @@ -138,7 +134,7 @@ jobs: - name: Tests run: | # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 - python -m pytest pytorch_lightning tests --cov=pytorch_lightning -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-py${{ matrix.python-version }}-${{ matrix.requires }}.xml + coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-py${{ matrix.python-version }}-${{ matrix.requires }}.xml - name: Examples run: | diff --git a/.github/workflows/docs-checks.yml b/.github/workflows/docs-checks.yml index 5ee4f23b4b3cc..4488c598c8ac7 100644 --- a/.github/workflows/docs-checks.yml +++ b/.github/workflows/docs-checks.yml @@ -98,7 +98,7 @@ jobs: # First run the same pipeline as Read-The-Docs cd docs make clean - make html --debug --jobs $(nproc) SPHINXOPTS="-W" + make html --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going" - name: Upload built docs uses: actions/upload-artifact@v2 diff --git a/.github/workflows/events-nightly.yml b/.github/workflows/events-nightly.yml index 24d8ce4002e5d..5ad4396a006f7 100644 --- a/.github/workflows/events-nightly.yml +++ b/.github/workflows/events-nightly.yml @@ -126,3 +126,26 @@ jobs: push: true tags: pytorchlightning/pytorch_lightning:base-conda-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }} timeout-minutes: 55 + +# docker-nvidia: +# runs-on: ubuntu-20.04 +# steps: +# - name: Checkout +# uses: actions/checkout@v2 +# +# # https://github.com/docker/setup-buildx-action +# # Set up Docker Buildx - to use cache-from and cache-to argument of buildx command +# - uses: docker/setup-buildx-action@v1 +# - name: Login to DockerHub +# uses: docker/login-action@v1 +# with: +# username: ${{ secrets.DOCKER_USERNAME }} +# password: ${{ secrets.DOCKER_PASSWORD }} +# +# - name: Publish NVIDIA to Docker Hub +# uses: docker/build-push-action@v2 +# with: +# file: dockers/nvidia/Dockerfile +# push: true +# tags: nvcr.io/pytorchlightning/pytorch_lightning:nvidia +# timeout-minutes: 55 diff --git a/.github/workflows/release-docker.yml b/.github/workflows/release-docker.yml index f285794cbc33b..36ecbe229ac7c 100644 --- a/.github/workflows/release-docker.yml +++ b/.github/workflows/release-docker.yml @@ -8,7 +8,7 @@ on: types: [created] jobs: - build-PL: + cuda-PL: runs-on: ubuntu-20.04 strategy: fail-fast: false @@ -36,3 +36,27 @@ jobs: build_args: PYTHON_VERSION=${{ matrix.python_version }},PYTORCH_VERSION=${{ matrix.pytorch_version }},LIGHTNING_VERSION=${{ steps.get_version.outputs.RELEASE_VERSION }} tags: "${{ steps.get_version.outputs.RELEASE_VERSION }}-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }},latest-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}" timeout-minutes: 55 + +# nvidia-PL: +# runs-on: ubuntu-20.04 +# steps: +# - name: Checkout +# uses: actions/checkout@v2 +# +# - name: Get release version +# if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'release' +# id: get_version +# run: echo "::set-output name=RELEASE_VERSION::$(echo ${GITHUB_REF##*/})" +# +# - name: Publish Releases to Docker +# # only on releases +# uses: docker/build-push-action@v1.1.0 +# if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'release' +# with: +# repository: nvcr.io/pytorchlightning/pytorch_lightning +# username: ${{ secrets.DOCKER_USERNAME }} +# password: ${{ secrets.DOCKER_PASSWORD }} +# dockerfile: dockers/nvidia/Dockerfile +# build_args: LIGHTNING_VERSION=${{ steps.get_version.outputs.RELEASE_VERSION }} +# tags: "${{ steps.get_version.outputs.RELEASE_VERSION }}-nvidia" +# timeout-minutes: 55 diff --git a/.gitignore b/.gitignore index cd0ba22453512..99939ff7fce0c 100644 --- a/.gitignore +++ b/.gitignore @@ -157,3 +157,4 @@ tags data MNIST runs +*trace* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 21c52539a890d..45eca43de93ac 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,8 +33,3 @@ repos: hooks: - id: yapf args: [--parallel, --in-place] - - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.790 - hooks: - - id: mypy diff --git a/CHANGELOG.md b/CHANGELOG.md index 4139a87d9f27b..6c50e3c54e305 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,8 +14,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) + - Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) + - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) @@ -31,15 +33,45 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915)) +- Added `teardown()` hook to LightningDataModule ([#4673](https://github.com/PyTorchLightning/pytorch-lightning/pull/4673)) + + - Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277)) - Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274)) +- Added `teardown` method to `BaseProfiler` to enable subclasses defining post-profiling steps outside of `__del__` ([#6370](https://github.com/PyTorchLightning/pytorch-lightning/pull/6370)) + + +- Added `setup` method to `BaseProfiler` to enable subclasses defining pre-profiling steps for every process ([#6633](https://github.com/PyTorchLightning/pytorch-lightning/pull/6633)) + + - Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) +- Added `Trainer.predict` config validation ([#6543](https://github.com/PyTorchLightning/pytorch-lightning/pull/6543)) + + +- Added `AbstractProfiler` interface ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) + + +- Added support for including module names for forward in the autograd trace of `PyTorchProfiler` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) + + +- Added support for the PyTorch 1.8.1 autograd profiler ([#6618](https://github.com/PyTorchLightning/pytorch-lightning/pull/6618)) + + +- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) + + +- Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595)) + + +- Added `artifact_location` argument to `MLFlowLogger` which will be passed to the `MlflowClient.create_experiment` call ([#6677](https://github.com/PyTorchLightning/pytorch-lightning/pull/6677)) + + ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) @@ -54,7 +86,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) -- Changed the default of `find_unused_parameters` back to `True` in DDP and DDP Spawn ([#6438](https://github.com/PyTorchLightning/pytorch-lightning/pull/6438)) +- Changed profilers to save separate report files per state and rank ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) + + +- Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) + + +- Changed the behavior of `on_epoch_start` to run at the beginning of validation & test epoch ([#6498](https://github.com/PyTorchLightning/pytorch-lightning/pull/6498)) ### Deprecated @@ -65,6 +103,27 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) +- Deprecated `Profiler(output_filename)` in favor of `dirpath` and `filename` ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) + + +- Deprecated `PytorchProfiler(profiled_functions)` in favor of `record_functions` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) + + +- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505), + [#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530), + [#6540](https://github.com/PyTorchLightning/pytorch-lightning/pull/6540), + [#6547](https://github.com/PyTorchLightning/pytorch-lightning/pull/6547), + [#6515](https://github.com/PyTorchLightning/pytorch-lightning/pull/6515), + [#6572](https://github.com/PyTorchLightning/pytorch-lightning/pull/6572), + [#6573](https://github.com/PyTorchLightning/pytorch-lightning/pull/6573), + [#6584](https://github.com/PyTorchLightning/pytorch-lightning/pull/6584), + [#6636](https://github.com/PyTorchLightning/pytorch-lightning/pull/6636), + [#6637](https://github.com/PyTorchLightning/pytorch-lightning/pull/6637), + [#6649](https://github.com/PyTorchLightning/pytorch-lightning/pull/6649), + [#6659](https://github.com/PyTorchLightning/pytorch-lightning/pull/6659), +) + + ### Removed - Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164)) @@ -98,6 +157,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565)) + + - Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)) @@ -113,31 +175,57 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `ModelCheckpoint(save_top_k=0, save_last=True)` not saving the `last` checkpoint ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136)) -- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)) +- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275)) -- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275)) +- Fixed `.teardown(stage='fit')` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) -- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324)) +- Fixed `.on_fit_{start,end}()` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) -- Fixed `.teardown(stage='fit')` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) +- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416)) -- Fixed `.on_fit_{start,end}()` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) +- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587)) -- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688)) +- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576)) -- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380)) +- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506)) -- Fixed DeepSpeed additional memory use on rank 0 when default device not set early enough ([#6460](https://github.com/PyTorchLightning/pytorch-lightning/pull/6460)) +- Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434)) -- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416)) +- Fixed error on TPUs when there was no `ModelCheckpoint` ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654)) + + +- Fixed `trainer.test` freeze on TPUs ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654)) + + +- Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](https://github.com/PyTorchLightning/pytorch-lightning/pull/6657)) + + +## [1.2.4] - 2021-03-16 + +### Changed + +- Changed the default of `find_unused_parameters` back to `True` in DDP and DDP Spawn ([#6438](https://github.com/PyTorchLightning/pytorch-lightning/pull/6438)) + +### Fixed + +- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)) +- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324)) +- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688)) +- Fixed broadcast to use PyTorch `broadcast_object_list` and add `reduce_decision` ([#6410](https://github.com/PyTorchLightning/pytorch-lightning/pull/6410)) +- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380)) +- Fixed DeepSpeed additional memory use on rank 0 when default device not set early enough ([#6460](https://github.com/PyTorchLightning/pytorch-lightning/pull/6460)) +- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398)) +- Fixed an issue with `Tuner.scale_batch_size` not finding the batch size attribute in the datamodule ([#5968](https://github.com/PyTorchLightning/pytorch-lightning/pull/5968)) +- Fixed an exception in the layer summary when the model contains torch.jit scripted submodules ([#6511](https://github.com/PyTorchLightning/pytorch-lightning/pull/6511)) +- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) ## [1.2.3] - 2021-03-09 @@ -157,9 +245,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372)) -- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398)) - - ## [1.2.2] - 2021-03-02 ### Added @@ -188,8 +273,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080)) - Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089)) - Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107)) - - - Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093)) diff --git a/Makefile b/Makefile index d35e0b77f8429..04b08fa2d27d1 100644 --- a/Makefile +++ b/Makefile @@ -29,4 +29,4 @@ test: clean docs: clean pip install --quiet -r requirements/docs.txt - python -m sphinx -b html -W docs/source docs/build + python -m sphinx -b html -W --keep-going docs/source docs/build diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 6dfddda0295fe..d88a31ae9775a 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -78,7 +78,7 @@ jobs: displayName: 'Get legacy checkpoints' - bash: | - python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 + python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50 displayName: 'Testing: standard' - bash: | @@ -88,19 +88,39 @@ jobs: - bash: | python -m coverage report python -m coverage xml - codecov --token=$(CODECOV_TOKEN) --flags=gpu,pytest --name="GPU-coverage" --env=linux,azure + python -m coverage html + python -m codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) --flags=gpu,pytest --name="GPU-coverage" --env=linux,azure + ls -l displayName: 'Statistics' + - task: PublishTestResults@2 + displayName: 'Publish test results' + inputs: + testResultsFiles: '$(Build.StagingDirectory)/test-results.xml' + testRunTitle: '$(Agent.OS) - $(Build.DefinitionName) - Python $(python.version)' + condition: succeededOrFailed() + + - task: PublishCodeCoverageResults@1 + displayName: 'Publish coverage report' + inputs: + codeCoverageTool: 'cobertura' + summaryFileLocation: 'coverage.xml' + reportDirectory: '$(Build.SourcesDirectory)/htmlcov' + testRunTitle: '$(Agent.OS) - $(Build.BuildNumber)[$(Agent.JobName)] - Python $(python.version)' + condition: succeededOrFailed() + - bash: | python -m pytest benchmarks -v --maxfail=2 --durations=0 displayName: 'Testing: benchmarks' - - bash: | + - script: | + set -e python -m pytest pl_examples -v --maxfail=2 --durations=0 python setup.py install --user --quiet bash pl_examples/run_ddp-example.sh - cd pl_examples/basic_examples - bash submit_ddp_job.sh - bash submit_ddp2_job.sh - pip uninstall -y pytorch-lightning + # cd pl_examples/basic_examples + # bash submit_ddp_job.sh + # bash submit_ddp2_job.sh + env: + PL_USE_MOCKED_MNIST: "1" displayName: 'Examples' diff --git a/dockers/nvidia/Dockerfile b/dockers/nvidia/Dockerfile index ea567a5306eed..4b04bc9426d4d 100644 --- a/dockers/nvidia/Dockerfile +++ b/dockers/nvidia/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM nvcr.io/nvidia/pytorch:20.12-py3 +FROM nvcr.io/nvidia/pytorch:21.02-py3 MAINTAINER PyTorchLightning @@ -22,16 +22,17 @@ COPY ./ ./pytorch-lightning/ # install dependencies RUN \ - # Disable cache #conda install "pip>20.1" && \ - #pip config set global.cache-dir false && \ - if [ -z $LIGHTNING_VERSION ] ; then \ - pip install ./pytorch-lightning --no-cache-dir ; \ + pip list | grep torch && \ + if [ ! -z "$LIGHTNING_VERSION" ] ; then \ rm -rf pytorch-lightning ; \ - else \ - rm -rf pytorch-lightning ; \ - pip install https://github.com/PyTorchLightning/pytorch-lightning/archive/${LIGHTNING_VERSION}.zip --no-cache-dir ; \ - fi + wget https://github.com/PyTorchLightning/pytorch-lightning/archive/${LIGHTNING_VERSION}.zip --progress=bar:force:noscroll ; \ + unzip ${LIGHTNING_VERSION}.zip ; \ + mv pytorch-lightning-*/ pytorch-lightning ; \ + rm *.zip ; \ + fi && \ + pip install ./pytorch-lightning["extra"] --no-cache-dir && \ + rm -rf pytorch-lightning RUN python --version && \ pip --version && \ diff --git a/dockers/release/Dockerfile b/dockers/release/Dockerfile index 3584ee02746e3..0eec1e41a5a3f 100644 --- a/dockers/release/Dockerfile +++ b/dockers/release/Dockerfile @@ -25,7 +25,6 @@ COPY ./ ./pytorch-lightning/ # install dependencies RUN \ - # Disable cache #conda install "pip>20.1" && \ if [ ! -z "$LIGHTNING_VERSION" ] ; then \ rm -rf pytorch-lightning ; \ diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 4fb90e7829fb4..5cdb0b377f2b7 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -267,7 +267,7 @@ Lightning allows multiple ways of training - TPUs (``tpu_cores=8|x``) (tpu or TPU pod) .. note:: - If you request multiple GPUs or nodes without setting a mode, DDP will be automatically used. + If you request multiple GPUs or nodes without setting a mode, DDP Spawn will be automatically used. For a deeper understanding of what Lightning is doing, feel free to read this `guide `_. @@ -697,24 +697,23 @@ To use DeepSpeed, you first need to install DeepSpeed using the commands below. .. code-block:: bash - pip install deepspeed mpi4py + pip install deepspeed If you run into an issue with the install or later in training, ensure that the CUDA version of the pytorch you've installed matches your locally installed CUDA (you can see which one has been recognized by running ``nvcc --version``). -Additionally if you run into any issues installing m4py, ensure you have openmpi installed using ``sudo apt install libopenmpi-dev`` or ``brew install mpich`` before running ``pip install mpi4py``. .. note:: Currently ``resume_from_checkpoint`` and manual optimization are not supported. DeepSpeed currently only supports single optimizer, single scheduler within the training loop. -ZeRO-Offload -"""""""""""" +DeepSpeed ZeRO Stage 2 +"""""""""""""""""""""" -Below we show an example of running `ZeRO-Offload `_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption. -For even more speed benefit, they offer an optimized CPU version of ADAM to run the offloaded computation, which is faster than the standard PyTorch implementation. By default we enable ZeRO-Offload. +By default, we enable `DeepSpeed ZeRO Stage 2 `_, which partitions your optimizer states (Stage 1) and your gradients (Stage 2) across your GPUs to reduce memory. In most cases, this is more efficient or at parity with DDP, primarily due to the optimized custom communications written by the DeepSpeed team. +As a result, benefits can also be seen on a single GPU. Do note that the default bucket sizes allocate around ``3.6GB`` of VRAM to use during distributed communications, which can be tweaked when instantiating the plugin described in a few sections below. .. note:: - To use ZeRO-Offload, you must use ``precision=16`` or set precision via `the DeepSpeed config. `_. + To use ZeRO, you must use ``precision=16``. .. code-block:: python @@ -725,6 +724,24 @@ For even more speed benefit, they offer an optimized CPU version of ADAM to run trainer.fit(model) +DeepSpeed ZeRO Stage 2 Offload +"""""""""""""""""""""""""""""" + +Below we show an example of running `ZeRO-Offload `_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption. + +.. note:: + To use ZeRO-Offload, you must use ``precision=16``. + +.. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.plugins import DeepSpeedPlugin + + model = MyModel() + trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(cpu_offload=True), precision=16) + trainer.fit(model) + + This can also be done via the command line using a Pytorch Lightning script: .. code-block:: bash @@ -740,7 +757,7 @@ You can also modify the ZeRO-Offload parameters via the plugin as below. from pytorch_lightning.plugins import DeepSpeedPlugin model = MyModel() - trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(allgather_bucket_size=5e8, reduce_bucket_size=5e8), precision=16) + trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(cpu_offload=True, allgather_bucket_size=5e8, reduce_bucket_size=5e8), precision=16) trainer.fit(model) @@ -752,11 +769,30 @@ You can also modify the ZeRO-Offload parameters via the plugin as below. The plugin sets a reasonable default of ``2e8``, which should work for most low VRAM GPUs (less than ``7GB``), allocating roughly ``3.6GB`` of VRAM as buffer. Higher VRAM GPUs should aim for values around ``5e8``. +For even more speed benefit, DeepSpeed offers an optimized CPU version of ADAM called `DeepSpeedCPUAdam `_ to run the offloaded computation, which is faster than the standard PyTorch implementation. + +.. code-block:: python + + import pytorch_lightning + from pytorch_lightning import Trainer + from pytorch_lightning.plugins import DeepSpeedPlugin + from deepspeed.ops.adam import DeepSpeedCPUAdam + + class MyModel(pl.LightningModule): + ... + def configure_optimizers(self): + # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) + return DeepSpeedCPUAdam(self.parameters()) + + model = MyModel() + trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(cpu_offload=True), precision=16) + trainer.fit(model) + Custom DeepSpeed Config """"""""""""""""""""""" -DeepSpeed allows use of custom DeepSpeed optimizers and schedulers defined within a config file. This allows you to enable optimizers such as `1-bit Adam `_. +In some cases you may want to define your own DeepSpeed Config, to access all parameters defined. We've exposed most of the important parameters, however, there may be debugging parameters to enable. Also, DeepSpeed allows the use of custom DeepSpeed optimizers and schedulers defined within a config file that is supported. .. note:: All plugin default parameters will be ignored when a config object is passed. diff --git a/docs/source/advanced/multiple_loaders.rst b/docs/source/advanced/multiple_loaders.rst index 3f230957ca283..1a82641953c3c 100644 --- a/docs/source/advanced/multiple_loaders.rst +++ b/docs/source/advanced/multiple_loaders.rst @@ -9,7 +9,7 @@ Multiple Datasets Lightning supports multiple dataloaders in a few ways. 1. Create a dataloader that iterates multiple datasets under the hood. -2. In the training loop you can pass multiple loaders as a dict or list/tuple and lightning +2. In the training loop you can pass multiple loaders as a dict or list/tuple and lightning will automatically combine the batches from different loaders. 3. In the validation and test loop you also have the option to return multiple dataloaders which lightning will call sequentially. @@ -75,13 +75,13 @@ For more details please have a look at :paramref:`~pytorch_lightning.trainer.tra loader_a = torch.utils.data.DataLoader(range(6), batch_size=4) loader_b = torch.utils.data.DataLoader(range(15), batch_size=5) - + # pass loaders as a dict. This will create batches like this: # {'a': batch from loader_a, 'b': batch from loader_b} loaders = {'a': loader_a, 'b': loader_b} - # OR: + # OR: # pass loaders as sequence. This will create batches like this: # [batch from loader_a, batch from loader_b] loaders = [loader_a, loader_b] @@ -89,7 +89,24 @@ For more details please have a look at :paramref:`~pytorch_lightning.trainer.tra return loaders Furthermore, Lightning also supports that nested lists and dicts (or a combination) can -be returned +be returned. + +.. testcode:: + + class LitModel(LightningModule): + + def train_dataloader(self): + + loader_a = torch.utils.data.DataLoader(range(8), batch_size=4) + loader_b = torch.utils.data.DataLoader(range(16), batch_size=2) + + return {'a': loader_a, 'b': loader_b} + + def training_step(self, batch, batch_idx): + # access a dictionnary with a batch from each dataloader + batch_a = batch["a"] + batch_b = batch["b"] + .. testcode:: @@ -103,12 +120,29 @@ be returned loader_c = torch.utils.data.DataLoader(range(64), batch_size=4) # pass loaders as a nested dict. This will create batches like this: - # {'loader_a_b': {'a': batch from loader a, 'b': batch from loader b}, - # 'loader_c_d': {'c': batch from loader c, 'd': batch from loader d}} - loaders = {'loaders_a_b': {'a': loader_a, 'b': loader_b}, - 'loaders_c_d': {'c': loader_c, 'd': loader_d}} + loaders = { + 'loaders_a_b': { + 'a': loader_a, + 'b': loader_b + }, + 'loaders_c_d': { + 'c': loader_c, + 'd': loader_d + } + } return loaders + def training_step(self, batch, batch_idx): + # access the data + batch_a_b = batch["loaders_a_b"] + batch_c_d = batch["loaders_c_d"] + + batch_a = batch_a_b["a"] + batch_b = batch_a_b["a"] + + batch_c = batch_c_d["c"] + batch_d = batch_c_d["d"] + ---------- Test/Val dataloaders diff --git a/docs/source/benchmarking/performance.rst b/docs/source/benchmarking/performance.rst index d1bc2c9ebc009..dbddaad3a5e3c 100644 --- a/docs/source/benchmarking/performance.rst +++ b/docs/source/benchmarking/performance.rst @@ -181,3 +181,19 @@ Most UNIX-based operating systems provide direct access to tmpfs through a mount .. code-block:: python datamodule = MyDataModule(data_root="/dev/shm/my_data") + + +Zero Grad ``set_to_none=True`` +------------------------------ + +In order to modestly improve performance, once can override :meth:`~pytorch_lightning.core.lightning.LightningModule.optimizer_zero_grad`. + +For a more detailed explanation of pros / cons of this technique, +read `this `_ documentation by the PyTorch team. + +.. testcode:: + + class Model(LightningModule): + + def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): + optimizer.zero_grad(set_to_none=True) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index f6deb9adf58d3..6e67f591da7c7 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1020,12 +1020,14 @@ This is the pseudocode to describe how all the hooks are called during a call to .. code-block:: python def fit(...): - on_fit_start() - if global_rank == 0: # prepare data is called on GLOBAL_ZERO only prepare_data() + configure_callbacks() + + on_fit_start() + for gpu/tpu in gpu/tpus: train_on_device(model.copy()) @@ -1043,6 +1045,7 @@ This is the pseudocode to describe how all the hooks are called during a call to teardown() def train_loop(): + on_epoch_start() on_train_epoch_start() train_outs = [] for train_batch in train_dataloader(): @@ -1068,12 +1071,15 @@ This is the pseudocode to describe how all the hooks are called during a call to val_loop() # end training epoch - logs = training_epoch_end(outs) + outs = training_epoch_end(outs) + on_train_epoch_end(outs) + on_epoch_end() def val_loop(): model.eval() torch.set_grad_enabled(False) + on_epoch_start() on_validation_epoch_start() val_outs = [] for val_batch in val_dataloader(): @@ -1087,6 +1093,7 @@ This is the pseudocode to describe how all the hooks are called during a call to validation_epoch_end(val_outs) on_validation_epoch_end() + on_epoch_end() # set up for train model.train() @@ -1114,12 +1121,12 @@ manual_backward on_after_backward ~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_after_backward +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_after_backward :noindex: on_before_zero_grad ~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_zero_grad +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad :noindex: on_fit_start @@ -1138,15 +1145,38 @@ on_fit_end on_load_checkpoint ~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_load_checkpoint +.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint :noindex: on_save_checkpoint ~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_save_checkpoint +.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint + :noindex: + +on_train_start +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_start + :noindex: + +on_train_end +~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_end :noindex: +on_validation_start +~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_start + :noindex: + +on_validation_end +~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_end + :noindex: on_pretrain_routine_start ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1184,6 +1214,11 @@ on_test_epoch_end .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_end :noindex: +on_test_end +~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_end + :noindex: on_train_batch_start ~~~~~~~~~~~~~~~~~~~~ @@ -1197,6 +1232,18 @@ on_train_batch_end .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_batch_end :noindex: +on_epoch_start +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_start + :noindex: + +on_epoch_end +~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_end + :noindex: + on_train_epoch_start ~~~~~~~~~~~~~~~~~~~~ @@ -1233,6 +1280,36 @@ on_validation_epoch_end .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_end :noindex: +on_post_move_to_device +~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_post_move_to_device + :noindex: + +on_validation_model_eval +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval + :noindex: + +on_validation_model_train +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train + :noindex: + +on_test_model_eval +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval + :noindex: + +on_test_model_train +~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train + :noindex: + optimizer_step ~~~~~~~~~~~~~~ @@ -1254,7 +1331,7 @@ prepare_data setup ~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.setup +.. automethod:: pytorch_lightning.core.hooks.DataHooks.setup :noindex: tbptt_split_batch @@ -1266,25 +1343,25 @@ tbptt_split_batch teardown ~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.teardown +.. automethod:: pytorch_lightning.core.hooks.DataHooks.teardown :noindex: train_dataloader ~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.train_dataloader +.. automethod:: pytorch_lightning.core.hooks.DataHooks.train_dataloader :noindex: val_dataloader ~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.val_dataloader +.. automethod:: pytorch_lightning.core.hooks.DataHooks.val_dataloader :noindex: test_dataloader ~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_dataloader +.. automethod:: pytorch_lightning.core.hooks.DataHooks.test_dataloader :noindex: transfer_batch_to_device diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 5614e481e0888..d86a8dc1ff472 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -1157,7 +1157,7 @@ precision | -Full precision (32), half precision (16). +Double precision (64), full precision (32) or half precision (16). Can be used on CPU, GPU or TPUs. If used on TPU will use torch.bfloat16 but tensor printing @@ -1172,6 +1172,9 @@ will still show torch.float32. # 16-bit precision trainer = Trainer(precision=16, gpus=1) + # 64-bit precision + trainer = Trainer(precision=64) + Example:: # one day diff --git a/docs/source/conf.py b/docs/source/conf.py index ccf824bb37d9b..1c1f3be8a636a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,7 +13,6 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. # import m2r -import builtins import glob import os import shutil @@ -27,10 +26,13 @@ FOLDER_GENERATED = 'generated' SPHINX_MOCK_REQUIREMENTS = int(os.environ.get('SPHINX_MOCK_REQUIREMENTS', True)) -if SPHINX_MOCK_REQUIREMENTS: - builtins.__LIGHTNING_SETUP__ = True -import pytorch_lightning # noqa: E402 +try: + from pytorch_lightning import info +except ImportError: + # alternative https://stackoverflow.com/a/67692/4521646 + sys.path.append(os.path.join(PATH_ROOT, "pytorch_lightning")) + import info # -- Project documents ------------------------------------------------------- @@ -79,13 +81,13 @@ def _transform_changelog(path_in: str, path_out: str) -> None: # -- Project information ----------------------------------------------------- project = 'PyTorch Lightning' -copyright = pytorch_lightning.__copyright__ -author = pytorch_lightning.__author__ +copyright = info.__copyright__ +author = info.__author__ # The short X.Y version -version = pytorch_lightning.__version__ +version = info.__version__ # The full version, including alpha/beta/rc tags -release = pytorch_lightning.__version__ +release = info.__version__ # -- General configuration --------------------------------------------------- @@ -176,8 +178,8 @@ def _transform_changelog(path_in: str, path_out: str) -> None: # documentation. html_theme_options = { - 'pytorch_project': pytorch_lightning.__homepage__, - 'canonical_url': pytorch_lightning.__homepage__, + 'pytorch_project': 'https://pytorchlightning.ai', + 'canonical_url': info.__docs_url__, 'collapse_navigation': False, 'display_version': True, 'logo_only': False, @@ -279,6 +281,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None: 'torch': ('https://pytorch.org/docs/stable/', None), 'numpy': ('https://numpy.org/doc/stable/', None), 'PIL': ('https://pillow.readthedocs.io/en/stable/', None), + 'torchmetrics': ('https://torchmetrics.readthedocs.io/en/stable/', None), } # -- Options for todo extension ---------------------------------------------- @@ -328,9 +331,11 @@ def package_list_from_file(file): 'comet-ml': 'comet_ml', 'neptune-client': 'neptune', 'hydra-core': 'hydra', + 'pyDeprecate': 'deprecate', } MOCK_PACKAGES = [] if SPHINX_MOCK_REQUIREMENTS: + MOCK_PACKAGES += ['fairscale'] # mock also base packages when we are on RTD since we don't install them there MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements.txt')) MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements', 'extra.txt')) diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index 63a221a06119f..73691c6dd76f5 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -349,3 +349,15 @@ on_load_checkpoint .. automethod:: pytorch_lightning.callbacks.Callback.on_load_checkpoint :noindex: + +on_after_backward +^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward + :noindex: + +on_before_zero_grad +^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_before_zero_grad + :noindex: diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 85134fda06fa2..881febe21316d 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -94,6 +94,10 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=self.batch_size) + def teardown(self, stage: Optional[str] = None): + # Used to clean-up when the run is finished + ... + But now, as the complexity of your processing grows (transforms, multiple-GPU training), you can let Lightning handle those details for you while making this dataset reusable so you can share with colleagues or use in different projects. @@ -243,7 +247,10 @@ There are also data operations you might want to perform on every GPU. Use setup self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) -.. warning:: `setup` is called from every process. Setting state here is okay. +.. warning:: ``setup`` is called from every process. Setting state here is okay. + + +.. note:: ``teardown`` can be used to clean up the state. It is also called from every process train_dataloader @@ -411,10 +418,14 @@ You can of course use DataModules in plain PyTorch code as well. for batch in dm.val_dataloader(): ... + dm.teardown(stage='fit') + # lazy load test data dm.setup(stage='test') for batch in dm.test_dataloader(): ... + dm.teardown(stage='test') + But overall, DataModules encourage reproducibility by allowing all details of a dataset to be specified in a unified structure. diff --git a/docs/source/extensions/logging.rst b/docs/source/extensions/logging.rst index a17d595f1fc44..9ad17b5fd1821 100644 --- a/docs/source/extensions/logging.rst +++ b/docs/source/extensions/logging.rst @@ -90,7 +90,7 @@ The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a .. note:: - Setting ``on_epoch=True`` will cache all your logged values during the full training epoch and perform a - reduction `on_epoch_end`. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction. + reduction in ``on_train_epoch_end``. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction. - Setting both ``on_step=True`` and ``on_epoch=True`` will create two keys per metric you log with suffix ``_step`` and ``_epoch``, respectively. You can refer to these keys e.g. in the `monitor` diff --git a/docs/source/extensions/metrics.rst b/docs/source/extensions/metrics.rst index 6a64c42ec2753..74a4a15deb2be 100644 --- a/docs/source/extensions/metrics.rst +++ b/docs/source/extensions/metrics.rst @@ -1,887 +1,9 @@ -.. testsetup:: * - - import torch - from torch.nn import Module - from pytorch_lightning.core.lightning import LightningModule - from pytorch_lightning.metrics import Metric - -.. _metrics: - ####### Metrics ####### -``pytorch_lightning.metrics`` is a Metrics API created for easy metric development and usage in -PyTorch and PyTorch Lightning. It is rigorously tested for all edge cases and includes a growing list of -common metric implementations. - -The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to the user. The metric base class inherits -``nn.Module`` which allows us to call ``metric(...)`` directly. The ``forward()`` method of the base ``Metric`` class -serves the dual purpose of calling ``update()`` on its input and simultaneously returning the value of the metric over the -provided input. +``pytorch_lightning.metrics`` has been moved to a separate package `TorchMetrics `_. +We will preserve compatibility for the next few releases, nevertheless, we encourage users to update to use this stand-alone package. .. warning:: - From v1.2 onward ``compute()`` will no longer automatically call ``reset()``, - and it is up to the user to reset metrics between epochs, except in the case where the - metric is directly passed to ``LightningModule``'s ``self.log``. - -These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in -distributed mode, the internal state of each metric is synced and reduced across each process, so that the -logic present in ``.compute()`` is applied to state information from all processes. - -The example below shows how to use a metric in your ``LightningModule``: - -.. code-block:: python - - def __init__(self): - ... - self.accuracy = pl.metrics.Accuracy() - - def training_step(self, batch, batch_idx): - x, y = batch - preds = self(x) - ... - # log step metric - self.log('train_acc_step', self.accuracy(preds, y)) - ... - - def training_epoch_end(self, outs): - # log epoch metric - self.log('train_acc_epoch', self.accuracy.compute()) - - -``Metric`` objects can also be directly logged, in which case Lightning will log -the metric based on ``on_step`` and ``on_epoch`` flags present in ``self.log(...)``. -If ``on_epoch`` is True, the logger automatically logs the end of epoch metric value by calling -``.compute()``. - -.. note:: - ``sync_dist``, ``sync_dist_op``, ``sync_dist_group``, ``reduce_fx`` and ``tbptt_reduce_fx`` - flags from ``self.log(...)`` don't affect the metric logging in any manner. The metric class - contains its own distributed synchronization logic. - - This however is only true for metrics that inherit the base class ``Metric``, - and thus the functional metric API provides no support for in-built distributed synchronization - or reduction functions. - - -.. code-block:: python - - def __init__(self): - ... - self.train_acc = pl.metrics.Accuracy() - self.valid_acc = pl.metrics.Accuracy() - - def training_step(self, batch, batch_idx): - x, y = batch - preds = self(x) - ... - self.train_acc(preds, y) - self.log('train_acc', self.train_acc, on_step=True, on_epoch=False) - - def validation_step(self, batch, batch_idx): - logits = self(x) - ... - self.valid_acc(logits, y) - self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True) - -.. note:: - - If using metrics in data parallel mode (dp), the metric update/logging should be done - in the ``_step_end`` method (where ```` is either ``training``, ``validation`` - or ``test``). This is due to metric states else being destroyed after each forward pass, - leading to wrong accumulation. In practice do the following: - - .. code-block:: python - - def training_step(self, batch, batch_idx): - data, target = batch - preds = self(data) - ... - return {'loss' : loss, 'preds' : preds, 'target' : target} - - def training_step_end(self, outputs): - #update and log - self.metric(outputs['preds'], outputs['target']) - self.log('metric', self.metric) - -This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example: - -.. code-block:: python - - from pytorch_lightning import metrics - - train_accuracy = metrics.Accuracy() - valid_accuracy = metrics.Accuracy(compute_on_step=False) - - for epoch in range(epochs): - for x, y in train_data: - y_hat = model(x) - - # training step accuracy - batch_acc = train_accuracy(y_hat, y) - - for x, y in valid_data: - y_hat = model(x) - valid_accuracy(y_hat, y) - - # total accuracy over all training batches - total_train_accuracy = train_accuracy.compute() - - # total accuracy over all validation batches - total_valid_accuracy = valid_accuracy.compute() - -.. note:: - - Metrics contain internal states that keep track of the data seen so far. - Do not mix metric states across training, validation and testing. - It is highly recommended to re-initialize the metric per mode as - shown in the examples above. For easy initializing the same metric multiple - times, the ``.clone()`` method can be used: - - .. testcode:: - - from pytorch_lightning.metrics import Accuracy - - def __init__(self): - ... - metric = Accuracy() - self.train_acc = metric.clone() - self.val_acc = metric.clone() - self.test_acc = metric.clone() - -.. note:: - - Metric states are **not** added to the models ``state_dict`` by default. - To change this, after initializing the metric, the method ``.persistent(mode)`` can - be used to enable (``mode=True``) or disable (``mode=False``) this behaviour. - -******************* -Metrics and devices -******************* - -Metrics are simple subclasses of :class:`~torch.nn.Module` and their metric states behave -similar to buffers and parameters of modules. This means that metrics states should -be moved to the same device as the input of the metric: - -.. code-block:: python - - from pytorch_lightning.metrics import Accuracy - - target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0)) - preds = torch.tensor([0, 1, 0, 0], device=torch.device("cuda", 0)) - - # Metric states are always initialized on cpu, and needs to be moved to - # the correct device - confmat = Accuracy(num_classes=2).to(torch.device("cuda", 0)) - out = confmat(preds, target) - print(out.device) # cuda:0 - -However, when **properly defined** inside a :class:`~pytorch_lightning.core.lightning.LightningModule` -, Lightning will automatically move the metrics to the same device as the data. Being -**properly defined** means that the metric is correctly identified as a child module of the -model (check ``.children()`` attribute of the model). Therefore, metrics cannot be placed -in native python ``list`` and ``dict``, as they will not be correctly identified -as child modules. Instead of ``list`` use :class:`~torch.nn.ModuleList` and instead of -``dict`` use :class:`~torch.nn.ModuleDict`. - -.. testcode:: - - from pytorch_lightning.metrics import Accuracy - - class MyModule(LightningModule): - def __init__(self): - ... - # valid ways metrics will be identified as child modules - self.metric1 = Accuracy() - self.metric2 = nn.ModuleList(Accuracy()) - self.metric3 = nn.ModuleDict({'accuracy': Accuracy()}) - - def training_step(self, batch, batch_idx): - # all metrics will be on the same device as the input batch - data, target = batch - preds = self(data) - ... - val1 = self.metric1(preds, target) - val2 = self.metric2[0](preds, target) - val3 = self.metric3['accuracy'](preds, target) - - -********************* -Implementing a Metric -********************* - -To implement your custom metric, subclass the base ``Metric`` class and implement the following methods: - -- ``__init__()``: Each state variable should be called using ``self.add_state(...)``. -- ``update()``: Any code needed to update the state given any inputs to the metric. -- ``compute()``: Computes a final value from the state of the metric. - -All you need to do is call ``add_state`` correctly to implement a custom metric with DDP. -``reset()`` is called on metric state variables added using ``add_state()``. - -To see how metric states are synchronized across distributed processes, refer to ``add_state()`` docs -from the base ``Metric`` class. - -Example implementation: - -.. testcode:: - - from pytorch_lightning.metrics import Metric - - class MyAccuracy(Metric): - def __init__(self, dist_sync_on_step=False): - super().__init__(dist_sync_on_step=dist_sync_on_step) - - self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - preds, target = self._input_format(preds, target) - assert preds.shape == target.shape - - self.correct += torch.sum(preds == target) - self.total += target.numel() - - def compute(self): - return self.correct.float() / self.total - -Metrics support backpropagation, if all computations involved in the metric calculation -are differentiable. However, note that the cached state is detached from the computational -graph and cannot be backpropagated. Not doing this would mean storing the computational -graph for each update call, which can lead to out-of-memory errors. -In practise this means that: - -.. code-block:: python - - metric = MyMetric() - val = metric(pred, target) # this value can be backpropagated - val = metric.compute() # this value cannot be backpropagated - - -Metric API ----------- - -.. autoclass:: pytorch_lightning.metrics.Metric - :noindex: - -Internal implementation details -------------------------------- - -This section briefly describe how metrics work internally. We encourage looking at the source code for more info. -Internally, Lightning wraps the user defined ``update()`` and ``compute()`` method. We do this to automatically -synchronize and reduce metric states across multiple devices. More precisely, calling ``update()`` does the -following internally: - -1. Clears computed cache -2. Calls user-defined ``update()`` - -Simiarly, calling ``compute()`` does the following internally - -1. Syncs metric states between processes -2. Reduce gathered metric states -3. Calls the user defined ``compute()`` method on the gathered metric states -4. Cache computed result - -From a user's standpoint this has one important side-effect: computed results are cached. This means that no -matter how many times ``compute`` is called after one and another, it will continue to return the same result. -The cache is first emptied on the next call to ``update``. - -``forward`` serves the dual purpose of both returning the metric on the current data and updating the internal -metric state for accumulating over multiple batches. The ``forward()`` method achives this by combining calls -to ``update`` and ``compute`` in the following way (assuming metric is initialized with ``compute_on_step=True``): - -1. Calls ``update()`` to update the global metric states (for accumulation over multiple batches) -2. Caches the global state -3. Calls ``reset()`` to clear global metric state -4. Calls ``update()`` to update local metric state -5. Calls ``compute()`` to calculate metric for current batch -6. Restores the global state - -This procedure has the consequence of calling the user defined ``update`` **twice** during a single -forward call (one to update global statistics and one for getting the batch statistics). - - -****************** -Metric Arithmetics -****************** - -Metrics support most of python built-in operators for arithmetic, logic and bitwise operations. - -For example for a metric that should return the sum of two different metrics, implementing a new metric is an overhead that is not necessary. -It can now be done with: - -.. code-block:: python - - first_metric = MyFirstMetric() - second_metric = MySecondMetric() - - new_metric = first_metric + second_metric - -``new_metric.update(*args, **kwargs)`` now calls update of ``first_metric`` and ``second_metric``. It forwards all positional arguments but -forwards only the keyword arguments that are available in respective metric's update declaration. - -Similarly ``new_metric.compute()`` now calls compute of ``first_metric`` and ``second_metric`` and adds the results up. - -This pattern is implemented for the following operators (with ``a`` being metrics and ``b`` being metrics, tensors, integer or floats): - -* Addition (``a + b``) -* Bitwise AND (``a & b``) -* Equality (``a == b``) -* Floordivision (``a // b``) -* Greater Equal (``a >= b``) -* Greater (``a > b``) -* Less Equal (``a <= b``) -* Less (``a < b``) -* Matrix Multiplication (``a @ b``) -* Modulo (``a % b``) -* Multiplication (``a * b``) -* Inequality (``a != b``) -* Bitwise OR (``a | b``) -* Power (``a ** b``) -* Substraction (``a - b``) -* True Division (``a / b``) -* Bitwise XOR (``a ^ b``) -* Absolute Value (``abs(a)``) -* Inversion (``~a``) -* Negative Value (``neg(a)``) -* Positive Value (``pos(a)``) - -**************** -MetricCollection -**************** - -In many cases it is beneficial to evaluate the model output by multiple metrics. -In this case the `MetricCollection` class may come in handy. It accepts a sequence -of metrics and wraps theses into a single callable metric class, with the same -interface as any other metric. - -Example: - -.. testcode:: - - from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall - target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) - preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) - metric_collection = MetricCollection([ - Accuracy(), - Precision(num_classes=3, average='macro'), - Recall(num_classes=3, average='macro') - ]) - print(metric_collection(preds, target)) - -.. testoutput:: - :options: +NORMALIZE_WHITESPACE - - {'Accuracy': tensor(0.1250), - 'Precision': tensor(0.0667), - 'Recall': tensor(0.1111)} - -Similarly it can also reduce the amount of code required to log multiple metrics -inside your LightningModule - -.. code-block:: python - - def __init__(self): - ... - metrics = pl.metrics.MetricCollection(...) - self.train_metrics = metrics.clone() - self.valid_metrics = metrics.clone() - - def training_step(self, batch, batch_idx): - logits = self(x) - ... - self.train_metrics(logits, y) - # use log_dict instead of log - self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train') - - def validation_step(self, batch, batch_idx): - logits = self(x) - ... - self.valid_metrics(logits, y) - # use log_dict instead of log - self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val') - -.. note:: - - `MetricCollection` as default assumes that all the metrics in the collection - have the same call signature. If this is not the case, input that should be - given to different metrics can given as keyword arguments to the collection. - -.. autoclass:: pytorch_lightning.metrics.MetricCollection - :noindex: - - -*************************** -Class vs Functional Metrics -*************************** - -The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs. - -Also, the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. -If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also using the class interface. - -********************** -Classification Metrics -********************** - -Input types ------------ - -For the purposes of classification metrics, inputs (predictions and targets) are split -into these categories (``N`` stands for the batch size and ``C`` for number of classes): - -.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1 - :header: "Type", "preds shape", "preds dtype", "target shape", "target dtype" - :widths: 20, 10, 10, 10, 10 - - "Binary", "(N,)", "``float``", "(N,)", "``binary``\*" - "Multi-class", "(N,)", "``int``", "(N,)", "``int``" - "Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``" - "Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*" - "Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``" - "Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``" - -.. note:: - All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so - that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``. - -When predictions or targets are integers, it is assumed that class labels start at 0, i.e. -the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types - -.. testcode:: - - # Binary inputs - binary_preds = torch.tensor([0.6, 0.1, 0.9]) - binary_target = torch.tensor([1, 0, 2]) - - # Multi-class inputs - mc_preds = torch.tensor([0, 2, 1]) - mc_target = torch.tensor([0, 1, 2]) - - # Multi-class inputs with probabilities - mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]]) - mc_target_probs = torch.tensor([0, 1, 2]) - - # Multi-label inputs - ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]]) - ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]]) - - -Using the is_multiclass parameter -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -In some cases, you might have inputs which appear to be (multi-dimensional) multi-class -but are actually binary/multi-label - for example, if both predictions and targets are -integer (binary) tensors. Or it could be the other way around, you want to treat -binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs. - -For these cases, the metrics where this distinction would make a difference, expose the -``is_multiclass`` argument. Let's see how this is used on the example of -:class:`~pytorch_lightning.metrics.StatScores` metric. - -First, let's consider the case with label predictions with 2 classes, which we want to -treat as binary. - -.. testcode:: - - from pytorch_lightning.metrics.functional import stat_scores - - # These inputs are supposed to be binary, but appear as multi-class - preds = torch.tensor([0, 1, 0]) - target = torch.tensor([1, 1, 0]) - -As you can see below, by default the inputs are treated -as multi-class. We can set ``is_multiclass=False`` to treat the inputs as binary - -which is the same as converting the predictions to float beforehand. - -.. doctest:: - - >>> stat_scores(preds, target, reduce='macro', num_classes=2) - tensor([[1, 1, 1, 0, 1], - [1, 0, 1, 1, 2]]) - >>> stat_scores(preds, target, reduce='macro', num_classes=1, is_multiclass=False) - tensor([[1, 0, 1, 1, 2]]) - >>> stat_scores(preds.float(), target, reduce='macro', num_classes=1) - tensor([[1, 0, 1, 1, 2]]) - -Next, consider the opposite example: inputs are binary (as predictions are probabilities), -but we would like to treat them as 2-class multi-class, to obtain the metric for both classes. - -.. testcode:: - - preds = torch.tensor([0.2, 0.7, 0.3]) - target = torch.tensor([1, 1, 0]) - -In this case we can set ``is_multiclass=True``, to treat the inputs as multi-class. - -.. doctest:: - - >>> stat_scores(preds, target, reduce='macro', num_classes=1) - tensor([[1, 0, 1, 1, 2]]) - >>> stat_scores(preds, target, reduce='macro', num_classes=2, is_multiclass=True) - tensor([[1, 1, 1, 0, 1], - [1, 0, 1, 1, 2]]) - - -Class Metrics (Classification) ------------------------------- - -Accuracy -~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.Accuracy - :noindex: - -AveragePrecision -~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.AveragePrecision - :noindex: - -AUC -~~~ - -.. autoclass:: pytorch_lightning.metrics.AUC - :noindex: - -AUROC -~~~~~ - -.. autoclass:: pytorch_lightning.metrics.AUROC - :noindex: - -ConfusionMatrix -~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.ConfusionMatrix - :noindex: - -F1 -~~ - -.. autoclass:: pytorch_lightning.metrics.F1 - :noindex: - -FBeta -~~~~~ - -.. autoclass:: pytorch_lightning.metrics.FBeta - :noindex: - -IoU -~~~ - -.. autoclass:: pytorch_lightning.metrics.IoU - :noindex: - -Hamming Distance -~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.HammingDistance - :noindex: - -Precision -~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.Precision - :noindex: - -PrecisionRecallCurve -~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.PrecisionRecallCurve - :noindex: - -Recall -~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.Recall - :noindex: - -ROC -~~~ - -.. autoclass:: pytorch_lightning.metrics.ROC - :noindex: - - -StatScores -~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.StatScores - :noindex: - - -Functional Metrics (Classification) ------------------------------------ - -accuracy [func] -~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.accuracy - :noindex: - - -auc [func] -~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.auc - :noindex: - - -auroc [func] -~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.auroc - :noindex: - - -average_precision [func] -~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.average_precision - :noindex: - - -confusion_matrix [func] -~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.confusion_matrix - :noindex: - - -dice_score [func] -~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.dice_score - :noindex: - - -f1 [func] -~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.f1 - :noindex: - - -fbeta [func] -~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.fbeta - :noindex: - -hamming_distance [func] -~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.hamming_distance - :noindex: - -iou [func] -~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.iou - :noindex: - - -roc [func] -~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.roc - :noindex: - - -precision [func] -~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.precision - :noindex: - - -precision_recall [func] -~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.precision_recall - :noindex: - - -precision_recall_curve [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.precision_recall_curve - :noindex: - - -recall [func] -~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.recall - :noindex: - -select_topk [func] -~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.utils.select_topk - :noindex: - - -stat_scores [func] -~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.stat_scores - :noindex: - - -stat_scores_multiple_classes [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.stat_scores_multiple_classes - :noindex: - - -to_categorical [func] -~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.utils.to_categorical - :noindex: - - -to_onehot [func] -~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.utils.to_onehot - :noindex: - -****************** -Regression Metrics -****************** - -Class Metrics (Regression) --------------------------- - -ExplainedVariance -~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.ExplainedVariance - :noindex: - - -MeanAbsoluteError -~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.MeanAbsoluteError - :noindex: - - -MeanSquaredError -~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.MeanSquaredError - :noindex: - - -MeanSquaredLogError -~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.MeanSquaredLogError - :noindex: - - -PSNR -~~~~ - -.. autoclass:: pytorch_lightning.metrics.PSNR - :noindex: - - -SSIM -~~~~ - -.. autoclass:: pytorch_lightning.metrics.SSIM - :noindex: - - -R2Score -~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.R2Score - :noindex: - -Functional Metrics (Regression) -------------------------------- - -explained_variance [func] -~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.explained_variance - :noindex: - - -image_gradients [func] -~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.image_gradients - :noindex: - - -mean_absolute_error [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.mean_absolute_error - :noindex: - - -mean_squared_error [func] -~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_error - :noindex: - - -mean_squared_log_error [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error - :noindex: - - -psnr [func] -~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.psnr - :noindex: - - -ssim [func] -~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.ssim - :noindex: - - -r2score [func] -~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.r2score - :noindex: - - -*** -NLP -*** - -bleu_score [func] ------------------ - -.. autofunction:: pytorch_lightning.metrics.functional.bleu_score - :noindex: - -******** -Pairwise -******** - -embedding_similarity [func] ---------------------------- - -.. autofunction:: pytorch_lightning.metrics.functional.embedding_similarity - :noindex: + ``pytorch_lightning.metrics`` is deprecated from v1.3 and will be removed in v1.5. diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst index c65894367a39e..551b8182caa7d 100644 --- a/docs/source/starter/introduction_guide.rst +++ b/docs/source/starter/introduction_guide.rst @@ -882,8 +882,8 @@ Or maybe we have a model that we use to do generation generated_imgs = model(z) -To perform inference at scale, it is possible to use ``trainer.predict`` with LightningModule ``predict`` function -By default, LightningModule ``predict`` calls forward, but it can be overriden to add any processing logic. +To perform inference at scale, it is possible to use ``trainer.predict`` with LightningModule ``predict_step`` function +By default, LightningModule ``predict_step`` calls forward, but it can be overriden to add any processing logic. .. code-block:: python @@ -893,7 +893,7 @@ By default, LightningModule ``predict`` calls forward, but it can be overriden t imgs = self.decoder(z) return imgs - def predict(self, batch, batch_idx: int , dataloader_idx: int = None): + def predict_step(self, batch, batch_idx: int , dataloader_idx: int = None): return self(batch) diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst index f68865f3695c3..7a1164b1bdf3a 100644 --- a/docs/source/starter/new-project.rst +++ b/docs/source/starter/new-project.rst @@ -83,7 +83,7 @@ Step 1: Define LightningModule .. testcode:: - class LitAutoEncoder(LightningModule): + class LitAutoEncoder(pl.LightningModule): def __init__(self): super().__init__() diff --git a/pl_examples/__init__.py b/pl_examples/__init__.py index ffd60f9ed71af..150ac309ddceb 100644 --- a/pl_examples/__init__.py +++ b/pl_examples/__init__.py @@ -15,10 +15,10 @@ _DATASETS_PATH = os.path.join(_PACKAGE_ROOT, 'Datasets') _TORCHVISION_AVAILABLE = _module_available("torchvision") -_TORCHVISION_MNIST_AVAILABLE = True +_TORCHVISION_MNIST_AVAILABLE = not bool(os.environ.get("PL_USE_MOCKED_MNIST", False)) _DALI_AVAILABLE = _module_available("nvidia.dali") -if _TORCHVISION_AVAILABLE: +if _TORCHVISION_MNIST_AVAILABLE: try: from torchvision.datasets.mnist import MNIST MNIST(_DATASETS_PATH, download=True) diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index a2010a89f4461..6841b8555ef1f 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -39,17 +39,17 @@ class LitAutoEncoder(pl.LightningModule): ) """ - def __init__(self): + def __init__(self, hidden_dim: int = 64): super().__init__() self.encoder = nn.Sequential( - nn.Linear(28 * 28, 64), + nn.Linear(28 * 28, hidden_dim), nn.ReLU(), - nn.Linear(64, 3), + nn.Linear(hidden_dim, 3), ) self.decoder = nn.Sequential( - nn.Linear(3, 64), + nn.Linear(3, hidden_dim), nn.ReLU(), - nn.Linear(64, 28 * 28), + nn.Linear(hidden_dim, 28 * 28), ) def forward(self, x): @@ -94,7 +94,7 @@ def cli_main(): # ------------ parser = ArgumentParser() parser.add_argument('--batch_size', default=32, type=int) - parser.add_argument('--hidden_dim', type=int, default=128) + parser.add_argument('--hidden_dim', type=int, default=64) parser = pl.Trainer.add_argparse_args(parser) args = parser.parse_args() @@ -112,7 +112,7 @@ def cli_main(): # ------------ # model # ------------ - model = LitAutoEncoder() + model = LitAutoEncoder(args.hidden_dim) # ------------ # training diff --git a/pl_examples/basic_examples/conv_sequential_example.py b/pl_examples/basic_examples/conv_sequential_example.py index 1c35c69d29f37..f3d9469144f50 100644 --- a/pl_examples/basic_examples/conv_sequential_example.py +++ b/pl_examples/basic_examples/conv_sequential_example.py @@ -27,11 +27,11 @@ import torch.nn as nn import torch.nn.functional as F import torchvision +from torchmetrics.functional import accuracy import pytorch_lightning as pl from pl_examples import cli_lightning_logo from pytorch_lightning import Trainer -from pytorch_lightning.metrics.functional import accuracy from pytorch_lightning.plugins import RPCSequentialPlugin from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE diff --git a/pl_examples/basic_examples/profiler_example.py b/pl_examples/basic_examples/profiler_example.py new file mode 100644 index 0000000000000..ca640a96f9588 --- /dev/null +++ b/pl_examples/basic_examples/profiler_example.py @@ -0,0 +1,102 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script will generate 2 traces: one for `training_step` and one for `validation_step`. +The traces can be visualized in 2 ways: +* With Chrome: + 1. Open Chrome and copy/paste this url: `chrome://tracing/`. + 2. Once tracing opens, click on `Load` at the top-right and load one of the generated traces. +* With PyTorch Tensorboard Profiler (Instructions are here: https://github.com/pytorch/kineto/tree/master/tb_plugin) + 1. pip install tensorboard torch-tb-profiler + 2. tensorboard --logdir={FOLDER} +""" + +import sys +from argparse import ArgumentParser + +import torch +import torchvision +import torchvision.models as models +import torchvision.transforms as T + +from pl_examples import cli_lightning_logo +from pytorch_lightning import LightningDataModule, LightningModule, Trainer + +DEFAULT_CMD_LINE = ( + "--max_epochs", + "1", + "--limit_train_batches", + "15", + "--limit_val_batches", + "15", + "--profiler", + "pytorch", + "--gpus", + f"{int(torch.cuda.is_available())}", +) + + +class ModelToProfile(LightningModule): + + def __init__(self, model): + super().__init__() + self.model = model + self.criterion = torch.nn.CrossEntropyLoss() + + def training_step(self, batch, batch_idx): + inputs, labels = batch + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + self.log("train_loss", loss) + return loss + + def validation_step(self, batch, batch_idx): + inputs, labels = batch + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + self.log("val_loss", loss) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9) + + +class CIFAR10DataModule(LightningDataModule): + + transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()]) + + def train_dataloader(self, *args, **kwargs): + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=self.transform) + return torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=0) + + def val_dataloader(self, *args, **kwargs): + valset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=self.transform) + return torch.utils.data.DataLoader(valset, batch_size=32, shuffle=True, num_workers=0) + + +def cli_main(): + + parser = ArgumentParser() + parser = Trainer.add_argparse_args(parser) + cmd_line = None if len(sys.argv) != 1 else DEFAULT_CMD_LINE + args = parser.parse_args(args=cmd_line) + + model = ModelToProfile(models.resnet50(pretrained=True)) + datamodule = CIFAR10DataModule() + trainer = Trainer(**vars(args)) + trainer.fit(model, datamodule=datamodule) + + +if __name__ == '__main__': + cli_lightning_logo() + cli_main() diff --git a/pl_examples/basic_examples/submit_ddp2_job.sh b/pl_examples/basic_examples/submit_ddp2_job.sh index 6fed6afef0d1c..026589a604c36 100755 --- a/pl_examples/basic_examples/submit_ddp2_job.sh +++ b/pl_examples/basic_examples/submit_ddp2_job.sh @@ -24,4 +24,4 @@ source activate $1 # ------------------------- # run script from above -srun python3 image_classifier.py --accelerator 'ddp2' --gpus 2 --num_nodes 2 +srun python3 simple_image_classifier.py --accelerator 'ddp2' --gpus 2 --num_nodes 2 --max_epochs 5 diff --git a/pl_examples/basic_examples/submit_ddp_job.sh b/pl_examples/basic_examples/submit_ddp_job.sh index 383579c4346b6..b4f5ff0a64d92 100755 --- a/pl_examples/basic_examples/submit_ddp_job.sh +++ b/pl_examples/basic_examples/submit_ddp_job.sh @@ -24,4 +24,4 @@ source activate $1 # ------------------------- # run script from above -srun python3 image_classifier.py --accelerator 'ddp' --gpus 2 --num_nodes 2 +srun python3 simple_image_classifier.py --accelerator 'ddp' --gpus 2 --num_nodes 2 --max_epochs 5 diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 88f4e66605741..4e148a18433a6 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -49,6 +49,7 @@ from torch.optim.lr_scheduler import MultiStepLR from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader +from torchmetrics import Accuracy from torchvision import models, transforms from torchvision.datasets import ImageFolder from torchvision.datasets.utils import download_and_extract_archive @@ -188,8 +189,8 @@ def __init__( self.__build_model() - self.train_acc = pl.metrics.Accuracy() - self.valid_acc = pl.metrics.Accuracy() + self.train_acc = Accuracy() + self.valid_acc = Accuracy() self.save_hyperparameters() def __build_model(self): diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index 569078c994ba4..b9660475bf2f7 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -2,42 +2,17 @@ import logging import os -import sys -import time -_this_year = time.strftime("%Y") -__version__ = '1.3.0dev' -__author__ = 'William Falcon et al.' -__author_email__ = 'waf2107@columbia.edu' -__license__ = 'Apache-2.0' -__copyright__ = f'Copyright (c) 2018-{_this_year}, {__author__}.' -__homepage__ = 'https://github.com/PyTorchLightning/pytorch-lightning' -# this has to be simple string, see: https://github.com/pypa/twine/issues/522 -__docs__ = ( - "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." - " Scale your models. Write less boilerplate." +from pytorch_lightning.info import ( # noqa: F401 + __author__, + __author_email__, + __copyright__, + __docs__, + __homepage__, + __license__, + __version__, ) -__long_docs__ = """ -Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. - It's more of a style-guide than a framework. -In Lightning, you organize your code into 3 distinct categories: - -1. Research code (goes in the LightningModule). -2. Engineering code (you delete, and is handled by the Trainer). -3. Non-essential research code (logging, etc. this goes in Callbacks). - -Although your research/production project might start simple, once you add things like GPU AND TPU training, - 16-bit precision, etc, you end up spending more time engineering than researching. - Lightning automates AND rigorously tests those parts for you. - -Overall, Lightning guarantees rigorously tested, correct, modern best practices for the automated parts. - -Documentation -------------- -- https://pytorch-lightning.readthedocs.io/en/latest -- https://pytorch-lightning.readthedocs.io/en/stable -""" _root_logger = logging.getLogger() _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) @@ -50,32 +25,20 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) -try: - # This variable is injected in the __builtins__ by the build - # process. It used to enable importing subpackages of skimage when - # the binaries are not built - _ = None if __LIGHTNING_SETUP__ else None -except NameError: - __LIGHTNING_SETUP__: bool = False - -if __LIGHTNING_SETUP__: # pragma: no-cover - sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover - # We are not importing the rest of the lightning during the build process, as it may not be compiled yet -else: - from pytorch_lightning import metrics - from pytorch_lightning.callbacks import Callback - from pytorch_lightning.core import LightningDataModule, LightningModule - from pytorch_lightning.trainer import Trainer - from pytorch_lightning.utilities.seed import seed_everything - - __all__ = [ - 'Trainer', - 'LightningDataModule', - 'LightningModule', - 'Callback', - 'seed_everything', - 'metrics', - ] +from pytorch_lightning import metrics # noqa: E402 +from pytorch_lightning.callbacks import Callback # noqa: E402 +from pytorch_lightning.core import LightningDataModule, LightningModule # noqa: E402 +from pytorch_lightning.trainer import Trainer # noqa: E402 +from pytorch_lightning.utilities.seed import seed_everything # noqa: E402 + +__all__ = [ + 'Trainer', + 'LightningDataModule', + 'LightningModule', + 'Callback', + 'seed_everything', + 'metrics', +] # for compatibility with namespace packages __import__('pkg_resources').declare_namespace(__name__) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 06191dcff6d80..1dcd541ca0610 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -21,8 +21,8 @@ from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import TrainingTypePlugin from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device -from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available from pytorch_lightning.utilities.enums import AMPType, LightningEnum if TYPE_CHECKING: @@ -66,17 +66,29 @@ def __init__( self.lr_schedulers: Sequence = [] self.optimizer_frequencies: Sequence = [] - def setup(self, trainer: 'Trainer', model: LightningModule) -> None: + def connect(self, model: LightningModule) -> None: + """Transfers ownership of the model to this plugin""" + self.training_type_plugin.connect(model) + + def setup_environment(self) -> None: + """ + Setup any processes or distributed connections. + This is called before the LightningModule/DataModule setup hook + which allows the user to access the accelerator environment before setup is complete. """ - Connects the plugins to the training process, creates optimizers + self.training_type_plugin.setup_environment() + def setup(self, trainer: 'Trainer', model: LightningModule) -> None: + """ + Setup plugins for the trainer fit and creates optimizers. Args: - trainer: the trainer instance to connect to - model: the model to train + trainer: the trainer instance + model: the LightningModule """ - self.connect_training_type_plugin(self.training_type_plugin, model) - self.setup_optimizers(trainer) - self.connect_precision_plugin(self.precision_plugin) + self.setup_training_type_plugin(self.training_type_plugin, model) + if not self.training_type_plugin.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) + self.setup_precision_plugin(self.precision_plugin) def start_training(self, trainer: 'Trainer') -> None: self.training_type_plugin.start_training(trainer) @@ -87,12 +99,14 @@ def start_evaluating(self, trainer: 'Trainer') -> None: def start_predicting(self, trainer: 'Trainer') -> None: self.training_type_plugin.start_predicting(trainer) - def pre_dispatch(self) -> None: + def pre_dispatch(self, trainer: 'Trainer') -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.pre_dispatch() + if self.training_type_plugin.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) self.precision_plugin.pre_dispatch() - def post_dispatch(self) -> None: + def post_dispatch(self, trainer: 'Trainer') -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.post_dispatch() self.precision_plugin.post_dispatch() @@ -206,7 +220,7 @@ def test_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context(): return self.training_type_plugin.test_step(*args) - def predict(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: + def predict_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: """The actual predict step. Args: @@ -222,7 +236,7 @@ def predict(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: args[0] = batch with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context(): - return self.training_type_plugin.predict(*args) + return self.training_type_plugin.predict_step(*args) def training_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: """A hook to do something at the end of the training step @@ -333,14 +347,11 @@ def setup_optimizers(self, trainer: 'Trainer') -> None: self.lr_schedulers = lr_schedulers self.optimizer_frequencies = optimizer_frequencies - def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: - """Attaches the training type plugin to the accelerator. - Also transfers ownership of the model to this plugin - - """ - plugin.connect(model) + def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: + """Attaches the training type plugin to the accelerator.""" + plugin.setup(model) - def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: + def setup_precision_plugin(self, plugin: PrecisionPlugin) -> None: """Attaches the precision plugin to the accelerator""" model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers) self.model = model @@ -349,7 +360,12 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: def to_device(self, batch: Any) -> Any: """Pushes the batch to the root device""" - return self.batch_to_device(batch, self.root_device) + # Todo (tchaton) Better fix + is_dict = isinstance(batch, dict) + if is_dict: + batch = [batch] + batch = self.batch_to_device(batch, self.root_device) + return batch[0] if is_dict else batch @property def amp_backend(self) -> Optional[LightningEnum]: @@ -405,7 +421,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra Return: A tensor of shape (world_size, batch, ...) """ - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads) def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: """Wraps the dataloader if necessary @@ -422,3 +438,31 @@ def results(self) -> Any: In distributed training, we make sure to transfer the results to the appropriate master process. """ return self.training_type_plugin.results + + # todo: remove in v1.5 + def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: + """ + Attaches the training type plugin to the accelerator. + Also transfers ownership of the model to this plugin + + .. deprecated::v1.3 + Will be removed in v1.5.0. + """ + rank_zero_warn( + 'Accelerator method `connect_training_type_plugin` was deprecated in v1.3.' + ' It will be removed in v1.5.' + ) + self.setup_training_type_plugin(plugin, model) + + # todo: remove in v1.5 + def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: + """Attaches the precision plugin to the accelerator + + .. deprecated::v1.3 + Will be removed in v1.5.0. + """ + rank_zero_warn( + 'Accelerator method `connect_precision_plugin` was deprecated in v1.3.' + ' It will be removed in v1.5.' + ) + self.setup_precision_plugin(plugin) diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index f428951b16932..22ea8f1e1b7aa 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import TYPE_CHECKING from pytorch_lightning.accelerators.accelerator import Accelerator diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index af9ce25f902b3..c23960e4fd9e3 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -1,6 +1,19 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import os -from typing import TYPE_CHECKING, Any +from typing import Any, TYPE_CHECKING import torch diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 57e65a62f6783..35a475e3e790d 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -1,4 +1,17 @@ -from typing import Any, Callable, Optional, TYPE_CHECKING +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch from torch.optim import Optimizer @@ -12,6 +25,9 @@ if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm + from torch_xla._patched_functions import clip_grad_norm_ + + xla_clip_grad_norm_ = clip_grad_norm_ if TYPE_CHECKING: from pytorch_lightning.core.lightning import LightningModule @@ -46,12 +62,25 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra Function to gather a tensor from several distributed processes Args: tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op + group: not available with TPUs + sync_grads: not available with TPUs Return: A tensor of shape (world_size, batch, ...) """ # todo: Add support for backward with all_gather - if torch.distributed.is_initialized(): - return xm.all_gather(tensor, group=group, sync_grads=sync_grads) + if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed: + return xm.all_gather(tensor).view(-1, *tensor.shape) return tensor + + def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0): + + model = self.lightning_module + parameters = model.parameters() + + grad_clip_val = float(clip_val) + if grad_clip_val <= 0: + return + + max_norm = grad_clip_val + + xla_clip_grad_norm_(parameters, max_norm, norm_type) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 0ba1fd4ff7785..7757902bd3baf 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -17,7 +17,7 @@ """ import abc -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from pytorch_lightning.core.lightning import LightningModule @@ -81,7 +81,7 @@ def on_train_epoch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the train epoch begins.""" pass - def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None: + def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: """Called when the train epoch ends.""" pass @@ -89,7 +89,7 @@ def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None """Called when the val epoch begins.""" pass - def on_validation_epoch_end(self, trainer, pl_module: LightningModule) -> None: + def on_validation_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: """Called when the val epoch ends.""" pass @@ -97,16 +97,16 @@ def on_test_epoch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the test epoch begins.""" pass - def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None: + def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: """Called when the test epoch ends.""" pass def on_epoch_start(self, trainer, pl_module: LightningModule) -> None: - """Called when the epoch begins.""" + """Called when either of train/val/test epoch begins.""" pass def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: - """Called when the epoch ends.""" + """Called when either of train/val/test epoch ends.""" pass def on_batch_start(self, trainer, pl_module: LightningModule) -> None: diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 38ccce648502a..4448de8e4834b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -172,4 +172,4 @@ def _run_early_stopping_check(self, trainer): trainer.should_stop = True # stop every ddp process if any world process decides to stop - trainer.should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(trainer.should_stop) + trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop) diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py index 0af7d61bf5dec..b1885087f4da0 100644 --- a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -74,7 +74,7 @@ def __init__(self, scheduling: Dict[int, int]): def going_to_accumulate_grad_batches(self): return any([v > 1 for v in self.scheduling.values()]) - def on_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, trainer, pl_module): epoch = trainer.current_epoch for i in reversed(range(len(self.epochs))): if epoch >= self.epochs[i]: diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index bf6c799ef728a..2781586730151 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -30,7 +30,7 @@ import yaml from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import WarningCache @@ -258,9 +258,9 @@ def save_checkpoint(self, trainer, unused: Optional = None): to handle correct behaviour in distributed training, i.e., saving only on rank 0. """ if unused is not None: - rank_zero_warn( + rank_zero_deprecation( "`ModelCheckpoint.save_checkpoint` signature has changed in v1.3. The `pl_module` parameter" - " has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning + " has been removed. Support for the old signature will be removed in v1.5" ) global_step = trainer.global_step @@ -371,9 +371,9 @@ def __init_triggers( # period takes precedence over every_n_val_epochs for backwards compatibility if period is not None: - rank_zero_warn( + rank_zero_deprecation( 'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' - ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ' Please use `every_n_val_epochs` instead.' ) self._every_n_val_epochs = period @@ -381,17 +381,17 @@ def __init_triggers( @property def period(self) -> Optional[int]: - rank_zero_warn( + rank_zero_deprecation( 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' - ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ' Please use `every_n_val_epochs` instead.' ) return self._period @period.setter def period(self, value: Optional[int]) -> None: - rank_zero_warn( + rank_zero_deprecation( 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' - ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ' Please use `every_n_val_epochs` instead.' ) self._period = value @@ -424,7 +424,7 @@ def _do_save(self, trainer, filepath: str): else: raise ValueError(".save_function() not set") - def check_monitor_top_k(self, current: torch.Tensor) -> bool: + def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) -> bool: if current is None: return False @@ -444,7 +444,12 @@ def check_monitor_top_k(self, current: torch.Tensor) -> bool: current = torch.tensor(current) monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode] - return monitor_op(current, self.best_k_models[self.kth_best_model_path]).item() + should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) + + # If using multiple devices, make sure all processes are unanimous on the decision. + should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save) + + return should_update_best_and_save @classmethod def _format_checkpoint_name( @@ -638,15 +643,7 @@ def _save_top_k_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): epoch = monitor_candidates.get("epoch") step = monitor_candidates.get("step") - # when `val_loss` is being logged and no ModelCheckpoint is being provided - # `val_loss` will be selected for monitor and need to be reduced to - # prevent processes divergence - # TODO: Move this logic to logger_connector. This also needs to be fixed for any - # other monitor logged value which aren't produced from a Metric. - if self.monitor == "val_loss": - current = trainer.training_type_plugin.reduce(current, reduce_op="mean") - - if self.check_monitor_top_k(current): + if self.check_monitor_top_k(trainer, current): self._update_best_and_save(current, epoch, step, trainer, monitor_candidates) elif self.verbose: rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}") @@ -731,5 +728,4 @@ def file_exists(self, filepath: Union[str, Path], trainer) -> bool: the internal state to diverge between ranks. """ exists = self._fs.exists(filepath) - exists = trainer.training_type_plugin.broadcast(exists) - return exists + return trainer.training_type_plugin.broadcast(exists) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 74e57e2b5642e..7dc4202530d04 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -39,8 +39,7 @@ class tqdm(_tqdm): """ - Custom tqdm progressbar where we append 0 to floating points/strings to - prevent the progress bar from flickering + Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from flickering """ @staticmethod @@ -201,7 +200,7 @@ def on_init_end(self, trainer): def on_train_start(self, trainer, pl_module): self._train_batch_idx = trainer.batch_idx - def on_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, trainer, pl_module): self._train_batch_idx = 0 def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): @@ -393,8 +392,8 @@ def on_train_start(self, trainer, pl_module): super().on_train_start(trainer, pl_module) self.main_progress_bar = self.init_train_tqdm() - def on_epoch_start(self, trainer, pl_module): - super().on_epoch_start(trainer, pl_module) + def on_train_epoch_start(self, trainer, pl_module): + super().on_train_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches if total_train_batches != float('inf'): diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 994c259f48964..4178c9eeacd50 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -14,7 +14,6 @@ """LightningDataModule for loading DataLoaders with ease.""" import functools -from abc import abstractmethod from argparse import ArgumentParser, Namespace from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union @@ -44,6 +43,8 @@ def __call__(cls, *args, **kwargs): cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data)) # Track setup calls cls.setup = track_data_hook_calls(cls.setup) + # Track teardown calls + cls.teardown = track_data_hook_calls(cls.teardown) # Get instance of LightningDataModule by mocking its __init__ via __call__ obj = type.__call__(cls, *args, **kwargs) @@ -52,12 +53,13 @@ def __call__(cls, *args, **kwargs): def track_data_hook_calls(fn): - """A decorator that checks if prepare_data/setup have been called. + """A decorator that checks if prepare_data/setup/teardown has been called. - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. Its corresponding `dm_has_setup_{stage}` attribute gets set to True + - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup`` Args: fn (function): Function that will be tracked to see if it has been called. @@ -71,9 +73,10 @@ def wrapped_fn(*args, **kwargs): # The object instance from which setup or prepare_data was called obj = args[0] + name = fn.__name__ # If calling setup, we check the stage and assign stage-specific bool args - if fn.__name__ == "setup": + if name in ("setup", "teardown"): # Get stage either by grabbing from args or checking kwargs. # If not provided, set call status of 'fit', 'validate', and 'test' to True. @@ -82,11 +85,11 @@ def wrapped_fn(*args, **kwargs): if stage is None: for s in ("fit", "validate", "test"): - setattr(obj, f"_has_setup_{s}", True) + setattr(obj, f"_has_{name}_{s}", True) else: - setattr(obj, f"_has_setup_{stage}", True) + setattr(obj, f"_has_{name}_{stage}", True) - if fn.__name__ == "prepare_data": + elif name == "prepare_data": obj._has_prepared_data = True return fn(*args, **kwargs) @@ -119,14 +122,18 @@ def val_dataloader(self): def test_dataloader(self): test_split = Dataset(...) return DataLoader(test_split) + def teardown(self): + # clean up after fit or test + # called on every process in DDP - A DataModule implements 5 key methods: + A DataModule implements 6 key methods: * **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode). * **setup** (things to do on every accelerator in distributed mode). * **train_dataloader** the training dataloader. * **val_dataloader** the val dataloader(s). * **test_dataloader** the test dataloader(s). + * **teardown** (things to do on every accelerator in distributed mode when finished) This allows you to share a full dataset without explaining how to download, @@ -154,11 +161,17 @@ def __init__( # Private attrs to keep track of whether or not data hooks have been called yet self._has_prepared_data = False + self._has_setup_fit = False self._has_setup_validate = False self._has_setup_test = False self._has_setup_predict = False + self._has_teardown_fit = False + self._has_teardown_validate = False + self._has_teardown_test = False + self._has_teardown_predict = False + @property def train_transforms(self): """ @@ -259,13 +272,41 @@ def has_setup_predict(self) -> bool: """ return self._has_setup_predict - @abstractmethod - def prepare_data(self, *args, **kwargs): - pass + @property + def has_teardown_fit(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='fit')`` has been called or not. - @abstractmethod - def setup(self, stage: Optional[str] = None): - pass + Returns: + bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default. + """ + return self._has_teardown_fit + + @property + def has_teardown_validate(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='validate')`` has been called or not. + + Returns: + bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default. + """ + return self._has_teardown_validate + + @property + def has_teardown_test(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='test')`` has been called or not. + + Returns: + bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default. + """ + return self._has_teardown_test + + @property + def has_teardown_predict(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='predict')`` has been called or not. + + Returns: + bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default. + """ + return self._has_teardown_predict @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 1399d1b3c66ba..bf3b0bf605679 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -25,42 +25,6 @@ class ModelHooks: """Hooks to be used in LightningModule.""" - def setup(self, stage: Optional[str] = None) -> None: - """ - Called at the beginning of fit (train + validate), validate, test, predict, or tune. - This is a good hook when you need to build models dynamically or adjust something about them. - This hook is called on every process when using DDP. - - Args: - stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` - - Example:: - - class LitModel(...): - def __init__(self): - self.l1 = None - - def prepare_data(self): - download_data() - tokenize() - - # don't do this - self.something = else - - def setup(stage): - data = Load_data(...) - self.l1 = nn.Linear(28, data.num_classes) - - """ - - def teardown(self, stage: Optional[str] = None) -> None: - """ - Called at the end of fit (train + validate), validate, test, predict, or tune. - - Args: - stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` - """ - def on_fit_start(self) -> None: """ Called at the very beginning of fit. @@ -224,13 +188,13 @@ def on_predict_model_eval(self) -> None: def on_epoch_start(self) -> None: """ - Called in the training loop at the very beginning of the epoch. + Called when either of train/val/test epoch begins. """ # do something when the epoch starts def on_epoch_end(self) -> None: """ - Called in the training loop at the very end of the epoch. + Called when either of train/val/test epoch ends. """ # do something when the epoch ends @@ -240,7 +204,7 @@ def on_train_epoch_start(self) -> None: """ # do something when the epoch starts - def on_train_epoch_end(self, outputs) -> None: + def on_train_epoch_end(self, outputs: List[Any]) -> None: """ Called in the training loop at the very end of the epoch. """ @@ -252,7 +216,7 @@ def on_validation_epoch_start(self) -> None: """ # do something when the epoch starts - def on_validation_epoch_end(self) -> None: + def on_validation_epoch_end(self, outputs: List[Any]) -> None: """ Called in the validation loop at the very end of the epoch. """ @@ -264,7 +228,7 @@ def on_test_epoch_start(self) -> None: """ # do something when the epoch starts - def on_test_epoch_end(self) -> None: + def on_test_epoch_end(self, outputs: List[Any]) -> None: """ Called in the test loop at the very end of the epoch. """ @@ -282,6 +246,18 @@ def on_test_end(self) -> None: """ # do something at the end of testing + def on_predict_start(self) -> None: + """ + Called at the beginning of predicting. + """ + # do something at the start of predicting + + def on_predict_end(self) -> None: + """ + Called at the end of predicting. + """ + # do something at the end of predicting + def on_before_zero_grad(self, optimizer: Optimizer) -> None: """ Called after optimizer.step() and before optimizer.zero_grad(). @@ -383,6 +359,42 @@ def prepare_data(self): model.test_dataloader() """ + def setup(self, stage: Optional[str] = None) -> None: + """ + Called at the beginning of fit (train + validate), validate, test, predict, or tune. + This is a good hook when you need to build models dynamically or adjust something about them. + This hook is called on every process when using DDP. + + Args: + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` + + Example:: + + class LitModel(...): + def __init__(self): + self.l1 = None + + def prepare_data(self): + download_data() + tokenize() + + # don't do this + self.something = else + + def setup(stage): + data = Load_data(...) + self.l1 = nn.Linear(28, data.num_classes) + + """ + + def teardown(self, stage: Optional[str] = None) -> None: + """ + Called at the end of fit (train + validate), validate, test, predict, or tune. + + Args: + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` + """ + def train_dataloader(self) -> Any: """ Implement one or more PyTorch DataLoaders for training. @@ -594,6 +606,18 @@ def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: will have an argument ``dataloader_idx`` which matches the order here. """ + def on_train_dataloader(self) -> None: + """Called before requesting the train dataloader.""" + + def on_val_dataloader(self) -> None: + """Called before requesting the val dataloader.""" + + def on_test_dataloader(self) -> None: + """Called before requesting the test dataloader.""" + + def on_predict_dataloader(self) -> None: + """Called before requesting the predict dataloader.""" + def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: """ Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 4c839f3a6c906..7efe88515b37e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,7 +38,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES from pytorch_lightning.core.step_result import Result -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -105,6 +105,7 @@ def __init__(self, *args, **kwargs): self._current_hook_fx_name = None self._current_dataloader_idx = None self._automatic_optimization: bool = True + self._param_requires_grad_state = dict() def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: @@ -719,10 +720,13 @@ def validation_step(self, *args, **kwargs): .. code-block:: python # pseudocode of order - out = validation_step() - if defined('validation_step_end'): - out = validation_step_end(out) - out = validation_epoch_end(out) + val_outs = [] + for val_batch in val_data: + out = validation_step(val_batch) + if defined('validation_step_end'): + out = validation_step_end(out) + val_outs.append(out) + val_outs = validation_epoch_end(val_outs) .. code-block:: python @@ -1053,7 +1057,7 @@ def test_epoch_end(self, outputs): self.log('final_metric', final_value) """ - def predict(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None): + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None): """ Use this function with trainer.predict(...). Override if you need to add any processing logic. """ @@ -1225,9 +1229,8 @@ def training_step(...): opt_a.step() """ if optimizer is not None: - rank_zero_warn( - "`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4", - DeprecationWarning + rank_zero_deprecation( + "`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4" ) # make sure we're using manual opt @@ -1311,7 +1314,7 @@ def untoggle_optimizer(self, optimizer_idx: int): if param in self._param_requires_grad_state: param.requires_grad = self._param_requires_grad_state[param] # save memory - del self._param_requires_grad_state + self._param_requires_grad_state = dict() def optimizer_step( self, diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index afb64535d1470..a3eab728f8ea8 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -16,7 +16,7 @@ import shutil import subprocess from collections import OrderedDict -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -71,14 +71,15 @@ def __init__(self, module: nn.Module): def __del__(self): self.detach_hook() - def _register_hook(self) -> RemovableHandle: + def _register_hook(self) -> Optional[RemovableHandle]: """ Registers a hook on the module that computes the input- and output size(s) on the first forward pass. If the hook is called, it will remove itself from the from the module, meaning that recursive models will only record their input- and output shapes once. + Registering hooks on :class:`~torch.jit.ScriptModule` is not supported. Return: - A handle for the installed hook. + A handle for the installed hook, or ``None`` if registering the hook is not possible. """ def hook(module, inp, out): @@ -88,7 +89,10 @@ def hook(module, inp, out): self._out_size = parse_batch_shape(out) self._hook_handle.remove() - return self._module.register_forward_hook(hook) + handle = None + if not isinstance(self._module, torch.jit.ScriptModule): + handle = self._module.register_forward_hook(hook) + return handle def detach_hook(self): """ diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index f8d7a2ffe3a23..3961586f4946a 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -20,8 +20,8 @@ import torch from torch import Tensor +from torchmetrics import Metric -from pytorch_lightning.metrics import Metric from pytorch_lightning.utilities.distributed import sync_ddp_if_available diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index 5da7dfa86084d..37ac5d8b13462 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -11,18 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io from typing import Any -import torch -from torch import distributed as torch_distrib - -from pytorch_lightning.utilities import _GROUP_AVAILABLE - -WORLD = None -if _GROUP_AVAILABLE: - from torch.distributed import group - WORLD = group.WORLD +from pytorch_lightning.overrides.torch_distributed import broadcast_object_list +from pytorch_lightning.utilities.distributed import group as _group class LightningDistributed: @@ -31,32 +23,13 @@ def __init__(self, rank=None, device=None): self.rank = rank self.device = device - def broadcast(self, obj: Any, group=WORLD): - if self.rank == 0: - self._emit(obj, group) - else: - obj = self._receive(group) - return obj - - def _broadcast(self, tensor, src=0, group=WORLD): - if group is None: - return torch_distrib.broadcast(tensor, src=src) - return torch_distrib.broadcast(tensor, src=0, group=group) - - def _emit(self, obj: Any, group=WORLD): - buffer = io.BytesIO() - torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - length_tensor = torch.tensor([len(data)]).long().to(self.device) - self._broadcast(length_tensor, src=0, group=group) - data_tensor = torch.ByteTensor(data).to(self.device) - self._broadcast(data_tensor, src=0, group=group) - - def _receive(self, group=WORLD): - length_tensor = torch.tensor([0]).long().to(self.device) - self._broadcast(length_tensor, src=0, group=group) - data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device) - self._broadcast(data_tensor, src=0, group=group) - buffer = io.BytesIO(data_tensor.cpu().numpy()) - obj = torch.load(buffer) - return obj + def broadcast(self, obj: Any, group=_group.WORLD): + # always wrap into a list so list can be brodcasted. + obj = [obj] + + if self.rank != 0: + obj = [None] * len(obj) + + broadcast_object_list(obj, 0, group=group or _group.WORLD) + + return obj[0] diff --git a/pytorch_lightning/info.py b/pytorch_lightning/info.py new file mode 100644 index 0000000000000..b00d1946424e7 --- /dev/null +++ b/pytorch_lightning/info.py @@ -0,0 +1,36 @@ +import time + +_this_year = time.strftime("%Y") +__version__ = '1.3.0dev' +__author__ = 'William Falcon et al.' +__author_email__ = 'waf2107@columbia.edu' +__license__ = 'Apache-2.0' +__copyright__ = f'Copyright (c) 2018-{_this_year}, {__author__}.' +__homepage__ = 'https://github.com/PyTorchLightning/pytorch-lightning' +__docs_url__ = "https://pytorch-lightning.readthedocs.io/en/stable/" +# this has to be simple string, see: https://github.com/pypa/twine/issues/522 +__docs__ = ( + "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." + " Scale your models. Write less boilerplate." +) +__long_docs__ = """ +Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. + It's more of a style-guide than a framework. + +In Lightning, you organize your code into 3 distinct categories: + +1. Research code (goes in the LightningModule). +2. Engineering code (you delete, and is handled by the Trainer). +3. Non-essential research code (logging, etc. this goes in Callbacks). + +Although your research/production project might start simple, once you add things like GPU AND TPU training, + 16-bit precision, etc, you end up spending more time engineering than researching. + Lightning automates AND rigorously tests those parts for you. + +Overall, Lightning guarantees rigorously tested, correct, modern best practices for the automated parts. + +Documentation +------------- +- https://pytorch-lightning.readthedocs.io/en/latest +- https://pytorch-lightning.readthedocs.io/en/stable +""" diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 4903febc1f633..82412c79d3f1f 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -79,7 +79,9 @@ def any_lightning_module_function_or_hook(self): Defaults to `./mlflow` if `tracking_uri` is not provided. Has no effect if `tracking_uri` is provided. prefix: A string to put at the beginning of metric keys. - figure_file_extension: File extension with which matplotlib saves figure + artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate + default. + figure_file_extension: File extension with which matplotlib saves figures. Raises: ImportError: @@ -95,6 +97,7 @@ def __init__( tags: Optional[Dict[str, Any]] = None, save_dir: Optional[str] = './mlruns', prefix: str = '', + artifact_location: Optional[str] = None, figure_file_extension='.png', ): if mlflow is None: @@ -112,9 +115,11 @@ def __init__( self._run_id = None self.tags = tags self._prefix = prefix - self._mlflow_client = MlflowClient(tracking_uri) + self._artifact_location = artifact_location self._figure_file_extension = figure_file_extension + self._mlflow_client = MlflowClient(tracking_uri) + @property @rank_zero_experiment def experiment(self) -> MlflowClient: @@ -133,7 +138,10 @@ def experiment(self) -> MlflowClient: self._experiment_id = expt.experiment_id else: log.warning(f'Experiment with name {self._experiment_name} not found. Creating it.') - self._experiment_id = self._mlflow_client.create_experiment(name=self._experiment_name) + self._experiment_id = self._mlflow_client.create_experiment( + name=self._experiment_name, + artifact_location=self._artifact_location, + ) if self._run_id is None: run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=self.tags) diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index a5a337f2cba9d..9b27fdf0cb253 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from pytorch_lightning.metrics.classification import ( # noqa: F401 Accuracy, AUC, @@ -37,3 +38,9 @@ R2Score, SSIM, ) +from pytorch_lightning.utilities import rank_zero_deprecation + +rank_zero_deprecation( + "`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package" + " (https://github.com/PyTorchLightning/metrics) since v1.3 and will be removed in v1.5" +) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 9d97cbec1a387..1a9febe0c831c 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -13,94 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch +from torchmetrics import Accuracy as _Accuracy -from pytorch_lightning.metrics.functional.accuracy import _accuracy_compute, _accuracy_update -from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.utils import deprecated_metrics -class Accuracy(Metric): - r""" - Computes `Accuracy `__: - - .. math:: - \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a - tensor of predictions. - - For multi-class and multi-dimensional multi-class data with probability predictions, the - parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the - top-K highest probability items are considered to find the correct label. - - For multi-label and multi-dimensional multi-class inputs, this metric computes the "global" - accuracy by default, which counts all labels or sub-samples separately. This can be - changed to subset accuracy (which requires all labels or sub-samples in the sample to - be correctly predicted) by setting ``subset_accuracy=True``. - - Accepts all input types listed in :ref:`extensions/metrics:input types`. - - Args: - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. - top_k: - Number of highest probability predictions considered to find the correct label, relevant - only for (multi-dimensional) multi-class inputs with probability predictions. The - default value (``None``) will be interpreted as 1 for these inputs. - - Should be left at default (``None``) for all other types of inputs. - subset_accuracy: - Whether to compute subset accuracy for multi-label and multi-dimensional - multi-class inputs (has no effect for other input types). - - - For multi-label inputs, if the parameter is set to ``True``, then all labels for - each sample must be correctly predicted for the sample to count as correct. If it - is set to ``False``, then all labels are counted separately - this is equivalent to - flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). - - - For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all - sub-sample (on the extra axis) must be correct for the sample to be counted as correct. - If it is set to ``False``, then all sub-samples are counter separately - this is equivalent, - in the case of label predictions, to flattening the inputs beforehand (i.e. - ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter - still applies in both cases, if set. - - compute_on_step: - Forward only calls ``update()`` and return ``None`` if this is set to ``False``. - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step - process_group: - Specify the process group on which synchronization is called. - default: ``None`` (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When ``None``, DDP - will be used to perform the allgather - - Raises: - ValueError: - If ``threshold`` is not between ``0`` and ``1``. - ValueError: - If ``top_k`` is not an ``integer`` larger than ``0``. - - Example: - - >>> from pytorch_lightning.metrics import Accuracy - >>> target = torch.tensor([0, 1, 2, 3]) - >>> preds = torch.tensor([0, 2, 1, 3]) - >>> accuracy = Accuracy() - >>> accuracy(preds, target) - tensor(0.5000) - - >>> target = torch.tensor([0, 1, 2]) - >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) - >>> accuracy = Accuracy(top_k=2) - >>> accuracy(preds, target) - tensor(0.6667) - - """ +class Accuracy(_Accuracy): + @deprecated_metrics(target=_Accuracy) def __init__( self, threshold: float = 0.5, @@ -111,45 +31,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - if not 0 < threshold < 1: - raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") - - if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): - raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") - - self.threshold = threshold - self.top_k = top_k - self.subset_accuracy = subset_accuracy - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. See :ref:`extensions/metrics:input types` for more information - on input types. - - Args: - preds: Predictions from model (probabilities, or labels) - target: Ground truth labels """ + This implementation refers to :class:`~torchmetrics.Accuracy`. - correct, total = _accuracy_update( - preds, target, threshold=self.threshold, top_k=self.top_k, subset_accuracy=self.subset_accuracy - ) - - self.correct += correct - self.total += total - - def compute(self) -> torch.Tensor: - """ - Computes accuracy based on inputs passed in to ``update`` previously. + .. deprecated:: + Use :class:`~torchmetrics.Accuracy`. Will be removed in v1.5.0. """ - return _accuracy_compute(self.correct, self.total) diff --git a/pytorch_lightning/metrics/classification/auc.py b/pytorch_lightning/metrics/classification/auc.py index 6c5a29173d20a..05bc7b27d7e68 100644 --- a/pytorch_lightning/metrics/classification/auc.py +++ b/pytorch_lightning/metrics/classification/auc.py @@ -13,36 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch +from torchmetrics import AUC as _AUC -from pytorch_lightning.metrics.functional.auc import _auc_compute, _auc_update -from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.metrics.utils import deprecated_metrics -class AUC(Metric): - r""" - Computes Area Under the Curve (AUC) using the trapezoidal rule - - Forward accepts two input tensors that should be 1D and have the same number - of elements - - Args: - reorder: AUC expects its first input to be sorted. If this is not the case, - setting this argument to ``True`` will use a stable sorting algorithm to - sort the input in decending order - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When ``None``, DDP - will be used to perform the allgather - """ +class AUC(_AUC): + @deprecated_metrics(target=_AUC) def __init__( self, reorder: bool = False, @@ -51,40 +29,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.reorder = reorder - - self.add_state("x", default=[], dist_reduce_fx=None) - self.add_state("y", default=[], dist_reduce_fx=None) - - rank_zero_warn( - 'Metric `AUC` will save all targets and predictions in buffer.' - ' For large datasets this may lead to large memory footprint.' - ) - - def update(self, x: torch.Tensor, y: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - x: Predictions from model (probabilities, or labels) - y: Ground truth labels """ - x, y = _auc_update(x, y) + This implementation refers to :class:`~torchmetrics.AUC`. - self.x.append(x) - self.y.append(y) - - def compute(self) -> torch.Tensor: - """ - Computes AUC based on inputs passed in to ``update`` previously. + .. deprecated:: + Use :class:`~torchmetrics.AUC`. Will be removed in v1.5.0. """ - x = torch.cat(self.x, dim=0) - y = torch.cat(self.y, dim=0) - return _auc_compute(x, y, reorder=self.reorder) diff --git a/pytorch_lightning/metrics/classification/auroc.py b/pytorch_lightning/metrics/classification/auroc.py index 6b9b5ae9f021f..e10b094fd5a2e 100644 --- a/pytorch_lightning/metrics/classification/auroc.py +++ b/pytorch_lightning/metrics/classification/auroc.py @@ -11,95 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from distutils.version import LooseVersion from typing import Any, Callable, Optional -import torch +from torchmetrics import AUROC as _AUROC -from pytorch_lightning.metrics.functional.auroc import _auroc_compute, _auroc_update -from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.metrics.utils import deprecated_metrics -class AUROC(Metric): - r"""Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC) - `_. - Works for both binary, multilabel and multiclass problems. In the case of - multiclass, the values will be calculated based on a one-vs-the-rest approach. - - Forward accepts - - - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor - with probabilities, where C is the number of classes. - - - ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels - - For non-binary input, if the ``preds`` and ``target`` tensor have the same - size the input will be interpretated as multilabel and if ``preds`` have one - dimension more than the ``target`` tensor the input will be interpretated as - multiclass. - - Args: - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - average: - - ``'macro'`` computes metric for each class and uniformly averages them - - ``'weighted'`` computes metric for each class and does a weighted-average, - where each class is weighted by their support (accounts for class imbalance) - - ``None`` computes and returns the metric per class - max_fpr: - If not ``None``, calculates standardized partial AUC over the - range [0, max_fpr]. Should be a float between 0 and 1. - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When ``None``, DDP - will be used to perform the allgather - - Raises: - ValueError: - If ``average`` is none of ``None``, ``"macro"`` or ``"weighted"``. - ValueError: - If ``max_fpr`` is not a ``float`` in the range ``(0, 1]``. - RuntimeError: - If ``PyTorch version`` is ``below 1.6`` since max_fpr requires ``torch.bucketize`` - which is not available below 1.6. - ValueError: - If the mode of data (binary, multi-label, multi-class) changes between batches. - - Example (binary case): - - >>> from pytorch_lightning.metrics import AUROC - >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) - >>> target = torch.tensor([0, 0, 1, 1, 1]) - >>> auroc = AUROC(pos_label=1) - >>> auroc(preds, target) - tensor(0.5000) - - Example (multiclass case): - - >>> from pytorch_lightning.metrics import AUROC - >>> preds = torch.tensor([[0.90, 0.05, 0.05], - ... [0.05, 0.90, 0.05], - ... [0.05, 0.05, 0.90], - ... [0.85, 0.05, 0.10], - ... [0.10, 0.10, 0.80]]) - >>> target = torch.tensor([0, 1, 1, 2, 2]) - >>> auroc = AUROC(num_classes=3) - >>> auroc(preds, target) - tensor(0.7778) - - """ +class AUROC(_AUROC): + @deprecated_metrics(target=_AUROC) def __init__( self, num_classes: Optional[int] = None, @@ -111,74 +32,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.num_classes = num_classes - self.pos_label = pos_label - self.average = average - self.max_fpr = max_fpr - - allowed_average = (None, 'macro', 'weighted') - if self.average not in allowed_average: - raise ValueError( - f'Argument `average` expected to be one of the following: {allowed_average} but got {average}' - ) - - if self.max_fpr is not None: - if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1): - raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}") - - if LooseVersion(torch.__version__) < LooseVersion('1.6.0'): - raise RuntimeError( - '`max_fpr` argument requires `torch.bucketize` which is not available below PyTorch version 1.6' - ) - - self.mode = None - self.add_state("preds", default=[], dist_reduce_fx=None) - self.add_state("target", default=[], dist_reduce_fx=None) - - rank_zero_warn( - 'Metric `AUROC` will save all targets and predictions in buffer.' - ' For large datasets this may lead to large memory footprint.' - ) - - def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. + This implementation refers to :class:`~torchmetrics.AUROC`. - Args: - preds: Predictions from model (probabilities, or labels) - target: Ground truth labels - """ - preds, target, mode = _auroc_update(preds, target) - - self.preds.append(preds) - self.target.append(target) - - if self.mode is not None and self.mode != mode: - raise ValueError( - 'The mode of data (binary, multi-label, multi-class) should be constant, but changed' - f' between batches from {self.mode} to {mode}' - ) - self.mode = mode - - def compute(self) -> torch.Tensor: - """ - Computes AUROC based on inputs passed in to ``update`` previously. + .. deprecated:: + Use :class:`~torchmetrics.AUROC`. Will be removed in v1.5.0. """ - preds = torch.cat(self.preds, dim=0) - target = torch.cat(self.target, dim=0) - return _auroc_compute( - preds, - target, - self.mode, - self.num_classes, - self.pos_label, - self.average, - self.max_fpr, - ) diff --git a/pytorch_lightning/metrics/classification/average_precision.py b/pytorch_lightning/metrics/classification/average_precision.py index f9c7bde158383..6c8cdbd52891d 100644 --- a/pytorch_lightning/metrics/classification/average_precision.py +++ b/pytorch_lightning/metrics/classification/average_precision.py @@ -11,64 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Union +from typing import Any, Optional -import torch +from torchmetrics import AveragePrecision as _AveragePrecision -from pytorch_lightning.metrics.functional.average_precision import _average_precision_compute, _average_precision_update -from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.metrics.utils import deprecated_metrics -class AveragePrecision(Metric): - """ - Computes the average precision score, which summarises the precision recall - curve into one number. Works for both binary and multiclass problems. - In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. - - Forward accepts - - - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor - with probabilities, where C is the number of classes. - - - ``target`` (long tensor): ``(N, ...)`` with integer labels - - Args: - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example (binary case): - - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 1]) - >>> average_precision = AveragePrecision(pos_label=1) - >>> average_precision(pred, target) - tensor(1.) - - Example (multiclass case): - - >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.05, 0.75, 0.05]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> average_precision = AveragePrecision(num_classes=5) - >>> average_precision(pred, target) - [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] - - """ +class AveragePrecision(_AveragePrecision): + @deprecated_metrics(target=_AveragePrecision) def __init__( self, num_classes: Optional[int] = None, @@ -77,48 +29,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - ) - - self.num_classes = num_classes - self.pos_label = pos_label - - self.add_state("preds", default=[], dist_reduce_fx=None) - self.add_state("target", default=[], dist_reduce_fx=None) - - rank_zero_warn( - 'Metric `AveragePrecision` will save all targets and predictions in buffer.' - ' For large datasets this may lead to large memory footprint.' - ) - - def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values - """ - preds, target, num_classes, pos_label = _average_precision_update( - preds, target, self.num_classes, self.pos_label - ) - self.preds.append(preds) - self.target.append(target) - self.num_classes = num_classes - self.pos_label = pos_label - - def compute(self) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Compute the average precision score - - Returns: - tensor with average precision. If multiclass will return list - of such tensors, one for each class + This implementation refers to :class:`~torchmetrics.AveragePrecision`. + .. deprecated:: + Use :class:`~torchmetrics.AveragePrecision`. Will be removed in v1.5.0. """ - preds = torch.cat(self.preds, dim=0) - target = torch.cat(self.target, dim=0) - return _average_precision_compute(preds, target, self.num_classes, self.pos_label) diff --git a/pytorch_lightning/metrics/classification/confusion_matrix.py b/pytorch_lightning/metrics/classification/confusion_matrix.py index c3defc82bc92d..2995f668380de 100644 --- a/pytorch_lightning/metrics/classification/confusion_matrix.py +++ b/pytorch_lightning/metrics/classification/confusion_matrix.py @@ -13,64 +13,14 @@ # limitations under the License. from typing import Any, Optional -import torch +from torchmetrics import ConfusionMatrix as _ConfusionMatrix -from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update -from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.utils import deprecated_metrics -class ConfusionMatrix(Metric): - """ - Computes the `confusion matrix - `_. Works with binary, - multiclass, and multilabel data. Accepts probabilities from a model output or - integer class values in prediction. Works with multi-dimensional preds and - target. - - Note: - This metric produces a multi-dimensional output, so it can not be directly logged. - - Forward accepts - - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. - - Args: - num_classes: Number of classes in the dataset. - normalize: Normalization mode for confusion matrix. Choose from - - - ``None`` or ``'none'``: no normalization (default) - - ``'true'``: normalization over the targets (most commonly used) - - ``'pred'``: normalization over the predictions - - ``'all'``: normalization over the whole matrix - - threshold: - Threshold value for binary or multi-label probabilites. default: 0.5 - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example: - - >>> from pytorch_lightning.metrics import ConfusionMatrix - >>> target = torch.tensor([1, 1, 0, 0]) - >>> preds = torch.tensor([0, 1, 0, 0]) - >>> confmat = ConfusionMatrix(num_classes=2) - >>> confmat(preds, target) - tensor([[2., 0.], - [1., 1.]]) - - """ +class ConfusionMatrix(_ConfusionMatrix): + @deprecated_metrics(target=_ConfusionMatrix) def __init__( self, num_classes: int, @@ -80,35 +30,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - ) - self.num_classes = num_classes - self.normalize = normalize - self.threshold = threshold - - allowed_normalize = ('true', 'pred', 'all', 'none', None) - assert self.normalize in allowed_normalize, \ - f"Argument average needs to one of the following: {allowed_normalize}" - - self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold) - self.confmat += confmat + This implementation refers to :class:`~torchmetrics.ConfusionMatrix`. - def compute(self) -> torch.Tensor: - """ - Computes confusion matrix + .. deprecated:: + Use :class:`~torchmetrics.ConfusionMatrix`. Will be removed in v1.5.0. """ - return _confusion_matrix_compute(self.confmat, self.normalize) diff --git a/pytorch_lightning/metrics/classification/f_beta.py b/pytorch_lightning/metrics/classification/f_beta.py index ae01b80966868..a3f4172f05400 100644 --- a/pytorch_lightning/metrics/classification/f_beta.py +++ b/pytorch_lightning/metrics/classification/f_beta.py @@ -13,72 +13,15 @@ # limitations under the License. from typing import Any, Optional -import torch +from torchmetrics import F1 as _F1 +from torchmetrics import FBeta as _FBeta -from pytorch_lightning.metrics.functional.f_beta import _fbeta_compute, _fbeta_update -from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.metrics.utils import deprecated_metrics -class FBeta(Metric): - r""" - Computes `F-score `_, specifically: - - .. math:: - F_\beta = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} - {(\beta^2 * \text{precision}) + \text{recall}} - - Where :math:`\beta` is some positive real factor. Works with binary, multiclass, and multilabel data. - Accepts probabilities from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. - - Forward accepts - - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. - - Args: - num_classes: Number of classes in the dataset. - beta: Beta coefficient in the F measure. - threshold: - Threshold value for binary or multi-label probabilities. default: 0.5 - - average: - - ``'micro'`` computes metric globally - - ``'macro'`` computes metric for each class and uniformly averages them - - ``'weighted'`` computes metric for each class and does a weighted-average, - where each class is weighted by their support (accounts for class imbalance) - - ``'none'`` or ``None`` computes and returns the metric per class - - multilabel: If predictions are from multilabel classification. - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Raises: - ValueError: - If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"``, ``None``. - - Example: - - >>> from pytorch_lightning.metrics import FBeta - >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) - >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) - >>> f_beta = FBeta(num_classes=3, beta=0.5) - >>> f_beta(preds, target) - tensor(0.3333) - - """ +class FBeta(_FBeta): + @deprecated_metrics(target=_FBeta) def __init__( self, num_classes: int, @@ -90,103 +33,17 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - ) - - self.num_classes = num_classes - self.beta = beta - self.threshold = threshold - self.average = average - self.multilabel = multilabel - - allowed_average = ("micro", "macro", "weighted", "none", None) - if self.average not in allowed_average: - raise ValueError( - 'Argument `average` expected to be one of the following:' - f' {allowed_average} but got {self.average}' - ) - - self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") - self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") - self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - true_positives, predicted_positives, actual_positives = _fbeta_update( - preds, target, self.num_classes, self.threshold, self.multilabel - ) - - self.true_positives += true_positives - self.predicted_positives += predicted_positives - self.actual_positives += actual_positives + This implementation refers to :class:`~torchmetrics.FBeta`. - def compute(self) -> torch.Tensor: + .. deprecated:: + Use :class:`~torchmetrics.FBeta`. Will be removed in v1.5.0. """ - Computes fbeta over state. - """ - return _fbeta_compute( - self.true_positives, self.predicted_positives, self.actual_positives, self.beta, self.average - ) - - -class F1(FBeta): - """ - Computes F1 metric. F1 metrics correspond to a harmonic mean of the - precision and recall scores. - - Works with binary, multiclass, and multilabel data. - Accepts logits from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. - Forward accepts - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. - This is the case for binary and multi-label logits. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. - - Args: - num_classes: Number of classes in the dataset. - threshold: - Threshold value for binary or multi-label logits. default: 0.5 - - average: - - ``'micro'`` computes metric globally - - ``'macro'`` computes metric for each class and uniformly averages them - - ``'weighted'`` computes metric for each class and does a weighted-average, - where each class is weighted by their support (accounts for class imbalance) - - ``'none'`` or ``None`` computes and returns the metric per class - - multilabel: If predictions are from multilabel classification. - compute_on_step: - Forward only calls ``update()`` and returns None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example: - >>> from pytorch_lightning.metrics import F1 - >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) - >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) - >>> f1 = F1(num_classes=3) - >>> f1(preds, target) - tensor(0.3333) - """ +class F1(_F1): + @deprecated_metrics(target=_F1) def __init__( self, num_classes: int, @@ -197,16 +54,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - if multilabel is not False: - rank_zero_warn(f'The `multilabel={multilabel}` parameter is unused and will not have any effect.') + """ + This implementation refers to :class:`~torchmetrics.F1`. - super().__init__( - num_classes=num_classes, - beta=1.0, - threshold=threshold, - average=average, - multilabel=multilabel, - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - ) + .. deprecated:: + Use :class:`~torchmetrics.F1`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/classification/hamming_distance.py b/pytorch_lightning/metrics/classification/hamming_distance.py index 62b4ae824a6d1..d66b0c2d9cfa8 100644 --- a/pytorch_lightning/metrics/classification/hamming_distance.py +++ b/pytorch_lightning/metrics/classification/hamming_distance.py @@ -13,61 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch +from torchmetrics import HammingDistance as _HammingDistance -from pytorch_lightning.metrics.functional.hamming_distance import _hamming_distance_compute, _hamming_distance_update -from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.utils import deprecated_metrics -class HammingDistance(Metric): - r""" - Computes the average `Hamming distance `_ (also - known as Hamming loss) between targets and predictions: - - .. math:: - \text{Hamming distance} = \frac{1}{N \cdot L}\sum_i^N \sum_l^L 1(y_{il} \neq \hat{y_{il}}) - - Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, - and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that - tensor. - - This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it - treats each possible label separately - meaning that, for example, multi-class data is - treated as if it were multi-label. - - Accepts all input types listed in :ref:`extensions/metrics:input types`. - - Args: - threshold: - Threshold probability value for transforming probability predictions to binary - (0 or 1) predictions, in the case of binary or multi-label inputs. - compute_on_step: - Forward only calls ``update()`` and return ``None`` if this is set to ``False``. - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - process_group: - Specify the process group on which synchronization is called. - default: ``None`` (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When ``None``, DDP - will be used to perform the all gather. - - Raises: - ValueError: - If ``threshold`` is not between ``0`` and ``1``. - - Example: - - >>> from pytorch_lightning.metrics import HammingDistance - >>> target = torch.tensor([[0, 1], [1, 1]]) - >>> preds = torch.tensor([[0, 1], [0, 1]]) - >>> hamming_distance = HammingDistance() - >>> hamming_distance(preds, target) - tensor(0.2500) - - """ +class HammingDistance(_HammingDistance): + @deprecated_metrics(target=_HammingDistance) def __init__( self, threshold: float = 0.5, @@ -76,36 +29,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - if not 0 < threshold < 1: - raise ValueError("The `threshold` should lie in the (0,1) interval.") - self.threshold = threshold - - def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. See :ref:`extensions/metrics:input types` for more information - on input types. + This implementation refers to :class:`~torchmetrics.HammingDistance`. - Args: - preds: Predictions from model (probabilities, or labels) - target: Ground truth labels - """ - correct, total = _hamming_distance_update(preds, target, self.threshold) - - self.correct += correct - self.total += total - - def compute(self) -> torch.Tensor: - """ - Computes hamming distance based on inputs passed in to ``update`` previously. + .. deprecated:: + Use :class:`~torchmetrics.HammingDistance`. Will be removed in v1.5.0. """ - return _hamming_distance_compute(self.correct, self.total) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py deleted file mode 100644 index ea6d5722b3041..0000000000000 --- a/pytorch_lightning/metrics/classification/helpers.py +++ /dev/null @@ -1,539 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional, Tuple - -import numpy as np -import torch - -from pytorch_lightning.metrics.utils import select_topk, to_onehot -from pytorch_lightning.utilities import LightningEnum - - -class DataType(LightningEnum): - """ - Enum to represent data type - """ - - BINARY = "binary" - MULTILABEL = "multi-label" - MULTICLASS = "multi-class" - MULTIDIM_MULTICLASS = "multi-dim multi-class" - - -class AverageMethod(LightningEnum): - """ - Enum to represent average method - """ - - MICRO = "micro" - MACRO = "macro" - WEIGHTED = "weighted" - NONE = "none" - SAMPLES = "samples" - - -class MDMCAverageMethod(LightningEnum): - """ - Enum to represent multi-dim multi-class average method - """ - - GLOBAL = "global" - SAMPLEWISE = "samplewise" - - -def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool): - """ - Perform basic validation of inputs that does not require deducing any information - of the type of inputs. - """ - - if target.is_floating_point(): - raise ValueError("The `target` has to be an integer tensor.") - if target.min() < 0: - raise ValueError("The `target` has to be a non-negative tensor.") - - preds_float = preds.is_floating_point() - if not preds_float and preds.min() < 0: - raise ValueError("If `preds` are integers, they have to be non-negative.") - - if not preds.shape[0] == target.shape[0]: - raise ValueError("The `preds` and `target` should have the same first dimension.") - - if preds_float and (preds.min() < 0 or preds.max() > 1): - raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") - - if not 0 < threshold < 1: - raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") - - if is_multiclass is False and target.max() > 1: - raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") - - if is_multiclass is False and not preds_float and preds.max() > 1: - raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") - - -def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[str, int]: - """ - This checks that the shape and type of inputs are consistent with - each other and fall into one of the allowed input types (see the - documentation of docstring of ``_input_format_classification``). It does - not check for consistency of number of classes, other functions take - care of that. - - It returns the name of the case in which the inputs fall, and the implied - number of classes (from the ``C`` dim for multi-class data, or extra dim(s) for - multi-label data). - """ - - preds_float = preds.is_floating_point() - - if preds.ndim == target.ndim: - if preds.shape != target.shape: - raise ValueError( - "The `preds` and `target` should have the same shape,", - f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", - ) - if preds_float and target.max() > 1: - raise ValueError( - "If `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." - ) - - # Get the case - if preds.ndim == 1 and preds_float: - case = DataType.BINARY - elif preds.ndim == 1 and not preds_float: - case = DataType.MULTICLASS - elif preds.ndim > 1 and preds_float: - case = DataType.MULTILABEL - else: - case = DataType.MULTIDIM_MULTICLASS - - implied_classes = preds[0].numel() - - elif preds.ndim == target.ndim + 1: - if not preds_float: - raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") - if preds.shape[2:] != target.shape[1:]: - raise ValueError( - "If `preds` have one dimension more than `target`, the shape of `preds` should be" - " (N, C, ...), and the shape of `target` should be (N, ...)." - ) - - implied_classes = preds.shape[1] - - if preds.ndim == 2: - case = DataType.MULTICLASS - else: - case = DataType.MULTIDIM_MULTICLASS - else: - raise ValueError( - "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" - " and `preds` should be (N, C, ...)." - ) - - return case, implied_classes - - -def _check_num_classes_binary(num_classes: int, is_multiclass: bool): - """ - This checks that the consistency of `num_classes` with the data - and `is_multiclass` param for binary data. - """ - - if num_classes > 2: - raise ValueError("Your data is binary, but `num_classes` is larger than 2.") - if num_classes == 2 and not is_multiclass: - raise ValueError( - "Your data is binary and `num_classes=2`, but `is_multiclass` is not True." - " Set it to True if you want to transform binary data to multi-class format." - ) - if num_classes == 1 and is_multiclass: - raise ValueError( - "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." - " Either set `is_multiclass=None`(default) or set `num_classes=2`" - " to transform binary data to multi-class format." - ) - - -def _check_num_classes_mc( - preds: torch.Tensor, target: torch.Tensor, num_classes: int, is_multiclass: bool, implied_classes: int -): - """ - This checks that the consistency of `num_classes` with the data - and `is_multiclass` param for (multi-dimensional) multi-class data. - """ - - if num_classes == 1 and is_multiclass is not False: - raise ValueError( - "You have set `num_classes=1`, but predictions are integers." - " If you want to convert (multi-dimensional) multi-class data with 2 classes" - " to binary/multi-label, set `is_multiclass=False`." - ) - if num_classes > 1: - if is_multiclass is False: - if implied_classes != num_classes: - raise ValueError( - "You have set `is_multiclass=False`, but the implied number of classes " - " (from shape of inputs) does not match `num_classes`. If you are trying to" - " transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`" - " should be either None or the product of the size of extra dimensions (...)." - " See Input Types in Metrics documentation." - ) - if num_classes <= target.max(): - raise ValueError("The highest label in `target` should be smaller than `num_classes`.") - if num_classes <= preds.max(): - raise ValueError("The highest label in `preds` should be smaller than `num_classes`.") - if preds.shape != target.shape and num_classes != implied_classes: - raise ValueError("The size of C dimension of `preds` does not match `num_classes`.") - - -def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes: int): - """ - This checks that the consistency of `num_classes` with the data - and `is_multiclass` param for multi-label data. - """ - - if is_multiclass and num_classes != 2: - raise ValueError( - "Your have set `is_multiclass=True`, but `num_classes` is not equal to 2." - " If you are trying to transform multi-label data to 2 class multi-dimensional" - " multi-class, you should set `num_classes` to either 2 or None." - ) - if not is_multiclass and num_classes != implied_classes: - raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") - - -def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool): - if case == DataType.BINARY: - raise ValueError("You can not use `top_k` parameter with binary data.") - if not isinstance(top_k, int) or top_k <= 0: - raise ValueError("The `top_k` has to be an integer larger than 0.") - if not preds_float: - raise ValueError("You have set `top_k`, but you do not have probability predictions.") - if is_multiclass is False: - raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.") - if case == DataType.MULTILABEL and is_multiclass: - raise ValueError( - "If you want to transform multi-label data to 2 class multi-dimensional" - "multi-class data using `is_multiclass=True`, you can not use `top_k`." - ) - if top_k >= implied_classes: - raise ValueError("The `top_k` has to be strictly smaller than the `C` dimension of `preds`.") - - -def _check_classification_inputs( - preds: torch.Tensor, - target: torch.Tensor, - threshold: float, - num_classes: Optional[int], - is_multiclass: bool, - top_k: Optional[int], -) -> str: - """Performs error checking on inputs for classification. - - This ensures that preds and target take one of the shape/type combinations that are - specified in ``_input_format_classification`` docstring. It also checks the cases of - over-rides with ``is_multiclass`` by checking (for multi-class and multi-dim multi-class - cases) that there are only up to 2 distinct labels. - - In case where preds are floats (probabilities), it is checked whether they are in [0,1] interval. - - When ``num_classes`` is given, it is checked that it is consitent with input cases (binary, - multi-label, ...), and that, if availible, the implied number of classes in the ``C`` - dimension is consistent with it (as well as that max label in target is smaller than it). - - When ``num_classes`` is not specified in these cases, consistency of the highest target - value against ``C`` dimension is checked for (multi-dimensional) multi-class cases. - - If ``top_k`` is set (not None) for inputs that do not have probability predictions (and - are not binary), an error is raised. Similarly if ``top_k`` is set to a number that - is higher than or equal to the ``C`` dimension of ``preds``, an error is raised. - - Preds and target tensors are expected to be squeezed already - all dimensions should be - greater than 1, except perhaps the first one (``N``). - - Args: - preds: Tensor with predictions (labels or probabilities) - target: Tensor with ground truth labels, always integers (labels) - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. - num_classes: - Number of classes. If not explicitly set, the number of classes will be infered - either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` - tensor, where applicable. - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. The default value (``None``) will be - interepreted as 1 for these inputs. If this parameter is set for multi-label inputs, - it will take precedence over threshold. - - Should be left unset (``None``) for inputs with label predictions. - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. - - - Return: - case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or - 'multi-dim multi-class' - """ - - # Baisc validation (that does not need case/type information) - _basic_input_validation(preds, target, threshold, is_multiclass) - - # Check that shape/types fall into one of the cases - case, implied_classes = _check_shape_and_type_consistency(preds, target) - - # For (multi-dim) multi-class case with prob preds, check that preds sum up to 1 - if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and preds.is_floating_point(): - if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all(): - raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.") - - # Check consistency with the `C` dimension in case of multi-class data - if preds.shape != target.shape: - if is_multiclass is False and implied_classes != 2: - raise ValueError( - "You have set `is_multiclass=False`, but have more than 2 classes in your data," - " based on the C dimension of `preds`." - ) - if target.max() >= implied_classes: - raise ValueError( - "The highest label in `target` should be smaller than the size of the `C` dimension of `preds`." - ) - - # Check that num_classes is consistent - if num_classes: - if case == DataType.BINARY: - _check_num_classes_binary(num_classes, is_multiclass) - elif case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS): - _check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes) - elif case.MULTILABEL: - _check_num_classes_ml(num_classes, is_multiclass, implied_classes) - - # Check that top_k is consistent - if top_k is not None: - _check_top_k(top_k, case, implied_classes, is_multiclass, preds.is_floating_point()) - - return case - - -def _input_format_classification( - preds: torch.Tensor, - target: torch.Tensor, - threshold: float = 0.5, - top_k: Optional[int] = None, - num_classes: Optional[int] = None, - is_multiclass: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor, str]: - """Convert preds and target tensors into common format. - - Preds and targets are supposed to fall into one of these categories (and are - validated to make sure this is the case): - - * Both preds and target are of shape ``(N,)``, and both are integers (multi-class) - * Both preds and target are of shape ``(N,)``, and target is binary, while preds - are a float (binary) - * preds are of shape ``(N, C)`` and are floats, and target is of shape ``(N,)`` and - is integer (multi-class) - * preds and target are of shape ``(N, ...)``, target is binary and preds is a float - (multi-label) - * preds are of shape ``(N, C, ...)`` and are floats, target is of shape ``(N, ...)`` - and is integer (multi-dimensional multi-class) - * preds and target are of shape ``(N, ...)`` both are integers (multi-dimensional - multi-class) - - To avoid ambiguities, all dimensions of size 1, except the first one, are squeezed out. - - The returned output tensors will be binary tensors of the same shape, either ``(N, C)`` - of ``(N, C, X)``, the details for each case are described below. The function also returns - a ``case`` string, which describes which of the above cases the inputs belonged to - regardless - of whether this was "overridden" by other settings (like ``is_multiclass``). - - In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed - into a binary tensor (elements become 1 if the probability is greater than or equal to - ``threshold`` or 0 otherwise). If ``is_multiclass=True``, then then both targets are preds - become ``(N, 2)`` tensors by a one-hot transformation; with the thresholding being applied to - preds first. - - In multi-class case, normally both preds and targets become ``(N, C)`` binary tensors; targets - by a one-hot transformation and preds by selecting ``top_k`` largest entries (if their original - shape was ``(N,C)``). However, if ``is_multiclass=False``, then targets and preds will be - returned as ``(N,1)`` tensor. - - In multi-label case, normally targets and preds are returned as ``(N, C)`` binary tensors, with - preds being binarized as in the binary case. Here the ``C`` dimension is obtained by flattening - all dimensions after the first one. However if ``is_multiclass=True``, then both are returned as - ``(N, 2, C)``, by an equivalent transformation as in the binary case. - - In multi-dimensional multi-class case, normally both target and preds are returned as - ``(N, C, X)`` tensors, with ``X`` resulting from flattening of all dimensions except ``N`` and - ``C``. The transformations performed here are equivalent to the multi-class case. However, if - ``is_multiclass=False`` (and there are up to two classes), then the data is returned as - ``(N, X)`` binary tensors (multi-label). - - Note that where a one-hot transformation needs to be performed and the number of classes - is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be - equal to ``num_classes``, if it is given, or the maximum label value in preds and - target. - - Args: - preds: Tensor with predictions (labels or probabilities) - target: Tensor with ground truth labels, always integers (labels) - threshold: - Threshold probability value for transforming probability predictions to binary - (0 or 1) predictions, in the case of binary or multi-label inputs. - num_classes: - Number of classes. If not explicitly set, the number of classes will be infered - either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` - tensor, where applicable. - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for (multi-dimensional) multi-class inputs with probability predictions. The - default value (``None``) will be interepreted as 1 for these inputs. - - Should be left unset (``None``) for all other types of inputs. - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. - - - Returns: - preds: binary tensor of shape ``(N, C)`` or ``(N, C, X)`` - target: binary tensor of shape ``(N, C)`` or ``(N, C, X)`` - case: The case the inputs fall in, one of ``'binary'``, ``'multi-class'``, ``'multi-label'`` or - ``'multi-dim multi-class'`` - """ - # Remove excess dimensions - if preds.shape[0] == 1: - preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0) - else: - preds, target = preds.squeeze(), target.squeeze() - - # Convert half precision tensors to full precision, as not all ops are supported - # for example, min() is not supported - if preds.dtype == torch.float16: - preds = preds.float() - - case = _check_classification_inputs( - preds, - target, - threshold=threshold, - num_classes=num_classes, - is_multiclass=is_multiclass, - top_k=top_k, - ) - - if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k: - preds = (preds >= threshold).int() - num_classes = num_classes if not is_multiclass else 2 - - if case == DataType.MULTILABEL and top_k: - preds = select_topk(preds, top_k) - - if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or is_multiclass: - if preds.is_floating_point(): - num_classes = preds.shape[1] - preds = select_topk(preds, top_k or 1) - else: - num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1 - preds = to_onehot(preds, max(2, num_classes)) - - target = to_onehot(target, max(2, num_classes)) - - if is_multiclass is False: - preds, target = preds[:, 1, ...], target[:, 1, ...] - - if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and is_multiclass is not False) or is_multiclass: - target = target.reshape(target.shape[0], target.shape[1], -1) - preds = preds.reshape(preds.shape[0], preds.shape[1], -1) - else: - target = target.reshape(target.shape[0], -1) - preds = preds.reshape(preds.shape[0], -1) - - # Some operatins above create an extra dimension for MC/binary case - this removes it - if preds.ndim > 2: - preds, target = preds.squeeze(-1), target.squeeze(-1) - - return preds.int(), target.int(), case - - -def _reduce_stat_scores( - numerator: torch.Tensor, - denominator: torch.Tensor, - weights: Optional[torch.Tensor], - average: str, - mdmc_average: Optional[str], - zero_division: int = 0, -) -> torch.Tensor: - """ - Reduces scores of type ``numerator/denominator`` or - ``weights * (numerator/denominator)``, if ``average='weighted'``. - - Args: - numerator: A tensor with numerator numbers. - denominator: A tensor with denominator numbers. If a denominator is - negative, the class will be ignored (if averaging), or its score - will be returned as ``nan`` (if ``average=None``). - If the denominator is zero, then ``zero_division`` score will be - used for those elements. - weights: - A tensor of weights to be used if ``average='weighted'``. - average: - The method to average the scores. Should be one of ``'micro'``, ``'macro'``, - ``'weighted'``, ``'none'``, ``None`` or ``'samples'``. The behavior - corresponds to `sklearn averaging methods `__. - mdmc_average: - The method to average the scores if inputs were multi-dimensional multi-class (MDMC). - Should be either ``'global'`` or ``'samplewise'``. If inputs were not - multi-dimensional multi-class, it should be ``None`` (default). - zero_division: - The value to use for the score if denominator equals zero. - """ - numerator, denominator = numerator.float(), denominator.float() - zero_div_mask = denominator == 0 - ignore_mask = denominator < 0 - - if weights is None: - weights = torch.ones_like(denominator) - else: - weights = weights.float() - - numerator = torch.where(zero_div_mask, torch.tensor(float(zero_division), device=numerator.device), numerator) - denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator) - weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights) - - if average not in (AverageMethod.MICRO, AverageMethod.NONE, None): - weights = weights / weights.sum(dim=-1, keepdim=True) - - scores = weights * (numerator / denominator) - - # This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted' - scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores) - - if mdmc_average == MDMCAverageMethod.SAMPLEWISE: - scores = scores.mean(dim=0) - ignore_mask = ignore_mask.sum(dim=0).bool() - - if average in (AverageMethod.NONE, None): - scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores) - else: - scores = scores.sum() - - return scores diff --git a/pytorch_lightning/metrics/classification/iou.py b/pytorch_lightning/metrics/classification/iou.py index a261b767a8190..f1d9d0945511a 100644 --- a/pytorch_lightning/metrics/classification/iou.py +++ b/pytorch_lightning/metrics/classification/iou.py @@ -13,70 +13,14 @@ # limitations under the License. from typing import Any, Optional -import torch +from torchmetrics import IoU as _IoU -from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix -from pytorch_lightning.metrics.functional.iou import _iou_from_confmat +from pytorch_lightning.metrics.utils import deprecated_metrics -class IoU(ConfusionMatrix): - r""" - Computes `Intersection over union, or Jaccard index calculation `_: - - .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} - - Where: :math:`A` and :math:`B` are both tensors of the same size, containing integer class values. - They may be subject to conversion from input data (see description below). Note that it is different from box IoU. - - Works with binary, multiclass and multi-label data. - Accepts probabilities from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. - - Forward accepts - - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. - - Args: - num_classes: Number of classes in the dataset. - ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. Has no effect if given an int that is not in the - range [0, num_classes-1]. By default, no index is ignored, and all classes are used. - absent_score: score to use for an individual class, if no instances of the class index were present in - `pred` AND no instances of the class index were present in `target`. For example, if we have 3 classes, - [0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be assigned the `absent_score`. - threshold: - Threshold value for binary or multi-label probabilities. - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example: - >>> from pytorch_lightning.metrics import IoU - >>> target = torch.randint(0, 2, (10, 25, 25)) - >>> pred = torch.tensor(target) - >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] - >>> iou = IoU(num_classes=2) - >>> iou(pred, target) - tensor(0.9660) - - """ +class IoU(_IoU): + @deprecated_metrics(target=_IoU) def __init__( self, num_classes: int, @@ -88,20 +32,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - super().__init__( - num_classes=num_classes, - normalize=None, - threshold=threshold, - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - ) - self.reduction = reduction - self.ignore_index = ignore_index - self.absent_score = absent_score - - def compute(self) -> torch.Tensor: """ - Computes intersection over union (IoU) + This implementation refers to :class:`~torchmetrics.IoU`. + + .. deprecated:: + Use :class:`~torchmetrics.IoU`. Will be removed in v1.5.0. """ - return _iou_from_confmat(self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction) diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 11862769e62a8..7b95d21dae97c 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -13,116 +13,15 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch +from torchmetrics import Precision as _Precision +from torchmetrics import Recall as _Recall -from pytorch_lightning.metrics.classification.stat_scores import StatScores -from pytorch_lightning.metrics.functional.precision_recall import _precision_compute, _recall_compute +from pytorch_lightning.metrics.utils import deprecated_metrics -class Precision(StatScores): - r""" - Computes `Precision `_: - - .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} - - Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and - false positives respecitively. With the use of ``top_k`` parameter, this metric can - generalize to Precision@K. - - The reduction method (how the precision scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. - - Args: - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics accross classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics accross classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - Note that what is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - multilabel: - .. warning :: This parameter is deprecated and has no effect. Will be removed in v1.4.0. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`extensions/metrics:input types`) - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. - - Should be left unset (``None``) for inputs with label predictions. - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. - - compute_on_step: - Forward only calls ``update()`` and return ``None`` if this is set to ``False``. - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step - process_group: - Specify the process group on which synchronization is called. - default: ``None`` (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When ``None``, DDP - will be used to perform the allgather. - - Raises: - ValueError: - If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. - - Example: - - >>> from pytorch_lightning.metrics import Precision - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> precision = Precision(average='macro', num_classes=3) - >>> precision(preds, target) - tensor(0.1667) - >>> precision = Precision(average='micro') - >>> precision(preds, target) - tensor(0.2500) - - """ +class Precision(_Precision): + @deprecated_metrics(target=_Precision) def __init__( self, num_classes: Optional[int] = None, @@ -138,146 +37,17 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - super().__init__( - reduce="macro" if average in ["weighted", "none", None] else average, - mdmc_reduce=mdmc_average, - threshold=threshold, - top_k=top_k, - num_classes=num_classes, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.average = average - - def compute(self) -> torch.Tensor: """ - Computes the precision score based on inputs passed in to ``update`` previously. + This implementation refers to :class:`~torchmetrics.Precision`. - Return: - The shape of the returned tensor depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes + .. deprecated:: + Use :class:`~torchmetrics.Precision`. Will be removed in v1.5.0. """ - tp, fp, tn, fn = self._get_final_stats() - return _precision_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) - -class Recall(StatScores): - r""" - Computes `Recall `_: - .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} - - Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and - false negatives respecitively. With the use of ``top_k`` parameter, this metric can - generalize to Recall@K. - - The reduction method (how the recall scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. - - Args: - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics accross classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics accross classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - Note that what is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - multilabel: - .. warning :: This parameter is deprecated and has no effect. Will be removed in v1.4.0. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`extensions/metrics:input types`) - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. - - Should be left unset (``None``) for inputs with label predictions. - - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. - - compute_on_step: - Forward only calls ``update()`` and return ``None`` if this is set to ``False``. - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step - process_group: - Specify the process group on which synchronization is called. - default: ``None`` (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When ``None``, DDP - will be used to perform the allgather. - - Raises: - ValueError: - If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. - - Example: - - >>> from pytorch_lightning.metrics import Recall - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> recall = Recall(average='macro', num_classes=3) - >>> recall(preds, target) - tensor(0.3333) - >>> recall = Recall(average='micro') - >>> recall(preds, target) - tensor(0.2500) - - """ +class Recall(_Recall): + @deprecated_metrics(target=_Recall) def __init__( self, num_classes: Optional[int] = None, @@ -293,36 +63,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - super().__init__( - reduce="macro" if average in ["weighted", "none", None] else average, - mdmc_reduce=mdmc_average, - threshold=threshold, - top_k=top_k, - num_classes=num_classes, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.average = average - - def compute(self) -> torch.Tensor: """ - Computes the recall score based on inputs passed in to ``update`` previously. - - Return: - The shape of the returned tensor depends on the ``average`` parameter + This implementation refers to :class:`~torchmetrics.Recall`. - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes + .. deprecated:: + Use :class:`~torchmetrics.Recall`. Will be removed in v1.5.0. """ - tp, fp, tn, fn = self._get_final_stats() - return _recall_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) diff --git a/pytorch_lightning/metrics/classification/precision_recall_curve.py b/pytorch_lightning/metrics/classification/precision_recall_curve.py index 5a02a99ed17fd..285cb2fb78ccc 100644 --- a/pytorch_lightning/metrics/classification/precision_recall_curve.py +++ b/pytorch_lightning/metrics/classification/precision_recall_curve.py @@ -11,80 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional -import torch +from torchmetrics import PrecisionRecallCurve as _PrecisionRecallCurve -from pytorch_lightning.metrics.functional.precision_recall_curve import ( - _precision_recall_curve_compute, - _precision_recall_curve_update, -) -from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.metrics.utils import deprecated_metrics -class PrecisionRecallCurve(Metric): - """ - Computes precision-recall pairs for different thresholds. Works for both - binary and multiclass problems. In the case of multiclass, the values will - be calculated based on a one-vs-the-rest approach. - - Forward accepts - - - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor - with probabilities, where C is the number of classes. - - - ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels - - Args: - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example (binary case): - - >>> from pytorch_lightning.metrics import PrecisionRecallCurve - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 0]) - >>> pr_curve = PrecisionRecallCurve(pos_label=1) - >>> precision, recall, thresholds = pr_curve(pred, target) - >>> precision - tensor([0.6667, 0.5000, 0.0000, 1.0000]) - >>> recall - tensor([1.0000, 0.5000, 0.0000, 0.0000]) - >>> thresholds - tensor([1, 2, 3]) - - Example (multiclass case): - - >>> from pytorch_lightning.metrics import PrecisionRecallCurve - >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.05, 0.75, 0.05]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> pr_curve = PrecisionRecallCurve(num_classes=5) - >>> precision, recall, thresholds = pr_curve(pred, target) - >>> precision # doctest: +NORMALIZE_WHITESPACE - [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), - tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] - >>> recall - [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] - >>> thresholds - [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] - - """ +class PrecisionRecallCurve(_PrecisionRecallCurve): + @deprecated_metrics(target=_PrecisionRecallCurve) def __init__( self, num_classes: Optional[int] = None, @@ -93,60 +29,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - ) - - self.num_classes = num_classes - self.pos_label = pos_label - - self.add_state("preds", default=[], dist_reduce_fx=None) - self.add_state("target", default=[], dist_reduce_fx=None) - - rank_zero_warn( - 'Metric `PrecisionRecallCurve` will save all targets and predictions in buffer.' - ' For large datasets this may lead to large memory footprint.' - ) - - def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values - """ - preds, target, num_classes, pos_label = _precision_recall_curve_update( - preds, target, self.num_classes, self.pos_label - ) - self.preds.append(preds) - self.target.append(target) - self.num_classes = num_classes - self.pos_label = pos_label - - def compute( - self - ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: - """ - Compute the precision-recall curve - - Returns: - 3-element tuple containing + This implementation refers to :class:`~torchmetrics.PrecisionRecallCurve`. - precision: - tensor where element i is the precision of predictions with - score >= thresholds[i] and the last element is 1. - If multiclass, this is a list of such tensors, one for each class. - recall: - tensor where element i is the recall of predictions with - score >= thresholds[i] and the last element is 0. - If multiclass, this is a list of such tensors, one for each class. - thresholds: - Thresholds used for computing precision/recall scores + .. deprecated:: + Use :class:`~torchmetrics.PrecisionRecallCurve`. Will be removed in v1.5.0. """ - preds = torch.cat(self.preds, dim=0) - target = torch.cat(self.target, dim=0) - return _precision_recall_curve_compute(preds, target, self.num_classes, self.pos_label) diff --git a/pytorch_lightning/metrics/classification/roc.py b/pytorch_lightning/metrics/classification/roc.py index 598646cde3861..3f6cf50803c86 100644 --- a/pytorch_lightning/metrics/classification/roc.py +++ b/pytorch_lightning/metrics/classification/roc.py @@ -11,79 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional -import torch +from torchmetrics import ROC as _ROC -from pytorch_lightning.metrics.functional.roc import _roc_compute, _roc_update -from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.metrics.utils import deprecated_metrics -class ROC(Metric): - """ - Computes the Receiver Operating Characteristic (ROC). Works for both - binary and multiclass problems. In the case of multiclass, the values will - be calculated based on a one-vs-the-rest approach. - - Forward accepts - - - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor - with probabilities, where C is the number of classes. - - - ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels - - Args: - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example (binary case): - - >>> from pytorch_lightning.metrics import ROC - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 1]) - >>> roc = ROC(pos_label=1) - >>> fpr, tpr, thresholds = roc(pred, target) - >>> fpr - tensor([0., 0., 0., 0., 1.]) - >>> tpr - tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) - >>> thresholds - tensor([4, 3, 2, 1, 0]) - - Example (multiclass case): - - >>> from pytorch_lightning.metrics import ROC - >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05], - ... [0.05, 0.05, 0.05, 0.75]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> roc = ROC(num_classes=4) - >>> fpr, tpr, thresholds = roc(pred, target) - >>> fpr - [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] - >>> tpr - [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])] - >>> thresholds # doctest: +NORMALIZE_WHITESPACE - [tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500])] - - """ +class ROC(_ROC): + @deprecated_metrics(target=_ROC) def __init__( self, num_classes: Optional[int] = None, @@ -92,56 +29,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - ) - - self.num_classes = num_classes - self.pos_label = pos_label - - self.add_state("preds", default=[], dist_reduce_fx=None) - self.add_state("target", default=[], dist_reduce_fx=None) - - rank_zero_warn( - 'Metric `ROC` will save all targets and predictions in buffer.' - ' For large datasets this may lead to large memory footprint.' - ) - - def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values - """ - preds, target, num_classes, pos_label = _roc_update(preds, target, self.num_classes, self.pos_label) - self.preds.append(preds) - self.target.append(target) - self.num_classes = num_classes - self.pos_label = pos_label - - def compute( - self - ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: - """ - Compute the receiver operating characteristic - - Returns: - 3-element tuple containing + This implementation refers to :class:`~torchmetrics.ROC`. - fpr: - tensor with false positive rates. - If multiclass, this is a list of such tensors, one for each class. - tpr: - tensor with true positive rates. - If multiclass, this is a list of such tensors, one for each class. - thresholds: - thresholds used for computing false- and true postive rates + .. deprecated:: + Use :class:`~torchmetrics.ROC`. Will be removed in v1.5.0. """ - preds = torch.cat(self.preds, dim=0) - target = torch.cat(self.target, dim=0) - return _roc_compute(preds, target, self.num_classes, self.pos_label) diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py index 4ac47ea466ada..1eed815d4b4cd 100644 --- a/pytorch_lightning/metrics/classification/stat_scores.py +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -11,125 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional -import torch +from torchmetrics import StatScores as _StatScores -from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_compute, _stat_scores_update -from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.utils import deprecated_metrics -class StatScores(Metric): - """Computes the number of true positives, false positives, true negatives, false negatives. - Related to `Type I and Type II errors `__ - and the `confusion matrix `__. - - The reduction method (how the statistics are aggregated) is controlled by the - ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the - multi-dimensional multi-class case. - - Accepts all inputs listed in :ref:`extensions/metrics:input types`. - - Args: - threshold: - Threshold probability value for transforming probability predictions to binary - (0 or 1) predictions, in the case of binary or multi-label inputs. - - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. - - Should be left unset (``None``) for inputs with label predictions. - - reduce: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Counts the statistics by summing over all [sample, class] - combinations (globally). Each statistic is represented by a single integer. - - ``'macro'``: Counts the statistics for each class separately (over all samples). - Each statistic is represented by a ``(C,)`` tensor. Requires ``num_classes`` - to be set. - - ``'samples'``: Counts the statistics for each sample separately (over all classes). - Each statistic is represented by a ``(N, )`` 1d tensor. - - Note that what is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_reduce``. - - num_classes: - Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. - - ignore_index: - Specify a class (label) to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and - ``reduce='macro'``, the class statistics for the ignored class will all be returned - as ``-1``. - - mdmc_reduce: - Defines how the multi-dimensional multi-class inputs are handeled. Should be - one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class (see :ref:`extensions/metrics:input types` for the definition of input types). - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then the outputs are concatenated together. In each - sample the extra axes ``...`` are flattened to become the sub-sample axis, and - statistics for each sample are computed by treating the sub-sample axis as the - ``N`` axis for that sample. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs are - flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``reduce`` parameter applies as usual. - - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. - - compute_on_step: - Forward only calls ``update()`` and return ``None`` if this is set to ``False``. - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step - process_group: - Specify the process group on which synchronization is called. - default: ``None`` (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When ``None``, DDP - will be used to perform the allgather. - - Raises: - ValueError: - If ``threshold`` is not a ``float`` between ``0`` and ``1``. - ValueError: - If ``reduce`` is none of ``"micro"``, ``"macro"`` or ``"samples"``. - ValueError: - If ``mdmc_reduce`` is none of ``None``, ``"samplewise"``, ``"global"``. - ValueError: - If ``reduce`` is set to ``"macro"`` and ``num_classes`` is not provided. - ValueError: - If ``num_classes`` is set - and ``ignore_index`` is not in the range ``0`` <= ``ignore_index`` < ``num_classes``. - - Example: - - >>> from pytorch_lightning.metrics.classification import StatScores - >>> preds = torch.tensor([1, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> stat_scores = StatScores(reduce='macro', num_classes=3) - >>> stat_scores(preds, target) - tensor([[0, 1, 2, 1, 1], - [1, 1, 1, 1, 2], - [1, 0, 3, 0, 1]]) - >>> stat_scores = StatScores(reduce='micro') - >>> stat_scores(preds, target) - tensor([2, 2, 6, 2, 4]) - - """ +class StatScores(_StatScores): + @deprecated_metrics(target=_StatScores) def __init__( self, threshold: float = 0.5, @@ -144,129 +35,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.reduce = reduce - self.mdmc_reduce = mdmc_reduce - self.num_classes = num_classes - self.threshold = threshold - self.is_multiclass = is_multiclass - self.ignore_index = ignore_index - self.top_k = top_k - - if not 0 < threshold < 1: - raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") - - if reduce not in ["micro", "macro", "samples"]: - raise ValueError(f"The `reduce` {reduce} is not valid.") - - if mdmc_reduce not in [None, "samplewise", "global"]: - raise ValueError(f"The `mdmc_reduce` {mdmc_reduce} is not valid.") - - if reduce == "macro" and (not num_classes or num_classes < 1): - raise ValueError("When you set `reduce` as 'macro', you have to provide the number of classes.") - - if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - if mdmc_reduce != "samplewise" and reduce != "samples": - if reduce == "micro": - zeros_shape = [] - elif reduce == "macro": - zeros_shape = (num_classes, ) - default, reduce_fn = lambda: torch.zeros(zeros_shape, dtype=torch.long), "sum" - else: - default, reduce_fn = lambda: [], None - - for s in ("tp", "fp", "tn", "fn"): - self.add_state(s, default=default(), dist_reduce_fx=reduce_fn) - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. See :ref:`extensions/metrics:input types` for more information - on input types. - - Args: - preds: Predictions from model (probabilities or labels) - target: Ground truth values - """ - - tp, fp, tn, fn = _stat_scores_update( - preds, - target, - reduce=self.reduce, - mdmc_reduce=self.mdmc_reduce, - threshold=self.threshold, - num_classes=self.num_classes, - top_k=self.top_k, - is_multiclass=self.is_multiclass, - ignore_index=self.ignore_index, - ) - - # Update states - if self.reduce != "samples" and self.mdmc_reduce != "samplewise": - self.tp += tp - self.fp += fp - self.tn += tn - self.fn += fn - else: - self.tp.append(tp) - self.fp.append(fp) - self.tn.append(tn) - self.fn.append(fn) - - def _get_final_stats(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Performs concatenation on the stat scores if neccesary, - before passing them to a compute function. """ + This implementation refers to :class:`~torchmetrics.StatScores`. - if isinstance(self.tp, list): - tp = torch.cat(self.tp) - fp = torch.cat(self.fp) - tn = torch.cat(self.tn) - fn = torch.cat(self.fn) - else: - tp, fp, tn, fn = self.tp, self.fp, self.tn, self.fn - - return tp, fp, tn, fn - - def compute(self) -> torch.Tensor: - """ - Computes the stat scores based on inputs passed in to ``update`` previously. - - Return: - The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds - to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The - shape depends on the ``reduce`` and ``mdmc_reduce`` (in case of multi-dimensional - multi-class data) parameters: - - - If the data is not multi-dimensional multi-class, then - - - If ``reduce='micro'``, the shape will be ``(5, )`` - - If ``reduce='macro'``, the shape will be ``(C, 5)``, - where ``C`` stands for the number of classes - - If ``reduce='samples'``, the shape will be ``(N, 5)``, where ``N`` stands for - the number of samples - - - If the data is multi-dimensional multi-class and ``mdmc_reduce='global'``, then - - - If ``reduce='micro'``, the shape will be ``(5, )`` - - If ``reduce='macro'``, the shape will be ``(C, 5)`` - - If ``reduce='samples'``, the shape will be ``(N*X, 5)``, where ``X`` stands for - the product of sizes of all "extra" dimensions of the data (i.e. all dimensions - except for ``C`` and ``N``) - - - If the data is multi-dimensional multi-class and ``mdmc_reduce='samplewise'``, then - - - If ``reduce='micro'``, the shape will be ``(N, 5)`` - - If ``reduce='macro'``, the shape will be ``(N, C, 5)`` - - If ``reduce='samples'``, the shape will be ``(N, X, 5)`` - + .. deprecated:: + Use :class:`~torchmetrics.StatScores`. Will be removed in v1.5.0. """ - tp, fp, tn, fn = self._get_final_stats() - return _stat_scores_compute(tp, fp, tn, fn) diff --git a/pytorch_lightning/metrics/compositional.py b/pytorch_lightning/metrics/compositional.py index df98d16a3ef7e..56bb1912e48e6 100644 --- a/pytorch_lightning/metrics/compositional.py +++ b/pytorch_lightning/metrics/compositional.py @@ -1,16 +1,28 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Callable, Union import torch +from torchmetrics import Metric +from torchmetrics.metric import CompositionalMetric as _CompositionalMetric -from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.utils import deprecated_metrics -class CompositionalMetric(Metric): - """Composition of two metrics with a specific operator - which will be executed upon metric's compute - - """ +class CompositionalMetric(_CompositionalMetric): + @deprecated_metrics(target=_CompositionalMetric) def __init__( self, operator: Callable, @@ -18,75 +30,6 @@ def __init__( metric_b: Union[Metric, int, float, torch.Tensor, None], ): """ - - Args: - operator: the operator taking in one (if metric_b is None) - or two arguments. Will be applied to outputs of metric_a.compute() - and (optionally if metric_b is not None) metric_b.compute() - metric_a: first metric whose compute() result is the first argument of operator - metric_b: second metric whose compute() result is the second argument of operator. - For operators taking in only one input, this should be None + .. deprecated:: + Use :class:`torchmetrics.metric.CompositionalMetric`. Will be removed in v1.5.0. """ - super().__init__() - - self.op = operator - - if isinstance(metric_a, torch.Tensor): - self.register_buffer("metric_a", metric_a) - else: - self.metric_a = metric_a - - if isinstance(metric_b, torch.Tensor): - self.register_buffer("metric_b", metric_b) - else: - self.metric_b = metric_b - - def _sync_dist(self, dist_sync_fn=None): - # No syncing required here. syncing will be done in metric_a and metric_b - pass - - def update(self, *args, **kwargs): - if isinstance(self.metric_a, Metric): - self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs)) - - if isinstance(self.metric_b, Metric): - self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs)) - - def compute(self): - - # also some parsing for kwargs? - if isinstance(self.metric_a, Metric): - val_a = self.metric_a.compute() - else: - val_a = self.metric_a - - if isinstance(self.metric_b, Metric): - val_b = self.metric_b.compute() - else: - val_b = self.metric_b - - if val_b is None: - return self.op(val_a) - - return self.op(val_a, val_b) - - def reset(self): - if isinstance(self.metric_a, Metric): - self.metric_a.reset() - - if isinstance(self.metric_b, Metric): - self.metric_b.reset() - - def persistent(self, mode: bool = False): - if isinstance(self.metric_a, Metric): - self.metric_a.persistent(mode=mode) - if isinstance(self.metric_b, Metric): - self.metric_b.persistent(mode=mode) - - def __repr__(self): - repr_str = ( - self.__class__.__name__ - + f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)" - ) - - return repr_str diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index b51ce2e678996..69fa9d75590e0 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -11,43 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch +from torchmetrics.functional import accuracy as _accuracy -from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType - - -def _accuracy_update( - preds: torch.Tensor, target: torch.Tensor, threshold: float, top_k: Optional[int], subset_accuracy: bool -) -> Tuple[torch.Tensor, torch.Tensor]: - - preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k) - - if mode == DataType.MULTILABEL and top_k: - raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.") - - if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy): - correct = (preds == target).all(dim=1).sum() - total = torch.tensor(target.shape[0], device=target.device) - elif mode == DataType.MULTILABEL and not subset_accuracy: - correct = (preds == target).sum() - total = torch.tensor(target.numel(), device=target.device) - elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy): - correct = (preds * target).sum() - total = target.sum() - elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy: - sample_correct = (preds * target).sum(dim=(1, 2)) - correct = (sample_correct == target.shape[2]).sum() - total = torch.tensor(target.shape[0], device=target.device) - - return correct, total - - -def _accuracy_compute(correct: torch.Tensor, total: torch.Tensor) -> torch.Tensor: - return correct.float() / total +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_accuracy) def accuracy( preds: torch.Tensor, target: torch.Tensor, @@ -55,66 +27,7 @@ def accuracy( top_k: Optional[int] = None, subset_accuracy: bool = False, ) -> torch.Tensor: - r"""Computes `Accuracy `_: - - .. math:: - \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a - tensor of predictions. - - For multi-class and multi-dimensional multi-class data with probability predictions, the - parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the - top-K highest probability items are considered to find the correct label. - - For multi-label and multi-dimensional multi-class inputs, this metric computes the "global" - accuracy by default, which counts all labels or sub-samples separately. This can be - changed to subset accuracy (which requires all labels or sub-samples in the sample to - be correctly predicted) by setting ``subset_accuracy=True``. - - Accepts all input types listed in :ref:`extensions/metrics:input types`. - - Args: - preds: Predictions from model (probabilities, or labels) - target: Ground truth labels - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. - top_k: - Number of highest probability predictions considered to find the correct label, relevant - only for (multi-dimensional) multi-class inputs with probability predictions. The - default value (``None``) will be interpreted as 1 for these inputs. - - Should be left at default (``None``) for all other types of inputs. - subset_accuracy: - Whether to compute subset accuracy for multi-label and multi-dimensional - multi-class inputs (has no effect for other input types). - - - For multi-label inputs, if the parameter is set to ``True``, then all labels for - each sample must be correctly predicted for the sample to count as correct. If it - is set to ``False``, then all labels are counted separately - this is equivalent to - flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). - - - For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all - sub-sample (on the extra axis) must be correct for the sample to be counted as correct. - If it is set to ``False``, then all sub-samples are counter separately - this is equivalent, - in the case of label predictions, to flattening the inputs beforehand (i.e. - ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter - still applies in both cases, if set. - - Example: - - >>> from pytorch_lightning.metrics.functional import accuracy - >>> target = torch.tensor([0, 1, 2, 3]) - >>> preds = torch.tensor([0, 2, 1, 3]) - >>> accuracy(preds, target) - tensor(0.5000) - - >>> target = torch.tensor([0, 1, 2]) - >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) - >>> accuracy(preds, target, top_k=2) - tensor(0.6667) """ - - correct, total = _accuracy_update(preds, target, threshold, top_k, subset_accuracy) - return _accuracy_compute(correct, total) + .. deprecated:: + Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/auc.py b/pytorch_lightning/metrics/functional/auc.py index 57ff9fe97fac2..7cc6aa458d397 100644 --- a/pytorch_lightning/metrics/functional/auc.py +++ b/pytorch_lightning/metrics/functional/auc.py @@ -11,64 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple - import torch +from torchmetrics.functional import auc as _auc -from pytorch_lightning.metrics.utils import _stable_1d_sort - - -def _auc_update(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - if x.ndim > 1 or y.ndim > 1: - raise ValueError( - f'Expected both `x` and `y` tensor to be 1d, but got' - f' tensors with dimention {x.ndim} and {y.ndim}' - ) - if x.numel() != y.numel(): - raise ValueError( - f'Expected the same number of elements in `x` and `y`' - f' tensor but received {x.numel()} and {y.numel()}' - ) - return x, y - - -def _auc_compute(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor: - if reorder: - x, x_idx = _stable_1d_sort(x) - y = y[x_idx] - - dx = x[1:] - x[:-1] - if (dx < 0).any(): - if (dx <= 0).all(): - direction = -1. - else: - raise ValueError( - "The `x` tensor is neither increasing or decreasing." - " Try setting the reorder argument to `True`." - ) - else: - direction = 1. - return direction * torch.trapz(y, x) +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_auc) def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor: """ - Computes Area Under the Curve (AUC) using the trapezoidal rule - - Args: - x: x-coordinates - y: y-coordinates - reorder: if True, will reorder the arrays - - Return: - Tensor containing AUC score (float) - - Example: - >>> from pytorch_lightning.metrics.functional import auc - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> auc(x, y) - tensor(4.) + .. deprecated:: + Use :func:`torchmetrics.functional.auc`. Will be removed in v1.5.0. """ - x, y = _auc_update(x, y) - return _auc_compute(x, y, reorder=reorder) diff --git a/pytorch_lightning/metrics/functional/auroc.py b/pytorch_lightning/metrics/functional/auroc.py index 2a8b18d7c6b66..c49aa1a8fdc48 100644 --- a/pytorch_lightning/metrics/functional/auroc.py +++ b/pytorch_lightning/metrics/functional/auroc.py @@ -11,129 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from distutils.version import LooseVersion -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence import torch +from torchmetrics.functional import auroc as _auroc -from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType -from pytorch_lightning.metrics.functional.auc import auc -from pytorch_lightning.metrics.functional.roc import roc -from pytorch_lightning.utilities import LightningEnum - - -class AverageMethods(LightningEnum): - """ Type of averages """ - MACRO = 'macro' - WEIGHTED = 'weighted' - NONE = None - - -def _auroc_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, str]: - # use _input_format_classification for validating the input and get the mode of data - _, _, mode = _input_format_classification(preds, target) - - if mode == 'multi class multi dim': - n_classes = preds.shape[1] - preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) - target = target.flatten() - if mode == 'multi-label' and preds.ndim > 2: - n_classes = preds.shape[1] - preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) - target = target.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) - - return preds, target, mode - - -def _auroc_compute( - preds: torch.Tensor, - target: torch.Tensor, - mode: str, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - average: Optional[str] = 'macro', - max_fpr: Optional[float] = None, - sample_weights: Optional[Sequence] = None, -) -> torch.Tensor: - # binary mode override num_classes - if mode == 'binary': - num_classes = 1 - - # check max_fpr parameter - if max_fpr is not None: - if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1): - raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}") - - if LooseVersion(torch.__version__) < LooseVersion('1.6.0'): - raise RuntimeError( - "`max_fpr` argument requires `torch.bucketize` which" - " is not available below PyTorch version 1.6" - ) - - # max_fpr parameter is only support for binary - if mode != 'binary': - raise ValueError( - f"Partial AUC computation not available in" - f" multilabel/multiclass setting, 'max_fpr' must be" - f" set to `None`, received `{max_fpr}`." - ) - - # calculate fpr, tpr - if mode == 'multi-label': - # for multilabel we iteratively evaluate roc in a binary fashion - output = [ - roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights) - for i in range(num_classes) - ] - fpr = [o[0] for o in output] - tpr = [o[1] for o in output] - else: - fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights) - - # calculate standard roc auc score - if max_fpr is None or max_fpr == 1: - if num_classes != 1: - # calculate auc scores per class - auc_scores = [auc(x, y) for x, y in zip(fpr, tpr)] - - # calculate average - if average == AverageMethods.NONE: - return auc_scores - elif average == AverageMethods.MACRO: - return torch.mean(torch.stack(auc_scores)) - elif average == AverageMethods.WEIGHTED: - if mode == DataType.MULTILABEL: - support = torch.sum(target, dim=0) - else: - support = torch.bincount(target.flatten(), minlength=num_classes) - return torch.sum(torch.stack(auc_scores) * support / support.sum()) - - allowed_average = [e.value for e in AverageMethods] - raise ValueError( - f"Argument `average` expected to be one of the following:" - f" {allowed_average} but got {average}" - ) - - return auc(fpr, tpr) - - max_fpr = torch.tensor(max_fpr, device=fpr.device) - # Add a single point at max_fpr and interpolate its tpr value - stop = torch.bucketize(max_fpr, fpr, out_int32=True, right=True) - weight = (max_fpr - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1]) - interp_tpr = torch.lerp(tpr[stop - 1], tpr[stop], weight) - tpr = torch.cat([tpr[:stop], interp_tpr.view(1)]) - fpr = torch.cat([fpr[:stop], max_fpr.view(1)]) - - # Compute partial AUC - partial_auc = auc(fpr, tpr) - - # McClish correction: standardize result to be 0.5 if non-discriminant - # and 1 if maximal - min_area = 0.5 * max_fpr**2 - max_area = max_fpr - return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area)) +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_auroc) def auroc( preds: torch.Tensor, target: torch.Tensor, @@ -143,47 +29,7 @@ def auroc( max_fpr: Optional[float] = None, sample_weights: Optional[Sequence] = None, ) -> torch.Tensor: - """ Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC) - `_ - - Args: - preds: predictions from model (logits or probabilities) - target: Ground truth labels - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - average: - - ``'macro'`` computes metric for each class and uniformly averages them - - ``'weighted'`` computes metric for each class and does a weighted-average, - where each class is weighted by their support (accounts for class imbalance) - - ``None`` computes and returns the metric per class - max_fpr: - If not ``None``, calculates standardized partial AUC over the - range [0, max_fpr]. Should be a float between 0 and 1. - sample_weight: sample weights for each data point - - Example (binary case): - - >>> from pytorch_lightning.metrics.functional import auroc - >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) - >>> target = torch.tensor([0, 0, 1, 1, 1]) - >>> auroc(preds, target, pos_label=1) - tensor(0.5000) - - Example (multiclass case): - - >>> from pytorch_lightning.metrics.functional import auroc - >>> preds = torch.tensor([[0.90, 0.05, 0.05], - ... [0.05, 0.90, 0.05], - ... [0.05, 0.05, 0.90], - ... [0.85, 0.05, 0.10], - ... [0.10, 0.10, 0.80]]) - >>> target = torch.tensor([0, 1, 1, 2, 2]) - >>> auroc(preds, target, num_classes=3) - tensor(0.7778) """ - preds, target, mode = _auroc_update(preds, target) - return _auroc_compute(preds, target, mode, num_classes, pos_label, average, max_fpr, sample_weights) + .. deprecated:: + Use :func:`torchmetrics.functional.auroc`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/average_precision.py b/pytorch_lightning/metrics/functional/average_precision.py index 2a82c4f38f20e..017b34739a0f4 100644 --- a/pytorch_lightning/metrics/functional/average_precision.py +++ b/pytorch_lightning/metrics/functional/average_precision.py @@ -11,45 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Sequence, Union import torch +from torchmetrics.functional import average_precision as _average_precision -from pytorch_lightning.metrics.functional.precision_recall_curve import ( - _precision_recall_curve_compute, - _precision_recall_curve_update, -) - - -def _average_precision_update( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor, int, int]: - return _precision_recall_curve_update(preds, target, num_classes, pos_label) - - -def _average_precision_compute( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - pos_label: int, - sample_weights: Optional[Sequence] = None -) -> Union[List[torch.Tensor], torch.Tensor]: - precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label) - # Return the step function integral - # The following works because the last entry of precision is - # guaranteed to be 1, as returned by precision_recall_curve - if num_classes == 1: - return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) - - res = [] - for p, r in zip(precision, recall): - res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1])) - return res +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_average_precision) def average_precision( preds: torch.Tensor, target: torch.Tensor, @@ -58,42 +28,6 @@ def average_precision( sample_weights: Optional[Sequence] = None, ) -> Union[List[torch.Tensor], torch.Tensor]: """ - Computes the average precision score. - - Args: - preds: predictions from model (logits or probabilities) - target: ground truth values - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - sample_weights: sample weights for each data point - - Returns: - tensor with average precision. If multiclass will return list - of such tensors, one for each class - - Example (binary case): - - >>> from pytorch_lightning.metrics.functional import average_precision - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 1]) - >>> average_precision(pred, target, pos_label=1) - tensor(1.) - - Example (multiclass case): - - >>> from pytorch_lightning.metrics.functional import average_precision - >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.05, 0.75, 0.05]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> average_precision(pred, target, num_classes=5) - [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] - + .. deprecated:: + Use :func:`torchmetrics.functional.average_precision`. Will be removed in v1.5.0. """ - preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes, pos_label) - return _average_precision_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index e697ade9be16b..be1fec196a346 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -15,12 +15,13 @@ from typing import Callable, Optional, Sequence, Tuple import torch +from torchmetrics.utilities import class_reduce, reduce +from torchmetrics.utilities.data import get_num_classes, to_categorical from pytorch_lightning.metrics.functional.auc import auc as __auc from pytorch_lightning.metrics.functional.auroc import auroc as __auroc from pytorch_lightning.metrics.functional.iou import iou as __iou -from pytorch_lightning.metrics.utils import class_reduce, get_num_classes, reduce, to_categorical -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn def stat_scores( @@ -30,26 +31,8 @@ def stat_scores( argmax_dim: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ - Calculates the number of true positive, false positive, true negative - and false negative for a specific class - - Args: - pred: prediction tensor - target: target tensor - class_index: class to calculate over - argmax_dim: if pred is a tensor of probabilities, this indicates the - axis the argmax transformation will be applied over - - Return: - True Positive, False Positive, True Negative, False Negative, Support - - Example: - - >>> x = torch.tensor([1, 2, 3]) - >>> y = torch.tensor([0, 2, 3]) - >>> tp, fp, tn, fn, sup = stat_scores(x, y, class_index=1) - >>> tp, fp, tn, fn, sup - (tensor(0), tensor(1), tensor(2), tensor(0), tensor(0)) + .. deprecated:: + Use :func:`torchmetrics.functional.stat_scores`. Will be removed in v1.4.0. """ if pred.ndim == target.ndim + 1: pred = to_categorical(pred, argmax_dim=argmax_dim) @@ -72,17 +55,13 @@ def stat_scores_multiple_classes( reduction: str = 'none', ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ - Calculates the number of true positive, false positive, true negative - and false negative for each class - - .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.stat_scores` - + .. deprecated:: + Use :func:`torchmetrics.functional.stat_scores`. Will be removed in v1.4.0. """ - - rank_zero_warn( + rank_zero_deprecation( "This `stat_scores_multiple_classes` was deprecated in v1.2.0 in favor of" " `from pytorch_lightning.metrics.functional import stat_scores`." - " It will be removed in v1.4.0", DeprecationWarning + " It will be removed in v1.4.0" ) if pred.ndim == target.ndim + 1: pred = to_categorical(pred, argmax_dim=argmax_dim) @@ -162,42 +141,13 @@ def precision_recall( return_state: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Computes precision and recall for different thresholds - - .. warning :: Deprecated in favor of - :func:`~pytorch_lightning.metrics.functional.precision_recall`. - Will be removed in v1.4.0. - - Args: - pred: estimated probabilities - target: ground-truth labels - num_classes: number of classes - class_reduction: method to reduce metric score over labels - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - - return_support: returns the support for each class, need for fbeta/f1 calculations - return_state: returns a internal state that can be ddp reduced - before doing the final calculation - - Return: - Tensor with precision and recall - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 2, 2, 2]) - >>> precision_recall(x, y, class_reduction='macro') - (tensor(0.5000), tensor(0.3333)) - + .. deprecated:: + Use :func:`torchmetrics.functional.precision_recall`. Will be removed in v1.4.0. """ - rank_zero_warn( + rank_zero_deprecation( "This `precision_recall` was deprecated in v1.2.0 in favor of" " `from pytorch_lightning.metrcs.functional import precision_recall`." - " It will be removed in v1.4.0", DeprecationWarning + " It will be removed in v1.4.0" ) tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes) @@ -219,37 +169,13 @@ def precision( class_reduction: str = 'micro', ) -> torch.Tensor: """ - Computes precision score. - - .. warning :: Deprecated in favor of - :func:`~pytorch_lightning.metrics.functional.recall`. Will be removed in v1.4.0. - - Args: - pred: estimated probabilities - target: ground-truth labels - num_classes: number of classes - class_reduction: method to reduce metric score over labels - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - - Return: - Tensor with precision. - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> precision(x, y) - tensor(0.7500) - + .. deprecated:: + Use :func:`torchmetrics.functional.precision`. Will be removed in v1.4.0. """ - rank_zero_warn( + rank_zero_deprecation( "This `precision` was deprecated in v1.2.0 in favor of" " `from pytorch_lightning.metrics.functional import precision`." - " It will be removed in v1.4.0", DeprecationWarning + " It will be removed in v1.4.0" ) return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[0] @@ -263,36 +189,13 @@ def recall( class_reduction: str = 'micro', ) -> torch.Tensor: """ - Computes recall score. - - .. warning :: Deprecated in favor of - :func:`~pytorch_lightning.metrics.functional.recall`. Will be removed in v1.4.0. - - Args: - pred: estimated probabilities - target: ground-truth labels - num_classes: number of classes - class_reduction: method to reduce metric score over labels - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - - Return: - Tensor with recall. - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> recall(x, y) - tensor(0.7500) + .. deprecated:: + Use :func:`torchmetrics.functional.recall`. Will be removed in v1.4.0. """ - rank_zero_warn( + rank_zero_deprecation( "This `recall` was deprecated in v1.2.0 in favor of" " `from pytorch_lightning.metrics.functional import recall`." - " It will be removed in v1.4.0", DeprecationWarning + " It will be removed in v1.4.0" ) return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[1] @@ -304,37 +207,19 @@ def auc( y: torch.Tensor, ) -> torch.Tensor: """ - Computes Area Under the Curve (AUC) using the trapezoidal rule - - .. warning :: Deprecated in favor of - :func:`~pytorch_lightning.metrics.functional.auc.auc`. Will be removed - in v1.4.0. - - Args: - x: x-coordinates - y: y-coordinates - - Return: - Tensor containing AUC score (float) - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> auc(x, y) - tensor(4.) + .. deprecated:: + Use :func:`torchmetrics.functional.auc`. Will be removed in v1.4.0. """ - rank_zero_warn( + rank_zero_deprecation( "This `auc` was deprecated in v1.2.0 in favor of" " `pytorch_lightning.metrics.functional.auc import auc`." - " It will be removed in v1.4.0", DeprecationWarning + " It will be removed in v1.4.0" ) return __auc(x, y) # todo: remove in 1.4 -def auc_decorator() -> Callable: - rank_zero_warn("This `auc_decorator` was deprecated in v1.2.0." " It will be removed in v1.4.0", DeprecationWarning) +def _auc_decorator() -> Callable: def wrapper(func_to_decorate: Callable) -> Callable: @@ -350,11 +235,7 @@ def new_func(*args, **kwargs) -> torch.Tensor: # todo: remove in 1.4 -def multiclass_auc_decorator() -> Callable: - rank_zero_warn( - "This `multiclass_auc_decorator` was deprecated in v1.2.0." - " It will be removed in v1.4.0", DeprecationWarning - ) +def _multiclass_auc_decorator() -> Callable: def wrapper(func_to_decorate: Callable) -> Callable: @@ -381,34 +262,12 @@ def auroc( max_fpr: float = None, ) -> torch.Tensor: """ - Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores - - .. warning :: Deprecated in favor of - :func:`~pytorch_lightning.metrics.functional.auroc.auroc`. Will be removed - in v1.4.0. - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weights - pos_label: the label for the positive class - max_fpr: If not ``None``, calculates standardized partial AUC over the - range [0, max_fpr]. Should be a float between 0 and 1. - - Return: - Tensor containing ROCAUC score - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 1, 0]) - >>> auroc(x, y) - tensor(0.5000) + .. deprecated:: + Use :func:`torchmetrics.functional.auroc`. Will be removed in v1.4.0. """ - rank_zero_warn( - "This `auroc` was deprecated in v1.2.0 in favor of" - " `pytorch_lightning.metrics.functional.auroc import auroc`." - " It will be removed in v1.4.0", DeprecationWarning + rank_zero_deprecation( + "This `auroc` was deprecated in v1.2.0 in favor of `pytorch_lightning.metrics.functional.auroc import auroc`." + " It will be removed in v1.4.0" ) return __auroc( preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label, max_fpr=max_fpr, num_classes=1 @@ -423,58 +282,15 @@ def multiclass_auroc( num_classes: Optional[int] = None, ) -> torch.Tensor: """ - Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from multiclass - prediction scores - - .. warning :: Deprecated in favor of - :func:`~pytorch_lightning.metrics.functional.auroc.auroc`. Will be removed - in v1.4.0. - - Args: - pred: estimated probabilities, with shape [N, C] - target: ground-truth labels, with shape [N,] - sample_weight: sample weights - num_classes: number of classes (default: None, computes automatically from data) - - Return: - Tensor containing ROCAUC score - - Example: - - >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], - ... [0.05, 0.85, 0.05, 0.05], - ... [0.05, 0.05, 0.85, 0.05], - ... [0.05, 0.05, 0.05, 0.85]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> multiclass_auroc(pred, target, num_classes=4) - tensor(0.6667) + .. deprecated:: + Use :func:`torchmetrics.functional.auroc`. Will be removed in v1.4.0. """ - rank_zero_warn( + rank_zero_deprecation( "This `multiclass_auroc` was deprecated in v1.2.0 in favor of" " `pytorch_lightning.metrics.functional.auroc import auroc`." - " It will be removed in v1.4.0", DeprecationWarning + " It will be removed in v1.4.0" ) - if not torch.allclose(pred.sum(dim=1), torch.tensor(1.0)): - raise ValueError( - "Multiclass AUROC metric expects the target scores to be" - " probabilities, i.e. they should sum up to 1.0 over classes" - ) - - if torch.unique(target).size(0) != pred.size(1): - raise ValueError( - f"Number of classes found in in 'target' ({torch.unique(target).size(0)})" - f" does not equal the number of columns in 'pred' ({pred.size(1)})." - " Multiclass AUROC is not defined when all of the classes do not" - " occur in the target labels." - ) - - if num_classes is not None and num_classes != pred.size(1): - raise ValueError( - f"Number of classes deduced from 'pred' ({pred.size(1)}) does not equal" - f" the number of classes passed in 'num_classes' ({num_classes})." - ) - return __auroc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes) @@ -487,34 +303,8 @@ def dice_score( reduction: str = 'elementwise_mean', ) -> torch.Tensor: """ - Compute dice score from prediction scores - - Args: - pred: estimated probabilities - target: ground-truth labels - bg: whether to also compute dice for the background - nan_score: score to return, if a NaN occurs during computation - no_fg_score: score to return, if no foreground pixel was found in target - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - Return: - Tensor containing dice score - - Example: - - >>> from pytorch_lightning.metrics.functional import dice_score - >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], - ... [0.05, 0.85, 0.05, 0.05], - ... [0.05, 0.05, 0.85, 0.05], - ... [0.05, 0.05, 0.05, 0.85]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> dice_score(pred, target) - tensor(0.3333) - + .. deprecated:: + Use :func:`torchmetrics.functional.dice_score`. Will be removed in v1.4.0. """ num_classes = pred.shape[1] bg = (1 - int(bool(bg))) @@ -544,47 +334,12 @@ def iou( reduction: str = 'elementwise_mean', ) -> torch.Tensor: """ - Intersection over union, or Jaccard index calculation. - - .. warning :: Deprecated in favor of - :func:`~pytorch_lightning.metrics.functional.iou.iou`. Will be removed in - v1.4.0. - - Args: - pred: Tensor containing integer predictions, with shape [N, d1, d2, ...] - target: Tensor containing integer targets, with shape [N, d1, d2, ...] - ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. Has no effect if given an int that is not in the - range [0, num_classes-1], where num_classes is either given or derived from pred and target. By default, no - index is ignored, and all classes are used. - absent_score: score to use for an individual class, if no instances of the class index were present in - `pred` AND no instances of the class index were present in `target`. For example, if we have 3 classes, - [0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be assigned the `absent_score`. Default is - 0.0. - num_classes: Optionally specify the number of classes - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - Return: - IoU score : Tensor containing single value if reduction is - 'elementwise_mean', or number of classes if reduction is 'none' - - Example: - - >>> target = torch.randint(0, 2, (10, 25, 25)) - >>> pred = torch.tensor(target) - >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] - >>> iou(pred, target) - tensor(0.9660) - + .. deprecated:: + Use :func:`torchmetrics.functional.iou`. Will be removed in v1.4.0. """ - rank_zero_warn( - "This `iou` was deprecated in v1.2.0 in favor of" - " `from pytorch_lightning.metrics.functional.iou import iou`." - " It will be removed in v1.4.0", DeprecationWarning + rank_zero_deprecation( + "This `iou` was deprecated in v1.2.0 in favor of `from pytorch_lightning.metrics.functional.iou import iou`." + " It will be removed in v1.4.0" ) return __iou( pred=pred, diff --git a/pytorch_lightning/metrics/functional/confusion_matrix.py b/pytorch_lightning/metrics/functional/confusion_matrix.py index 58947f2cb19ed..038bd8b49b730 100644 --- a/pytorch_lightning/metrics/functional/confusion_matrix.py +++ b/pytorch_lightning/metrics/functional/confusion_matrix.py @@ -14,44 +14,12 @@ from typing import Optional import torch +from torchmetrics.functional import confusion_matrix as _confusion_matrix -from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType -from pytorch_lightning.utilities import rank_zero_warn - - -def _confusion_matrix_update( - preds: torch.Tensor, target: torch.Tensor, num_classes: int, threshold: float = 0.5 -) -> torch.Tensor: - preds, target, mode = _input_format_classification(preds, target, threshold) - if mode not in (DataType.BINARY, DataType.MULTILABEL): - preds = preds.argmax(dim=1) - target = target.argmax(dim=1) - unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long) - bins = torch.bincount(unique_mapping, minlength=num_classes**2) - confmat = bins.reshape(num_classes, num_classes) - return confmat - - -def _confusion_matrix_compute(confmat: torch.Tensor, normalize: Optional[str] = None) -> torch.Tensor: - allowed_normalize = ('true', 'pred', 'all', 'none', None) - assert normalize in allowed_normalize, \ - f"Argument average needs to one of the following: {allowed_normalize}" - confmat = confmat.float() - if normalize is not None and normalize != 'none': - if normalize == 'true': - cm = confmat / confmat.sum(axis=1, keepdim=True) - elif normalize == 'pred': - cm = confmat / confmat.sum(axis=0, keepdim=True) - elif normalize == 'all': - cm = confmat / confmat.sum() - nan_elements = cm[torch.isnan(cm)].nelement() - if nan_elements != 0: - cm[torch.isnan(cm)] = 0 - rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.') - return cm - return confmat +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_confusion_matrix) def confusion_matrix( preds: torch.Tensor, target: torch.Tensor, @@ -60,38 +28,6 @@ def confusion_matrix( threshold: float = 0.5 ) -> torch.Tensor: """ - Computes the confusion matrix. Works with binary, multiclass, and multilabel data. - Accepts probabilities from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. - - Args: - preds: (float or long tensor), Either a ``(N, ...)`` tensor with labels or - ``(N, C, ...)`` where C is the number of classes, tensor with labels/probabilities - target: ``target`` (long tensor), tensor with shape ``(N, ...)`` with ground true labels - num_classes: Number of classes in the dataset. - normalize: Normalization mode for confusion matrix. Choose from - - - ``None`` or ``'none'``: no normalization (default) - - ``'true'``: normalization over the targets (most commonly used) - - ``'pred'``: normalization over the predictions - - ``'all'``: normalization over the whole matrix - - threshold: - Threshold value for binary or multi-label probabilities. default: 0.5 - - Example: - - >>> from pytorch_lightning.metrics.functional import confusion_matrix - >>> target = torch.tensor([1, 1, 0, 0]) - >>> preds = torch.tensor([0, 1, 0, 0]) - >>> confusion_matrix(preds, target, num_classes=2) - tensor([[2., 0.], - [1., 1.]]) + .. deprecated:: + Use :func:`torchmetrics.functional.confusion_matrix`. Will be removed in v1.5.0. """ - confmat = _confusion_matrix_update(preds, target, num_classes, threshold) - return _confusion_matrix_compute(confmat, normalize) diff --git a/pytorch_lightning/metrics/functional/explained_variance.py b/pytorch_lightning/metrics/functional/explained_variance.py index 617d800c754e3..233a0851b8d56 100644 --- a/pytorch_lightning/metrics/functional/explained_variance.py +++ b/pytorch_lightning/metrics/functional/explained_variance.py @@ -11,78 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Tuple, Union +from typing import Sequence, Union import torch +from torchmetrics.functional import explained_variance as _explained_variance -from pytorch_lightning.metrics.utils import _check_same_shape - - -def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - _check_same_shape(preds, target) - return preds, target - - -def _explained_variance_compute( - preds: torch.Tensor, - target: torch.Tensor, - multioutput: str = 'uniform_average', -) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - diff_avg = torch.mean(target - preds, dim=0) - numerator = torch.mean((target - preds - diff_avg)**2, dim=0) - - target_avg = torch.mean(target, dim=0) - denominator = torch.mean((target - target_avg)**2, dim=0) - - # Take care of division by zero - nonzero_numerator = numerator != 0 - nonzero_denominator = denominator != 0 - valid_score = nonzero_numerator & nonzero_denominator - output_scores = torch.ones_like(diff_avg) - output_scores[valid_score] = 1.0 - (numerator[valid_score] / denominator[valid_score]) - output_scores[nonzero_numerator & ~nonzero_denominator] = 0. - - # Decide what to do in multioutput case - # Todo: allow user to pass in tensor with weights - if multioutput == 'raw_values': - return output_scores - if multioutput == 'uniform_average': - return torch.mean(output_scores) - if multioutput == 'variance_weighted': - denom_sum = torch.sum(denominator) - return torch.sum(denominator / denom_sum * output_scores) +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_explained_variance) def explained_variance( preds: torch.Tensor, target: torch.Tensor, multioutput: str = 'uniform_average', ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: """ - Computes explained variance. - - Args: - preds: estimated labels - target: ground truth labels - multioutput: Defines aggregation in the case of multiple output scores. Can be one - of the following strings (default is `'uniform_average'`.): - - * `'raw_values'` returns full set of scores - * `'uniform_average'` scores are uniformly averaged - * `'variance_weighted'` scores are weighted by their individual variances - - Example: - - >>> from pytorch_lightning.metrics.functional import explained_variance - >>> target = torch.tensor([3, -0.5, 2, 7]) - >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> explained_variance(preds, target) - tensor(0.9572) - - >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) - >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) - >>> explained_variance(preds, target, multioutput='raw_values') - tensor([0.9677, 1.0000]) + .. deprecated:: + Use :func:`torchmetrics.functional.explained_variance`. Will be removed in v1.5.0. """ - preds, target = _explained_variance_update(preds, target) - return _explained_variance_compute(preds, target, multioutput) diff --git a/pytorch_lightning/metrics/functional/f_beta.py b/pytorch_lightning/metrics/functional/f_beta.py index debb6c8285fc9..f994c9a8a3271 100644 --- a/pytorch_lightning/metrics/functional/f_beta.py +++ b/pytorch_lightning/metrics/functional/f_beta.py @@ -11,46 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple - import torch +from torchmetrics.functional import f1 as _f1 +from torchmetrics.functional import fbeta as _fbeta -from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce - - -def _fbeta_update( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - threshold: float = 0.5, - multilabel: bool = False -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - preds, target = _input_format_classification_one_hot(num_classes, preds, target, threshold, multilabel) - true_positives = torch.sum(preds * target, dim=1) - predicted_positives = torch.sum(preds, dim=1) - actual_positives = torch.sum(target, dim=1) - return true_positives, predicted_positives, actual_positives - - -def _fbeta_compute( - true_positives: torch.Tensor, - predicted_positives: torch.Tensor, - actual_positives: torch.Tensor, - beta: float = 1.0, - average: str = "micro" -) -> torch.Tensor: - if average == "micro": - precision = true_positives.sum().float() / predicted_positives.sum() - recall = true_positives.sum().float() / actual_positives.sum() - else: - precision = true_positives.float() / predicted_positives - recall = true_positives.float() / actual_positives - - num = (1 + beta**2) * precision * recall - denom = beta**2 * precision + recall - return class_reduce(num, denom, weights=actual_positives, class_reduction=average) +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_fbeta) def fbeta( preds: torch.Tensor, target: torch.Tensor, @@ -61,49 +29,12 @@ def fbeta( multilabel: bool = False ) -> torch.Tensor: """ - Computes f_beta metric. - - Works with binary, multiclass, and multilabel data. - Accepts probabilities from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. - - Args: - preds: predictions from model (probabilities, or labels) - target: ground truth labels - num_classes: Number of classes in the dataset. - beta: Beta coefficient in the F measure. - threshold: - Threshold value for binary or multi-label probabilities. default: 0.5 - - average: - - ``'micro'`` computes metric globally - - ``'macro'`` computes metric for each class and uniformly averages them - - ``'weighted'`` computes metric for each class and does a weighted-average, - where each class is weighted by their support (accounts for class imbalance) - - ``'none'`` or ``None`` computes and returns the metric per class - - multilabel: If predictions are from multilabel classification. - - Example: - - >>> from pytorch_lightning.metrics.functional import fbeta - >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) - >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) - >>> fbeta(preds, target, num_classes=3, beta=0.5) - tensor(0.3333) - + .. deprecated:: + Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0. """ - true_positives, predicted_positives, actual_positives = _fbeta_update( - preds, target, num_classes, threshold, multilabel - ) - return _fbeta_compute(true_positives, predicted_positives, actual_positives, beta, average) +@deprecated_metrics(target=_f1) def f1( preds: torch.Tensor, target: torch.Tensor, @@ -113,39 +44,6 @@ def f1( multilabel: bool = False ) -> torch.Tensor: """ - Computes F1 metric. F1 metrics correspond to a equally weighted average of the - precision and recall scores. - - Works with binary, multiclass, and multilabel data. - Accepts probabilities from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. - - Args: - preds: predictions from model (probabilities, or labels) - target: ground truth labels - num_classes: Number of classes in the dataset. - threshold: - Threshold value for binary or multi-label probabilities. default: 0.5 - - average: - - ``'micro'`` computes metric globally - - ``'macro'`` computes metric for each class and uniformly averages them - - ``'weighted'`` computes metric for each class and does a weighted-average, - where each class is weighted by their support (accounts for class imbalance) - - ``'none'`` or ``None`` computes and returns the metric per class - - multilabel: If predictions are from multilabel classification. - - Example: - >>> from pytorch_lightning.metrics.functional import f1 - >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) - >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) - >>> f1(preds, target, num_classes=3) - tensor(0.3333) + .. deprecated:: + Use :func:`torchmetrics.functional.f1`. Will be removed in v1.5.0. """ - return fbeta(preds, target, num_classes, 1.0, threshold, average, multilabel) diff --git a/pytorch_lightning/metrics/functional/hamming_distance.py b/pytorch_lightning/metrics/functional/hamming_distance.py index 60409751fc9f0..6a390e776f111 100644 --- a/pytorch_lightning/metrics/functional/hamming_distance.py +++ b/pytorch_lightning/metrics/functional/hamming_distance.py @@ -11,64 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union - import torch +from torchmetrics.functional import hamming_distance as _hamming_distance -from pytorch_lightning.metrics.classification.helpers import _input_format_classification - - -def _hamming_distance_update( - preds: torch.Tensor, - target: torch.Tensor, - threshold: float = 0.5, -) -> Tuple[torch.Tensor, int]: - preds, target, _ = _input_format_classification(preds, target, threshold=threshold) - - correct = (preds == target).sum() - total = preds.numel() - - return correct, total - - -def _hamming_distance_compute(correct: torch.Tensor, total: Union[int, torch.Tensor]) -> torch.Tensor: - return 1 - correct.float() / total +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_hamming_distance) def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: - r""" - Computes the average `Hamming distance `_ (also - known as Hamming loss) between targets and predictions: - - .. math:: - \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) - - Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, - and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that - tensor. - - This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it - treats each possible label separately - meaning that, for example, multi-class data is - treated as if it were multi-label. - - Accepts all input types listed in :ref:`extensions/metrics:input types`. - - Args: - preds: Predictions from model - target: Ground truth - threshold: - Threshold probability value for transforming probability predictions to binary - (0 or 1) predictions, in the case of binary or multi-label inputs. - - Example: - - >>> from pytorch_lightning.metrics.functional import hamming_distance - >>> target = torch.tensor([[0, 1], [1, 1]]) - >>> preds = torch.tensor([[0, 1], [0, 1]]) - >>> hamming_distance(preds, target) - tensor(0.2500) - """ - - correct, total = _hamming_distance_update(preds, target, threshold) - return _hamming_distance_compute(correct, total) + .. deprecated:: + Use :func:`torchmetrics.functional.hamming_distance`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/image_gradients.py b/pytorch_lightning/metrics/functional/image_gradients.py index 3fbed571e008e..e2151c5fc1d93 100644 --- a/pytorch_lightning/metrics/functional/image_gradients.py +++ b/pytorch_lightning/metrics/functional/image_gradients.py @@ -14,62 +14,14 @@ from typing import Tuple import torch +from torchmetrics.functional import image_gradients as _image_gradients - -def _image_gradients_validate(img: torch.Tensor) -> torch.Tensor: - """ Validates whether img is a 4D torch Tensor """ - - if not isinstance(img, torch.Tensor): - raise TypeError(f"The `img` expects a value of type but got {type(img)}") - if img.ndim != 4: - raise RuntimeError(f"The `img` expects a 4D tensor but got {img.ndim}D tensor") - - -def _compute_image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ Computes image gradients (dy/dx) for a given image """ - - batch_size, channels, height, width = img.shape - - dy = img[..., 1:, :] - img[..., :-1, :] - dx = img[..., :, 1:] - img[..., :, :-1] - - shapey = [batch_size, channels, 1, width] - dy = torch.cat([dy, torch.zeros(shapey, device=img.device, dtype=img.dtype)], dim=2) - dy = dy.view(img.shape) - - shapex = [batch_size, channels, height, 1] - dx = torch.cat([dx, torch.zeros(shapex, device=img.device, dtype=img.dtype)], dim=3) - dx = dx.view(img.shape) - - return dy, dx +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_image_gradients) def image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - Computes the `gradients `_ of a given image using finite difference - - Args: - img: An ``(N, C, H, W)`` input tensor where C is the number of image channels - - Return: - Tuple of (dy, dx) with each gradient of shape ``[N, C, H, W]`` - - Example: - >>> from pytorch_lightning.metrics.functional import image_gradients - >>> image = torch.arange(0, 1*1*5*5, dtype=torch.float32) - >>> image = torch.reshape(image, (1, 1, 5, 5)) - >>> dy, dx = image_gradients(image) - >>> dy[0, 0, :, :] - tensor([[5., 5., 5., 5., 5.], - [5., 5., 5., 5., 5.], - [5., 5., 5., 5., 5.], - [5., 5., 5., 5., 5.], - [0., 0., 0., 0., 0.]]) - - .. note:: The implementation follows the 1-step finite difference method as followed - by the TF implementation. The values are organized such that the gradient of - [I(x+1, y)-[I(x, y)]] are at the (x, y) location + .. deprecated:: + Use :func:`torchmetrics.functional.image_gradients`. Will be removed in v1.5.0. """ - _image_gradients_validate(img) - - return _compute_image_gradients(img) diff --git a/pytorch_lightning/metrics/functional/iou.py b/pytorch_lightning/metrics/functional/iou.py index 7b6851b5cebd0..76f59854ad4bf 100644 --- a/pytorch_lightning/metrics/functional/iou.py +++ b/pytorch_lightning/metrics/functional/iou.py @@ -14,34 +14,12 @@ from typing import Optional import torch +from torchmetrics.functional import iou as _iou -from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_update -from pytorch_lightning.metrics.utils import get_num_classes, reduce - - -def _iou_from_confmat( - confmat: torch.Tensor, - num_classes: int, - ignore_index: Optional[int] = None, - absent_score: float = 0.0, - reduction: str = 'elementwise_mean', -): - intersection = torch.diag(confmat) - union = confmat.sum(0) + confmat.sum(1) - intersection - - # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. - scores = intersection.float() / union.float() - scores[union == 0] = absent_score - - # Remove the ignored class index from the scores. - if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes: - scores = torch.cat([ - scores[:ignore_index], - scores[ignore_index + 1:], - ]) - return reduce(scores, reduction=reduction) +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_iou) def iou( pred: torch.Tensor, target: torch.Tensor, @@ -51,60 +29,7 @@ def iou( num_classes: Optional[int] = None, reduction: str = 'elementwise_mean', ) -> torch.Tensor: - r""" - Computes `Intersection over union, or Jaccard index calculation `_: - - .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} - - Where: :math:`A` and :math:`B` are both tensors of the same size, - containing integer class values. They may be subject to conversion from - input data (see description below). - - Note that it is different from box IoU. - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. - - If pred has an extra dimension as in the case of multi-class scores we - perform an argmax on ``dim=1``. - - Args: - preds: tensor containing predictions from model (probabilities, or labels) with shape ``[N, d1, d2, ...]`` - target: tensor containing ground truth labels with shape ``[N, d1, d2, ...]`` - ignore_index: optional int specifying a target class to ignore. If given, - this class index does not contribute to the returned score, regardless - of reduction method. Has no effect if given an int that is not in the - range [0, num_classes-1], where num_classes is either given or derived - from pred and target. By default, no index is ignored, and all classes are used. - absent_score: score to use for an individual class, if no instances of - the class index were present in `pred` AND no instances of the class - index were present in `target`. For example, if we have 3 classes, - [0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be - assigned the `absent_score`. - threshold: - Threshold value for binary or multi-label probabilities. default: 0.5 - num_classes: - Optionally specify the number of classes - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - Return: - IoU score : Tensor containing single value if reduction is - 'elementwise_mean', or number of classes if reduction is 'none' - - Example: - - >>> from pytorch_lightning.metrics.functional import iou - >>> target = torch.randint(0, 2, (10, 25, 25)) - >>> pred = torch.tensor(target) - >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] - >>> iou(pred, target) - tensor(0.9660) """ - - num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) - confmat = _confusion_matrix_update(pred, target, num_classes, threshold) - return _iou_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction) + .. deprecated:: + Use :func:`torchmetrics.functional.iou`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/mean_absolute_error.py b/pytorch_lightning/metrics/functional/mean_absolute_error.py index 671368ba240f9..219284d79d623 100644 --- a/pytorch_lightning/metrics/functional/mean_absolute_error.py +++ b/pytorch_lightning/metrics/functional/mean_absolute_error.py @@ -11,41 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch +from torchmetrics.functional import mean_absolute_error as _mean_absolute_error -from pytorch_lightning.metrics.utils import _check_same_shape - - -def _mean_absolute_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - _check_same_shape(preds, target) - sum_abs_error = torch.sum(torch.abs(preds - target)) - n_obs = target.numel() - return sum_abs_error, n_obs - - -def _mean_absolute_error_compute(sum_abs_error: torch.Tensor, n_obs: int) -> torch.Tensor: - return sum_abs_error / n_obs +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_mean_absolute_error) def mean_absolute_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Computes mean absolute error - - Args: - pred: estimated labels - target: ground truth labels - - Return: - Tensor with MAE - - Example: - >>> from pytorch_lightning.metrics.functional import mean_absolute_error - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mean_absolute_error(x, y) - tensor(0.2500) + .. deprecated:: + Use :func:`torchmetrics.functional.mean_absolute_error`. Will be removed in v1.5.0. """ - sum_abs_error, n_obs = _mean_absolute_error_update(preds, target) - return _mean_absolute_error_compute(sum_abs_error, n_obs) diff --git a/pytorch_lightning/metrics/functional/mean_relative_error.py b/pytorch_lightning/metrics/functional/mean_relative_error.py index eedaea1a26a4f..329fe040ebc7d 100644 --- a/pytorch_lightning/metrics/functional/mean_relative_error.py +++ b/pytorch_lightning/metrics/functional/mean_relative_error.py @@ -11,44 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch +from torchmetrics.functional.regression.mean_relative_error import mean_relative_error as _mean_relative_error -from pytorch_lightning.metrics.utils import _check_same_shape - - -def _mean_relative_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - _check_same_shape(preds, target) - target_nz = target.clone() - target_nz[target == 0] = 1 - sum_rltv_error = torch.sum(torch.abs((preds - target) / target_nz)) - n_obs = target.numel() - return sum_rltv_error, n_obs - - -def _mean_relative_error_compute(sum_rltv_error: torch.Tensor, n_obs: int) -> torch.Tensor: - return sum_rltv_error / n_obs +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_mean_relative_error) def mean_relative_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Computes mean relative error - - Args: - pred: estimated labels - target: ground truth labels - - Return: - Tensor with mean relative error - - Example: - - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mean_relative_error(x, y) - tensor(0.1250) - + .. deprecated:: + Use :func:`torchmetrics.functional.regression.mean_relative_error`. Will be removed in v1.5.0. """ - sum_rltv_error, n_obs = _mean_relative_error_update(preds, target) - return _mean_relative_error_compute(sum_rltv_error, n_obs) diff --git a/pytorch_lightning/metrics/functional/mean_squared_error.py b/pytorch_lightning/metrics/functional/mean_squared_error.py index 2cdd4ea679043..5bbc0bb1c6a83 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_error.py @@ -11,41 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch +from torchmetrics.functional import mean_squared_error as _mean_squared_error -from pytorch_lightning.metrics.utils import _check_same_shape - - -def _mean_squared_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - _check_same_shape(preds, target) - sum_squared_error = torch.sum(torch.pow(preds - target, 2)) - n_obs = target.numel() - return sum_squared_error, n_obs - - -def _mean_squared_error_compute(sum_squared_error: torch.Tensor, n_obs: int) -> torch.Tensor: - return sum_squared_error / n_obs +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_mean_squared_error) def mean_squared_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Computes mean squared error - - Args: - preds: estimated labels - target: ground truth labels - - Return: - Tensor with MSE - - Example: - >>> from pytorch_lightning.metrics.functional import mean_squared_error - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mean_squared_error(x, y) - tensor(0.2500) + .. deprecated:: + Use :func:`torchmetrics.functional.mean_squared_error`. Will be removed in v1.5.0. """ - sum_squared_error, n_obs = _mean_squared_error_update(preds, target) - return _mean_squared_error_compute(sum_squared_error, n_obs) diff --git a/pytorch_lightning/metrics/functional/mean_squared_log_error.py b/pytorch_lightning/metrics/functional/mean_squared_log_error.py index 45c255eb61d78..29786529381d5 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_log_error.py @@ -11,41 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch +from torchmetrics.functional import mean_squared_log_error as _mean_squared_log_error -from pytorch_lightning.metrics.utils import _check_same_shape - - -def _mean_squared_log_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - _check_same_shape(preds, target) - sum_squared_log_error = torch.sum(torch.pow(torch.log1p(preds) - torch.log1p(target), 2)) - n_obs = target.numel() - return sum_squared_log_error, n_obs - - -def _mean_squared_log_error_compute(sum_squared_log_error: torch.Tensor, n_obs: int) -> torch.Tensor: - return sum_squared_log_error / n_obs +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_mean_squared_log_error) def mean_squared_log_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Computes mean squared log error - - Args: - preds: estimated labels - target: ground truth labels - - Return: - Tensor with RMSLE - - Example: - >>> from pytorch_lightning.metrics.functional import mean_squared_log_error - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mean_squared_log_error(x, y) - tensor(0.0207) + .. deprecated:: + Use :func:`torchmetrics.functional.mean_squared_log_error`. Will be removed in v1.5.0. """ - sum_squared_log_error, n_obs = _mean_squared_log_error_update(preds, target) - return _mean_squared_log_error_compute(sum_squared_log_error, n_obs) diff --git a/pytorch_lightning/metrics/functional/nlp.py b/pytorch_lightning/metrics/functional/nlp.py index b1466c66112bc..c59d7cf2b8976 100644 --- a/pytorch_lightning/metrics/functional/nlp.py +++ b/pytorch_lightning/metrics/functional/nlp.py @@ -16,34 +16,15 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from collections import Counter -from typing import List, Sequence +from typing import Sequence import torch +from torchmetrics.functional import bleu_score as _bleu_score - -def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: - """ - Counting how many times each word appears in a given text with ngram - - Args: - ngram_input_list: A list of translated text or reference texts - n_gram: gram value ranged 1 to 4 - - Return: - ngram_counter: a collections.Counter object of ngram - """ - - ngram_counter = Counter() - - for i in range(1, n_gram + 1): - for j in range(len(ngram_input_list) - i + 1): - ngram_key = tuple(ngram_input_list[j:(i + j)]) - ngram_counter[ngram_key] += 1 - - return ngram_counter +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_bleu_score) def bleu_score( translate_corpus: Sequence[str], reference_corpus: Sequence[str], @@ -51,64 +32,6 @@ def bleu_score( smooth: bool = False ) -> torch.Tensor: """ - Calculate BLEU score of machine translated text with one or more references - - Args: - translate_corpus: An iterable of machine translated corpus - reference_corpus: An iterable of iterables of reference corpus - n_gram: Gram value ranged from 1 to 4 (Default 4) - smooth: Whether or not to apply smoothing – Lin et al. 2004 - - Return: - Tensor with BLEU Score - - Example: - >>> from pytorch_lightning.metrics.functional import bleu_score - >>> translate_corpus = ['the cat is on the mat'.split()] - >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] - >>> bleu_score(translate_corpus, reference_corpus) - tensor(0.7598) + .. deprecated:: + Use :func:`torchmetrics.functional.bleu_score`. Will be removed in v1.5.0. """ - - assert len(translate_corpus) == len(reference_corpus) - numerator = torch.zeros(n_gram) - denominator = torch.zeros(n_gram) - c = 0.0 - r = 0.0 - - for (translation, references) in zip(translate_corpus, reference_corpus): - c += len(translation) - ref_len_list = [len(ref) for ref in references] - ref_len_diff = [abs(len(translation) - x) for x in ref_len_list] - r += ref_len_list[ref_len_diff.index(min(ref_len_diff))] - translation_counter = _count_ngram(translation, n_gram) - reference_counter = Counter() - - for ref in references: - reference_counter |= _count_ngram(ref, n_gram) - - ngram_counter_clip = translation_counter & reference_counter - - for counter_clip in ngram_counter_clip: - numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] - - for counter in translation_counter: - denominator[len(counter) - 1] += translation_counter[counter] - - trans_len = torch.tensor(c) - ref_len = torch.tensor(r) - - if min(numerator) == 0.0: - return torch.tensor(0.0) - - if smooth: - precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram)) - else: - precision_scores = numerator / denominator - - log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores) - geometric_mean = torch.exp(torch.sum(log_precision_scores)) - brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len)) - bleu = brevity_penalty * geometric_mean - - return bleu diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index 6f5aafd79d109..7b6c8641b5829 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -14,29 +14,14 @@ from typing import Optional import torch +from torchmetrics.functional import precision as _precision +from torchmetrics.functional import precision_recall as _precision_recall +from torchmetrics.functional import recall as _recall -from pytorch_lightning.metrics.classification.helpers import _reduce_stat_scores -from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update -from pytorch_lightning.utilities import rank_zero_warn - - -def _precision_compute( - tp: torch.Tensor, - fp: torch.Tensor, - tn: torch.Tensor, - fn: torch.Tensor, - average: str, - mdmc_average: Optional[str], -) -> torch.Tensor: - return _reduce_stat_scores( - numerator=tp, - denominator=tp + fp, - weights=None if average != "weighted" else tp + fn, - average=average, - mdmc_average=mdmc_average, - ) +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_precision) def precision( preds: torch.Tensor, target: torch.Tensor, @@ -47,158 +32,14 @@ def precision( threshold: float = 0.5, top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, - class_reduction: Optional[str] = None, ) -> torch.Tensor: - r""" - Computes `Precision `_: - - .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} - - Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and - false positives respecitively. With the use of ``top_k`` parameter, this metric can - generalize to Precision@K. - - The reduction method (how the precision scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. - - Args: - preds: Predictions from model (probabilities or labels) - target: Ground truth values - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics accross classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics accross classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - Note that what is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`extensions/metrics:input types`) - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. - - Should be left unset (``None``) for inputs with label predictions. - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. - - class_reduction: - .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. - - Return: - The shape of the returned tensor depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes - - Example: - - >>> from pytorch_lightning.metrics.functional import precision - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> precision(preds, target, average='macro', num_classes=3) - tensor(0.1667) - >>> precision(preds, target, average='micro') - tensor(0.2500) - """ - if class_reduction: - rank_zero_warn( - "This `class_reduction` parameter was deprecated in v1.2.0 in favor of" - " `reduce`. It will be removed in v1.4.0", - DeprecationWarning, - ) - average = class_reduction - - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - allowed_mdmc_average = [None, "samplewise", "global"] - if mdmc_average not in allowed_mdmc_average: - raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") - - if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): - raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") - - if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - reduce = "macro" if average in ["weighted", "none", None] else average - tp, fp, tn, fn = _stat_scores_update( - preds, - target, - reduce=reduce, - mdmc_reduce=mdmc_average, - threshold=threshold, - num_classes=num_classes, - top_k=top_k, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - ) - - return _precision_compute(tp, fp, tn, fn, average, mdmc_average) - - -def _recall_compute( - tp: torch.Tensor, - fp: torch.Tensor, - tn: torch.Tensor, - fn: torch.Tensor, - average: str, - mdmc_average: Optional[str], -) -> torch.Tensor: - return _reduce_stat_scores( - numerator=tp, - denominator=tp + fn, - weights=None if average != "weighted" else tp + fn, - average=average, - mdmc_average=mdmc_average, - ) + .. deprecated:: + Use :func:`torchmetrics.functional.precision`. Will be removed in v1.5.0. + """ +@deprecated_metrics(target=_recall) def recall( preds: torch.Tensor, target: torch.Tensor, @@ -209,141 +50,14 @@ def recall( threshold: float = 0.5, top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, - class_reduction: Optional[str] = None, ) -> torch.Tensor: - r""" - Computes `Recall `_: - - .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} - - Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and - false negatives respecitively. With the use of ``top_k`` parameter, this metric can - generalize to Recall@K. - - The reduction method (how the recall scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. - - Args: - preds: Predictions from model (probabilities, or labels) - target: Ground truth values - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics accross classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics accross classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - Note that what is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`extensions/metrics:input types`) - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. - - Should be left unset (``None``) for inputs with label predictions. - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. - - class_reduction: - .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. - - Return: - The shape of the returned tensor depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes - - Example: - - >>> from pytorch_lightning.metrics.functional import recall - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> recall(preds, target, average='macro', num_classes=3) - tensor(0.3333) - >>> recall(preds, target, average='micro') - tensor(0.2500) - """ - if class_reduction: - rank_zero_warn( - "This `class_reduction` parameter was deprecated in v1.2.0 in favor of" - " `reduce`. It will be removed in v1.4.0", - DeprecationWarning, - ) - average = class_reduction - - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - allowed_mdmc_average = [None, "samplewise", "global"] - if mdmc_average not in allowed_mdmc_average: - raise ValueError("The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") - - if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): - raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") - - if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - reduce = "macro" if average in ["weighted", "none", None] else average - tp, fp, tn, fn = _stat_scores_update( - preds, - target, - reduce=reduce, - mdmc_reduce=mdmc_average, - threshold=threshold, - num_classes=num_classes, - top_k=top_k, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - ) - - return _recall_compute(tp, fp, tn, fn, average, mdmc_average) + .. deprecated:: + Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0. + """ +@deprecated_metrics(target=_precision_recall) def precision_recall( preds: torch.Tensor, target: torch.Tensor, @@ -354,143 +68,8 @@ def precision_recall( threshold: float = 0.5, top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, - class_reduction: Optional[str] = None, ) -> torch.Tensor: - r""" - Computes `Precision and Recall `_: - - .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} - - - .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} - - Where :math:`\text{TP}`m :math:`\text{FN}` and :math:`\text{FP}` represent the number - of true positives, false negatives and false positives respecitively. With the use of - ``top_k`` parameter, this metric can generalize to Recall@K and Precision@K. - - The reduction method (how the recall scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. - - Args: - preds: Predictions from model (probabilities, or labels) - target: Ground truth values - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics accross classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics accross classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - Note that what is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`extensions/metrics:input types`) - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. - - Should be left unset (``None``) for inputs with label predictions. - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. - - class_reduction: - .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. - - Return: - The function returns a tuple with two elements: precision and recall. Their shape - depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, they are a single element tensor - - If ``average in ['none', None]``, they are a tensor of shape ``(C, )``, where ``C`` stands for - the number of classes - - Example: - - >>> from pytorch_lightning.metrics.functional import precision_recall - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> precision_recall(preds, target, average='macro', num_classes=3) - (tensor(0.1667), tensor(0.3333)) - >>> precision_recall(preds, target, average='micro') - (tensor(0.2500), tensor(0.2500)) - """ - if class_reduction: - rank_zero_warn( - "This `class_reduction` parameter was deprecated in v1.2.0 in favor of" - " `reduce`. It will be removed in v1.4.0", - DeprecationWarning, - ) - average = class_reduction - - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - allowed_mdmc_average = [None, "samplewise", "global"] - if mdmc_average not in allowed_mdmc_average: - raise ValueError("The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") - - if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): - raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") - - if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - reduce = "macro" if average in ["weighted", "none", None] else average - tp, fp, tn, fn = _stat_scores_update( - preds, - target, - reduce=reduce, - mdmc_reduce=mdmc_average, - threshold=threshold, - num_classes=num_classes, - top_k=top_k, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - ) - - precision = _precision_compute(tp, fp, tn, fn, average, mdmc_average) - recall = _recall_compute(tp, fp, tn, fn, average, mdmc_average) - - return precision, recall + .. deprecated:: + Use :func:`torchmetrics.functional.precision_recall`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/precision_recall_curve.py b/pytorch_lightning/metrics/functional/precision_recall_curve.py index fb442b020af88..dc9863cbb47c4 100644 --- a/pytorch_lightning/metrics/functional/precision_recall_curve.py +++ b/pytorch_lightning/metrics/functional/precision_recall_curve.py @@ -14,140 +14,12 @@ from typing import List, Optional, Sequence, Tuple, Union import torch -import torch.nn.functional as F +from torchmetrics.functional import precision_recall_curve as _precision_recall_curve -from pytorch_lightning.utilities import rank_zero_warn - - -def _binary_clf_curve( - preds: torch.Tensor, - target: torch.Tensor, - sample_weights: Optional[Sequence] = None, - pos_label: int = 1., -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py - """ - if sample_weights is not None and not isinstance(sample_weights, torch.Tensor): - sample_weights = torch.tensor(sample_weights, device=preds.device, dtype=torch.float) - - # remove class dimension if necessary - if preds.ndim > target.ndim: - preds = preds[:, 0] - desc_score_indices = torch.argsort(preds, descending=True) - - preds = preds[desc_score_indices] - target = target[desc_score_indices] - - if sample_weights is not None: - weight = sample_weights[desc_score_indices] - else: - weight = 1. - - # pred typically has many tied values. Here we extract - # the indices associated with the distinct values. We also - # concatenate a value for the end of the curve. - distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0] - threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1) - target = (target == pos_label).to(torch.long) - tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] - - if sample_weights is not None: - # express fps as a cumsum to ensure fps is increasing even in - # the presence of floating point errors - fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs] - else: - fps = 1 + threshold_idxs - tps - - return fps, tps, preds[threshold_idxs] - - -def _precision_recall_curve_update( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor, int, int]: - if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): - raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") - # single class evaluation - if len(preds.shape) == len(target.shape): - num_classes = 1 - if pos_label is None: - rank_zero_warn('`pos_label` automatically set 1.') - pos_label = 1 - preds = preds.flatten() - target = target.flatten() - - # multi class evaluation - if len(preds.shape) == len(target.shape) + 1: - if pos_label is not None: - rank_zero_warn( - 'Argument `pos_label` should be `None` when running' - f' multiclass precision recall curve. Got {pos_label}' - ) - if num_classes != preds.shape[1]: - raise ValueError( - f'Argument `num_classes` was set to {num_classes} in' - f' metric `precision_recall_curve` but detected {preds.shape[1]}' - ' number of classes from predictions' - ) - preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1) - target = target.flatten() - - return preds, target, num_classes, pos_label - - -def _precision_recall_curve_compute( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - pos_label: int, - sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: - - if num_classes == 1: - fps, tps, thresholds = _binary_clf_curve( - preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label - ) - - precision = tps / (tps + fps) - recall = tps / tps[-1] - - # stop when full recall attained - # and reverse the outputs so recall is decreasing - last_ind = torch.where(tps == tps[-1])[0][0] - sl = slice(0, last_ind.item() + 1) - - # need to call reversed explicitly, since including that to slice would - # introduce negative strides that are not yet supported in pytorch - precision = torch.cat([reversed(precision[sl]), torch.ones(1, dtype=precision.dtype, device=precision.device)]) - - recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)]) - - thresholds = reversed(thresholds[sl]).clone() - - return precision, recall, thresholds - - # Recursively call per class - precision, recall, thresholds = [], [], [] - for c in range(num_classes): - preds_c = preds[:, c] - res = precision_recall_curve( - preds=preds_c, - target=target, - num_classes=1, - pos_label=c, - sample_weights=sample_weights, - ) - precision.append(res[0]) - recall.append(res[1]) - thresholds.append(res[2]) - - return precision, recall, thresholds +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_precision_recall_curve) def precision_recall_curve( preds: torch.Tensor, target: torch.Tensor, @@ -155,64 +27,8 @@ def precision_recall_curve( pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: + List[torch.Tensor]], ]: """ - Computes precision-recall pairs for different thresholds. - - Args: - preds: predictions from model (probabilities) - target: ground truth labels - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - sample_weights: sample weights for each data point - - Returns: - 3-element tuple containing - - precision: - tensor where element i is the precision of predictions with - score >= thresholds[i] and the last element is 1. - If multiclass, this is a list of such tensors, one for each class. - recall: - tensor where element i is the recall of predictions with - score >= thresholds[i] and the last element is 0. - If multiclass, this is a list of such tensors, one for each class. - thresholds: - Thresholds used for computing precision/recall scores - - Example (binary case): - - >>> from pytorch_lightning.metrics.functional import precision_recall_curve - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 0]) - >>> precision, recall, thresholds = precision_recall_curve(pred, target, pos_label=1) - >>> precision - tensor([0.6667, 0.5000, 0.0000, 1.0000]) - >>> recall - tensor([1.0000, 0.5000, 0.0000, 0.0000]) - >>> thresholds - tensor([1, 2, 3]) - - Example (multiclass case): - - >>> from pytorch_lightning.metrics.functional import precision_recall_curve - >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.05, 0.75, 0.05]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> precision, recall, thresholds = precision_recall_curve(pred, target, num_classes=5) - >>> precision # doctest: +NORMALIZE_WHITESPACE - [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), - tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] - >>> recall - [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] - >>> thresholds - [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] + .. deprecated:: + Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0. """ - preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label) - return _precision_recall_curve_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/pytorch_lightning/metrics/functional/psnr.py b/pytorch_lightning/metrics/functional/psnr.py index bd513d4ca21dd..51be9d47b91f9 100644 --- a/pytorch_lightning/metrics/functional/psnr.py +++ b/pytorch_lightning/metrics/functional/psnr.py @@ -14,46 +14,12 @@ from typing import Optional, Tuple, Union import torch +from torchmetrics.functional import psnr as _psnr -from pytorch_lightning import utilities -from pytorch_lightning.metrics import utils - - -def _psnr_compute( - sum_squared_error: torch.Tensor, - n_obs: torch.Tensor, - data_range: torch.Tensor, - base: float = 10.0, - reduction: str = 'elementwise_mean', -) -> torch.Tensor: - psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs) - psnr = psnr_base_e * (10 / torch.log(torch.tensor(base))) - return utils.reduce(psnr, reduction=reduction) - - -def _psnr_update(preds: torch.Tensor, - target: torch.Tensor, - dim: Optional[Union[int, Tuple[int, ...]]] = None) -> Tuple[torch.Tensor, torch.Tensor]: - if dim is None: - sum_squared_error = torch.sum(torch.pow(preds - target, 2)) - n_obs = torch.tensor(target.numel(), device=target.device) - return sum_squared_error, n_obs - - sum_squared_error = torch.sum(torch.pow(preds - target, 2), dim=dim) - - if isinstance(dim, int): - dim_list = [dim] - else: - dim_list = list(dim) - if not dim_list: - n_obs = torch.tensor(target.numel(), device=target.device) - else: - n_obs = torch.tensor(target.size(), device=target.device)[dim_list].prod() - n_obs = n_obs.expand_as(sum_squared_error) - - return sum_squared_error, n_obs +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_psnr) def psnr( preds: torch.Tensor, target: torch.Tensor, @@ -63,46 +29,6 @@ def psnr( dim: Optional[Union[int, Tuple[int, ...]]] = None, ) -> torch.Tensor: """ - Computes the peak signal-to-noise ratio - - Args: - preds: estimated signal - target: groun truth signal - data_range: - the range of the data. If None, it is determined from the data (max - min). ``data_range`` must be given - when ``dim`` is not None. - base: a base of a logarithm to use (default: 10) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - dim: - Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is - None meaning scores will be reduced across all dimensions. - Return: - Tensor with PSNR score - - Example: - >>> from pytorch_lightning.metrics.functional import psnr - >>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) - >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) - >>> psnr(pred, target) - tensor(2.5527) - + .. deprecated:: + Use :func:`torchmetrics.functional.psnr`. Will be removed in v1.5.0. """ - if dim is None and reduction != 'elementwise_mean': - utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') - - if data_range is None: - if dim is not None: - # Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to calculate - # `data_range` in the future. - raise ValueError("The `data_range` must be given when `dim` is not None.") - - data_range = target.max() - target.min() - else: - data_range = torch.tensor(float(data_range)) - sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim) - return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction) diff --git a/pytorch_lightning/metrics/functional/r2score.py b/pytorch_lightning/metrics/functional/r2score.py index ef8a20c806ee9..fe4b541989358 100644 --- a/pytorch_lightning/metrics/functional/r2score.py +++ b/pytorch_lightning/metrics/functional/r2score.py @@ -11,121 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch +from torchmetrics.functional import r2score as _r2score -from pytorch_lightning.metrics.utils import _check_same_shape -from pytorch_lightning.utilities import rank_zero_warn - - -def _r2score_update( - preds: torch.tensor, - target: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - _check_same_shape(preds, target) - if preds.ndim > 2: - raise ValueError( - 'Expected both prediction and target to be 1D or 2D tensors,' - f' but recevied tensors with dimension {preds.shape}' - ) - if len(preds) < 2: - raise ValueError('Needs atleast two samples to calculate r2 score.') - - sum_error = torch.sum(target, dim=0) - sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0) - residual = torch.sum(torch.pow(target - preds, 2.0), dim=0) - total = target.size(0) - - return sum_squared_error, sum_error, residual, total - - -def _r2score_compute( - sum_squared_error: torch.Tensor, - sum_error: torch.Tensor, - residual: torch.Tensor, - total: torch.Tensor, - adjusted: int = 0, - multioutput: str = "uniform_average" -) -> torch.Tensor: - mean_error = sum_error / total - diff = sum_squared_error - sum_error * mean_error - raw_scores = 1 - (residual / diff) - - if multioutput == "raw_values": - r2score = raw_scores - elif multioutput == "uniform_average": - r2score = torch.mean(raw_scores) - elif multioutput == "variance_weighted": - diff_sum = torch.sum(diff) - r2score = torch.sum(diff / diff_sum * raw_scores) - else: - raise ValueError( - 'Argument `multioutput` must be either `raw_values`,' - f' `uniform_average` or `variance_weighted`. Received {multioutput}.' - ) - - if adjusted < 0 or not isinstance(adjusted, int): - raise ValueError('`adjusted` parameter should be an integer larger or' ' equal to 0.') - - if adjusted != 0: - if adjusted > total - 1: - rank_zero_warn( - "More independent regressions than datapoints in" - " adjusted r2 score. Falls back to standard r2 score.", UserWarning - ) - elif adjusted == total - 1: - rank_zero_warn("Division by zero in adjusted r2 score. Falls back to" " standard r2 score.", UserWarning) - else: - r2score = 1 - (1 - r2score) * (total - 1) / (total - adjusted - 1) - return r2score +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_r2score) def r2score( preds: torch.Tensor, target: torch.Tensor, adjusted: int = 0, multioutput: str = "uniform_average", ) -> torch.Tensor: - r""" - Computes r2 score also known as `coefficient of determination - `_: - - .. math:: R^2 = 1 - \frac{SS_res}{SS_tot} - - where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and - :math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate - adjusted r2 score given by - - .. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1} - - where the parameter :math:`k` (the number of independent regressors) should - be provided as the ``adjusted`` argument. - - Args: - preds: estimated labels - target: ground truth labels - adjusted: number of independent regressors for calculating adjusted r2 score. - Default 0 (standard r2 score). - multioutput: Defines aggregation in the case of multiple output scores. Can be one - of the following strings (default is ``'uniform_average'``.): - - * ``'raw_values'`` returns full set of scores - * ``'uniform_average'`` scores are uniformly averaged - * ``'variance_weighted'`` scores are weighted by their individual variances - - Example: - - >>> from pytorch_lightning.metrics.functional import r2score - >>> target = torch.tensor([3, -0.5, 2, 7]) - >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> r2score(preds, target) - tensor(0.9486) - - >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) - >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) - >>> r2score(preds, target, multioutput='raw_values') - tensor([0.9654, 0.9082]) """ - sum_squared_error, sum_error, residual, total = _r2score_update(preds, target) - return _r2score_compute(sum_squared_error, sum_error, residual, total, adjusted, multioutput) + .. deprecated:: + Use :func:`torchmetrics.functional.r2score`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/roc.py b/pytorch_lightning/metrics/functional/roc.py index 030c974365807..928a0b40fca54 100644 --- a/pytorch_lightning/metrics/functional/roc.py +++ b/pytorch_lightning/metrics/functional/roc.py @@ -13,135 +13,21 @@ # limitations under the License. from typing import List, Optional, Sequence, Tuple, Union -import torch +from torch import Tensor +from torchmetrics.functional import roc as _roc -from pytorch_lightning.metrics.functional.precision_recall_curve import ( - _binary_clf_curve, - _precision_recall_curve_update, -) - - -def _roc_update( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor, int, int]: - return _precision_recall_curve_update(preds, target, num_classes, pos_label) - - -def _roc_compute( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - pos_label: int, - sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: - - if num_classes == 1: - fps, tps, thresholds = _binary_clf_curve( - preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label - ) - # Add an extra threshold position - # to make sure that the curve starts at (0, 0) - tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps]) - fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps]) - thresholds = torch.cat([thresholds[0][None] + 1, thresholds]) - - if fps[-1] <= 0: - raise ValueError("No negative samples in targets, false positive value should be meaningless") - fpr = fps / fps[-1] - - if tps[-1] <= 0: - raise ValueError("No positive samples in targets, true positive value should be meaningless") - tpr = tps / tps[-1] - - return fpr, tpr, thresholds - - # Recursively call per class - fpr, tpr, thresholds = [], [], [] - for c in range(num_classes): - preds_c = preds[:, c] - res = roc( - preds=preds_c, - target=target, - num_classes=1, - pos_label=c, - sample_weights=sample_weights, - ) - fpr.append(res[0]) - tpr.append(res[1]) - thresholds.append(res[2]) - - return fpr, tpr, thresholds +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_roc) def roc( - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """ - Computes the Receiver Operating Characteristic (ROC). - - Args: - preds: predictions from model (logits or probabilities) - target: ground truth values - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - sample_weights: sample weights for each data point - - Returns: - 3-element tuple containing - - fpr: - tensor with false positive rates. - If multiclass, this is a list of such tensors, one for each class. - tpr: - tensor with true positive rates. - If multiclass, this is a list of such tensors, one for each class. - thresholds: - thresholds used for computing false- and true postive rates - - Example (binary case): - - >>> from pytorch_lightning.metrics.functional import roc - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 1]) - >>> fpr, tpr, thresholds = roc(pred, target, pos_label=1) - >>> fpr - tensor([0., 0., 0., 0., 1.]) - >>> tpr - tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) - >>> thresholds - tensor([4, 3, 2, 1, 0]) - - Example (multiclass case): - - >>> from pytorch_lightning.metrics.functional import roc - >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05], - ... [0.05, 0.05, 0.05, 0.75]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> fpr, tpr, thresholds = roc(pred, target, num_classes=4) - >>> fpr - [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] - >>> tpr - [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])] - >>> thresholds # doctest: +NORMALIZE_WHITESPACE - [tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500])] + .. deprecated:: + Use :func:`torchmetrics.functional.roc`. Will be removed in v1.5.0. """ - preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes, pos_label) - return _roc_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/pytorch_lightning/metrics/functional/self_supervised.py b/pytorch_lightning/metrics/functional/self_supervised.py index ed00677bb32d3..65dec211e938a 100644 --- a/pytorch_lightning/metrics/functional/self_supervised.py +++ b/pytorch_lightning/metrics/functional/self_supervised.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from torchmetrics.functional import embedding_similarity as _embedding_similarity +from pytorch_lightning.metrics.utils import deprecated_metrics + +@deprecated_metrics(target=_embedding_similarity) def embedding_similarity( batch: torch.Tensor, similarity: str = 'cosine', @@ -21,39 +25,6 @@ def embedding_similarity( zero_diagonal: bool = True ) -> torch.Tensor: """ - Computes representation similarity - - Example: - >>> from pytorch_lightning.metrics.functional import embedding_similarity - >>> embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]]) - >>> embedding_similarity(embeddings) - tensor([[0.0000, 1.0000, 0.9759], - [1.0000, 0.0000, 0.9759], - [0.9759, 0.9759, 0.0000]]) - - Args: - batch: (batch, dim) - similarity: 'dot' or 'cosine' - reduction: 'none', 'sum', 'mean' (all along dim -1) - zero_diagonal: if True, the diagonals are set to zero - - Return: - A square matrix (batch, batch) with the similarity scores between all elements - If sum or mean are used, then returns (b, 1) with the reduced value for each row + .. deprecated:: + Use :func:`torchmetrics.functional.embedding_similarity`. Will be removed in v1.5.0. """ - if similarity == 'cosine': - norm = torch.norm(batch, p=2, dim=1) - batch = batch / norm.unsqueeze(1) - - sqr_mtx = batch.mm(batch.transpose(1, 0)) - - if zero_diagonal: - sqr_mtx = sqr_mtx.fill_diagonal_(0) - - if reduction == 'mean': - sqr_mtx = sqr_mtx.mean(dim=-1) - - if reduction == 'sum': - sqr_mtx = sqr_mtx.sum(dim=-1) - - return sqr_mtx diff --git a/pytorch_lightning/metrics/functional/ssim.py b/pytorch_lightning/metrics/functional/ssim.py index 459c1855f6fee..31cff7fcfb9b4 100644 --- a/pytorch_lightning/metrics/functional/ssim.py +++ b/pytorch_lightning/metrics/functional/ssim.py @@ -11,107 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence import torch -from torch.nn import functional as F +from torchmetrics.functional import ssim as _ssim -from pytorch_lightning.metrics.utils import _check_same_shape, reduce - - -def _gaussian(kernel_size: int, sigma: int, dtype: torch.dtype, device: torch.device): - dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device) - gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2) - return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) - - -def _gaussian_kernel( - channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device -): - gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device) - gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device) - kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) - - return kernel.expand(channel, 1, kernel_size[0], kernel_size[1]) - - -def _ssim_update( - preds: torch.Tensor, - target: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - if preds.dtype != target.dtype: - raise TypeError( - "Expected `preds` and `target` to have the same data type." - f" Got pred: {preds.dtype} and target: {target.dtype}." - ) - _check_same_shape(preds, target) - if len(preds.shape) != 4: - raise ValueError( - "Expected `preds` and `target` to have BxCxHxW shape." - f" Got pred: {preds.shape} and target: {target.shape}." - ) - return preds, target - - -def _ssim_compute( - preds: torch.Tensor, - target: torch.Tensor, - kernel_size: Sequence[int] = (11, 11), - sigma: Sequence[float] = (1.5, 1.5), - reduction: str = "elementwise_mean", - data_range: Optional[float] = None, - k1: float = 0.01, - k2: float = 0.03, -): - if len(kernel_size) != 2 or len(sigma) != 2: - raise ValueError( - "Expected `kernel_size` and `sigma` to have the length of two." - f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}." - ) - - if any(x % 2 == 0 or x <= 0 for x in kernel_size): - raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.") - - if any(y <= 0 for y in sigma): - raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.") - - if data_range is None: - data_range = max(preds.max() - preds.min(), target.max() - target.min()) - - c1 = pow(k1 * data_range, 2) - c2 = pow(k2 * data_range, 2) - device = preds.device - - channel = preds.size(1) - dtype = preds.dtype - kernel = _gaussian_kernel(channel, kernel_size, sigma, dtype, device) - pad_w = (kernel_size[0] - 1) // 2 - pad_h = (kernel_size[1] - 1) // 2 - - preds = F.pad(preds, (pad_w, pad_w, pad_h, pad_h), mode='reflect') - target = F.pad(target, (pad_w, pad_w, pad_h, pad_h), mode='reflect') - - input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) - outputs = F.conv2d(input_list, kernel, groups=channel) - output_list = [outputs[x * preds.size(0):(x + 1) * preds.size(0)] for x in range(len(outputs))] - - mu_pred_sq = output_list[0].pow(2) - mu_target_sq = output_list[1].pow(2) - mu_pred_target = output_list[0] * output_list[1] - - sigma_pred_sq = output_list[2] - mu_pred_sq - sigma_target_sq = output_list[3] - mu_target_sq - sigma_pred_target = output_list[4] - mu_pred_target - - upper = 2 * sigma_pred_target + c2 - lower = sigma_pred_sq + sigma_target_sq + c2 - - ssim_idx = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower) - ssim_idx = ssim_idx[..., pad_h:-pad_h, pad_w:-pad_w] - - return reduce(ssim_idx, reduction) +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_ssim) def ssim( preds: torch.Tensor, target: torch.Tensor, @@ -123,32 +31,6 @@ def ssim( k2: float = 0.03, ) -> torch.Tensor: """ - Computes Structual Similarity Index Measure - - Args: - preds: estimated image - target: ground truth image - kernel_size: size of the gaussian kernel (default: (11, 11)) - sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - data_range: Range of the image. If ``None``, it is determined from the image (max - min) - k1: Parameter of SSIM. Default: 0.01 - k2: Parameter of SSIM. Default: 0.03 - - Return: - Tensor with SSIM score - - Example: - >>> from pytorch_lightning.metrics.functional import ssim - >>> preds = torch.rand([16, 1, 16, 16]) - >>> target = preds * 0.75 - >>> ssim(preds, target) - tensor(0.9219) + .. deprecated:: + Use :func:`torchmetrics.functional.ssim`. Will be removed in v1.5.0. """ - preds, target = _ssim_update(preds, target) - return _ssim_compute(preds, target, kernel_size, sigma, reduction, data_range, k1, k2) diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py index 44b4434f4dcf1..30c03da237fe6 100644 --- a/pytorch_lightning/metrics/functional/stat_scores.py +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -11,131 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch +from torchmetrics.functional import stat_scores as _stat_scores -from pytorch_lightning.metrics.classification.helpers import _input_format_classification - - -def _del_column(tensor: torch.Tensor, index: int): - """ Delete the column at index.""" - - return torch.cat([tensor[:, :index], tensor[:, (index + 1):]], 1) - - -def _stat_scores( - preds: torch.Tensor, - target: torch.Tensor, - reduce: str = "micro", -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Calculate the number of tp, fp, tn, fn. - - Args: - preds: - An ``(N, C)`` or ``(N, C, X)`` tensor of predictions (0 or 1) - target: - An ``(N, C)`` or ``(N, C, X)`` tensor of true labels (0 or 1) - reduce: - One of ``'micro'``, ``'macro'``, ``'samples'`` - - Return: - Returns a list of 4 tensors; tp, fp, tn, fn. - The shape of the returned tensors depnds on the shape of the inputs - and the ``reduce`` parameter: - - If inputs are of the shape ``(N, C)``, then - - If ``reduce='micro'``, the returned tensors are 1 element tensors - - If ``reduce='macro'``, the returned tensors are ``(C,)`` tensors - - If ``reduce'samples'``, the returned tensors are ``(N,)`` tensors - - If inputs are of the shape ``(N, C, X)``, then - - If ``reduce='micro'``, the returned tensors are ``(N,)`` tensors - - If ``reduce='macro'``, the returned tensors are ``(N,C)`` tensors - - If ``reduce='samples'``, the returned tensors are ``(N,X)`` tensors - """ - if reduce == "micro": - dim = [0, 1] if preds.ndim == 2 else [1, 2] - elif reduce == "macro": - dim = 0 if preds.ndim == 2 else 2 - elif reduce == "samples": - dim = 1 - - true_pred, false_pred = target == preds, target != preds - pos_pred, neg_pred = preds == 1, preds == 0 - - tp = (true_pred * pos_pred).sum(dim=dim) - fp = (false_pred * pos_pred).sum(dim=dim) - - tn = (true_pred * neg_pred).sum(dim=dim) - fn = (false_pred * neg_pred).sum(dim=dim) - - return tp.long(), fp.long(), tn.long(), fn.long() - - -def _stat_scores_update( - preds: torch.Tensor, - target: torch.Tensor, - reduce: str = "micro", - mdmc_reduce: Optional[str] = None, - num_classes: Optional[int] = None, - top_k: Optional[int] = None, - threshold: float = 0.5, - is_multiclass: Optional[bool] = None, - ignore_index: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - - preds, target, _ = _input_format_classification( - preds, target, threshold=threshold, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k - ) - - if ignore_index is not None and not 0 <= ignore_index < preds.shape[1]: - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {preds.shape[0]} classes") - - if ignore_index is not None and preds.shape[1] == 1: - raise ValueError("You can not use `ignore_index` with binary data.") - - if preds.ndim == 3: - if not mdmc_reduce: - raise ValueError( - "When your inputs are multi-dimensional multi-class, you have to set the `mdmc_reduce` parameter" - ) - if mdmc_reduce == "global": - preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) - target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) - - # Delete what is in ignore_index, if applicable (and classes don't matter): - if ignore_index is not None and reduce != "macro": - preds = _del_column(preds, ignore_index) - target = _del_column(target, ignore_index) - - tp, fp, tn, fn = _stat_scores(preds, target, reduce=reduce) - - # Take care of ignore_index - if ignore_index is not None and reduce == "macro": - tp[..., ignore_index] = -1 - fp[..., ignore_index] = -1 - tn[..., ignore_index] = -1 - fn[..., ignore_index] = -1 - - return tp, fp, tn, fn - - -def _stat_scores_compute(tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, fn: torch.Tensor) -> torch.Tensor: - - outputs = [ - tp.unsqueeze(-1), - fp.unsqueeze(-1), - tn.unsqueeze(-1), - fn.unsqueeze(-1), - tp.unsqueeze(-1) + fn.unsqueeze(-1), # support - ] - outputs = torch.cat(outputs, -1) - outputs = torch.where(outputs < 0, torch.tensor(-1, device=outputs.device), outputs) - - return outputs +from pytorch_lightning.metrics.utils import deprecated_metrics +@deprecated_metrics(target=_stat_scores) def stat_scores( preds: torch.Tensor, target: torch.Tensor, @@ -147,137 +31,7 @@ def stat_scores( is_multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, ) -> torch.Tensor: - """Computes the number of true positives, false positives, true negatives, false negatives. - Related to `Type I and Type II errors `__ - and the `confusion matrix `__. - - The reduction method (how the statistics are aggregated) is controlled by the - ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the - multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. - - Args: - preds: Predictions from model (probabilities or labels) - target: Ground truth values - threshold: - Threshold probability value for transforming probability predictions to binary - (0 or 1) predictions, in the case of binary or multi-label inputs. - - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. - - Should be left unset (``None``) for inputs with label predictions. - - reduce: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Counts the statistics by summing over all [sample, class] - combinations (globally). Each statistic is represented by a single integer. - - ``'macro'``: Counts the statistics for each class separately (over all samples). - Each statistic is represented by a ``(C,)`` tensor. Requires ``num_classes`` - to be set. - - ``'samples'``: Counts the statistics for each sample separately (over all classes). - Each statistic is represented by a ``(N, )`` 1d tensor. - - Note that what is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_reduce``. - - num_classes: - Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. - - ignore_index: - Specify a class (label) to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and - ``reduce='macro'``, the class statistics for the ignored class will all be returned - as ``-1``. - - mdmc_reduce: - Defines how the multi-dimensional multi-class inputs are handeled. Should be - one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class (see :ref:`extensions/metrics:input types` for the definition of input types). - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then the outputs are concatenated together. In each - sample the extra axes ``...`` are flattened to become the sub-sample axis, and - statistics for each sample are computed by treating the sub-sample axis as the - ``N`` axis for that sample. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs are - flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``reduce`` parameter applies as usual. - - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. - - Return: - The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds - to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The - shape depends on the ``reduce`` and ``mdmc_reduce`` (in case of multi-dimensional - multi-class data) parameters: - - - If the data is not multi-dimensional multi-class, then - - - If ``reduce='micro'``, the shape will be ``(5, )`` - - If ``reduce='macro'``, the shape will be ``(C, 5)``, - where ``C`` stands for the number of classes - - If ``reduce='samples'``, the shape will be ``(N, 5)``, where ``N`` stands for - the number of samples - - - If the data is multi-dimensional multi-class and ``mdmc_reduce='global'``, then - - - If ``reduce='micro'``, the shape will be ``(5, )`` - - If ``reduce='macro'``, the shape will be ``(C, 5)`` - - If ``reduce='samples'``, the shape will be ``(N*X, 5)``, where ``X`` stands for - the product of sizes of all "extra" dimensions of the data (i.e. all dimensions - except for ``C`` and ``N``) - - - If the data is multi-dimensional multi-class and ``mdmc_reduce='samplewise'``, then - - - If ``reduce='micro'``, the shape will be ``(N, 5)`` - - If ``reduce='macro'``, the shape will be ``(N, C, 5)`` - - If ``reduce='samples'``, the shape will be ``(N, X, 5)`` - - Example: - - >>> from pytorch_lightning.metrics.functional import stat_scores - >>> preds = torch.tensor([1, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> stat_scores(preds, target, reduce='macro', num_classes=3) - tensor([[0, 1, 2, 1, 1], - [1, 1, 1, 1, 2], - [1, 0, 3, 0, 1]]) - >>> stat_scores(preds, target, reduce='micro') - tensor([2, 2, 6, 2, 4]) """ - - if reduce not in ["micro", "macro", "samples"]: - raise ValueError(f"The `reduce` {reduce} is not valid.") - - if mdmc_reduce not in [None, "samplewise", "global"]: - raise ValueError(f"The `mdmc_reduce` {mdmc_reduce} is not valid.") - - if reduce == "macro" and (not num_classes or num_classes < 1): - raise ValueError("When you set `reduce` as 'macro', you have to provide the number of classes.") - - if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - tp, fp, tn, fn = _stat_scores_update( - preds, - target, - reduce=reduce, - mdmc_reduce=mdmc_reduce, - top_k=top_k, - threshold=threshold, - num_classes=num_classes, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - ) - return _stat_scores_compute(tp, fp, tn, fn) + .. deprecated:: + Use :func:`torchmetrics.functional.stat_scores`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 3ff3039cb99b1..ee0fcdb8a92e1 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -11,54 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -import inspect -from abc import ABC, abstractmethod -from collections.abc import Sequence -from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import torch -from torch import nn +from torchmetrics import Metric as _Metric +from torchmetrics.collections import MetricCollection as _MetricCollection -from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum -from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.distributed import gather_all_tensors +from pytorch_lightning.metrics.utils import deprecated_metrics -class Metric(nn.Module, ABC): - """ - Base class for all metrics present in the Metrics API. - - Implements ``add_state()``, ``forward()``, ``reset()`` and a few other things to - handle distributed synchronization and per-step metric computation. - - Override ``update()`` and ``compute()`` functions to implement your own metric. Use - ``add_state()`` to register metric state variables which keep track of state on each - call of ``update()`` and are synchronized across processes when ``compute()`` is called. - - Note: - Metric state variables can either be ``torch.Tensors`` or an empty list which can we used - to store `torch.Tensors``. - - Note: - Different metrics only override ``update()`` and not ``forward()``. A call to ``update()`` - is valid, but it won't return the metric value at the current step. A call to ``forward()`` - automatically calls ``update()`` and also returns the metric value at the current step. - - Args: - compute_on_step: - Forward only calls ``update()`` and returns None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. default: None - """ +class Metric(_Metric): + @deprecated_metrics(target=_Metric) def __init__( self, compute_on_step: bool = True, @@ -66,559 +29,17 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__() - - self.dist_sync_on_step = dist_sync_on_step - self.compute_on_step = compute_on_step - self.process_group = process_group - self.dist_sync_fn = dist_sync_fn - self._to_sync = True - - self._update_signature = inspect.signature(self.update) - self.update = self._wrap_update(self.update) - self.compute = self._wrap_compute(self.compute) - self._computed = None - self._forward_cache = None - - # initialize state - self._defaults = {} - self._persistent = {} - self._reductions = {} - - def add_state( - self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = False - ): - """ - Adds metric state variable. Only used by subclasses. - - Args: - name: The name of the state variable. The variable will then be accessible at ``self.name``. - default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be - reset to this value when ``self.reset()`` is called. - dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode. - If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, - and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction - only makes sense if the state is a list, and not a tensor. The user can also pass a custom - function in this parameter. - persistent (Optional): whether the state will be saved as part of the modules ``state_dict``. - Default is ``False``. - - Note: - Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes. - However, there won't be any reduction function applied to the synchronized metric state. - - The metric states would be synced as follows - - - If the metric state is ``torch.Tensor``, the synced value will be a stacked ``torch.Tensor`` across - the process dimension if the metric state was a ``torch.Tensor``. The original ``torch.Tensor`` metric - state retains dimension and hence the synchronized output will be of shape ``(num_process, ...)``. - - - If the metric state is a ``list``, the synced value will be a ``list`` containing the - combined elements from all processes. - - Note: - When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow - the format discussed in the above note. - - Raises: - ValueError: - If ``default`` is not a ``tensor`` or an ``empty list``. - ValueError: - If ``dist_reduce_fx`` is not callable or one of ``"mean"``, ``"sum"``, ``"cat"``, ``None``. - """ - if ( - not isinstance(default, torch.Tensor) and not isinstance(default, list) # noqa: W503 - or (isinstance(default, list) and len(default) != 0) # noqa: W503 - ): - raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)") - - if dist_reduce_fx == "sum": - dist_reduce_fx = dim_zero_sum - elif dist_reduce_fx == "mean": - dist_reduce_fx = dim_zero_mean - elif dist_reduce_fx == "cat": - dist_reduce_fx = dim_zero_cat - elif dist_reduce_fx is not None and not isinstance(dist_reduce_fx, Callable): - raise ValueError("`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]") - - setattr(self, name, default) - - self._defaults[name] = deepcopy(default) - self._persistent[name] = persistent - self._reductions[name] = dist_reduce_fx - - @torch.jit.unused - def forward(self, *args, **kwargs): - """ - Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True. - """ - # add current step - with torch.no_grad(): - self.update(*args, **kwargs) - self._forward_cache = None - - if self.compute_on_step: - self._to_sync = self.dist_sync_on_step - - # save context before switch - cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} - - # call reset, update, compute, on single batch - self.reset() - self.update(*args, **kwargs) - self._forward_cache = self.compute() - - # restore context - for attr, val in cache.items(): - setattr(self, attr, val) - self._to_sync = True - self._computed = None - - return self._forward_cache - - def _sync_dist(self, dist_sync_fn=gather_all_tensors): - input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()} - output_dict = apply_to_collection( - input_dict, - torch.Tensor, - dist_sync_fn, - group=self.process_group, - ) - - for attr, reduction_fn in self._reductions.items(): - # pre-processing ops (stack or flatten for inputs) - if isinstance(output_dict[attr][0], torch.Tensor): - output_dict[attr] = torch.stack(output_dict[attr]) - elif isinstance(output_dict[attr][0], list): - output_dict[attr] = _flatten(output_dict[attr]) - - assert isinstance(reduction_fn, (Callable)) or reduction_fn is None - reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr] - setattr(self, attr, reduced) - - def _wrap_update(self, update): - - @functools.wraps(update) - def wrapped_func(*args, **kwargs): - self._computed = None - return update(*args, **kwargs) - - return wrapped_func - - def _wrap_compute(self, compute): - - @functools.wraps(compute) - def wrapped_func(*args, **kwargs): - # return cached value - if self._computed is not None: - return self._computed - - dist_sync_fn = self.dist_sync_fn - if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized(): - # User provided a bool, so we assume DDP if available - dist_sync_fn = gather_all_tensors - - synced = False - if self._to_sync and dist_sync_fn is not None: - # cache prior to syncing - cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} - - # sync - self._sync_dist(dist_sync_fn) - synced = True - - self._computed = compute(*args, **kwargs) - if synced: - # if we synced, restore to cache so that we can continue to accumulate un-synced state - for attr, val in cache.items(): - setattr(self, attr, val) - - return self._computed - - return wrapped_func - - @abstractmethod - def update(self) -> None: # pylint: disable=E0202 - """ - Override this method to update the state variables of your metric class. - """ - pass - - @abstractmethod - def compute(self): # pylint: disable=E0202 - """ - Override this method to compute the final metric value from state variables - synchronized across the distributed backend. - """ - pass - - def reset(self): - """ - This method automatically resets the metric state variables to their default value. - """ - for attr, default in self._defaults.items(): - current_val = getattr(self, attr) - if isinstance(default, torch.Tensor): - setattr(self, attr, deepcopy(default).to(current_val.device)) - else: - setattr(self, attr, deepcopy(default)) - - def clone(self): - """ Make a copy of the metric """ - return deepcopy(self) - - def __getstate__(self): - # ignore update and compute functions for pickling - return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]} - - def __setstate__(self, state): - # manually restore update and compute functions for pickling - self.__dict__.update(state) - self.update = self._wrap_update(self.update) - self.compute = self._wrap_compute(self.compute) - - def _apply(self, fn): - """Overwrite _apply function such that we can also move metric states - to the correct device when `.to`, `.cuda`, etc methods are called + r""" + .. deprecated:: + Use :class:`torchmetrics.Metric`. Will be removed in v1.5.0. """ - self = super()._apply(fn) - # Also apply fn to metric states - for key in self._defaults.keys(): - current_val = getattr(self, key) - if isinstance(current_val, torch.Tensor): - setattr(self, key, fn(current_val)) - elif isinstance(current_val, Sequence): - setattr(self, key, [fn(cur_v) for cur_v in current_val]) - else: - raise TypeError( - "Expected metric state to be either a torch.Tensor" - f"or a list of torch.Tensor, but encountered {current_val}" - ) - return self - - def persistent(self, mode: bool = False): - """Method for post-init to change if metric states should be saved to - its state_dict - """ - for key in self._persistent.keys(): - self._persistent[key] = mode - - def state_dict(self, destination=None, prefix='', keep_vars=False): - destination = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - # Register metric states to be part of the state_dict - for key in self._defaults.keys(): - if self._persistent[key]: - current_val = getattr(self, key) - if not keep_vars: - if torch.is_tensor(current_val): - current_val = current_val.detach() - elif isinstance(current_val, list): - current_val = [cur_v.detach() if torch.is_tensor(cur_v) else cur_v for cur_v in current_val] - destination[prefix + key] = current_val - return destination - - def _filter_kwargs(self, **kwargs): - """ filter kwargs such that they match the update signature of the metric """ - - # filter all parameters based on update signature except those of - # type VAR_POSITIONAL (*args) and VAR_KEYWORD (**kwargs) - _params = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) - filtered_kwargs = { - k: v - for k, v in kwargs.items() if k in self._update_signature.parameters.keys() - and self._update_signature.parameters[k].kind not in _params - } - - # if no kwargs filtered, return al kwargs as default - if not filtered_kwargs: - filtered_kwargs = kwargs - return filtered_kwargs - - def __hash__(self): - hash_vals = [self.__class__.__name__] - - for key in self._defaults.keys(): - val = getattr(self, key) - # Special case: allow list values, so long - # as their elements are hashable - if hasattr(val, '__iter__') and not isinstance(val, torch.Tensor): - hash_vals.extend(val) - else: - hash_vals.append(val) - - return hash(tuple(hash_vals)) - - def __add__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.add, self, other) - - def __and__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.bitwise_and, self, other) - - def __eq__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.eq, self, other) - - def __floordiv__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.floor_divide, self, other) - - def __ge__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.ge, self, other) - def __gt__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.gt, self, other) - - def __le__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.le, self, other) - - def __lt__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.lt, self, other) - - def __matmul__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.matmul, self, other) - - def __mod__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.fmod, self, other) - - def __mul__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.mul, self, other) - - def __ne__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.ne, self, other) - - def __or__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.bitwise_or, self, other) - - def __pow__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.pow, self, other) - - def __radd__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.add, other, self) - - def __rand__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - # swap them since bitwise_and only supports that way and it's commutative - return CompositionalMetric(torch.bitwise_and, self, other) - - def __rfloordiv__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.floor_divide, other, self) - - def __rmatmul__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.matmul, other, self) - - def __rmod__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.fmod, other, self) - - def __rmul__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.mul, other, self) - - def __ror__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.bitwise_or, other, self) - - def __rpow__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.pow, other, self) - - def __rsub__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.sub, other, self) - - def __rtruediv__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.true_divide, other, self) - - def __rxor__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.bitwise_xor, other, self) - - def __sub__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.sub, self, other) - - def __truediv__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.true_divide, self, other) - - def __xor__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.bitwise_xor, self, other) - - def __abs__(self): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.abs, self, None) - - def __inv__(self): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.bitwise_not, self, None) - - def __invert__(self): - return self.__inv__() - - def __neg__(self): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(_neg, self, None) - - def __pos__(self): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.abs, self, None) - - -def _neg(tensor: torch.Tensor): - return -torch.abs(tensor) - - -class MetricCollection(nn.ModuleDict): - """ - MetricCollection class can be used to chain metrics that have the same - call pattern into one single class. - - Args: - metrics: One of the following - - * list or tuple: if metrics are passed in as a list, will use the - metrics class name as key for output dict. Therefore, two metrics - of the same class cannot be chained this way. - - * dict: if metrics are passed in as a dict, will use each key in the - dict as key for output dict. Use this format if you want to chain - together multiple of the same metric with different parameters. - - Raises: - ValueError: - If one of the elements of ``metrics`` is not an instance of ``pl.metrics.Metric``. - ValueError: - If two elements in ``metrics`` have the same ``name``. - ValueError: - If ``metrics`` is not a ``list``, ``tuple`` or a ``dict``. - - Example (input as list): - - >>> from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall - >>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) - >>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) - >>> metrics = MetricCollection([Accuracy(), - ... Precision(num_classes=3, average='macro'), - ... Recall(num_classes=3, average='macro')]) - >>> metrics(preds, target) - {'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)} - - Example (input as dict): - - >>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'), - ... 'macro_recall': Recall(num_classes=3, average='macro')}) - >>> metrics(preds, target) - {'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)} - - """ +class MetricCollection(_MetricCollection): + @deprecated_metrics(target=_MetricCollection) def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]): - super().__init__() - if isinstance(metrics, dict): - # Check all values are metrics - for name, metric in metrics.items(): - if not isinstance(metric, Metric): - raise ValueError( - f"Value {metric} belonging to key {name}" - " is not an instance of `pl.metrics.Metric`" - ) - self[name] = metric - elif isinstance(metrics, (tuple, list)): - for metric in metrics: - if not isinstance(metric, Metric): - raise ValueError( - f"Input {metric} to `MetricCollection` is not a instance" - " of `pl.metrics.Metric`" - ) - name = metric.__class__.__name__ - if name in self: - raise ValueError(f"Encountered two metrics both named {name}") - self[name] = metric - else: - raise ValueError("Unknown input to MetricCollection.") - - def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202 - """ - Iteratively call forward for each metric. Positional arguments (args) will - be passed to every metric in the collection, while keyword arguments (kwargs) - will be filtered based on the signature of the individual metric. - """ - return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()} - - def update(self, *args, **kwargs): # pylint: disable=E0202 - """ - Iteratively call update for each metric. Positional arguments (args) will - be passed to every metric in the collection, while keyword arguments (kwargs) - will be filtered based on the signature of the individual metric. """ - for _, m in self.items(): - m_kwargs = m._filter_kwargs(**kwargs) - m.update(*args, **m_kwargs) - - def compute(self) -> Dict[str, Any]: - return {k: m.compute() for k, m in self.items()} - - def reset(self): - """ Iteratively call reset for each metric """ - for _, m in self.items(): - m.reset() - - def clone(self): - """ Make a copy of the metric collection """ - return deepcopy(self) - - def persistent(self, mode: bool = True): - """Method for post-init to change if metric states should be saved to - its state_dict + .. deprecated:: + Use :class:`torchmetrics.MetricCollection`. Will be removed in v1.5.0. """ - for _, m in self.items(): - m.persistent(mode) diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py index fc033fcd16759..0f94ae2fb3754 100644 --- a/pytorch_lightning/metrics/regression/explained_variance.py +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -13,72 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch +from torchmetrics import ExplainedVariance as _ExplainedVariance -from pytorch_lightning.metrics.functional.explained_variance import ( - _explained_variance_compute, - _explained_variance_update, -) -from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.metrics.utils import deprecated_metrics -class ExplainedVariance(Metric): - r""" - Computes `explained variance - `_: - - .. math:: \text{ExplainedVariance} = 1 - \frac{\text{Var}(y - \hat{y})}{\text{Var}(y)} - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a - tensor of predictions. - - Forward accepts - - - ``preds`` (float tensor): ``(N,)`` or ``(N, ...)`` (multioutput) - - ``target`` (long tensor): ``(N,)`` or ``(N, ...)`` (multioutput) - - In the case of multioutput, as default the variances will be uniformly - averaged over the additional dimensions. Please see argument `multioutput` - for changing this behavior. - - Args: - multioutput: - Defines aggregation in the case of multiple output scores. Can be one - of the following strings (default is `'uniform_average'`.): - - * `'raw_values'` returns full set of scores - * `'uniform_average'` scores are uniformly averaged - * `'variance_weighted'` scores are weighted by their individual variances - - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Raises: - ValueError: - If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``. - - Example: - - >>> from pytorch_lightning.metrics import ExplainedVariance - >>> target = torch.tensor([3, -0.5, 2, 7]) - >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> explained_variance = ExplainedVariance() - >>> explained_variance(preds, target) - tensor(0.9572) - - >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) - >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) - >>> explained_variance = ExplainedVariance(multioutput='raw_values') - >>> explained_variance(preds, target) - tensor([0.9677, 1.0000]) - """ +class ExplainedVariance(_ExplainedVariance): + @deprecated_metrics(target=_ExplainedVariance) def __init__( self, multioutput: str = 'uniform_average', @@ -87,43 +29,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') - if multioutput not in allowed_multioutput: - raise ValueError( - f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}' - ) - self.multioutput = multioutput - self.add_state("y", default=[], dist_reduce_fx=None) - self.add_state("y_pred", default=[], dist_reduce_fx=None) - - rank_zero_warn( - 'Metric `ExplainedVariance` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' - ) - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - preds, target = _explained_variance_update(preds, target) - self.y_pred.append(preds) - self.y.append(target) + This implementation refers to :class:`~torchmetrics.ExplainedVariance`. - def compute(self): - """ - Computes explained variance over state. + .. deprecated:: + Use :class:`~torchmetrics.ExplainedVariance`. Will be removed in v1.5.0. """ - preds = torch.cat(self.y_pred, dim=0) - target = torch.cat(self.y, dim=0) - return _explained_variance_compute(preds, target, self.multioutput) diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py index ca184daf736b8..57c7db420445b 100644 --- a/pytorch_lightning/metrics/regression/mean_absolute_error.py +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -13,42 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch +from torchmetrics import MeanAbsoluteError as _MeanAbsoluteError -from pytorch_lightning.metrics.functional.mean_absolute_error import ( - _mean_absolute_error_compute, - _mean_absolute_error_update, -) -from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.utils import deprecated_metrics -class MeanAbsoluteError(Metric): - r""" - Computes `mean absolute error `_ (MAE): - - .. math:: \text{MAE} = \frac{1}{N}\sum_i^N | y_i - \hat{y_i} | - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - - Args: - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example: - - >>> from pytorch_lightning.metrics import MeanAbsoluteError - >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) - >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> mean_absolute_error = MeanAbsoluteError() - >>> mean_absolute_error(preds, target) - tensor(0.5000) - """ +class MeanAbsoluteError(_MeanAbsoluteError): + @deprecated_metrics(target=_MeanAbsoluteError) def __init__( self, compute_on_step: bool = True, @@ -56,31 +28,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - sum_abs_error, n_obs = _mean_absolute_error_update(preds, target) + This implementation refers to :class:`~torchmetrics.MeanAbsoluteError`. - self.sum_abs_error += sum_abs_error - self.total += n_obs - - def compute(self): - """ - Computes mean absolute error over state. + .. deprecated:: + Use :class:`~torchmetrics.MeanAbsoluteError`. Will be removed in v1.5.0. """ - return _mean_absolute_error_compute(self.sum_abs_error, self.total) diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py index 09f275ded8638..c8e9c151c99d9 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -13,43 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch +from torchmetrics import MeanSquaredError as _MeanSquaredError -from pytorch_lightning.metrics.functional.mean_squared_error import ( - _mean_squared_error_compute, - _mean_squared_error_update, -) -from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.utils import deprecated_metrics -class MeanSquaredError(Metric): - r""" - Computes `mean squared error `_ (MSE): - - .. math:: \text{MSE} = \frac{1}{N}\sum_i^N(y_i - \hat{y_i})^2 - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - - Args: - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example: - - >>> from pytorch_lightning.metrics import MeanSquaredError - >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) - >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) - >>> mean_squared_error = MeanSquaredError() - >>> mean_squared_error(preds, target) - tensor(0.8750) - - """ +class MeanSquaredError(_MeanSquaredError): + @deprecated_metrics(target=_MeanSquaredError) def __init__( self, compute_on_step: bool = True, @@ -57,31 +28,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - sum_squared_error, n_obs = _mean_squared_error_update(preds, target) - - self.sum_squared_error += sum_squared_error - self.total += n_obs + This implementation refers to :class:`~torchmetrics.MeanSquaredError`. - def compute(self): - """ - Computes mean squared error over state. + .. deprecated:: + Use :class:`~torchmetrics.MeanSquaredError`. Will be removed in v1.5.0. """ - return _mean_squared_error_compute(self.sum_squared_error, self.total) diff --git a/pytorch_lightning/metrics/regression/mean_squared_log_error.py b/pytorch_lightning/metrics/regression/mean_squared_log_error.py index 18105e687b0b1..c8ee8a7069115 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_log_error.py @@ -13,45 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch +from torchmetrics import MeanSquaredLogError as _MeanSquaredLogError -from pytorch_lightning.metrics.functional.mean_squared_log_error import ( - _mean_squared_log_error_compute, - _mean_squared_log_error_update, -) -from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.utils import deprecated_metrics -class MeanSquaredLogError(Metric): - r""" - Computes `mean squared logarithmic error - `_ - (MSLE): - - .. math:: \text{MSLE} = \frac{1}{N}\sum_i^N (\log_e(1 + y_i) - \log_e(1 + \hat{y_i}))^2 - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - - Args: - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example: - - >>> from pytorch_lightning.metrics import MeanSquaredLogError - >>> target = torch.tensor([2.5, 5, 4, 8]) - >>> preds = torch.tensor([3, 5, 2.5, 7]) - >>> mean_squared_log_error = MeanSquaredLogError() - >>> mean_squared_log_error(preds, target) - tensor(0.0397) - - """ +class MeanSquaredLogError(_MeanSquaredLogError): + @deprecated_metrics(target=_MeanSquaredLogError) def __init__( self, compute_on_step: bool = True, @@ -59,31 +28,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - sum_squared_log_error, n_obs = _mean_squared_log_error_update(preds, target) - - self.sum_squared_log_error += sum_squared_log_error - self.total += n_obs + This implementation refers to :class:`~torchmetrics.MeanSquaredLogError`. - def compute(self): - """ - Compute mean squared logarithmic error over state. + .. deprecated:: + Use :class:`~torchmetrics.MeanSquaredLogError`. Will be removed in v1.5.0. """ - return _mean_squared_log_error_compute(self.sum_squared_log_error, self.total) diff --git a/pytorch_lightning/metrics/regression/psnr.py b/pytorch_lightning/metrics/regression/psnr.py index 8a38bf515ebca..f972e9a8e2b5e 100644 --- a/pytorch_lightning/metrics/regression/psnr.py +++ b/pytorch_lightning/metrics/regression/psnr.py @@ -11,61 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Tuple, Union -import torch +from torchmetrics import PSNR as _PSNR -from pytorch_lightning import utilities -from pytorch_lightning.metrics.functional.psnr import _psnr_compute, _psnr_update -from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.utils import deprecated_metrics -class PSNR(Metric): - r""" - Computes `peak signal-to-noise ratio `_ (PSNR): - - .. math:: \text{PSNR}(I, J) = 10 * \log_{10} \left(\frac{\max(I)^2}{\text{MSE}(I, J)}\right) - - Where :math:`\text{MSE}` denotes the `mean-squared-error - `_ function. - - Args: - data_range: - the range of the data. If None, it is determined from the data (max - min). - The ``data_range`` must be given when ``dim`` is not None. - base: a base of a logarithm to use (default: 10) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - dim: - Dimensions to reduce PSNR scores over, provided as either an integer or a list of integers. Default is - None meaning scores will be reduced across all dimensions and all batches. - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Raises: - ValueError: - If ``dim`` is not ``None`` and ``data_range`` is not given. - - Example: - - >>> from pytorch_lightning.metrics import PSNR - >>> psnr = PSNR() - >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) - >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) - >>> psnr(preds, target) - tensor(2.5527) - - """ +class PSNR(_PSNR): + @deprecated_metrics(target=_PSNR) def __init__( self, data_range: Optional[float] = None, @@ -76,71 +31,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - ) - - if dim is None and reduction != 'elementwise_mean': - utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') - - if dim is None: - self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - else: - self.add_state("sum_squared_error", default=[]) - self.add_state("total", default=[]) - - if data_range is None: - if dim is not None: - # Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to - # calculate `data_range` in the future. - raise ValueError("The `data_range` must be given when `dim` is not None.") - - self.data_range = None - self.add_state("min_target", default=torch.tensor(0.0), dist_reduce_fx=torch.min) - self.add_state("max_target", default=torch.tensor(0.0), dist_reduce_fx=torch.max) - else: - self.register_buffer("data_range", torch.tensor(float(data_range))) - self.base = base - self.reduction = reduction - self.dim = tuple(dim) if isinstance(dim, Sequence) else dim - - def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. + This implementation refers to :class:`~torchmetrics.PSNR`. - Args: - preds: Predictions from model - target: Ground truth values + .. deprecated:: + Use :class:`~torchmetrics.PSNR`. Will be removed in v1.5.0. """ - sum_squared_error, n_obs = _psnr_update(preds, target, dim=self.dim) - if self.dim is None: - if self.data_range is None: - # keep track of min and max target values - self.min_target = min(target.min(), self.min_target) - self.max_target = max(target.max(), self.max_target) - - self.sum_squared_error += sum_squared_error - self.total += n_obs - else: - self.sum_squared_error.append(sum_squared_error) - self.total.append(n_obs) - - def compute(self): - """ - Compute peak signal-to-noise ratio over state. - """ - if self.data_range is not None: - data_range = self.data_range - else: - data_range = self.max_target - self.min_target - - if self.dim is None: - sum_squared_error = self.sum_squared_error - total = self.total - else: - sum_squared_error = torch.cat([values.flatten() for values in self.sum_squared_error]) - total = torch.cat([values.flatten() for values in self.total]) - return _psnr_compute(sum_squared_error, total, data_range, base=self.base, reduction=self.reduction) diff --git a/pytorch_lightning/metrics/regression/r2score.py b/pytorch_lightning/metrics/regression/r2score.py index 40d9d24711375..ad5f7f3bd8d07 100644 --- a/pytorch_lightning/metrics/regression/r2score.py +++ b/pytorch_lightning/metrics/regression/r2score.py @@ -13,81 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch +from torchmetrics import R2Score as _R2Score -from pytorch_lightning.metrics.functional.r2score import _r2score_compute, _r2score_update -from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.utils import deprecated_metrics -class R2Score(Metric): - r""" - Computes r2 score also known as `coefficient of determination - `_: - - .. math:: R^2 = 1 - \frac{SS_res}{SS_tot} - - where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and - :math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate - adjusted r2 score given by - - .. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1} - - where the parameter :math:`k` (the number of independent regressors) should - be provided as the `adjusted` argument. - - Forward accepts - - - ``preds`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput) - - ``target`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput) - - In the case of multioutput, as default the variances will be uniformly - averaged over the additional dimensions. Please see argument `multioutput` - for changing this behavior. - - Args: - num_outputs: - Number of outputs in multioutput setting (default is 1) - adjusted: - number of independent regressors for calculating adjusted r2 score. - Default 0 (standard r2 score). - multioutput: - Defines aggregation in the case of multiple output scores. Can be one - of the following strings (default is ``'uniform_average'``.): - - * ``'raw_values'`` returns full set of scores - * ``'uniform_average'`` scores are uniformly averaged - * ``'variance_weighted'`` scores are weighted by their individual variances - - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Raises: - ValueError: - If ``adjusted`` parameter is not an integer larger or equal to 0. - ValueError: - If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``. - - Example: - - >>> from pytorch_lightning.metrics import R2Score - >>> target = torch.tensor([3, -0.5, 2, 7]) - >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> r2score = R2Score() - >>> r2score(preds, target) - tensor(0.9486) - - >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) - >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) - >>> r2score = R2Score(num_outputs=2, multioutput='raw_values') - >>> r2score(preds, target) - tensor([0.9654, 0.9082]) - """ +class R2Score(_R2Score): + @deprecated_metrics(target=_R2Score) def __init__( self, num_outputs: int = 1, @@ -98,50 +31,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.num_outputs = num_outputs - - if adjusted < 0 or not isinstance(adjusted, int): - raise ValueError('`adjusted` parameter should be an integer larger or equal to 0.') - self.adjusted = adjusted - - allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') - if multioutput not in allowed_multioutput: - raise ValueError( - f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}' - ) - self.multioutput = multioutput - - self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") - self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") - self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - sum_squared_error, sum_error, residual, total = _r2score_update(preds, target) + This implementation refers to :class:`~torchmetrics.R2Score`. - self.sum_squared_error += sum_squared_error - self.sum_error += sum_error - self.residual += residual - self.total += total - - def compute(self) -> torch.Tensor: - """ - Computes r2 score over the metric states. + .. deprecated:: + Use :class:`~torchmetrics.R2Score`. Will be removed in v1.5.0. """ - return _r2score_compute( - self.sum_squared_error, self.sum_error, self.residual, self.total, self.adjusted, self.multioutput - ) diff --git a/pytorch_lightning/metrics/regression/ssim.py b/pytorch_lightning/metrics/regression/ssim.py index 09b55fb2bb456..cf5571f3e68f4 100644 --- a/pytorch_lightning/metrics/regression/ssim.py +++ b/pytorch_lightning/metrics/regression/ssim.py @@ -13,43 +13,14 @@ # limitations under the License. from typing import Any, Optional, Sequence -import torch +from torchmetrics import SSIM as _SSIM -from pytorch_lightning.metrics.functional.ssim import _ssim_compute, _ssim_update -from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.metrics.utils import deprecated_metrics -class SSIM(Metric): - """ - Computes `Structual Similarity Index Measure - `_ (SSIM). - - Args: - kernel_size: size of the gaussian kernel (default: (11, 11)) - sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - data_range: Range of the image. If ``None``, it is determined from the image (max - min) - k1: Parameter of SSIM. Default: 0.01 - k2: Parameter of SSIM. Default: 0.03 - - Return: - Tensor with SSIM score - - Example: - >>> from pytorch_lightning.metrics import SSIM - >>> preds = torch.rand([16, 1, 16, 16]) - >>> target = preds * 0.75 - >>> ssim = SSIM() - >>> ssim(preds, target) - tensor(0.9219) - """ +class SSIM(_SSIM): + @deprecated_metrics(target=_SSIM) def __init__( self, kernel_size: Sequence[int] = (11, 11), @@ -62,44 +33,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - ) - rank_zero_warn( - 'Metric `SSIM` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' - ) - - self.add_state("y", default=[], dist_reduce_fx=None) - self.add_state("y_pred", default=[], dist_reduce_fx=None) - self.kernel_size = kernel_size - self.sigma = sigma - self.data_range = data_range - self.k1 = k1 - self.k2 = k2 - self.reduction = reduction - - def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. + This implementation refers to :class:`~torchmetrics.SSIM`. - Args: - preds: Predictions from model - target: Ground truth values - """ - preds, target = _ssim_update(preds, target) - self.y_pred.append(preds) - self.y.append(target) - - def compute(self): - """ - Computes explained variance over state. + .. deprecated:: + Use :class:`~torchmetrics.SSIM`. Will be removed in v1.5.0. """ - preds = torch.cat(self.y_pred, dim=0) - target = torch.cat(self.y, dim=0) - return _ssim_compute( - preds, target, self.kernel_size, self.sigma, self.reduction, self.data_range, self.k1, self.k2 - ) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index cd0713fde0173..4adc88a37ba21 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -11,293 +11,86 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from functools import partial +from typing import Optional import torch +from deprecate import deprecated +from torchmetrics.utilities.data import dim_zero_cat as _dim_zero_cat +from torchmetrics.utilities.data import dim_zero_mean as _dim_zero_mean +from torchmetrics.utilities.data import dim_zero_sum as _dim_zero_sum +from torchmetrics.utilities.data import get_num_classes as _get_num_classes +from torchmetrics.utilities.data import select_topk as _select_topk +from torchmetrics.utilities.data import to_categorical as _to_categorical +from torchmetrics.utilities.data import to_onehot as _to_onehot +from torchmetrics.utilities.distributed import class_reduce as _class_reduce +from torchmetrics.utilities.distributed import reduce as _reduce -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation -METRIC_EPS = 1e-6 +deprecated_metrics = partial(deprecated, deprecated_in="1.3.0", remove_in="1.5.0", stream=rank_zero_deprecation) +@deprecated_metrics(target=_dim_zero_cat) def dim_zero_cat(x): - x = x if isinstance(x, (list, tuple)) else [x] - return torch.cat(x, dim=0) + pass +@deprecated_metrics(target=_dim_zero_sum) def dim_zero_sum(x): - return torch.sum(x, dim=0) + pass +@deprecated_metrics(target=_dim_zero_mean) def dim_zero_mean(x): - return torch.mean(x, dim=0) + pass -def _flatten(x): - return [item for sublist in x for item in sublist] - - -def _check_same_shape(pred: torch.Tensor, target: torch.Tensor): - """ Check that predictions and target have the same shape, else raise error """ - if pred.shape != target.shape: - raise RuntimeError("Predictions and targets are expected to have the same shape") - - -def _input_format_classification_one_hot( - num_classes: int, - preds: torch.Tensor, - target: torch.Tensor, - threshold: float = 0.5, - multilabel: bool = False -) -> Tuple[torch.Tensor, torch.Tensor]: - """Convert preds and target tensors into one hot spare label tensors - - Args: - num_classes: number of classes - preds: either tensor with labels, tensor with probabilities/logits or - multilabel tensor - target: tensor with ground true labels - threshold: float used for thresholding multilabel input - multilabel: boolean flag indicating if input is multilabel - - Raises: - ValueError: - If ``preds`` and ``target`` don't have the same number of dimensions - or one additional dimension for ``preds``. - - Returns: - preds: one hot tensor of shape [num_classes, -1] with predicted labels - target: one hot tensors of shape [num_classes, -1] with true labels +@deprecated_metrics(target=_to_onehot) +def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> torch.Tensor: """ - if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1): - raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") - - if preds.ndim == target.ndim + 1: - # multi class probabilites - preds = torch.argmax(preds, dim=1) - - if preds.ndim == target.ndim and preds.dtype in (torch.long, torch.int) and num_classes > 1 and not multilabel: - # multi-class - preds = to_onehot(preds, num_classes=num_classes) - target = to_onehot(target, num_classes=num_classes) - - elif preds.ndim == target.ndim and preds.is_floating_point(): - # binary or multilabel probablities - preds = (preds >= threshold).long() - - # transpose class as first dim and reshape - if preds.ndim > 1: - preds = preds.transpose(1, 0) - target = target.transpose(1, 0) - - return preds.reshape(num_classes, -1), target.reshape(num_classes, -1) - - -def to_onehot( - label_tensor: torch.Tensor, - num_classes: Optional[int] = None, -) -> torch.Tensor: + .. deprecated:: + Use :func:`torchmetrics.utilities.data.to_onehot`. Will be removed in v1.5.0. """ - Converts a dense label tensor to one-hot format - - Args: - label_tensor: dense label tensor, with shape [N, d1, d2, ...] - num_classes: number of classes C - - Returns: - A sparse label tensor with shape [N, C, d1, d2, ...] - - Example: - - >>> from pytorch_lightning.metrics.utils import to_onehot - >>> x = torch.tensor([1, 2, 3]) - >>> to_onehot(x) - tensor([[0, 1, 0, 0], - [0, 0, 1, 0], - [0, 0, 0, 1]]) - """ - if num_classes is None: - num_classes = int(label_tensor.max().detach().item() + 1) - - tensor_onehot = torch.zeros( - label_tensor.shape[0], - num_classes, - *label_tensor.shape[1:], - dtype=label_tensor.dtype, - device=label_tensor.device, - ) - index = label_tensor.long().unsqueeze(1).expand_as(tensor_onehot) - return tensor_onehot.scatter_(1, index, 1.0) +@deprecated_metrics(target=_select_topk) def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: """ - Convert a probability tensor to binary by selecting top-k highest entries. - - Args: - prob_tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the - position defined by the ``dim`` argument - topk: number of highest entries to turn into 1s - dim: dimension on which to compare entries - - Returns: - A binary tensor of the same shape as the input tensor of type torch.int32 - - Example: - - >>> from pytorch_lightning.metrics.utils import select_topk - >>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) - >>> select_topk(x, topk=2) - tensor([[0, 1, 1], - [1, 1, 0]], dtype=torch.int32) + .. deprecated:: + Use :func:`torchmetrics.utilities.data.select_topk`. Will be removed in v1.5.0. """ - zeros = torch.zeros_like(prob_tensor) - topk_tensor = zeros.scatter(dim, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) - return topk_tensor.int() +@deprecated_metrics(target=_to_categorical) def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: """ - Converts a tensor of probabilities to a dense label tensor - - Args: - tensor: probabilities to get the categorical label [N, d1, d2, ...] - argmax_dim: dimension to apply - - Return: - A tensor with categorical labels [N, d2, ...] - - Example: - - >>> from pytorch_lightning.metrics.utils import to_categorical - >>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) - >>> to_categorical(x) - tensor([1, 0]) + .. deprecated:: + Use :func:`torchmetrics.utilities.data.to_categorical`. Will be removed in v1.5.0. """ - return torch.argmax(tensor, dim=argmax_dim) -def get_num_classes( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, -) -> int: +@deprecated_metrics(target=_get_num_classes) +def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None) -> int: """ - Calculates the number of classes for a given prediction and target tensor. - - Args: - pred: predicted values - target: true labels - num_classes: number of classes if known - - Return: - An integer that represents the number of classes. + .. deprecated:: + Use :func:`torchmetrics.utilities.data.get_num_classes`. Will be removed in v1.5.0. """ - num_target_classes = int(target.max().detach().item() + 1) - num_pred_classes = int(pred.max().detach().item() + 1) - num_all_classes = max(num_target_classes, num_pred_classes) - - if num_classes is None: - num_classes = num_all_classes - elif num_classes != num_all_classes: - rank_zero_warn( - f"You have set {num_classes} number of classes which is" - f" different from predicted ({num_pred_classes}) and" - f" target ({num_target_classes}) number of classes", - RuntimeWarning, - ) - return num_classes +@deprecated_metrics(target=_reduce) def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: """ - Reduces a given tensor by a given reduction method - - Args: - to_reduce : the tensor, which shall be reduced - reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum') - - Return: - reduced Tensor - - Raise: - ValueError if an invalid reduction parameter was given + .. deprecated:: + Use :func:`torchmetrics.utilities.reduce`. Will be removed in v1.5.0. """ - if reduction == "elementwise_mean": - return torch.mean(to_reduce) - if reduction == "none": - return to_reduce - if reduction == "sum": - return torch.sum(to_reduce) - raise ValueError("Reduction parameter unknown.") +@deprecated_metrics(target=_class_reduce) def class_reduce( num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none" ) -> torch.Tensor: """ - Function used to reduce classification metrics of the form `num / denom * weights`. - For example for calculating standard accuracy the num would be number of - true positives per class, denom would be the support per class, and weights - would be a tensor of 1s - - Args: - num: numerator tensor - denom: denominator tensor - weights: weights for each class - class_reduction: reduction method for multiclass problems - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'`` or ``None``: returns calculated metric per class - - Raises: - ValueError: - If ``class_reduction`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"`` or ``None``. - """ - valid_reduction = ("micro", "macro", "weighted", "none", None) - if class_reduction == "micro": - fraction = torch.sum(num) / torch.sum(denom) - else: - fraction = num / denom - - # We need to take care of instances where the denom can be 0 - # for some (or all) classes which will produce nans - fraction[fraction != fraction] = 0 - - if class_reduction == "micro": - return fraction - elif class_reduction == "macro": - return torch.mean(fraction) - elif class_reduction == "weighted": - return torch.sum(fraction * (weights.float() / torch.sum(weights))) - elif class_reduction == "none" or class_reduction is None: - return fraction - - raise ValueError( - f"Reduction parameter {class_reduction} unknown." - f" Choose between one of these: {valid_reduction}" - ) - - -def _stable_1d_sort(x: torch, N: int = 2049): - """ - Stable sort of 1d tensors. Pytorch defaults to a stable sorting algorithm - if number of elements are larger than 2048. This function pads the tensors, - makes the sort and returns the sorted array (with the padding removed) - See this discussion: https://discuss.pytorch.org/t/is-torch-sort-stable/20714 - - Raises: - ValueError: - If dim of ``x`` is greater than 1 since stable sort works with only 1d tensors. + .. deprecated:: + Use :func:`torchmetrics.utilities.class_reduce`. Will be removed in v1.5.0. """ - if x.ndim > 1: - raise ValueError('Stable sort only works on 1d tensors') - n = x.numel() - if N - n > 0: - x_max = x.max() - x = torch.cat([x, (x_max + 1) * torch.ones(N - n, dtype=x.dtype, device=x.device)], 0) - x_sort = x.sort() - i = min(N, n) - return x_sort.values[:i], x_sort.indices[:i] diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index 1d6f4e93b5779..0c1ac7b359fd0 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -53,7 +53,7 @@ def forward(self, *inputs, **kwargs): elif trainer and (trainer.sanity_checking or trainer.validating): output = self.module.validation_step(*inputs, **kwargs) elif trainer and trainer.predicting: - output = self.module.predict(*inputs, **kwargs) + output = self.module.predict_step(*inputs, **kwargs) else: output = self.module(*inputs, **kwargs) diff --git a/pytorch_lightning/overrides/torch_distributed.py b/pytorch_lightning/overrides/torch_distributed.py new file mode 100644 index 0000000000000..67b64c046dc18 --- /dev/null +++ b/pytorch_lightning/overrides/torch_distributed.py @@ -0,0 +1,94 @@ +import logging +import pickle + +import torch + +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 + +log = logging.getLogger(__name__) + +if torch.distributed.is_available(): + from torch.distributed import Backend, broadcast, get_backend, get_rank, GroupMember + +# The code underneath is taken from PyTorch ``torch/distributed/distributed_c10d.py`` +# and enable broadcasting for PyTorch 1.6 and lower. + + +# https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L160 +def _rank_not_in_group(group): + """ + Helper that checks if the current process's rank is not in a given group. + """ + if group is None: + return False + return group == GroupMember.NON_GROUP_MEMBER + + +# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1164 +def _object_to_tensor(obj): + buffer = pickle.dumps(obj) + byte_storage = torch.ByteStorage.from_buffer(buffer) # type: ignore[attr-defined] + byte_tensor = torch.ByteTensor(byte_storage) + local_size = torch.LongTensor([byte_tensor.numel()]) + return byte_tensor, local_size + + +# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py +def _tensor_to_object(tensor, tensor_size): + buf = tensor.numpy().tobytes()[:tensor_size] + out = pickle.loads(buf) + return out + + +# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327 +def _broadcast_object_list(object_list, src=0, group=None): + if _rank_not_in_group(group): + return + + my_rank = get_rank() + # Serialize object_list elements to tensors on src rank. + if my_rank == src: + tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.LongTensor(len(object_list)) + + group_backend = get_backend(group) + is_nccl_backend = group_backend == Backend.NCCL + current_device = torch.device("cpu") + if is_nccl_backend: + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.device('cuda', torch.cuda.current_device()) + object_sizes_tensor = object_sizes_tensor.to(current_device) + object_sizes_tensor = object_sizes_tensor.to(current_device) + + # Broadcast object sizes + broadcast(object_sizes_tensor, src=src, group=group) + + # Concatenate and broadcast serialized object tensors + if my_rank == src: + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item()) + + if is_nccl_backend: + object_tensor = object_tensor.to(current_device) + + broadcast(object_tensor, src=src, group=group) + + # Deserialize objects using their stored sizes. + offset = 0 + if my_rank != src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset:offset + obj_size] + obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload] + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size) + + +if _TORCH_GREATER_EQUAL_1_7 and torch.distributed.is_available(): + from torch.distributed.distributed_c10d import broadcast_object_list +else: + broadcast_object_list = _broadcast_object_list diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index dec672d025294..a67235baa4767 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,6 +1,7 @@ from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401 from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401 @@ -29,6 +30,7 @@ "DDPSpawnPlugin", "DeepSpeedPlugin", "DeepSpeedPrecisionPlugin", + "DoublePrecisionPlugin", "HorovodPlugin", "NativeMixedPrecisionPlugin", "PrecisionPlugin", diff --git a/pytorch_lightning/plugins/precision/__init__.py b/pytorch_lightning/plugins/precision/__init__.py index fc60deffcbb77..d32aac829a13d 100644 --- a/pytorch_lightning/plugins/precision/__init__.py +++ b/pytorch_lightning/plugins/precision/__init__.py @@ -1,5 +1,6 @@ from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 75570e453ec1b..b600eca5e6bc2 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -169,5 +169,4 @@ def pre_optimizer_step( pl_module.trainer.call_hook("on_after_backward") optimizer.step(**kwargs) - return False diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py new file mode 100644 index 0000000000000..4720f0f874fd0 --- /dev/null +++ b/pytorch_lightning/plugins/precision/double.py @@ -0,0 +1,95 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import wraps +from typing import Any, Sequence, Tuple, TYPE_CHECKING, List + +import torch + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities.apply_func import apply_to_collection + +if TYPE_CHECKING: + from torch.nn import Module + from torch.optim import Optimizer + + +class _DoublePrecisionPatch: + """Class to handle patching of methods in the ``LightningModule`` and subsequent teardown.""" + + def __init__(self, model: 'Module', method_name: str, old_method: Any) -> None: + self.model = model + self.method_name = method_name + self.old_method = old_method + + def teardown(self) -> None: + setattr(self.model, self.method_name, self.old_method) + + @staticmethod + def _to_double_precision(data: torch.Tensor) -> torch.Tensor: + if data.is_floating_point(): + return data.double() + return data + + @staticmethod + def _move_float_tensors_to_double(collection: Any) -> Any: + return apply_to_collection( + collection, torch.Tensor, function=_DoublePrecisionPatch._to_double_precision + ) + + @classmethod + def patch(cls, model: 'Module', method_name: str) -> '_DoublePrecisionPatch': + old_method = getattr(model, method_name) + + @wraps(old_method) + def new_method(*args: Any, **kwargs: Any) -> Any: + return old_method( + *_DoublePrecisionPatch._move_float_tensors_to_double(args), + **_DoublePrecisionPatch._move_float_tensors_to_double(kwargs) + ) + + setattr(model, method_name, new_method if callable(old_method) else old_method) + return cls(model, method_name, old_method) + + +class DoublePrecisionPlugin(PrecisionPlugin): + """Plugin for training with double (``torch.float64``) precision.""" + + precision: int = 64 + + def __init__(self) -> None: + self.patches: List[_DoublePrecisionPatch] = [] + + def connect( + self, + model: 'Module', + optimizers: Sequence['Optimizer'], + lr_schedulers: Sequence[Any], + ) -> Tuple['Module', Sequence['Optimizer'], Sequence[Any]]: + """Converts the model to double precision and wraps the `training_step`, `validation_step`, `test_step`, + `predict_step`, and `forward` methods to convert incoming floating point data to double. Does not alter + `optimizers` or `lr_schedulers`.""" + model = model.to(dtype=torch.float64) + if isinstance(model, LightningModule): + self.patches.append(_DoublePrecisionPatch.patch(model, 'training_step')) + self.patches.append(_DoublePrecisionPatch.patch(model, 'validation_step')) + self.patches.append(_DoublePrecisionPatch.patch(model, 'test_step')) + self.patches.append(_DoublePrecisionPatch.patch(model, 'predict_step')) + self.patches.append(_DoublePrecisionPatch.patch(model, 'forward')) + + return super().connect(model, optimizers, lr_schedulers) + + def post_dispatch(self) -> None: + while len(self.patches) > 0: + self.patches.pop().teardown() diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index dc822680bcbda..3c83945c8a1b7 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -103,3 +103,21 @@ def train_step_context(self) -> Generator[None, None, None]: """Enable autocast context""" with torch.cuda.amp.autocast(): yield + + @contextmanager + def val_step_context(self) -> Generator[None, None, None]: + """Enable autocast context""" + with torch.cuda.amp.autocast(): + yield + + @contextmanager + def test_step_context(self) -> Generator[None, None, None]: + """Enable autocast context""" + with torch.cuda.amp.autocast(): + yield + + @contextmanager + def predict_context(self) -> Generator[None, None, None]: + """Enable autocast context""" + with torch.cuda.amp.autocast(): + yield diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 2b1579cf497c0..7172d82391bd3 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -100,7 +100,6 @@ def post_optimizer_step(self, optimizer: 'Optimizer', optimizer_idx: int) -> Non def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: """Clips the gradients to a specific value""" - # TODO: separate TPU case from here if clip_val is None: return diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index f857ad50399cf..58e26e7db32d8 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -80,9 +80,7 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) return distributed_sampler_kwargs - def setup(self, model): - self._model = model - + def setup_environment(self): # start the other scripts if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": self._call_children_scripts() @@ -90,6 +88,8 @@ def setup(self, model): # set the task idx self.task_idx = self.cluster_environment.local_rank() + self.setup_distributed() + def _call_children_scripts(self): # bookkeeping of spawned processes @@ -161,6 +161,34 @@ def _call_children_scripts(self): delay = np.random.uniform(1, 5, 1)[0] sleep(delay) + def setup_distributed(self): + # TODO: check if needed + seed = os.environ.get("PL_GLOBAL_SEED") + if seed is not None: + seed_everything(int(seed)) + + # determine which process we are and world size + self.set_world_ranks() + + # set warning rank + rank_zero_only.rank = self.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + self.init_ddp_connection(self.global_rank, self.world_size) + + # on world_size=0 let everyone know training is starting + if self.is_global_zero and not torch.distributed.is_initialized(): + log.info("-" * 100) + log.info(f"distributed_backend={self.distributed_backend}") + log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") + log.info("-" * 100) + + # set the ranks and devices + self.dist.rank = self.global_rank + self.dist.device = self.root_device + def _check_can_spawn_children(self): if self._has_spawned_children: raise RuntimeError( @@ -179,9 +207,7 @@ def pre_configure_ddp(self): # Many models require setting this parameter to True, as there are corner cases # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. - self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( - "find_unused_parameters", True - ) + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( "find_unused_parameters", False @@ -215,37 +241,6 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) def pre_dispatch(self): - # TODO: check if needed - seed = os.environ.get("PL_GLOBAL_SEED") - if seed is not None: - seed_everything(int(seed)) - - # determine which process we are and world size - self.set_world_ranks() - - # set warning rank - rank_zero_only.rank = self.global_rank - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - self.init_ddp_connection(self.global_rank, self.world_size) - - # TODO: we moved it to the trainer.fit after calling pre_dispatch - # ... need to double check that it is the correct place - # self.trainer.call_setup_hook(self.model) - - # on world_size=0 let everyone know training is starting - if self.is_global_zero and not torch.distributed.is_initialized(): - log.info("-" * 100) - log.info(f"distributed_backend={self.distributed_backend}") - log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") - log.info("-" * 100) - - # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device - if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) @@ -303,7 +298,7 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.model(*args, **kwargs) - def predict(self, *args, **kwargs): + def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) def post_training_step(self): diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 3636b2fb92fa2..87d7fa5faecac 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -77,8 +77,6 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs def setup(self, model): - self._model = model - os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) # pass in a state q @@ -172,9 +170,7 @@ def pre_configure_ddp(self): # Many models require setting this parameter to True, as there are corner cases # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. - self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( - "find_unused_parameters", True - ) + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( "find_unused_parameters", False @@ -286,7 +282,7 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.model(*args, **kwargs) - def predict(self, *args, **kwargs): + def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) def post_training_step(self): diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index b54155d60eae5..b196044937414 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -192,17 +192,7 @@ def _load_config(self, config): return config def pre_dispatch(self): - self.set_world_ranks() - self.init_ddp_connection(self.global_rank, self.world_size) - self.init_deepspeed() - - # set warning rank - rank_zero_only.rank = self.global_rank - - # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device self.barrier() def init_deepspeed(self): diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 1d5398778c0df..a8e42e0fa747a 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -71,8 +71,8 @@ def barrier(self, *args, **kwargs): def broadcast(self, obj: object, src: int = 0) -> object: return obj - def reduce_early_stopping_decision(self, should_stop: bool) -> bool: - return should_stop + def reduce_boolean_decision(self, decision: bool) -> bool: + return decision def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) @@ -83,7 +83,7 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.model(*args, **kwargs) - def predict(self, *args, **kwargs): + def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) def training_step_end(self, output): diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 2fe3906cb01d0..8d0add27cbb29 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -21,7 +21,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE -from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp +from pytorch_lightning.utilities.distributed import group, rank_zero_only, ReduceOp if _HOROVOD_AVAILABLE: import horovod.torch as hvd @@ -96,14 +96,14 @@ def start_training(self, trainer): stack.enter_context(optimizer.skip_synchronize()) # set up training routine - self._results = trainer.run_train() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user hvd.join() def start_evaluating(self, trainer): with ExitStack(): - self._results = trainer.run_evaluate() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user hvd.join() @@ -111,7 +111,7 @@ def start_evaluating(self, trainer): def start_predicting(self, trainer): with ExitStack(): # set up training routine - self._results = trainer.run_predict() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user hvd.join() @@ -159,8 +159,13 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ hvd.join() return hvd.allreduce(tensor, op=reduce_op) - def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None): - if group is not None: + def all_gather( + self, + result: Union[torch.Tensor], + group: Optional[Any] = group.WORLD, + sync_grads: bool = False + ) -> torch.Tensor: + if group is not None and group != group.WORLD: raise ValueError( "Horovod does not support allgather using a subcommunicator at this time. " "Unset `group`." diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 715c5332e231c..d9a8e70588c43 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io import os from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import List, Optional +from typing import Any, List, Optional import torch from torch.nn.parallel import DistributedDataParallel @@ -36,9 +35,10 @@ def __init__( ): super().__init__() self.parallel_devices = parallel_devices + self.cluster_environment = cluster_environment + self.global_rank = 0 self.world_size = 1 self.local_rank = 0 - self.cluster_environment = cluster_environment @property @abstractmethod @@ -53,14 +53,6 @@ def on_gpu(self): def lightning_module(self): return unwrap_lightning_module(self._model) - @abstractmethod - def setup(self, model): - raise NotImplementedError - - def connect(self, model, *args, **kwargs): - self.setup(model) - return self.model - @property def is_global_zero(self) -> bool: return self.global_rank == 0 @@ -70,11 +62,15 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=len(self.parallel_devices), rank=self.global_rank) return distributed_sampler_kwargs - def reduce_early_stopping_decision(self, should_stop: bool) -> bool: - should_stop = torch.tensor(int(should_stop), device=self.lightning_module.device) - should_stop = self.reduce(should_stop, reduce_op=ReduceOp.SUM) - should_stop = bool(should_stop == self.world_size) - return should_stop + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + + def reduce_boolean_decision(self, decision: bool) -> bool: + decision = torch.tensor(int(decision), device=self.lightning_module.device) + decision = self.reduce(decision, reduce_op=ReduceOp.SUM) + decision = bool(decision == self.world_size) + return decision @property def torch_distributed_backend(self): @@ -112,13 +108,3 @@ def block_backward_sync(self): yield None else: yield None - - def broadcast(self, obj: object, src: int) -> object: - buffer = io.BytesIO() - torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - data_tensor = torch.tensor(data).to(self.root_device, dtype=torch.float) - data = all_gather_ddp_if_available(data_tensor) - buffer = io.BytesIO(data.cpu().byte().numpy()) - obj = torch.load(buffer) - return obj diff --git a/pytorch_lightning/plugins/training_type/rpc.py b/pytorch_lightning/plugins/training_type/rpc.py index faf528d76b768..3e0f57daef001 100644 --- a/pytorch_lightning/plugins/training_type/rpc.py +++ b/pytorch_lightning/plugins/training_type/rpc.py @@ -13,7 +13,7 @@ # limitations under the License. import os from contextlib import suppress -from typing import List, Optional, Callable +from typing import Callable, List, Optional import torch diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 336c16f0f1a03..ba26fc9f58ec5 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -13,7 +13,7 @@ # limitations under the License import logging import os -from typing import List, Optional, Callable +from typing import Callable, List, Optional import torch import torch.distributed as torch_distrib diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index d11ae87bed660..d70779adf3ba1 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Union +from typing import Any, Optional, Union import torch @@ -23,6 +23,9 @@ class SingleDevicePlugin(TrainingTypePlugin): def __init__(self, device: torch.device): super().__init__() self.device: torch.device = device + self.global_rank = 0 + self.local_rank = 0 + self.world_size = 1 @property def on_tpu(self) -> bool: @@ -47,6 +50,10 @@ def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> """ return tensor + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes """ + return tensor + @property def root_device(self) -> torch.device: return self.device @@ -57,8 +64,7 @@ def model_to_device(self) -> None: self._model.to(self.root_device) - def connect(self, model: torch.nn.Module) -> torch.nn.Module: - self._model = model + def setup(self, model: torch.nn.Module) -> torch.nn.Module: self.model_to_device() return self.model diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index d3cbd0d6b5d79..b8d670ff16881 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -39,13 +39,8 @@ def __init__(self, device: Union[torch.device, int]): def on_tpu(self) -> bool: return True - def connect(self, model: torch.nn.Module) -> torch.nn.Module: - self._model = model - self.model_to_device() - return self._model - def model_to_device(self) -> None: - self._model.to(self.root_device) + self.model.to(self.root_device) def pre_dispatch(self) -> None: if isinstance(self.device, int): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index e05a7bc03ef5c..a8706d54cb5c9 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -17,7 +17,6 @@ from typing import Any, Dict, Iterable, List, Optional, Union import torch -import torch.distributed as torch_distrib import torch.multiprocessing as mp from pytorch_lightning.core.lightning import LightningModule @@ -53,10 +52,9 @@ def __init__( self.tpu_local_core_rank = 0 self.start_method = None - def connect(self, model: torch.nn.Module) -> torch.nn.Module: + def setup(self, model: torch.nn.Module) -> torch.nn.Module: self.create_mp_queue() - self._model = model - return self._model + return self.model def create_mp_queue(self): self.start_method = 'fork' @@ -110,13 +108,15 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: # replace trainer save_checkpoint to use `xm.save` trainer.save_checkpoint = self.save_checkpoint - self.barrier() + self.barrier("pre-run-stage") results = trainer.run_stage() self.__save_end_of_training_weights(self.lightning_module) self.transfer_distrib_spawn_state_on_fit_end(results) + self.barrier("end-process") + def __save_end_of_training_weights(self, model: LightningModule) -> None: # when training ends on these platforms dump weights to get out of the main process if on_colab_kaggle(): @@ -127,11 +127,11 @@ def model_to_device(self) -> None: self._model.to(xm.xla_device()) def barrier(self, name: Optional[str] = None) -> None: - if torch_distrib.is_initialized(): - rendezvous(f"pl.Trainer.{name}") + rendezvous(name) def transfer_distrib_spawn_state_on_fit_end(self, results): - best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path + checkpoint_callback = self.lightning_module.trainer.checkpoint_callback + best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None if self.mp_queue is not None: rank_zero_warn("cleaning up ddp environment...") @@ -203,12 +203,11 @@ def save_spawn_weights(self, model: LightningModule) -> Optional[str]: model.trainer.save_checkpoint(path) return path - def reduce_early_stopping_decision(self, should_stop: bool) -> bool: - should_stop = torch.tensor(int(should_stop), device=self.lightning_module.device) - stop = xm.mesh_reduce('stop_signal', should_stop, sum) - rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") - should_stop = int(stop.item()) == self.world_size - return should_stop + def reduce_decision(self, decision: bool) -> bool: + decision = torch.tensor(int(decision), device=self.device) + decision = self.reduce(decision, "sum") + decision = bool(decision == self.world_size) + return decision def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): if not isinstance(output, torch.Tensor): @@ -296,8 +295,8 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.lightning_module.test_step(*args, **kwargs) - def predict(self, *args, **kwargs): - return self.lightning_module.predict(*args, **kwargs) + def predict_step(self, *args, **kwargs): + return self.lightning_module.predict_step(*args, **kwargs) def save_checkpoint(self, filepath, weights_only: bool = False): """Save model/training states as a checkpoint file through state-dump and file-write. diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 7783f066dbc61..08dca63a7c925 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -33,11 +33,20 @@ class TrainingTypePlugin(Plugin, ABC): def __init__(self) -> None: self._model = None self._results = None - self.global_rank = 0 - @abstractmethod def connect(self, model: 'Module') -> None: - """Called by the accelerator to connect it with this plugin""" + """Called by the accelerator to connect the accelerator and the model with this plugin""" + self.model = model + + def setup_environment(self) -> None: + """ + Setup any processes or distributed connections. + This is called before the LightningModule/DataModule setup hook + which allows the user to access the accelerator environment before setup is complete. + """ + + def setup(self, model: 'Module') -> None: + """Called by the accelerator to finish setup.""" @property @abstractmethod @@ -77,9 +86,13 @@ def barrier(self, name: Optional[str] = None) -> None: def broadcast(self, obj: object, src: int = 0) -> object: """Broadcasts an object to all processes""" - def reduce_early_stopping_decision(self, should_stop: bool) -> bool: - """Reduce the early stopping decision across all possibly spawned processes""" - return should_stop + @abstractmethod + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes """ + + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes""" + return decision def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run before precision plugin executes backward""" @@ -119,15 +132,15 @@ def rpc_enabled(self) -> bool: def start_training(self, trainer: 'Trainer') -> None: # double dispatch to initiate the training loop - self._results = trainer.run_train() + self._results = trainer.run_stage() def start_evaluating(self, trainer: 'Trainer') -> None: # double dispatch to initiate the test loop - self._results = trainer.run_evaluate() + self._results = trainer.run_stage() def start_predicting(self, trainer: 'Trainer') -> None: # double dispatch to initiate the predicting loop - self._results = trainer.run_predict() + self._results = trainer.run_stage() def training_step(self, *args, **kwargs): return self.lightning_module.training_step(*args, **kwargs) @@ -141,8 +154,8 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.lightning_module.test_step(*args, **kwargs) - def predict(self, *args, **kwargs): - return self.lightning_module.predict(*args, **kwargs) + def predict_step(self, *args, **kwargs): + return self.lightning_module.predict_step(*args, **kwargs) def training_step_end(self, output): return output @@ -169,3 +182,13 @@ def init_optimizers(self, trainer: "Trainer", model: LightningModule): def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): optimizer.step(closure=lambda_closure, **kwargs) + + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + """ + Override to delay setting optimizers and schedulers till after dispatch. + This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model. + However this may break certain precision plugins such as APEX which require optimizers to be set. + Returns: If True, delay setup optimizers till pre_dispatch, else call within setup. + """ + return False diff --git a/pytorch_lightning/plugins/training_type/utils.py b/pytorch_lightning/plugins/training_type/utils.py index 7380f871f59a5..eddb9077116dc 100644 --- a/pytorch_lightning/plugins/training_type/utils.py +++ b/pytorch_lightning/plugins/training_type/utils.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py index e09a5ea11a084..6ac6e16c18529 100644 --- a/pytorch_lightning/profiler/__init__.py +++ b/pytorch_lightning/profiler/__init__.py @@ -121,7 +121,8 @@ def custom_processing_step(self, data): Autograd includes a profiler that lets you inspect the cost of different operators inside your model - both on the CPU and GPU. -Find the Pytorch Profiler doc at [PyTorch Profiler](https://pytorch-lightning.readthedocs.io/en/stable/profiler.html) +To read more about the PyTorch Profiler and all its options, +have a look at its `docs `__ .. code-block:: python @@ -134,16 +135,16 @@ def custom_processing_step(self, data): This profiler works with PyTorch ``DistributedDataParallel``. -If ``output_filename`` is provided, each rank will save their profiled operation to their own file. +If ``filename`` is provided, each rank will save their profiled operation to their own file. The profiler +report can be quite long, so you setting a ``filename`` will save the report instead of logging it to the +output in your terminal. If no filename is given, it will be logged only on rank 0. +The profiler's results will be printed on the completion of ``{fit,validate,test,predict}``. -The profiler's results will be printed on the completion of a training `fit()`. This profiler -report can be quite long, so you can also specify an `output_filename` to save the report instead -of logging it to the output in your terminal. - -This profiler will record only for `training_step_and_backward`, `evaluation_step` and `test_step` functions by default. -The output below shows the profiling for the action `training_step_and_backward`. -The user can provide ``PyTorchProfiler(profiled_functions=[...])`` to extend the scope of profiled functions. +This profiler will record ``training_step_and_backward``, ``training_step``, ``backward``, +``validation_step``, ``test_step``, and ``predict_step`` by default. +The output below shows the profiling for the action ``training_step_and_backward``. +The user can provide ``PyTorchProfiler(record_functions={...})`` to extend the scope of profiled functions. .. note:: When using the PyTorch Profiler, wall clock time will not not be representative of the true wall clock time. This is due to forcing profiled operations to be measured synchronously, when many CUDA ops happen asynchronously. It is recommended to use this Profiler to find bottlenecks/breakdowns, however for end to end wall clock time use the `SimpleProfiler`. # noqa E501 @@ -184,13 +185,13 @@ def custom_processing_step(self, data): To visualize the profiled operation, you can either: -* Use:: +Use:: nvvp trace_name.prof -* Use:: +Or:: - python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))' + python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))' """ diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index d704ba83236c1..bc9e3541dbaa8 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -21,31 +21,19 @@ from abc import ABC, abstractmethod from collections import defaultdict from contextlib import contextmanager -from typing import Optional, Union +from pathlib import Path +from typing import Any, Callable, Dict, Optional, TextIO, Tuple, Union import numpy as np +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem log = logging.getLogger(__name__) -class BaseProfiler(ABC): - """ - If you wish to write a custom profiler, you should inhereit from this class. - """ - - def __init__(self, output_streams: Optional[Union[list, tuple]] = None): - """ - Args: - output_streams: callable - """ - if output_streams: - if not isinstance(output_streams, (list, tuple)): - output_streams = [output_streams] - else: - output_streams = [] - self.write_streams = output_streams +class AbstractProfiler(ABC): + """Specification of a profiler.""" @abstractmethod def start(self, action_name: str) -> None: @@ -55,6 +43,48 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: """Defines how to record the duration once an action is complete.""" + @abstractmethod + def summary(self) -> str: + """Create profiler summary in text format.""" + + @abstractmethod + def setup(self, **kwargs: Any) -> None: + """Execute arbitrary pre-profiling set-up steps as defined by subclass.""" + + @abstractmethod + def teardown(self, **kwargs: Any) -> None: + """Execute arbitrary post-profiling tear-down steps as defined by subclass.""" + + +class BaseProfiler(AbstractProfiler): + """ + If you wish to write a custom profiler, you should inherit from this class. + """ + + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + output_filename: Optional[str] = None, + ) -> None: + self.dirpath = dirpath + self.filename = filename + if output_filename is not None: + rank_zero_warn( + "`Profiler` signature has changed in v1.3. The `output_filename` parameter has been removed in" + " favor of `dirpath` and `filename`. Support for the old signature will be removed in v1.5", + DeprecationWarning + ) + filepath = Path(output_filename) + self.dirpath = filepath.parent + self.filename = filepath.stem + + self._output_file: Optional[TextIO] = None + self._write_stream: Optional[Callable] = None + self._local_rank: Optional[int] = None + self._log_dir: Optional[str] = None + self._stage: Optional[str] = None + @contextmanager def profile(self, action_name: str) -> None: """ @@ -86,17 +116,92 @@ def profile_iterable(self, iterable, action_name: str) -> None: self.stop(action_name) break + def _rank_zero_info(self, *args, **kwargs) -> None: + if self._local_rank in (None, 0): + log.info(*args, **kwargs) + + def _prepare_filename(self, extension: str = ".txt") -> str: + filename = "" + if self._stage is not None: + filename += f"{self._stage}-" + filename += str(self.filename) + if self._local_rank is not None: + filename += f"-{self._local_rank}" + filename += extension + return filename + + def _prepare_streams(self) -> None: + if self._write_stream is not None: + return + if self.filename: + filepath = os.path.join(self.dirpath, self._prepare_filename()) + fs = get_filesystem(filepath) + file = fs.open(filepath, "a") + self._output_file = file + self._write_stream = file.write + else: + self._write_stream = self._rank_zero_info + def describe(self) -> None: - """Logs a profile report after the conclusion of the training run.""" - for write in self.write_streams: - write(self.summary()) + """Logs a profile report after the conclusion of run.""" + # there are pickling issues with open file handles in Python 3.6 + # so to avoid them, we open and close the files within this function + # by calling `_prepare_streams` and `teardown` + self._prepare_streams() + self._write_stream(self.summary()) + if self._output_file is not None: + self._output_file.flush() + self.teardown(stage=self._stage) + + def _stats_to_str(self, stats: Dict[str, str]) -> str: + stage = f"{self._stage.upper()} " if self._stage is not None else "" + output = [stage + "Profiler Report"] + for action, value in stats.items(): + header = f"Profile stats for: {action}" + if self._local_rank is not None: + header += f" rank: {self._local_rank}" + output.append(header) + output.append(value) + return os.linesep.join(output) + + def setup( + self, + stage: Optional[str] = None, + local_rank: Optional[int] = None, + log_dir: Optional[str] = None, + ) -> None: + """Execute arbitrary pre-profiling set-up steps.""" + self._stage = stage + self._local_rank = local_rank + self._log_dir = log_dir + self.dirpath = self.dirpath or log_dir + + def teardown(self, stage: Optional[str] = None) -> None: + """ + Execute arbitrary post-profiling tear-down steps. + + Closes the currently open file and stream. + """ + self._write_stream = None + if self._output_file is not None: + self._output_file.close() + self._output_file = None # can't pickle TextIOWrapper + + def __del__(self) -> None: + self.teardown(stage=self._stage) + + def start(self, action_name: str) -> None: + raise NotImplementedError + + def stop(self, action_name: str) -> None: + raise NotImplementedError - @abstractmethod def summary(self) -> str: - """Create profiler summary in text format.""" + raise NotImplementedError - def on_train_start(self, local_rank: Optional[int] = None): - self.local_rank = local_rank + @property + def local_rank(self) -> int: + return 0 if self._local_rank is None else self._local_rank class PassThroughProfiler(BaseProfiler): @@ -105,9 +210,6 @@ class PassThroughProfiler(BaseProfiler): The Trainer uses this class by default. """ - def __init__(self): - super().__init__(output_streams=None) - def start(self, action_name: str) -> None: pass @@ -124,30 +226,32 @@ class SimpleProfiler(BaseProfiler): the mean duration of each action and the total time spent over the entire training run. """ - def __init__(self, output_filename: Optional[str] = None, extended=True): + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + extended: bool = True, + output_filename: Optional[str] = None, + ) -> None: """ Args: - output_filename: optionally save profile results to file instead of printing - to std out when training is finished. + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. + + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. Raises: ValueError: If you attempt to start an action which has already started, or if you attempt to stop recording an action which was never started. """ - self.current_actions = {} + super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) + self.current_actions: Dict[str, float] = {} self.recorded_durations = defaultdict(list) self.extended = extended - - self.output_fname = output_filename - self.output_file = None - if self.output_fname: - fs = get_filesystem(self.output_fname) - self.output_file = fs.open(self.output_fname, "w") - - streaming_out = [self.output_file.write] if self.output_file else [log.info] self.start_time = time.monotonic() - super().__init__(output_streams=streaming_out) def start(self, action_name: str) -> None: if action_name in self.current_actions: @@ -162,14 +266,18 @@ def stop(self, action_name: str) -> None: duration = end_time - start_time self.recorded_durations[action_name].append(duration) - def make_report(self): + def _make_report(self) -> Tuple[list, float]: total_duration = time.monotonic() - self.start_time report = [[a, d, 100. * np.sum(d) / total_duration] for a, d in self.recorded_durations.items()] report.sort(key=lambda x: x[2], reverse=True) return report, total_duration def summary(self) -> str: - output_string = "\n\nProfiler Report\n" + sep = os.linesep + output_string = "" + if self._stage is not None: + output_string += f"{self._stage.upper()} " + output_string += f"Profiler Report{sep}" if self.extended: @@ -177,16 +285,16 @@ def summary(self) -> str: max_key = np.max([len(k) for k in self.recorded_durations.keys()]) def log_row(action, mean, num_calls, total, per): - row = f"{os.linesep}{action:<{max_key}s}\t| {mean:<15}\t|" + row = f"{sep}{action:<{max_key}s}\t| {mean:<15}\t|" row += f"{num_calls:<15}\t| {total:<15}\t| {per:<15}\t|" return row output_string += log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %") output_string_len = len(output_string) - output_string += f"{os.linesep}{'-' * output_string_len}" - report, total_duration = self.make_report() + output_string += f"{sep}{'-' * output_string_len}" + report, total_duration = self._make_report() output_string += log_row("Total", "-", "_", f"{total_duration:.5}", "100 %") - output_string += f"{os.linesep}{'-' * output_string_len}" + output_string += f"{sep}{'-' * output_string_len}" for action, durations, duration_per in report: output_string += log_row( action, @@ -198,27 +306,16 @@ def log_row(action, mean, num_calls, total, per): else: def log_row(action, mean, total): - return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}" + return f"{sep}{action:<20s}\t| {mean:<15}\t| {total:<15}" output_string += log_row("Action", "Mean duration (s)", "Total time (s)") - output_string += f"{os.linesep}{'-' * 65}" + output_string += f"{sep}{'-' * 65}" for action, durations in self.recorded_durations.items(): output_string += log_row(action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}") - output_string += os.linesep + output_string += sep return output_string - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - if self.output_file: - self.output_file.flush() - - def __del__(self): - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() - class AdvancedProfiler(BaseProfiler): """ @@ -227,11 +324,22 @@ class AdvancedProfiler(BaseProfiler): verbose and you should only use this if you want very detailed reports. """ - def __init__(self, output_filename: Optional[str] = None, line_count_restriction: float = 1.0): + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + line_count_restriction: float = 1.0, + output_filename: Optional[str] = None, + ) -> None: """ Args: - output_filename: optionally save profile results to file instead of printing - to std out when training is finished. + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. + + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. + line_count_restriction: this can be used to limit the number of functions reported for each action. either an integer (to select a count of lines), or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines) @@ -240,18 +348,10 @@ def __init__(self, output_filename: Optional[str] = None, line_count_restriction ValueError: If you attempt to stop recording an action which was never started. """ - self.profiled_actions = {} + super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) + self.profiled_actions: Dict[str, cProfile.Profile] = {} self.line_count_restriction = line_count_restriction - self.output_fname = output_filename - self.output_file = None - if self.output_fname: - fs = get_filesystem(self.output_fname) - self.output_file = fs.open(self.output_fname, "w") - - streaming_out = [self.output_file.write] if self.output_file else [log.info] - super().__init__(output_streams=streaming_out) - def start(self, action_name: str) -> None: if action_name not in self.profiled_actions: self.profiled_actions[action_name] = cProfile.Profile() @@ -260,9 +360,7 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: pr = self.profiled_actions.get(action_name) if pr is None: - raise ValueError( # pragma: no-cover - f"Attempting to stop recording an action ({action_name}) which was never started." - ) + raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.") pr.disable() def summary(self) -> str: @@ -272,21 +370,16 @@ def summary(self) -> str: ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative') ps.print_stats(self.line_count_restriction) recorded_stats[action_name] = s.getvalue() + return self._stats_to_str(recorded_stats) - # log to standard out - output_string = f"{os.linesep}Profiler Report{os.linesep}" - for action, stats in recorded_stats.items(): - output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}" - - return output_string - - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - if self.output_file: - self.output_file.flush() + def teardown(self, stage: Optional[str] = None) -> None: + super().teardown(stage=stage) + self.profiled_actions = {} - def __del__(self): - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() + def __reduce__(self): + # avoids `TypeError: cannot pickle 'cProfile.Profile' object` + return ( + self.__class__, + tuple(), + dict(dirpath=self.dirpath, filename=self.filename, line_count_restriction=self.line_count_restriction), + ) diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index 88a33a3d367f8..fa2c2917f98a2 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -12,27 +12,197 @@ # See the License for the specific language governing permissions and # limitations under the License. """Profiler to check if there are any bottlenecks in your code.""" - import inspect import logging import os -from typing import List, Optional +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Set, Type, TYPE_CHECKING, Union import torch +from torch import nn, Tensor +from torch.autograd.profiler import record_function from pytorch_lightning.profiler.profilers import BaseProfiler -from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.distributed import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE + +if TYPE_CHECKING: + from torch.autograd.profiler import EventList + from torch.utils.hooks import RemovableHandle + + from pytorch_lightning.core.lightning import LightningModule + +if _KINETO_AVAILABLE: + from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler log = logging.getLogger(__name__) +_PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx] + + +class RegisterRecordFunction: + """ + While profiling autograd operations, this class will add labels for module names around the forward function. + + The Lightning PyTorch Profiler will activate this feature automatically. It can be deactivated as follows: + + Example:: + from pytorch_lightning.profilers import PyTorchProfiler + profiler = PyTorchProfiler(record_module_names=False) + Trainer(profiler=profiler) + + It can be used outside of Lightning as follows: + + Example:: + from pytorch_lightning import Trainer, seed_everything + with RegisterRecordFunction(model): + out = model(batch) + """ + + def __init__(self, model: nn.Module) -> None: + self._model = model + self._records: Dict[str, record_function] = {} + self._handles: Dict[str, List['RemovableHandle']] = {} + + def _start_recording_forward(self, _: nn.Module, input: Tensor, record_name: str) -> Tensor: + record = record_function(record_name) + record.__enter__() + self._records[record_name] = record + return input + + def _stop_recording_forward(self, _: nn.Module, __: Tensor, output: Tensor, record_name: str) -> Tensor: + self._records[record_name].__exit__(None, None, None) + return output + + def __enter__(self) -> None: + for module_name, module in self._model.named_modules(): + if module_name: + full_name = f"{type(module).__module__}.{type(module).__name__}" + record_name = f"{full_name}: {module_name}" + pre_forward_handle = module.register_forward_pre_hook( + partial(self._start_recording_forward, record_name=record_name) + ) + post_forward_handle = module.register_forward_hook( + partial(self._stop_recording_forward, record_name=record_name) + ) + + self._handles[module_name] = [pre_forward_handle, post_forward_handle] + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + for handles in self._handles.values(): + for h in handles: + h.remove() + self._handles = {} + + +class ScheduleWrapper: + """ + This class is used to override the schedule logic from the profiler and perform + recording for both `training_step`, `validation_step`. + """ + + def __init__(self, schedule: Callable) -> None: + if not _KINETO_AVAILABLE: + raise ModuleNotFoundError("You are trying to use `ScheduleWrapper` which require kineto install.") + self._schedule = schedule + self.reset() + + def setup(self, start_action_name: str) -> None: + self._start_action_name = start_action_name + + def pre_step(self, current_action: str) -> None: + self._current_action = current_action + + def reset(self): + self._num_training_step_and_backward = 0 + self._num_validation_step = 0 + self._num_test_step = 0 + self._num_predict_step = 0 + self._training_step_and_backward_reached_end = False + self._validation_step_reached_end = False + self._test_step_reached_end = False + self._predict_step_reached_end = False + # used to stop profiler when `ProfilerAction.RECORD_AND_SAVE` is reached. + self._current_action: Optional[str] = None + self._start_action_name: Optional[str] = None + + @property + def num_step(self) -> int: + if self._current_action == "training_step_and_backward": + return self._num_training_step_and_backward + elif self._current_action == "validation_step": + return self._num_validation_step + elif self._current_action == "test_step": + return self._num_test_step + elif self._current_action == "predict_step": + return self._num_predict_step + else: + return 0 + + def _step(self) -> None: + if self._current_action == "training_step_and_backward": + self._num_training_step_and_backward += 1 + elif self._current_action == "validation_step": + if self._start_action_name == "on_fit_start": + if self._num_training_step_and_backward > 0: + self._num_validation_step += 1 + else: + self._num_validation_step += 1 + elif self._current_action == "test_step": + self._num_test_step += 1 + elif self._current_action == "predict_step": + self._num_predict_step += 1 + + @property + def has_finished(self) -> bool: + if self._current_action == "training_step_and_backward": + return self._training_step_and_backward_reached_end + elif self._current_action == "validation_step": + return self._validation_step_reached_end + elif self._current_action == "test_step": + return self._test_step_reached_end + elif self._current_action == "predict_step": + return self._predict_step_reached_end + return False + + def __call__(self, num_step: int) -> 'ProfilerAction': + # ignore the provided input. Keep internal state instead. + if self.has_finished: + return ProfilerAction.NONE + + self._step() + action = self._schedule(self.num_step) + if action == ProfilerAction.RECORD_AND_SAVE: + if self._current_action == "training_step_and_backward": + self._training_step_and_backward_reached_end = True + elif self._current_action == "validation_step": + self._validation_step_reached_end = True + elif self._current_action == "test_step": + self._test_step_reached_end = True + elif self._current_action == "predict_step": + self._predict_step_reached_end = True + return action + class PyTorchProfiler(BaseProfiler): - PROFILED_FUNCTIONS = ("training_step_and_backward", "validation_step", "test_step") - AVAILABLE_SORT_KEYS = ( + RECORD_FUNCTIONS = { + "training_step_and_backward", + "training_step", + "backward", + "validation_step", + "test_step", + "predict_step", + } + STEP_FUNCTIONS = { + "training_step_and_backward", + "validation_step", + "test_step", + "predict_step", + } + AVAILABLE_SORT_KEYS = { "cpu_time", "cuda_time", "cpu_time_total", @@ -42,56 +212,43 @@ class PyTorchProfiler(BaseProfiler): "self_cpu_memory_usage", "self_cuda_memory_usage", "count", - ) + } + START_RECORD_FUNCTIONS = { + 'on_fit_start', + 'on_validation_start', + 'on_test_start', + 'on_predict_start', + } def __init__( self, - output_filename: Optional[str] = None, - enabled: bool = True, - use_cuda: bool = False, - record_shapes: bool = False, - profile_memory: bool = False, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, group_by_input_shapes: bool = False, - with_stack: bool = False, - use_kineto: bool = False, - use_cpu: bool = True, emit_nvtx: bool = False, - export_to_chrome: bool = False, - path_to_export_trace: str = None, + export_to_chrome: bool = True, row_limit: int = 20, sort_by_key: Optional[str] = None, + record_functions: Set[str] = None, + record_module_names: bool = True, profiled_functions: Optional[List] = None, - local_rank: Optional[int] = None, - ): + output_filename: Optional[str] = None, + **profiler_kwargs: Any, + ) -> None: """ This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of different operators inside your model - both on the CPU and GPU Args: + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. - output_filename: optionally save profile results to file instead of printing - to std out when training is finished. When using ``ddp``, - each rank will stream the profiled operation to their own file - with the extension ``_{rank}.txt`` - - enabled: Setting this to False makes this context manager a no-op. - - use_cuda: Enables timing of CUDA events as well using the cudaEvent API. - Adds approximately 4us of overhead to each tensor operation. - - record_shapes: If shapes recording is set, information about input dimensions will be collected. - - profile_memory: Whether to report memory usage, default: True (Introduced in PyTorch 1.6.0) + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. group_by_input_shapes: Include operator input shapes and group calls by shape. - with_stack: record source information (file and line number) for the ops (Introduced in PyTorch 1.7.0) - - use_kineto: experimental support for Kineto profiler (Introduced in PyTorch 1.8.0) - - use_cpu: use_kineto=True and can be used to lower the overhead - for GPU-only profiling (Introduced in PyTorch 1.8.0) - emit_nvtx: Context manager that makes every autograd operation emit an NVTX range Run:: @@ -102,202 +259,254 @@ def __init__( nvvp trace_name.prof torch.autograd.profiler.load_nvprof(path) - export_to_chrome: Wether to export the sequence of profiled operators for Chrome. + export_to_chrome: Whether to export the sequence of profiled operators for Chrome. It will generate a ``.json`` file which can be read by Chrome. - path_to_export_trace: Directory path to export ``.json`` traces when using ``export_to_chrome=True``. - By default, it will be save where the file being is being run. - - row_limit: Limit the number of rows in a table, `0` is a special value that + row_limit: Limit the number of rows in a table, ``-1`` is a special value that removes the limit completely. - sort_by_key: Keys to sort out profiled table + sort_by_key: Attribute used to sort entries. By default + they are printed in the same order as they were registered. + Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``, + ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``, + ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``. - profiled_functions: list of profiled functions which will create a context manager on. + record_functions: Set of profiled functions which will create a context manager on. Any other will be pass through. - local_rank: When running in distributed setting, local_rank is used for each process - to write to their own file if `output_fname` is provided. + record_module_names: Whether to add module names while recording autograd operation. + + profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version Raises: MisconfigurationException: - If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``, or - if log file is not a ``.txt`` file. - ValueError: - If you attempt to stop recording an action which was never started. + If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``. + If arg ``schedule`` is not a ``Callable``. + If arg ``schedule`` does not return a ``torch.profiler.ProfilerAction``. """ - - self.profiled_actions = {} - self.enabled = enabled - self.profiled_functions = profiled_functions or self.PROFILED_FUNCTIONS - self.use_cuda = use_cuda - self.record_shapes = record_shapes - self.profile_memory = profile_memory - self.sort_by_key = sort_by_key or ("cuda_time_total" if self.use_cuda else "cpu_time_total") - self.with_stack = with_stack - self.group_by_input_shapes = group_by_input_shapes and record_shapes - self.use_kineto = use_kineto - self.use_cpu = use_cpu - self.row_limit = row_limit - self.emit_nvtx = emit_nvtx - self.export_to_chrome = export_to_chrome - self.path_to_export_trace = path_to_export_trace - - if export_to_chrome and path_to_export_trace is None: - rank_zero_warn( - "The exported trace would be save locally as `path_to_export_trace` is empty." - " Note: Each functions will generate its own traced file." + super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) + + record_functions = self.__deprecation_check(profiled_functions, record_functions) + + self._group_by_input_shapes = group_by_input_shapes and profiler_kwargs.get("record_shapes", False) + self._emit_nvtx = emit_nvtx + self._export_to_chrome = export_to_chrome + self._row_limit = row_limit + self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total" + self._user_record_functions = record_functions + self._record_functions_start = self._user_record_functions | self.START_RECORD_FUNCTIONS + self._record_functions = self._user_record_functions | self.RECORD_FUNCTIONS + self._record_module_names = record_module_names + self._profiler_kwargs = profiler_kwargs + + self.profiler: Optional[_PROFILER] = None + self.function_events: Optional['EventList'] = None + self._lightning_module: Optional['LightningModule'] = None # set by ProfilerConnector + self._register: Optional[RegisterRecordFunction] = None + self._parent_profiler: Optional[_PROFILER] = None + self._recording_map: Dict[str, record_function] = {} + self._start_action_name: Optional[str] = None + self._schedule: Optional[ScheduleWrapper] = None + + if _KINETO_AVAILABLE: + self._init_kineto(profiler_kwargs) + + if self._sort_by_key not in self.AVAILABLE_SORT_KEYS: + raise MisconfigurationException( + f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " ) - if self.sort_by_key not in self.AVAILABLE_SORT_KEYS: - raise MisconfigurationException( - f"Found sort_by_key: {sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " + def _init_kineto(self, profiler_kwargs: Any) -> None: + has_schedule = "schedule" in profiler_kwargs + self._has_on_trace_ready = "on_trace_ready" in profiler_kwargs + + schedule = profiler_kwargs.get("schedule", None) + if schedule is not None: + if not isinstance(schedule, Callable): + raise MisconfigurationException(f"Schedule should be a callable. Found: {schedule}") + action = schedule(0) + if not isinstance(action, ProfilerAction): + raise MisconfigurationException( + f"Schedule should return a `torch.profiler.ProfilerAction`. Found: {action}" + ) + schedule = schedule if has_schedule else self._default_schedule() + self._schedule = ScheduleWrapper(schedule) if schedule is not None else schedule + self._profiler_kwargs["schedule"] = self._schedule + + activities = profiler_kwargs.get("activities", None) + self._profiler_kwargs["activities"] = activities or self._default_activities() + self._export_to_flame_graph = profiler_kwargs.get("export_to_flame_graph", False) + self._metric = profiler_kwargs.get("metric", "self_cpu_time_total") + with_stack = profiler_kwargs.get("with_stack", False) or self._export_to_flame_graph + self._profiler_kwargs["with_stack"] = with_stack + + def __deprecation_check( + self, + profiled_functions: Optional[List[str]], + record_functions: Optional[Set[str]], + ) -> Set[str]: + if record_functions is None: + record_functions = set() + + if profiled_functions is not None: + rank_zero_warn( + "`PyTorchProfiler.profiled_functions` has been renamed to" + " `record_functions` in v1.3 and will be removed in v1.5", DeprecationWarning ) + if not record_functions: + record_functions |= set(profiled_functions) + else: + raise MisconfigurationException( + "You set `PytorchProfiler.profiled_functions` and `PyTorchProfiler.record_functions`." + " Please use only the later." + ) + + return record_functions + + @staticmethod + def _default_schedule() -> Optional[callable]: + if _KINETO_AVAILABLE: + # Those schedule defaults allow the profiling overhead to be negligible over training time. + return torch.profiler.schedule(wait=1, warmup=1, active=3) + + def _default_activities(self) -> List['ProfilerActivity']: + activities = [] + if not _KINETO_AVAILABLE: + return activities + if self._profiler_kwargs.get("use_cpu", True): + activities.append(ProfilerActivity.CPU) + if self._profiler_kwargs.get("use_cuda", torch.cuda.is_available()): + activities.append(ProfilerActivity.CUDA) + return activities - self.profiled_actions = {} - self.context_names = {} - self.running_stack = [] - self.profiler = None + def start(self, action_name: str) -> None: + if self.profiler is None and action_name in self._record_functions_start: - self.output_fname = output_filename - self.output_file = None - if local_rank is not None: - self.on_train_start(local_rank=local_rank) - self.on_train_start = super().on_train_start + # close profiler if it is already opened. might happen if 2 profilers + # are created and the first one did not call `describe` + try: + torch.autograd._disable_profiler() # noqa + except (AttributeError, RuntimeError): + pass - def on_train_start(self, local_rank: Optional[str] = None): - self.local_rank = local_rank + if self._schedule is not None: + self._schedule.setup(action_name) - # when logging to `log.info`, only perform profiling on rank 0 - if local_rank != 0 and self.output_fname is None: - self.wrap_functions_into_rank_zero_only() + self._create_profilers() - if self.output_fname: - if local_rank is not None: - if '.txt' not in self.output_fname: - raise MisconfigurationException("Log file should be .txt file.") + profiler = self.profiler.__enter__() + if profiler is not None: + self.profiler = profiler - self.output_fname = self.output_fname.replace(".txt", f"_{self.local_rank}.txt") + if self._parent_profiler is not None: + self._parent_profiler.__enter__() - fs = get_filesystem(self.output_fname) - self.output_file = fs.open(self.output_fname, "w") + if self._register is not None: + self._register.__enter__() - streaming_out = [self.output_file.write] if self.output_file else [log.info] - super().__init__(output_streams=streaming_out) + if ( + self.profiler is not None and action_name in self._record_functions + and action_name not in self._recording_map + ): + recording = record_function(action_name) + recording.__enter__() + self._recording_map[action_name] = recording - def wrap_functions_into_rank_zero_only(self): - self.start = rank_zero_only(self.start) - self.stop = rank_zero_only(self.stop) - self.summary = rank_zero_only(self.summary) - self.describe = rank_zero_only(self.describe) + def stop(self, action_name: str) -> None: + if action_name in self._recording_map: + self._recording_map[action_name].__exit__(None, None, None) + del self._recording_map[action_name] - def start(self, action_name: str) -> None: - if action_name not in self.profiled_functions: + if not _KINETO_AVAILABLE or self._emit_nvtx: return - if len(self.running_stack) > 0: - self._stop(self.running_stack[-1]) - self.running_stack.append(action_name) + if self.profiler is not None and action_name in self.STEP_FUNCTIONS: + if self._schedule is not None: + self._schedule.pre_step(action_name) - self.context_names[action_name] = "/".join(self.running_stack) + def on_trace_ready(profiler): + if self.dirpath is not None: + if self._export_to_chrome: + handler = tensorboard_trace_handler(self.dirpath, self._prepare_filename(extension="")) + handler(profiler) - self._start(action_name) + if self._export_to_flame_graph: + path = os.path.join(self.dirpath, self._prepare_filename(extension=".stack")) + profiler.export_stacks(path, metric=self._metric) + else: + rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None") - def _start(self, action_name: str) -> None: - if self.emit_nvtx: - self._parent_profiler = self._create_profiler(action_name, torch.cuda.profiler.profile, enter=True) - self._create_profiler(action_name, torch.autograd.profiler.emit_nvtx) - else: - self._create_profiler(action_name, torch.autograd.profiler.profile) - - def _create_profiler(self, action_name, profiler, enter=True): - init_args = inspect.signature(profiler.__init__).parameters - profiler_args = {k: v for k, v in vars(self).items() if k in init_args} - pr = profiler(**profiler_args) - if enter: - out_pr = pr.__enter__() - if out_pr is not None: - pr = out_pr - self.profiler = pr - return self.profiler - - def _stop(self, action_name: str) -> None: - if self.profiler is None: - return + if not self._has_on_trace_ready: + self.profiler.on_trace_ready = on_trace_ready - self.profiler.__exit__(exc_type=None, exc_val=None, exc_tb=None) + if self._schedule is not None: + self.profiler.step_num = self._schedule.num_step + self.profiler.step() - if isinstance(self.profiler, torch.autograd.profiler.emit_nvtx): - # when running ``emit_nvtx``, PyTorch requires 2 context manager. - # The parent_profiler is being closed too. - self._parent_profiler.__exit__(None, None, None) - return + def summary(self) -> str: + if not self._profiler_kwargs.get("enabled", True) or self._emit_nvtx: + return "" - function_events = self.profiler.function_events - self.profiler = None - for name in self.running_stack: - if name not in self.profiled_actions: - self.profiled_actions[name] = function_events - else: - self.profiled_actions[name] += function_events + self._delete_profilers() - def stop(self, action_name: str) -> None: - if action_name not in self.profiled_functions: - return + if not self.function_events: + return "" + + if self._export_to_chrome and not _KINETO_AVAILABLE: + filename = f"{self.local_rank}_trace.json" + path_to_trace = (filename if self.dirpath is None else os.path.join(self.dirpath, filename)) + self.function_events.export_chrome_trace(path_to_trace) + + data = self.function_events.key_averages(group_by_input_shapes=self._group_by_input_shapes) + table = data.table(sort_by=self._sort_by_key, row_limit=self._row_limit) - if len(self.running_stack) == 0 or self.running_stack[-1] != action_name: - raise ValueError( # pragma: no-cover - f"Attempting to stop recording an action ({action_name}) which was never started." + recorded_stats = {"records": table} + return self._stats_to_str(recorded_stats) + + def _create_profilers(self) -> None: + if self._emit_nvtx: + self._parent_profiler = self._create_profiler(torch.cuda.profiler.profile) + self.profiler = self._create_profiler(torch.autograd.profiler.emit_nvtx) + else: + self._parent_profiler = None + self.profiler = self._create_profiler( + torch.profiler.profile if _KINETO_AVAILABLE else torch.autograd.profiler.profile ) - self._stop(action_name) - self.running_stack.pop() - # restore running profiler - if len(self.running_stack) > 0: - self._start(self.running_stack[-1]) + if self._record_module_names and self._lightning_module is not None: + self._register = RegisterRecordFunction(self._lightning_module) - def summary(self) -> str: - recorded_stats = {} - output_string = '' - local_rank = '0' if self.local_rank is None else self.local_rank + def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER: + init_parameters = inspect.signature(profiler.__init__).parameters + kwargs = {k: v for k, v in self._profiler_kwargs.items() if k in init_parameters} + return profiler(**kwargs) - if not self.enabled: - return output_string + def _cache_functions_events(self) -> None: + if self._emit_nvtx: + return + self.function_events = self.profiler.events() if _KINETO_AVAILABLE else self.profiler.function_events - for action_name, function_events in self.profiled_actions.items(): + def _delete_profilers(self) -> None: + if self.profiler is not None: + self.profiler.__exit__(None, None, None) + self._cache_functions_events() + self.profiler = None - # next line is a workaround for a pytorch issue (fixed on master, still present - # on 1.7). Without it the code fails with `AssertionError: There is already a CPU - # parent event for detach` - function_events.populate_cpu_children = lambda: None + if self._schedule is not None: + self._schedule.reset() - if self.export_to_chrome: - filename = f"{action_name}_{local_rank}_trace.json" - path_to_trace = filename if self.path_to_export_trace is None \ - else os.path.join(self.path_to_export_trace, filename) - function_events.export_chrome_trace(path_to_trace) + if self._parent_profiler is not None: + self._parent_profiler.__exit__(None, None, None) + self._parent_profiler = None - if self.emit_nvtx: - return output_string + if self._register is not None: + self._register.__exit__(None, None, None) + self._register = None - else: - data = function_events.key_averages(group_by_input_shapes=self.group_by_input_shapes) - table = data.table(sort_by=self.sort_by_key, row_limit=self.row_limit) - recorded_stats[action_name] = table - - # log to standard out - output_string = f"{os.linesep}Profiler Report{os.linesep}" - for action, stats in recorded_stats.items(): - output_string += (f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}") - - return output_string - - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - if self.output_file: - self.output_file.flush() - - def __del__(self): - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() + def teardown(self, stage: Optional[str] = None) -> None: + self._delete_profilers() + + for k in self._recording_map: + self.stop(k) + self._recording_map = {} + + super().teardown(stage=stage) diff --git a/pytorch_lightning/setup_tools.py b/pytorch_lightning/setup_tools.py index f5aed2608635e..3362ccb479895 100644 --- a/pytorch_lightning/setup_tools.py +++ b/pytorch_lightning/setup_tools.py @@ -16,7 +16,7 @@ import re from typing import List -from pytorch_lightning import __homepage__, __version__, _PROJECT_ROOT +_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comment_char: str = '#') -> List[str]: @@ -40,10 +40,10 @@ def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comme return reqs -def _load_readme_description(path_dir: str, homepage: str = __homepage__, version: str = __version__) -> str: +def _load_readme_description(path_dir: str, homepage: str, version: str) -> str: """Load readme as decribtion - >>> _load_readme_description(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> _load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE '
...' """ path_readme = os.path.join(path_dir, "README.md") diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 5aa9f1a44276b..6d434e12a2e78 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -15,11 +15,15 @@ from abc import ABC from copy import deepcopy from inspect import signature -from typing import Any, Callable, Dict, List, Type, Optional +from typing import Any, Callable, Dict, List, Optional, Type from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature +from pytorch_lightning.utilities.warnings import WarningCache + +warning_cache = WarningCache() class TrainerCallbackHookMixin(ABC): @@ -79,8 +83,12 @@ def on_train_epoch_start(self): for callback in self.callbacks: callback.on_train_epoch_start(self, self.lightning_module) - def on_train_epoch_end(self, outputs): - """Called when the epoch ends.""" + def on_train_epoch_end(self, outputs: List[Any]): + """Called when the epoch ends. + + Args: + outputs: List of outputs on each ``train`` epoch + """ for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) @@ -89,28 +97,52 @@ def on_validation_epoch_start(self): for callback in self.callbacks: callback.on_validation_epoch_start(self, self.lightning_module) - def on_validation_epoch_end(self): - """Called when the epoch ends.""" + def on_validation_epoch_end(self, outputs: List[Any]): + """Called when the epoch ends. + + Args: + outputs: List of outputs on each ``validation`` epoch + """ for callback in self.callbacks: - callback.on_validation_epoch_end(self, self.lightning_module) + if is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"): + callback.on_validation_epoch_end(self, self.lightning_module, outputs) + else: + warning_cache.warn( + "`Callback.on_validation_epoch_end` signature has changed in v1.3." + " `outputs` parameter has been added." + " Support for the old signature will be removed in v1.5", DeprecationWarning + ) + callback.on_validation_epoch_end(self, self.lightning_module) def on_test_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: callback.on_test_epoch_start(self, self.lightning_module) - def on_test_epoch_end(self): - """Called when the epoch ends.""" + def on_test_epoch_end(self, outputs: List[Any]): + """Called when the epoch ends. + + Args: + outputs: List of outputs on each ``test`` epoch + """ for callback in self.callbacks: - callback.on_test_epoch_end(self, self.lightning_module) + if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"): + callback.on_test_epoch_end(self, self.lightning_module, outputs) + else: + warning_cache.warn( + "`Callback.on_test_epoch_end` signature has changed in v1.3." + " `outputs` parameter has been added." + " Support for the old signature will be removed in v1.5", DeprecationWarning + ) + callback.on_test_epoch_end(self, self.lightning_module) def on_epoch_start(self): - """Called when the epoch begins.""" + """Called when either of train/val/test epoch begins.""" for callback in self.callbacks: callback.on_epoch_start(self, self.lightning_module) def on_epoch_end(self): - """Called when the epoch ends.""" + """Called when either of train/val/test epoch ends.""" for callback in self.callbacks: callback.on_epoch_end(self, self.lightning_module) @@ -211,10 +243,10 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: callback_states = {} for callback in self.callbacks: if self.__is_old_signature(callback.on_save_checkpoint): - rank_zero_warn( + rank_zero_deprecation( "`Callback.on_save_checkpoint` signature has changed in v1.3." " A `checkpoint` parameter has been added." - " Support for the old signature will be removed in v1.5", DeprecationWarning + " Support for the old signature will be removed in v1.5" ) state = callback.on_save_checkpoint(self, self.lightning_module) # noqa: parameter-unfilled else: diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 8c539b5ff478d..a7ba2b1c40123 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -40,7 +40,8 @@ def verify_loop_configurations(self, model: LightningModule) -> None: self.__verify_eval_loop_configuration(model, 'val') elif self.trainer.state == TrainerState.TESTING: self.__verify_eval_loop_configuration(model, 'test') - # TODO: add predict + elif self.trainer.state == TrainerState.PREDICTING: + self.__verify_predict_loop_configuration(model) def __verify_train_loop_configuration(self, model): # ----------------------------------- @@ -99,3 +100,9 @@ def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) - rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop') if has_step and not has_loader: rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop') + + def __verify_predict_loop_configuration(self, model: LightningModule) -> None: + + has_predict_dataloader = is_overridden('predict_dataloader', model) + if not has_predict_dataloader: + raise MisconfigurationException('Dataloader not found for `Trainer.predict`') diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 99d716f6b5a8c..30d2b48975a84 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -32,6 +32,7 @@ DDPSpawnShardedPlugin, DeepSpeedPlugin, DeepSpeedPrecisionPlugin, + DoublePrecisionPlugin, HorovodPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin, @@ -273,6 +274,10 @@ def use_deepspeed(self) -> bool: @property def is_distributed(self) -> bool: + # Used for custom plugins. + # Custom plugins should implement is_distributed property. + if hasattr(self.training_type_plugin, 'is_distributed') and not self.on_tpu: + return self.training_type_plugin.is_distributed is_distributed = self.use_ddp or self.use_ddp2 or self.use_horovod if self.on_tpu: is_distributed |= self.training_type_plugin.is_distributed @@ -315,7 +320,8 @@ def select_precision_plugin(self) -> PrecisionPlugin: if self.precision == 32: return PrecisionPlugin() - + elif self.precision == 64: + return DoublePrecisionPlugin() elif self.precision == 16: if self.on_tpu: return TPUHalfPrecisionPlugin() @@ -354,7 +360,7 @@ def select_precision_plugin(self) -> PrecisionPlugin: log.info("Using APEX 16bit precision.") return ApexMixedPrecisionPlugin(self.amp_level) - raise NotImplementedError("We only support precisions 32 and 16!") + raise NotImplementedError("We only support precisions 64, 32 and 16!") def select_training_type_plugin(self) -> TrainingTypePlugin: if self.use_ddp2: @@ -426,6 +432,11 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None: training_type.num_nodes = self.num_nodes + # Automatically set sync_batchnorm if None. + # Useful for custom plugins. + if hasattr(training_type, 'sync_batchnorm') and getattr(training_type, 'sync_batchnorm') is None: + training_type.sync_batchnorm = self.sync_batchnorm + return training_type def select_accelerator(self) -> Accelerator: diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index b3fc0b4eb7b29..5d2f141dc64a8 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -150,6 +150,10 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = N self.trainer.datamodule = datamodule datamodule.trainer = self.trainer + # experimental feature for Flash + if hasattr(datamodule, "data_pipeline"): + model.data_pipeline = datamodule.data_pipeline + class _PatchDataLoader(object): r""" diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py index 2e788c256af0d..1f1c41c6eb2f0 100644 --- a/pytorch_lightning/trainer/connectors/env_vars_connector.py +++ b/pytorch_lightning/trainer/connectors/env_vars_connector.py @@ -18,27 +18,25 @@ from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables -def overwrite_by_env_vars(fn: Callable) -> Callable: +def _defaults_from_env_vars(fn: Callable) -> Callable: """ Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which input arguments should be moved automatically to the correct device. - """ @wraps(fn) - def overwrite_by_env_vars(self, *args, **kwargs): - # get the class - cls = self.__class__ + def insert_env_defaults(self, *args, **kwargs): + cls = self.__class__ # get the class if args: # inace any args passed move them to kwargs # parse only the argument names cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] # convert args to kwargs kwargs.update({k: v for k, v in zip(cls_arg_names, args)}) + env_variables = vars(parse_env_variables(cls)) # update the kwargs by env variables - # todo: maybe add a warning that some init args were overwritten by Env arguments - kwargs.update(vars(parse_env_variables(cls))) + kwargs = dict(list(env_variables.items()) + list(kwargs.items())) # all args were already moved to kwargs return fn(self, **kwargs) - return overwrite_by_env_vars + return insert_env_defaults diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 223216846758f..7759c8028d325 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -13,9 +13,11 @@ # limitations under the License. from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple +from weakref import proxy import torch +import pytorch_lightning as pl from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import DistributedType, LightningEnum @@ -50,7 +52,7 @@ class HookResultStore: Those data structures enables us to reduce properly Result object when batch loop is finished. """ - def __init__(self, fx_name): + def __init__(self, fx_name: str) -> None: self._fx_name = fx_name self._internals = {} self._internals_reduced = {} @@ -104,6 +106,7 @@ def get_batch_log_metrics(self, *args, **kwargs): def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs) -> None: if not isinstance(opt_metric, Result): raise Exception("The provided opt_metric should be a Result Object. Something is wrong") + func = getattr(opt_metric, func_name) metrics_to_log = func(*args, add_dataloader_idx=self.has_several_dataloaders, **kwargs) results.append(metrics_to_log) @@ -222,8 +225,8 @@ class EpochResultStore: ``` """ - def __init__(self, trainer) -> None: - self.trainer = trainer + def __init__(self, trainer: 'pl.Trainer') -> None: + self.trainer = proxy(trainer) self.reset() def __getitem__(self, key: str) -> Any: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 15428c5d5c248..0d0c3781c7724 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -81,16 +81,13 @@ def cached_results(self) -> Union[EpochResultStore, None]: return self._cached_results.get(self.trainer._running_stage) # type: ignore def get_metrics(self, key: str) -> Dict: - metrics_holder = getattr(self, f"_{key}", None) - model_ref = self.trainer.lightning_module - metrics_holder.convert( - self.trainer._device_type == DeviceType.TPU, - model_ref.device if model_ref is not None else model_ref, - ) + metrics_holder: MetricsHolder = getattr(self, f"_{key}") + model = self.trainer.lightning_module + metrics_holder.convert(model.device if model is not None else None) return metrics_holder.metrics def set_metrics(self, key: str, val: Dict) -> None: - metrics_holder = getattr(self, f"_{key}", None) + metrics_holder: MetricsHolder = getattr(self, f"_{key}") metrics_holder.reset(val) def reset(self) -> None: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py index 82f328a927485..1efbcc638674f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -12,44 +12,52 @@ # See the License for the specific language governing permissions and # limitations under the License. import numbers -from typing import Any +from typing import Any, Dict, Optional, Union import torch +from torchmetrics import Metric -from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +_METRIC_TYPE = Union[Metric, torch.Tensor, int, float, Any] class MetricsHolder: """ - This class acts as a dictonary holder. + This class acts as a dictionary holder. It holds metrics and implements conversion functions. Those functions will be triggered within LoggerConnector when the property is being requested from the user. """ - def __init__(self, to_float: bool = False): - self.metrics = {} + def __init__(self, to_float: bool = False) -> None: + self.metrics: Dict[str, _METRIC_TYPE] = {} self._to_float = to_float - def update(self, metrics): + def update(self, metrics: dict) -> None: self.metrics.update(metrics) - def pop(self, key, default): + def pop(self, key: str, default: _METRIC_TYPE) -> _METRIC_TYPE: return self.metrics.pop(key, default) - def reset(self, metrics): + def reset(self, metrics: Dict[str, _METRIC_TYPE]) -> None: self.metrics = metrics - def convert(self, use_tpu: bool, device: torch.device): + def convert(self, device: Optional[torch.device]) -> None: for key, value in self.metrics.items(): - self.metrics[key] = self._convert(value, use_tpu, device) - - def _convert(self, current: Any, use_tpu: bool, device: torch.device): - if self._to_float: - return self._convert_to_float(current, use_tpu, device) - return self._convert_to_tensor(current, use_tpu, device) - - def _convert_to_float(self, current, use_tpu: bool, device: torch.device): + if self._to_float: + if isinstance(value, torch.Tensor) and value.numel() != 1: + raise MisconfigurationException( + f"The metric `{key}` does not contain a single element" + f" thus it cannot be converted to float. Found `{value}`" + ) + converted = self._convert_to_float(value) + else: + converted = self._convert_to_tensor(value, device) + self.metrics[key] = converted + + @staticmethod + def _convert_to_float(current: _METRIC_TYPE) -> float: if isinstance(current, Metric): current = current.compute().detach() @@ -61,16 +69,13 @@ def _convert_to_float(self, current, use_tpu: bool, device: torch.device): return current - def _convert_to_tensor(self, current: Any, use_tpu: bool, device: torch.device): - if current is not None: - if isinstance(current, Metric): - current = current.compute().detach() + @staticmethod + def _convert_to_tensor(current: _METRIC_TYPE, device: Optional[torch.device]) -> torch.Tensor: + if isinstance(current, Metric): + current = current.compute().detach() - elif isinstance(current, numbers.Number): - if device is None: - current = torch.tensor(current, dtype=torch.float) - else: - current = torch.tensor(current, device=device, dtype=torch.float) + elif isinstance(current, numbers.Number): + current = torch.tensor(current, device=device, dtype=torch.float) if isinstance(current, torch.Tensor) and current.device.type == "xla": current = current.cpu() diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py index 98d65c1285ff7..fa1002d70a7ce 100644 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License - from typing import Union +from weakref import proxy from pytorch_lightning.profiler import ( AdvancedProfiler, @@ -54,6 +54,8 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, str]): ) self.trainer.profiler = profiler or PassThroughProfiler() - def on_train_start(self, trainer): + def setup(self) -> None: + trainer = self.trainer local_rank = trainer.local_rank if trainer.world_size > 1 else None - self.trainer.profiler.on_train_start(local_rank) + trainer.profiler._lightning_module = proxy(trainer.lightning_module) + trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 83505913d0186..59ec40c3df2e8 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -16,7 +16,7 @@ import platform from abc import ABC from copy import deepcopy -from typing import Callable, Iterable, List, Tuple, Union +from typing import Iterable, List, Tuple, Union from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler @@ -41,7 +41,7 @@ class TrainerDataLoadingMixin(ABC): tpu_local_core_rank: int train_dataloader: DataLoader num_training_batches: Union[int, float] - val_check_batch:... + val_check_batch: float val_dataloaders: List[DataLoader] num_val_batches: List[Union[int, float]] test_dataloaders: List[DataLoader] @@ -191,7 +191,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: Args: model: The current `LightningModule` """ - self.train_dataloader = self.request_dataloader(model.train_dataloader) + self.train_dataloader = self.request_dataloader(model, "train") if self.overfit_batches > 0: if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler): @@ -271,7 +271,7 @@ def _reset_eval_dataloader( """ # always get the loaders first so we can count how many there are loader_name = f'{mode}_dataloader' - dataloaders = self.request_dataloader(getattr(model, loader_name)) + dataloaders = self.request_dataloader(model, mode) if not isinstance(dataloaders, list): dataloaders = [dataloaders] @@ -280,7 +280,7 @@ def _reset_eval_dataloader( # duplicate it the numb of times needed to match the train loaders if self.overfit_batches > 0: num_loaders = len(dataloaders) - train_dataloader = self.request_dataloader(getattr(model, 'train_dataloader')) + train_dataloader = self.request_dataloader(model, 'train') dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)] self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders) @@ -293,9 +293,9 @@ def _reset_eval_dataloader( if mode in modes and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler): # when overfitting, the dataloader should not have sampler - if self.overfit_batches > 0: + if self.overfit_batches > 0 and mode != 'predict': rank_zero_warn( - 'You requested to overfit but enabled test/val dataloader shuffling.' + 'You requested to overfit but enabled val/test dataloader shuffling.' ' We are turning it off for you.' ) dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset)) @@ -303,7 +303,7 @@ def _reset_eval_dataloader( else: rank_zero_warn( f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn' - ' this off for validation and test dataloaders.' + ' this off for val/test/predict dataloaders.' ) if any([dl is None for dl in dataloaders]): @@ -380,7 +380,7 @@ def reset_predict_dataloader(self, model) -> None: if has_loader: self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(model, 'predict') - def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: + def request_dataloader(self, model: LightningModule, stage: str) -> DataLoader: """Handles downloading data in the GPU or TPU case. Args: @@ -389,9 +389,10 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: Returns: The dataloader """ - dataloader = dataloader_fx() + if model.trainer is not None: + model.trainer.call_hook(f"on_{stage}_dataloader") + dataloader: DataLoader = getattr(model, f'{stage}_dataloader')() dataloader = self._flatten_dl_only(dataloader) - self.accelerator.barrier('get_dataloaders') return dataloader diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 69d3887fc7718..32dbc8c4088a3 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -14,7 +14,7 @@ from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector -from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn +from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_deprecation class DeprecatedDistDeviceAttributes: @@ -24,96 +24,94 @@ class DeprecatedDistDeviceAttributes: @property def on_cpu(self) -> bool: - rank_zero_warn("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.") return self.accelerator_connector._device_type == DeviceType.CPU @on_cpu.setter def on_cpu(self, val: bool) -> None: - rank_zero_warn("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.") if val: self.accelerator_connector._device_type = DeviceType.CPU @property def on_tpu(self) -> bool: - rank_zero_warn("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.") return self.accelerator_connector._device_type == DeviceType.TPU @on_tpu.setter def on_tpu(self, val: bool) -> None: - rank_zero_warn("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.") if val: self.accelerator_connector._device_type = DeviceType.TPU @property def use_tpu(self) -> bool: - rank_zero_warn("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.") return self.on_tpu @use_tpu.setter def use_tpu(self, val: bool) -> None: - rank_zero_warn("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.") self.on_tpu = val @property def on_gpu(self) -> bool: - rank_zero_warn("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.") return self.accelerator_connector._device_type == DeviceType.GPU @on_gpu.setter def on_gpu(self, val: bool) -> None: - rank_zero_warn("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.") if val: self.accelerator_connector._device_type = DeviceType.GPU @property def use_dp(self) -> bool: - rank_zero_warn("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.") return self.accelerator_connector._distrib_type == DistributedType.DP @use_dp.setter def use_dp(self, val: bool) -> None: - rank_zero_warn("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.") if val: self.accelerator_connector._distrib_type = DistributedType.DP @property def use_ddp(self) -> bool: - rank_zero_warn("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.") return self.accelerator_connector._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) @use_ddp.setter def use_ddp(self, val: bool) -> None: - rank_zero_warn("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.") if val: self.accelerator_connector._distrib_type = DistributedType.DDP @property def use_ddp2(self) -> bool: - rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.") return self.accelerator_connector._distrib_type == DistributedType.DDP2 @use_ddp2.setter def use_ddp2(self, val: bool) -> None: - rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.") if val: self.accelerator_connector._distrib_type = DistributedType.DDP2 @property def use_horovod(self) -> bool: - rank_zero_warn("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.") return self.accelerator_connector._distrib_type == DistributedType.HOROVOD @use_horovod.setter def use_horovod(self, val: bool) -> None: - rank_zero_warn("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.") if val: self.accelerator_connector._distrib_type = DistributedType.HOROVOD @property def use_single_gpu(self) -> bool: - rank_zero_warn( - "Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning - ) + rank_zero_deprecation("Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.") # todo, limiting to exclude DDP2 is not clear but it comes from connectors... return ( self.accelerator_connector._device_type and self.accelerator_connector._device_type == DeviceType.GPU @@ -122,10 +120,7 @@ def use_single_gpu(self) -> bool: @use_single_gpu.setter def use_single_gpu(self, val: bool) -> None: - rank_zero_warn( - "Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.", - DeprecationWarning, - ) + rank_zero_deprecation("Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.") if val: self.accelerator_connector._device_type = DeviceType.GPU @@ -138,23 +133,22 @@ class DeprecatedTrainerAttributes: @property def accelerator_backend(self) -> Accelerator: - rank_zero_warn( + rank_zero_deprecation( "The `Trainer.accelerator_backend` attribute is deprecated in favor of `Trainer.accelerator`" - " since 1.2 and will be removed in v1.4.", DeprecationWarning + " since 1.2 and will be removed in v1.4." ) return self.accelerator def get_model(self) -> LightningModule: - rank_zero_warn( + rank_zero_deprecation( "The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`" - " and will be removed in v1.4.", DeprecationWarning + " and will be removed in v1.4." ) return self.lightning_module @property def running_sanity_check(self) -> bool: - rank_zero_warn( - "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking`" - " and will be removed in v1.5.", DeprecationWarning + rank_zero_deprecation( + "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5." ) return self.sanity_checking diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 91cfc2ec757d5..a87073428e725 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -11,12 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import torch from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.warnings import WarningCache @@ -97,6 +100,10 @@ def on_evaluation_end(self, *args, **kwargs): else: self.trainer.call_hook('on_validation_end', *args, **kwargs) + if self.trainer.state != TrainerState.FITTING: + # summarize profile results + self.trainer.profiler.describe() + def reload_evaluation_dataloaders(self): model = self.trainer.lightning_module if self.trainer.testing: @@ -118,6 +125,8 @@ def setup(self, model, max_batches, dataloaders): self._predictions = [[] for _ in range(self.num_dataloaders)] def on_evaluation_epoch_start(self, *args, **kwargs): + self.trainer.call_hook('on_epoch_start', *args, **kwargs) + if self.trainer.testing: self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) else: @@ -202,9 +211,6 @@ def __run_eval_epoch_end(self, num_dataloaders): # with a single dataloader don't pass an array outputs = self.outputs - # free memory - self.outputs = [] - eval_results = outputs if num_dataloaders == 1: eval_results = outputs[0] @@ -313,13 +319,40 @@ def store_predictions(self, output, batch_idx, dataloader_idx): def on_evaluation_epoch_end(self, *args, **kwargs): # call the callback hook - if self.trainer.testing: - self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) - else: - self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) + self.call_on_evaluation_epoch_end_hook() self.trainer.call_hook('on_epoch_end') + def call_on_evaluation_epoch_end_hook(self): + outputs = self.outputs + + # free memory + self.outputs = [] + + model_ref = self.trainer.lightning_module + hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" + + self.trainer._reset_result_and_set_hook_fx_name(hook_name) + + with self.trainer.profiler.profile(hook_name): + + if hasattr(self.trainer, hook_name): + on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name) + on_evaluation_epoch_end_hook(outputs) + + if is_overridden(hook_name, model_ref): + model_hook_fx = getattr(model_ref, hook_name) + if is_param_in_hook_signature(model_hook_fx, "outputs"): + model_hook_fx(outputs) + else: + self.warning_cache.warn( + f"`ModelHooks.{hook_name}` signature has changed in v1.3. `outputs` parameter has been added." + " Support for the old signature will be removed in v1.5", DeprecationWarning + ) + model_hook_fx() + + self.trainer._cache_logged_metrics() + def log_evaluation_step_metrics(self, output, batch_idx): if self.trainer.sanity_checking: return diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index 4fe6960055ca9..b33f41cb2ea48 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -28,23 +28,24 @@ def __init__(self, trainer): def on_trainer_init(self): self.trainer.num_predict_batches = [] - def get_predict_dataloaders(self, max_batches): + def get_predict_dataloaders(self): self.trainer.reset_predict_dataloader(self.trainer.lightning_module) dataloaders = self.trainer.predict_dataloaders - if max_batches is None: - max_batches = self.trainer.num_predict_batches + max_batches = self.trainer.num_predict_batches return dataloaders, max_batches - def should_skip_predict(self, dataloaders, max_batches): - return dataloaders is None or not sum(max_batches) + def should_skip_predict(self, max_batches): + return sum(max_batches) == 0 def on_predict_model_eval(self, *_, **__): model_ref = self.trainer.lightning_module model_ref.on_predict_model_eval() def setup(self, model, max_batches, dataloaders): + self.trainer.call_hook("on_predict_start") + # copy properties for forward overrides self.trainer.model_connector.copy_trainer_model_properties(model) @@ -66,7 +67,7 @@ def _get_num_dataloaders(self, dataloaders): length = len(dataloaders[0]) return length - def predict(self, batch, batch_idx, dataloader_idx): + def predict_step(self, batch, batch_idx, dataloader_idx): # configure args args = [batch, batch_idx] if self.num_dataloaders: @@ -75,7 +76,7 @@ def predict(self, batch, batch_idx, dataloader_idx): model_ref = self.trainer.lightning_module model_ref._current_fx_name = "predict" - predictions = self.trainer.accelerator.predict(args) + predictions = self.trainer.accelerator.predict_step(args) if predictions is None: self.warning_cache.warn("predict returned None if it was on purpose, ignore this warning...") @@ -87,6 +88,8 @@ def predict(self, batch, batch_idx, dataloader_idx): return def on_predict_epoch_end(self): + self.trainer.profiler.describe() + self.trainer._progress_bar_callback.on_predict_end(self.trainer, self.trainer.lightning_module) results = self._predictions @@ -100,3 +103,11 @@ def _convert_to_numpy(v): return results[0] return results + + def on_predict_start(self): + # hook + self.trainer.call_hook("on_predict_start") + + def on_predict_end(self): + # hook + self.trainer.call_hook("on_predict_end") diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index b5654b148afc6..315e3c60c0557 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -491,6 +491,16 @@ def sanity_checking(self, val: bool) -> None: elif self.sanity_checking: self._running_stage = None + @property + def _setup_state(self) -> TrainerState: + # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" + return TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + + @property + def _teardown_state(self) -> Optional[TrainerState]: + if self.state.running: + return self._setup_state + # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c3039d24aadc0..c692c3f1c113f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -38,7 +38,7 @@ from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector -from pytorch_lightning.trainer.connectors.env_vars_connector import overwrite_by_env_vars +from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector @@ -57,7 +57,7 @@ from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -84,7 +84,7 @@ class Trainer( DeprecatedTrainerAttributes, ): - @overwrite_by_env_vars + @_defaults_from_env_vars def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, @@ -198,11 +198,13 @@ def __init__( gradient_clip_val: 0 means don't clip. - limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches) + limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches) - limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches) + limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches) - limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches) + limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches) + + limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches) logger: Logger (or iterable collection of loggers) for experiment tracking. @@ -221,11 +223,12 @@ def __init__( profiler: To profile individual steps during training and assist in identifying bottlenecks. - overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). + overfit_batches: Overfit a fraction of training data (float) or a set number of batches (int). plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. - precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs. + precision: Double precision (64), full precision (32) or half precision (16). Can be used on CPU, GPU or + TPUs. max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000. @@ -318,6 +321,10 @@ def __init__( self.predict_loop = PredictLoop(self) # training state + if weights_summary is not None and weights_summary not in ModelSummary.MODES: + raise MisconfigurationException( + f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, but got {weights_summary}" + ) self.weights_summary = weights_summary self.shown_warnings = set() @@ -349,7 +356,6 @@ def __init__( max_steps, min_steps, num_sanity_val_steps, - weights_summary, ) self.evaluation_loop.on_trainer_init() @@ -426,8 +432,10 @@ def fit( # ---------------------------- # SET UP TRAINING # ---------------------------- - self.call_setup_hook(model) self.call_hook("on_before_accelerator_backend_setup", model) + self.accelerator.connect(model) + self.accelerator.setup_environment() + self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module # ---------------------------- @@ -441,13 +449,15 @@ def fit( | || {self.dispatch} || | || LIGHTNING - {self.accelerator.start_training} or || - {self.accelerator.start_evaluating} or || FLOW - {self.accelerator.start_predicting} || + {self.accelerator.start_training} || + or {self.accelerator.start_evaluating} || + or {self.accelerator.start_predicting} || FLOW + | || + {self.run_stage} || | || DIRECTION - {self.run_train} or || - {self.run_evaluation} or || - {self.run_predict} || + {self.run_train} || + or {self.run_evaluation} || + or {self.run_predict} || | || results \/ This is used to guide readers to the core loops: train, test, predict. @@ -491,7 +501,7 @@ def fit( return self.accelerator.results or 1 def pre_dispatch(self): - self.accelerator.pre_dispatch() + self.accelerator.pre_dispatch(self) # log hyper-parameters if self.logger is not None: @@ -501,7 +511,7 @@ def pre_dispatch(self): self.logger.save() def post_dispatch(self): - self.accelerator.post_dispatch() + self.accelerator.post_dispatch(self) self.accelerator.teardown() def dispatch(self): @@ -514,6 +524,9 @@ def dispatch(self): def run_stage(self): results = None + + self.profile_connector.setup() + if self.evaluating: results = self.run_evaluate() elif self.predicting: @@ -540,10 +553,7 @@ def _pre_training_routine(self): # print model summary if self.is_global_zero and self.weights_summary is not None and not self.testing: - if self.weights_summary in ModelSummary.MODES: - ref_model.summarize(mode=self.weights_summary) - else: - raise MisconfigurationException("weights_summary can be None, " + ", ".join(ModelSummary.MODES)) + ref_model.summarize(mode=self.weights_summary) # restore training and model before hpc is called self.checkpoint_connector.restore_weights() @@ -753,11 +763,13 @@ def run_evaluate(self): return eval_loop_results def run_predict(self): + self.predict_loop.on_predict_start() + # prepare dataloaders - dataloaders, max_batches = self.predict_loop.get_predict_dataloaders(None) + dataloaders, max_batches = self.predict_loop.get_predict_dataloaders() # check if we want to skip this evaluation - if self.predict_loop.should_skip_predict(dataloaders, max_batches): + if self.predict_loop.should_skip_predict(max_batches): return [] # ref model @@ -775,7 +787,6 @@ def run_predict(self): for dataloader_idx, dataloader in enumerate(dataloaders): dataloader = self.accelerator.process_dataloader(dataloader) dl_max_batches = self.predict_loop.max_batches[dataloader_idx] - for batch_idx, batch in enumerate(dataloader): if batch is None: continue @@ -785,10 +796,15 @@ def run_predict(self): break # lightning module methods - with self.profiler.profile("predict"): - self.predict_loop.predict(batch, batch_idx, dataloader_idx) + with self.profiler.profile("predict_step"): + self.predict_loop.predict_step(batch, batch_idx, dataloader_idx) results = self.predict_loop.on_predict_epoch_end() + self.predict_loop.on_predict_end() + + # re-enable grads + torch.set_grad_enabled(True) + return results def run_sanity_check(self, ref_model): @@ -922,9 +938,7 @@ def test( # If you supply a datamodule you can't supply test_dataloaders if test_dataloaders and datamodule: - raise MisconfigurationException( - 'You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`' - ) + raise MisconfigurationException('You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`') model_provided = model is not None model = model or self.lightning_module @@ -969,7 +983,9 @@ def __load_ckpt_weights( ' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`' ) - self.training_type_plugin.barrier() + # only one process running at this point for TPUs, as spawn isn't triggered yet + if not self._device_type == DeviceType.TPU: + self.training_type_plugin.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) @@ -1058,8 +1074,7 @@ def tune( def call_setup_hook(self, model: LightningModule) -> None: assert self.state.running, f"TrainerState: {self.state}" - # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" - state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + state = self._setup_state if self.datamodule is not None: called = getattr(self.datamodule, f'has_setup_{state}') @@ -1070,11 +1085,14 @@ def call_setup_hook(self, model: LightningModule) -> None: model.setup(stage=state) def call_teardown_hook(self, model: LightningModule) -> None: - if self.state.running: - state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state - else: - state = None + state = self._teardown_state + + if self.datamodule is not None: + called = getattr(self.datamodule, f'has_teardown_{state}') + if not called: + self.datamodule.teardown(stage=state) + self.profiler.teardown(stage=state) self.teardown(stage=state) model.teardown(stage=state) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 88b87afcb9358..427ef8100af28 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -14,12 +14,12 @@ from contextlib import contextmanager, suppress from copy import copy, deepcopy +from typing import Optional import numpy as np import torch from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins import ParallelPlugin @@ -36,7 +36,7 @@ class TrainLoop: - def __init__(self, trainer, multiple_trainloader_mode): + def __init__(self, trainer, multiple_trainloader_mode: str): self.trainer = trainer self.early_stopping_accumulator = None self.checkpoint_accumulator = None @@ -53,13 +53,12 @@ def __init__(self, trainer, multiple_trainloader_mode): def on_trainer_init( self, - max_epochs, - min_epochs, - max_steps, - min_steps, - num_sanity_val_steps, - weights_summary, - ): + max_epochs: Optional[int], + min_epochs: Optional[int], + max_steps: Optional[int], + min_steps: Optional[int], + num_sanity_val_steps: int, + ) -> None: self.trainer.global_step = 0 self.trainer.current_epoch = 0 self.trainer.should_stop = False @@ -82,12 +81,6 @@ def on_trainer_init( else: self.trainer.num_sanity_val_steps = num_sanity_val_steps - self.trainer.weights_summary = weights_summary - if weights_summary is not None and weights_summary not in ModelSummary.MODES: - raise MisconfigurationException( - f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, got {weights_summary}" - ) - @property def num_optimizers(self): num_optimizers = len(self.get_optimizers_iterable()) @@ -102,10 +95,7 @@ def on_train_start(self): # hook self.trainer.call_hook("on_train_start") - # provide rank to profiler - self.trainer.profile_connector.on_train_start(self.trainer) - - def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): + def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) @@ -140,8 +130,7 @@ def on_train_end(self): self.trainer.logger.finalize("success") # summarize profile results - if self.trainer.global_rank == 0: - self.trainer.profiler.describe() + self.trainer.profiler.describe() # give accelerators a chance to finish self.trainer.accelerator.on_train_end() @@ -188,7 +177,7 @@ def on_train_epoch_start(self, epoch): self.trainer.train_dataloader.sampler.set_epoch(epoch) # changing gradient according accumulation_scheduler - self.trainer.accumulation_scheduler.on_epoch_start(self.trainer, self.trainer.lightning_module) + self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch self.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches) @@ -551,7 +540,7 @@ def run_training_epoch(self): self.increment_accumulated_grad_global_step() # epoch end hook - self.run_on_epoch_end_hook(epoch_output) + self.on_train_epoch_end(epoch_output) # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics( @@ -747,7 +736,7 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, # backward pass if result is not None: - with self.trainer.profiler.profile("model_backward"): + with self.trainer.profiler.profile("backward"): self.backward(result, optimizer, opt_idx) # hook - call this hook only @@ -793,7 +782,7 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None): # update lr self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics) - def run_on_epoch_end_hook(self, epoch_output): + def on_train_epoch_end(self, epoch_output): # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 78810141b1369..b9fa9afe0e77e 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -33,13 +33,20 @@ def on_trainer_init(self, auto_lr_find, auto_scale_batch_size): self.trainer.auto_lr_find = auto_lr_find self.trainer.auto_scale_batch_size = auto_scale_batch_size - def tune(self, model, train_dataloader, val_dataloaders, datamodule): + def setup_trainer( + self, + model: LightningModule, + train_dataloader: Optional[DataLoader] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: LightningDataModule = None, + ): + self.trainer.model_connector.copy_trainer_model_properties(model) # setup data, etc... self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) - # hook self.trainer.data_connector.prepare_data(model) + def tune(self, model, train_dataloader, val_dataloaders, datamodule): # Run auto batch size scaling if self.trainer.auto_scale_batch_size: if isinstance(self.trainer.auto_scale_batch_size, bool): @@ -104,6 +111,7 @@ def scale_batch_size( or datamodule. """ + self.setup_trainer(model, **fit_kwargs) return scale_batch_size( self.trainer, model, @@ -128,6 +136,7 @@ def lr_find( datamodule: Optional[LightningDataModule] = None, update_attr: bool = False, ): + self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule) return lr_find( self.trainer, model, diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 3e2ee3e51efe1..f4617c23da383 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -17,6 +17,7 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 from pytorch_lightning.utilities.distributed import ( # noqa: F401 AllGatherGrad, + rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn, diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index ee42ab3241ff6..46d88184ee190 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -67,7 +67,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp # Value has been passed as a flag => It is currently None, so we need to set it to True # We always set to True, regardless of the default value. # Users must pass False directly, but when passing nothing True is assumed. - # i.e. the only way to disable somthing that defaults to True is to use the long form: + # i.e. the only way to disable something that defaults to True is to use the long form: # "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None, # which then becomes True here. @@ -107,7 +107,7 @@ def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: - r"""Scans the Trainer signature and returns argument names, types and default values. + r"""Scans the class signature and returns argument names, types and default values. Returns: List with tuples of 3 values: @@ -119,11 +119,11 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: >>> args = get_init_arguments_and_types(Trainer) """ - trainer_default_params = inspect.signature(cls).parameters + cls_default_params = inspect.signature(cls).parameters name_type_default = [] - for arg in trainer_default_params: - arg_type = trainer_default_params[arg].annotation - arg_default = trainer_default_params[arg].default + for arg in cls_default_params: + arg_type = cls_default_params[arg].annotation + arg_default = cls_default_params[arg].default try: arg_types = tuple(arg_type.__args__) except AttributeError: @@ -242,9 +242,6 @@ def add_argparse_args( if arg == 'track_grad_norm': use_type = float - if arg_default is inspect._empty: - arg_default = None - parser.add_argument( f'--{arg}', dest=arg, @@ -291,10 +288,7 @@ def _gpus_allowed_type(x) -> Union[int, str]: def _gpus_arg_default(x) -> Union[int, str]: - if ',' in x: - return str(x) - else: - return int(x) + return _gpus_allowed_type(x) def _int_or_float_type(x) -> Union[int, float]: diff --git a/pytorch_lightning/utilities/argparse_utils.py b/pytorch_lightning/utilities/argparse_utils.py index eb53579f948e8..80db2429f7d2a 100644 --- a/pytorch_lightning/utilities/argparse_utils.py +++ b/pytorch_lightning/utilities/argparse_utils.py @@ -1,7 +1,5 @@ -from warnings import warn +from pytorch_lightning.utilities import rank_zero_deprecation -warn( - "`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v1.4", DeprecationWarning -) +rank_zero_deprecation("`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v1.4") from pytorch_lightning.utilities.argparse import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index e797c32bbf917..bf7a199fc08dc 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -15,7 +15,7 @@ import logging import os import warnings -from functools import wraps +from functools import partial, wraps from typing import Any, Optional, Union import torch @@ -24,6 +24,7 @@ if torch.distributed.is_available(): from torch.distributed import group, ReduceOp + else: class ReduceOp: @@ -62,6 +63,7 @@ def _debug(*args, **kwargs): rank_zero_debug = rank_zero_only(_debug) rank_zero_info = rank_zero_only(_info) rank_zero_warn = rank_zero_only(_warn) +rank_zero_deprecation = partial(rank_zero_warn, category=DeprecationWarning) def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None): @@ -171,7 +173,7 @@ def backward(ctx, *grad_output): torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) - return grad_output[torch.distributed.get_rank()] + return grad_output[torch.distributed.get_rank()], None def all_gather_ddp_if_available( diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 41a13d6c678a0..baeac9be57218 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities""" +import importlib import operator import platform import sys @@ -19,7 +20,7 @@ from importlib.util import find_spec import torch -from pkg_resources import DistributionNotFound, get_distribution +from pkg_resources import DistributionNotFound def _module_available(module_path: str) -> bool: @@ -42,11 +43,24 @@ def _module_available(module_path: str) -> bool: def _compare_version(package: str, op, version) -> bool: + """ + Compare package version with some requirements + + >>> _compare_version("torch", operator.ge, "0.1") + True + """ try: - pkg_version = LooseVersion(get_distribution(package).version) - return op(pkg_version, LooseVersion(version)) - except DistributionNotFound: + pkg = importlib.import_module(package) + except (ModuleNotFoundError, DistributionNotFound): + return False + try: + pkg_version = LooseVersion(pkg.__version__) + except AttributeError: return False + if not (hasattr(pkg_version, "vstring") and hasattr(pkg_version, "version")): + # this is mock by sphinx, so it shall return True ro generate all summaries + return True + return op(pkg_version, LooseVersion(version)) _IS_WINDOWS = platform.system() == "Windows" @@ -54,7 +68,9 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0") _TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0") _TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") +_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0") +_KINETO_AVAILABLE = torch.profiler.kineto_available() if _TORCH_GREATER_EQUAL_1_8 else False _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available('pl_bolts') _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') diff --git a/pytorch_lightning/utilities/model_utils.py b/pytorch_lightning/utilities/model_utils.py index 7fd5b287f7ba3..728f73f4f0d32 100644 --- a/pytorch_lightning/utilities/model_utils.py +++ b/pytorch_lightning/utilities/model_utils.py @@ -1,8 +1,7 @@ -from warnings import warn +from pytorch_lightning.utilities import rank_zero_deprecation -warn( - "`model_utils` package has been renamed to `model_helpers` since v1.2 and will be removed in v1.4", - DeprecationWarning +rank_zero_deprecation( + "`model_utils` package has been renamed to `model_helpers` since v1.2 and will be removed in v1.4" ) from pytorch_lightning.utilities.model_helpers import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/signature_utils.py b/pytorch_lightning/utilities/signature_utils.py new file mode 100644 index 0000000000000..546d8e845ecb1 --- /dev/null +++ b/pytorch_lightning/utilities/signature_utils.py @@ -0,0 +1,22 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Callable + + +def is_param_in_hook_signature(hook_fx: Callable, param: str) -> bool: + hook_params = list(inspect.signature(hook_fx).parameters) + if "args" in hook_params or param in hook_params: + return True + return False diff --git a/pytorch_lightning/utilities/warning_utils.py b/pytorch_lightning/utilities/warning_utils.py index c520086f62a81..0668bababa609 100644 --- a/pytorch_lightning/utilities/warning_utils.py +++ b/pytorch_lightning/utilities/warning_utils.py @@ -1,7 +1,5 @@ -from warnings import warn +from pytorch_lightning.utilities import rank_zero_deprecation -warn( - "`warning_utils` package has been renamed to `warnings` since v1.2 and will be removed in v1.4", DeprecationWarning -) +rank_zero_deprecation("`warning_utils` package has been renamed to `warnings` since v1.2 and will be removed in v1.4") from pytorch_lightning.utilities.warnings import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py index aa0af1697ac51..f028222e3930b 100644 --- a/pytorch_lightning/utilities/xla_device_utils.py +++ b/pytorch_lightning/utilities/xla_device_utils.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from warnings import warn +from pytorch_lightning.utilities import rank_zero_deprecation -warn( - "`xla_device_utils` package has been renamed to `xla_device` since v1.2 and will be removed in v1.4", - DeprecationWarning +rank_zero_deprecation( + "`xla_device_utils` package has been renamed to `xla_device` since v1.2 and will be removed in v1.4" ) from pytorch_lightning.utilities.xla_device import * # noqa: F403 E402 F401 diff --git a/requirements.txt b/requirements.txt index bdfd6601ba4c2..4649983b79d78 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,5 @@ PyYAML>=5.1, !=5.4.* # OmegaConf requirement >=5.1 tqdm>=4.41.0 fsspec[http]>=0.8.1 tensorboard>=2.2.0 +torchmetrics>=0.2.0 +pyDeprecate==0.1.1 \ No newline at end of file diff --git a/requirements/adjust_versions.py b/requirements/adjust_versions.py index c1499cd4ea5ee..d0dfbc59e2352 100644 --- a/requirements/adjust_versions.py +++ b/requirements/adjust_versions.py @@ -11,6 +11,7 @@ "1.7.0": dict(torchvision="0.8.1", torchtext="0.8"), "1.7.1": dict(torchvision="0.8.2", torchtext="0.8.1"), "1.8.0": dict(torchvision="0.9.0", torchtext="0.9"), + "1.8.1": dict(torchvision="0.9.0", torchtext="0.9"), } diff --git a/requirements/extra.txt b/requirements/extra.txt index 85437327bce06..715916c4e36ac 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -4,7 +4,8 @@ matplotlib>3.1 horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already installed omegaconf>=2.0.1 torchtext>=0.5 -onnx>=1.7.0 +# onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 +# todo: when switch to standard package stream, drop `fairscale` from hard mocked docs libs https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip diff --git a/requirements/test.txt b/requirements/test.txt index 60c861cea9c50..259cc2e2d6442 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,12 +1,11 @@ -coverage>=5.2 +coverage>5.2.0 codecov>=2.1 pytest>=6.0 -pytest-cov>2.10 -pytest-xdist +#pytest-cov>2.10 +#pytest-xdist flake8>=3.6 check-manifest twine==3.2 -# scipy>=0.13.3 scikit-learn>=0.22.2 scikit-image>=0.17.2 isort>=5.6.4 diff --git a/setup.cfg b/setup.cfg index 0e64df0530d82..6365482e32aa8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,23 +39,14 @@ exclude_lines = pass rank_zero_warn raise NotImplementedError - # TODO: figure out how to get codecov to pick up the test results on these backends # The actual coverage for each is 90%+ # *metrics (94%+) are temporarily removed from testing while tests speed up omit = - pytorch_lightning/accelerators/ddp_*.py - pytorch_lightning/accelerators/ddp2_*.py - pytorch_lightning/accelerators/dp_*.py - pytorch_lightning/accelerators/tpu_*.py + pytorch_lightning/cluster_environments/*.py pytorch_lightning/utilities/xla_device_utils.py pytorch_lightning/utilities/distributed.py pytorch_lightning/tuner/auto_gpu_select.py - # TODO: temporary, until accelerator refactor is finished - pytorch_lightning/accelerators/accelerator.py - pytorch_lightning/plugins/training_type/*.py - pytorch_lightning/plugins/precision/*.py - pytorch_lightning/plugins/base_plugin.py [flake8] @@ -73,10 +64,8 @@ verbose = 2 # https://pep8.readthedocs.io/en/latest/intro.html#error-codes format = pylint ignore = - E731 # do not assign a lambda expression, use a def - W503 # line break before binary operator - # because of YAPF - till https://github.com/google/yapf/issues/897 is resolved - E231 # missing whitespace after ',', ';', or ':'; for black + E731 # Ignore "Do not assign a lambda expression, use a def" + W503 # Ignore "Line break occurred before a binary operator" # setup.cfg or tox.ini diff --git a/setup.py b/setup.py index 5d619d51977b2..e53e24ebf0702 100755 --- a/setup.py +++ b/setup.py @@ -16,20 +16,22 @@ import os # Always prefer setuptools over distutils +import sys + from setuptools import find_packages, setup try: - import builtins + from pytorch_lightning import info, setup_tools except ImportError: - import __builtin__ as builtins + # alternative https://stackoverflow.com/a/67692/4521646 + sys.path.append("pytorch_lightning") + import info + import setup_tools # https://packaging.python.org/guides/single-sourcing-package-version/ # http://blog.ionelmc.ro/2014/05/25/python-packaging/ -PATH_ROOT = os.path.dirname(__file__) -builtins.__LIGHTNING_SETUP__ = True - -import pytorch_lightning # noqa: E402 -from pytorch_lightning.setup_tools import _load_readme_description, _load_requirements # noqa: E402 +_PATH_ROOT = os.path.dirname(__file__) +_PATH_REQUIRE = os.path.join(_PATH_ROOT, 'requirements') # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras # Define package extras. These are only installed if you specify them. @@ -37,10 +39,10 @@ # From local copy of repo, use like `pip install ".[dev, docs]"` extras = { # 'docs': load_requirements(file_name='docs.txt'), - 'examples': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='examples.txt'), - 'loggers': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='loggers.txt'), - 'extra': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='extra.txt'), - 'test': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='test.txt') + 'examples': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='examples.txt'), + 'loggers': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='loggers.txt'), + 'extra': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='extra.txt'), + 'test': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='test.txt') } extras['dev'] = extras['extra'] + extras['loggers'] + extras['test'] extras['all'] = extras['dev'] + extras['examples'] # + extras['docs'] @@ -53,6 +55,12 @@ # filter cpu only packages extras[ex] = [pkg for pkg in extras[kw] if not any(pgpu.lower() in pkg.lower() for pgpu in PACKAGES_GPU_ONLY)] +long_description = setup_tools._load_readme_description( + _PATH_ROOT, + homepage=info.__homepage__, + version=info.__version__, +) + # https://packaging.python.org/discussions/install-requires-vs-requirements / # keep the meta-data here for simplicity in reading this file... it's not obvious # what happens and to non-engineers they won't know to look in init ... @@ -60,22 +68,22 @@ # engineer specific practices setup( name="pytorch-lightning", - version=pytorch_lightning.__version__, - description=pytorch_lightning.__docs__, - author=pytorch_lightning.__author__, - author_email=pytorch_lightning.__author_email__, - url=pytorch_lightning.__homepage__, + version=info.__version__, + description=info.__docs__, + author=info.__author__, + author_email=info.__author_email__, + url=info.__homepage__, download_url='https://github.com/PyTorchLightning/pytorch-lightning', - license=pytorch_lightning.__license__, + license=info.__license__, packages=find_packages(exclude=['tests', 'tests/*', 'benchmarks', 'legacy', 'legacy/*']), - long_description=_load_readme_description(PATH_ROOT), + long_description=long_description, long_description_content_type='text/markdown', include_package_data=True, zip_safe=False, keywords=['deep learning', 'pytorch', 'AI'], python_requires='>=3.6', setup_requires=[], - install_requires=_load_requirements(PATH_ROOT), + install_requires=setup_tools._load_requirements(_PATH_ROOT), extras_require=extras, project_urls={ "Bug Tracker": "https://github.com/PyTorchLightning/pytorch-lightning/issues", diff --git a/tests/__init__.py b/tests/__init__.py index 433f183896dee..fc634e6b73fec 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -19,8 +19,8 @@ _TEST_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_TEST_ROOT) _TEMP_PATH = os.path.join(_PROJECT_ROOT, 'test_temp') -DATASETS_PATH = os.path.join(_PROJECT_ROOT, 'Datasets') -LEGACY_PATH = os.path.join(_PROJECT_ROOT, 'legacy') +PATH_DATASETS = os.path.join(_PROJECT_ROOT, 'Datasets') +PATH_LEGACY = os.path.join(_PROJECT_ROOT, 'legacy') # todo: this setting `PYTHONPATH` may not be used by other evns like Conda for import packages if _PROJECT_ROOT not in os.getenv('PYTHONPATH', ""): diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index e6139de5d3028..79a17df074e35 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -98,7 +98,8 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock): "SLURM_LOCALID": "10" } ) -def test_accelerator_choice_ddp_slurm(): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_slurm(setup_distributed_mock): class CB(Callback): @@ -136,7 +137,8 @@ def on_fit_start(self, trainer, pl_module): } ) @mock.patch('torch.cuda.device_count', return_value=2) -def test_accelerator_choice_ddp2_slurm(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp2_slurm(device_count_mock, setup_distributed_mock): class CB(Callback): @@ -165,7 +167,8 @@ def on_fit_start(self, trainer, pl_module): @RunIf(min_gpus=1) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"}) @mock.patch('torch.cuda.device_count', return_value=2) -def test_accelerator_choice_ddp_te(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_te(device_count_mock, setup_distributed_mock): class CB(Callback): @@ -193,7 +196,8 @@ def on_fit_start(self, trainer, pl_module): @RunIf(min_gpus=1) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"}) @mock.patch('torch.cuda.device_count', return_value=2) -def test_accelerator_choice_ddp2_te(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp2_te(device_count_mock, setup_distributed_mock): class CB(Callback): @@ -224,7 +228,8 @@ def on_fit_start(self, trainer, pl_module): "NODE_RANK": "0", }) @mock.patch('torch.cuda.device_count', return_value=0) -def test_accelerator_choice_ddp_cpu_te(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_cpu_te(device_count_mock, setup_distributed_mock): class CB(Callback): @@ -259,7 +264,8 @@ def on_fit_start(self, trainer, pl_module): } ) @mock.patch('torch.cuda.device_count', return_value=0) -def test_accelerator_choice_ddp_cpu_slurm(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock): class CB(Callback): @@ -294,7 +300,8 @@ def on_fit_start(self, trainer, pl_module): } ) @mock.patch('torch.cuda.device_count', return_value=0) -def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock, setup_distributed_mock): """ Test that we choose the custom cluster even when SLURM or TE flags are around """ @@ -304,6 +311,9 @@ class CustomCluster(LightningEnvironment): def master_address(self): return 'asdf' + def creates_children(self) -> bool: + return True + class CB(Callback): def on_fit_start(self, trainer, pl_module): @@ -336,7 +346,8 @@ def on_fit_start(self, trainer, pl_module): } ) @mock.patch('torch.cuda.device_count', return_value=0) -def test_custom_accelerator(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_custom_accelerator(device_count_mock, setup_distributed_mock): class Accel(Accelerator): pass @@ -371,7 +382,8 @@ class TrainTypePlugin(SingleDevicePlugin): } ) @mock.patch('torch.cuda.device_count', return_value=0) -def test_dist_backend_accelerator_mapping(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_dist_backend_accelerator_mapping(device_count_mock, setup_distributed_mock): class CB(Callback): diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 6962af7249d1b..bd8636ba839f9 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -8,11 +8,13 @@ from tests.helpers.runif import RunIf -@pytest.mark.parametrize("trainer_kwargs", ( - pytest.param({"gpus": 1}, marks=RunIf(min_gpus=1)), - pytest.param({"accelerator": "dp", "gpus": 2}, marks=RunIf(min_gpus=2)), - pytest.param({"accelerator": "ddp_spawn", "gpus": 2}, marks=RunIf(min_gpus=2)), -)) +@pytest.mark.parametrize( + "trainer_kwargs", ( + pytest.param(dict(gpus=1), marks=RunIf(min_gpus=1)), + pytest.param(dict(accelerator="dp", gpus=2), marks=RunIf(min_gpus=2)), + pytest.param(dict(accelerator="ddp_spawn", gpus=2), marks=RunIf(min_gpus=2)), + ) +) def test_evaluate(tmpdir, trainer_kwargs): tutils.set_random_master_port() diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 81a5132e47356..46379a9d10c14 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -3,10 +3,12 @@ import pytest import torch +from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.plugins import SingleDevicePlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel def test_unsupported_precision_plugins(): @@ -18,3 +20,33 @@ def test_unsupported_precision_plugins(): ) with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."): accelerator.setup(trainer=trainer, model=model) + + +@pytest.mark.parametrize("delay_dispatch", [True, False]) +def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch): + """ + Test when using a custom training type plugin that delays setup optimizers, + we do not call setup optimizers till ``pre_dispatch``. + """ + + class TestModel(BoringModel): + + def on_fit_start(self): + if delay_dispatch: + # Ensure we haven't setup optimizers if we've delayed dispatch + assert len(self.trainer.optimizers) == 0 + else: + assert len(self.trainer.optimizers) > 0 + + def on_fit_end(self): + assert len(self.trainer.optimizers) > 0 + + class CustomPlugin(SingleDevicePlugin): + + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + return delay_dispatch + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=CustomPlugin(device=torch.device("cpu"))) + trainer.fit(model) diff --git a/tests/accelerators/test_ddp.py b/tests/accelerators/test_ddp.py index 14e73d920af4b..541110ac8846b 100644 --- a/tests/accelerators/test_ddp.py +++ b/tests/accelerators/test_ddp.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from typing import Optional +from unittest import mock from unittest.mock import patch import pytest @@ -91,7 +93,6 @@ def test_torch_distributed_backend_env_variables(tmpdir): _environ = {"PL_TORCH_DISTRIBUTED_BACKEND": "undefined", "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2"} with patch.dict(os.environ, _environ), \ patch('torch.cuda.device_count', return_value=2): - with pytest.raises(ValueError, match="Invalid backend: 'undefined'"): model = BoringModel() trainer = Trainer( @@ -102,3 +103,30 @@ def test_torch_distributed_backend_env_variables(tmpdir): logger=False, ) trainer.fit(model) + + +@RunIf(skip_windows=True) +@mock.patch('torch.cuda.device_count', return_value=1) +@mock.patch('torch.cuda.is_available', return_value=True) +@mock.patch('torch.cuda.set_device') +@mock.patch.dict(os.environ, {'PL_TORCH_DISTRIBUTED_BACKEND': 'gloo'}, clear=True) +def test_ddp_torch_dist_is_available_in_setup(mock_set_device, mock_is_available, mock_device_count, tmpdir): + """ + Test to ensure torch distributed is available within the setup hook using ddp + """ + + class TestModel(BoringModel): + + def setup(self, stage: Optional[str] = None) -> None: + assert torch.distributed.is_initialized() + raise SystemExit() + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + accelerator="ddp", + gpus=1, + ) + with pytest.raises(SystemExit): + trainer.fit(model) diff --git a/tests/base/model_template.py b/tests/base/model_template.py index 1ec2df7865caa..86578fef4c699 100644 --- a/tests/base/model_template.py +++ b/tests/base/model_template.py @@ -18,6 +18,7 @@ import torch.nn.functional as F from pytorch_lightning.core.lightning import LightningModule +from tests import PATH_DATASETS from tests.base.model_optimizers import ConfigureOptimizersPool from tests.base.model_test_dataloaders import TestDataloaderVariations from tests.base.model_test_epoch_ends import TestEpochEndVariations @@ -28,7 +29,7 @@ from tests.base.model_valid_dataloaders import ValDataloaderVariations from tests.base.model_valid_epoch_ends import ValidationEpochEndVariations from tests.base.model_valid_steps import ValidationStepVariations -from tests.helpers.datasets import PATH_DATASETS, TrialMNIST +from tests.helpers.datasets import TrialMNIST class EvalModelTemplate( diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py index 78926cc9a7dd4..df0eab31aac37 100644 --- a/tests/callbacks/test_callback_hook_outputs.py +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -71,3 +71,66 @@ def on_train_epoch_end(self, outputs) -> None: results = trainer.fit(model) assert results + + +def test_on_val_epoch_end_outputs(tmpdir): + + class CB(Callback): + + def on_validation_epoch_end(self, trainer, pl_module, outputs): + if trainer.running_sanity_check: + assert len(outputs[0]) == trainer.num_sanity_val_batches[0] + else: + assert len(outputs[0]) == trainer.num_val_batches[0] + + model = BoringModel() + + trainer = Trainer( + callbacks=CB(), + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + weights_summary=None, + ) + + trainer.fit(model) + + +def test_on_test_epoch_end_outputs(tmpdir): + + class CB(Callback): + + def on_test_epoch_end(self, trainer, pl_module, outputs): + assert len(outputs[0]) == trainer.num_test_batches[0] + + model = BoringModel() + + trainer = Trainer( + callbacks=CB(), + default_root_dir=tmpdir, + weights_summary=None, + ) + + trainer.test(model) + + +def test_free_memory_on_eval_outputs(tmpdir): + + class CB(Callback): + + def on_epoch_end(self, trainer, pl_module): + assert len(trainer.evaluation_loop.outputs) == 0 + + model = BoringModel() + + trainer = Trainer( + callbacks=CB(), + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + weights_summary=None, + ) + + trainer.fit(model) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 626eb59dffb9c..713971629bdf4 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -46,17 +46,18 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), - call.setup(trainer, model, 'fit'), call.on_before_accelerator_backend_setup(trainer, model), + call.setup(trainer, model, 'fit'), call.on_fit_start(trainer, model), call.on_pretrain_routine_start(trainer, model), call.on_pretrain_routine_end(trainer, model), call.on_sanity_check_start(trainer, model), call.on_validation_start(trainer, model), + call.on_epoch_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_epoch_end(trainer, model), + call.on_validation_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_sanity_check_end(trainer, model), @@ -84,10 +85,11 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_train_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_validation_start(trainer, model), + call.on_epoch_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_epoch_end(trainer, model), + call.on_validation_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC @@ -115,15 +117,16 @@ def test_trainer_callback_hook_system_test(tmpdir): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), - call.setup(trainer, model, 'test'), call.on_before_accelerator_backend_setup(trainer, model), + call.setup(trainer, model, 'test'), call.on_test_start(trainer, model), + call.on_epoch_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_test_batch_start(trainer, model, ANY, 1, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0), - call.on_test_epoch_end(trainer, model), + call.on_test_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_test_end(trainer, model), call.teardown(trainer, model, 'test'), @@ -148,15 +151,16 @@ def test_trainer_callback_hook_system_validate(tmpdir): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), - call.setup(trainer, model, 'validate'), call.on_before_accelerator_backend_setup(trainer, model), + call.setup(trainer, model, 'validate'), call.on_validation_start(trainer, model), + call.on_epoch_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_validation_batch_start(trainer, model, ANY, 1, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0), - call.on_validation_epoch_end(trainer, model), + call.on_validation_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.teardown(trainer, model, 'validate'), diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 397e471e8a4b8..7926bc46dd290 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -19,6 +19,7 @@ from pytorch_lightning import callbacks, seed_everything, Trainer from tests.helpers import BoringModel +from tests.helpers.runif import RunIf @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @@ -102,3 +103,42 @@ def training_step(self, batch, batch_idx): # make sure types are correct assert save_mock.call_count == expected + + +@mock.patch('torch.save') +@RunIf(special=True, min_gpus=2) +@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [(1, 1, 1.0, 1), (2, 2, 0.3, 5)]) +def test_top_k_ddp(save_mock, tmpdir, k, epochs, val_check_interval, expected): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + local_rank = int(os.getenv("LOCAL_RANK")) + self.log('my_loss', batch_idx * (1 + local_rank), on_epoch=True) + return super().training_step(batch, batch_idx) + + def training_epoch_end(self, outputs) -> None: + data = str(self.global_rank) + obj = [[data], (data, ), set(data)] + out = self.trainer.training_type_plugin.broadcast(obj) + assert obj == [[str(self.global_rank)], (str(self.global_rank), ), set(str(self.global_rank))] + assert out == [['0'], ('0', ), set('0')] + + model = TestModel() + trainer = Trainer( + callbacks=[callbacks.ModelCheckpoint(dirpath=tmpdir, monitor='my_loss_step', save_top_k=k, mode="max")], + default_root_dir=tmpdir, + max_epochs=epochs, + weights_summary=None, + val_check_interval=val_check_interval, + accelerator="ddp", + gpus=2, + limit_train_batches=64, + limit_val_batches=32, + ) + if os.getenv("LOCAL_RANK") == "0": + with pytest.raises(UserWarning, match="The value associated to the key my_loss_epoch: [15.5, 31.0]"): + trainer.fit(model) + assert save_mock.call_count == expected + else: + trainer.fit(model) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index a22e8b77e67a3..325cc4925f4f4 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -18,9 +18,9 @@ import pytest from pytorch_lightning import Trainer -from tests import LEGACY_PATH +from tests import PATH_LEGACY -LEGACY_CHECKPOINTS_PATH = os.path.join(LEGACY_PATH, 'checkpoints') +LEGACY_CHECKPOINTS_PATH = os.path.join(PATH_LEGACY, 'checkpoints') CHECKPOINT_EXTENSION = ".ckpt" @@ -56,6 +56,7 @@ "1.2.1", "1.2.2", "1.2.3", + "1.2.4", ] ) def test_resume_legacy_checkpoints(tmpdir, pl_version: str): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e5583b9bbdf86..75f25b90fa45f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -356,7 +356,7 @@ def on_train_start(self, trainer, pl_module): torch.save = Mock(wraps=torch.save) def on_save_checkpoint(self, trainer, pl_module, checkpoint): - # expect all ranks to run but only rank 0 will actually write the checkpoint file + # only rank 0 will call ``torch.save`` super().on_save_checkpoint(trainer, pl_module, checkpoint) self.on_save_checkpoint_count += 1 @@ -366,8 +366,7 @@ def on_train_end(self, trainer, pl_module): assert self.best_model_score assert self.on_save_checkpoint_count == self.expected_count if trainer.is_global_zero: - # twice the calls expected because ddp broadcast also uses torch.save - assert torch.save.call_count == self.expected_count * 2 + assert torch.save.call_count == self.expected_count else: assert torch.save.call_count == 0 diff --git a/tests/checkpointing/test_torch_saving.py b/tests/checkpointing/test_torch_saving.py index c8b1e96aeaf0a..8eabc4640046f 100644 --- a/tests/checkpointing/test_torch_saving.py +++ b/tests/checkpointing/test_torch_saving.py @@ -47,6 +47,7 @@ def test_model_torch_save_ddp_cpu(tmpdir): max_epochs=num_epochs, accelerator="ddp_cpu", num_processes=2, + logger=False, ) temp_path = os.path.join(tmpdir, 'temp.pt') trainer.fit(model) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 2118fec6c207b..c8808ec37326c 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -128,6 +128,10 @@ def test_data_hooks_called(tmpdir): assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict dm.prepare_data() assert dm.has_prepared_data @@ -135,6 +139,10 @@ def test_data_hooks_called(tmpdir): assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict dm.setup() assert dm.has_prepared_data @@ -142,49 +150,84 @@ def test_data_hooks_called(tmpdir): assert dm.has_setup_test assert dm.has_setup_validate assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict + + dm.teardown() + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_test + assert dm.has_setup_validate + assert not dm.has_setup_predict + assert dm.has_teardown_fit + assert dm.has_teardown_test + assert dm.has_teardown_validate + assert not dm.has_teardown_predict @pytest.mark.parametrize("use_kwarg", (False, True)) def test_data_hooks_called_verbose(tmpdir, use_kwarg): dm = BoringDataModule() - assert not dm.has_prepared_data - assert not dm.has_setup_fit - assert not dm.has_setup_test - dm.prepare_data() - assert dm.has_prepared_data assert not dm.has_setup_fit assert not dm.has_setup_test + assert not dm.has_setup_validate assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict dm.setup(stage='fit') if use_kwarg else dm.setup('fit') - assert dm.has_prepared_data assert dm.has_setup_fit assert not dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='validate') if use_kwarg else dm.setup('validate') - assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='test') if use_kwarg else dm.setup('test') - assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_validate assert dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='predict') if use_kwarg else dm.setup('predict') - assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_validate assert dm.has_setup_test assert dm.has_setup_predict + dm.teardown(stage='fit') if use_kwarg else dm.teardown('fit') + assert dm.has_teardown_fit + assert not dm.has_teardown_validate + assert not dm.has_teardown_test + assert not dm.has_teardown_predict + + dm.teardown(stage='validate') if use_kwarg else dm.teardown('validate') + assert dm.has_teardown_fit + assert dm.has_teardown_validate + assert not dm.has_teardown_test + assert not dm.has_teardown_predict + + dm.teardown(stage='test') if use_kwarg else dm.teardown('test') + assert dm.has_teardown_fit + assert dm.has_teardown_validate + assert dm.has_teardown_test + assert not dm.has_teardown_predict + + dm.teardown(stage='predict') if use_kwarg else dm.teardown('predict') + assert dm.has_teardown_fit + assert dm.has_teardown_validate + assert dm.has_teardown_test + assert dm.has_teardown_predict + def test_dm_add_argparse_args(tmpdir): parser = ArgumentParser() diff --git a/tests/core/test_hooks.py b/tests/core/test_hooks.py new file mode 100644 index 0000000000000..191da0a1400c7 --- /dev/null +++ b/tests/core/test_hooks.py @@ -0,0 +1,56 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import Trainer +from tests.helpers.boring_model import BoringModel + + +def test_on_val_epoch_end_outputs(tmpdir): + + class TestModel(BoringModel): + + def on_validation_epoch_end(self, outputs): + if trainer.running_sanity_check: + assert len(outputs[0]) == trainer.num_sanity_val_batches[0] + else: + assert len(outputs[0]) == trainer.num_val_batches[0] + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + weights_summary=None, + ) + + trainer.fit(model) + + +def test_on_test_epoch_end_outputs(tmpdir): + + class TestModel(BoringModel): + + def on_test_epoch_end(self, outputs): + assert len(outputs[0]) == trainer.num_test_batches[0] + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=2, + weights_summary=None, + ) + + trainer.test(model) diff --git a/tests/core/test_memory.py b/tests/core/test_memory.py index 903154adf823d..3088743f71488 100644 --- a/tests/core/test_memory.py +++ b/tests/core/test_memory.py @@ -88,6 +88,19 @@ def forward(self, x): return self.reduce(self.embed(x)) +class PartialScriptModel(LightningModule): + """ A model which contains scripted layers. """ + + def __init__(self): + super().__init__() + self.layer1 = torch.jit.script(nn.Linear(5, 3)) + self.layer2 = nn.Linear(3, 2) + self.example_input_array = torch.rand(2, 5) + + def forward(self, x): + return self.layer2(self.layer1(x)) + + def test_invalid_weights_summmary(): """ Test that invalid value for weights_summary raises an error. """ with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'): @@ -214,6 +227,15 @@ def test_summary_layer_types(mode): ] +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +def test_summary_with_scripted_modules(mode): + model = PartialScriptModel() + summary = model.summarize(mode=mode) + assert summary.layer_types == ["RecursiveScriptModule", "Linear"] + assert summary.in_sizes == [UNKNOWN_SIZE, [2, 3]] + assert summary.out_sizes == [UNKNOWN_SIZE, [2, 2]] + + @pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) @pytest.mark.parametrize(['example_input', 'expected_size'], [ pytest.param([], UNKNOWN_SIZE), @@ -265,7 +287,7 @@ def test_empty_model_size(mode): @RunIf(min_gpus=1, amp_native=True) -def test_model_size_precision(monkeypatch, tmpdir): +def test_model_size_precision(tmpdir): """ Test model size for half and full precision. """ model = PreCalculatedModel() diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 9d31688d9bcc0..0b797dff0e42f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -15,10 +15,10 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp +from torchmetrics import Metric import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result -from pytorch_lightning.metrics import Metric from tests.helpers.runif import RunIf diff --git a/tests/deprecated_api/__init__.py b/tests/deprecated_api/__init__.py index 99e21d1ed6b22..ccfae3ec8dcf2 100644 --- a/tests/deprecated_api/__init__.py +++ b/tests/deprecated_api/__init__.py @@ -13,9 +13,27 @@ # limitations under the License. """Test deprecated functionality which will be removed in vX.Y.Z""" import sys +from contextlib import contextmanager +from typing import Optional + +import pytest def _soft_unimport_module(str_module): # once the module is imported e.g with parsing with pytest it lives in memory if str_module in sys.modules: del sys.modules[str_module] + + +@contextmanager +def no_deprecated_call(match: Optional[str] = None): + with pytest.warns(None) as record: + yield + try: + w = record.pop(DeprecationWarning) + if match is not None and match not in str(w.message): + return + except AssertionError: + # no DeprecationWarning raised + return + raise AssertionError(f"`DeprecationWarning` was raised: {w}") diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index 39f5e0dca5075..99e1b31f6edad 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -130,16 +130,6 @@ def test_v1_4_0_deprecated_metrics(): with pytest.deprecated_call(match='will be removed in v1.4'): precision_recall(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3))) - # Testing deprecation of class_reduction arg in the *new* precision - from pytorch_lightning.metrics.functional import precision - with pytest.deprecated_call(match='will be removed in v1.4'): - precision(torch.randint(0, 2, (10, )), torch.randint(0, 2, (10, )), class_reduction='micro') - - # Testing deprecation of class_reduction arg in the *new* recall - from pytorch_lightning.metrics.functional import recall - with pytest.deprecated_call(match='will be removed in v1.4'): - recall(torch.randint(0, 2, (10, )), torch.randint(0, 2, (10, )), class_reduction='micro') - from pytorch_lightning.metrics.functional.classification import auc with pytest.deprecated_call(match='will be removed in v1.4'): auc(torch.rand(10, ).sort().values, torch.rand(10, )) @@ -152,14 +142,6 @@ def test_v1_4_0_deprecated_metrics(): with pytest.deprecated_call(match='will be removed in v1.4'): multiclass_auroc(torch.rand(20, 5).softmax(dim=-1), torch.randint(0, 5, (20, )), num_classes=5) - from pytorch_lightning.metrics.functional.classification import auc_decorator - with pytest.deprecated_call(match='will be removed in v1.4'): - auc_decorator() - - from pytorch_lightning.metrics.functional.classification import multiclass_auc_decorator - with pytest.deprecated_call(match='will be removed in v1.4'): - multiclass_auc_decorator() - class CustomDDPPlugin(DDPSpawnPlugin): diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index e65ebbab254de..fc3fe3112e71e 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -20,6 +20,9 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.profiler import AdvancedProfiler, BaseProfiler, PyTorchProfiler, SimpleProfiler +from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache +from tests.deprecated_api import no_deprecated_call from tests.helpers import BoringModel from tests.helpers.utils import no_warning_call @@ -78,6 +81,11 @@ def on_save_checkpoint(self, *args): trainer.save_checkpoint(filepath) +def test_v1_5_0_legacy_profiler_argument(): + with pytest.deprecated_call(match="renamed to `record_functions` in v1.3"): + PyTorchProfiler(profiled_functions=[]) + + def test_v1_5_0_running_sanity_check(): trainer = Trainer() with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'): @@ -111,3 +119,102 @@ def test_v1_5_0_model_checkpoint_period(tmpdir): ModelCheckpoint(dirpath=tmpdir) with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): ModelCheckpoint(dirpath=tmpdir, period=1) + + +def test_v1_5_0_old_on_validation_epoch_end(tmpdir): + callback_warning_cache.clear() + + class OldSignature(Callback): + + def on_validation_epoch_end(self, trainer, pl_module): # noqa + ... + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.fit(model) + + class OldSignatureModel(BoringModel): + + def on_validation_epoch_end(self): # noqa + ... + + model = OldSignatureModel() + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.fit(model) + + callback_warning_cache.clear() + + class NewSignature(Callback): + + def on_validation_epoch_end(self, trainer, pl_module, outputs): + ... + + trainer.callbacks = [NewSignature()] + with no_deprecated_call(match="`Callback.on_validation_epoch_end` signature has changed in v1.3."): + trainer.fit(model) + + class NewSignatureModel(BoringModel): + + def on_validation_epoch_end(self, outputs): + ... + + model = NewSignatureModel() + with no_deprecated_call(match="`ModelHooks.on_validation_epoch_end` signature has changed in v1.3."): + trainer.fit(model) + + +def test_v1_5_0_old_on_test_epoch_end(tmpdir): + callback_warning_cache.clear() + + class OldSignature(Callback): + + def on_test_epoch_end(self, trainer, pl_module): # noqa + ... + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.test(model) + + class OldSignatureModel(BoringModel): + + def on_test_epoch_end(self): # noqa + ... + + model = OldSignatureModel() + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.test(model) + + callback_warning_cache.clear() + + class NewSignature(Callback): + + def on_test_epoch_end(self, trainer, pl_module, outputs): + ... + + trainer.callbacks = [NewSignature()] + with no_deprecated_call(match="`Callback.on_test_epoch_end` signature has changed in v1.3."): + trainer.test(model) + + class NewSignatureModel(BoringModel): + + def on_test_epoch_end(self, outputs): + ... + + model = NewSignatureModel() + with no_deprecated_call(match="`ModelHooks.on_test_epoch_end` signature has changed in v1.3."): + trainer.test(model) + + +@pytest.mark.parametrize("cls", (BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler)) +def test_v1_5_0_profiler_output_filename(tmpdir, cls): + filepath = str(tmpdir / "test.txt") + with pytest.deprecated_call(match="`output_filename` parameter has been removed"): + profiler = cls(output_filename=filepath) + assert profiler.dirpath == tmpdir + assert profiler.filename == "test" diff --git a/tests/helpers/advanced_models.py b/tests/helpers/advanced_models.py index 7ad678b3046fd..2b0146e1ee099 100644 --- a/tests/helpers/advanced_models.py +++ b/tests/helpers/advanced_models.py @@ -20,6 +20,7 @@ from torch.utils.data import DataLoader from pytorch_lightning.core.lightning import LightningModule +from tests import PATH_DATASETS from tests.helpers.datasets import AverageDataset, MNIST, TrialMNIST @@ -165,7 +166,7 @@ def configure_optimizers(self): return [opt_g, opt_d], [] def train_dataloader(self): - return DataLoader(TrialMNIST(train=True, download=True), batch_size=16) + return DataLoader(TrialMNIST(root=PATH_DATASETS, train=True, download=True), batch_size=16) class ParityModuleRNN(LightningModule): @@ -223,6 +224,7 @@ def configure_optimizers(self): def train_dataloader(self): return DataLoader(MNIST( + root=PATH_DATASETS, train=True, download=True, ), batch_size=128, num_workers=1) diff --git a/tests/helpers/datasets.py b/tests/helpers/datasets.py index e7bdad0f1538c..77035796ca3b1 100644 --- a/tests/helpers/datasets.py +++ b/tests/helpers/datasets.py @@ -22,11 +22,6 @@ from torch import Tensor from torch.utils.data import Dataset -from tests import _PROJECT_ROOT - -#: local path to test datasets -PATH_DATASETS = os.path.join(_PROJECT_ROOT, 'Datasets') - class MNIST(Dataset): """ @@ -47,7 +42,7 @@ class MNIST(Dataset): downloaded again. Examples: - >>> dataset = MNIST(download=True) + >>> dataset = MNIST(".", download=True) >>> len(dataset) 60000 >>> torch.bincount(dataset.targets) @@ -65,7 +60,7 @@ class MNIST(Dataset): def __init__( self, - root: str = PATH_DATASETS, + root: str, train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, @@ -152,7 +147,7 @@ class TrialMNIST(MNIST): kwargs: Same as MNIST Examples: - >>> dataset = TrialMNIST(download=True) + >>> dataset = TrialMNIST(".", download=True) >>> len(dataset) 300 >>> sorted(set([d.item() for d in dataset.targets])) @@ -161,7 +156,7 @@ class TrialMNIST(MNIST): tensor([100, 100, 100]) """ - def __init__(self, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs): + def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs): # number of examples per class self.num_samples = num_samples # take just a subset of MNIST dataset @@ -169,7 +164,7 @@ def __init__(self, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2 self.cache_folder_name = f"digits-{'-'.join(str(d) for d in self.digits)}_nb-{self.num_samples}" - super().__init__(normalize=(0.5, 1.0), **kwargs) + super().__init__(root, normalize=(0.5, 1.0), **kwargs) @staticmethod def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor, num_samples: int, digits: Sequence): diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index fe85fbaea9025..5483e33d9cddb 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -56,6 +56,7 @@ def __new__( *args, min_gpus: int = 0, min_torch: Optional[str] = None, + max_torch: Optional[str] = None, min_python: Optional[str] = None, quantization: bool = False, amp_apex: bool = False, @@ -76,6 +77,7 @@ def __new__( args: native pytest.mark.skipif arguments min_gpus: min number of gpus required to run test min_torch: minimum pytorch version to run test + max_torch: maximum pytorch version to run test min_python: minimum python version required to run test quantization: if `torch.quantization` package is required to run test amp_apex: NVIDIA Apex is installed @@ -102,6 +104,11 @@ def __new__( conditions.append(torch_version < LooseVersion(min_torch)) reasons.append(f"torch>={min_torch}") + if max_torch: + torch_version = LooseVersion(get_distribution("torch").version) + conditions.append(torch_version >= LooseVersion(max_torch)) + reasons.append(f"torch<{max_torch}") + if min_python: py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" conditions.append(py_version < LooseVersion(min_python)) diff --git a/tests/helpers/test_datasets.py b/tests/helpers/test_datasets.py index 6319fdb562504..8c866bdbab789 100644 --- a/tests/helpers/test_datasets.py +++ b/tests/helpers/test_datasets.py @@ -16,12 +16,19 @@ import cloudpickle import pytest +from tests import PATH_DATASETS from tests.helpers.datasets import AverageDataset, MNIST, TrialMNIST -@pytest.mark.parametrize('dataset_cls', [MNIST, TrialMNIST, AverageDataset]) -def test_pickling_dataset_mnist(tmpdir, dataset_cls): - mnist = dataset_cls() +@pytest.mark.parametrize( + 'dataset_cls,args', [ + (MNIST, dict(root=PATH_DATASETS)), + (TrialMNIST, dict(root=PATH_DATASETS)), + (AverageDataset, dict()), + ] +) +def test_pickling_dataset_mnist(tmpdir, dataset_cls, args): + mnist = dataset_cls(**args) mnist_pickled = pickle.dumps(mnist) pickle.loads(mnist_pickled) diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 40b7c741b702e..a49e7bd23f065 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -202,6 +202,34 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir): logger.log_hyperparams(params) +@mock.patch('pytorch_lightning.loggers.mlflow.time') +@mock.patch('pytorch_lightning.loggers.mlflow.mlflow') +@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') +def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): + """ + Test that the logger calls methods on the mlflow experiment correctly. + """ + time.return_value = 1 + + logger = MLFlowLogger('test', save_dir=tmpdir, artifact_location='my_artifact_location') + logger._mlflow_client.get_experiment_by_name.return_value = None + + params = {'test': 'test_param'} + logger.log_hyperparams(params) + + logger.experiment.log_param.assert_called_once_with(logger.run_id, 'test', 'test_param') + + metrics = {'some_metric': 10} + logger.log_metrics(metrics) + + logger.experiment.log_metric.assert_called_once_with(logger.run_id, 'some_metric', 10, 1000, None) + + logger._mlflow_client.create_experiment.assert_called_once_with( + name='test', + artifact_location='my_artifact_location', + ) + + @mock.patch('pytorch_lightning.loggers.mlflow.mlflow') @mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') @pytest.mark.parametrize("step_idx", [10, None]) diff --git a/tests/metrics/classification/__init__.py b/tests/metrics/classification/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/tests/metrics/classification/inputs.py b/tests/metrics/classification/inputs.py deleted file mode 100644 index 7f2ac450385fe..0000000000000 --- a/tests/metrics/classification/inputs.py +++ /dev/null @@ -1,66 +0,0 @@ -from collections import namedtuple - -import torch - -from tests.metrics.utils import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES - -Input = namedtuple('Input', ["preds", "target"]) - -_input_binary_prob = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) -) - -_input_binary = Input( - preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) -) - -_input_multilabel_prob = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) -) - -_input_multilabel_multidim_prob = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)) -) - -_input_multilabel = Input( - preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) -) - -_input_multilabel_multidim = Input( - preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)) -) - -# Generate edge multilabel edge case, where nothing matches (scores are undefined) -__temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) -__temp_target = abs(__temp_preds - 1) - -_input_multilabel_no_match = Input(preds=__temp_preds, target=__temp_target) - -__mc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) -__mc_prob_preds = __mc_prob_preds / __mc_prob_preds.sum(dim=2, keepdim=True) - -_input_multiclass_prob = Input( - preds=__mc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) -) - -_input_multiclass = Input( - preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) -) - -__mdmc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM) -__mdmc_prob_preds = __mdmc_prob_preds / __mdmc_prob_preds.sum(dim=2, keepdim=True) - -_input_multidim_multiclass_prob = Input( - preds=__mdmc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) -) - -_input_multidim_multiclass = Input( - preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) -) diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py deleted file mode 100644 index bed60aa88388f..0000000000000 --- a/tests/metrics/classification/test_accuracy.py +++ /dev/null @@ -1,175 +0,0 @@ -from functools import partial - -import numpy as np -import pytest -import torch -from sklearn.metrics import accuracy_score as sk_accuracy - -from pytorch_lightning.metrics import Accuracy -from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType -from pytorch_lightning.metrics.functional import accuracy -from tests.metrics.classification.inputs import _input_binary, _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass as _input_mcls -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.classification.inputs import _input_multilabel as _input_mlb -from tests.metrics.classification.inputs import _input_multilabel_multidim as _input_mlmd -from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob -from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from tests.metrics.utils import MetricTester, THRESHOLD - -torch.manual_seed(42) - - -def _sk_accuracy(preds, target, subset_accuracy): - sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD) - sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - - if mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy: - sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1)) - sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2]) - elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy: - return np.all(sk_preds == sk_target, axis=(1, 2)).mean() - elif mode == DataType.MULTILABEL and not subset_accuracy: - sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1) - - return sk_accuracy(y_true=sk_target, y_pred=sk_preds) - - -@pytest.mark.parametrize( - "preds, target, subset_accuracy", - [ - (_input_binary_prob.preds, _input_binary_prob.target, False), - (_input_binary.preds, _input_binary.target, False), - (_input_mlb_prob.preds, _input_mlb_prob.target, True), - (_input_mlb_prob.preds, _input_mlb_prob.target, False), - (_input_mlb.preds, _input_mlb.target, True), - (_input_mlb.preds, _input_mlb.target, False), - (_input_mcls_prob.preds, _input_mcls_prob.target, False), - (_input_mcls.preds, _input_mcls.target, False), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, False), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, True), - (_input_mdmc.preds, _input_mdmc.target, False), - (_input_mdmc.preds, _input_mdmc.target, True), - (_input_mlmd_prob.preds, _input_mlmd_prob.target, True), - (_input_mlmd_prob.preds, _input_mlmd_prob.target, False), - (_input_mlmd.preds, _input_mlmd.target, True), - (_input_mlmd.preds, _input_mlmd.target, False), - ], -) -class TestAccuracies(MetricTester): - - @pytest.mark.parametrize("ddp", [False, True]) - @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accuracy): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=Accuracy, - sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy), - dist_sync_on_step=dist_sync_on_step, - metric_args={ - "threshold": THRESHOLD, - "subset_accuracy": subset_accuracy - }, - ) - - def test_accuracy_fn(self, preds, target, subset_accuracy): - self.run_functional_metric_test( - preds, - target, - metric_functional=accuracy, - sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy), - metric_args={ - "threshold": THRESHOLD, - "subset_accuracy": subset_accuracy - }, - ) - - -_l1to4 = [0.1, 0.2, 0.3, 0.4] -_l1to4t3 = np.array([_l1to4, _l1to4, _l1to4]) -_l1to4t3_mcls = [_l1to4t3.T, _l1to4t3.T, _l1to4t3.T] - -# The preds in these examples always put highest probability on class 3, second highest on class 2, -# third highest on class 1, and lowest on class 0 -_topk_preds_mcls = torch.tensor([_l1to4t3, _l1to4t3]).float() -_topk_target_mcls = torch.tensor([[1, 2, 3], [2, 1, 0]]) - -# This is like for MC case, but one sample in each batch is sabotaged with 0 class prediction :) -_topk_preds_mdmc = torch.tensor([_l1to4t3_mcls, _l1to4t3_mcls]).float() -_topk_target_mdmc = torch.tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]]) - - -# Replace with a proper sk_metric test once sklearn 0.24 hits :) -@pytest.mark.parametrize( - "preds, target, exp_result, k, subset_accuracy", - [ - (_topk_preds_mcls, _topk_target_mcls, 1 / 6, 1, False), - (_topk_preds_mcls, _topk_target_mcls, 3 / 6, 2, False), - (_topk_preds_mcls, _topk_target_mcls, 5 / 6, 3, False), - (_topk_preds_mcls, _topk_target_mcls, 1 / 6, 1, True), - (_topk_preds_mcls, _topk_target_mcls, 3 / 6, 2, True), - (_topk_preds_mcls, _topk_target_mcls, 5 / 6, 3, True), - (_topk_preds_mdmc, _topk_target_mdmc, 1 / 6, 1, False), - (_topk_preds_mdmc, _topk_target_mdmc, 8 / 18, 2, False), - (_topk_preds_mdmc, _topk_target_mdmc, 13 / 18, 3, False), - (_topk_preds_mdmc, _topk_target_mdmc, 1 / 6, 1, True), - (_topk_preds_mdmc, _topk_target_mdmc, 2 / 6, 2, True), - (_topk_preds_mdmc, _topk_target_mdmc, 3 / 6, 3, True), - ], -) -def test_topk_accuracy(preds, target, exp_result, k, subset_accuracy): - topk = Accuracy(top_k=k, subset_accuracy=subset_accuracy) - - for batch in range(preds.shape[0]): - topk(preds[batch], target[batch]) - - assert topk.compute() == exp_result - - # Test functional - total_samples = target.shape[0] * target.shape[1] - - preds = preds.view(total_samples, 4, -1) - target = target.view(total_samples, -1) - - assert accuracy(preds, target, top_k=k, subset_accuracy=subset_accuracy) == exp_result - - -# Only MC and MDMC with probs input type should be accepted for top_k -@pytest.mark.parametrize( - "preds, target", - [ - (_input_binary_prob.preds, _input_binary_prob.target), - (_input_binary.preds, _input_binary.target), - (_input_mlb_prob.preds, _input_mlb_prob.target), - (_input_mlb.preds, _input_mlb.target), - (_input_mcls.preds, _input_mcls.target), - (_input_mdmc.preds, _input_mdmc.target), - (_input_mlmd_prob.preds, _input_mlmd_prob.target), - (_input_mlmd.preds, _input_mlmd.target), - ], -) -def test_topk_accuracy_wrong_input_types(preds, target): - topk = Accuracy(top_k=1) - - with pytest.raises(ValueError): - topk(preds[0], target[0]) - - with pytest.raises(ValueError): - accuracy(preds[0], target[0], top_k=1) - - -@pytest.mark.parametrize("top_k, threshold", [(0, 0.5), (None, 1.5)]) -def test_wrong_params(top_k, threshold): - preds, target = _input_mcls_prob.preds, _input_mcls_prob.target - - with pytest.raises(ValueError): - acc = Accuracy(threshold=threshold, top_k=top_k) - acc(preds, target) - acc.compute() - - with pytest.raises(ValueError): - accuracy(preds, target, threshold=threshold, top_k=top_k) diff --git a/tests/metrics/classification/test_auc.py b/tests/metrics/classification/test_auc.py deleted file mode 100644 index e902151ecffce..0000000000000 --- a/tests/metrics/classification/test_auc.py +++ /dev/null @@ -1,64 +0,0 @@ -from collections import namedtuple - -import numpy as np -import pytest -import torch -from sklearn.metrics import auc as _sk_auc - -from pytorch_lightning.metrics.classification.auc import AUC -from pytorch_lightning.metrics.functional.auc import auc -from tests.metrics.utils import MetricTester, NUM_BATCHES - -torch.manual_seed(42) - - -def sk_auc(x, y): - x = x.flatten() - y = y.flatten() - return _sk_auc(x, y) - - -Input = namedtuple('Input', ["x", "y"]) - -_examples = [] -# generate already ordered samples, sorted in both directions -for i in range(4): - x = np.random.randint(0, 5, (NUM_BATCHES * 8)) - y = np.random.randint(0, 5, (NUM_BATCHES * 8)) - idx = np.argsort(x, kind='stable') - x = x[idx] if i % 2 == 0 else x[idx[::-1]] - y = y[idx] if i % 2 == 0 else x[idx[::-1]] - x = x.reshape(NUM_BATCHES, 8) - y = y.reshape(NUM_BATCHES, 8) - _examples.append(Input(x=torch.tensor(x), y=torch.tensor(y))) - - -@pytest.mark.parametrize("x, y", _examples) -class TestAUC(MetricTester): - - @pytest.mark.parametrize("ddp", [False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_auc(self, x, y, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp=ddp, - preds=x, - target=y, - metric_class=AUC, - sk_metric=sk_auc, - dist_sync_on_step=dist_sync_on_step, - ) - - def test_auc_functional(self, x, y): - self.run_functional_metric_test(x, y, metric_functional=auc, sk_metric=sk_auc, metric_args={"reorder": False}) - - -@pytest.mark.parametrize(['x', 'y', 'expected'], [ - pytest.param([0, 1], [0, 1], 0.5), - pytest.param([1, 0], [0, 1], 0.5), - pytest.param([1, 0, 0], [0, 1, 1], 0.5), - pytest.param([0, 1], [1, 1], 1), - pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5), -]) -def test_auc(x, y, expected): - # Test Area Under Curve (AUC) computation - assert auc(torch.tensor(x), torch.tensor(y), reorder=True) == expected diff --git a/tests/metrics/classification/test_auroc.py b/tests/metrics/classification/test_auroc.py deleted file mode 100644 index 0affcb1010225..0000000000000 --- a/tests/metrics/classification/test_auroc.py +++ /dev/null @@ -1,142 +0,0 @@ -from distutils.version import LooseVersion -from functools import partial - -import pytest -import torch -from sklearn.metrics import roc_auc_score as sk_roc_auc_score - -from pytorch_lightning.metrics.classification.auroc import AUROC -from pytorch_lightning.metrics.functional.auroc import auroc -from tests.metrics.classification.inputs import _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob -from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from tests.metrics.utils import MetricTester, NUM_CLASSES - -torch.manual_seed(42) - - -def _sk_auroc_binary_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_roc_auc_score(y_true=sk_target, y_score=sk_preds, average=average, max_fpr=max_fpr) - - -def _sk_auroc_multiclass_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() - return sk_roc_auc_score( - y_true=sk_target, - y_score=sk_preds, - average=average, - max_fpr=max_fpr, - multi_class=multi_class, - ) - - -def _sk_auroc_multidim_multiclass_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return sk_roc_auc_score( - y_true=sk_target, - y_score=sk_preds, - average=average, - max_fpr=max_fpr, - multi_class=multi_class, - ) - - -def _sk_auroc_multilabel_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.reshape(-1, num_classes).numpy() - return sk_roc_auc_score( - y_true=sk_target, - y_score=sk_preds, - average=average, - max_fpr=max_fpr, - multi_class=multi_class, - ) - - -def _sk_auroc_multilabel_multidim_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - return sk_roc_auc_score( - y_true=sk_target, - y_score=sk_preds, - average=average, - max_fpr=max_fpr, - multi_class=multi_class, - ) - - -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [(_input_binary_prob.preds, _input_binary_prob.target, _sk_auroc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_auroc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_auroc_multidim_multiclass_prob, NUM_CLASSES), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_auroc_multilabel_prob, NUM_CLASSES), - (_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_auroc_multilabel_multidim_prob, NUM_CLASSES)] -) -@pytest.mark.parametrize("average", ['macro', 'weighted']) -@pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) -class TestAUROC(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, ddp, dist_sync_on_step): - # max_fpr different from None is not support in multi class - if max_fpr is not None and num_classes != 1: - pytest.skip('max_fpr parameter not support for multi class or multi label') - - # max_fpr only supported for torch v1.6 or higher - if max_fpr is not None and LooseVersion(torch.__version__) < LooseVersion('1.6.0'): - pytest.skip('requires torch v1.6 or higher to test max_fpr argument') - - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=AUROC, - sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr), - dist_sync_on_step=dist_sync_on_step, - metric_args={ - "num_classes": num_classes, - "average": average, - "max_fpr": max_fpr - }, - ) - - def test_auroc_functional(self, preds, target, sk_metric, num_classes, average, max_fpr): - # max_fpr different from None is not support in multi class - if max_fpr is not None and num_classes != 1: - pytest.skip('max_fpr parameter not support for multi class or multi label') - - # max_fpr only supported for torch v1.6 or higher - if max_fpr is not None and LooseVersion(torch.__version__) < LooseVersion('1.6.0'): - pytest.skip('requires torch v1.6 or higher to test max_fpr argument') - - self.run_functional_metric_test( - preds, - target, - metric_functional=auroc, - sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr), - metric_args={ - "num_classes": num_classes, - "average": average, - "max_fpr": max_fpr - }, - ) - - -def test_error_on_different_mode(): - """ test that an error is raised if the user pass in data of - different modes (binary, multi-label, multi-class) - """ - metric = AUROC() - # pass in multi-class data - metric.update(torch.randn(10, 5).softmax(dim=-1), torch.randint(0, 5, (10, ))) - with pytest.raises(ValueError, match=r"The mode of data.* should be constant.*"): - # pass in multi-label data - metric.update(torch.rand(10, 5), torch.randint(0, 2, (10, 5))) diff --git a/tests/metrics/classification/test_average_precision.py b/tests/metrics/classification/test_average_precision.py deleted file mode 100644 index 7cab20883e970..0000000000000 --- a/tests/metrics/classification/test_average_precision.py +++ /dev/null @@ -1,97 +0,0 @@ -from functools import partial - -import numpy as np -import pytest -import torch -from sklearn.metrics import average_precision_score as sk_average_precision_score - -from pytorch_lightning.metrics.classification.average_precision import AveragePrecision -from pytorch_lightning.metrics.functional.average_precision import average_precision -from tests.metrics.classification.inputs import _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.utils import MetricTester, NUM_CLASSES - -torch.manual_seed(42) - - -def _sk_average_precision_score(y_true, probas_pred, num_classes=1): - if num_classes == 1: - return sk_average_precision_score(y_true, probas_pred) - - res = [] - for i in range(num_classes): - y_true_temp = np.zeros_like(y_true) - y_true_temp[y_true == i] = 1 - res.append(sk_average_precision_score(y_true_temp, probas_pred[:, i])) - return res - - -def _sk_avg_prec_binary_prob(preds, target, num_classes=1): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() - - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_avg_prec_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_avg_prec_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_avg_prec_multidim_multiclass_prob, NUM_CLASSES), - ] -) -class TestAveragePrecision(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_average_precision(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=AveragePrecision, - sk_metric=partial(sk_metric, num_classes=num_classes), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes} - ) - - def test_average_precision_functional(self, preds, target, sk_metric, num_classes): - self.run_functional_metric_test( - preds, - target, - metric_functional=average_precision, - sk_metric=partial(sk_metric, num_classes=num_classes), - metric_args={"num_classes": num_classes}, - ) - - -@pytest.mark.parametrize( - ['scores', 'target', 'expected_score'], - [ - # Check the average_precision_score of a constant predictor is - # the TPR - # Generate a dataset with 25% of positives - # And a constant score - # The precision is then the fraction of positive whatever the recall - # is, as there is only one threshold: - pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25), - # With threshold 0.8 : 1 TP and 2 TN and one FN - pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75), - ] -) -def test_average_precision(scores, target, expected_score): - assert average_precision(scores, target) == expected_score diff --git a/tests/metrics/classification/test_confusion_matrix.py b/tests/metrics/classification/test_confusion_matrix.py deleted file mode 100644 index 5371044d6d4b0..0000000000000 --- a/tests/metrics/classification/test_confusion_matrix.py +++ /dev/null @@ -1,128 +0,0 @@ -from functools import partial - -import numpy as np -import pytest -import torch -from sklearn.metrics import confusion_matrix as sk_confusion_matrix - -from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix -from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix -from tests.metrics.classification.inputs import _input_binary, _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass as _input_mcls -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.classification.inputs import _input_multilabel as _input_mlb -from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD - -torch.manual_seed(42) - - -def _sk_cm_binary_prob(preds, target, normalize=None): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -def _sk_cm_binary(preds, target, normalize=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -def _sk_cm_multilabel_prob(preds, target, normalize=None): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -def _sk_cm_multilabel(preds, target, normalize=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -def _sk_cm_multiclass_prob(preds, target, normalize=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -def _sk_cm_multiclass(preds, target, normalize=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -def _sk_cm_multidim_multiclass_prob(preds, target, normalize=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -def _sk_cm_multidim_multiclass(preds, target, normalize=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -@pytest.mark.parametrize("normalize", ['true', 'pred', 'all', None]) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [(_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2), - (_input_binary.preds, _input_binary.target, _sk_cm_binary, 2), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, 2), - (_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, 2), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES), - (_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES), - (_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES)] -) -class TestConfusionMatrix(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_confusion_matrix(self, normalize, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=ConfusionMatrix, - sk_metric=partial(sk_metric, normalize=normalize), - dist_sync_on_step=dist_sync_on_step, - metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, - "normalize": normalize - } - ) - - def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes): - self.run_functional_metric_test( - preds, - target, - metric_functional=confusion_matrix, - sk_metric=partial(sk_metric, normalize=normalize), - metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, - "normalize": normalize - } - ) - - -def test_warning_on_nan(tmpdir): - preds = torch.randint(3, size=(20, )) - target = torch.randint(3, size=(20, )) - - with pytest.warns(UserWarning, match='.* nan values found in confusion matrix have been replaced with zeros.'): - confusion_matrix(preds, target, num_classes=5, normalize='true') diff --git a/tests/metrics/classification/test_f_beta.py b/tests/metrics/classification/test_f_beta.py deleted file mode 100644 index b9458fb6c530c..0000000000000 --- a/tests/metrics/classification/test_f_beta.py +++ /dev/null @@ -1,153 +0,0 @@ -from functools import partial - -import numpy as np -import pytest -import torch -from sklearn.metrics import fbeta_score - -from pytorch_lightning.metrics import F1, FBeta -from pytorch_lightning.metrics.functional import f1, fbeta -from tests.metrics.classification.inputs import _input_binary, _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass as _input_mcls -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.classification.inputs import _input_multilabel as _input_mlb -from tests.metrics.classification.inputs import _input_multilabel_no_match as _input_mlb_nomatch -from tests.metrics.classification.inputs import _input_multilabel_prob as _mlb_prob_inputs -from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD - -torch.manual_seed(42) - - -def _sk_fbeta_binary_prob(preds, target, average='micro', beta=1.0): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return fbeta_score(y_true=sk_target, y_pred=sk_preds, average='binary', beta=beta) - - -def _sk_fbeta_binary(preds, target, average='micro', beta=1.0): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return fbeta_score(y_true=sk_target, y_pred=sk_preds, average='binary', beta=beta) - - -def _sk_fbeta_multilabel_prob(preds, target, average='micro', beta=1.0): - sk_preds = (preds.view(-1, NUM_CLASSES).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1, NUM_CLASSES).numpy() - - return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) - - -def _sk_fbeta_multilabel(preds, target, average='micro', beta=1.0): - sk_preds = preds.view(-1, NUM_CLASSES).numpy() - sk_target = target.view(-1, NUM_CLASSES).numpy() - - return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) - - -def _sk_fbeta_multiclass_prob(preds, target, average='micro', beta=1.0): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() - - return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) - - -def _sk_fbeta_multiclass(preds, target, average='micro', beta=1.0): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) - - -def _sk_fbeta_multidim_multiclass_prob(preds, target, average='micro', beta=1.0): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() - - return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) - - -def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.0): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) - - -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes, multilabel", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_fbeta_binary_prob, 1, False), - (_input_binary.preds, _input_binary.target, _sk_fbeta_binary, 1, False), - (_mlb_prob_inputs.preds, _mlb_prob_inputs.target, _sk_fbeta_multilabel_prob, NUM_CLASSES, True), - (_input_mlb.preds, _input_mlb.target, _sk_fbeta_multilabel, NUM_CLASSES, True), - (_input_mlb_nomatch.preds, _input_mlb_nomatch.target, _sk_fbeta_multilabel, NUM_CLASSES, True), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_fbeta_multiclass_prob, NUM_CLASSES, False), - (_input_mcls.preds, _input_mcls.target, _sk_fbeta_multiclass, NUM_CLASSES, False), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_fbeta_multidim_multiclass_prob, NUM_CLASSES, False), - (_input_mdmc.preds, _input_mdmc.target, _sk_fbeta_multidim_multiclass, NUM_CLASSES, False), - ], -) -@pytest.mark.parametrize("average", ['micro', 'macro', 'weighted', None]) -@pytest.mark.parametrize("beta", [0.5, 1.0, 2.0]) -class TestFBeta(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_fbeta(self, preds, target, sk_metric, num_classes, multilabel, average, beta, ddp, dist_sync_on_step): - metric_class = F1 if beta == 1.0 else partial(FBeta, beta=beta) - - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=metric_class, - sk_metric=partial(sk_metric, average=average, beta=beta), - dist_sync_on_step=dist_sync_on_step, - metric_args={ - "num_classes": num_classes, - "average": average, - "multilabel": multilabel, - "threshold": THRESHOLD, - }, - check_dist_sync_on_step=False, - check_batch=False, - ) - - def test_fbeta_functional(self, preds, target, sk_metric, num_classes, multilabel, average, beta): - metric_functional = f1 if beta == 1.0 else partial(fbeta, beta=beta) - - self.run_functional_metric_test( - preds=preds, - target=target, - metric_functional=metric_functional, - sk_metric=partial(sk_metric, average=average, beta=beta), - metric_args={ - "num_classes": num_classes, - "average": average, - "multilabel": multilabel, - "threshold": THRESHOLD - } - ) - - -@pytest.mark.parametrize(['pred', 'target', 'beta', 'exp_score'], [ - pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 0.5, [0.5, 0.5]), - pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 1, [0.5, 0.5]), - pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 2, [0.5, 0.5]), -]) -def test_fbeta_score(pred, target, beta, exp_score): - score = fbeta(torch.tensor(pred), torch.tensor(target), num_classes=1, beta=beta, average='none') - assert torch.allclose(score, torch.tensor(exp_score)) - - -@pytest.mark.parametrize(['pred', 'target', 'exp_score'], [ - pytest.param([0., 0., 0., 0.], [1., 1., 1., 1.], [0.0, 0.0]), - pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], [0.5, 0.5]), - pytest.param([1., 0., 1., 0.], [1., 0., 1., 0.], [1.0, 1.0]), -]) -def test_f1_score(pred, target, exp_score): - score = f1(torch.tensor(pred), torch.tensor(target), num_classes=1, average='none') - assert torch.allclose(score, torch.tensor(exp_score)) diff --git a/tests/metrics/classification/test_hamming_distance.py b/tests/metrics/classification/test_hamming_distance.py deleted file mode 100644 index c57072c033c8c..0000000000000 --- a/tests/metrics/classification/test_hamming_distance.py +++ /dev/null @@ -1,80 +0,0 @@ -import pytest -import torch -from sklearn.metrics import hamming_loss as sk_hamming_loss - -from pytorch_lightning.metrics import HammingDistance -from pytorch_lightning.metrics.classification.helpers import _input_format_classification -from pytorch_lightning.metrics.functional import hamming_distance -from tests.metrics.classification.inputs import _input_binary, _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass as _input_mcls -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.classification.inputs import _input_multilabel as _input_mlb -from tests.metrics.classification.inputs import _input_multilabel_multidim as _input_mlmd -from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob -from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from tests.metrics.utils import MetricTester, THRESHOLD - -torch.manual_seed(42) - - -def _sk_hamming_loss(preds, target): - sk_preds, sk_target, _ = _input_format_classification(preds, target, threshold=THRESHOLD) - sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0], -1), sk_target.reshape(sk_target.shape[0], -1) - - return sk_hamming_loss(y_true=sk_target, y_pred=sk_preds) - - -@pytest.mark.parametrize( - "preds, target", - [ - (_input_binary_prob.preds, _input_binary_prob.target), - (_input_binary.preds, _input_binary.target), - (_input_mlb_prob.preds, _input_mlb_prob.target), - (_input_mlb.preds, _input_mlb.target), - (_input_mcls_prob.preds, _input_mcls_prob.target), - (_input_mcls.preds, _input_mcls.target), - (_input_mdmc_prob.preds, _input_mdmc_prob.target), - (_input_mdmc.preds, _input_mdmc.target), - (_input_mlmd_prob.preds, _input_mlmd_prob.target), - (_input_mlmd.preds, _input_mlmd.target), - ], -) -class TestHammingDistance(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=HammingDistance, - sk_metric=_sk_hamming_loss, - dist_sync_on_step=dist_sync_on_step, - metric_args={"threshold": THRESHOLD}, - ) - - def test_hamming_distance_fn(self, preds, target): - self.run_functional_metric_test( - preds, - target, - metric_functional=hamming_distance, - sk_metric=_sk_hamming_loss, - metric_args={"threshold": THRESHOLD}, - ) - - -@pytest.mark.parametrize("threshold", [1.5]) -def test_wrong_params(threshold): - preds, target = _input_mcls_prob.preds, _input_mcls_prob.target - - with pytest.raises(ValueError): - ham_dist = HammingDistance(threshold=threshold) - ham_dist(preds, target) - ham_dist.compute() - - with pytest.raises(ValueError): - hamming_distance(preds, target, threshold=threshold) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py deleted file mode 100644 index a78d799b1a07d..0000000000000 --- a/tests/metrics/classification/test_inputs.py +++ /dev/null @@ -1,311 +0,0 @@ -import pytest -import torch -from torch import rand, randint - -from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType -from pytorch_lightning.metrics.utils import select_topk, to_onehot -from tests.metrics.classification.inputs import _input_binary as _bin -from tests.metrics.classification.inputs import _input_binary_prob as _bin_prob -from tests.metrics.classification.inputs import _input_multiclass as _mc -from tests.metrics.classification.inputs import _input_multiclass_prob as _mc_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass as _mdmc -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _mdmc_prob -from tests.metrics.classification.inputs import _input_multilabel as _ml -from tests.metrics.classification.inputs import _input_multilabel_multidim as _mlmd -from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _mlmd_prob -from tests.metrics.classification.inputs import _input_multilabel_prob as _ml_prob -from tests.metrics.classification.inputs import Input -from tests.metrics.utils import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, THRESHOLD - -torch.manual_seed(42) - -# Some additional inputs to test on -_ml_prob_half = Input(_ml_prob.preds.half(), _ml_prob.target) - -_mc_prob_2cls_preds = rand(NUM_BATCHES, BATCH_SIZE, 2) -_mc_prob_2cls_preds /= _mc_prob_2cls_preds.sum(dim=2, keepdim=True) -_mc_prob_2cls = Input(_mc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))) - -_mdmc_prob_many_dims_preds = rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, EXTRA_DIM) -_mdmc_prob_many_dims_preds /= _mdmc_prob_many_dims_preds.sum(dim=2, keepdim=True) -_mdmc_prob_many_dims = Input( - _mdmc_prob_many_dims_preds, - randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)), -) - -_mdmc_prob_2cls_preds = rand(NUM_BATCHES, BATCH_SIZE, 2, EXTRA_DIM) -_mdmc_prob_2cls_preds /= _mdmc_prob_2cls_preds.sum(dim=2, keepdim=True) -_mdmc_prob_2cls = Input(_mdmc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM))) - -# Some utils -T = torch.Tensor - - -def _idn(x): - return x - - -def _usq(x): - return x.unsqueeze(-1) - - -def _thrs(x): - return x >= THRESHOLD - - -def _rshp1(x): - return x.reshape(x.shape[0], -1) - - -def _rshp2(x): - return x.reshape(x.shape[0], x.shape[1], -1) - - -def _onehot(x): - return to_onehot(x, NUM_CLASSES) - - -def _onehot2(x): - return to_onehot(x, 2) - - -def _top1(x): - return select_topk(x, 1) - - -def _top2(x): - return select_topk(x, 2) - - -# To avoid ugly black line wrapping -def _ml_preds_tr(x): - return _rshp1(_thrs(x)) - - -def _onehot_rshp1(x): - return _onehot(_rshp1(x)) - - -def _onehot2_rshp1(x): - return _onehot2(_rshp1(x)) - - -def _top1_rshp2(x): - return _top1(_rshp2(x)) - - -def _top2_rshp2(x): - return _top2(_rshp2(x)) - - -def _probs_to_mc_preds_tr(x): - return _onehot2(_thrs(x)) - - -def _mlmd_prob_to_mc_preds_tr(x): - return _onehot2(_rshp1(_thrs(x))) - - -######################## -# Test correct inputs -######################## - - -@pytest.mark.parametrize( - "inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target", - [ - ############################# - # Test usual expected cases - (_bin, None, False, None, "multi-class", _usq, _usq), - (_bin, 1, False, None, "multi-class", _usq, _usq), - (_bin_prob, None, None, None, "binary", lambda x: _usq(_thrs(x)), _usq), - (_ml_prob, None, None, None, "multi-label", _thrs, _idn), - (_ml, None, False, None, "multi-dim multi-class", _idn, _idn), - (_ml_prob, None, None, None, "multi-label", _ml_preds_tr, _rshp1), - (_ml_prob, None, None, 2, "multi-label", _top2, _rshp1), - (_mlmd, None, False, None, "multi-dim multi-class", _rshp1, _rshp1), - (_mc, NUM_CLASSES, None, None, "multi-class", _onehot, _onehot), - (_mc_prob, None, None, None, "multi-class", _top1, _onehot), - (_mc_prob, None, None, 2, "multi-class", _top2, _onehot), - (_mdmc, NUM_CLASSES, None, None, "multi-dim multi-class", _onehot, _onehot), - (_mdmc_prob, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot), - (_mdmc_prob, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot), - (_mdmc_prob_many_dims, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), - (_mdmc_prob_many_dims, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), - ########################### - # Test some special cases - # Make sure that half precision works, i.e. is converted to full precision - (_ml_prob_half, None, None, None, "multi-label", lambda x: _ml_preds_tr(x.float()), _rshp1), - # Binary as multiclass - (_bin, None, None, None, "multi-class", _onehot2, _onehot2), - # Binary probs as multiclass - (_bin_prob, None, True, None, "binary", _probs_to_mc_preds_tr, _onehot2), - # Multilabel as multiclass - (_ml, None, True, None, "multi-dim multi-class", _onehot2, _onehot2), - # Multilabel probs as multiclass - (_ml_prob, None, True, None, "multi-label", _probs_to_mc_preds_tr, _onehot2), - # Multidim multilabel as multiclass - (_mlmd, None, True, None, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1), - # Multidim multilabel probs as multiclass - (_mlmd_prob, None, True, None, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1), - # Multiclass prob with 2 classes as binary - (_mc_prob_2cls, None, False, None, "multi-class", lambda x: _top1(x)[:, [1]], _usq), - # Multi-dim multi-class with 2 classes as multi-label - (_mdmc_prob_2cls, None, False, None, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), - ], -) -def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): - - def __get_data_type_enum(str_exp_mode): - return next(DataType[n] for n in dir(DataType) if DataType[n] == str_exp_mode) - - for exp_mode in (exp_mode, __get_data_type_enum(exp_mode)): - preds_out, target_out, mode = _input_format_classification( - preds=inputs.preds[0], - target=inputs.target[0], - threshold=THRESHOLD, - num_classes=num_classes, - is_multiclass=is_multiclass, - top_k=top_k, - ) - - assert mode == exp_mode - assert torch.equal(preds_out, post_preds(inputs.preds[0]).int()) - assert torch.equal(target_out, post_target(inputs.target[0]).int()) - - # Test that things work when batch_size = 1 - preds_out, target_out, mode = _input_format_classification( - preds=inputs.preds[0][[0], ...], - target=inputs.target[0][[0], ...], - threshold=THRESHOLD, - num_classes=num_classes, - is_multiclass=is_multiclass, - top_k=top_k, - ) - - assert mode == exp_mode - assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int()) - assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int()) - - -# Test that threshold is correctly applied -def test_threshold(): - target = T([1, 1, 1]).int() - preds_probs = T([0.5 - 1e-5, 0.5, 0.5 + 1e-5]) - - preds_probs_out, _, _ = _input_format_classification(preds_probs, target, threshold=0.5) - - assert torch.equal(torch.tensor([0, 1, 1], dtype=torch.int), preds_probs_out.squeeze().int()) - - -######################################################################## -# Test incorrect inputs -######################################################################## - - -@pytest.mark.parametrize("threshold", [-0.5, 0.0, 1.0, 1.5]) -def test_incorrect_threshold(threshold): - preds, target = rand(size=(7, )), randint(high=2, size=(7, )) - with pytest.raises(ValueError): - _input_format_classification(preds, target, threshold=threshold) - - -@pytest.mark.parametrize( - "preds, target, num_classes, is_multiclass", - [ - # Target not integer - (randint(high=2, size=(7, )), randint(high=2, size=(7, )).float(), None, None), - # Target negative - (randint(high=2, size=(7, )), -randint(high=2, size=(7, )), None, None), - # Preds negative integers - (-randint(high=2, size=(7, )), randint(high=2, size=(7, )), None, None), - # Negative probabilities - (-rand(size=(7, )), randint(high=2, size=(7, )), None, None), - # is_multiclass=False and target > 1 - (rand(size=(7, )), randint(low=2, high=4, size=(7, )), None, False), - # is_multiclass=False and preds integers with > 1 - (randint(low=2, high=4, size=(7, )), randint(high=2, size=(7, )), None, False), - # Wrong batch size - (randint(high=2, size=(8, )), randint(high=2, size=(7, )), None, None), - # Completely wrong shape - (randint(high=2, size=(7, )), randint(high=2, size=(7, 4)), None, None), - # Same #dims, different shape - (randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), None, None), - # Same shape and preds floats, target not binary - (rand(size=(7, 3)), randint(low=2, high=4, size=(7, 3)), None, None), - # #dims in preds = 1 + #dims in target, C shape not second or last - (rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), None, None), - # #dims in preds = 1 + #dims in target, preds not float - (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), None, None), - # is_multiclass=False, with C dimension > 2 - (_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE, )), None, False), - # Probs of multiclass preds do not sum up to 1 - (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), None, None), - # Max target larger or equal to C dimension - (_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, )), None, None), - # C dimension not equal to num_classes - (_mc_prob.preds[0], _mc_prob.target[0], NUM_CLASSES + 1, None), - # Max target larger than num_classes (with #dim preds = 1 + #dims target) - (_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, NUM_CLASSES)), 4, None), - # Max target larger than num_classes (with #dim preds = #dims target) - (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 4, None), - # Max preds larger than num_classes (with #dim preds = #dims target) - (randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 4, None), - # Num_classes=1, but is_multiclass not false - (randint(high=2, size=(7, )), randint(high=2, size=(7, )), 1, None), - # is_multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes - (randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False), - # Multilabel input with implied class dimension != num_classes - (rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False), - # Multilabel input with is_multiclass=True, but num_classes != 2 (or None) - (rand(size=(7, 3)), randint(high=2, size=(7, 3)), 4, True), - # Binary input, num_classes > 2 - (rand(size=(7, )), randint(high=2, size=(7, )), 4, None), - # Binary input, num_classes == 2 and is_multiclass not True - (rand(size=(7, )), randint(high=2, size=(7, )), 2, None), - (rand(size=(7, )), randint(high=2, size=(7, )), 2, False), - # Binary input, num_classes == 1 and is_multiclass=True - (rand(size=(7, )), randint(high=2, size=(7, )), 1, True), - ], -) -def test_incorrect_inputs(preds, target, num_classes, is_multiclass): - with pytest.raises(ValueError): - _input_format_classification( - preds=preds, target=target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass - ) - - -@pytest.mark.parametrize( - "preds, target, num_classes, is_multiclass, top_k", - [ - # Topk set with non (md)mc or ml prob data - (_bin.preds[0], _bin.target[0], None, None, 2), - (_bin_prob.preds[0], _bin_prob.target[0], None, None, 2), - (_mc.preds[0], _mc.target[0], None, None, 2), - (_ml.preds[0], _ml.target[0], None, None, 2), - (_mlmd.preds[0], _mlmd.target[0], None, None, 2), - (_mdmc.preds[0], _mdmc.target[0], None, None, 2), - # top_k = 0 - (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, None, 0), - # top_k = float - (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, None, 0.123), - # top_k =2 with 2 classes, is_multiclass=False - (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, False, 2), - # top_k = number of classes (C dimension) - (_mc_prob.preds[0], _mc_prob.target[0], None, None, NUM_CLASSES), - # is_multiclass = True for ml prob inputs, top_k set - (_ml_prob.preds[0], _ml_prob.target[0], None, True, 2), - # top_k = num_classes for ml prob inputs - (_ml_prob.preds[0], _ml_prob.target[0], None, True, NUM_CLASSES), - ], -) -def test_incorrect_inputs_topk(preds, target, num_classes, is_multiclass, top_k): - with pytest.raises(ValueError): - _input_format_classification( - preds=preds, - target=target, - threshold=THRESHOLD, - num_classes=num_classes, - is_multiclass=is_multiclass, - top_k=top_k, - ) diff --git a/tests/metrics/classification/test_iou.py b/tests/metrics/classification/test_iou.py deleted file mode 100644 index 6bb100f68165a..0000000000000 --- a/tests/metrics/classification/test_iou.py +++ /dev/null @@ -1,216 +0,0 @@ -from functools import partial - -import numpy as np -import pytest -import torch -from sklearn.metrics import jaccard_score as sk_jaccard_score - -from pytorch_lightning.metrics.classification.iou import IoU -from pytorch_lightning.metrics.functional.iou import iou -from tests.metrics.classification.inputs import _input_binary, _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass as _input_mcls -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.classification.inputs import _input_multilabel as _input_mlb -from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD - - -def _sk_iou_binary_prob(preds, target, average=None): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -def _sk_iou_binary(preds, target, average=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -def _sk_iou_multilabel_prob(preds, target, average=None): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -def _sk_iou_multilabel(preds, target, average=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -def _sk_iou_multiclass_prob(preds, target, average=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -def _sk_iou_multiclass(preds, target, average=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -def _sk_iou_multidim_multiclass_prob(preds, target, average=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -def _sk_iou_multidim_multiclass(preds, target, average=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -@pytest.mark.parametrize("reduction", ['elementwise_mean', 'none']) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [(_input_binary_prob.preds, _input_binary_prob.target, _sk_iou_binary_prob, 2), - (_input_binary.preds, _input_binary.target, _sk_iou_binary, 2), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_iou_multilabel_prob, 2), - (_input_mlb.preds, _input_mlb.target, _sk_iou_multilabel, 2), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_iou_multiclass_prob, NUM_CLASSES), - (_input_mcls.preds, _input_mcls.target, _sk_iou_multiclass, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_iou_multidim_multiclass_prob, NUM_CLASSES), - (_input_mdmc.preds, _input_mdmc.target, _sk_iou_multidim_multiclass, NUM_CLASSES)] -) -class TestIoU(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_confusion_matrix(self, reduction, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): - average = 'macro' if reduction == 'elementwise_mean' else None # convert tags - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=IoU, - sk_metric=partial(sk_metric, average=average), - dist_sync_on_step=dist_sync_on_step, - metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, - "reduction": reduction - } - ) - - def test_confusion_matrix_functional(self, reduction, preds, target, sk_metric, num_classes): - average = 'macro' if reduction == 'elementwise_mean' else None # convert tags - self.run_functional_metric_test( - preds, - target, - metric_functional=iou, - sk_metric=partial(sk_metric, average=average), - metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, - "reduction": reduction - } - ) - - -@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [ - pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])), - pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])), - pytest.param(False, 'none', 0, torch.Tensor([1, 1])), - pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])), - pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])), - pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])), -]) -def test_iou(half_ones, reduction, ignore_index, expected): - pred = (torch.arange(120) % 3).view(-1, 1) - target = (torch.arange(120) % 3).view(-1, 1) - if half_ones: - pred[:60] = 1 - iou_val = iou( - pred=pred, - target=target, - ignore_index=ignore_index, - reduction=reduction, - ) - assert torch.allclose(iou_val, expected, atol=1e-9) - - -# test `absent_score` -@pytest.mark.parametrize( - ['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'], - [ - # Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid - # scores the function can return ([0., 1.] range, inclusive). - # 2 classes, class 0 is correct everywhere, class 1 is absent. - pytest.param([0], [0], None, -1., 2, [1., -1.]), - pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]), - # absent_score not applied if only class 0 is present and it's the only class. - pytest.param([0], [0], None, -1., 1, [1.]), - # 2 classes, class 1 is correct everywhere, class 0 is absent. - pytest.param([1], [1], None, -1., 2, [-1., 1.]), - pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]), - # When 0 index ignored, class 0 does not get a score (not even the absent_score). - pytest.param([1], [1], 0, -1., 2, [1.0]), - # 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score. - pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]), - pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]), - # 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score. - pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]), - pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]), - # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class - # 2 is absent. - pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]), - # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class - # 2 is absent. - pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]), - # Sanity checks with absent_score of 1.0. - pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]), - pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]), - ] -) -def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected): - iou_val = iou( - pred=torch.tensor(pred), - target=torch.tensor(target), - ignore_index=ignore_index, - absent_score=absent_score, - num_classes=num_classes, - reduction='none', - ) - assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) - - -# example data taken from -# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py -@pytest.mark.parametrize( - ['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], - [ - # Ignoring an index outside of [0, num_classes-1] should have no effect. - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]), - # Ignoring a valid index drops only that index from the result. - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]), - # When reducing to mean or sum, the ignored index does not contribute to the output. - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]), - ] -) -def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected): - iou_val = iou( - pred=torch.tensor(pred), - target=torch.tensor(target), - ignore_index=ignore_index, - num_classes=num_classes, - reduction=reduction, - ) - assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py deleted file mode 100644 index a9bf39044174a..0000000000000 --- a/tests/metrics/classification/test_precision_recall.py +++ /dev/null @@ -1,347 +0,0 @@ -from functools import partial -from typing import Callable, Optional - -import numpy as np -import pytest -import torch -from sklearn.metrics import precision_score, recall_score - -from pytorch_lightning.metrics import Metric, Precision, Recall -from pytorch_lightning.metrics.classification.helpers import _input_format_classification -from pytorch_lightning.metrics.functional import precision, precision_recall, recall -from tests.metrics.classification.inputs import _input_binary, _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass as _input_mcls -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.classification.inputs import _input_multilabel as _input_mlb -from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD - -torch.manual_seed(42) - - -def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None): - if average == "none": - average = None - if num_classes == 1: - average = "binary" - - labels = list(range(num_classes)) - try: - labels.remove(ignore_index) - except ValueError: - pass - - sk_preds, sk_target, _ = _input_format_classification( - preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass - ) - sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - - sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels) - - if len(labels) != num_classes and not average: - sk_scores = np.insert(sk_scores, ignore_index, np.nan) - - return sk_scores - - -def _sk_prec_recall_multidim_multiclass( - preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average -): - preds, target, _ = _input_format_classification( - preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass - ) - - if mdmc_average == "global": - preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) - target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) - - return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, ignore_index) - elif mdmc_average == "samplewise": - scores = [] - - for i in range(preds.shape[0]): - pred_i = preds[i, ...].T - target_i = target[i, ...].T - scores_i = _sk_prec_recall(pred_i, target_i, sk_fn, num_classes, average, False, ignore_index) - - scores.append(np.expand_dims(scores_i, 0)) - - return np.concatenate(scores).mean(axis=0) - - -@pytest.mark.parametrize("metric, fn_metric", [(Precision, precision), (Recall, recall)]) -@pytest.mark.parametrize( - "average, mdmc_average, num_classes, ignore_index, match_str", - [ - ("wrong", None, None, None, "`average`"), - ("micro", "wrong", None, None, "`mdmc"), - ("macro", None, None, None, "number of classes"), - ("macro", None, 1, 0, "ignore_index"), - ], -) -def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index, match_str): - with pytest.raises(ValueError, match=match_str): - metric( - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, - ) - - with pytest.raises(ValueError, match=match_str): - fn_metric( - _input_binary.preds[0], - _input_binary.target[0], - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, - ) - - with pytest.raises(ValueError, match=match_str): - precision_recall( - _input_binary.preds[0], - _input_binary.target[0], - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, - ) - - -@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) -def test_zero_division(metric_class, metric_fn): - """ Test that zero_division works correctly (currently should just set to 0). """ - - preds = torch.tensor([1, 2, 1, 1]) - target = torch.tensor([2, 1, 2, 1]) - - cl_metric = metric_class(average="none", num_classes=3) - cl_metric(preds, target) - - result_cl = cl_metric.compute() - result_fn = metric_fn(preds, target, average="none", num_classes=3) - - assert result_cl[0] == result_fn[0] == 0 - - -@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) -def test_no_support(metric_class, metric_fn): - """This tests a rare edge case, where there is only one class present - in target, and ignore_index is set to exactly that class - and the - average method is equal to 'weighted'. - - This would mean that the sum of weights equals zero, and would, without - taking care of this case, return NaN. However, the reduction function - should catch that and set the metric to equal the value of zero_division - in this case (zero_division is for now not configurable and equals 0). - """ - - preds = torch.tensor([1, 1, 0, 0]) - target = torch.tensor([0, 0, 0, 0]) - - cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=0) - cl_metric(preds, target) - - result_cl = cl_metric.compute() - result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=0) - - assert result_cl == result_fn == 0 - - -@pytest.mark.parametrize( - "metric_class, metric_fn, sk_fn", [(Recall, recall, recall_score), (Precision, precision, precision_score)] -) -@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -@pytest.mark.parametrize("ignore_index", [None, 0]) -@pytest.mark.parametrize( - "preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper", - [ - (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_prec_recall), - (_input_binary.preds, _input_binary.target, 1, False, None, _sk_prec_recall), - (_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_prec_recall), - (_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_multidim_multiclass), - ( - _input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "global", - _sk_prec_recall_multidim_multiclass - ), - (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_multidim_multiclass), - ( - _input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "samplewise", - _sk_prec_recall_multidim_multiclass - ), - ], -) -class TestPrecisionRecall(MetricTester): - - @pytest.mark.parametrize("ddp", [False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_precision_recall_class( - self, - ddp: bool, - dist_sync_on_step: bool, - preds: torch.Tensor, - target: torch.Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - sk_fn: Callable, - is_multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], - ): - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=metric_class, - sk_metric=partial( - sk_wrapper, - sk_fn=sk_fn, - average=average, - num_classes=num_classes, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - mdmc_average=mdmc_average, - ), - dist_sync_on_step=dist_sync_on_step, - metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "is_multiclass": is_multiclass, - "ignore_index": ignore_index, - "mdmc_average": mdmc_average, - }, - check_dist_sync_on_step=True, - check_batch=True, - ) - - def test_precision_recall_fn( - self, - preds: torch.Tensor, - target: torch.Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - sk_fn: Callable, - is_multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], - ): - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - - self.run_functional_metric_test( - preds, - target, - metric_functional=metric_fn, - sk_metric=partial( - sk_wrapper, - sk_fn=sk_fn, - average=average, - num_classes=num_classes, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - mdmc_average=mdmc_average, - ), - metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "is_multiclass": is_multiclass, - "ignore_index": ignore_index, - "mdmc_average": mdmc_average, - }, - ) - - -@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -def test_precision_recall_joint(average): - """A simple test of the joint precision_recall metric. - - No need to test this thorougly, as it is just a combination of precision and recall, - which are already tested thoroughly. - """ - - precision_result = precision( - _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES - ) - recall_result = recall( - _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES - ) - - prec_recall_result = precision_recall( - _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES - ) - - assert torch.equal(precision_result, prec_recall_result[0]) - assert torch.equal(recall_result, prec_recall_result[1]) - - -_mc_k_target = torch.tensor([0, 1, 2]) -_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) - - -@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) -@pytest.mark.parametrize( - "k, preds, target, average, expected_prec, expected_recall", - [ - (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)), - (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1 / 2), torch.tensor(1.0)), - (1, _ml_k_preds, _ml_k_target, "micro", torch.tensor(0.0), torch.tensor(0.0)), - (2, _ml_k_preds, _ml_k_target, "micro", torch.tensor(1 / 6), torch.tensor(1 / 3)), - ], -) -def test_top_k( - metric_class, - metric_fn, - k: int, - preds: torch.Tensor, - target: torch.Tensor, - average: str, - expected_prec: torch.Tensor, - expected_recall: torch.Tensor, -): - """A simple test to check that top_k works as expected. - - Just a sanity check, the tests in StatScores should already guarantee - the corectness of results. - """ - - class_metric = metric_class(top_k=k, average=average, num_classes=3) - class_metric.update(preds, target) - - if metric_class.__name__ == "Precision": - result = expected_prec - else: - result = expected_recall - - assert torch.equal(class_metric.compute(), result) - assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) diff --git a/tests/metrics/classification/test_precision_recall_curve.py b/tests/metrics/classification/test_precision_recall_curve.py deleted file mode 100644 index 6a60e1fd36fdd..0000000000000 --- a/tests/metrics/classification/test_precision_recall_curve.py +++ /dev/null @@ -1,97 +0,0 @@ -from functools import partial - -import numpy as np -import pytest -import torch -from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve - -from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve -from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve -from tests.metrics.classification.inputs import _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.utils import MetricTester, NUM_CLASSES - -torch.manual_seed(42) - - -def _sk_precision_recall_curve(y_true, probas_pred, num_classes=1): - """ Adjusted comparison function that can also handles multiclass """ - if num_classes == 1: - return sk_precision_recall_curve(y_true, probas_pred) - - precision, recall, thresholds = [], [], [] - for i in range(num_classes): - y_true_temp = np.zeros_like(y_true) - y_true_temp[y_true == i] = 1 - res = sk_precision_recall_curve(y_true_temp, probas_pred[:, i]) - precision.append(res[0]) - recall.append(res[1]) - thresholds.append(res[2]) - return precision, recall, thresholds - - -def _sk_prec_rc_binary_prob(preds, target, num_classes=1): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -def _sk_prec_rc_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() - - return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -def _sk_prec_rc_multidim_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_prec_rc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_prec_rc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_prec_rc_multidim_multiclass_prob, NUM_CLASSES), - ] -) -class TestPrecisionRecallCurve(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_precision_recall_curve(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=PrecisionRecallCurve, - sk_metric=partial(sk_metric, num_classes=num_classes), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes} - ) - - def test_precision_recall_curve_functional(self, preds, target, sk_metric, num_classes): - self.run_functional_metric_test( - preds, - target, - metric_functional=precision_recall_curve, - sk_metric=partial(sk_metric, num_classes=num_classes), - metric_args={"num_classes": num_classes}, - ) - - -@pytest.mark.parametrize( - ['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], - [pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])] -) -def test_pr_curve(pred, target, expected_p, expected_r, expected_t): - p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target)) - assert p.size() == r.size() - assert p.size(0) == t.size(0) + 1 - - assert torch.allclose(p, torch.tensor(expected_p).to(p)) - assert torch.allclose(r, torch.tensor(expected_r).to(r)) - assert torch.allclose(t, torch.tensor(expected_t).to(t)) diff --git a/tests/metrics/classification/test_roc.py b/tests/metrics/classification/test_roc.py deleted file mode 100644 index 46a23322ca1c0..0000000000000 --- a/tests/metrics/classification/test_roc.py +++ /dev/null @@ -1,99 +0,0 @@ -from functools import partial - -import numpy as np -import pytest -import torch -from sklearn.metrics import roc_curve as sk_roc_curve - -from pytorch_lightning.metrics.classification.roc import ROC -from pytorch_lightning.metrics.functional.roc import roc -from tests.metrics.classification.inputs import _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.utils import MetricTester, NUM_CLASSES - -torch.manual_seed(42) - - -def _sk_roc_curve(y_true, probas_pred, num_classes=1): - """ Adjusted comparison function that can also handles multiclass """ - if num_classes == 1: - return sk_roc_curve(y_true, probas_pred, drop_intermediate=False) - - fpr, tpr, thresholds = [], [], [] - for i in range(num_classes): - y_true_temp = np.zeros_like(y_true) - y_true_temp[y_true == i] = 1 - res = sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False) - fpr.append(res[0]) - tpr.append(res[1]) - thresholds.append(res[2]) - return fpr, tpr, thresholds - - -def _sk_roc_binary_prob(preds, target, num_classes=1): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -def _sk_roc_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() - - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), - ] -) -class TestROC(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_roc(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=ROC, - sk_metric=partial(sk_metric, num_classes=num_classes), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes} - ) - - def test_roc_functional(self, preds, target, sk_metric, num_classes): - self.run_functional_metric_test( - preds, - target, - metric_functional=roc, - sk_metric=partial(sk_metric, num_classes=num_classes), - metric_args={"num_classes": num_classes}, - ) - - -@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [ - pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]), - pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]), - pytest.param([1, 1], [1, 0], [0, 1], [0, 1]), - pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]), - pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]), -]) -def test_roc_curve(pred, target, expected_tpr, expected_fpr): - fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target)) - - assert fpr.shape == tpr.shape - assert fpr.size(0) == thresh.size(0) - assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr)) - assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr)) diff --git a/tests/metrics/classification/test_stat_scores.py b/tests/metrics/classification/test_stat_scores.py deleted file mode 100644 index 659765931c433..0000000000000 --- a/tests/metrics/classification/test_stat_scores.py +++ /dev/null @@ -1,255 +0,0 @@ -from functools import partial -from typing import Callable, Optional - -import numpy as np -import pytest -import torch -from sklearn.metrics import multilabel_confusion_matrix - -from pytorch_lightning.metrics import StatScores -from pytorch_lightning.metrics.classification.helpers import _input_format_classification -from pytorch_lightning.metrics.functional import stat_scores -from tests.metrics.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mccls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.classification.inputs import _input_multilabel as _input_mcls -from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD - -torch.manual_seed(42) - - -def _sk_stat_scores(preds, target, reduce, num_classes, is_multiclass, ignore_index, top_k, mdmc_reduce=None): - preds, target, _ = _input_format_classification( - preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k - ) - sk_preds, sk_target = preds.numpy(), target.numpy() - - if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1: - sk_preds = np.delete(sk_preds, ignore_index, 1) - sk_target = np.delete(sk_target, ignore_index, 1) - - if preds.shape[1] == 1 and reduce == "samples": - sk_target = sk_target.T - sk_preds = sk_preds.T - - sk_stats = multilabel_confusion_matrix( - sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1 - ) - - if preds.shape[1] == 1 and reduce != "samples": - sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]] - else: - sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]] - - if reduce == "micro": - sk_stats = sk_stats.sum(axis=0, keepdims=True) - - sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1) - - if reduce == "micro": - sk_stats = sk_stats[0] - - if reduce == "macro" and ignore_index is not None and preds.shape[1]: - sk_stats[ignore_index, :] = -1 - - return sk_stats - - -def _sk_stat_scores_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, is_multiclass, ignore_index, top_k): - preds, target, _ = _input_format_classification( - preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k - ) - - if mdmc_reduce == "global": - preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) - target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) - - return _sk_stat_scores(preds, target, reduce, None, False, ignore_index, top_k) - elif mdmc_reduce == "samplewise": - scores = [] - - for i in range(preds.shape[0]): - pred_i = preds[i, ...].T - target_i = target[i, ...].T - scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False, ignore_index, top_k) - - scores.append(np.expand_dims(scores_i, 0)) - - return np.concatenate(scores) - - -@pytest.mark.parametrize( - "reduce, mdmc_reduce, num_classes, inputs, ignore_index", - [ - ["unknown", None, None, _input_binary, None], - ["micro", "unknown", None, _input_binary, None], - ["macro", None, None, _input_binary, None], - ["micro", None, None, _input_mdmc_prob, None], - ["micro", None, None, _input_binary_prob, 0], - ["micro", None, None, _input_mccls_prob, NUM_CLASSES], - ["micro", None, NUM_CLASSES, _input_mccls_prob, NUM_CLASSES], - ], -) -def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index): - """Test a combination of parameters that are invalid and should raise an error. - - This includes invalid ``reduce`` and ``mdmc_reduce`` parameter values, not setting - ``num_classes`` when ``reduce='macro'`, not setting ``mdmc_reduce`` when inputs - are multi-dim multi-class``, setting ``ignore_index`` when inputs are binary, as well - as setting ``ignore_index`` to a value higher than the number of classes. - """ - with pytest.raises(ValueError): - stat_scores( - inputs.preds[0], inputs.target[0], reduce, mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index - ) - - with pytest.raises(ValueError): - sts = StatScores(reduce=reduce, mdmc_reduce=mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index) - sts(inputs.preds[0], inputs.target[0]) - - -def test_wrong_threshold(): - with pytest.raises(ValueError): - StatScores(threshold=1.5) - - -@pytest.mark.parametrize("ignore_index", [None, 0]) -@pytest.mark.parametrize("reduce", ["micro", "macro", "samples"]) -@pytest.mark.parametrize( - "preds, target, sk_fn, mdmc_reduce, num_classes, is_multiclass, top_k", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None), - (_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2), - (_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None), - (_input_mccls_prob.preds, _input_mccls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None), - (_input_mccls_prob.preds, _input_mccls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2), - (_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None), - (_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None), - ( - _input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, - None - ), - (_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None), - ], -) -class TestStatScores(MetricTester): - # DDP tests temporarily disabled due to hanging issues - @pytest.mark.parametrize("ddp", [False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_stat_scores_class( - self, - ddp: bool, - dist_sync_on_step: bool, - sk_fn: Callable, - preds: torch.Tensor, - target: torch.Tensor, - reduce: str, - mdmc_reduce: Optional[str], - num_classes: Optional[int], - is_multiclass: Optional[bool], - ignore_index: Optional[int], - top_k: Optional[int], - ): - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=StatScores, - sk_metric=partial( - sk_fn, - reduce=reduce, - mdmc_reduce=mdmc_reduce, - num_classes=num_classes, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - top_k=top_k, - ), - dist_sync_on_step=dist_sync_on_step, - metric_args={ - "num_classes": num_classes, - "reduce": reduce, - "mdmc_reduce": mdmc_reduce, - "threshold": THRESHOLD, - "is_multiclass": is_multiclass, - "ignore_index": ignore_index, - "top_k": top_k, - }, - check_dist_sync_on_step=True, - check_batch=True, - ) - - def test_stat_scores_fn( - self, - sk_fn: Callable, - preds: torch.Tensor, - target: torch.Tensor, - reduce: str, - mdmc_reduce: Optional[str], - num_classes: Optional[int], - is_multiclass: Optional[bool], - ignore_index: Optional[int], - top_k: Optional[int], - ): - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - self.run_functional_metric_test( - preds, - target, - metric_functional=stat_scores, - sk_metric=partial( - sk_fn, - reduce=reduce, - mdmc_reduce=mdmc_reduce, - num_classes=num_classes, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - top_k=top_k, - ), - metric_args={ - "num_classes": num_classes, - "reduce": reduce, - "mdmc_reduce": mdmc_reduce, - "threshold": THRESHOLD, - "is_multiclass": is_multiclass, - "ignore_index": ignore_index, - "top_k": top_k, - }, - ) - - -_mc_k_target = torch.tensor([0, 1, 2]) -_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) - - -@pytest.mark.parametrize( - "k, preds, target, reduce, expected", - [ - (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor([2, 1, 5, 1, 3])), - (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor([3, 3, 3, 0, 3])), - (1, _ml_k_preds, _ml_k_target, "micro", torch.tensor([0, 3, 3, 3, 3])), - (2, _ml_k_preds, _ml_k_target, "micro", torch.tensor([1, 5, 1, 2, 3])), - (1, _mc_k_preds, _mc_k_target, "macro", torch.tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])), - (2, _mc_k_preds, _mc_k_target, "macro", torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])), - (1, _ml_k_preds, _ml_k_target, "macro", torch.tensor([[0, 0, 0], [1, 0, 2], [1, 1, 1], [1, 2, 0], [1, 2, 0]])), - (2, _ml_k_preds, _ml_k_target, "macro", torch.tensor([[0, 1, 0], [2, 0, 3], [0, 1, 0], [1, 1, 0], [1, 2, 0]])), - ], -) -def test_top_k(k: int, preds: torch.Tensor, target: torch.Tensor, reduce: str, expected: torch.Tensor): - """ A simple test to check that top_k works as expected """ - - class_metric = StatScores(top_k=k, reduce=reduce, num_classes=3) - class_metric.update(preds, target) - - assert torch.equal(class_metric.compute(), expected.T) - assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T) diff --git a/tests/metrics/functional/__init__.py b/tests/metrics/functional/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py deleted file mode 100644 index 39622c4cd3550..0000000000000 --- a/tests/metrics/functional/test_classification.py +++ /dev/null @@ -1,89 +0,0 @@ -import pytest -import torch - -from pytorch_lightning import seed_everything -from pytorch_lightning.metrics.functional.classification import dice_score -from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve -from pytorch_lightning.metrics.utils import get_num_classes, to_categorical, to_onehot - - -def test_onehot(): - test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) - expected = torch.stack([ - torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), - torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) - ]) - - assert test_tensor.shape == (2, 5) - assert expected.shape == (2, 10, 5) - - onehot_classes = to_onehot(test_tensor, num_classes=10) - onehot_no_classes = to_onehot(test_tensor) - - assert torch.allclose(onehot_classes, onehot_no_classes) - - assert onehot_classes.shape == expected.shape - assert onehot_no_classes.shape == expected.shape - - assert torch.allclose(expected.to(onehot_no_classes), onehot_no_classes) - assert torch.allclose(expected.to(onehot_classes), onehot_classes) - - -def test_to_categorical(): - test_tensor = torch.stack([ - torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), - torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) - ]).to(torch.float) - - expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) - assert expected.shape == (2, 5) - assert test_tensor.shape == (2, 10, 5) - - result = to_categorical(test_tensor) - - assert result.shape == expected.shape - assert torch.allclose(result, expected.to(result.dtype)) - - -@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [ - pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10), - pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), - pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), -]) -def test_get_num_classes(pred, target, num_classes, expected_num_classes): - assert get_num_classes(pred, target, num_classes) == expected_num_classes - - -@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [ - pytest.param(1, 1., 42), - pytest.param(None, 1., 42), -]) -def test_binary_clf_curve(sample_weight, pos_label, exp_shape): - # TODO: move back the pred and target to test func arguments - # if you fix the array inside the function, you'd also have fix the shape, - # because when the array changes, you also have to fix the shape - seed_everything(0) - pred = torch.randint(low=51, high=99, size=(100, ), dtype=torch.float) / 100 - target = torch.tensor([0, 1] * 50, dtype=torch.int) - if sample_weight is not None: - sample_weight = torch.ones_like(pred) * sample_weight - - fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) - - assert isinstance(tps, torch.Tensor) - assert isinstance(fps, torch.Tensor) - assert isinstance(thresh, torch.Tensor) - assert tps.shape == (exp_shape, ) - assert fps.shape == (exp_shape, ) - assert thresh.shape == (exp_shape, ) - - -@pytest.mark.parametrize(['pred', 'target', 'expected'], [ - pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.), - pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.), - pytest.param([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3), - pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.), -]) -def test_dice_score(pred, target, expected): - score = dice_score(torch.tensor(pred), torch.tensor(target)) - assert score == expected diff --git a/tests/metrics/functional/test_image_gradients.py b/tests/metrics/functional/test_image_gradients.py deleted file mode 100644 index 2e406793b4370..0000000000000 --- a/tests/metrics/functional/test_image_gradients.py +++ /dev/null @@ -1,109 +0,0 @@ -import pytest -import torch - -from pytorch_lightning.metrics.functional.image_gradients import image_gradients - - -def test_invalid_input_img_type(): - """Test Whether the module successfully handles invalid input data type""" - invalid_dummy_input = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] - - with pytest.raises(TypeError): - image_gradients(invalid_dummy_input) - - -def test_invalid_input_ndims(): - """ - Test whether the module successfully handles invalid number of dimensions - of input tensor - """ - - BATCH_SIZE = 1 - HEIGHT = 5 - WIDTH = 5 - CHANNELS = 1 - - image = torch.arange(0, BATCH_SIZE * HEIGHT * WIDTH * CHANNELS, dtype=torch.float32) - image = torch.reshape(image, (HEIGHT, WIDTH)) - - with pytest.raises(RuntimeError): - image_gradients(image) - - -def test_multi_batch_image_gradients(): - """Test whether the module correctly calculates gradients for known input - with non-unity batch size.Example input-output pair taken from TF's implementation of i - mage-gradients - """ - - BATCH_SIZE = 5 - HEIGHT = 5 - WIDTH = 5 - CHANNELS = 1 - - single_channel_img = torch.arange(0, 1 * HEIGHT * WIDTH * CHANNELS, dtype=torch.float32) - single_channel_img = torch.reshape(single_channel_img, (CHANNELS, HEIGHT, WIDTH)) - image = torch.stack([single_channel_img for _ in range(BATCH_SIZE)], dim=0) - - true_dy = [ - [5., 5., 5., 5., 5.], - [5., 5., 5., 5., 5.], - [5., 5., 5., 5., 5.], - [5., 5., 5., 5., 5.], - [0., 0., 0., 0., 0.], - ] - - true_dx = [ - [1., 1., 1., 1., 0.], - [1., 1., 1., 1., 0.], - [1., 1., 1., 1., 0.], - [1., 1., 1., 1., 0.], - [1., 1., 1., 1., 0.], - ] - true_dy = torch.Tensor(true_dy) - true_dx = torch.Tensor(true_dx) - - dy, dx = image_gradients(image) - - for batch_id in range(BATCH_SIZE): - assert torch.allclose(dy[batch_id, 0, :, :], true_dy) - assert dy.shape == (BATCH_SIZE, 1, HEIGHT, WIDTH) - assert dx.shape == (BATCH_SIZE, 1, HEIGHT, WIDTH) - - -def test_image_gradients(): - """Test whether the module correctly calculates gradients for known input. - Example input-output pair taken from TF's implementation of image-gradients - """ - - BATCH_SIZE = 1 - HEIGHT = 5 - WIDTH = 5 - CHANNELS = 1 - - image = torch.arange(0, BATCH_SIZE * HEIGHT * WIDTH * CHANNELS, dtype=torch.float32) - image = torch.reshape(image, (BATCH_SIZE, CHANNELS, HEIGHT, WIDTH)) - - true_dy = [ - [5., 5., 5., 5., 5.], - [5., 5., 5., 5., 5.], - [5., 5., 5., 5., 5.], - [5., 5., 5., 5., 5.], - [0., 0., 0., 0., 0.], - ] - - true_dx = [ - [1., 1., 1., 1., 0.], - [1., 1., 1., 1., 0.], - [1., 1., 1., 1., 0.], - [1., 1., 1., 1., 0.], - [1., 1., 1., 1., 0.], - ] - - true_dy = torch.Tensor(true_dy) - true_dx = torch.Tensor(true_dx) - - dy, dx = image_gradients(image) - - assert torch.allclose(dy, true_dy), "dy fails test" - assert torch.allclose(dx, true_dx), "dx fails tests" diff --git a/tests/metrics/functional/test_nlp.py b/tests/metrics/functional/test_nlp.py deleted file mode 100644 index b8faadc16085f..0000000000000 --- a/tests/metrics/functional/test_nlp.py +++ /dev/null @@ -1,68 +0,0 @@ -import pytest -import torch -from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction - -from pytorch_lightning.metrics.functional.nlp import bleu_score - -# example taken from -# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu -HYPOTHESIS1 = tuple( - "It is a guide to action which ensures that the military always obeys the commands of the party".split() -) -REFERENCE1 = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split()) -REFERENCE2 = tuple( - "It is a guiding principle which makes the military forces always being under the command of the Party".split() -) -REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split()) - -# example taken from -# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu -HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split() -HYP2 = "he read the book because he was interested in world history".split() - -REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split() -REF1B = "It is a guiding principle which makes the military force always being under the command of the Party".split() -REF1C = "It is the practical guide for the army always to heed the directions of the party".split() -REF2A = "he was interested in world history because he read the book".split() - -LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]] -HYPOTHESES = [HYP1, HYP2] - -# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction -smooth_func = SmoothingFunction().method2 - - -@pytest.mark.parametrize( - ["weights", "n_gram", "smooth_func", "smooth"], - [ - pytest.param([1], 1, None, False), - pytest.param([0.5, 0.5], 2, smooth_func, True), - pytest.param([0.333333, 0.333333, 0.333333], 3, None, False), - pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True), - ], -) -def test_bleu_score(weights, n_gram, smooth_func, smooth): - nltk_output = sentence_bleu( - [REFERENCE1, REFERENCE2, REFERENCE3], - HYPOTHESIS1, - weights=weights, - smoothing_function=smooth_func, - ) - pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth) - assert torch.allclose(pl_output, torch.tensor(nltk_output)) - - nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func) - pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth) - assert torch.allclose(pl_output, torch.tensor(nltk_output)) - - -def test_bleu_empty(): - hyp = [[]] - ref = [[[]]] - assert bleu_score(hyp, ref) == torch.tensor(0.0) - - -def test_no_4_gram(): - hyps = [["My", "full", "pytorch-lightning"]] - refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]] - assert bleu_score(hyps, refs) == torch.tensor(0.0) diff --git a/tests/metrics/functional/test_reduction.py b/tests/metrics/functional/test_reduction.py deleted file mode 100644 index 03a34f6c5a25b..0000000000000 --- a/tests/metrics/functional/test_reduction.py +++ /dev/null @@ -1,28 +0,0 @@ -import pytest -import torch - -from pytorch_lightning.metrics.utils import class_reduce, reduce - - -def test_reduce(): - start_tensor = torch.rand(50, 40, 30) - - assert torch.allclose(reduce(start_tensor, 'elementwise_mean'), torch.mean(start_tensor)) - assert torch.allclose(reduce(start_tensor, 'sum'), torch.sum(start_tensor)) - assert torch.allclose(reduce(start_tensor, 'none'), start_tensor) - - with pytest.raises(ValueError): - reduce(start_tensor, 'error_reduction') - - -def test_class_reduce(): - num = torch.randint(1, 10, (100, )).float() - denom = torch.randint(10, 20, (100, )).float() - weights = torch.randint(1, 100, (100, )).float() - - assert torch.allclose(class_reduce(num, denom, weights, 'micro'), torch.sum(num) / torch.sum(denom)) - assert torch.allclose(class_reduce(num, denom, weights, 'macro'), torch.mean(num / denom)) - assert torch.allclose( - class_reduce(num, denom, weights, 'weighted'), torch.sum(num / denom * (weights / torch.sum(weights))) - ) - assert torch.allclose(class_reduce(num, denom, weights, 'none'), num / denom) diff --git a/tests/metrics/functional/test_self_supervised.py b/tests/metrics/functional/test_self_supervised.py deleted file mode 100644 index fbabc5e93cffc..0000000000000 --- a/tests/metrics/functional/test_self_supervised.py +++ /dev/null @@ -1,32 +0,0 @@ -import pytest -import torch -from sklearn.metrics import pairwise - -from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity - - -@pytest.mark.parametrize('similarity', ['cosine', 'dot']) -@pytest.mark.parametrize('reduction', ['none', 'mean', 'sum']) -def test_against_sklearn(similarity, reduction): - """Compare PL metrics to sklearn version.""" - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - batch = torch.randn(5, 10, device=device) # 100 samples in 10 dimensions - - pl_dist = embedding_similarity(batch, similarity=similarity, reduction=reduction, zero_diagonal=False) - - def sklearn_embedding_distance(batch, similarity, reduction): - - metric_func = {'cosine': pairwise.cosine_similarity, 'dot': pairwise.linear_kernel}[similarity] - - dist = metric_func(batch, batch) - if reduction == 'mean': - return dist.mean(axis=-1) - if reduction == 'sum': - return dist.sum(axis=-1) - return dist - - sk_dist = sklearn_embedding_distance(batch.cpu().detach().numpy(), similarity=similarity, reduction=reduction) - sk_dist = torch.tensor(sk_dist, dtype=torch.float, device=device) - - assert torch.allclose(sk_dist, pl_dist) diff --git a/tests/metrics/regression/__init__.py b/tests/metrics/regression/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/tests/metrics/regression/test_explained_variance.py b/tests/metrics/regression/test_explained_variance.py deleted file mode 100644 index adab562ac6055..0000000000000 --- a/tests/metrics/regression/test_explained_variance.py +++ /dev/null @@ -1,77 +0,0 @@ -from collections import namedtuple -from functools import partial - -import pytest -import torch -from sklearn.metrics import explained_variance_score - -from pytorch_lightning.metrics.functional import explained_variance -from pytorch_lightning.metrics.regression import ExplainedVariance -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -num_targets = 5 - -Input = namedtuple('Input', ["preds", "target"]) - -_single_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.rand(NUM_BATCHES, BATCH_SIZE), -) - -_multi_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), -) - - -def _single_target_sk_metric(preds, target, sk_fn=explained_variance_score): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_fn(sk_target, sk_preds) - - -def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score): - sk_preds = preds.view(-1, num_targets).numpy() - sk_target = target.view(-1, num_targets).numpy() - return sk_fn(sk_target, sk_preds) - - -@pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted']) -@pytest.mark.parametrize( - "preds, target, sk_metric", - [ - (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), - (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), - ], -) -class TestExplainedVariance(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_explained_variance(self, multioutput, preds, target, sk_metric, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp, - preds, - target, - ExplainedVariance, - partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)), - dist_sync_on_step, - metric_args=dict(multioutput=multioutput), - ) - - def test_explained_variance_functional(self, multioutput, preds, target, sk_metric): - self.run_functional_metric_test( - preds, - target, - explained_variance, - partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)), - metric_args=dict(multioutput=multioutput), - ) - - -def test_error_on_different_shape(metric_class=ExplainedVariance): - metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/metrics/regression/test_mean_error.py b/tests/metrics/regression/test_mean_error.py deleted file mode 100644 index 041ce12f11164..0000000000000 --- a/tests/metrics/regression/test_mean_error.py +++ /dev/null @@ -1,87 +0,0 @@ -from collections import namedtuple -from functools import partial - -import pytest -import torch -from sklearn.metrics import mean_absolute_error as sk_mean_absolute_error -from sklearn.metrics import mean_squared_error as sk_mean_squared_error -from sklearn.metrics import mean_squared_log_error as sk_mean_squared_log_error - -from pytorch_lightning.metrics.functional import mean_absolute_error, mean_squared_error, mean_squared_log_error -from pytorch_lightning.metrics.regression import MeanAbsoluteError, MeanSquaredError, MeanSquaredLogError -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -num_targets = 5 - -Input = namedtuple('Input', ["preds", "target"]) - -_single_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.rand(NUM_BATCHES, BATCH_SIZE), -) - -_multi_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), -) - - -def _single_target_sk_metric(preds, target, sk_fn=mean_squared_error): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_fn(sk_preds, sk_target) - - -def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error): - sk_preds = preds.view(-1, num_targets).numpy() - sk_target = target.view(-1, num_targets).numpy() - return sk_fn(sk_preds, sk_target) - - -@pytest.mark.parametrize( - "preds, target, sk_metric", - [ - (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), - (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), - ], -) -@pytest.mark.parametrize( - "metric_class, metric_functional, sk_fn", - [ - (MeanSquaredError, mean_squared_error, sk_mean_squared_error), - (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error), - (MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error), - ], -) -class TestMeanError(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_mean_error_class( - self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, ddp, dist_sync_on_step - ): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=metric_class, - sk_metric=partial(sk_metric, sk_fn=sk_fn), - dist_sync_on_step=dist_sync_on_step, - ) - - def test_mean_error_functional(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): - self.run_functional_metric_test( - preds=preds, - target=target, - metric_functional=metric_functional, - sk_metric=partial(sk_metric, sk_fn=sk_fn), - ) - - -@pytest.mark.parametrize("metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError]) -def test_error_on_different_shape(metric_class): - metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/metrics/regression/test_psnr.py b/tests/metrics/regression/test_psnr.py deleted file mode 100644 index eb07fffb9d55c..0000000000000 --- a/tests/metrics/regression/test_psnr.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import namedtuple -from functools import partial - -import numpy as np -import pytest -import torch -from skimage.metrics import peak_signal_noise_ratio - -from pytorch_lightning.metrics.functional import psnr -from pytorch_lightning.metrics.regression import PSNR -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -Input = namedtuple('Input', ["preds", "target"]) - -_input_size = (NUM_BATCHES, BATCH_SIZE, 32, 32) -_inputs = [ - Input( - preds=torch.randint(n_cls_pred, _input_size, dtype=torch.float), - target=torch.randint(n_cls_target, _input_size, dtype=torch.float), - ) for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)] -] - - -def _to_sk_peak_signal_noise_ratio_inputs(value, dim): - value = value.numpy() - batches = value[None] if value.ndim == len(_input_size) - 1 else value - - if dim is None: - return [batches] - - num_dims = np.size(dim) - if not num_dims: - return batches - - inputs = [] - for batch in batches: - batch = np.moveaxis(batch, dim, np.arange(-num_dims, 0)) - psnr_input_shape = batch.shape[-num_dims:] - inputs.extend(batch.reshape(-1, *psnr_input_shape)) - return inputs - - -def _sk_psnr(preds, target, data_range, reduction, dim): - sk_preds_lists = _to_sk_peak_signal_noise_ratio_inputs(preds, dim=dim) - sk_target_lists = _to_sk_peak_signal_noise_ratio_inputs(target, dim=dim) - np_reduce_map = {"elementwise_mean": np.mean, "none": np.array, "sum": np.sum} - return np_reduce_map[reduction]([ - peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range) - for sk_target, sk_preds in zip(sk_target_lists, sk_preds_lists) - ]) - - -def _base_e_sk_psnr(preds, target, data_range, reduction, dim): - return _sk_psnr(preds, target, data_range, reduction, dim) * np.log(10) - - -@pytest.mark.parametrize( - "preds, target, data_range, reduction, dim", - [ - (_inputs[0].preds, _inputs[0].target, 10, "elementwise_mean", None), - (_inputs[1].preds, _inputs[1].target, 10, "elementwise_mean", None), - (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", None), - (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", 1), - (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", (1, 2)), - (_inputs[2].preds, _inputs[2].target, 5, "sum", (1, 2)), - ], -) -@pytest.mark.parametrize( - "base, sk_metric", - [ - (10.0, _sk_psnr), - (2.718281828459045, _base_e_sk_psnr), - ], -) -class TestPSNR(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_psnr(self, preds, target, data_range, base, reduction, dim, sk_metric, ddp, dist_sync_on_step): - _args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim} - self.run_class_metric_test( - ddp, - preds, - target, - PSNR, - partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim), - metric_args=_args, - dist_sync_on_step=dist_sync_on_step, - ) - - def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduction, dim): - _args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim} - self.run_functional_metric_test( - preds, - target, - psnr, - partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim), - metric_args=_args, - ) - - -@pytest.mark.parametrize("reduction", ["none", "sum"]) -def test_reduction_for_dim_none(reduction): - match = f"The `reduction={reduction}` will not have any effect when `dim` is None." - with pytest.warns(UserWarning, match=match): - PSNR(reduction=reduction, dim=None) - - with pytest.warns(UserWarning, match=match): - psnr(_inputs[0].preds, _inputs[0].target, reduction=reduction, dim=None) - - -def test_missing_data_range(): - with pytest.raises(ValueError): - PSNR(data_range=None, dim=0) - - with pytest.raises(ValueError): - psnr(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0) diff --git a/tests/metrics/regression/test_r2score.py b/tests/metrics/regression/test_r2score.py deleted file mode 100644 index 232b003e6116a..0000000000000 --- a/tests/metrics/regression/test_r2score.py +++ /dev/null @@ -1,114 +0,0 @@ -from collections import namedtuple -from functools import partial - -import pytest -import torch -from sklearn.metrics import r2_score as sk_r2score - -from pytorch_lightning.metrics.functional import r2score -from pytorch_lightning.metrics.regression import R2Score -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -num_targets = 5 - -Input = namedtuple('Input', ["preds", "target"]) - -_single_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.rand(NUM_BATCHES, BATCH_SIZE), -) - -_multi_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), -) - - -def _single_target_sk_metric(preds, target, adjusted, multioutput): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) - if adjusted != 0: - r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) - return r2_score - - -def _multi_target_sk_metric(preds, target, adjusted, multioutput): - sk_preds = preds.view(-1, num_targets).numpy() - sk_target = target.view(-1, num_targets).numpy() - r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) - if adjusted != 0: - r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) - return r2_score - - -@pytest.mark.parametrize("adjusted", [0, 5, 10]) -@pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted']) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_outputs", - [ - (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric, 1), - (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric, num_targets), - ], -) -class TestR2Score(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_r2(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp, - preds, - target, - R2Score, - partial(sk_metric, adjusted=adjusted, multioutput=multioutput), - dist_sync_on_step, - metric_args=dict(adjusted=adjusted, multioutput=multioutput, num_outputs=num_outputs), - ) - - def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, num_outputs): - self.run_functional_metric_test( - preds, - target, - r2score, - partial(sk_metric, adjusted=adjusted, multioutput=multioutput), - metric_args=dict(adjusted=adjusted, multioutput=multioutput), - ) - - -def test_error_on_different_shape(metric_class=R2Score): - metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) - - -def test_error_on_multidim_tensors(metric_class=R2Score): - metric = metric_class() - with pytest.raises( - ValueError, - match=r'Expected both prediction and target to be 1D or 2D tensors,' - r' but recevied tensors with dimension .' - ): - metric(torch.randn(10, 20, 5), torch.randn(10, 20, 5)) - - -def test_error_on_too_few_samples(metric_class=R2Score): - metric = metric_class() - with pytest.raises(ValueError, match='Needs atleast two samples to calculate r2 score.'): - metric(torch.randn(1, ), torch.randn(1, )) - - -def test_warning_on_too_large_adjusted(metric_class=R2Score): - metric = metric_class(adjusted=10) - - with pytest.warns( - UserWarning, - match="More independent regressions than datapoints in" - " adjusted r2 score. Falls back to standard r2 score." - ): - metric(torch.randn(10, ), torch.randn(10, )) - - with pytest.warns(UserWarning, match="Division by zero in adjusted r2 score. Falls back to" " standard r2 score."): - metric(torch.randn(11, ), torch.randn(11, )) diff --git a/tests/metrics/regression/test_ssim.py b/tests/metrics/regression/test_ssim.py deleted file mode 100644 index f7e4b7a58e001..0000000000000 --- a/tests/metrics/regression/test_ssim.py +++ /dev/null @@ -1,104 +0,0 @@ -from collections import namedtuple -from functools import partial - -import pytest -import torch -from skimage.metrics import structural_similarity - -from pytorch_lightning.metrics.functional import ssim -from pytorch_lightning.metrics.regression import SSIM -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -Input = namedtuple('Input', ["preds", "target", "multichannel"]) - -_inputs = [] -for size, channel, coef, multichannel, dtype in [ - (12, 3, 0.9, True, torch.float), - (13, 1, 0.8, False, torch.float32), - (14, 1, 0.7, False, torch.double), - (15, 3, 0.6, True, torch.float64), -]: - preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) - _inputs.append(Input( - preds=preds, - target=preds * coef, - multichannel=multichannel, - )) - - -def _sk_metric(preds, target, data_range, multichannel): - c, h, w = preds.shape[-3:] - sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() - sk_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() - if not multichannel: - sk_preds = sk_preds[:, :, :, 0] - sk_target = sk_target[:, :, :, 0] - - return structural_similarity( - sk_target, - sk_preds, - data_range=data_range, - multichannel=multichannel, - gaussian_weights=True, - win_size=11, - sigma=1.5, - use_sample_covariance=False - ) - - -@pytest.mark.parametrize( - "preds, target, multichannel", - [(i.preds, i.target, i.multichannel) for i in _inputs], -) -class TestSSIM(MetricTester): - atol = 6e-5 - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_ssim(self, preds, target, multichannel, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp, - preds, - target, - SSIM, - partial(_sk_metric, data_range=1.0, multichannel=multichannel), - metric_args={"data_range": 1.0}, - dist_sync_on_step=dist_sync_on_step, - ) - - def test_ssim_functional(self, preds, target, multichannel): - self.run_functional_metric_test( - preds, - target, - ssim, - partial(_sk_metric, data_range=1.0, multichannel=multichannel), - metric_args={"data_range": 1.0}, - ) - - -@pytest.mark.parametrize( - ['pred', 'target', 'kernel', 'sigma'], - [ - pytest.param([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape) - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma) - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma) - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma) - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input - ], -) -def test_ssim_invalid_inputs(pred, target, kernel, sigma): - pred_t = torch.rand(pred) - target_t = torch.rand(target, dtype=torch.float64) - with pytest.raises(TypeError): - ssim(pred_t, target_t) - - pred = torch.rand(pred) - target = torch.rand(target) - with pytest.raises(ValueError): - ssim(pred, target, kernel, sigma) diff --git a/tests/metrics/test_composition.py b/tests/metrics/test_composition.py deleted file mode 100644 index 7845e86f514ff..0000000000000 --- a/tests/metrics/test_composition.py +++ /dev/null @@ -1,510 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from operator import neg, pos - -import pytest -import torch - -from pytorch_lightning.metrics.compositional import CompositionalMetric -from pytorch_lightning.metrics.metric import Metric -from tests.helpers.runif import RunIf - - -class DummyMetric(Metric): - - def __init__(self, val_to_return): - super().__init__() - self._num_updates = 0 - self._val_to_return = val_to_return - - def update(self, *args, **kwargs) -> None: - self._num_updates += 1 - - def compute(self): - return torch.tensor(self._val_to_return) - - def reset(self): - self._num_updates = 0 - return super().reset() - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(4)), - (2, torch.tensor(4)), - (2.0, torch.tensor(4.0)), - (torch.tensor(2), torch.tensor(4)), - ], -) -def test_metrics_add(second_operand, expected_result): - first_metric = DummyMetric(2) - - final_add = first_metric + second_operand - final_radd = second_operand + first_metric - - assert isinstance(final_add, CompositionalMetric) - assert isinstance(final_radd, CompositionalMetric) - - assert torch.allclose(expected_result, final_add.compute()) - assert torch.allclose(expected_result, final_radd.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [(DummyMetric(3), torch.tensor(2)), (3, torch.tensor(2)), (3, torch.tensor(2)), (torch.tensor(3), torch.tensor(2))], -) -@RunIf(min_torch="1.5.0") -def test_metrics_and(second_operand, expected_result): - first_metric = DummyMetric(2) - - final_and = first_metric & second_operand - final_rand = second_operand & first_metric - - assert isinstance(final_and, CompositionalMetric) - assert isinstance(final_rand, CompositionalMetric) - - assert torch.allclose(expected_result, final_and.compute()) - assert torch.allclose(expected_result, final_rand.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(True)), - (2, torch.tensor(True)), - (2.0, torch.tensor(True)), - (torch.tensor(2), torch.tensor(True)), - ], -) -def test_metrics_eq(second_operand, expected_result): - first_metric = DummyMetric(2) - - final_eq = first_metric == second_operand - - assert isinstance(final_eq, CompositionalMetric) - - # can't use allclose for bool tensors - assert (expected_result == final_eq.compute()).all() - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(2)), - (2, torch.tensor(2)), - (2.0, torch.tensor(2.0)), - (torch.tensor(2), torch.tensor(2)), - ], -) -@RunIf(min_torch="1.5.0") -def test_metrics_floordiv(second_operand, expected_result): - first_metric = DummyMetric(5) - - final_floordiv = first_metric // second_operand - - assert isinstance(final_floordiv, CompositionalMetric) - - assert torch.allclose(expected_result, final_floordiv.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(True)), - (2, torch.tensor(True)), - (2.0, torch.tensor(True)), - (torch.tensor(2), torch.tensor(True)), - ], -) -def test_metrics_ge(second_operand, expected_result): - first_metric = DummyMetric(5) - - final_ge = first_metric >= second_operand - - assert isinstance(final_ge, CompositionalMetric) - - # can't use allclose for bool tensors - assert (expected_result == final_ge.compute()).all() - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(True)), - (2, torch.tensor(True)), - (2.0, torch.tensor(True)), - (torch.tensor(2), torch.tensor(True)), - ], -) -def test_metrics_gt(second_operand, expected_result): - first_metric = DummyMetric(5) - - final_gt = first_metric > second_operand - - assert isinstance(final_gt, CompositionalMetric) - - # can't use allclose for bool tensors - assert (expected_result == final_gt.compute()).all() - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(False)), - (2, torch.tensor(False)), - (2.0, torch.tensor(False)), - (torch.tensor(2), torch.tensor(False)), - ], -) -def test_metrics_le(second_operand, expected_result): - first_metric = DummyMetric(5) - - final_le = first_metric <= second_operand - - assert isinstance(final_le, CompositionalMetric) - - # can't use allclose for bool tensors - assert (expected_result == final_le.compute()).all() - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(False)), - (2, torch.tensor(False)), - (2.0, torch.tensor(False)), - (torch.tensor(2), torch.tensor(False)), - ], -) -def test_metrics_lt(second_operand, expected_result): - first_metric = DummyMetric(5) - - final_lt = first_metric < second_operand - - assert isinstance(final_lt, CompositionalMetric) - - # can't use allclose for bool tensors - assert (expected_result == final_lt.compute()).all() - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [(DummyMetric([2, 2, 2]), torch.tensor(12)), (torch.tensor([2, 2, 2]), torch.tensor(12))], -) -def test_metrics_matmul(second_operand, expected_result): - first_metric = DummyMetric([2, 2, 2]) - - final_matmul = first_metric @ second_operand - - assert isinstance(final_matmul, CompositionalMetric) - - assert torch.allclose(expected_result, final_matmul.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(1)), - (2, torch.tensor(1)), - (2.0, torch.tensor(1)), - (torch.tensor(2), torch.tensor(1)), - ], -) -def test_metrics_mod(second_operand, expected_result): - first_metric = DummyMetric(5) - - final_mod = first_metric % second_operand - - assert isinstance(final_mod, CompositionalMetric) - # prevent Runtime error for PT 1.8 - Long did not match Float - assert torch.allclose(expected_result.to(float), final_mod.compute().to(float)) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(4)), - (2, torch.tensor(4)), - (2.0, torch.tensor(4.0)), - (torch.tensor(2), torch.tensor(4)), - ], -) -def test_metrics_mul(second_operand, expected_result): - first_metric = DummyMetric(2) - - final_mul = first_metric * second_operand - final_rmul = second_operand * first_metric - - assert isinstance(final_mul, CompositionalMetric) - assert isinstance(final_rmul, CompositionalMetric) - - assert torch.allclose(expected_result, final_mul.compute()) - assert torch.allclose(expected_result, final_rmul.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(False)), - (2, torch.tensor(False)), - (2.0, torch.tensor(False)), - (torch.tensor(2), torch.tensor(False)), - ], -) -def test_metrics_ne(second_operand, expected_result): - first_metric = DummyMetric(2) - - final_ne = first_metric != second_operand - - assert isinstance(final_ne, CompositionalMetric) - - # can't use allclose for bool tensors - assert (expected_result == final_ne.compute()).all() - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [(DummyMetric([1, 0, 3]), torch.tensor([-1, -2, 3])), (torch.tensor([1, 0, 3]), torch.tensor([-1, -2, 3]))], -) -@RunIf(min_torch="1.5.0") -def test_metrics_or(second_operand, expected_result): - first_metric = DummyMetric([-1, -2, 3]) - - final_or = first_metric | second_operand - final_ror = second_operand | first_metric - - assert isinstance(final_or, CompositionalMetric) - assert isinstance(final_ror, CompositionalMetric) - - assert torch.allclose(expected_result, final_or.compute()) - assert torch.allclose(expected_result, final_ror.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - pytest.param(DummyMetric(2), torch.tensor(4)), - pytest.param(2, torch.tensor(4)), - pytest.param(2.0, torch.tensor(4.0), marks=RunIf(min_torch="1.6.0")), - pytest.param(torch.tensor(2), torch.tensor(4)), - ], -) -def test_metrics_pow(second_operand, expected_result): - first_metric = DummyMetric(2) - - final_pow = first_metric**second_operand - - assert isinstance(final_pow, CompositionalMetric) - - assert torch.allclose(expected_result, final_pow.compute()) - - -@pytest.mark.parametrize( - ["first_operand", "expected_result"], - [(5, torch.tensor(2)), (5.0, torch.tensor(2.0)), (torch.tensor(5), torch.tensor(2))], -) -@RunIf(min_torch="1.5.0") -def test_metrics_rfloordiv(first_operand, expected_result): - second_operand = DummyMetric(2) - - final_rfloordiv = first_operand // second_operand - - assert isinstance(final_rfloordiv, CompositionalMetric) - assert torch.allclose(expected_result, final_rfloordiv.compute()) - - -@pytest.mark.parametrize(["first_operand", "expected_result"], [(torch.tensor([2, 2, 2]), torch.tensor(12))]) -def test_metrics_rmatmul(first_operand, expected_result): - second_operand = DummyMetric([2, 2, 2]) - - final_rmatmul = first_operand @ second_operand - - assert isinstance(final_rmatmul, CompositionalMetric) - - assert torch.allclose(expected_result, final_rmatmul.compute()) - - -@pytest.mark.parametrize(["first_operand", "expected_result"], [(torch.tensor(2), torch.tensor(2))]) -def test_metrics_rmod(first_operand, expected_result): - second_operand = DummyMetric(5) - - final_rmod = first_operand % second_operand - - assert isinstance(final_rmod, CompositionalMetric) - - assert torch.allclose(expected_result, final_rmod.compute()) - - -@pytest.mark.parametrize( - "first_operand,expected_result", - [ - pytest.param(DummyMetric(2), torch.tensor(4)), - pytest.param(2, torch.tensor(4)), - pytest.param(2.0, torch.tensor(4.0), marks=RunIf(min_torch="1.6.0")), - ], -) -def test_metrics_rpow(first_operand, expected_result): - second_operand = DummyMetric(2) - - final_rpow = first_operand**second_operand - - assert isinstance(final_rpow, CompositionalMetric) - - assert torch.allclose(expected_result, final_rpow.compute()) - - -@pytest.mark.parametrize( - ["first_operand", "expected_result"], - [ - (DummyMetric(3), torch.tensor(1)), - (3, torch.tensor(1)), - (3.0, torch.tensor(1.0)), - (torch.tensor(3), torch.tensor(1)), - ], -) -def test_metrics_rsub(first_operand, expected_result): - second_operand = DummyMetric(2) - - final_rsub = first_operand - second_operand - - assert isinstance(final_rsub, CompositionalMetric) - - assert torch.allclose(expected_result, final_rsub.compute()) - - -@pytest.mark.parametrize( - ["first_operand", "expected_result"], - [ - (DummyMetric(6), torch.tensor(2.0)), - (6, torch.tensor(2.0)), - (6.0, torch.tensor(2.0)), - (torch.tensor(6), torch.tensor(2.0)), - ], -) -@RunIf(min_torch="1.5.0") -def test_metrics_rtruediv(first_operand, expected_result): - second_operand = DummyMetric(3) - - final_rtruediv = first_operand / second_operand - - assert isinstance(final_rtruediv, CompositionalMetric) - - assert torch.allclose(expected_result, final_rtruediv.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(1)), - (2, torch.tensor(1)), - (2.0, torch.tensor(1.0)), - (torch.tensor(2), torch.tensor(1)), - ], -) -def test_metrics_sub(second_operand, expected_result): - first_metric = DummyMetric(3) - - final_sub = first_metric - second_operand - - assert isinstance(final_sub, CompositionalMetric) - - assert torch.allclose(expected_result, final_sub.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(3), torch.tensor(2.0)), - (3, torch.tensor(2.0)), - (3.0, torch.tensor(2.0)), - (torch.tensor(3), torch.tensor(2.0)), - ], -) -@RunIf(min_torch="1.5.0") -def test_metrics_truediv(second_operand, expected_result): - first_metric = DummyMetric(6) - - final_truediv = first_metric / second_operand - - assert isinstance(final_truediv, CompositionalMetric) - - assert torch.allclose(expected_result, final_truediv.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [(DummyMetric([1, 0, 3]), torch.tensor([-2, -2, 0])), (torch.tensor([1, 0, 3]), torch.tensor([-2, -2, 0]))], -) -def test_metrics_xor(second_operand, expected_result): - first_metric = DummyMetric([-1, -2, 3]) - - final_xor = first_metric ^ second_operand - final_rxor = second_operand ^ first_metric - - assert isinstance(final_xor, CompositionalMetric) - assert isinstance(final_rxor, CompositionalMetric) - - assert torch.allclose(expected_result, final_xor.compute()) - assert torch.allclose(expected_result, final_rxor.compute()) - - -def test_metrics_abs(): - first_metric = DummyMetric(-1) - - final_abs = abs(first_metric) - - assert isinstance(final_abs, CompositionalMetric) - - assert torch.allclose(torch.tensor(1), final_abs.compute()) - - -def test_metrics_invert(): - first_metric = DummyMetric(1) - - final_inverse = ~first_metric - assert isinstance(final_inverse, CompositionalMetric) - assert torch.allclose(torch.tensor(-2), final_inverse.compute()) - - -def test_metrics_neg(): - first_metric = DummyMetric(1) - - final_neg = neg(first_metric) - assert isinstance(final_neg, CompositionalMetric) - assert torch.allclose(torch.tensor(-1), final_neg.compute()) - - -def test_metrics_pos(): - first_metric = DummyMetric(-1) - - final_pos = pos(first_metric) - assert isinstance(final_pos, CompositionalMetric) - assert torch.allclose(torch.tensor(1), final_pos.compute()) - - -def test_compositional_metrics_update(): - - compos = DummyMetric(5) + DummyMetric(4) - - assert isinstance(compos, CompositionalMetric) - compos.update() - compos.update() - compos.update() - - assert isinstance(compos.metric_a, DummyMetric) - assert isinstance(compos.metric_b, DummyMetric) - - assert compos.metric_a._num_updates == 3 - assert compos.metric_b._num_updates == 3 diff --git a/tests/metrics/test_ddp.py b/tests/metrics/test_ddp.py deleted file mode 100644 index 5120cce0a0425..0000000000000 --- a/tests/metrics/test_ddp.py +++ /dev/null @@ -1,71 +0,0 @@ -import pytest -import torch - -from pytorch_lightning.metrics import Metric -from tests.helpers.runif import RunIf -from tests.metrics.test_metric import Dummy -from tests.metrics.utils import setup_ddp - -torch.manual_seed(42) - - -def _test_ddp_sum(rank, worldsize): - setup_ddp(rank, worldsize) - dummy = Dummy() - dummy._reductions = {"foo": torch.sum} - dummy.foo = torch.tensor(1) - - dummy._sync_dist() - assert dummy.foo == worldsize - - -def _test_ddp_cat(rank, worldsize): - setup_ddp(rank, worldsize) - dummy = Dummy() - dummy._reductions = {"foo": torch.cat} - dummy.foo = [torch.tensor([1])] - dummy._sync_dist() - assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) - - -def _test_ddp_sum_cat(rank, worldsize): - setup_ddp(rank, worldsize) - dummy = Dummy() - dummy._reductions = {"foo": torch.cat, "bar": torch.sum} - dummy.foo = [torch.tensor([1])] - dummy.bar = torch.tensor(1) - dummy._sync_dist() - assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) - assert dummy.bar == worldsize - - -@RunIf(skip_windows=True) -@pytest.mark.parametrize("process", [_test_ddp_cat, _test_ddp_sum, _test_ddp_sum_cat]) -def test_ddp(process): - torch.multiprocessing.spawn(process, args=(2, ), nprocs=2) - - -def _test_non_contiguous_tensors(rank, worldsize): - setup_ddp(rank, worldsize) - - class DummyMetric(Metric): - - def __init__(self): - super().__init__() - self.add_state("x", default=[], dist_reduce_fx=None) - - def update(self, x): - self.x.append(x) - - def compute(self): - x = torch.cat(self.x, dim=0) - return x.sum() - - metric = DummyMetric() - metric.update(torch.randn(10, 5)[:, 0]) - - -@RunIf(skip_windows=True) -def test_non_contiguous_tensors(): - """ Test that gather_all operation works for non contiguous tensors """ - torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2, ), nprocs=2) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py deleted file mode 100644 index ad7b4566dc012..0000000000000 --- a/tests/metrics/test_metric.py +++ /dev/null @@ -1,395 +0,0 @@ -import pickle -from collections import OrderedDict -from distutils.version import LooseVersion - -import cloudpickle -import numpy as np -import pytest -import torch -from torch import nn - -from pytorch_lightning.metrics.metric import Metric, MetricCollection -from tests.helpers.runif import RunIf - -torch.manual_seed(42) - - -class Dummy(Metric): - name = "Dummy" - - def __init__(self): - super().__init__() - self.add_state("x", torch.tensor(0.0), dist_reduce_fx=None) - - def update(self): - pass - - def compute(self): - pass - - -class DummyList(Metric): - name = "DummyList" - - def __init__(self): - super().__init__() - self.add_state("x", list(), dist_reduce_fx=None) - - def update(self): - pass - - def compute(self): - pass - - -def test_inherit(): - Dummy() - - -def test_add_state(): - a = Dummy() - - a.add_state("a", torch.tensor(0), "sum") - assert a._reductions["a"](torch.tensor([1, 1])) == 2 - - a.add_state("b", torch.tensor(0), "mean") - assert np.allclose(a._reductions["b"](torch.tensor([1.0, 2.0])).numpy(), 1.5) - - a.add_state("c", torch.tensor(0), "cat") - assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2, ) - - with pytest.raises(ValueError): - a.add_state("d1", torch.tensor(0), 'xyz') - - with pytest.raises(ValueError): - a.add_state("d2", torch.tensor(0), 42) - - with pytest.raises(ValueError): - a.add_state("d3", [torch.tensor(0)], 'sum') - - with pytest.raises(ValueError): - a.add_state("d4", 42, 'sum') - - def custom_fx(x): - return -1 - - a.add_state("e", torch.tensor(0), custom_fx) - assert a._reductions["e"](torch.tensor([1, 1])) == -1 - - -def test_add_state_persistent(): - a = Dummy() - - a.add_state("a", torch.tensor(0), "sum", persistent=True) - assert "a" in a.state_dict() - - a.add_state("b", torch.tensor(0), "sum", persistent=False) - - if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - assert "b" not in a.state_dict() - - -def test_reset(): - - class A(Dummy): - pass - - class B(DummyList): - pass - - a = A() - assert a.x == 0 - a.x = torch.tensor(5) - a.reset() - assert a.x == 0 - - b = B() - assert isinstance(b.x, list) and len(b.x) == 0 - b.x = torch.tensor(5) - b.reset() - assert isinstance(b.x, list) and len(b.x) == 0 - - -def test_update(): - - class A(Dummy): - - def update(self, x): - self.x += x - - a = A() - assert a.x == 0 - assert a._computed is None - a.update(1) - assert a._computed is None - assert a.x == 1 - a.update(2) - assert a.x == 3 - assert a._computed is None - - -def test_compute(): - - class A(Dummy): - - def update(self, x): - self.x += x - - def compute(self): - return self.x - - a = A() - assert 0 == a.compute() - assert 0 == a.x - a.update(1) - assert a._computed is None - assert a.compute() == 1 - assert a._computed == 1 - a.update(2) - assert a._computed is None - assert a.compute() == 3 - assert a._computed == 3 - - # called without update, should return cached value - a._computed = 5 - assert a.compute() == 5 - - -def test_hash(): - - class A(Dummy): - pass - - class B(DummyList): - pass - - a1 = A() - a2 = A() - assert hash(a1) != hash(a2) - - b1 = B() - b2 = B() - assert hash(b1) == hash(b2) - assert isinstance(b1.x, list) and len(b1.x) == 0 - b1.x.append(torch.tensor(5)) - assert isinstance(hash(b1), int) # <- check that nothing crashes - assert isinstance(b1.x, list) and len(b1.x) == 1 - b2.x.append(torch.tensor(5)) - # Sanity: - assert isinstance(b2.x, list) and len(b2.x) == 1 - # Now that they have tensor contents, they should have different hashes: - assert hash(b1) != hash(b2) - - -def test_forward(): - - class A(Dummy): - - def update(self, x): - self.x += x - - def compute(self): - return self.x - - a = A() - assert a(5) == 5 - assert a._forward_cache == 5 - - assert a(8) == 8 - assert a._forward_cache == 8 - - assert a.compute() == 13 - - -class DummyMetric1(Dummy): - - def update(self, x): - self.x += x - - def compute(self): - return self.x - - -class DummyMetric2(Dummy): - - def update(self, y): - self.x -= y - - def compute(self): - return self.x - - -def test_pickle(tmpdir): - # doesn't tests for DDP - a = DummyMetric1() - a.update(1) - - metric_pickled = pickle.dumps(a) - metric_loaded = pickle.loads(metric_pickled) - - assert metric_loaded.compute() == 1 - - metric_loaded.update(5) - assert metric_loaded.compute() == 6 - - metric_pickled = cloudpickle.dumps(a) - metric_loaded = cloudpickle.loads(metric_pickled) - - assert metric_loaded.compute() == 1 - - -def test_state_dict(tmpdir): - """ test that metric states can be removed and added to state dict """ - metric = Dummy() - assert metric.state_dict() == OrderedDict() - metric.persistent(True) - assert metric.state_dict() == OrderedDict(x=0) - metric.persistent(False) - assert metric.state_dict() == OrderedDict() - - -def test_child_metric_state_dict(): - """ test that child metric states will be added to parent state dict """ - - class TestModule(nn.Module): - - def __init__(self): - super().__init__() - self.metric = Dummy() - self.metric.add_state('a', torch.tensor(0), persistent=True) - self.metric.add_state('b', [], persistent=True) - self.metric.register_buffer('c', torch.tensor(0)) - - module = TestModule() - expected_state_dict = { - 'metric.a': torch.tensor(0), - 'metric.b': [], - 'metric.c': torch.tensor(0), - } - assert module.state_dict() == expected_state_dict - - -@RunIf(min_gpus=1) -def test_device_and_dtype_transfer(tmpdir): - metric = DummyMetric1() - assert metric.x.is_cuda is False - assert metric.x.dtype == torch.float32 - - metric = metric.to(device='cuda') - assert metric.x.is_cuda - - metric = metric.double() - assert metric.x.dtype == torch.float64 - - metric = metric.half() - assert metric.x.dtype == torch.float16 - - -def test_metric_collection(tmpdir): - m1 = DummyMetric1() - m2 = DummyMetric2() - - metric_collection = MetricCollection([m1, m2]) - - # Test correct dict structure - assert len(metric_collection) == 2 - assert metric_collection['DummyMetric1'] == m1 - assert metric_collection['DummyMetric2'] == m2 - - # Test correct initialization - for name, metric in metric_collection.items(): - assert metric.x == 0, f'Metric {name} not initialized correctly' - - # Test every metric gets updated - metric_collection.update(5) - for name, metric in metric_collection.items(): - assert metric.x.abs() == 5, f'Metric {name} not updated correctly' - - # Test compute on each metric - metric_collection.update(-5) - metric_vals = metric_collection.compute() - assert len(metric_vals) == 2 - for name, metric_val in metric_vals.items(): - assert metric_val == 0, f'Metric {name}.compute not called correctly' - - # Test that everything is reset - for name, metric in metric_collection.items(): - assert metric.x == 0, f'Metric {name} not reset correctly' - - # Test pickable - metric_pickled = pickle.dumps(metric_collection) - metric_loaded = pickle.loads(metric_pickled) - assert isinstance(metric_loaded, MetricCollection) - - -@RunIf(min_gpus=1) -def test_device_and_dtype_transfer_metriccollection(tmpdir): - m1 = DummyMetric1() - m2 = DummyMetric2() - - metric_collection = MetricCollection([m1, m2]) - for _, metric in metric_collection.items(): - assert metric.x.is_cuda is False - assert metric.x.dtype == torch.float32 - - metric_collection = metric_collection.to(device='cuda') - for _, metric in metric_collection.items(): - assert metric.x.is_cuda - - metric_collection = metric_collection.double() - for _, metric in metric_collection.items(): - assert metric.x.dtype == torch.float64 - - metric_collection = metric_collection.half() - for _, metric in metric_collection.items(): - assert metric.x.dtype == torch.float16 - - -def test_metric_collection_wrong_input(tmpdir): - """ Check that errors are raised on wrong input """ - m1 = DummyMetric1() - - # Not all input are metrics (list) - with pytest.raises(ValueError): - _ = MetricCollection([m1, 5]) - - # Not all input are metrics (dict) - with pytest.raises(ValueError): - _ = MetricCollection({'metric1': m1, 'metric2': 5}) - - # Same metric passed in multiple times - with pytest.raises(ValueError, match='Encountered two metrics both named *.'): - _ = MetricCollection([m1, m1]) - - # Not a list or dict passed in - with pytest.raises(ValueError, match='Unknown input to MetricCollection.'): - _ = MetricCollection(m1) - - -def test_metric_collection_args_kwargs(tmpdir): - """ Check that args and kwargs gets passed correctly in metric collection, - Checks both update and forward method - """ - m1 = DummyMetric1() - m2 = DummyMetric2() - - metric_collection = MetricCollection([m1, m2]) - - # args gets passed to all metrics - metric_collection.update(5) - assert metric_collection['DummyMetric1'].x == 5 - assert metric_collection['DummyMetric2'].x == -5 - metric_collection.reset() - _ = metric_collection(5) - assert metric_collection['DummyMetric1'].x == 5 - assert metric_collection['DummyMetric2'].x == -5 - metric_collection.reset() - - # kwargs gets only passed to metrics that it matches - metric_collection.update(x=10, y=20) - assert metric_collection['DummyMetric1'].x == 10 - assert metric_collection['DummyMetric2'].x == -20 - metric_collection.reset() - _ = metric_collection(x=10, y=20) - assert metric_collection['DummyMetric1'].x == 10 - assert metric_collection['DummyMetric2'].x == -20 diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 895305fa9da7e..e52e39cb16488 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -1,11 +1,13 @@ import torch +from torchmetrics import Metric as TMetric from pytorch_lightning import Trainer -from pytorch_lightning.metrics import Metric, MetricCollection +from pytorch_lightning.metrics import Metric as PLMetric +from pytorch_lightning.metrics import MetricCollection from tests.helpers.boring_model import BoringModel -class SumMetric(Metric): +class SumMetric(TMetric): def __init__(self): super().__init__() @@ -18,7 +20,7 @@ def compute(self): return self.x -class DiffMetric(Metric): +class DiffMetric(PLMetric): def __init__(self): super().__init__() diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py new file mode 100644 index 0000000000000..d3703bf3691c9 --- /dev/null +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -0,0 +1,348 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in v1.5.0""" + +import pytest +import torch + +from pytorch_lightning.metrics import ( + Accuracy, + AUC, + AUROC, + AveragePrecision, + ConfusionMatrix, + ExplainedVariance, + F1, + FBeta, + HammingDistance, + IoU, + MeanAbsoluteError, + MeanSquaredError, + MeanSquaredLogError, + MetricCollection, + Precision, + PrecisionRecallCurve, + PSNR, + R2Score, + Recall, + ROC, + SSIM, + StatScores, +) +from pytorch_lightning.metrics.functional import ( + auc, + auroc, + average_precision, + bleu_score, + confusion_matrix, + embedding_similarity, + explained_variance, + f1, + fbeta, + hamming_distance, + iou, + mean_absolute_error, + mean_squared_error, + mean_squared_log_error, + precision, + precision_recall, + precision_recall_curve, + psnr, + r2score, + recall, + roc, + ssim, + stat_scores, +) +from pytorch_lightning.metrics.functional.accuracy import accuracy +from pytorch_lightning.metrics.functional.mean_relative_error import mean_relative_error +from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot + + +def test_v1_5_metrics_utils(): + x = torch.tensor([1, 2, 3]) + with pytest.deprecated_call(match="It will be removed in v1.5.0"): + assert torch.equal(to_onehot(x), torch.Tensor([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]).to(int)) + + with pytest.deprecated_call(match="It will be removed in v1.5.0"): + assert get_num_classes(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 0])) == 4 + + x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) + with pytest.deprecated_call(match="It will be removed in v1.5.0"): + assert torch.equal(select_topk(x, topk=2), torch.Tensor([[0, 1, 1], [1, 1, 0]]).to(torch.int32)) + + x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) + with pytest.deprecated_call(match="It will be removed in v1.5.0"): + assert torch.equal(to_categorical(x), torch.Tensor([1, 0]).to(int)) + + +def test_v1_5_metrics_collection(): + target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) + preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) + + MetricCollection.__init__._warned = False + with pytest.deprecated_call(match="It will be removed in v1.5.0."): + metrics = MetricCollection([Accuracy()]) + assert metrics(preds, target) == {'Accuracy': torch.tensor(0.1250)} + + +def test_v1_5_metric_accuracy(): + accuracy._warned = False + + preds = torch.tensor([0, 0, 1, 0, 1]) + target = torch.tensor([0, 0, 1, 1, 1]) + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert accuracy(preds, target) == torch.tensor(0.8) + + Accuracy.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + Accuracy() + + +def test_v1_5_metric_auc_auroc(): + AUC.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + AUC() + + ROC.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + ROC() + + AUROC.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + AUROC() + + x = torch.tensor([0, 1, 2, 3]) + y = torch.tensor([0, 1, 2, 2]) + auc._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert auc(x, y) == torch.tensor(4.) + + preds = torch.tensor([0, 1, 2, 3]) + target = torch.tensor([0, 1, 1, 1]) + roc._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + fpr, tpr, thrs = roc(preds, target, pos_label=1) + assert torch.equal(fpr, torch.tensor([0., 0., 0., 0., 1.])) + assert torch.allclose(tpr, torch.tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]), atol=1e-4) + assert torch.equal(thrs, torch.tensor([4, 3, 2, 1, 0])) + + preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) + target = torch.tensor([0, 0, 1, 1, 1]) + auroc._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert auroc(preds, target) == torch.tensor(0.5) + + +def test_v1_5_metric_precision_recall(): + AveragePrecision.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + AveragePrecision() + + Precision.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + Precision() + + Recall.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + Recall() + + PrecisionRecallCurve.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + PrecisionRecallCurve() + + pred = torch.tensor([0, 1, 2, 3]) + target = torch.tensor([0, 1, 1, 1]) + average_precision._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert average_precision(pred, target) == torch.tensor(1.) + + precision._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert precision(pred, target) == torch.tensor(0.5) + + recall._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert recall(pred, target) == torch.tensor(0.5) + + precision_recall._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + prec, rc = precision_recall(pred, target) + assert prec == torch.tensor(0.5) + assert rc == torch.tensor(0.5) + + precision_recall_curve._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + prec, rc, thrs = precision_recall_curve(pred, target) + assert torch.equal(prec, torch.tensor([1., 1., 1., 1.])) + assert torch.allclose(rc, torch.tensor([1., 0.6667, 0.3333, 0.]), atol=1e-4) + assert torch.equal(thrs, torch.tensor([1, 2, 3])) + + +def test_v1_5_metric_classif_mix(): + ConfusionMatrix.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + ConfusionMatrix(num_classes=1) + + FBeta.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + FBeta(num_classes=1) + + F1.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + F1(num_classes=1) + + HammingDistance.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + HammingDistance() + + StatScores.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + StatScores() + + target = torch.tensor([1, 1, 0, 0]) + preds = torch.tensor([0, 1, 0, 0]) + confusion_matrix._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert torch.equal(confusion_matrix(preds, target, num_classes=2), torch.tensor([[2., 0.], [1., 1.]])) + + target = torch.tensor([0, 1, 2, 0, 1, 2]) + preds = torch.tensor([0, 2, 1, 0, 0, 1]) + fbeta._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert torch.allclose(fbeta(preds, target, num_classes=3, beta=0.5), torch.tensor(0.3333), atol=1e-4) + + f1._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert torch.allclose(f1(preds, target, num_classes=3), torch.tensor(0.3333), atol=1e-4) + + target = torch.tensor([[0, 1], [1, 1]]) + preds = torch.tensor([[0, 1], [0, 1]]) + hamming_distance._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert hamming_distance(preds, target) == torch.tensor(0.25) + + preds = torch.tensor([1, 0, 2, 1]) + target = torch.tensor([1, 1, 2, 0]) + stat_scores._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert torch.equal(stat_scores(preds, target, reduce='micro'), torch.tensor([2, 2, 6, 2, 4])) + + +def test_v1_5_metric_detect(): + IoU.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + IoU(num_classes=1) + + target = torch.randint(0, 2, (10, 25, 25)) + preds = torch.tensor(target) + preds[2:5, 7:13, 9:15] = 1 - preds[2:5, 7:13, 9:15] + iou._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = iou(preds, target) + assert torch.allclose(res, torch.tensor(0.9660), atol=1e-4) + + +def test_v1_5_metric_regress(): + ExplainedVariance.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + ExplainedVariance() + + MeanAbsoluteError.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + MeanAbsoluteError() + + MeanSquaredError.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + MeanSquaredError() + + MeanSquaredLogError.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + MeanSquaredLogError() + + target = torch.tensor([3, -0.5, 2, 7]) + preds = torch.tensor([2.5, 0.0, 2, 8]) + explained_variance._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = explained_variance(preds, target) + assert torch.allclose(res, torch.tensor(0.9572), atol=1e-4) + + x = torch.tensor([0., 1, 2, 3]) + y = torch.tensor([0., 1, 2, 2]) + mean_absolute_error._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert mean_absolute_error(x, y) == 0.25 + + mean_relative_error._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert mean_relative_error(x, y) == 0.125 + + mean_squared_error._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert mean_squared_error(x, y) == 0.25 + + mean_squared_log_error._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = mean_squared_log_error(x, y) + assert torch.allclose(res, torch.tensor(0.0207), atol=1e-4) + + PSNR.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + PSNR() + + R2Score.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + R2Score() + + SSIM.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + SSIM() + + preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) + target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) + psnr._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = psnr(preds, target) + assert torch.allclose(res, torch.tensor(2.5527), atol=1e-4) + + target = torch.tensor([3, -0.5, 2, 7]) + preds = torch.tensor([2.5, 0.0, 2, 8]) + r2score._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = r2score(preds, target) + assert torch.allclose(res, torch.tensor(0.9486), atol=1e-4) + + preds = torch.rand([16, 1, 16, 16]) + target = preds * 0.75 + ssim._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = ssim(preds, target) + assert torch.allclose(res, torch.tensor(0.9219), atol=1e-4) + + +def test_v1_5_metric_others(): + translate_corpus = ['the cat is on the mat'.split()] + reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] + bleu_score._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = bleu_score(translate_corpus, reference_corpus) + assert torch.allclose(res, torch.tensor(0.7598), atol=1e-4) + + embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]]) + embedding_similarity._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = embedding_similarity(embeddings) + assert torch.allclose( + res, torch.tensor([[0.0000, 1.0000, 0.9759], [1.0000, 0.0000, 0.9759], [0.9759, 0.9759, 0.0000]]), atol=1e-4 + ) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4bd6608ce3fcf..f1f17d0624936 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -8,8 +8,7 @@ import pytest import torch from torch.multiprocessing import Pool, set_start_method - -from pytorch_lightning.metrics import Metric +from torchmetrics import Metric try: set_start_method("spawn") diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index d3868cfd979e6..46ab64afccb03 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -21,6 +21,8 @@ import os import sys +import torch + # this is needed because Conda does not use `PYTHONPATH` env var while pip and virtualenv do PYTHONPATH = os.getenv('PYTHONPATH', '') if ':' in PYTHONPATH: @@ -52,8 +54,13 @@ def run_test_from_config(trainer_options, on_gpu, check_size=True): ckpt_path = trainer_options['weights_save_path'] trainer_options.update(callbacks=[ModelCheckpoint(dirpath=ckpt_path)]) - model = BoringModel() + class TestModel(BoringModel): + + def training_epoch_end(self, outputs) -> None: + res = self.trainer.training_type_plugin.reduce(torch.tensor(1., device=self.device), reduce_op="sum") + assert res.sum() == self.trainer.training_type_plugin.world_size + model = TestModel() trainer = Trainer(**trainer_options) trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 9853db342436b..0b9d6776c1aaa 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -17,24 +17,43 @@ import pytest import torch from torch import optim +from torch.utils.data import DataLoader import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel +from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf class AMPTestModel(BoringModel): - def training_step(self, batch, batch_idx): + def _step(self, batch, batch_idx): assert torch.is_autocast_enabled() output = self(batch) assert output.dtype == torch.float16 loss = self.loss(batch, output) - return {"loss": loss} + return loss + + def training_step(self, batch, batch_idx): + output = self._step(batch, batch_idx) + return {"loss": output} + + def validation_step(self, batch, batch_idx): + output = self._step(batch, batch_idx) + return {"x": output} + + def test_step(self, batch, batch_idx): + output = self._step(batch, batch_idx) + return {"y": output} + + def predict(self, batch, batch_idx, dataloader_idx=None): + assert torch.is_autocast_enabled() + output = self(batch) + assert output.dtype == torch.float16 + return output @pytest.mark.skip(reason='dp + amp not supported currently') # TODO @@ -54,6 +73,8 @@ def test_amp_single_gpu_dp(tmpdir): model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -73,6 +94,8 @@ def test_amp_single_gpu_ddp_spawn(tmpdir): model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -112,6 +135,8 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir): model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 0d1c7cf40a2bf..1d55d4a5a63b7 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect from unittest import mock from unittest.mock import PropertyMock @@ -20,7 +19,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.trainer.states import TrainerState -from tests.helpers import BoringModel, RandomDataset +from tests.helpers import BoringModel, RandomDataset, BoringDataModule from tests.helpers.runif import RunIf @@ -260,7 +259,7 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): def test_trainer_model_hook_system(tmpdir): - """Test the hooks system.""" + """Test the LightningModule hook system.""" class HookedModel(BoringModel): @@ -269,149 +268,151 @@ def __init__(self): self.called = [] def on_after_backward(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_after_backward") super().on_after_backward() - def on_before_zero_grad(self, optimizer): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_before_zero_grad(optimizer) + def on_before_zero_grad(self, *args, **kwargs): + self.called.append("on_before_zero_grad") + super().on_before_zero_grad(*args, **kwargs) def on_epoch_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_epoch_start") super().on_epoch_start() def on_epoch_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_epoch_end") super().on_epoch_end() def on_fit_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_fit_start") super().on_fit_start() def on_fit_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_fit_end") super().on_fit_end() - def on_hpc_load(self, checkpoint): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_hpc_load(checkpoint) + def on_hpc_load(self, *args, **kwargs): + self.called.append("on_hpc_load") + super().on_hpc_load(*args, **kwargs) - def on_hpc_save(self, checkpoint): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_hpc_save(checkpoint) + def on_hpc_save(self, *args, **kwargs): + self.called.append("on_hpc_save") + super().on_hpc_save(*args, **kwargs) - def on_load_checkpoint(self, checkpoint): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_load_checkpoint(checkpoint) + def on_load_checkpoint(self, *args, **kwargs): + self.called.append("on_load_checkpoint") + super().on_load_checkpoint(*args, **kwargs) - def on_save_checkpoint(self, checkpoint): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_save_checkpoint(checkpoint) + def on_save_checkpoint(self, *args, **kwargs): + self.called.append("on_save_checkpoint") + super().on_save_checkpoint(*args, **kwargs) def on_pretrain_routine_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_pretrain_routine_start") super().on_pretrain_routine_start() def on_pretrain_routine_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_pretrain_routine_end") super().on_pretrain_routine_end() def on_train_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_train_start") super().on_train_start() def on_train_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_train_end") super().on_train_end() - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_train_batch_start(batch, batch_idx, dataloader_idx) + def on_train_batch_start(self, *args, **kwargs): + self.called.append("on_train_batch_start") + super().on_train_batch_start(*args, **kwargs) - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) + def on_train_batch_end(self, *args, **kwargs): + self.called.append("on_train_batch_end") + super().on_train_batch_end(*args, **kwargs) def on_train_epoch_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_train_epoch_start") super().on_train_epoch_start() def on_train_epoch_end(self, outputs): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_train_epoch_end") super().on_train_epoch_end(outputs) def on_validation_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_start") super().on_validation_start() def on_validation_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_end") super().on_validation_end() - def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_validation_batch_start(batch, batch_idx, dataloader_idx) + def on_validation_batch_start(self, *args, **kwargs): + self.called.append("on_validation_batch_start") + super().on_validation_batch_start(*args, **kwargs) - def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) + def on_validation_batch_end(self, *args, **kwargs): + self.called.append("on_validation_batch_end") + super().on_validation_batch_end(*args, **kwargs) def on_validation_epoch_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_epoch_start") super().on_validation_epoch_start() - def on_validation_epoch_end(self): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_validation_epoch_end() + def on_validation_epoch_end(self, *args, **kwargs): + self.called.append("on_validation_epoch_end") + super().on_validation_epoch_end(*args, **kwargs) def on_test_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_start") super().on_test_start() - def on_test_batch_start(self, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_test_batch_start(batch, batch_idx, dataloader_idx) + def on_test_batch_start(self, *args, **kwargs): + self.called.append("on_test_batch_start") + super().on_test_batch_start(*args, **kwargs) - def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_test_batch_end(outputs, batch, batch_idx, dataloader_idx) + def on_test_batch_end(self, *args, **kwargs): + self.called.append("on_test_batch_end") + super().on_test_batch_end(*args, **kwargs) def on_test_epoch_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_epoch_start") super().on_test_epoch_start() - def on_test_epoch_end(self): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_test_epoch_end() + def on_test_epoch_end(self, *args, **kwargs): + self.called.append("on_test_epoch_end") + super().on_test_epoch_end(*args, **kwargs) def on_validation_model_eval(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_model_eval") super().on_validation_model_eval() def on_validation_model_train(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_model_train") super().on_validation_model_train() def on_test_model_eval(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_model_eval") super().on_test_model_eval() def on_test_model_train(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_model_train") super().on_test_model_train() def on_test_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_end") super().on_test_end() + def setup(self, stage=None): + self.called.append(f"setup_{stage}") + super().setup(stage=stage) + def teardown(self, stage=None): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append(f"teardown_{stage}") super().teardown(stage) model = HookedModel() - assert model.called == [] - # fit model trainer = Trainer( default_root_dir=tmpdir, @@ -427,11 +428,13 @@ def teardown(self, stage=None): trainer.fit(model) expected = [ + 'setup_fit', 'on_fit_start', 'on_pretrain_routine_start', 'on_pretrain_routine_end', 'on_validation_model_eval', 'on_validation_start', + 'on_epoch_start', 'on_validation_epoch_start', 'on_validation_batch_start', 'on_validation_batch_end', @@ -454,6 +457,7 @@ def teardown(self, stage=None): 'on_epoch_end', 'on_validation_model_eval', 'on_validation_start', + 'on_epoch_start', 'on_validation_epoch_start', 'on_validation_batch_start', 'on_validation_batch_end', @@ -464,7 +468,7 @@ def teardown(self, stage=None): 'on_validation_model_train', 'on_train_end', 'on_fit_end', - 'teardown', + 'teardown_fit', ] assert model.called == expected @@ -472,8 +476,10 @@ def teardown(self, stage=None): trainer.validate(model, verbose=False) expected = [ + 'setup_validate', 'on_validation_model_eval', 'on_validation_start', + 'on_epoch_start', 'on_validation_epoch_start', 'on_validation_batch_start', 'on_validation_batch_end', @@ -481,16 +487,18 @@ def teardown(self, stage=None): 'on_epoch_end', 'on_validation_end', 'on_validation_model_train', - 'teardown', + 'teardown_validate', ] assert model.called == expected model = HookedModel() - trainer.test(model, verbose=False) + expected = [ + 'setup_test', 'on_test_model_eval', 'on_test_start', + 'on_epoch_start', 'on_test_epoch_start', 'on_test_batch_start', 'on_test_batch_end', @@ -498,6 +506,119 @@ def teardown(self, stage=None): 'on_epoch_end', 'on_test_end', 'on_test_model_train', - 'teardown', + 'teardown_test', ] assert model.called == expected + + +def test_trainer_datamodule_hook_system(tmpdir): + """Test the LightningDataModule hook system.""" + + class HookedDataModule(BoringDataModule): + def __init__(self): + super().__init__() + self.called = [] + + def prepare_data(self): + self.called.append("prepare_data") + super().prepare_data() + + def setup(self, stage=None): + self.called.append(f"setup_{stage}") + super().setup(stage=stage) + + def teardown(self, stage=None): + self.called.append(f"teardown_{stage}") + super().teardown(stage=stage) + + def train_dataloader(self): + self.called.append("train_dataloader") + return super().train_dataloader() + + def test_dataloader(self): + self.called.append("test_dataloader") + return super().test_dataloader() + + def val_dataloader(self): + self.called.append("val_dataloader") + return super().val_dataloader() + + def predict_dataloader(self): + self.called.append("predict_dataloader") + + def transfer_batch_to_device(self, *args, **kwargs): + self.called.append("transfer_batch_to_device") + return super().transfer_batch_to_device(*args, **kwargs) + + def on_before_batch_transfer(self, *args, **kwargs): + self.called.append("on_before_batch_transfer") + return super().on_before_batch_transfer(*args, **kwargs) + + def on_after_batch_transfer(self, *args, **kwargs): + self.called.append("on_after_batch_transfer") + return super().on_after_batch_transfer(*args, **kwargs) + + model = BoringModel() + dm = HookedDataModule() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=1, + limit_train_batches=2, + limit_test_batches=1, + progress_bar_refresh_rate=0, + weights_summary=None, + reload_dataloaders_every_epoch=True, + ) + trainer.fit(model, datamodule=dm) + + expected = [ + 'prepare_data', + 'setup_fit', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'train_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_fit' + ] + assert dm.called == expected + + dm = HookedDataModule() + trainer.validate(model, datamodule=dm, verbose=False) + + expected = [ + 'prepare_data', + 'setup_validate', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_validate' + ] + assert dm.called == expected + + dm = HookedDataModule() + trainer.test(model, datamodule=dm, verbose=False) + + expected = [ + 'prepare_data', + 'setup_test', + 'test_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_test' + ] + assert dm.called == expected diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 636979821b313..3c8c9b0f36041 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -16,11 +16,13 @@ import shlex import subprocess import sys +from unittest.mock import patch import numpy as np import pytest import torch from sklearn.metrics import accuracy_score +from torch import optim import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils @@ -47,6 +49,9 @@ def _run_horovod(trainer_options, on_gpu=False): # for Horovod, we interpret `gpus` to be set per worker trainer_options.update(gpus=1 if on_gpu else None) tutils.reset_seed() + # todo: Find why coverage breaks CI. + # append = '-a' if '.coverage' in os.listdir(_PROJECT_ROOT) else '' # noqa E265 + # str(num_processes), sys.executable, '-m', 'coverage', 'run', '--source', 'pytorch_lightning', append, # noqa E265 cmdline = [ 'horovodrun', '-np', str(num_processes), sys.executable, TEST_SCRIPT, '--trainer-options', @@ -109,7 +114,9 @@ def test_horovod_multi_gpu(tmpdir): _run_horovod(trainer_options, on_gpu=True) -@pytest.mark.skip(reason="Horovod has a problem with broadcast when using apex?") # todo +# https://discuss.pytorch.org/t/torch-cuda-amp-vs-nvidia-apex/74994 +# Check with (tgaddair) on Horovod issues if this feature is needed +@pytest.mark.skip(reason="Horovod currently doesn't work with Apex") # todo @RunIf(min_gpus=2, skip_windows=True, amp_apex=True, horovod_nccl=True) def test_horovod_apex(tmpdir): """Test Horovod with multi-GPU support using apex amp.""" @@ -130,7 +137,6 @@ def test_horovod_apex(tmpdir): _run_horovod(trainer_options, on_gpu=True) -@pytest.mark.skip(reason="Skip till Horovod fixes integration with Native torch.cuda.amp") # todo @RunIf(min_gpus=2, skip_windows=True, amp_native=True, horovod_nccl=True) def test_horovod_amp(tmpdir): """Test Horovod with multi-GPU support using native amp.""" @@ -151,6 +157,24 @@ def test_horovod_amp(tmpdir): _run_horovod(trainer_options, on_gpu=True) +@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) +def test_horovod_gather(tmpdir): + """Test Horovod with multi-GPU support using native amp.""" + trainer_options = dict( + default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), + gradient_clip_val=1.0, + progress_bar_refresh_rate=0, + max_epochs=1, + limit_train_batches=0.4, + limit_val_batches=0.2, + gpus=2, + deterministic=True, + accelerator='horovod', + ) + _run_horovod(trainer_options, on_gpu=True) + + @RunIf(min_gpus=1, skip_windows=True, horovod_nccl=True) def test_horovod_transfer_batch_to_gpu(tmpdir): @@ -179,7 +203,7 @@ def validation_step(self, batch, *args, **kwargs): tpipes.run_model_test_without_loggers(trainer_options, model) -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, horovod=True) def test_horovod_multi_optimizer(tmpdir): model = BasicGAN() @@ -211,8 +235,7 @@ def get_optimizer_params(optimizer): assert get_model_params(model.discriminator) == get_optimizer_params(trainer.optimizers[1]) -# TODO: unclear Horovod failure... -@pytest.mark.skip(reason="unclear Horovod failure...") +@pytest.mark.skipif(reason="CI agent.jobstatus=Succeeded: Permission denied") @RunIf(skip_windows=True, horovod=True) def test_result_reduce_horovod(tmpdir): """Make sure result logging works with Horovod. @@ -254,6 +277,7 @@ def training_epoch_end(self, outputs) -> None: max_epochs=1, log_every_n_steps=1, weights_summary=None, + logger=False ) trainer.fit(model) @@ -261,9 +285,8 @@ def training_epoch_end(self, outputs) -> None: horovod.run(hvd_test_fn, np=2) -# TODO: unclear Horovod failure... -@pytest.mark.skip(reason="unclear Horovod failure...") -@RunIf(skip_windows=True, horovod=True) +@pytest.mark.skipif(reason="CI agent.jobstatus=Succeeded: Permission denied") +@RunIf(skip_windows=True, horovod=True, num_gpus=2) def test_accuracy_metric_horovod(): num_batches = 10 batch_size = 16 @@ -278,10 +301,7 @@ def sk_metric(preds, target): target = torch.randint(high=2, size=(num_batches, batch_size)) def _compute_batch(): - trainer = Trainer( - fast_dev_run=True, - accelerator='horovod', - ) + trainer = Trainer(fast_dev_run=True, accelerator='horovod', logger=False) assert isinstance(trainer.accelerator, CPUAccelerator) # TODO: test that we selected the correct training_type_plugin based on horovod flags @@ -289,7 +309,7 @@ def _compute_batch(): metric = Accuracy( compute_on_step=True, dist_sync_on_step=True, - dist_sync_fn=trainer.training_type_plugin.gather_all_tensors, + dist_sync_fn=trainer.training_type_plugin.all_gather, threshold=threshold ) @@ -314,33 +334,45 @@ def _compute_batch(): horovod.run(_compute_batch, np=2) -# @RunIf(skip_windows=True) -# def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir): -# model = BoringModel() -# model.configure_optimizers = model.configure_optimizers__multiple_schedulers -# -# num_workers = 8 -# init_lr = hparams.get('learning_rate') * num_workers -# -# with patch('pytorch_lightning.accelerators.legacy.horovod_backend.hvd.size') as mock_hvd_size: -# mock_hvd_size.return_value = 8 -# -# # fit model -# trainer = Trainer( -# default_root_dir=tmpdir, -# max_epochs=1, -# limit_val_batches=0.5, -# limit_train_batches=0.2, -# distributed_backend='horovod' -# ) -# results = trainer.fit(model) -# assert results == 1 -# -# adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0] -# adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0] -# -# # Called ones after end of epoch with gamma=0.1 -# assert pytest.approx(init_lr * 0.1) == adjusted_lr1 -# -# # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times with gamma=0.1 -# assert pytest.approx(init_lr * 0.1) == adjusted_lr2 +@RunIf(skip_windows=True, horovod=True) +def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx): + return super().training_step(batch, batch_idx) + + def configure_optimizers(self): + optimizer1 = optim.Adam(self.parameters(), lr=0.1) + optimizer2 = optim.Adam(self.parameters(), lr=0.1) + lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1) + lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) + return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] + + model = TestModel() + model.training_epoch_end = None + + num_workers = 8 + init_lr = 0.1 * num_workers + + with patch('horovod.torch.size', return_value=8): + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.5, + limit_train_batches=0.2, + accelerator='horovod' + ) + results = trainer.fit(model) + assert results == 1 + + adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0] + adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0] + + # Called ones after end of epoch with gamma=0.1 + assert pytest.approx(init_lr * 0.1) == adjusted_lr1 + + # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times with gamma=0.1 + assert pytest.approx(init_lr * 0.1) == adjusted_lr2 diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 0c922c99149fa..b2ed0db87d8d5 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -355,3 +355,44 @@ def test_reduce(rank): assert result.item() == 8 xmp.spawn(test_reduce, nprocs=8, start_method='fork') + + +@RunIf(tpu=True) +@pl_multi_process_test +@pytest.mark.parametrize("clip_val", [10]) +@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_") +def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir): + """ + Ensure that clip gradients is only called if the value is greater than 0. + TODO: Fix (test fails with parametrize) + """ + tutils.reset_seed() + trainer_options = dict( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=1, + tpu_cores=1, + precision=16, + limit_train_batches=4, + limit_val_batches=4, + gradient_clip_val=clip_val, + ) + model = BoringModel() + tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) + + if clip_val > 0: + mock_clip_grad_norm.assert_called() + else: + mock_clip_grad_norm.assert_not_called() + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_if_test_works_with_checkpoint_false(tmpdir): + """Ensure that model trains properly when `checkpoint_callback` is set to False.""" + + # Train a model on TPU + model = BoringModel() + trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 3921e7ef33b8e..aaf47c82d5f08 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -24,7 +24,7 @@ ("training", "training_step"), ("testing", "test_step"), ("validating", "validation_step"), - ("predicting", "predict"), + ("predicting", "predict_step"), ] ) def test_lightning_wrapper_module_methods(wrapper_class, stage): diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py new file mode 100644 index 0000000000000..872b49ef48635 --- /dev/null +++ b/tests/plugins/test_custom_plugin.py @@ -0,0 +1,41 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import DDPPlugin +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +class CustomParallelPlugin(DDPPlugin): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + # Set to None so it will be overwritten by the accelerator connector. + self.sync_batchnorm = None + + +@RunIf(skip_windows=True) +def test_sync_batchnorm_set(tmpdir): + """Tests if sync_batchnorm is automatically set for custom plugin.""" + model = BoringModel() + plugin = CustomParallelPlugin() + assert plugin.sync_batchnorm is None + trainer = Trainer( + max_epochs=1, + plugins=[plugin], + default_root_dir=tmpdir, + sync_batchnorm=True, + ) + trainer.fit(model) + assert plugin.sync_batchnorm is True diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index cf5c23a824732..e6b15069f256a 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -180,7 +180,7 @@ def test_deepspeed_defaults(tmpdir): assert isinstance(plugin.config["zero_optimization"], dict) -@RunIf(deepspeed=True) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_invalid_deepspeed_defaults_no_precision(tmpdir): """Test to ensure that using defaults, if precision is not set to 16, we throw an exception.""" model = BoringModel() @@ -195,7 +195,7 @@ def test_invalid_deepspeed_defaults_no_precision(tmpdir): trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_warn_deepspeed_override_backward(tmpdir): """Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning.""" @@ -216,7 +216,7 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_run_configure_optimizers(tmpdir): """Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), whilst using configure_optimizers for optimizers and schedulers.""" @@ -246,7 +246,7 @@ def on_train_start(self) -> None: _assert_save_model_is_equal(model, tmpdir, trainer) -@RunIf(min_gpus=1, deepspeed=True) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_config(tmpdir, deepspeed_zero_config): """ Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers @@ -280,7 +280,7 @@ def on_train_start(self) -> None: _assert_save_model_is_equal(model, tmpdir, trainer) -@RunIf(min_gpus=1, deepspeed=True) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_custom_precision_params(tmpdir): """Ensure if we modify the FP16 parameters via the DeepSpeedPlugin, the deepspeed config contains these changes.""" @@ -301,7 +301,7 @@ def on_train_start(self) -> None: trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_config): """Ensure if we use a config and turn off cpu_offload, that this is set to False within the config.""" diff --git a/tests/plugins/test_double_plugin.py b/tests/plugins/test_double_plugin.py new file mode 100644 index 0000000000000..f089b1c23149e --- /dev/null +++ b/tests/plugins/test_double_plugin.py @@ -0,0 +1,129 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +import torch +from torch.utils.data import DataLoader, Dataset + +from pytorch_lightning import Trainer +from tests.helpers.boring_model import BoringModel, RandomDataset + + +class RandomFloatIntDataset(Dataset): + + def __init__(self, size, length): + self.len = length + self.float_data = torch.randn(length, size) + self.int_data = torch.randint(10, (length, 1)) + + def __getitem__(self, index): + return self.float_data[index], self.int_data[index] + + def __len__(self): + return self.len + + +class DoublePrecisionBoringModel(BoringModel): + + def training_step(self, batch, batch_idx): + float_data, int_data = batch + assert float_data.dtype == torch.float64 + output = self(float_data) + loss = self.loss(batch, output) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + assert batch.dtype == torch.float64 + output = self(batch) + loss = self.loss(batch, output) + return {"x": loss} + + def test_step(self, batch, batch_idx): + assert batch.dtype == torch.float64 + output = self(batch) + loss = self.loss(batch, output) + return {"y": loss} + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + assert batch.dtype == torch.float64 + return self(batch) + + def on_fit_start(self): + assert self.layer.weight.dtype == torch.float64 + + def on_after_backward(self): + assert self.layer.weight.grad.dtype == torch.float64 + + def train_dataloader(self): + dataset = RandomFloatIntDataset(32, 64) + assert dataset.float_data.dtype == torch.float32 # Don't start with double data + return DataLoader(dataset) + + def predict_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + +class DoublePrecisionBoringModelNoForward(BoringModel): + + def training_step(self, batch, batch_idx): + assert batch.dtype == torch.float64 + output = self.layer(batch) + assert output.dtype == torch.float64 + loss = self.loss(batch, output) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + assert batch.dtype == torch.float64 + output = self.layer(batch) + assert output.dtype == torch.float64 + loss = self.loss(batch, output) + return {"x": loss} + + def test_step(self, batch, batch_idx): + assert batch.dtype == torch.float64 + output = self.layer(batch) + assert output.dtype == torch.float64 + loss = self.loss(batch, output) + return {"y": loss} + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + assert batch.dtype == torch.float64 + output = self.layer(batch) + assert output.dtype == torch.float64 + return output + + def predict_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + +@pytest.mark.parametrize( + 'boring_model', + (DoublePrecisionBoringModel, DoublePrecisionBoringModelNoForward) +) +def test_double_precision(tmpdir, boring_model): + model = boring_model() + original_training_step = model.training_step + + trainer = Trainer( + max_epochs=2, + default_root_dir=tmpdir, + fast_dev_run=2, + precision=64, + log_every_n_steps=1, + ) + trainer.fit(model) + trainer.test(model) + trainer.predict(model) + + assert model.training_step == original_training_step diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index a48f048160ee5..655e12f046e04 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -259,10 +259,12 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): @RunIf(skip_windows=True, special=True, fairscale=True) -@pytest.mark.parametrize("trainer_kwargs", ( - {'num_processes': 2}, - pytest.param({'gpus': 2}, marks=RunIf(min_gpus=2)) -)) +@pytest.mark.parametrize( + "trainer_kwargs", ( + dict(num_processes=2), + pytest.param(dict(gpus=2), marks=RunIf(min_gpus=2)), + ) +) def test_ddp_sharded_plugin_test_multigpu(tmpdir, trainer_kwargs): """ Test to ensure we can use validate and test without fit diff --git a/tests/special_tests.sh b/tests/special_tests.sh index b2ef6dfdacbf3..c381b5e9feeb6 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -14,9 +14,15 @@ # Running special tests set -e export PL_RUNNING_SPECIAL_TESTS=1 -DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" +DEFAULTS="-m coverage run --source pytorch_lightning --append -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_invalid_deepspeed_defaults_no_precision +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_warn_deepspeed_override_backward +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_run_configure_optimizers +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_config +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_custom_precision_params +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_assert_config_zero_offload_disabled python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_multigpu python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual @@ -28,8 +34,9 @@ python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_ddp python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_dp python ${DEFAULTS} tests/trainer/logging_/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp python ${DEFAULTS} tests/callbacks/test_pruning.py::test_pruning_callback_ddp -python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_ddp +python ${DEFAULTS} tests/test_profiler.py::test_pytorch_profiler_trainer_ddp python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp python ${DEFAULTS} tests/trainer/test_data_loading.py::test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model -nvprof --profile-from-start off -o trace_name.prof -- python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_nested_emit_nvtx +python ${DEFAULTS} tests/checkpointing/test_checkpoint_callback_frequency.py::test_top_k_ddp +nvprof --profile-from-start off -o trace_name.prof -- python ${DEFAULTS} tests/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 9b51ca7f7c6d2..a6e33b3366f33 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -13,13 +13,22 @@ # limitations under the License. import logging import os +import platform import time -from pathlib import Path +from copy import deepcopy +from distutils.version import LooseVersion import numpy as np import pytest +import torch -from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler +from pytorch_lightning.profiler.pytorch import RegisterRecordFunction +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005 @@ -40,14 +49,7 @@ def _sleep_generator(durations): @pytest.fixture def simple_profiler(): - profiler = SimpleProfiler() - return profiler - - -@pytest.fixture -def advanced_profiler(tmpdir): - profiler = AdvancedProfiler(output_filename=os.path.join(tmpdir, "profiler.txt")) - return profiler + return SimpleProfiler() @pytest.mark.parametrize(["action", "expected"], [ @@ -93,14 +95,6 @@ def test_simple_profiler_overhead(simple_profiler, n_iter=5): assert all(durations < PROFILER_OVERHEAD_MAX_TOLERANCE) -def test_simple_profiler_describe(caplog, simple_profiler): - """Ensure the profiler won't fail when reporting the summary.""" - with caplog.at_level(logging.INFO): - simple_profiler.describe() - - assert "Profiler Report" in caplog.text - - def test_simple_profiler_value_errors(simple_profiler): """Ensure errors are raised where expected.""" @@ -116,6 +110,77 @@ def test_simple_profiler_value_errors(simple_profiler): simple_profiler.stop(action) +def test_simple_profiler_deepcopy(tmpdir): + simple_profiler = SimpleProfiler(dirpath=tmpdir, filename="test") + simple_profiler.describe() + assert deepcopy(simple_profiler) + + +def test_simple_profiler_log_dir(tmpdir): + """Ensure the profiler dirpath defaults to `trainer.log_dir` when not present""" + profiler = SimpleProfiler(filename="profiler") + assert profiler._log_dir is None + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + profiler=profiler, + ) + trainer.fit(model) + + expected = tmpdir / "lightning_logs" / "version_0" + assert trainer.log_dir == expected + assert profiler._log_dir == trainer.log_dir + assert expected.join("fit-profiler.txt").exists() + + +@RunIf(skip_windows=True) +def test_simple_profiler_distributed_files(tmpdir): + """Ensure the proper files are saved in distributed""" + profiler = SimpleProfiler(dirpath=tmpdir, filename='profiler') + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=2, + accelerator="ddp_cpu", + num_processes=2, + profiler=profiler, + logger=False, + ) + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + + actual = set(os.listdir(profiler.dirpath)) + expected = {f"{stage}-profiler-{rank}.txt" for stage in ("fit", "validate", "test") for rank in (0, 1)} + assert actual == expected + + for f in profiler.dirpath.listdir(): + assert f.read_text('utf-8') + + +def test_simple_profiler_logs(tmpdir, caplog, simple_profiler): + """Ensure that the number of printed logs is correct""" + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=2, + profiler=simple_profiler, + logger=False, + ) + with caplog.at_level(logging.INFO, logger="pytorch_lightning.profiler.profilers"): + trainer.fit(model) + trainer.test(model) + + assert caplog.text.count("Profiler Report") == 2 + + +@pytest.fixture +def advanced_profiler(tmpdir): + return AdvancedProfiler(dirpath=tmpdir, filename="profiler") + + @pytest.mark.parametrize(["action", "expected"], [ pytest.param("a", [3, 1]), pytest.param("b", [2]), @@ -174,7 +239,8 @@ def test_advanced_profiler_describe(tmpdir, advanced_profiler): pass # log to stdout and print to file advanced_profiler.describe() - data = Path(advanced_profiler.output_fname).read_text() + path = advanced_profiler.dirpath / f"{advanced_profiler.filename}.txt" + data = path.read_text("utf-8") assert len(data) > 0 @@ -187,3 +253,259 @@ def test_advanced_profiler_value_errors(advanced_profiler): advanced_profiler.start(action) advanced_profiler.stop(action) + + +def test_advanced_profiler_deepcopy(advanced_profiler): + advanced_profiler.describe() + assert deepcopy(advanced_profiler) + + +@pytest.fixture +def pytorch_profiler(tmpdir): + return PyTorchProfiler(dirpath=tmpdir, filename="profiler") + + +@RunIf(max_torch="1.8.1") +def test_pytorch_profiler_describe(pytorch_profiler): + """Ensure the profiler won't fail when reporting the summary.""" + with pytorch_profiler.profile("on_test_start"): + torch.tensor(0) + + # log to stdout and print to file + pytorch_profiler.describe() + path = pytorch_profiler.dirpath / f"{pytorch_profiler.filename}.txt" + data = path.read_text("utf-8") + assert len(data) > 0 + + +def test_pytorch_profiler_raises(pytorch_profiler): + """Ensure errors are raised where expected.""" + with pytest.raises(MisconfigurationException, match="profiled_functions` and `PyTorchProfiler.record"): + PyTorchProfiler(profiled_functions=["a"], record_functions=["b"]) + + +@RunIf(min_torch="1.6.0") +def test_advanced_profiler_cprofile_deepcopy(tmpdir): + """Checks for pickle issue reported in #6522""" + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + profiler="advanced", + stochastic_weight_avg=True, + ) + trainer.fit(model) + + +@RunIf(min_gpus=2, special=True) +def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler): + """Ensure that the profiler can be given to the training and default step are properly recorded. """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=5, + limit_val_batches=5, + profiler=pytorch_profiler, + accelerator="ddp", + gpus=2, + ) + trainer.fit(model) + + expected = {'validation_step'} + if not _KINETO_AVAILABLE: + expected |= {'training_step_and_backward', 'training_step', 'backward'} + for name in expected: + assert sum(e.name == name for e in pytorch_profiler.function_events), name + + files = set(os.listdir(pytorch_profiler.dirpath)) + expected = f"fit-profiler-{trainer.local_rank}.txt" + assert expected in files + + path = pytorch_profiler.dirpath / expected + assert path.read_text("utf-8") + + if _KINETO_AVAILABLE: + files = os.listdir(pytorch_profiler.dirpath) + files = [file for file in files if file.endswith('.json')] + assert len(files) == 2, files + local_rank = trainer.local_rank + assert any(f'training_step_{local_rank}' in f for f in files) + assert any(f'validation_step_{local_rank}' in f for f in files) + + +def test_pytorch_profiler_trainer_test(tmpdir): + """Ensure that the profiler can be given to the trainer and test step are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_test_batches=2, + profiler=pytorch_profiler, + ) + trainer.test(model) + + assert sum(e.name == 'test_step' for e in pytorch_profiler.function_events) + + path = pytorch_profiler.dirpath / f"test-{pytorch_profiler.filename}.txt" + assert path.read_text("utf-8") + + if _KINETO_AVAILABLE: + files = sorted([file for file in os.listdir(tmpdir) if file.endswith('.json')]) + assert any(f'test_step_{trainer.local_rank}' in f for f in files) + + +def test_pytorch_profiler_trainer_predict(tmpdir): + """Ensure that the profiler can be given to the trainer and predict function are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) + model = BoringModel() + model.predict_dataloader = model.train_dataloader + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_predict_batches=2, + profiler=pytorch_profiler, + ) + trainer.predict(model) + + assert sum(e.name == 'predict_step' for e in pytorch_profiler.function_events) + path = pytorch_profiler.dirpath / f"predict-{pytorch_profiler.filename}.txt" + assert path.read_text("utf-8") + + +def test_pytorch_profiler_trainer_validate(tmpdir): + """Ensure that the profiler can be given to the trainer and validate function are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=2, + profiler=pytorch_profiler, + ) + trainer.validate(model) + + assert sum(e.name == 'validation_step' for e in pytorch_profiler.function_events) + + path = pytorch_profiler.dirpath / f"validate-{pytorch_profiler.filename}.txt" + assert path.read_text("utf-8") + + +def test_pytorch_profiler_nested(tmpdir): + """Ensure that the profiler handles nested context""" + + pytorch_profiler = PyTorchProfiler( + record_functions={"a", "b", "c"}, use_cuda=False, dirpath=tmpdir, filename="profiler", schedule=None + ) + + with pytorch_profiler.profile("a"): + a = torch.ones(42) + with pytorch_profiler.profile("b"): + b = torch.zeros(42) + with pytorch_profiler.profile("c"): + _ = a + b + + pytorch_profiler.describe() + + events_name = {e.name for e in pytorch_profiler.function_events} + + if platform.system() == "Windows": + expected = {'a', 'add', 'b', 'c', 'profiler::_record_function_enter', 'profiler::_record_function_exit'} + else: + expected = { + 'signed char', 'add', 'profiler::_record_function_exit', 'bool', 'char', 'profiler::_record_function_enter' + } + + if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + expected = {'add', 'zeros', 'ones', 'zero_', 'b', 'fill_', 'c', 'a', 'empty'} + + if LooseVersion(torch.__version__) >= LooseVersion("1.7.0"): + expected = { + 'aten::zeros', 'aten::add', 'aten::zero_', 'c', 'b', 'a', 'aten::fill_', 'aten::empty', 'aten::ones' + } + + assert events_name == expected, (events_name, torch.__version__, platform.system()) + + +@RunIf(min_gpus=1, special=True) +def test_pytorch_profiler_nested_emit_nvtx(tmpdir): + """ + This test check emit_nvtx is correctly supported + """ + profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True) + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + profiler=profiler, + gpus=1, + ) + trainer.fit(model) + + +@RunIf(min_torch="1.5.0") +def test_register_record_function(tmpdir): + + use_cuda = torch.cuda.is_available() + pytorch_profiler = PyTorchProfiler( + export_to_chrome=False, + record_functions={"a"}, + use_cuda=use_cuda, + dirpath=tmpdir, + filename="profiler", + schedule=None, + on_trace_ready=None, + ) + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), torch.nn.Linear(1, 1)) + + model = TestModel() + input = torch.rand((1, 1)) + + if use_cuda: + model = model.cuda() + input = input.cuda() + + with pytorch_profiler.profile("a"): + with RegisterRecordFunction(model): + model(input) + + pytorch_profiler.describe() + event_names = [e.name for e in pytorch_profiler.function_events] + assert 'torch.nn.modules.container.Sequential: layer' in event_names + assert 'torch.nn.modules.linear.Linear: layer.0' in event_names + assert 'torch.nn.modules.activation.ReLU: layer.1' in event_names + assert 'torch.nn.modules.linear.Linear: layer.2' in event_names + + +@pytest.mark.parametrize("cls", (SimpleProfiler, AdvancedProfiler, PyTorchProfiler)) +def test_profiler_teardown(tmpdir, cls): + """ + This test checks if profiler teardown method is called when trainer is exiting. + """ + + class TestCallback(Callback): + + def on_fit_end(self, trainer, *args, **kwargs) -> None: + # describe sets it to None + assert trainer.profiler._output_file is None + + profiler = cls(dirpath=tmpdir, filename="profiler") + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()]) + trainer.fit(model) + + assert profiler._output_file is None + + +def test_pytorch_profiler_deepcopy(tmpdir): + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profiler", schedule=None) + pytorch_profiler.start("on_train_start") + torch.tensor(1) + pytorch_profiler.describe() + assert deepcopy(pytorch_profiler) diff --git a/tests/trainer/flags/test_env_vars.py b/tests/trainer/flags/test_env_vars.py index ba76820d15ee8..65b251a6633b5 100644 --- a/tests/trainer/flags/test_env_vars.py +++ b/tests/trainer/flags/test_env_vars.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from unittest import mock from pytorch_lightning import Trainer -def test_passing_env_variables(tmpdir): +def test_passing_no_env_variables(): """Testing overwriting trainer arguments """ trainer = Trainer() assert trainer.logger is not None @@ -25,17 +26,29 @@ def test_passing_env_variables(tmpdir): assert trainer.logger is None assert trainer.max_steps == 42 - os.environ['PL_TRAINER_LOGGER'] = 'False' - os.environ['PL_TRAINER_MAX_STEPS'] = '7' + +@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "False", "PL_TRAINER_MAX_STEPS": "7"}) +def test_passing_env_variables_only(): + """Testing overwriting trainer arguments """ trainer = Trainer() assert trainer.logger is None assert trainer.max_steps == 7 - os.environ['PL_TRAINER_LOGGER'] = 'True' + +@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "True", "PL_TRAINER_MAX_STEPS": "7"}) +def test_passing_env_variables_defaults(): + """Testing overwriting trainer arguments """ trainer = Trainer(False, max_steps=42) - assert trainer.logger is not None - assert trainer.max_steps == 7 + assert trainer.logger is None + assert trainer.max_steps == 42 + - # this has to be cleaned - del os.environ['PL_TRAINER_LOGGER'] - del os.environ['PL_TRAINER_MAX_STEPS'] +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"}) +@mock.patch('torch.cuda.device_count', return_value=2) +@mock.patch('torch.cuda.is_available', return_value=True) +def test_passing_env_variables_gpus(cuda_available_mock, device_count_mock): + """Testing overwriting trainer arguments """ + trainer = Trainer() + assert trainer.gpus == 2 + trainer = Trainer(gpus=1) + assert trainer.gpus == 1 diff --git a/tests/trainer/logging_/test_eval_loop_logging_1_0.py b/tests/trainer/logging_/test_eval_loop_logging_1_0.py index 72084454ba10d..674e2aeb6511b 100644 --- a/tests/trainer/logging_/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_/test_eval_loop_logging_1_0.py @@ -126,7 +126,6 @@ def validation_step_end(self, acc): def validation_epoch_end(self, outputs): self.log('g', torch.tensor(2, device=self.device), on_epoch=True) self.validation_epoch_end_called = True - assert len(self.trainer.evaluation_loop.outputs) == 0 def backward(self, loss, optimizer, optimizer_idx): return LightningModule.backward(self, loss, optimizer, optimizer_idx) @@ -496,9 +495,15 @@ def on_validation_start(self, trainer, pl_module): ) def on_epoch_start(self, trainer, pl_module): - self.make_logging( - pl_module, 'on_epoch_start', 2, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices - ) + if trainer.validating: + self.make_logging( + pl_module, + 'on_epoch_start', + 2, + on_steps=self.choices, + on_epochs=self.choices, + prob_bars=self.choices + ) def on_validation_epoch_start(self, trainer, pl_module): self.make_logging( @@ -530,7 +535,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, self.count += 1 def on_epoch_end(self, trainer, pl_module): - if not trainer.training: + if trainer.validating: self.make_logging( pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices ) @@ -568,7 +573,6 @@ def validation_step(self, batch, batch_idx): callbacks=[test_callback], ) trainer.fit(model) - trainer.test() assert test_callback.funcs_called_count["on_epoch_start"] == 1 # assert test_callback.funcs_called_count["on_batch_start"] == 1 diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 3db0a8eaa065b..d14ed71940328 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -447,13 +447,38 @@ def is_float(value: Any) -> bool: "y": torch.tensor(2), "z": acc(preds, targets), }) - metric_holder.convert(False, device) + metric_holder.convert(device) metrics = metric_holder.metrics assert excepted_function(metrics["x"]) assert excepted_function(metrics["y"]) assert excepted_function(metrics["z"]) +def test_metric_holder_raises(tmpdir): + """Check that an error is raised when trying to convert non-scalar tensors""" + + class TestModel(BoringModel): + + def validation_step(self, batch, *args, **kwargs): + output = self(batch) + return {"test": output} + + def test_step(self, *args, **kwargs): + return self.validation_step(*args, **kwargs) + + model = TestModel() + model.validation_epoch_end = None + model.test_epoch_end = None + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + + match = "The metric `test` does not contain a single element" + with pytest.raises(MisconfigurationException, match=match): + trainer.validate(model) + with pytest.raises(MisconfigurationException, match=match): + trainer.test(model) + + def test_logging_to_progress_bar_with_reserved_key(tmpdir): """ Test that logging a metric with a reserved name to the progress bar raises a warning. """ @@ -465,10 +490,7 @@ def training_step(self, *args, **kwargs): return output model = TestModel() - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=2, - ) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"): trainer.fit(model) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 34845c46b45eb..f13448187364c 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -292,7 +292,9 @@ def test_init_optimizers_during_evaluation(tmpdir, fn): """ Test that optimizers is an empty list during evaluation """ + class TestModel(BoringModel): + def configure_optimizers(self): optimizer1 = torch.optim.Adam(self.parameters(), lr=0.1) optimizer2 = torch.optim.Adam(self.parameters(), lr=0.1) diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 4dc5b5f34b50c..5dc1ea5de4e8a 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -13,7 +13,6 @@ # limitations under the License. from pytorch_lightning import Trainer -from tests.accelerators import DDPLauncher from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -81,25 +80,3 @@ def test_get_model_gpu(tmpdir): gpus=1, ) trainer.fit(model) - - -@RunIf(min_gpus=1, skip_windows=True) -@DDPLauncher.run("--accelerator [accelerator]", max_epochs=["1"], accelerator=["ddp", "ddp_spawn"]) -def test_get_model_ddp_gpu(tmpdir, args=None): - """ - Tests that `trainer.lightning_module` extracts the model correctly when using GPU + ddp accelerators - """ - - model = TrainerGetModel() - - limit_train_batches = 2 - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=limit_train_batches, - limit_val_batches=2, - max_epochs=1, - gpus=1, - accelerator=args.accelerator - ) - trainer.fit(model) - return 1 diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index 59e10480a485e..9fccd9b36440a 100644 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest +import torch -from pytorch_lightning import Trainer +from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel +from tests.helpers import BoringModel, RandomDataset def test_wrong_train_setting(tmpdir): @@ -101,3 +102,48 @@ def test_val_loop_config(tmpdir): model = BoringModel() model.validation_step = None trainer.validate(model) + + +@pytest.mark.parametrize("datamodule", [False, True]) +def test_trainer_predict_verify_config(tmpdir, datamodule): + + class TestModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + class TestLightningDataModule(LightningDataModule): + + def __init__(self, dataloaders): + super().__init__() + self._dataloaders = dataloaders + + def test_dataloader(self): + return self._dataloaders + + def predict_dataloader(self): + return self._dataloaders + + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] + + model = TestModel() + + trainer = Trainer(default_root_dir=tmpdir) + + if datamodule: + datamodule = TestLightningDataModule(dataloaders) + results = trainer.predict(model, datamodule=datamodule) + else: + results = trainer.predict(model, dataloaders=dataloaders) + + assert len(results) == 2 + assert results[0][0].shape == torch.Size([1, 2]) + + model.predict_dataloader = None + + with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"): + trainer.predict(model) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index e4aea38fb7f37..505af173b7910 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -137,6 +137,7 @@ def test_multiple_eval_dataloader(tmpdir, ckpt_path): """Verify multiple evaluation dataloaders.""" class MultipleTestDataloaderModel(EvalModelTemplate): + def test_dataloader(self): return [self.dataloader(train=False), self.dataloader(train=False)] @@ -1158,3 +1159,71 @@ def test_replace_sampler_with_multiprocessing_context(tmpdir): new_data_loader = trainer.replace_sampler(train, SequentialSampler(train.dataset)) assert (new_data_loader.multiprocessing_context == train.multiprocessing_context) + + +def test_request_dataloader(tmpdir): + """ + This test asserts dataloader can be modified and properly set to the trainer. + """ + + class DataLoaderWrapper: + + def __init__(self, loader): + self.loader = loader + self._iter = iter(self.loader) + + def __iter__(self): + self._iter = iter(self.loader) + return self._iter + + def __next__(self): + return next(self._iter) + + class DataLoaderFunc: + + def __init__(self, loader): + self.loader = loader + + def __call__(self): + return self.loader + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.on_train_dataloader_called = False + self.on_train_batch_start_called = False + self.on_val_dataloader_called = False + self.on_val_batch_start_called = False + + def on_train_dataloader(self) -> None: + loader = self.train_dataloader() + self.train_dataloader = DataLoaderFunc(DataLoaderWrapper(loader)) + self.on_train_dataloader_called = True + + def on_train_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None: + assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper) + self.on_train_batch_start_called = True + + def on_val_dataloader(self) -> None: + loader = self.val_dataloader() + self.val_dataloader = DataLoaderFunc(DataLoaderWrapper(loader)) + self.on_val_dataloader_called = True + + def on_validation_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None: + assert isinstance(self.trainer.val_dataloaders[0], DataLoaderWrapper) + self.on_val_batch_start_called = True + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + model = TestModel() + trainer.fit(model) + trainer.test(model) + assert model.on_train_dataloader_called + assert model.on_train_batch_start_called + assert model.on_val_dataloader_called + assert model.on_val_batch_start_called diff --git a/tests/trainer/test_evaluation_loop.py b/tests/trainer/test_evaluation_loop.py new file mode 100644 index 0000000000000..3fe58afde7341 --- /dev/null +++ b/tests/trainer/test_evaluation_loop.py @@ -0,0 +1,42 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +from pytorch_lightning import Trainer +from tests.helpers.boring_model import BoringModel + + +@mock.patch("pytorch_lightning.trainer.evaluation_loop.EvaluationLoop.call_on_evaluation_epoch_end_hook") +def test_call_on_evaluation_epoch_end_hook(eval_epoch_end_mock, tmpdir): + """ + Tests that `call_on_evaluation_epoch_end_hook` is called + for `on_validation_epoch_end` and `on_test_epoch_end` hooks + """ + model = BoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + weights_summary=None, + ) + + trainer.fit(model) + # sanity + 2 epochs + assert eval_epoch_end_mock.call_count == 3 + + trainer.test() + # sanity + 2 epochs + called once for test + assert eval_epoch_end_mock.call_count == 4 diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index e85c43361976d..44510eb16184d 100644 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -271,3 +271,27 @@ def test_lr_finder_fails_fast_on_bad_config(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_steps=2, auto_lr_find=True) with pytest.raises(MisconfigurationException, match='should have one of these fields'): trainer.tune(BoringModel()) + + +def test_lr_find_with_bs_scale(tmpdir): + """ Test that lr_find runs with batch_size_scaling """ + + class BoringModelTune(BoringModel): + + def __init__(self, learning_rate=0.1, batch_size=2): + super().__init__() + self.save_hyperparameters() + + model = BoringModelTune() + before_lr = model.hparams.learning_rate + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + ) + bs = trainer.tuner.scale_batch_size(model) + lr = trainer.tuner.lr_find(model).suggestion() + + assert lr != before_lr + assert isinstance(bs, int) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5b06879b1f6d1..4ca2f737f5106 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -17,7 +17,6 @@ import sys from argparse import Namespace from copy import deepcopy -from distutils.version import LooseVersion from pathlib import Path from unittest.mock import ANY, call, patch @@ -43,12 +42,6 @@ from tests.helpers.runif import RunIf -@pytest.fixture -def pytorch_profiler(tmpdir): - profiler = PyTorchProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"), local_rank=0) - return profiler - - @pytest.mark.parametrize("url_ckpt", [True, False]) def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" @@ -603,7 +596,9 @@ def test_benchmark_option(tmpdir): @pytest.mark.parametrize("save_top_k", (-1, 0, 1, 2)) @pytest.mark.parametrize("fn", ("validate", "test")) def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn): + class TestModel(BoringModel): + def validation_step(self, batch, batch_idx): self.log("foo", -batch_idx) return super().validation_step(batch, batch_idx) @@ -1445,16 +1440,30 @@ def test_trainer_predict_no_return(tmpdir): class CustomBoringModel(BoringModel): - def predict(self, batch, batch_idx, dataloader_idx=None): + def predict_step(self, batch, batch_idx, dataloader_idx=None): if (batch_idx + 1) % 2 == 0: return - return super().predict(batch, batch_idx, dataloader_idx) + return super().predict_step(batch, batch_idx, dataloader_idx) with pytest.warns(UserWarning, match='predict returned None'): predict(tmpdir, None, None, 1, model=CustomBoringModel()) +def test_trainer_predict_grad(tmpdir): + + class CustomBoringModel(BoringModel): + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + assert batch.expand_as(batch).grad_fn is None + return super().predict_step(batch, batch_idx, dataloader_idx) + + predict(tmpdir, None, None, 1, model=CustomBoringModel()) + + x = torch.zeros(1, requires_grad=True) + assert x.expand_as(x).grad_fn is not None + + @pytest.mark.parametrize('datamodule', [False, True]) def test_trainer_predict_cpu(tmpdir, datamodule): predict(tmpdir, None, None, 1, datamodule=datamodule) @@ -1486,124 +1495,6 @@ def test_trainer_predict_ddp_cpu(tmpdir): predict(tmpdir, "ddp_cpu", 0, 2) -def test_pytorch_profiler_describe(pytorch_profiler): - """Ensure the profiler won't fail when reporting the summary.""" - with pytorch_profiler.profile("test_step"): - pass - - # log to stdout and print to file - pytorch_profiler.describe() - data = Path(pytorch_profiler.output_fname).read_text() - assert len(data) > 0 - - -def test_pytorch_profiler_value_errors(pytorch_profiler): - """Ensure errors are raised where expected.""" - - action = "test_step" - with pytest.raises(ValueError): - pytorch_profiler.stop(action) - - pytorch_profiler.start(action) - pytorch_profiler.stop(action) - - -@RunIf(min_gpus=2, special=True) -@pytest.mark.parametrize("use_output_filename", [False, True]) -def test_pytorch_profiler_trainer_ddp(tmpdir, use_output_filename): - """Ensure that the profiler can be given to the training and default step are properly recorded. """ - - if use_output_filename: - output_filename = os.path.join(tmpdir, "profiler.txt") - else: - output_filename = None - - profiler = PyTorchProfiler(output_filename=output_filename) - - model = BoringModel() - trainer = Trainer( - fast_dev_run=True, - profiler=profiler, - accelerator="ddp", - gpus=2, - ) - trainer.fit(model) - - enabled = use_output_filename or not use_output_filename and profiler.local_rank == 0 - - if enabled: - assert len(profiler.summary()) > 0 - assert set(profiler.profiled_actions.keys()) == {'training_step_and_backward', 'validation_step'} - else: - assert profiler.summary() is None - assert set(profiler.profiled_actions.keys()) == set() - - if use_output_filename: - profiler.describe() - data = Path(profiler.output_fname).read_text() - assert len(data) > 0 - - -def test_pytorch_profiler_nested(tmpdir): - """Ensure that the profiler handles nested context""" - - pytorch_profiler = PyTorchProfiler( - profiled_functions=["a", "b", "c"], use_cuda=False, output_filename=os.path.join(tmpdir, "profiler.txt") - ) - - with pytorch_profiler.profile("a"): - a = torch.ones(42) - with pytorch_profiler.profile("b"): - b = torch.zeros(42) - with pytorch_profiler.profile("c"): - _ = a + b - - pa = pytorch_profiler.profiled_actions - - # From PyTorch 1.8.0, less operation are being traced. - if LooseVersion(torch.__version__) >= LooseVersion("1.8.0"): - expected_ = { - 'a': ['ones', 'empty', 'fill_', 'zeros', 'empty', 'zero_', 'add'], - 'b': ['zeros', 'empty', 'zero_'], - 'c': ['add'], - } - # From PyTorch 1.6.0, more operation are being traced. - elif LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - expected_ = { - 'a': ['ones', 'empty', 'fill_', 'zeros', 'empty', 'zero_', 'fill_', 'add', 'empty'], - 'b': ['zeros', 'empty', 'zero_', 'fill_'], - 'c': ['add', 'empty'], - } - else: - expected_ = { - 'a': ['add'], - 'b': [], - 'c': ['add'], - } - - for n in ('a', 'b', 'c'): - pa[n] = [e.name for e in pa[n]] - if LooseVersion(torch.__version__) >= LooseVersion("1.7.1"): - pa[n] = [e.replace("aten::", "") for e in pa[n]] - assert pa[n] == expected_[n] - - -@RunIf(min_gpus=1, special=True) -def test_pytorch_profiler_nested_emit_nvtx(tmpdir): - """ - This test check emit_nvtx is correctly supported - """ - profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True) - - model = BoringModel() - trainer = Trainer( - fast_dev_run=True, - profiler=profiler, - gpus=1, - ) - trainer.fit(model) - - @pytest.mark.parametrize( ["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"], [(0.2, 0, 0, 0, False), (0.5, 10, 2, 4, True)], @@ -1854,3 +1745,35 @@ def test_check_val_every_n_epoch_exception(tmpdir): max_epochs=1, check_val_every_n_epoch=1.2, ) + + +def test_trainer_attach_data_pipeline_to_model(tmpdir): + + class DataPipeline: + + pass + + class TestDataModule(LightningDataModule): + + data_pipeline = DataPipeline() + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + class TestCallback(Callback): + + def on_fit_start(self, trainer, pl_module: LightningModule) -> None: + """Called when fit begins""" + assert isinstance(pl_module.data_pipeline, DataPipeline) + + model = BoringModel() + dm = TestDataModule() + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=[TestCallback()]) + trainer.fit(model, datamodule=dm) diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py new file mode 100644 index 0000000000000..ad7fc57092f32 --- /dev/null +++ b/tests/tuner/test_scale_batch_size.py @@ -0,0 +1,65 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from torch.utils.data import DataLoader + +from pytorch_lightning import Trainer +from pytorch_lightning.tuner.tuning import Tuner +from tests.helpers import BoringDataModule, BoringModel + + +class BatchSizeDataModule(BoringDataModule): + + def __init__(self, batch_size=None): + super().__init__() + if batch_size is not None: + self.batch_size = batch_size + + def train_dataloader(self): + return DataLoader(self.random_train, batch_size=getattr(self, "batch_size", 1)) + + +class BatchSizeModel(BoringModel): + + def __init__(self, batch_size=None): + super().__init__() + if batch_size is not None: + self.batch_size = batch_size + + +@pytest.mark.parametrize( + "model,datamodule", [ + (BatchSizeModel(2), None), + (BatchSizeModel(2), BatchSizeDataModule(2)), + (BatchSizeModel(2), BatchSizeDataModule(None)), + (BatchSizeModel(None), BatchSizeDataModule(2)), + ] +) +def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule): + """ Test the tuner method `Tuner.scale_batch_size` with a datamodule. """ + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=0, + max_epochs=1, + ) + tuner = Tuner(trainer) + new_batch_size = tuner.scale_batch_size( + model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule + ) + assert new_batch_size == 16 + if hasattr(model, "batch_size"): + assert model.batch_size == 16 + if datamodule is not None and hasattr(datamodule, "batch_size"): + assert datamodule.batch_size == 16 diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 259f9f4c09871..d67c9473bbb2e 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -95,3 +95,26 @@ def training_epoch_end(self, outputs) -> None: trainer.fit(model) assert model.training_epoch_end_called + + +@RunIf(min_gpus=2, skip_windows=True, special=True) +def test_all_gather_sync_grads(tmpdir): + + class TestModel(BoringModel): + + training_step_called = False + + def training_step(self, batch, batch_idx): + self.training_step_called = True + tensor = torch.rand(2, 2, requires_grad=True, device=self.device) + gathered_tensor = self.all_gather(tensor, sync_grads=True) + assert gathered_tensor.shape == torch.Size([2, 2, 2]) + + loss = gathered_tensor.sum() + + return loss + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2) + trainer.fit(model) + assert model.training_step_called diff --git a/tests/utilities/test_argparse_utils.py b/tests/utilities/test_argparse.py similarity index 80% rename from tests/utilities/test_argparse_utils.py rename to tests/utilities/test_argparse.py index b2eac514941e6..f13af4362364c 100644 --- a/tests/utilities/test_argparse_utils.py +++ b/tests/utilities/test_argparse.py @@ -1,17 +1,52 @@ import io -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace from typing import List +from unittest.mock import MagicMock import pytest from pytorch_lightning import Trainer from pytorch_lightning.utilities.argparse import ( + _gpus_arg_default, + _int_or_float_type, add_argparse_args, + from_argparse_args, get_abbrev_qualified_cls_name, + parse_argparser, parse_args_from_docstring, ) +class ArgparseExample: + + def __init__(self, a: int = 0, b: str = '', c: bool = False): + self.a = a + self.b = b + self.c = c + + +def test_from_argparse_args(): + args = Namespace(a=1, b='test', c=True, d='not valid') + my_instance = from_argparse_args(ArgparseExample, args) + assert my_instance.a == 1 + assert my_instance.b == 'test' + assert my_instance.c + + parser = ArgumentParser() + mock_trainer = MagicMock() + _ = from_argparse_args(mock_trainer, parser) + mock_trainer.parse_argparser.assert_called_once_with(parser) + + +def test_parse_argparser(): + args = Namespace(a=1, b='test', c=None, d='not valid') + new_args = parse_argparser(ArgparseExample, args) + assert new_args.a == 1 + assert new_args.b == 'test' + assert new_args.c + assert new_args.d == 'not valid' + + def test_parse_args_from_docstring_normal(): args_help = parse_args_from_docstring( """Constrain image dataset @@ -168,3 +203,13 @@ def test_add_argparse_args_no_argument_group(): args = parser.parse_args(fake_argv) assert args.main_arg == "abc" assert args.my_parameter == 2 + + +def test_gpus_arg_default(): + assert _gpus_arg_default('1,2') == '1,2' + assert _gpus_arg_default('1') == 1 + + +def test_int_or_float_type(): + assert isinstance(_int_or_float_type('0.0'), float) + assert isinstance(_int_or_float_type('0'), int) From 6af3f1e0948c50de4653825c2fa783276652e2a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 26 Mar 2021 12:23:39 +0100 Subject: [PATCH 10/25] Change from manual to pythonic tempfile --- pytorch_lightning/loggers/mlflow.py | 13 ++++++++----- tests/loggers/test_mlflow.py | 3 ++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 82412c79d3f1f..75e62539dd0a8 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -17,6 +17,7 @@ """ import logging import re +import tempfile from argparse import Namespace from pathlib import Path from time import time @@ -199,12 +200,14 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: import matplotlib.pyplot as plt - # save temporary file until logged - filename = Path(self.save_dir) / (name + f"_step_{step}" + self._figure_file_extension) - figure.savefig(filename) - self.experiment.log_artifact(self.run_id, filename, artifact_path="figure_" + name) + with tempfile.NamedTemporaryFile(suffix=self._figure_file_extension) as tmp_file: + figure.savefig(tmp_file) + self.experiment.log_artifact( + self.run_id, + tmp_file.name, + artifact_path=Path(self.save_dir) / ("figure_" + name) + ) - filename.unlink() # delete temporary file if close: plt.close(figure) diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index a49e7bd23f065..cb461fe4ef387 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -247,4 +247,5 @@ def test_mlflow_log_figure(client, mlflow, step_idx, figure_format, tmpdir): fname_expect = logger.save_dir + f'/dummy_step_{step_idx}{figure_format}' artifact_expect = 'figure_dummy' - mock_log.assert_called_once_with(logger.run_id, Path(fname_expect), artifact_path=artifact_expect) + mock_log.assert_called_once() + mock_log.call_args_list[0][1]['artifact_path'] == artifact_expect From 7f2407b67d0ca9eeea2b582902ab8970966256c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 26 Mar 2021 12:33:09 +0100 Subject: [PATCH 11/25] Revert "Merge master" This reverts commit 32f67c6dce617067938c4d9c64eef05b344c7ee6. --- .github/CODEOWNERS | 6 +- .github/workflows/ci_dockers.yml | 18 +- .github/workflows/ci_test-base.yml | 5 +- .github/workflows/ci_test-conda.yml | 2 +- .github/workflows/ci_test-full.yml | 6 +- .github/workflows/docs-checks.yml | 2 +- .github/workflows/events-nightly.yml | 23 - .github/workflows/release-docker.yml | 26 +- .gitignore | 1 - .pre-commit-config.yaml | 5 + CHANGELOG.md | 113 +-- Makefile | 2 +- azure-pipelines.yml | 34 +- dockers/nvidia/Dockerfile | 19 +- dockers/release/Dockerfile | 1 + docs/source/advanced/multi_gpu.rst | 56 +- docs/source/advanced/multiple_loaders.rst | 50 +- docs/source/benchmarking/performance.rst | 16 - docs/source/common/lightning_module.rst | 101 +- docs/source/common/trainer.rst | 5 +- docs/source/conf.py | 25 +- docs/source/extensions/callbacks.rst | 12 - docs/source/extensions/datamodules.rst | 13 +- docs/source/extensions/logging.rst | 2 +- docs/source/extensions/metrics.rst | 884 +++++++++++++++++- docs/source/starter/introduction_guide.rst | 6 +- docs/source/starter/new-project.rst | 2 +- pl_examples/__init__.py | 4 +- pl_examples/basic_examples/autoencoder.py | 14 +- .../basic_examples/conv_sequential_example.py | 2 +- .../basic_examples/profiler_example.py | 102 -- pl_examples/basic_examples/submit_ddp2_job.sh | 2 +- pl_examples/basic_examples/submit_ddp_job.sh | 2 +- .../computer_vision_fine_tuning.py | 5 +- pytorch_lightning/__init__.py | 81 +- pytorch_lightning/accelerators/accelerator.py | 86 +- pytorch_lightning/accelerators/cpu.py | 13 - pytorch_lightning/accelerators/gpu.py | 15 +- pytorch_lightning/accelerators/tpu.py | 39 +- pytorch_lightning/callbacks/base.py | 12 +- pytorch_lightning/callbacks/early_stopping.py | 2 +- .../gradient_accumulation_scheduler.py | 2 +- .../callbacks/model_checkpoint.py | 40 +- pytorch_lightning/callbacks/progress.py | 9 +- pytorch_lightning/core/datamodule.py | 67 +- pytorch_lightning/core/hooks.py | 106 +-- pytorch_lightning/core/lightning.py | 23 +- pytorch_lightning/core/memory.py | 12 +- pytorch_lightning/core/step_result.py | 2 +- pytorch_lightning/distributed/dist.py | 51 +- pytorch_lightning/info.py | 36 - pytorch_lightning/loggers/mlflow.py | 14 +- pytorch_lightning/metrics/__init__.py | 7 - .../metrics/classification/accuracy.py | 130 ++- .../metrics/classification/auc.py | 67 +- .../metrics/classification/auroc.py | 158 +++- .../classification/average_precision.py | 103 +- .../classification/confusion_matrix.py | 90 +- .../metrics/classification/f_beta.py | 180 +++- .../classification/hamming_distance.py | 88 +- .../metrics/classification/helpers.py | 539 +++++++++++ .../metrics/classification/iou.py | 83 +- .../classification/precision_recall.py | 283 +++++- .../classification/precision_recall_curve.py | 131 ++- .../metrics/classification/roc.py | 126 ++- .../metrics/classification/stat_scores.py | 245 ++++- pytorch_lightning/metrics/compositional.py | 97 +- .../metrics/functional/accuracy.py | 101 +- pytorch_lightning/metrics/functional/auc.py | 59 +- pytorch_lightning/metrics/functional/auroc.py | 168 +++- .../metrics/functional/average_precision.py | 78 +- .../metrics/functional/classification.py | 331 ++++++- .../metrics/functional/confusion_matrix.py | 74 +- .../metrics/functional/explained_variance.py | 69 +- .../metrics/functional/f_beta.py | 120 ++- .../metrics/functional/hamming_distance.py | 61 +- .../metrics/functional/image_gradients.py | 58 +- pytorch_lightning/metrics/functional/iou.py | 87 +- .../metrics/functional/mean_absolute_error.py | 35 +- .../metrics/functional/mean_relative_error.py | 38 +- .../metrics/functional/mean_squared_error.py | 35 +- .../functional/mean_squared_log_error.py | 35 +- pytorch_lightning/metrics/functional/nlp.py | 89 +- .../metrics/functional/precision_recall.py | 453 ++++++++- .../functional/precision_recall_curve.py | 196 +++- pytorch_lightning/metrics/functional/psnr.py | 84 +- .../metrics/functional/r2score.py | 112 ++- pytorch_lightning/metrics/functional/roc.py | 132 ++- .../metrics/functional/self_supervised.py | 41 +- pytorch_lightning/metrics/functional/ssim.py | 130 ++- .../metrics/functional/stat_scores.py | 260 +++++- pytorch_lightning/metrics/metric.py | 603 +++++++++++- .../metrics/regression/explained_variance.py | 106 ++- .../metrics/regression/mean_absolute_error.py | 64 +- .../metrics/regression/mean_squared_error.py | 65 +- .../regression/mean_squared_log_error.py | 67 +- pytorch_lightning/metrics/regression/psnr.py | 123 ++- .../metrics/regression/r2score.py | 122 ++- pytorch_lightning/metrics/regression/ssim.py | 78 +- pytorch_lightning/metrics/utils.py | 287 +++++- pytorch_lightning/overrides/base.py | 2 +- .../overrides/torch_distributed.py | 94 -- pytorch_lightning/plugins/__init__.py | 2 - .../plugins/precision/__init__.py | 1 - .../plugins/precision/apex_amp.py | 1 + pytorch_lightning/plugins/precision/double.py | 95 -- .../plugins/precision/native_amp.py | 18 - .../plugins/precision/precision_plugin.py | 1 + .../plugins/training_type/ddp.py | 71 +- .../plugins/training_type/ddp_spawn.py | 8 +- .../plugins/training_type/deepspeed.py | 10 + pytorch_lightning/plugins/training_type/dp.py | 6 +- .../plugins/training_type/horovod.py | 17 +- .../plugins/training_type/parallel.py | 38 +- .../plugins/training_type/rpc.py | 2 +- .../plugins/training_type/rpc_sequential.py | 2 +- .../plugins/training_type/single_device.py | 12 +- .../plugins/training_type/single_tpu.py | 7 +- .../plugins/training_type/tpu_spawn.py | 31 +- .../training_type/training_type_plugin.py | 45 +- .../plugins/training_type/utils.py | 13 - pytorch_lightning/profiler/__init__.py | 25 +- pytorch_lightning/profiler/profilers.py | 277 ++---- pytorch_lightning/profiler/pytorch.py | 601 ++++-------- pytorch_lightning/setup_tools.py | 6 +- pytorch_lightning/trainer/callback_hook.py | 60 +- .../trainer/configuration_validator.py | 9 +- .../connectors/accelerator_connector.py | 15 +- .../trainer/connectors/data_connector.py | 4 - .../trainer/connectors/env_vars_connector.py | 14 +- .../logger_connector/epoch_result_store.py | 9 +- .../logger_connector/logger_connector.py | 11 +- .../logger_connector/metrics_holder.py | 57 +- .../trainer/connectors/profiler_connector.py | 8 +- pytorch_lightning/trainer/data_loading.py | 23 +- pytorch_lightning/trainer/deprecated_api.py | 56 +- pytorch_lightning/trainer/evaluation_loop.py | 47 +- pytorch_lightning/trainer/predict_loop.py | 25 +- pytorch_lightning/trainer/properties.py | 10 - pytorch_lightning/trainer/trainer.py | 92 +- pytorch_lightning/trainer/training_loop.py | 39 +- pytorch_lightning/tuner/tuning.py | 13 +- pytorch_lightning/utilities/__init__.py | 1 - pytorch_lightning/utilities/argparse.py | 20 +- pytorch_lightning/utilities/argparse_utils.py | 6 +- pytorch_lightning/utilities/distributed.py | 6 +- pytorch_lightning/utilities/imports.py | 24 +- pytorch_lightning/utilities/model_utils.py | 7 +- .../utilities/signature_utils.py | 22 - pytorch_lightning/utilities/warning_utils.py | 6 +- .../utilities/xla_device_utils.py | 7 +- requirements.txt | 2 - requirements/adjust_versions.py | 1 - requirements/extra.txt | 3 +- requirements/test.txt | 7 +- setup.cfg | 17 +- setup.py | 46 +- tests/__init__.py | 4 +- .../test_accelerator_connector.py | 30 +- tests/accelerators/test_common.py | 12 +- tests/accelerators/test_cpu.py | 32 - tests/accelerators/test_ddp.py | 30 +- tests/base/model_template.py | 3 +- tests/callbacks/test_callback_hook_outputs.py | 63 -- tests/callbacks/test_callbacks.py | 18 +- .../test_checkpoint_callback_frequency.py | 40 - .../checkpointing/test_legacy_checkpoints.py | 5 +- tests/checkpointing/test_model_checkpoint.py | 5 +- tests/checkpointing/test_torch_saving.py | 1 - tests/core/test_datamodules.py | 61 +- tests/core/test_hooks.py | 56 -- tests/core/test_memory.py | 24 +- tests/core/test_metric_result_integration.py | 2 +- tests/deprecated_api/__init__.py | 18 - tests/deprecated_api/test_remove_1-4.py | 18 + tests/deprecated_api/test_remove_1-5.py | 107 --- tests/helpers/advanced_models.py | 4 +- tests/helpers/datasets.py | 15 +- tests/helpers/runif.py | 7 - tests/helpers/test_datasets.py | 13 +- tests/loggers/test_mlflow.py | 28 - tests/metrics/classification/__init__.py | 0 tests/metrics/classification/inputs.py | 66 ++ tests/metrics/classification/test_accuracy.py | 175 ++++ tests/metrics/classification/test_auc.py | 64 ++ tests/metrics/classification/test_auroc.py | 142 +++ .../classification/test_average_precision.py | 97 ++ .../classification/test_confusion_matrix.py | 128 +++ tests/metrics/classification/test_f_beta.py | 153 +++ .../classification/test_hamming_distance.py | 80 ++ tests/metrics/classification/test_inputs.py | 311 ++++++ tests/metrics/classification/test_iou.py | 216 +++++ .../classification/test_precision_recall.py | 347 +++++++ .../test_precision_recall_curve.py | 97 ++ tests/metrics/classification/test_roc.py | 99 ++ .../classification/test_stat_scores.py | 255 +++++ tests/metrics/functional/__init__.py | 0 .../metrics/functional/test_classification.py | 89 ++ .../functional/test_image_gradients.py | 109 +++ tests/metrics/functional/test_nlp.py | 68 ++ tests/metrics/functional/test_reduction.py | 28 + .../functional/test_self_supervised.py | 32 + tests/metrics/regression/__init__.py | 0 .../regression/test_explained_variance.py | 77 ++ tests/metrics/regression/test_mean_error.py | 87 ++ tests/metrics/regression/test_psnr.py | 133 +++ tests/metrics/regression/test_r2score.py | 114 +++ tests/metrics/regression/test_ssim.py | 104 +++ tests/metrics/test_composition.py | 510 ++++++++++ tests/metrics/test_ddp.py | 71 ++ tests/metrics/test_metric.py | 395 ++++++++ tests/metrics/test_metric_lightning.py | 8 +- tests/metrics/test_remove_1-5_metrics.py | 348 ------- tests/metrics/utils.py | 3 +- .../data/horovod/train_default_model.py | 9 +- tests/models/test_amp.py | 31 +- tests/models/test_hooks.py | 261 ++---- tests/models/test_horovod.py | 118 +-- tests/models/test_tpu.py | 41 - tests/overrides/test_data_parallel.py | 2 +- tests/plugins/test_custom_plugin.py | 41 - tests/plugins/test_deepspeed_plugin.py | 12 +- tests/plugins/test_double_plugin.py | 129 --- tests/plugins/test_sharded_plugin.py | 10 +- tests/special_tests.sh | 13 +- tests/test_profiler.py | 360 +------ tests/trainer/flags/test_env_vars.py | 31 +- .../logging_/test_eval_loop_logging_1_0.py | 16 +- .../trainer/logging_/test_logger_connector.py | 32 +- tests/trainer/optimization/test_optimizers.py | 2 - tests/trainer/properties/test_get_model.py | 23 + tests/trainer/test_config_validator.py | 50 +- tests/trainer/test_dataloaders.py | 69 -- tests/trainer/test_evaluation_loop.py | 42 - tests/trainer/test_lr_finder.py | 24 - tests/trainer/test_trainer.py | 177 +++- tests/tuner/test_scale_batch_size.py | 65 -- tests/utilities/test_all_gather_grad.py | 23 - ...est_argparse.py => test_argparse_utils.py} | 47 +- 239 files changed, 12644 insertions(+), 4902 deletions(-) delete mode 100644 pl_examples/basic_examples/profiler_example.py delete mode 100644 pytorch_lightning/info.py create mode 100644 pytorch_lightning/metrics/classification/helpers.py delete mode 100644 pytorch_lightning/overrides/torch_distributed.py delete mode 100644 pytorch_lightning/plugins/precision/double.py delete mode 100644 pytorch_lightning/utilities/signature_utils.py delete mode 100644 tests/core/test_hooks.py create mode 100644 tests/metrics/classification/__init__.py create mode 100644 tests/metrics/classification/inputs.py create mode 100644 tests/metrics/classification/test_accuracy.py create mode 100644 tests/metrics/classification/test_auc.py create mode 100644 tests/metrics/classification/test_auroc.py create mode 100644 tests/metrics/classification/test_average_precision.py create mode 100644 tests/metrics/classification/test_confusion_matrix.py create mode 100644 tests/metrics/classification/test_f_beta.py create mode 100644 tests/metrics/classification/test_hamming_distance.py create mode 100644 tests/metrics/classification/test_inputs.py create mode 100644 tests/metrics/classification/test_iou.py create mode 100644 tests/metrics/classification/test_precision_recall.py create mode 100644 tests/metrics/classification/test_precision_recall_curve.py create mode 100644 tests/metrics/classification/test_roc.py create mode 100644 tests/metrics/classification/test_stat_scores.py create mode 100644 tests/metrics/functional/__init__.py create mode 100644 tests/metrics/functional/test_classification.py create mode 100644 tests/metrics/functional/test_image_gradients.py create mode 100644 tests/metrics/functional/test_nlp.py create mode 100644 tests/metrics/functional/test_reduction.py create mode 100644 tests/metrics/functional/test_self_supervised.py create mode 100644 tests/metrics/regression/__init__.py create mode 100644 tests/metrics/regression/test_explained_variance.py create mode 100644 tests/metrics/regression/test_mean_error.py create mode 100644 tests/metrics/regression/test_psnr.py create mode 100644 tests/metrics/regression/test_r2score.py create mode 100644 tests/metrics/regression/test_ssim.py create mode 100644 tests/metrics/test_composition.py create mode 100644 tests/metrics/test_ddp.py create mode 100644 tests/metrics/test_metric.py delete mode 100644 tests/metrics/test_remove_1-5_metrics.py delete mode 100644 tests/plugins/test_custom_plugin.py delete mode 100644 tests/plugins/test_double_plugin.py delete mode 100644 tests/trainer/test_evaluation_loop.py delete mode 100644 tests/tuner/test_scale_batch_size.py rename tests/utilities/{test_argparse.py => test_argparse_utils.py} (80%) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 4ac6944c7a31a..6afdcc4cbe29f 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -34,9 +34,9 @@ /pytorch_lightning/utilities @borda @tchaton @SeanNaren @carmocca # Metrics -/pytorch_lightning/metrics/ @SkafteNicki @ananyahjha93 @justusschock -/tests/metrics/ @SkafteNicki @ananyahjha93 @justusschock -/docs/source/metrics.rst @SkafteNicki @ananyahjha93 @justusschock +/pytorch_lightning/metrics/ @teddykoker @ananyahjha93 @justusschock +/tests/metrics/ @teddykoker @ananyahjha93 @justusschock +/docs/source/metrics.rst @teddykoker @ananyahjha93 @justusschock # API /pytorch_lightning/callbacks/base.py @williamfalcon diff --git a/.github/workflows/ci_dockers.yml b/.github/workflows/ci_dockers.yml index 897e16a12d44f..9f77fb76aa593 100644 --- a/.github/workflows/ci_dockers.yml +++ b/.github/workflows/ci_dockers.yml @@ -29,6 +29,9 @@ jobs: - name: Checkout uses: actions/checkout@v2 + # https://github.com/docker/setup-buildx-action + # Set up Docker Buildx - to use cache-from and cache-to argument of buildx command + - uses: docker/setup-buildx-action@v1 - name: Build PL Docker # publish master/release uses: docker/build-push-action@v2 @@ -51,6 +54,9 @@ jobs: - name: Checkout uses: actions/checkout@v2 + # https://github.com/docker/setup-buildx-action + # Set up Docker Buildx - to use cache-from and cache-to argument of buildx command + - uses: docker/setup-buildx-action@v1 - name: Build XLA Docker # publish master/release uses: docker/build-push-action@v2 @@ -87,6 +93,9 @@ jobs: echo "::set-output name=CUDA::$cuda" id: extend + # https://github.com/docker/setup-buildx-action + # Set up Docker Buildx - to use cache-from and cache-to argument of buildx command + - uses: docker/setup-buildx-action@v1 - name: Build CUDA Docker # publish master/release uses: docker/build-push-action@v2 @@ -121,6 +130,9 @@ jobs: echo "::set-output name=CUDA::$cuda" id: extend + # https://github.com/docker/setup-buildx-action + # Set up Docker Buildx - to use cache-from and cache-to argument of buildx command + - uses: docker/setup-buildx-action@v1 - name: Build CUDA Docker # publish master/release uses: docker/build-push-action@v2 @@ -138,8 +150,10 @@ jobs: steps: - name: Checkout uses: actions/checkout@v2 - - - name: Build NVIDIA Docker + # https://github.com/docker/setup-buildx-action + # Set up Docker Buildx - to use cache-from and cache-to argument of buildx command + - uses: docker/setup-buildx-action@v1 + - name: Build CUDA Docker uses: docker/build-push-action@v2 with: file: dockers/nvidia/Dockerfile diff --git a/.github/workflows/ci_test-base.yml b/.github/workflows/ci_test-base.yml index 77363992718af..ed8a2e30949b7 100644 --- a/.github/workflows/ci_test-base.yml +++ b/.github/workflows/ci_test-base.yml @@ -51,8 +51,9 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade --user pip - pip install --requirement ./requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade - pip install "pytest>6.0" "pytest-cov>2.10" --upgrade-strategy only-if-needed + pip install --requirement ./requirements.txt --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade + pip install --requirement ./requirements/test.txt --quiet --upgrade-strategy only-if-needed + # pip install tox coverage python --version pip --version pip list diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index da853bf623d1b..812d06f310812 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -44,7 +44,7 @@ jobs: - name: Tests run: | # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 - coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-torch${{ matrix.pytorch-version }}.xml + python -m pytest pytorch_lightning tests --cov=pytorch_lightning -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-torch${{ matrix.pytorch-version }}.xml shell: bash -l {0} - name: Upload pytest results diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 5a3e23a37fd0b..3d3f7d11570a4 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -17,6 +17,10 @@ jobs: os: [ubuntu-18.04, windows-2019, macOS-10.15] python-version: [3.6, 3.7, 3.8] requires: ['minimal', 'latest'] + exclude: + # # todo: segmentation fault for minimal and hanging for latest + - python-version: 3.8 + os: ubuntu-18.04 # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 35 # TODO: the macOS is taking too long, probably caching did not work... @@ -134,7 +138,7 @@ jobs: - name: Tests run: | # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 - coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-py${{ matrix.python-version }}-${{ matrix.requires }}.xml + python -m pytest pytorch_lightning tests --cov=pytorch_lightning -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-py${{ matrix.python-version }}-${{ matrix.requires }}.xml - name: Examples run: | diff --git a/.github/workflows/docs-checks.yml b/.github/workflows/docs-checks.yml index 4488c598c8ac7..5ee4f23b4b3cc 100644 --- a/.github/workflows/docs-checks.yml +++ b/.github/workflows/docs-checks.yml @@ -98,7 +98,7 @@ jobs: # First run the same pipeline as Read-The-Docs cd docs make clean - make html --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going" + make html --debug --jobs $(nproc) SPHINXOPTS="-W" - name: Upload built docs uses: actions/upload-artifact@v2 diff --git a/.github/workflows/events-nightly.yml b/.github/workflows/events-nightly.yml index 5ad4396a006f7..24d8ce4002e5d 100644 --- a/.github/workflows/events-nightly.yml +++ b/.github/workflows/events-nightly.yml @@ -126,26 +126,3 @@ jobs: push: true tags: pytorchlightning/pytorch_lightning:base-conda-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }} timeout-minutes: 55 - -# docker-nvidia: -# runs-on: ubuntu-20.04 -# steps: -# - name: Checkout -# uses: actions/checkout@v2 -# -# # https://github.com/docker/setup-buildx-action -# # Set up Docker Buildx - to use cache-from and cache-to argument of buildx command -# - uses: docker/setup-buildx-action@v1 -# - name: Login to DockerHub -# uses: docker/login-action@v1 -# with: -# username: ${{ secrets.DOCKER_USERNAME }} -# password: ${{ secrets.DOCKER_PASSWORD }} -# -# - name: Publish NVIDIA to Docker Hub -# uses: docker/build-push-action@v2 -# with: -# file: dockers/nvidia/Dockerfile -# push: true -# tags: nvcr.io/pytorchlightning/pytorch_lightning:nvidia -# timeout-minutes: 55 diff --git a/.github/workflows/release-docker.yml b/.github/workflows/release-docker.yml index 36ecbe229ac7c..f285794cbc33b 100644 --- a/.github/workflows/release-docker.yml +++ b/.github/workflows/release-docker.yml @@ -8,7 +8,7 @@ on: types: [created] jobs: - cuda-PL: + build-PL: runs-on: ubuntu-20.04 strategy: fail-fast: false @@ -36,27 +36,3 @@ jobs: build_args: PYTHON_VERSION=${{ matrix.python_version }},PYTORCH_VERSION=${{ matrix.pytorch_version }},LIGHTNING_VERSION=${{ steps.get_version.outputs.RELEASE_VERSION }} tags: "${{ steps.get_version.outputs.RELEASE_VERSION }}-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }},latest-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}" timeout-minutes: 55 - -# nvidia-PL: -# runs-on: ubuntu-20.04 -# steps: -# - name: Checkout -# uses: actions/checkout@v2 -# -# - name: Get release version -# if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'release' -# id: get_version -# run: echo "::set-output name=RELEASE_VERSION::$(echo ${GITHUB_REF##*/})" -# -# - name: Publish Releases to Docker -# # only on releases -# uses: docker/build-push-action@v1.1.0 -# if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'release' -# with: -# repository: nvcr.io/pytorchlightning/pytorch_lightning -# username: ${{ secrets.DOCKER_USERNAME }} -# password: ${{ secrets.DOCKER_PASSWORD }} -# dockerfile: dockers/nvidia/Dockerfile -# build_args: LIGHTNING_VERSION=${{ steps.get_version.outputs.RELEASE_VERSION }} -# tags: "${{ steps.get_version.outputs.RELEASE_VERSION }}-nvidia" -# timeout-minutes: 55 diff --git a/.gitignore b/.gitignore index 99939ff7fce0c..cd0ba22453512 100644 --- a/.gitignore +++ b/.gitignore @@ -157,4 +157,3 @@ tags data MNIST runs -*trace* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 45eca43de93ac..21c52539a890d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,3 +33,8 @@ repos: hooks: - id: yapf args: [--parallel, --in-place] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.790 + hooks: + - id: mypy diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c50e3c54e305..4139a87d9f27b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,10 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) - - Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) @@ -33,45 +31,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915)) -- Added `teardown()` hook to LightningDataModule ([#4673](https://github.com/PyTorchLightning/pytorch-lightning/pull/4673)) - - - Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277)) - Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274)) -- Added `teardown` method to `BaseProfiler` to enable subclasses defining post-profiling steps outside of `__del__` ([#6370](https://github.com/PyTorchLightning/pytorch-lightning/pull/6370)) - - -- Added `setup` method to `BaseProfiler` to enable subclasses defining pre-profiling steps for every process ([#6633](https://github.com/PyTorchLightning/pytorch-lightning/pull/6633)) - - - Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) -- Added `Trainer.predict` config validation ([#6543](https://github.com/PyTorchLightning/pytorch-lightning/pull/6543)) - - -- Added `AbstractProfiler` interface ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) - - -- Added support for including module names for forward in the autograd trace of `PyTorchProfiler` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) - - -- Added support for the PyTorch 1.8.1 autograd profiler ([#6618](https://github.com/PyTorchLightning/pytorch-lightning/pull/6618)) - - -- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) - - -- Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595)) - - -- Added `artifact_location` argument to `MLFlowLogger` which will be passed to the `MlflowClient.create_experiment` call ([#6677](https://github.com/PyTorchLightning/pytorch-lightning/pull/6677)) - - ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) @@ -86,13 +54,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) -- Changed profilers to save separate report files per state and rank ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) - - -- Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) - - -- Changed the behavior of `on_epoch_start` to run at the beginning of validation & test epoch ([#6498](https://github.com/PyTorchLightning/pytorch-lightning/pull/6498)) +- Changed the default of `find_unused_parameters` back to `True` in DDP and DDP Spawn ([#6438](https://github.com/PyTorchLightning/pytorch-lightning/pull/6438)) ### Deprecated @@ -103,27 +65,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) -- Deprecated `Profiler(output_filename)` in favor of `dirpath` and `filename` ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) - - -- Deprecated `PytorchProfiler(profiled_functions)` in favor of `record_functions` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) - - -- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505), - [#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530), - [#6540](https://github.com/PyTorchLightning/pytorch-lightning/pull/6540), - [#6547](https://github.com/PyTorchLightning/pytorch-lightning/pull/6547), - [#6515](https://github.com/PyTorchLightning/pytorch-lightning/pull/6515), - [#6572](https://github.com/PyTorchLightning/pytorch-lightning/pull/6572), - [#6573](https://github.com/PyTorchLightning/pytorch-lightning/pull/6573), - [#6584](https://github.com/PyTorchLightning/pytorch-lightning/pull/6584), - [#6636](https://github.com/PyTorchLightning/pytorch-lightning/pull/6636), - [#6637](https://github.com/PyTorchLightning/pytorch-lightning/pull/6637), - [#6649](https://github.com/PyTorchLightning/pytorch-lightning/pull/6649), - [#6659](https://github.com/PyTorchLightning/pytorch-lightning/pull/6659), -) - - ### Removed - Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164)) @@ -157,9 +98,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565)) - - - Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)) @@ -175,57 +113,31 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `ModelCheckpoint(save_top_k=0, save_last=True)` not saving the `last` checkpoint ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136)) -- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275)) - - -- Fixed `.teardown(stage='fit')` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) - - -- Fixed `.on_fit_{start,end}()` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) - - -- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416)) - - -- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587)) - - -- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576)) +- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)) -- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506)) +- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275)) -- Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434)) +- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324)) -- Fixed error on TPUs when there was no `ModelCheckpoint` ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654)) +- Fixed `.teardown(stage='fit')` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) -- Fixed `trainer.test` freeze on TPUs ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654)) +- Fixed `.on_fit_{start,end}()` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) -- Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](https://github.com/PyTorchLightning/pytorch-lightning/pull/6657)) +- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688)) -## [1.2.4] - 2021-03-16 +- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380)) -### Changed -- Changed the default of `find_unused_parameters` back to `True` in DDP and DDP Spawn ([#6438](https://github.com/PyTorchLightning/pytorch-lightning/pull/6438)) +- Fixed DeepSpeed additional memory use on rank 0 when default device not set early enough ([#6460](https://github.com/PyTorchLightning/pytorch-lightning/pull/6460)) -### Fixed -- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)) -- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324)) -- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688)) -- Fixed broadcast to use PyTorch `broadcast_object_list` and add `reduce_decision` ([#6410](https://github.com/PyTorchLightning/pytorch-lightning/pull/6410)) -- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380)) -- Fixed DeepSpeed additional memory use on rank 0 when default device not set early enough ([#6460](https://github.com/PyTorchLightning/pytorch-lightning/pull/6460)) -- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398)) -- Fixed an issue with `Tuner.scale_batch_size` not finding the batch size attribute in the datamodule ([#5968](https://github.com/PyTorchLightning/pytorch-lightning/pull/5968)) -- Fixed an exception in the layer summary when the model contains torch.jit scripted submodules ([#6511](https://github.com/PyTorchLightning/pytorch-lightning/pull/6511)) -- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) +- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416)) ## [1.2.3] - 2021-03-09 @@ -245,6 +157,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372)) +- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398)) + + ## [1.2.2] - 2021-03-02 ### Added @@ -273,6 +188,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080)) - Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089)) - Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107)) + + - Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093)) diff --git a/Makefile b/Makefile index 04b08fa2d27d1..d35e0b77f8429 100644 --- a/Makefile +++ b/Makefile @@ -29,4 +29,4 @@ test: clean docs: clean pip install --quiet -r requirements/docs.txt - python -m sphinx -b html -W --keep-going docs/source docs/build + python -m sphinx -b html -W docs/source docs/build diff --git a/azure-pipelines.yml b/azure-pipelines.yml index d88a31ae9775a..6dfddda0295fe 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -78,7 +78,7 @@ jobs: displayName: 'Get legacy checkpoints' - bash: | - python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50 + python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 displayName: 'Testing: standard' - bash: | @@ -88,39 +88,19 @@ jobs: - bash: | python -m coverage report python -m coverage xml - python -m coverage html - python -m codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) --flags=gpu,pytest --name="GPU-coverage" --env=linux,azure - ls -l + codecov --token=$(CODECOV_TOKEN) --flags=gpu,pytest --name="GPU-coverage" --env=linux,azure displayName: 'Statistics' - - task: PublishTestResults@2 - displayName: 'Publish test results' - inputs: - testResultsFiles: '$(Build.StagingDirectory)/test-results.xml' - testRunTitle: '$(Agent.OS) - $(Build.DefinitionName) - Python $(python.version)' - condition: succeededOrFailed() - - - task: PublishCodeCoverageResults@1 - displayName: 'Publish coverage report' - inputs: - codeCoverageTool: 'cobertura' - summaryFileLocation: 'coverage.xml' - reportDirectory: '$(Build.SourcesDirectory)/htmlcov' - testRunTitle: '$(Agent.OS) - $(Build.BuildNumber)[$(Agent.JobName)] - Python $(python.version)' - condition: succeededOrFailed() - - bash: | python -m pytest benchmarks -v --maxfail=2 --durations=0 displayName: 'Testing: benchmarks' - - script: | - set -e + - bash: | python -m pytest pl_examples -v --maxfail=2 --durations=0 python setup.py install --user --quiet bash pl_examples/run_ddp-example.sh - # cd pl_examples/basic_examples - # bash submit_ddp_job.sh - # bash submit_ddp2_job.sh - env: - PL_USE_MOCKED_MNIST: "1" + cd pl_examples/basic_examples + bash submit_ddp_job.sh + bash submit_ddp2_job.sh + pip uninstall -y pytorch-lightning displayName: 'Examples' diff --git a/dockers/nvidia/Dockerfile b/dockers/nvidia/Dockerfile index 4b04bc9426d4d..ea567a5306eed 100644 --- a/dockers/nvidia/Dockerfile +++ b/dockers/nvidia/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM nvcr.io/nvidia/pytorch:21.02-py3 +FROM nvcr.io/nvidia/pytorch:20.12-py3 MAINTAINER PyTorchLightning @@ -22,17 +22,16 @@ COPY ./ ./pytorch-lightning/ # install dependencies RUN \ + # Disable cache #conda install "pip>20.1" && \ - pip list | grep torch && \ - if [ ! -z "$LIGHTNING_VERSION" ] ; then \ + #pip config set global.cache-dir false && \ + if [ -z $LIGHTNING_VERSION ] ; then \ + pip install ./pytorch-lightning --no-cache-dir ; \ rm -rf pytorch-lightning ; \ - wget https://github.com/PyTorchLightning/pytorch-lightning/archive/${LIGHTNING_VERSION}.zip --progress=bar:force:noscroll ; \ - unzip ${LIGHTNING_VERSION}.zip ; \ - mv pytorch-lightning-*/ pytorch-lightning ; \ - rm *.zip ; \ - fi && \ - pip install ./pytorch-lightning["extra"] --no-cache-dir && \ - rm -rf pytorch-lightning + else \ + rm -rf pytorch-lightning ; \ + pip install https://github.com/PyTorchLightning/pytorch-lightning/archive/${LIGHTNING_VERSION}.zip --no-cache-dir ; \ + fi RUN python --version && \ pip --version && \ diff --git a/dockers/release/Dockerfile b/dockers/release/Dockerfile index 0eec1e41a5a3f..3584ee02746e3 100644 --- a/dockers/release/Dockerfile +++ b/dockers/release/Dockerfile @@ -25,6 +25,7 @@ COPY ./ ./pytorch-lightning/ # install dependencies RUN \ + # Disable cache #conda install "pip>20.1" && \ if [ ! -z "$LIGHTNING_VERSION" ] ; then \ rm -rf pytorch-lightning ; \ diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 5cdb0b377f2b7..4fb90e7829fb4 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -267,7 +267,7 @@ Lightning allows multiple ways of training - TPUs (``tpu_cores=8|x``) (tpu or TPU pod) .. note:: - If you request multiple GPUs or nodes without setting a mode, DDP Spawn will be automatically used. + If you request multiple GPUs or nodes without setting a mode, DDP will be automatically used. For a deeper understanding of what Lightning is doing, feel free to read this `guide `_. @@ -697,48 +697,31 @@ To use DeepSpeed, you first need to install DeepSpeed using the commands below. .. code-block:: bash - pip install deepspeed + pip install deepspeed mpi4py If you run into an issue with the install or later in training, ensure that the CUDA version of the pytorch you've installed matches your locally installed CUDA (you can see which one has been recognized by running ``nvcc --version``). +Additionally if you run into any issues installing m4py, ensure you have openmpi installed using ``sudo apt install libopenmpi-dev`` or ``brew install mpich`` before running ``pip install mpi4py``. .. note:: Currently ``resume_from_checkpoint`` and manual optimization are not supported. DeepSpeed currently only supports single optimizer, single scheduler within the training loop. -DeepSpeed ZeRO Stage 2 -"""""""""""""""""""""" - -By default, we enable `DeepSpeed ZeRO Stage 2 `_, which partitions your optimizer states (Stage 1) and your gradients (Stage 2) across your GPUs to reduce memory. In most cases, this is more efficient or at parity with DDP, primarily due to the optimized custom communications written by the DeepSpeed team. -As a result, benefits can also be seen on a single GPU. Do note that the default bucket sizes allocate around ``3.6GB`` of VRAM to use during distributed communications, which can be tweaked when instantiating the plugin described in a few sections below. - -.. note:: - To use ZeRO, you must use ``precision=16``. - -.. code-block:: python - - from pytorch_lightning import Trainer - - model = MyModel() - trainer = Trainer(gpus=4, plugins='deepspeed', precision=16) - trainer.fit(model) - - -DeepSpeed ZeRO Stage 2 Offload -"""""""""""""""""""""""""""""" +ZeRO-Offload +"""""""""""" Below we show an example of running `ZeRO-Offload `_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption. +For even more speed benefit, they offer an optimized CPU version of ADAM to run the offloaded computation, which is faster than the standard PyTorch implementation. By default we enable ZeRO-Offload. .. note:: - To use ZeRO-Offload, you must use ``precision=16``. + To use ZeRO-Offload, you must use ``precision=16`` or set precision via `the DeepSpeed config. `_. .. code-block:: python from pytorch_lightning import Trainer - from pytorch_lightning.plugins import DeepSpeedPlugin model = MyModel() - trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(cpu_offload=True), precision=16) + trainer = Trainer(gpus=4, plugins='deepspeed', precision=16) trainer.fit(model) @@ -757,7 +740,7 @@ You can also modify the ZeRO-Offload parameters via the plugin as below. from pytorch_lightning.plugins import DeepSpeedPlugin model = MyModel() - trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(cpu_offload=True, allgather_bucket_size=5e8, reduce_bucket_size=5e8), precision=16) + trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(allgather_bucket_size=5e8, reduce_bucket_size=5e8), precision=16) trainer.fit(model) @@ -769,30 +752,11 @@ You can also modify the ZeRO-Offload parameters via the plugin as below. The plugin sets a reasonable default of ``2e8``, which should work for most low VRAM GPUs (less than ``7GB``), allocating roughly ``3.6GB`` of VRAM as buffer. Higher VRAM GPUs should aim for values around ``5e8``. -For even more speed benefit, DeepSpeed offers an optimized CPU version of ADAM called `DeepSpeedCPUAdam `_ to run the offloaded computation, which is faster than the standard PyTorch implementation. - -.. code-block:: python - - import pytorch_lightning - from pytorch_lightning import Trainer - from pytorch_lightning.plugins import DeepSpeedPlugin - from deepspeed.ops.adam import DeepSpeedCPUAdam - - class MyModel(pl.LightningModule): - ... - def configure_optimizers(self): - # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) - return DeepSpeedCPUAdam(self.parameters()) - - model = MyModel() - trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(cpu_offload=True), precision=16) - trainer.fit(model) - Custom DeepSpeed Config """"""""""""""""""""""" -In some cases you may want to define your own DeepSpeed Config, to access all parameters defined. We've exposed most of the important parameters, however, there may be debugging parameters to enable. Also, DeepSpeed allows the use of custom DeepSpeed optimizers and schedulers defined within a config file that is supported. +DeepSpeed allows use of custom DeepSpeed optimizers and schedulers defined within a config file. This allows you to enable optimizers such as `1-bit Adam `_. .. note:: All plugin default parameters will be ignored when a config object is passed. diff --git a/docs/source/advanced/multiple_loaders.rst b/docs/source/advanced/multiple_loaders.rst index 1a82641953c3c..3f230957ca283 100644 --- a/docs/source/advanced/multiple_loaders.rst +++ b/docs/source/advanced/multiple_loaders.rst @@ -9,7 +9,7 @@ Multiple Datasets Lightning supports multiple dataloaders in a few ways. 1. Create a dataloader that iterates multiple datasets under the hood. -2. In the training loop you can pass multiple loaders as a dict or list/tuple and lightning +2. In the training loop you can pass multiple loaders as a dict or list/tuple and lightning will automatically combine the batches from different loaders. 3. In the validation and test loop you also have the option to return multiple dataloaders which lightning will call sequentially. @@ -75,13 +75,13 @@ For more details please have a look at :paramref:`~pytorch_lightning.trainer.tra loader_a = torch.utils.data.DataLoader(range(6), batch_size=4) loader_b = torch.utils.data.DataLoader(range(15), batch_size=5) - + # pass loaders as a dict. This will create batches like this: # {'a': batch from loader_a, 'b': batch from loader_b} loaders = {'a': loader_a, 'b': loader_b} - # OR: + # OR: # pass loaders as sequence. This will create batches like this: # [batch from loader_a, batch from loader_b] loaders = [loader_a, loader_b] @@ -89,24 +89,7 @@ For more details please have a look at :paramref:`~pytorch_lightning.trainer.tra return loaders Furthermore, Lightning also supports that nested lists and dicts (or a combination) can -be returned. - -.. testcode:: - - class LitModel(LightningModule): - - def train_dataloader(self): - - loader_a = torch.utils.data.DataLoader(range(8), batch_size=4) - loader_b = torch.utils.data.DataLoader(range(16), batch_size=2) - - return {'a': loader_a, 'b': loader_b} - - def training_step(self, batch, batch_idx): - # access a dictionnary with a batch from each dataloader - batch_a = batch["a"] - batch_b = batch["b"] - +be returned .. testcode:: @@ -120,29 +103,12 @@ be returned. loader_c = torch.utils.data.DataLoader(range(64), batch_size=4) # pass loaders as a nested dict. This will create batches like this: - loaders = { - 'loaders_a_b': { - 'a': loader_a, - 'b': loader_b - }, - 'loaders_c_d': { - 'c': loader_c, - 'd': loader_d - } - } + # {'loader_a_b': {'a': batch from loader a, 'b': batch from loader b}, + # 'loader_c_d': {'c': batch from loader c, 'd': batch from loader d}} + loaders = {'loaders_a_b': {'a': loader_a, 'b': loader_b}, + 'loaders_c_d': {'c': loader_c, 'd': loader_d}} return loaders - def training_step(self, batch, batch_idx): - # access the data - batch_a_b = batch["loaders_a_b"] - batch_c_d = batch["loaders_c_d"] - - batch_a = batch_a_b["a"] - batch_b = batch_a_b["a"] - - batch_c = batch_c_d["c"] - batch_d = batch_c_d["d"] - ---------- Test/Val dataloaders diff --git a/docs/source/benchmarking/performance.rst b/docs/source/benchmarking/performance.rst index dbddaad3a5e3c..d1bc2c9ebc009 100644 --- a/docs/source/benchmarking/performance.rst +++ b/docs/source/benchmarking/performance.rst @@ -181,19 +181,3 @@ Most UNIX-based operating systems provide direct access to tmpfs through a mount .. code-block:: python datamodule = MyDataModule(data_root="/dev/shm/my_data") - - -Zero Grad ``set_to_none=True`` ------------------------------- - -In order to modestly improve performance, once can override :meth:`~pytorch_lightning.core.lightning.LightningModule.optimizer_zero_grad`. - -For a more detailed explanation of pros / cons of this technique, -read `this `_ documentation by the PyTorch team. - -.. testcode:: - - class Model(LightningModule): - - def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): - optimizer.zero_grad(set_to_none=True) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 6e67f591da7c7..f6deb9adf58d3 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1020,14 +1020,12 @@ This is the pseudocode to describe how all the hooks are called during a call to .. code-block:: python def fit(...): + on_fit_start() + if global_rank == 0: # prepare data is called on GLOBAL_ZERO only prepare_data() - configure_callbacks() - - on_fit_start() - for gpu/tpu in gpu/tpus: train_on_device(model.copy()) @@ -1045,7 +1043,6 @@ This is the pseudocode to describe how all the hooks are called during a call to teardown() def train_loop(): - on_epoch_start() on_train_epoch_start() train_outs = [] for train_batch in train_dataloader(): @@ -1071,15 +1068,12 @@ This is the pseudocode to describe how all the hooks are called during a call to val_loop() # end training epoch - outs = training_epoch_end(outs) - on_train_epoch_end(outs) - on_epoch_end() + logs = training_epoch_end(outs) def val_loop(): model.eval() torch.set_grad_enabled(False) - on_epoch_start() on_validation_epoch_start() val_outs = [] for val_batch in val_dataloader(): @@ -1093,7 +1087,6 @@ This is the pseudocode to describe how all the hooks are called during a call to validation_epoch_end(val_outs) on_validation_epoch_end() - on_epoch_end() # set up for train model.train() @@ -1121,12 +1114,12 @@ manual_backward on_after_backward ~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_after_backward +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_after_backward :noindex: on_before_zero_grad ~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_zero_grad :noindex: on_fit_start @@ -1145,38 +1138,15 @@ on_fit_end on_load_checkpoint ~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_load_checkpoint :noindex: on_save_checkpoint ~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint - :noindex: - -on_train_start -~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_start - :noindex: - -on_train_end -~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_end +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_save_checkpoint :noindex: -on_validation_start -~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_start - :noindex: - -on_validation_end -~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_end - :noindex: on_pretrain_routine_start ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1214,11 +1184,6 @@ on_test_epoch_end .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_end :noindex: -on_test_end -~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_end - :noindex: on_train_batch_start ~~~~~~~~~~~~~~~~~~~~ @@ -1232,18 +1197,6 @@ on_train_batch_end .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_batch_end :noindex: -on_epoch_start -~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_start - :noindex: - -on_epoch_end -~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_end - :noindex: - on_train_epoch_start ~~~~~~~~~~~~~~~~~~~~ @@ -1280,36 +1233,6 @@ on_validation_epoch_end .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_end :noindex: -on_post_move_to_device -~~~~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_post_move_to_device - :noindex: - -on_validation_model_eval -~~~~~~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval - :noindex: - -on_validation_model_train -~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train - :noindex: - -on_test_model_eval -~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval - :noindex: - -on_test_model_train -~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train - :noindex: - optimizer_step ~~~~~~~~~~~~~~ @@ -1331,7 +1254,7 @@ prepare_data setup ~~~~~ -.. automethod:: pytorch_lightning.core.hooks.DataHooks.setup +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.setup :noindex: tbptt_split_batch @@ -1343,25 +1266,25 @@ tbptt_split_batch teardown ~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.DataHooks.teardown +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.teardown :noindex: train_dataloader ~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.DataHooks.train_dataloader +.. automethod:: pytorch_lightning.core.lightning.LightningModule.train_dataloader :noindex: val_dataloader ~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.DataHooks.val_dataloader +.. automethod:: pytorch_lightning.core.lightning.LightningModule.val_dataloader :noindex: test_dataloader ~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.DataHooks.test_dataloader +.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_dataloader :noindex: transfer_batch_to_device diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index d86a8dc1ff472..5614e481e0888 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -1157,7 +1157,7 @@ precision | -Double precision (64), full precision (32) or half precision (16). +Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs. If used on TPU will use torch.bfloat16 but tensor printing @@ -1172,9 +1172,6 @@ will still show torch.float32. # 16-bit precision trainer = Trainer(precision=16, gpus=1) - # 64-bit precision - trainer = Trainer(precision=64) - Example:: # one day diff --git a/docs/source/conf.py b/docs/source/conf.py index 1c1f3be8a636a..ccf824bb37d9b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,6 +13,7 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. # import m2r +import builtins import glob import os import shutil @@ -26,13 +27,10 @@ FOLDER_GENERATED = 'generated' SPHINX_MOCK_REQUIREMENTS = int(os.environ.get('SPHINX_MOCK_REQUIREMENTS', True)) +if SPHINX_MOCK_REQUIREMENTS: + builtins.__LIGHTNING_SETUP__ = True -try: - from pytorch_lightning import info -except ImportError: - # alternative https://stackoverflow.com/a/67692/4521646 - sys.path.append(os.path.join(PATH_ROOT, "pytorch_lightning")) - import info +import pytorch_lightning # noqa: E402 # -- Project documents ------------------------------------------------------- @@ -81,13 +79,13 @@ def _transform_changelog(path_in: str, path_out: str) -> None: # -- Project information ----------------------------------------------------- project = 'PyTorch Lightning' -copyright = info.__copyright__ -author = info.__author__ +copyright = pytorch_lightning.__copyright__ +author = pytorch_lightning.__author__ # The short X.Y version -version = info.__version__ +version = pytorch_lightning.__version__ # The full version, including alpha/beta/rc tags -release = info.__version__ +release = pytorch_lightning.__version__ # -- General configuration --------------------------------------------------- @@ -178,8 +176,8 @@ def _transform_changelog(path_in: str, path_out: str) -> None: # documentation. html_theme_options = { - 'pytorch_project': 'https://pytorchlightning.ai', - 'canonical_url': info.__docs_url__, + 'pytorch_project': pytorch_lightning.__homepage__, + 'canonical_url': pytorch_lightning.__homepage__, 'collapse_navigation': False, 'display_version': True, 'logo_only': False, @@ -281,7 +279,6 @@ def _transform_changelog(path_in: str, path_out: str) -> None: 'torch': ('https://pytorch.org/docs/stable/', None), 'numpy': ('https://numpy.org/doc/stable/', None), 'PIL': ('https://pillow.readthedocs.io/en/stable/', None), - 'torchmetrics': ('https://torchmetrics.readthedocs.io/en/stable/', None), } # -- Options for todo extension ---------------------------------------------- @@ -331,11 +328,9 @@ def package_list_from_file(file): 'comet-ml': 'comet_ml', 'neptune-client': 'neptune', 'hydra-core': 'hydra', - 'pyDeprecate': 'deprecate', } MOCK_PACKAGES = [] if SPHINX_MOCK_REQUIREMENTS: - MOCK_PACKAGES += ['fairscale'] # mock also base packages when we are on RTD since we don't install them there MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements.txt')) MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements', 'extra.txt')) diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index 73691c6dd76f5..63a221a06119f 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -349,15 +349,3 @@ on_load_checkpoint .. automethod:: pytorch_lightning.callbacks.Callback.on_load_checkpoint :noindex: - -on_after_backward -^^^^^^^^^^^^^^^^^ - -.. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward - :noindex: - -on_before_zero_grad -^^^^^^^^^^^^^^^^^^^ - -.. automethod:: pytorch_lightning.callbacks.Callback.on_before_zero_grad - :noindex: diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 881febe21316d..85134fda06fa2 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -94,10 +94,6 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=self.batch_size) - def teardown(self, stage: Optional[str] = None): - # Used to clean-up when the run is finished - ... - But now, as the complexity of your processing grows (transforms, multiple-GPU training), you can let Lightning handle those details for you while making this dataset reusable so you can share with colleagues or use in different projects. @@ -247,10 +243,7 @@ There are also data operations you might want to perform on every GPU. Use setup self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) -.. warning:: ``setup`` is called from every process. Setting state here is okay. - - -.. note:: ``teardown`` can be used to clean up the state. It is also called from every process +.. warning:: `setup` is called from every process. Setting state here is okay. train_dataloader @@ -418,14 +411,10 @@ You can of course use DataModules in plain PyTorch code as well. for batch in dm.val_dataloader(): ... - dm.teardown(stage='fit') - # lazy load test data dm.setup(stage='test') for batch in dm.test_dataloader(): ... - dm.teardown(stage='test') - But overall, DataModules encourage reproducibility by allowing all details of a dataset to be specified in a unified structure. diff --git a/docs/source/extensions/logging.rst b/docs/source/extensions/logging.rst index 9ad17b5fd1821..a17d595f1fc44 100644 --- a/docs/source/extensions/logging.rst +++ b/docs/source/extensions/logging.rst @@ -90,7 +90,7 @@ The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a .. note:: - Setting ``on_epoch=True`` will cache all your logged values during the full training epoch and perform a - reduction in ``on_train_epoch_end``. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction. + reduction `on_epoch_end`. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction. - Setting both ``on_step=True`` and ``on_epoch=True`` will create two keys per metric you log with suffix ``_step`` and ``_epoch``, respectively. You can refer to these keys e.g. in the `monitor` diff --git a/docs/source/extensions/metrics.rst b/docs/source/extensions/metrics.rst index 74a4a15deb2be..6a64c42ec2753 100644 --- a/docs/source/extensions/metrics.rst +++ b/docs/source/extensions/metrics.rst @@ -1,9 +1,887 @@ +.. testsetup:: * + + import torch + from torch.nn import Module + from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.metrics import Metric + +.. _metrics: + ####### Metrics ####### -``pytorch_lightning.metrics`` has been moved to a separate package `TorchMetrics `_. -We will preserve compatibility for the next few releases, nevertheless, we encourage users to update to use this stand-alone package. +``pytorch_lightning.metrics`` is a Metrics API created for easy metric development and usage in +PyTorch and PyTorch Lightning. It is rigorously tested for all edge cases and includes a growing list of +common metric implementations. + +The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to the user. The metric base class inherits +``nn.Module`` which allows us to call ``metric(...)`` directly. The ``forward()`` method of the base ``Metric`` class +serves the dual purpose of calling ``update()`` on its input and simultaneously returning the value of the metric over the +provided input. .. warning:: - ``pytorch_lightning.metrics`` is deprecated from v1.3 and will be removed in v1.5. + From v1.2 onward ``compute()`` will no longer automatically call ``reset()``, + and it is up to the user to reset metrics between epochs, except in the case where the + metric is directly passed to ``LightningModule``'s ``self.log``. + +These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in +distributed mode, the internal state of each metric is synced and reduced across each process, so that the +logic present in ``.compute()`` is applied to state information from all processes. + +The example below shows how to use a metric in your ``LightningModule``: + +.. code-block:: python + + def __init__(self): + ... + self.accuracy = pl.metrics.Accuracy() + + def training_step(self, batch, batch_idx): + x, y = batch + preds = self(x) + ... + # log step metric + self.log('train_acc_step', self.accuracy(preds, y)) + ... + + def training_epoch_end(self, outs): + # log epoch metric + self.log('train_acc_epoch', self.accuracy.compute()) + + +``Metric`` objects can also be directly logged, in which case Lightning will log +the metric based on ``on_step`` and ``on_epoch`` flags present in ``self.log(...)``. +If ``on_epoch`` is True, the logger automatically logs the end of epoch metric value by calling +``.compute()``. + +.. note:: + ``sync_dist``, ``sync_dist_op``, ``sync_dist_group``, ``reduce_fx`` and ``tbptt_reduce_fx`` + flags from ``self.log(...)`` don't affect the metric logging in any manner. The metric class + contains its own distributed synchronization logic. + + This however is only true for metrics that inherit the base class ``Metric``, + and thus the functional metric API provides no support for in-built distributed synchronization + or reduction functions. + + +.. code-block:: python + + def __init__(self): + ... + self.train_acc = pl.metrics.Accuracy() + self.valid_acc = pl.metrics.Accuracy() + + def training_step(self, batch, batch_idx): + x, y = batch + preds = self(x) + ... + self.train_acc(preds, y) + self.log('train_acc', self.train_acc, on_step=True, on_epoch=False) + + def validation_step(self, batch, batch_idx): + logits = self(x) + ... + self.valid_acc(logits, y) + self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True) + +.. note:: + + If using metrics in data parallel mode (dp), the metric update/logging should be done + in the ``_step_end`` method (where ```` is either ``training``, ``validation`` + or ``test``). This is due to metric states else being destroyed after each forward pass, + leading to wrong accumulation. In practice do the following: + + .. code-block:: python + + def training_step(self, batch, batch_idx): + data, target = batch + preds = self(data) + ... + return {'loss' : loss, 'preds' : preds, 'target' : target} + + def training_step_end(self, outputs): + #update and log + self.metric(outputs['preds'], outputs['target']) + self.log('metric', self.metric) + +This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example: + +.. code-block:: python + + from pytorch_lightning import metrics + + train_accuracy = metrics.Accuracy() + valid_accuracy = metrics.Accuracy(compute_on_step=False) + + for epoch in range(epochs): + for x, y in train_data: + y_hat = model(x) + + # training step accuracy + batch_acc = train_accuracy(y_hat, y) + + for x, y in valid_data: + y_hat = model(x) + valid_accuracy(y_hat, y) + + # total accuracy over all training batches + total_train_accuracy = train_accuracy.compute() + + # total accuracy over all validation batches + total_valid_accuracy = valid_accuracy.compute() + +.. note:: + + Metrics contain internal states that keep track of the data seen so far. + Do not mix metric states across training, validation and testing. + It is highly recommended to re-initialize the metric per mode as + shown in the examples above. For easy initializing the same metric multiple + times, the ``.clone()`` method can be used: + + .. testcode:: + + from pytorch_lightning.metrics import Accuracy + + def __init__(self): + ... + metric = Accuracy() + self.train_acc = metric.clone() + self.val_acc = metric.clone() + self.test_acc = metric.clone() + +.. note:: + + Metric states are **not** added to the models ``state_dict`` by default. + To change this, after initializing the metric, the method ``.persistent(mode)`` can + be used to enable (``mode=True``) or disable (``mode=False``) this behaviour. + +******************* +Metrics and devices +******************* + +Metrics are simple subclasses of :class:`~torch.nn.Module` and their metric states behave +similar to buffers and parameters of modules. This means that metrics states should +be moved to the same device as the input of the metric: + +.. code-block:: python + + from pytorch_lightning.metrics import Accuracy + + target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0)) + preds = torch.tensor([0, 1, 0, 0], device=torch.device("cuda", 0)) + + # Metric states are always initialized on cpu, and needs to be moved to + # the correct device + confmat = Accuracy(num_classes=2).to(torch.device("cuda", 0)) + out = confmat(preds, target) + print(out.device) # cuda:0 + +However, when **properly defined** inside a :class:`~pytorch_lightning.core.lightning.LightningModule` +, Lightning will automatically move the metrics to the same device as the data. Being +**properly defined** means that the metric is correctly identified as a child module of the +model (check ``.children()`` attribute of the model). Therefore, metrics cannot be placed +in native python ``list`` and ``dict``, as they will not be correctly identified +as child modules. Instead of ``list`` use :class:`~torch.nn.ModuleList` and instead of +``dict`` use :class:`~torch.nn.ModuleDict`. + +.. testcode:: + + from pytorch_lightning.metrics import Accuracy + + class MyModule(LightningModule): + def __init__(self): + ... + # valid ways metrics will be identified as child modules + self.metric1 = Accuracy() + self.metric2 = nn.ModuleList(Accuracy()) + self.metric3 = nn.ModuleDict({'accuracy': Accuracy()}) + + def training_step(self, batch, batch_idx): + # all metrics will be on the same device as the input batch + data, target = batch + preds = self(data) + ... + val1 = self.metric1(preds, target) + val2 = self.metric2[0](preds, target) + val3 = self.metric3['accuracy'](preds, target) + + +********************* +Implementing a Metric +********************* + +To implement your custom metric, subclass the base ``Metric`` class and implement the following methods: + +- ``__init__()``: Each state variable should be called using ``self.add_state(...)``. +- ``update()``: Any code needed to update the state given any inputs to the metric. +- ``compute()``: Computes a final value from the state of the metric. + +All you need to do is call ``add_state`` correctly to implement a custom metric with DDP. +``reset()`` is called on metric state variables added using ``add_state()``. + +To see how metric states are synchronized across distributed processes, refer to ``add_state()`` docs +from the base ``Metric`` class. + +Example implementation: + +.. testcode:: + + from pytorch_lightning.metrics import Metric + + class MyAccuracy(Metric): + def __init__(self, dist_sync_on_step=False): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + preds, target = self._input_format(preds, target) + assert preds.shape == target.shape + + self.correct += torch.sum(preds == target) + self.total += target.numel() + + def compute(self): + return self.correct.float() / self.total + +Metrics support backpropagation, if all computations involved in the metric calculation +are differentiable. However, note that the cached state is detached from the computational +graph and cannot be backpropagated. Not doing this would mean storing the computational +graph for each update call, which can lead to out-of-memory errors. +In practise this means that: + +.. code-block:: python + + metric = MyMetric() + val = metric(pred, target) # this value can be backpropagated + val = metric.compute() # this value cannot be backpropagated + + +Metric API +---------- + +.. autoclass:: pytorch_lightning.metrics.Metric + :noindex: + +Internal implementation details +------------------------------- + +This section briefly describe how metrics work internally. We encourage looking at the source code for more info. +Internally, Lightning wraps the user defined ``update()`` and ``compute()`` method. We do this to automatically +synchronize and reduce metric states across multiple devices. More precisely, calling ``update()`` does the +following internally: + +1. Clears computed cache +2. Calls user-defined ``update()`` + +Simiarly, calling ``compute()`` does the following internally + +1. Syncs metric states between processes +2. Reduce gathered metric states +3. Calls the user defined ``compute()`` method on the gathered metric states +4. Cache computed result + +From a user's standpoint this has one important side-effect: computed results are cached. This means that no +matter how many times ``compute`` is called after one and another, it will continue to return the same result. +The cache is first emptied on the next call to ``update``. + +``forward`` serves the dual purpose of both returning the metric on the current data and updating the internal +metric state for accumulating over multiple batches. The ``forward()`` method achives this by combining calls +to ``update`` and ``compute`` in the following way (assuming metric is initialized with ``compute_on_step=True``): + +1. Calls ``update()`` to update the global metric states (for accumulation over multiple batches) +2. Caches the global state +3. Calls ``reset()`` to clear global metric state +4. Calls ``update()`` to update local metric state +5. Calls ``compute()`` to calculate metric for current batch +6. Restores the global state + +This procedure has the consequence of calling the user defined ``update`` **twice** during a single +forward call (one to update global statistics and one for getting the batch statistics). + + +****************** +Metric Arithmetics +****************** + +Metrics support most of python built-in operators for arithmetic, logic and bitwise operations. + +For example for a metric that should return the sum of two different metrics, implementing a new metric is an overhead that is not necessary. +It can now be done with: + +.. code-block:: python + + first_metric = MyFirstMetric() + second_metric = MySecondMetric() + + new_metric = first_metric + second_metric + +``new_metric.update(*args, **kwargs)`` now calls update of ``first_metric`` and ``second_metric``. It forwards all positional arguments but +forwards only the keyword arguments that are available in respective metric's update declaration. + +Similarly ``new_metric.compute()`` now calls compute of ``first_metric`` and ``second_metric`` and adds the results up. + +This pattern is implemented for the following operators (with ``a`` being metrics and ``b`` being metrics, tensors, integer or floats): + +* Addition (``a + b``) +* Bitwise AND (``a & b``) +* Equality (``a == b``) +* Floordivision (``a // b``) +* Greater Equal (``a >= b``) +* Greater (``a > b``) +* Less Equal (``a <= b``) +* Less (``a < b``) +* Matrix Multiplication (``a @ b``) +* Modulo (``a % b``) +* Multiplication (``a * b``) +* Inequality (``a != b``) +* Bitwise OR (``a | b``) +* Power (``a ** b``) +* Substraction (``a - b``) +* True Division (``a / b``) +* Bitwise XOR (``a ^ b``) +* Absolute Value (``abs(a)``) +* Inversion (``~a``) +* Negative Value (``neg(a)``) +* Positive Value (``pos(a)``) + +**************** +MetricCollection +**************** + +In many cases it is beneficial to evaluate the model output by multiple metrics. +In this case the `MetricCollection` class may come in handy. It accepts a sequence +of metrics and wraps theses into a single callable metric class, with the same +interface as any other metric. + +Example: + +.. testcode:: + + from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall + target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) + preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) + metric_collection = MetricCollection([ + Accuracy(), + Precision(num_classes=3, average='macro'), + Recall(num_classes=3, average='macro') + ]) + print(metric_collection(preds, target)) + +.. testoutput:: + :options: +NORMALIZE_WHITESPACE + + {'Accuracy': tensor(0.1250), + 'Precision': tensor(0.0667), + 'Recall': tensor(0.1111)} + +Similarly it can also reduce the amount of code required to log multiple metrics +inside your LightningModule + +.. code-block:: python + + def __init__(self): + ... + metrics = pl.metrics.MetricCollection(...) + self.train_metrics = metrics.clone() + self.valid_metrics = metrics.clone() + + def training_step(self, batch, batch_idx): + logits = self(x) + ... + self.train_metrics(logits, y) + # use log_dict instead of log + self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train') + + def validation_step(self, batch, batch_idx): + logits = self(x) + ... + self.valid_metrics(logits, y) + # use log_dict instead of log + self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val') + +.. note:: + + `MetricCollection` as default assumes that all the metrics in the collection + have the same call signature. If this is not the case, input that should be + given to different metrics can given as keyword arguments to the collection. + +.. autoclass:: pytorch_lightning.metrics.MetricCollection + :noindex: + + +*************************** +Class vs Functional Metrics +*************************** + +The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs. + +Also, the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. +If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also using the class interface. + +********************** +Classification Metrics +********************** + +Input types +----------- + +For the purposes of classification metrics, inputs (predictions and targets) are split +into these categories (``N`` stands for the batch size and ``C`` for number of classes): + +.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1 + :header: "Type", "preds shape", "preds dtype", "target shape", "target dtype" + :widths: 20, 10, 10, 10, 10 + + "Binary", "(N,)", "``float``", "(N,)", "``binary``\*" + "Multi-class", "(N,)", "``int``", "(N,)", "``int``" + "Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``" + "Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*" + "Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``" + "Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``" + +.. note:: + All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so + that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``. + +When predictions or targets are integers, it is assumed that class labels start at 0, i.e. +the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types + +.. testcode:: + + # Binary inputs + binary_preds = torch.tensor([0.6, 0.1, 0.9]) + binary_target = torch.tensor([1, 0, 2]) + + # Multi-class inputs + mc_preds = torch.tensor([0, 2, 1]) + mc_target = torch.tensor([0, 1, 2]) + + # Multi-class inputs with probabilities + mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]]) + mc_target_probs = torch.tensor([0, 1, 2]) + + # Multi-label inputs + ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]]) + ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]]) + + +Using the is_multiclass parameter +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In some cases, you might have inputs which appear to be (multi-dimensional) multi-class +but are actually binary/multi-label - for example, if both predictions and targets are +integer (binary) tensors. Or it could be the other way around, you want to treat +binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs. + +For these cases, the metrics where this distinction would make a difference, expose the +``is_multiclass`` argument. Let's see how this is used on the example of +:class:`~pytorch_lightning.metrics.StatScores` metric. + +First, let's consider the case with label predictions with 2 classes, which we want to +treat as binary. + +.. testcode:: + + from pytorch_lightning.metrics.functional import stat_scores + + # These inputs are supposed to be binary, but appear as multi-class + preds = torch.tensor([0, 1, 0]) + target = torch.tensor([1, 1, 0]) + +As you can see below, by default the inputs are treated +as multi-class. We can set ``is_multiclass=False`` to treat the inputs as binary - +which is the same as converting the predictions to float beforehand. + +.. doctest:: + + >>> stat_scores(preds, target, reduce='macro', num_classes=2) + tensor([[1, 1, 1, 0, 1], + [1, 0, 1, 1, 2]]) + >>> stat_scores(preds, target, reduce='macro', num_classes=1, is_multiclass=False) + tensor([[1, 0, 1, 1, 2]]) + >>> stat_scores(preds.float(), target, reduce='macro', num_classes=1) + tensor([[1, 0, 1, 1, 2]]) + +Next, consider the opposite example: inputs are binary (as predictions are probabilities), +but we would like to treat them as 2-class multi-class, to obtain the metric for both classes. + +.. testcode:: + + preds = torch.tensor([0.2, 0.7, 0.3]) + target = torch.tensor([1, 1, 0]) + +In this case we can set ``is_multiclass=True``, to treat the inputs as multi-class. + +.. doctest:: + + >>> stat_scores(preds, target, reduce='macro', num_classes=1) + tensor([[1, 0, 1, 1, 2]]) + >>> stat_scores(preds, target, reduce='macro', num_classes=2, is_multiclass=True) + tensor([[1, 1, 1, 0, 1], + [1, 0, 1, 1, 2]]) + + +Class Metrics (Classification) +------------------------------ + +Accuracy +~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.Accuracy + :noindex: + +AveragePrecision +~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.AveragePrecision + :noindex: + +AUC +~~~ + +.. autoclass:: pytorch_lightning.metrics.AUC + :noindex: + +AUROC +~~~~~ + +.. autoclass:: pytorch_lightning.metrics.AUROC + :noindex: + +ConfusionMatrix +~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.ConfusionMatrix + :noindex: + +F1 +~~ + +.. autoclass:: pytorch_lightning.metrics.F1 + :noindex: + +FBeta +~~~~~ + +.. autoclass:: pytorch_lightning.metrics.FBeta + :noindex: + +IoU +~~~ + +.. autoclass:: pytorch_lightning.metrics.IoU + :noindex: + +Hamming Distance +~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.HammingDistance + :noindex: + +Precision +~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.Precision + :noindex: + +PrecisionRecallCurve +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.PrecisionRecallCurve + :noindex: + +Recall +~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.Recall + :noindex: + +ROC +~~~ + +.. autoclass:: pytorch_lightning.metrics.ROC + :noindex: + + +StatScores +~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.StatScores + :noindex: + + +Functional Metrics (Classification) +----------------------------------- + +accuracy [func] +~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.accuracy + :noindex: + + +auc [func] +~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.auc + :noindex: + + +auroc [func] +~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.auroc + :noindex: + + +average_precision [func] +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.average_precision + :noindex: + + +confusion_matrix [func] +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.confusion_matrix + :noindex: + + +dice_score [func] +~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.dice_score + :noindex: + + +f1 [func] +~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.f1 + :noindex: + + +fbeta [func] +~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.fbeta + :noindex: + +hamming_distance [func] +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.hamming_distance + :noindex: + +iou [func] +~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.iou + :noindex: + + +roc [func] +~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.roc + :noindex: + + +precision [func] +~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.precision + :noindex: + + +precision_recall [func] +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.precision_recall + :noindex: + + +precision_recall_curve [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.precision_recall_curve + :noindex: + + +recall [func] +~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.recall + :noindex: + +select_topk [func] +~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.utils.select_topk + :noindex: + + +stat_scores [func] +~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.stat_scores + :noindex: + + +stat_scores_multiple_classes [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.stat_scores_multiple_classes + :noindex: + + +to_categorical [func] +~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.utils.to_categorical + :noindex: + + +to_onehot [func] +~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.utils.to_onehot + :noindex: + +****************** +Regression Metrics +****************** + +Class Metrics (Regression) +-------------------------- + +ExplainedVariance +~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.ExplainedVariance + :noindex: + + +MeanAbsoluteError +~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.MeanAbsoluteError + :noindex: + + +MeanSquaredError +~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.MeanSquaredError + :noindex: + + +MeanSquaredLogError +~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.MeanSquaredLogError + :noindex: + + +PSNR +~~~~ + +.. autoclass:: pytorch_lightning.metrics.PSNR + :noindex: + + +SSIM +~~~~ + +.. autoclass:: pytorch_lightning.metrics.SSIM + :noindex: + + +R2Score +~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.R2Score + :noindex: + +Functional Metrics (Regression) +------------------------------- + +explained_variance [func] +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.explained_variance + :noindex: + + +image_gradients [func] +~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.image_gradients + :noindex: + + +mean_absolute_error [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.mean_absolute_error + :noindex: + + +mean_squared_error [func] +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_error + :noindex: + + +mean_squared_log_error [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error + :noindex: + + +psnr [func] +~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.psnr + :noindex: + + +ssim [func] +~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.ssim + :noindex: + + +r2score [func] +~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.r2score + :noindex: + + +*** +NLP +*** + +bleu_score [func] +----------------- + +.. autofunction:: pytorch_lightning.metrics.functional.bleu_score + :noindex: + +******** +Pairwise +******** + +embedding_similarity [func] +--------------------------- + +.. autofunction:: pytorch_lightning.metrics.functional.embedding_similarity + :noindex: diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst index 551b8182caa7d..c65894367a39e 100644 --- a/docs/source/starter/introduction_guide.rst +++ b/docs/source/starter/introduction_guide.rst @@ -882,8 +882,8 @@ Or maybe we have a model that we use to do generation generated_imgs = model(z) -To perform inference at scale, it is possible to use ``trainer.predict`` with LightningModule ``predict_step`` function -By default, LightningModule ``predict_step`` calls forward, but it can be overriden to add any processing logic. +To perform inference at scale, it is possible to use ``trainer.predict`` with LightningModule ``predict`` function +By default, LightningModule ``predict`` calls forward, but it can be overriden to add any processing logic. .. code-block:: python @@ -893,7 +893,7 @@ By default, LightningModule ``predict_step`` calls forward, but it can be overri imgs = self.decoder(z) return imgs - def predict_step(self, batch, batch_idx: int , dataloader_idx: int = None): + def predict(self, batch, batch_idx: int , dataloader_idx: int = None): return self(batch) diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst index 7a1164b1bdf3a..f68865f3695c3 100644 --- a/docs/source/starter/new-project.rst +++ b/docs/source/starter/new-project.rst @@ -83,7 +83,7 @@ Step 1: Define LightningModule .. testcode:: - class LitAutoEncoder(pl.LightningModule): + class LitAutoEncoder(LightningModule): def __init__(self): super().__init__() diff --git a/pl_examples/__init__.py b/pl_examples/__init__.py index 150ac309ddceb..ffd60f9ed71af 100644 --- a/pl_examples/__init__.py +++ b/pl_examples/__init__.py @@ -15,10 +15,10 @@ _DATASETS_PATH = os.path.join(_PACKAGE_ROOT, 'Datasets') _TORCHVISION_AVAILABLE = _module_available("torchvision") -_TORCHVISION_MNIST_AVAILABLE = not bool(os.environ.get("PL_USE_MOCKED_MNIST", False)) +_TORCHVISION_MNIST_AVAILABLE = True _DALI_AVAILABLE = _module_available("nvidia.dali") -if _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: try: from torchvision.datasets.mnist import MNIST MNIST(_DATASETS_PATH, download=True) diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index 6841b8555ef1f..a2010a89f4461 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -39,17 +39,17 @@ class LitAutoEncoder(pl.LightningModule): ) """ - def __init__(self, hidden_dim: int = 64): + def __init__(self): super().__init__() self.encoder = nn.Sequential( - nn.Linear(28 * 28, hidden_dim), + nn.Linear(28 * 28, 64), nn.ReLU(), - nn.Linear(hidden_dim, 3), + nn.Linear(64, 3), ) self.decoder = nn.Sequential( - nn.Linear(3, hidden_dim), + nn.Linear(3, 64), nn.ReLU(), - nn.Linear(hidden_dim, 28 * 28), + nn.Linear(64, 28 * 28), ) def forward(self, x): @@ -94,7 +94,7 @@ def cli_main(): # ------------ parser = ArgumentParser() parser.add_argument('--batch_size', default=32, type=int) - parser.add_argument('--hidden_dim', type=int, default=64) + parser.add_argument('--hidden_dim', type=int, default=128) parser = pl.Trainer.add_argparse_args(parser) args = parser.parse_args() @@ -112,7 +112,7 @@ def cli_main(): # ------------ # model # ------------ - model = LitAutoEncoder(args.hidden_dim) + model = LitAutoEncoder() # ------------ # training diff --git a/pl_examples/basic_examples/conv_sequential_example.py b/pl_examples/basic_examples/conv_sequential_example.py index f3d9469144f50..1c35c69d29f37 100644 --- a/pl_examples/basic_examples/conv_sequential_example.py +++ b/pl_examples/basic_examples/conv_sequential_example.py @@ -27,11 +27,11 @@ import torch.nn as nn import torch.nn.functional as F import torchvision -from torchmetrics.functional import accuracy import pytorch_lightning as pl from pl_examples import cli_lightning_logo from pytorch_lightning import Trainer +from pytorch_lightning.metrics.functional import accuracy from pytorch_lightning.plugins import RPCSequentialPlugin from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE diff --git a/pl_examples/basic_examples/profiler_example.py b/pl_examples/basic_examples/profiler_example.py deleted file mode 100644 index ca640a96f9588..0000000000000 --- a/pl_examples/basic_examples/profiler_example.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script will generate 2 traces: one for `training_step` and one for `validation_step`. -The traces can be visualized in 2 ways: -* With Chrome: - 1. Open Chrome and copy/paste this url: `chrome://tracing/`. - 2. Once tracing opens, click on `Load` at the top-right and load one of the generated traces. -* With PyTorch Tensorboard Profiler (Instructions are here: https://github.com/pytorch/kineto/tree/master/tb_plugin) - 1. pip install tensorboard torch-tb-profiler - 2. tensorboard --logdir={FOLDER} -""" - -import sys -from argparse import ArgumentParser - -import torch -import torchvision -import torchvision.models as models -import torchvision.transforms as T - -from pl_examples import cli_lightning_logo -from pytorch_lightning import LightningDataModule, LightningModule, Trainer - -DEFAULT_CMD_LINE = ( - "--max_epochs", - "1", - "--limit_train_batches", - "15", - "--limit_val_batches", - "15", - "--profiler", - "pytorch", - "--gpus", - f"{int(torch.cuda.is_available())}", -) - - -class ModelToProfile(LightningModule): - - def __init__(self, model): - super().__init__() - self.model = model - self.criterion = torch.nn.CrossEntropyLoss() - - def training_step(self, batch, batch_idx): - inputs, labels = batch - outputs = self.model(inputs) - loss = self.criterion(outputs, labels) - self.log("train_loss", loss) - return loss - - def validation_step(self, batch, batch_idx): - inputs, labels = batch - outputs = self.model(inputs) - loss = self.criterion(outputs, labels) - self.log("val_loss", loss) - - def configure_optimizers(self): - return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9) - - -class CIFAR10DataModule(LightningDataModule): - - transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()]) - - def train_dataloader(self, *args, **kwargs): - trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=self.transform) - return torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=0) - - def val_dataloader(self, *args, **kwargs): - valset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=self.transform) - return torch.utils.data.DataLoader(valset, batch_size=32, shuffle=True, num_workers=0) - - -def cli_main(): - - parser = ArgumentParser() - parser = Trainer.add_argparse_args(parser) - cmd_line = None if len(sys.argv) != 1 else DEFAULT_CMD_LINE - args = parser.parse_args(args=cmd_line) - - model = ModelToProfile(models.resnet50(pretrained=True)) - datamodule = CIFAR10DataModule() - trainer = Trainer(**vars(args)) - trainer.fit(model, datamodule=datamodule) - - -if __name__ == '__main__': - cli_lightning_logo() - cli_main() diff --git a/pl_examples/basic_examples/submit_ddp2_job.sh b/pl_examples/basic_examples/submit_ddp2_job.sh index 026589a604c36..6fed6afef0d1c 100755 --- a/pl_examples/basic_examples/submit_ddp2_job.sh +++ b/pl_examples/basic_examples/submit_ddp2_job.sh @@ -24,4 +24,4 @@ source activate $1 # ------------------------- # run script from above -srun python3 simple_image_classifier.py --accelerator 'ddp2' --gpus 2 --num_nodes 2 --max_epochs 5 +srun python3 image_classifier.py --accelerator 'ddp2' --gpus 2 --num_nodes 2 diff --git a/pl_examples/basic_examples/submit_ddp_job.sh b/pl_examples/basic_examples/submit_ddp_job.sh index b4f5ff0a64d92..383579c4346b6 100755 --- a/pl_examples/basic_examples/submit_ddp_job.sh +++ b/pl_examples/basic_examples/submit_ddp_job.sh @@ -24,4 +24,4 @@ source activate $1 # ------------------------- # run script from above -srun python3 simple_image_classifier.py --accelerator 'ddp' --gpus 2 --num_nodes 2 --max_epochs 5 +srun python3 image_classifier.py --accelerator 'ddp' --gpus 2 --num_nodes 2 diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 4e148a18433a6..88f4e66605741 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -49,7 +49,6 @@ from torch.optim.lr_scheduler import MultiStepLR from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader -from torchmetrics import Accuracy from torchvision import models, transforms from torchvision.datasets import ImageFolder from torchvision.datasets.utils import download_and_extract_archive @@ -189,8 +188,8 @@ def __init__( self.__build_model() - self.train_acc = Accuracy() - self.valid_acc = Accuracy() + self.train_acc = pl.metrics.Accuracy() + self.valid_acc = pl.metrics.Accuracy() self.save_hyperparameters() def __build_model(self): diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index b9660475bf2f7..569078c994ba4 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -2,17 +2,42 @@ import logging import os +import sys +import time -from pytorch_lightning.info import ( # noqa: F401 - __author__, - __author_email__, - __copyright__, - __docs__, - __homepage__, - __license__, - __version__, +_this_year = time.strftime("%Y") +__version__ = '1.3.0dev' +__author__ = 'William Falcon et al.' +__author_email__ = 'waf2107@columbia.edu' +__license__ = 'Apache-2.0' +__copyright__ = f'Copyright (c) 2018-{_this_year}, {__author__}.' +__homepage__ = 'https://github.com/PyTorchLightning/pytorch-lightning' +# this has to be simple string, see: https://github.com/pypa/twine/issues/522 +__docs__ = ( + "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." + " Scale your models. Write less boilerplate." ) +__long_docs__ = """ +Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. + It's more of a style-guide than a framework. +In Lightning, you organize your code into 3 distinct categories: + +1. Research code (goes in the LightningModule). +2. Engineering code (you delete, and is handled by the Trainer). +3. Non-essential research code (logging, etc. this goes in Callbacks). + +Although your research/production project might start simple, once you add things like GPU AND TPU training, + 16-bit precision, etc, you end up spending more time engineering than researching. + Lightning automates AND rigorously tests those parts for you. + +Overall, Lightning guarantees rigorously tested, correct, modern best practices for the automated parts. + +Documentation +------------- +- https://pytorch-lightning.readthedocs.io/en/latest +- https://pytorch-lightning.readthedocs.io/en/stable +""" _root_logger = logging.getLogger() _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) @@ -25,20 +50,32 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) -from pytorch_lightning import metrics # noqa: E402 -from pytorch_lightning.callbacks import Callback # noqa: E402 -from pytorch_lightning.core import LightningDataModule, LightningModule # noqa: E402 -from pytorch_lightning.trainer import Trainer # noqa: E402 -from pytorch_lightning.utilities.seed import seed_everything # noqa: E402 - -__all__ = [ - 'Trainer', - 'LightningDataModule', - 'LightningModule', - 'Callback', - 'seed_everything', - 'metrics', -] +try: + # This variable is injected in the __builtins__ by the build + # process. It used to enable importing subpackages of skimage when + # the binaries are not built + _ = None if __LIGHTNING_SETUP__ else None +except NameError: + __LIGHTNING_SETUP__: bool = False + +if __LIGHTNING_SETUP__: # pragma: no-cover + sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover + # We are not importing the rest of the lightning during the build process, as it may not be compiled yet +else: + from pytorch_lightning import metrics + from pytorch_lightning.callbacks import Callback + from pytorch_lightning.core import LightningDataModule, LightningModule + from pytorch_lightning.trainer import Trainer + from pytorch_lightning.utilities.seed import seed_everything + + __all__ = [ + 'Trainer', + 'LightningDataModule', + 'LightningModule', + 'Callback', + 'seed_everything', + 'metrics', + ] # for compatibility with namespace packages __import__('pkg_resources').declare_namespace(__name__) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 1dcd541ca0610..06191dcff6d80 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -21,8 +21,8 @@ from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import TrainingTypePlugin from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device +from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available from pytorch_lightning.utilities.enums import AMPType, LightningEnum if TYPE_CHECKING: @@ -66,29 +66,17 @@ def __init__( self.lr_schedulers: Sequence = [] self.optimizer_frequencies: Sequence = [] - def connect(self, model: LightningModule) -> None: - """Transfers ownership of the model to this plugin""" - self.training_type_plugin.connect(model) - - def setup_environment(self) -> None: - """ - Setup any processes or distributed connections. - This is called before the LightningModule/DataModule setup hook - which allows the user to access the accelerator environment before setup is complete. - """ - self.training_type_plugin.setup_environment() - def setup(self, trainer: 'Trainer', model: LightningModule) -> None: """ - Setup plugins for the trainer fit and creates optimizers. + Connects the plugins to the training process, creates optimizers + Args: - trainer: the trainer instance - model: the LightningModule + trainer: the trainer instance to connect to + model: the model to train """ - self.setup_training_type_plugin(self.training_type_plugin, model) - if not self.training_type_plugin.setup_optimizers_in_pre_dispatch: - self.setup_optimizers(trainer) - self.setup_precision_plugin(self.precision_plugin) + self.connect_training_type_plugin(self.training_type_plugin, model) + self.setup_optimizers(trainer) + self.connect_precision_plugin(self.precision_plugin) def start_training(self, trainer: 'Trainer') -> None: self.training_type_plugin.start_training(trainer) @@ -99,14 +87,12 @@ def start_evaluating(self, trainer: 'Trainer') -> None: def start_predicting(self, trainer: 'Trainer') -> None: self.training_type_plugin.start_predicting(trainer) - def pre_dispatch(self, trainer: 'Trainer') -> None: + def pre_dispatch(self) -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.pre_dispatch() - if self.training_type_plugin.setup_optimizers_in_pre_dispatch: - self.setup_optimizers(trainer) self.precision_plugin.pre_dispatch() - def post_dispatch(self, trainer: 'Trainer') -> None: + def post_dispatch(self) -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.post_dispatch() self.precision_plugin.post_dispatch() @@ -220,7 +206,7 @@ def test_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context(): return self.training_type_plugin.test_step(*args) - def predict_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: + def predict(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: """The actual predict step. Args: @@ -236,7 +222,7 @@ def predict_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: args[0] = batch with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context(): - return self.training_type_plugin.predict_step(*args) + return self.training_type_plugin.predict(*args) def training_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: """A hook to do something at the end of the training step @@ -347,11 +333,14 @@ def setup_optimizers(self, trainer: 'Trainer') -> None: self.lr_schedulers = lr_schedulers self.optimizer_frequencies = optimizer_frequencies - def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: - """Attaches the training type plugin to the accelerator.""" - plugin.setup(model) + def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: + """Attaches the training type plugin to the accelerator. + Also transfers ownership of the model to this plugin + + """ + plugin.connect(model) - def setup_precision_plugin(self, plugin: PrecisionPlugin) -> None: + def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: """Attaches the precision plugin to the accelerator""" model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers) self.model = model @@ -360,12 +349,7 @@ def setup_precision_plugin(self, plugin: PrecisionPlugin) -> None: def to_device(self, batch: Any) -> Any: """Pushes the batch to the root device""" - # Todo (tchaton) Better fix - is_dict = isinstance(batch, dict) - if is_dict: - batch = [batch] - batch = self.batch_to_device(batch, self.root_device) - return batch[0] if is_dict else batch + return self.batch_to_device(batch, self.root_device) @property def amp_backend(self) -> Optional[LightningEnum]: @@ -421,7 +405,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra Return: A tensor of shape (world_size, batch, ...) """ - return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads) + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: """Wraps the dataloader if necessary @@ -438,31 +422,3 @@ def results(self) -> Any: In distributed training, we make sure to transfer the results to the appropriate master process. """ return self.training_type_plugin.results - - # todo: remove in v1.5 - def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: - """ - Attaches the training type plugin to the accelerator. - Also transfers ownership of the model to this plugin - - .. deprecated::v1.3 - Will be removed in v1.5.0. - """ - rank_zero_warn( - 'Accelerator method `connect_training_type_plugin` was deprecated in v1.3.' - ' It will be removed in v1.5.' - ) - self.setup_training_type_plugin(plugin, model) - - # todo: remove in v1.5 - def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: - """Attaches the precision plugin to the accelerator - - .. deprecated::v1.3 - Will be removed in v1.5.0. - """ - rank_zero_warn( - 'Accelerator method `connect_precision_plugin` was deprecated in v1.3.' - ' It will be removed in v1.5.' - ) - self.setup_precision_plugin(plugin) diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 22ea8f1e1b7aa..f428951b16932 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -1,16 +1,3 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from typing import TYPE_CHECKING from pytorch_lightning.accelerators.accelerator import Accelerator diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index c23960e4fd9e3..af9ce25f902b3 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -1,19 +1,6 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import logging import os -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 35a475e3e790d..57e65a62f6783 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -1,17 +1,4 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Optional, TYPE_CHECKING import torch from torch.optim import Optimizer @@ -25,9 +12,6 @@ if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm - from torch_xla._patched_functions import clip_grad_norm_ - - xla_clip_grad_norm_ = clip_grad_norm_ if TYPE_CHECKING: from pytorch_lightning.core.lightning import LightningModule @@ -62,25 +46,12 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra Function to gather a tensor from several distributed processes Args: tensor: tensor of shape (batch, ...) - group: not available with TPUs - sync_grads: not available with TPUs + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op Return: A tensor of shape (world_size, batch, ...) """ # todo: Add support for backward with all_gather - if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed: - return xm.all_gather(tensor).view(-1, *tensor.shape) + if torch.distributed.is_initialized(): + return xm.all_gather(tensor, group=group, sync_grads=sync_grads) return tensor - - def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0): - - model = self.lightning_module - parameters = model.parameters() - - grad_clip_val = float(clip_val) - if grad_clip_val <= 0: - return - - max_norm = grad_clip_val - - xla_clip_grad_norm_(parameters, max_norm, norm_type) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 7757902bd3baf..0ba1fd4ff7785 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -17,7 +17,7 @@ """ import abc -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from pytorch_lightning.core.lightning import LightningModule @@ -81,7 +81,7 @@ def on_train_epoch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the train epoch begins.""" pass - def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: + def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None: """Called when the train epoch ends.""" pass @@ -89,7 +89,7 @@ def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None """Called when the val epoch begins.""" pass - def on_validation_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: + def on_validation_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the val epoch ends.""" pass @@ -97,16 +97,16 @@ def on_test_epoch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the test epoch begins.""" pass - def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: + def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the test epoch ends.""" pass def on_epoch_start(self, trainer, pl_module: LightningModule) -> None: - """Called when either of train/val/test epoch begins.""" + """Called when the epoch begins.""" pass def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: - """Called when either of train/val/test epoch ends.""" + """Called when the epoch ends.""" pass def on_batch_start(self, trainer, pl_module: LightningModule) -> None: diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4448de8e4834b..38ccce648502a 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -172,4 +172,4 @@ def _run_early_stopping_check(self, trainer): trainer.should_stop = True # stop every ddp process if any world process decides to stop - trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop) + trainer.should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(trainer.should_stop) diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py index b1885087f4da0..0af7d61bf5dec 100644 --- a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -74,7 +74,7 @@ def __init__(self, scheduling: Dict[int, int]): def going_to_accumulate_grad_batches(self): return any([v > 1 for v in self.scheduling.values()]) - def on_train_epoch_start(self, trainer, pl_module): + def on_epoch_start(self, trainer, pl_module): epoch = trainer.current_epoch for i in reversed(range(len(self.epochs))): if epoch >= self.epochs[i]: diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2781586730151..bf6c799ef728a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -30,7 +30,7 @@ import yaml from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import WarningCache @@ -258,9 +258,9 @@ def save_checkpoint(self, trainer, unused: Optional = None): to handle correct behaviour in distributed training, i.e., saving only on rank 0. """ if unused is not None: - rank_zero_deprecation( + rank_zero_warn( "`ModelCheckpoint.save_checkpoint` signature has changed in v1.3. The `pl_module` parameter" - " has been removed. Support for the old signature will be removed in v1.5" + " has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning ) global_step = trainer.global_step @@ -371,9 +371,9 @@ def __init_triggers( # period takes precedence over every_n_val_epochs for backwards compatibility if period is not None: - rank_zero_deprecation( + rank_zero_warn( 'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' - ' Please use `every_n_val_epochs` instead.' + ' Please use `every_n_val_epochs` instead.', DeprecationWarning ) self._every_n_val_epochs = period @@ -381,17 +381,17 @@ def __init_triggers( @property def period(self) -> Optional[int]: - rank_zero_deprecation( + rank_zero_warn( 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' - ' Please use `every_n_val_epochs` instead.' + ' Please use `every_n_val_epochs` instead.', DeprecationWarning ) return self._period @period.setter def period(self, value: Optional[int]) -> None: - rank_zero_deprecation( + rank_zero_warn( 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' - ' Please use `every_n_val_epochs` instead.' + ' Please use `every_n_val_epochs` instead.', DeprecationWarning ) self._period = value @@ -424,7 +424,7 @@ def _do_save(self, trainer, filepath: str): else: raise ValueError(".save_function() not set") - def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) -> bool: + def check_monitor_top_k(self, current: torch.Tensor) -> bool: if current is None: return False @@ -444,12 +444,7 @@ def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) - current = torch.tensor(current) monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode] - should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) - - # If using multiple devices, make sure all processes are unanimous on the decision. - should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save) - - return should_update_best_and_save + return monitor_op(current, self.best_k_models[self.kth_best_model_path]).item() @classmethod def _format_checkpoint_name( @@ -643,7 +638,15 @@ def _save_top_k_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): epoch = monitor_candidates.get("epoch") step = monitor_candidates.get("step") - if self.check_monitor_top_k(trainer, current): + # when `val_loss` is being logged and no ModelCheckpoint is being provided + # `val_loss` will be selected for monitor and need to be reduced to + # prevent processes divergence + # TODO: Move this logic to logger_connector. This also needs to be fixed for any + # other monitor logged value which aren't produced from a Metric. + if self.monitor == "val_loss": + current = trainer.training_type_plugin.reduce(current, reduce_op="mean") + + if self.check_monitor_top_k(current): self._update_best_and_save(current, epoch, step, trainer, monitor_candidates) elif self.verbose: rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}") @@ -728,4 +731,5 @@ def file_exists(self, filepath: Union[str, Path], trainer) -> bool: the internal state to diverge between ranks. """ exists = self._fs.exists(filepath) - return trainer.training_type_plugin.broadcast(exists) + exists = trainer.training_type_plugin.broadcast(exists) + return exists diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 7dc4202530d04..74e57e2b5642e 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -39,7 +39,8 @@ class tqdm(_tqdm): """ - Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from flickering + Custom tqdm progressbar where we append 0 to floating points/strings to + prevent the progress bar from flickering """ @staticmethod @@ -200,7 +201,7 @@ def on_init_end(self, trainer): def on_train_start(self, trainer, pl_module): self._train_batch_idx = trainer.batch_idx - def on_train_epoch_start(self, trainer, pl_module): + def on_epoch_start(self, trainer, pl_module): self._train_batch_idx = 0 def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): @@ -392,8 +393,8 @@ def on_train_start(self, trainer, pl_module): super().on_train_start(trainer, pl_module) self.main_progress_bar = self.init_train_tqdm() - def on_train_epoch_start(self, trainer, pl_module): - super().on_train_epoch_start(trainer, pl_module) + def on_epoch_start(self, trainer, pl_module): + super().on_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches if total_train_batches != float('inf'): diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 4178c9eeacd50..994c259f48964 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -14,6 +14,7 @@ """LightningDataModule for loading DataLoaders with ease.""" import functools +from abc import abstractmethod from argparse import ArgumentParser, Namespace from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union @@ -43,8 +44,6 @@ def __call__(cls, *args, **kwargs): cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data)) # Track setup calls cls.setup = track_data_hook_calls(cls.setup) - # Track teardown calls - cls.teardown = track_data_hook_calls(cls.teardown) # Get instance of LightningDataModule by mocking its __init__ via __call__ obj = type.__call__(cls, *args, **kwargs) @@ -53,13 +52,12 @@ def __call__(cls, *args, **kwargs): def track_data_hook_calls(fn): - """A decorator that checks if prepare_data/setup/teardown has been called. + """A decorator that checks if prepare_data/setup have been called. - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. Its corresponding `dm_has_setup_{stage}` attribute gets set to True - - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup`` Args: fn (function): Function that will be tracked to see if it has been called. @@ -73,10 +71,9 @@ def wrapped_fn(*args, **kwargs): # The object instance from which setup or prepare_data was called obj = args[0] - name = fn.__name__ # If calling setup, we check the stage and assign stage-specific bool args - if name in ("setup", "teardown"): + if fn.__name__ == "setup": # Get stage either by grabbing from args or checking kwargs. # If not provided, set call status of 'fit', 'validate', and 'test' to True. @@ -85,11 +82,11 @@ def wrapped_fn(*args, **kwargs): if stage is None: for s in ("fit", "validate", "test"): - setattr(obj, f"_has_{name}_{s}", True) + setattr(obj, f"_has_setup_{s}", True) else: - setattr(obj, f"_has_{name}_{stage}", True) + setattr(obj, f"_has_setup_{stage}", True) - elif name == "prepare_data": + if fn.__name__ == "prepare_data": obj._has_prepared_data = True return fn(*args, **kwargs) @@ -122,18 +119,14 @@ def val_dataloader(self): def test_dataloader(self): test_split = Dataset(...) return DataLoader(test_split) - def teardown(self): - # clean up after fit or test - # called on every process in DDP - A DataModule implements 6 key methods: + A DataModule implements 5 key methods: * **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode). * **setup** (things to do on every accelerator in distributed mode). * **train_dataloader** the training dataloader. * **val_dataloader** the val dataloader(s). * **test_dataloader** the test dataloader(s). - * **teardown** (things to do on every accelerator in distributed mode when finished) This allows you to share a full dataset without explaining how to download, @@ -161,17 +154,11 @@ def __init__( # Private attrs to keep track of whether or not data hooks have been called yet self._has_prepared_data = False - self._has_setup_fit = False self._has_setup_validate = False self._has_setup_test = False self._has_setup_predict = False - self._has_teardown_fit = False - self._has_teardown_validate = False - self._has_teardown_test = False - self._has_teardown_predict = False - @property def train_transforms(self): """ @@ -272,41 +259,13 @@ def has_setup_predict(self) -> bool: """ return self._has_setup_predict - @property - def has_teardown_fit(self) -> bool: - """Return bool letting you know if ``datamodule.teardown(stage='fit')`` has been called or not. + @abstractmethod + def prepare_data(self, *args, **kwargs): + pass - Returns: - bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default. - """ - return self._has_teardown_fit - - @property - def has_teardown_validate(self) -> bool: - """Return bool letting you know if ``datamodule.teardown(stage='validate')`` has been called or not. - - Returns: - bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default. - """ - return self._has_teardown_validate - - @property - def has_teardown_test(self) -> bool: - """Return bool letting you know if ``datamodule.teardown(stage='test')`` has been called or not. - - Returns: - bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default. - """ - return self._has_teardown_test - - @property - def has_teardown_predict(self) -> bool: - """Return bool letting you know if ``datamodule.teardown(stage='predict')`` has been called or not. - - Returns: - bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default. - """ - return self._has_teardown_predict + @abstractmethod + def setup(self, stage: Optional[str] = None): + pass @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index bf3b0bf605679..1399d1b3c66ba 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -25,6 +25,42 @@ class ModelHooks: """Hooks to be used in LightningModule.""" + def setup(self, stage: Optional[str] = None) -> None: + """ + Called at the beginning of fit (train + validate), validate, test, predict, or tune. + This is a good hook when you need to build models dynamically or adjust something about them. + This hook is called on every process when using DDP. + + Args: + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` + + Example:: + + class LitModel(...): + def __init__(self): + self.l1 = None + + def prepare_data(self): + download_data() + tokenize() + + # don't do this + self.something = else + + def setup(stage): + data = Load_data(...) + self.l1 = nn.Linear(28, data.num_classes) + + """ + + def teardown(self, stage: Optional[str] = None) -> None: + """ + Called at the end of fit (train + validate), validate, test, predict, or tune. + + Args: + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` + """ + def on_fit_start(self) -> None: """ Called at the very beginning of fit. @@ -188,13 +224,13 @@ def on_predict_model_eval(self) -> None: def on_epoch_start(self) -> None: """ - Called when either of train/val/test epoch begins. + Called in the training loop at the very beginning of the epoch. """ # do something when the epoch starts def on_epoch_end(self) -> None: """ - Called when either of train/val/test epoch ends. + Called in the training loop at the very end of the epoch. """ # do something when the epoch ends @@ -204,7 +240,7 @@ def on_train_epoch_start(self) -> None: """ # do something when the epoch starts - def on_train_epoch_end(self, outputs: List[Any]) -> None: + def on_train_epoch_end(self, outputs) -> None: """ Called in the training loop at the very end of the epoch. """ @@ -216,7 +252,7 @@ def on_validation_epoch_start(self) -> None: """ # do something when the epoch starts - def on_validation_epoch_end(self, outputs: List[Any]) -> None: + def on_validation_epoch_end(self) -> None: """ Called in the validation loop at the very end of the epoch. """ @@ -228,7 +264,7 @@ def on_test_epoch_start(self) -> None: """ # do something when the epoch starts - def on_test_epoch_end(self, outputs: List[Any]) -> None: + def on_test_epoch_end(self) -> None: """ Called in the test loop at the very end of the epoch. """ @@ -246,18 +282,6 @@ def on_test_end(self) -> None: """ # do something at the end of testing - def on_predict_start(self) -> None: - """ - Called at the beginning of predicting. - """ - # do something at the start of predicting - - def on_predict_end(self) -> None: - """ - Called at the end of predicting. - """ - # do something at the end of predicting - def on_before_zero_grad(self, optimizer: Optimizer) -> None: """ Called after optimizer.step() and before optimizer.zero_grad(). @@ -359,42 +383,6 @@ def prepare_data(self): model.test_dataloader() """ - def setup(self, stage: Optional[str] = None) -> None: - """ - Called at the beginning of fit (train + validate), validate, test, predict, or tune. - This is a good hook when you need to build models dynamically or adjust something about them. - This hook is called on every process when using DDP. - - Args: - stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` - - Example:: - - class LitModel(...): - def __init__(self): - self.l1 = None - - def prepare_data(self): - download_data() - tokenize() - - # don't do this - self.something = else - - def setup(stage): - data = Load_data(...) - self.l1 = nn.Linear(28, data.num_classes) - - """ - - def teardown(self, stage: Optional[str] = None) -> None: - """ - Called at the end of fit (train + validate), validate, test, predict, or tune. - - Args: - stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` - """ - def train_dataloader(self) -> Any: """ Implement one or more PyTorch DataLoaders for training. @@ -606,18 +594,6 @@ def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: will have an argument ``dataloader_idx`` which matches the order here. """ - def on_train_dataloader(self) -> None: - """Called before requesting the train dataloader.""" - - def on_val_dataloader(self) -> None: - """Called before requesting the val dataloader.""" - - def on_test_dataloader(self) -> None: - """Called before requesting the test dataloader.""" - - def on_predict_dataloader(self) -> None: - """Called before requesting the predict dataloader.""" - def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: """ Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7efe88515b37e..4c839f3a6c906 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,7 +38,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES from pytorch_lightning.core.step_result import Result -from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -105,7 +105,6 @@ def __init__(self, *args, **kwargs): self._current_hook_fx_name = None self._current_dataloader_idx = None self._automatic_optimization: bool = True - self._param_requires_grad_state = dict() def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: @@ -720,13 +719,10 @@ def validation_step(self, *args, **kwargs): .. code-block:: python # pseudocode of order - val_outs = [] - for val_batch in val_data: - out = validation_step(val_batch) - if defined('validation_step_end'): - out = validation_step_end(out) - val_outs.append(out) - val_outs = validation_epoch_end(val_outs) + out = validation_step() + if defined('validation_step_end'): + out = validation_step_end(out) + out = validation_epoch_end(out) .. code-block:: python @@ -1057,7 +1053,7 @@ def test_epoch_end(self, outputs): self.log('final_metric', final_value) """ - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None): + def predict(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None): """ Use this function with trainer.predict(...). Override if you need to add any processing logic. """ @@ -1229,8 +1225,9 @@ def training_step(...): opt_a.step() """ if optimizer is not None: - rank_zero_deprecation( - "`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4" + rank_zero_warn( + "`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4", + DeprecationWarning ) # make sure we're using manual opt @@ -1314,7 +1311,7 @@ def untoggle_optimizer(self, optimizer_idx: int): if param in self._param_requires_grad_state: param.requires_grad = self._param_requires_grad_state[param] # save memory - self._param_requires_grad_state = dict() + del self._param_requires_grad_state def optimizer_step( self, diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index a3eab728f8ea8..afb64535d1470 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -16,7 +16,7 @@ import shutil import subprocess from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple, Union import numpy as np import torch @@ -71,15 +71,14 @@ def __init__(self, module: nn.Module): def __del__(self): self.detach_hook() - def _register_hook(self) -> Optional[RemovableHandle]: + def _register_hook(self) -> RemovableHandle: """ Registers a hook on the module that computes the input- and output size(s) on the first forward pass. If the hook is called, it will remove itself from the from the module, meaning that recursive models will only record their input- and output shapes once. - Registering hooks on :class:`~torch.jit.ScriptModule` is not supported. Return: - A handle for the installed hook, or ``None`` if registering the hook is not possible. + A handle for the installed hook. """ def hook(module, inp, out): @@ -89,10 +88,7 @@ def hook(module, inp, out): self._out_size = parse_batch_shape(out) self._hook_handle.remove() - handle = None - if not isinstance(self._module, torch.jit.ScriptModule): - handle = self._module.register_forward_hook(hook) - return handle + return self._module.register_forward_hook(hook) def detach_hook(self): """ diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 3961586f4946a..f8d7a2ffe3a23 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -20,8 +20,8 @@ import torch from torch import Tensor -from torchmetrics import Metric +from pytorch_lightning.metrics import Metric from pytorch_lightning.utilities.distributed import sync_ddp_if_available diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index 37ac5d8b13462..5da7dfa86084d 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -11,10 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import io from typing import Any -from pytorch_lightning.overrides.torch_distributed import broadcast_object_list -from pytorch_lightning.utilities.distributed import group as _group +import torch +from torch import distributed as torch_distrib + +from pytorch_lightning.utilities import _GROUP_AVAILABLE + +WORLD = None +if _GROUP_AVAILABLE: + from torch.distributed import group + WORLD = group.WORLD class LightningDistributed: @@ -23,13 +31,32 @@ def __init__(self, rank=None, device=None): self.rank = rank self.device = device - def broadcast(self, obj: Any, group=_group.WORLD): - # always wrap into a list so list can be brodcasted. - obj = [obj] - - if self.rank != 0: - obj = [None] * len(obj) - - broadcast_object_list(obj, 0, group=group or _group.WORLD) - - return obj[0] + def broadcast(self, obj: Any, group=WORLD): + if self.rank == 0: + self._emit(obj, group) + else: + obj = self._receive(group) + return obj + + def _broadcast(self, tensor, src=0, group=WORLD): + if group is None: + return torch_distrib.broadcast(tensor, src=src) + return torch_distrib.broadcast(tensor, src=0, group=group) + + def _emit(self, obj: Any, group=WORLD): + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + length_tensor = torch.tensor([len(data)]).long().to(self.device) + self._broadcast(length_tensor, src=0, group=group) + data_tensor = torch.ByteTensor(data).to(self.device) + self._broadcast(data_tensor, src=0, group=group) + + def _receive(self, group=WORLD): + length_tensor = torch.tensor([0]).long().to(self.device) + self._broadcast(length_tensor, src=0, group=group) + data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device) + self._broadcast(data_tensor, src=0, group=group) + buffer = io.BytesIO(data_tensor.cpu().numpy()) + obj = torch.load(buffer) + return obj diff --git a/pytorch_lightning/info.py b/pytorch_lightning/info.py deleted file mode 100644 index b00d1946424e7..0000000000000 --- a/pytorch_lightning/info.py +++ /dev/null @@ -1,36 +0,0 @@ -import time - -_this_year = time.strftime("%Y") -__version__ = '1.3.0dev' -__author__ = 'William Falcon et al.' -__author_email__ = 'waf2107@columbia.edu' -__license__ = 'Apache-2.0' -__copyright__ = f'Copyright (c) 2018-{_this_year}, {__author__}.' -__homepage__ = 'https://github.com/PyTorchLightning/pytorch-lightning' -__docs_url__ = "https://pytorch-lightning.readthedocs.io/en/stable/" -# this has to be simple string, see: https://github.com/pypa/twine/issues/522 -__docs__ = ( - "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." - " Scale your models. Write less boilerplate." -) -__long_docs__ = """ -Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. - It's more of a style-guide than a framework. - -In Lightning, you organize your code into 3 distinct categories: - -1. Research code (goes in the LightningModule). -2. Engineering code (you delete, and is handled by the Trainer). -3. Non-essential research code (logging, etc. this goes in Callbacks). - -Although your research/production project might start simple, once you add things like GPU AND TPU training, - 16-bit precision, etc, you end up spending more time engineering than researching. - Lightning automates AND rigorously tests those parts for you. - -Overall, Lightning guarantees rigorously tested, correct, modern best practices for the automated parts. - -Documentation -------------- -- https://pytorch-lightning.readthedocs.io/en/latest -- https://pytorch-lightning.readthedocs.io/en/stable -""" diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 75e62539dd0a8..b6d25be7f0b0c 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -80,9 +80,7 @@ def any_lightning_module_function_or_hook(self): Defaults to `./mlflow` if `tracking_uri` is not provided. Has no effect if `tracking_uri` is provided. prefix: A string to put at the beginning of metric keys. - artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate - default. - figure_file_extension: File extension with which matplotlib saves figures. + figure_file_extension: File extension with which matplotlib saves figure Raises: ImportError: @@ -98,7 +96,6 @@ def __init__( tags: Optional[Dict[str, Any]] = None, save_dir: Optional[str] = './mlruns', prefix: str = '', - artifact_location: Optional[str] = None, figure_file_extension='.png', ): if mlflow is None: @@ -116,10 +113,8 @@ def __init__( self._run_id = None self.tags = tags self._prefix = prefix - self._artifact_location = artifact_location - self._figure_file_extension = figure_file_extension - self._mlflow_client = MlflowClient(tracking_uri) + self._figure_file_extension = figure_file_extension @property @rank_zero_experiment @@ -139,10 +134,7 @@ def experiment(self) -> MlflowClient: self._experiment_id = expt.experiment_id else: log.warning(f'Experiment with name {self._experiment_name} not found. Creating it.') - self._experiment_id = self._mlflow_client.create_experiment( - name=self._experiment_name, - artifact_location=self._artifact_location, - ) + self._experiment_id = self._mlflow_client.create_experiment(name=self._experiment_name) if self._run_id is None: run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=self.tags) diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 9b27fdf0cb253..a5a337f2cba9d 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from pytorch_lightning.metrics.classification import ( # noqa: F401 Accuracy, AUC, @@ -38,9 +37,3 @@ R2Score, SSIM, ) -from pytorch_lightning.utilities import rank_zero_deprecation - -rank_zero_deprecation( - "`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package" - " (https://github.com/PyTorchLightning/metrics) since v1.3 and will be removed in v1.5" -) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 1a9febe0c831c..9d97cbec1a387 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -13,14 +13,94 @@ # limitations under the License. from typing import Any, Callable, Optional -from torchmetrics import Accuracy as _Accuracy +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.accuracy import _accuracy_compute, _accuracy_update +from pytorch_lightning.metrics.metric import Metric -class Accuracy(_Accuracy): +class Accuracy(Metric): + r""" + Computes `Accuracy `__: + + .. math:: + \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. + + For multi-class and multi-dimensional multi-class data with probability predictions, the + parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the + top-K highest probability items are considered to find the correct label. + + For multi-label and multi-dimensional multi-class inputs, this metric computes the "global" + accuracy by default, which counts all labels or sub-samples separately. This can be + changed to subset accuracy (which requires all labels or sub-samples in the sample to + be correctly predicted) by setting ``subset_accuracy=True``. + + Accepts all input types listed in :ref:`extensions/metrics:input types`. + + Args: + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. + top_k: + Number of highest probability predictions considered to find the correct label, relevant + only for (multi-dimensional) multi-class inputs with probability predictions. The + default value (``None``) will be interpreted as 1 for these inputs. + + Should be left at default (``None``) for all other types of inputs. + subset_accuracy: + Whether to compute subset accuracy for multi-label and multi-dimensional + multi-class inputs (has no effect for other input types). + + - For multi-label inputs, if the parameter is set to ``True``, then all labels for + each sample must be correctly predicted for the sample to count as correct. If it + is set to ``False``, then all labels are counted separately - this is equivalent to + flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). + + - For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all + sub-sample (on the extra axis) must be correct for the sample to be counted as correct. + If it is set to ``False``, then all sub-samples are counter separately - this is equivalent, + in the case of label predictions, to flattening the inputs beforehand (i.e. + ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter + still applies in both cases, if set. + + compute_on_step: + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step + process_group: + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather + + Raises: + ValueError: + If ``threshold`` is not between ``0`` and ``1``. + ValueError: + If ``top_k`` is not an ``integer`` larger than ``0``. + + Example: + + >>> from pytorch_lightning.metrics import Accuracy + >>> target = torch.tensor([0, 1, 2, 3]) + >>> preds = torch.tensor([0, 2, 1, 3]) + >>> accuracy = Accuracy() + >>> accuracy(preds, target) + tensor(0.5000) + + >>> target = torch.tensor([0, 1, 2]) + >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) + >>> accuracy = Accuracy(top_k=2) + >>> accuracy(preds, target) + tensor(0.6667) + + """ - @deprecated_metrics(target=_Accuracy) def __init__( self, threshold: float = 0.5, @@ -31,9 +111,45 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + if not 0 < threshold < 1: + raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") + + if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): + raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") + + self.threshold = threshold + self.top_k = top_k + self.subset_accuracy = subset_accuracy + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. See :ref:`extensions/metrics:input types` for more information + on input types. + + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth labels """ - This implementation refers to :class:`~torchmetrics.Accuracy`. - .. deprecated:: - Use :class:`~torchmetrics.Accuracy`. Will be removed in v1.5.0. + correct, total = _accuracy_update( + preds, target, threshold=self.threshold, top_k=self.top_k, subset_accuracy=self.subset_accuracy + ) + + self.correct += correct + self.total += total + + def compute(self) -> torch.Tensor: + """ + Computes accuracy based on inputs passed in to ``update`` previously. """ + return _accuracy_compute(self.correct, self.total) diff --git a/pytorch_lightning/metrics/classification/auc.py b/pytorch_lightning/metrics/classification/auc.py index 05bc7b27d7e68..6c5a29173d20a 100644 --- a/pytorch_lightning/metrics/classification/auc.py +++ b/pytorch_lightning/metrics/classification/auc.py @@ -13,14 +13,36 @@ # limitations under the License. from typing import Any, Callable, Optional -from torchmetrics import AUC as _AUC +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.auc import _auc_compute, _auc_update +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.utilities import rank_zero_warn -class AUC(_AUC): +class AUC(Metric): + r""" + Computes Area Under the Curve (AUC) using the trapezoidal rule + + Forward accepts two input tensors that should be 1D and have the same number + of elements + + Args: + reorder: AUC expects its first input to be sorted. If this is not the case, + setting this argument to ``True`` will use a stable sorting algorithm to + sort the input in decending order + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather + """ - @deprecated_metrics(target=_AUC) def __init__( self, reorder: bool = False, @@ -29,9 +51,40 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.reorder = reorder + + self.add_state("x", default=[], dist_reduce_fx=None) + self.add_state("y", default=[], dist_reduce_fx=None) + + rank_zero_warn( + 'Metric `AUC` will save all targets and predictions in buffer.' + ' For large datasets this may lead to large memory footprint.' + ) + + def update(self, x: torch.Tensor, y: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + x: Predictions from model (probabilities, or labels) + y: Ground truth labels """ - This implementation refers to :class:`~torchmetrics.AUC`. + x, y = _auc_update(x, y) - .. deprecated:: - Use :class:`~torchmetrics.AUC`. Will be removed in v1.5.0. + self.x.append(x) + self.y.append(y) + + def compute(self) -> torch.Tensor: + """ + Computes AUC based on inputs passed in to ``update`` previously. """ + x = torch.cat(self.x, dim=0) + y = torch.cat(self.y, dim=0) + return _auc_compute(x, y, reorder=self.reorder) diff --git a/pytorch_lightning/metrics/classification/auroc.py b/pytorch_lightning/metrics/classification/auroc.py index e10b094fd5a2e..6b9b5ae9f021f 100644 --- a/pytorch_lightning/metrics/classification/auroc.py +++ b/pytorch_lightning/metrics/classification/auroc.py @@ -11,16 +11,95 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from distutils.version import LooseVersion from typing import Any, Callable, Optional -from torchmetrics import AUROC as _AUROC +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.auroc import _auroc_compute, _auroc_update +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.utilities import rank_zero_warn -class AUROC(_AUROC): +class AUROC(Metric): + r"""Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC) + `_. + Works for both binary, multilabel and multiclass problems. In the case of + multiclass, the values will be calculated based on a one-vs-the-rest approach. + + Forward accepts + + - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor + with probabilities, where C is the number of classes. + + - ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels + + For non-binary input, if the ``preds`` and ``target`` tensor have the same + size the input will be interpretated as multilabel and if ``preds`` have one + dimension more than the ``target`` tensor the input will be interpretated as + multiclass. + + Args: + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translate to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range [0,num_classes-1] + average: + - ``'macro'`` computes metric for each class and uniformly averages them + - ``'weighted'`` computes metric for each class and does a weighted-average, + where each class is weighted by their support (accounts for class imbalance) + - ``None`` computes and returns the metric per class + max_fpr: + If not ``None``, calculates standardized partial AUC over the + range [0, max_fpr]. Should be a float between 0 and 1. + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather + + Raises: + ValueError: + If ``average`` is none of ``None``, ``"macro"`` or ``"weighted"``. + ValueError: + If ``max_fpr`` is not a ``float`` in the range ``(0, 1]``. + RuntimeError: + If ``PyTorch version`` is ``below 1.6`` since max_fpr requires ``torch.bucketize`` + which is not available below 1.6. + ValueError: + If the mode of data (binary, multi-label, multi-class) changes between batches. + + Example (binary case): + + >>> from pytorch_lightning.metrics import AUROC + >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) + >>> target = torch.tensor([0, 0, 1, 1, 1]) + >>> auroc = AUROC(pos_label=1) + >>> auroc(preds, target) + tensor(0.5000) + + Example (multiclass case): + + >>> from pytorch_lightning.metrics import AUROC + >>> preds = torch.tensor([[0.90, 0.05, 0.05], + ... [0.05, 0.90, 0.05], + ... [0.05, 0.05, 0.90], + ... [0.85, 0.05, 0.10], + ... [0.10, 0.10, 0.80]]) + >>> target = torch.tensor([0, 1, 1, 2, 2]) + >>> auroc = AUROC(num_classes=3) + >>> auroc(preds, target) + tensor(0.7778) + + """ - @deprecated_metrics(target=_AUROC) def __init__( self, num_classes: Optional[int] = None, @@ -32,9 +111,74 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.num_classes = num_classes + self.pos_label = pos_label + self.average = average + self.max_fpr = max_fpr + + allowed_average = (None, 'macro', 'weighted') + if self.average not in allowed_average: + raise ValueError( + f'Argument `average` expected to be one of the following: {allowed_average} but got {average}' + ) + + if self.max_fpr is not None: + if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1): + raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}") + + if LooseVersion(torch.__version__) < LooseVersion('1.6.0'): + raise RuntimeError( + '`max_fpr` argument requires `torch.bucketize` which is not available below PyTorch version 1.6' + ) + + self.mode = None + self.add_state("preds", default=[], dist_reduce_fx=None) + self.add_state("target", default=[], dist_reduce_fx=None) + + rank_zero_warn( + 'Metric `AUROC` will save all targets and predictions in buffer.' + ' For large datasets this may lead to large memory footprint.' + ) + + def update(self, preds: torch.Tensor, target: torch.Tensor): """ - This implementation refers to :class:`~torchmetrics.AUROC`. + Update state with predictions and targets. - .. deprecated:: - Use :class:`~torchmetrics.AUROC`. Will be removed in v1.5.0. + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth labels + """ + preds, target, mode = _auroc_update(preds, target) + + self.preds.append(preds) + self.target.append(target) + + if self.mode is not None and self.mode != mode: + raise ValueError( + 'The mode of data (binary, multi-label, multi-class) should be constant, but changed' + f' between batches from {self.mode} to {mode}' + ) + self.mode = mode + + def compute(self) -> torch.Tensor: + """ + Computes AUROC based on inputs passed in to ``update`` previously. """ + preds = torch.cat(self.preds, dim=0) + target = torch.cat(self.target, dim=0) + return _auroc_compute( + preds, + target, + self.mode, + self.num_classes, + self.pos_label, + self.average, + self.max_fpr, + ) diff --git a/pytorch_lightning/metrics/classification/average_precision.py b/pytorch_lightning/metrics/classification/average_precision.py index 6c8cdbd52891d..f9c7bde158383 100644 --- a/pytorch_lightning/metrics/classification/average_precision.py +++ b/pytorch_lightning/metrics/classification/average_precision.py @@ -11,16 +11,64 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, List, Optional, Union -from torchmetrics import AveragePrecision as _AveragePrecision +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.average_precision import _average_precision_compute, _average_precision_update +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.utilities import rank_zero_warn -class AveragePrecision(_AveragePrecision): +class AveragePrecision(Metric): + """ + Computes the average precision score, which summarises the precision recall + curve into one number. Works for both binary and multiclass problems. + In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. + + Forward accepts + + - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor + with probabilities, where C is the number of classes. + + - ``target`` (long tensor): ``(N, ...)`` with integer labels + + Args: + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translate to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range [0,num_classes-1] + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example (binary case): + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> average_precision = AveragePrecision(pos_label=1) + >>> average_precision(pred, target) + tensor(1.) + + Example (multiclass case): + + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> average_precision = AveragePrecision(num_classes=5) + >>> average_precision(pred, target) + [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] + + """ - @deprecated_metrics(target=_AveragePrecision) def __init__( self, num_classes: Optional[int] = None, @@ -29,9 +77,48 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + + self.num_classes = num_classes + self.pos_label = pos_label + + self.add_state("preds", default=[], dist_reduce_fx=None) + self.add_state("target", default=[], dist_reduce_fx=None) + + rank_zero_warn( + 'Metric `AveragePrecision` will save all targets and predictions in buffer.' + ' For large datasets this may lead to large memory footprint.' + ) + + def update(self, preds: torch.Tensor, target: torch.Tensor): """ - This implementation refers to :class:`~torchmetrics.AveragePrecision`. + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + preds, target, num_classes, pos_label = _average_precision_update( + preds, target, self.num_classes, self.pos_label + ) + self.preds.append(preds) + self.target.append(target) + self.num_classes = num_classes + self.pos_label = pos_label + + def compute(self) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Compute the average precision score + + Returns: + tensor with average precision. If multiclass will return list + of such tensors, one for each class - .. deprecated:: - Use :class:`~torchmetrics.AveragePrecision`. Will be removed in v1.5.0. """ + preds = torch.cat(self.preds, dim=0) + target = torch.cat(self.target, dim=0) + return _average_precision_compute(preds, target, self.num_classes, self.pos_label) diff --git a/pytorch_lightning/metrics/classification/confusion_matrix.py b/pytorch_lightning/metrics/classification/confusion_matrix.py index 2995f668380de..c3defc82bc92d 100644 --- a/pytorch_lightning/metrics/classification/confusion_matrix.py +++ b/pytorch_lightning/metrics/classification/confusion_matrix.py @@ -13,14 +13,64 @@ # limitations under the License. from typing import Any, Optional -from torchmetrics import ConfusionMatrix as _ConfusionMatrix +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update +from pytorch_lightning.metrics.metric import Metric -class ConfusionMatrix(_ConfusionMatrix): +class ConfusionMatrix(Metric): + """ + Computes the `confusion matrix + `_. Works with binary, + multiclass, and multilabel data. Accepts probabilities from a model output or + integer class values in prediction. Works with multi-dimensional preds and + target. + + Note: + This metric produces a multi-dimensional output, so it can not be directly logged. + + Forward accepts + + - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes + - ``target`` (long tensor): ``(N, ...)`` + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument + to convert into integer labels. This is the case for binary and multi-label probabilities. + + If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + + Args: + num_classes: Number of classes in the dataset. + normalize: Normalization mode for confusion matrix. Choose from + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + + threshold: + Threshold value for binary or multi-label probabilites. default: 0.5 + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example: + + >>> from pytorch_lightning.metrics import ConfusionMatrix + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> confmat = ConfusionMatrix(num_classes=2) + >>> confmat(preds, target) + tensor([[2., 0.], + [1., 1.]]) + + """ - @deprecated_metrics(target=_ConfusionMatrix) def __init__( self, num_classes: int, @@ -30,9 +80,35 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): + + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + self.num_classes = num_classes + self.normalize = normalize + self.threshold = threshold + + allowed_normalize = ('true', 'pred', 'all', 'none', None) + assert self.normalize in allowed_normalize, \ + f"Argument average needs to one of the following: {allowed_normalize}" + + self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values """ - This implementation refers to :class:`~torchmetrics.ConfusionMatrix`. + confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold) + self.confmat += confmat - .. deprecated:: - Use :class:`~torchmetrics.ConfusionMatrix`. Will be removed in v1.5.0. + def compute(self) -> torch.Tensor: + """ + Computes confusion matrix """ + return _confusion_matrix_compute(self.confmat, self.normalize) diff --git a/pytorch_lightning/metrics/classification/f_beta.py b/pytorch_lightning/metrics/classification/f_beta.py index a3f4172f05400..ae01b80966868 100644 --- a/pytorch_lightning/metrics/classification/f_beta.py +++ b/pytorch_lightning/metrics/classification/f_beta.py @@ -13,15 +13,72 @@ # limitations under the License. from typing import Any, Optional -from torchmetrics import F1 as _F1 -from torchmetrics import FBeta as _FBeta +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.f_beta import _fbeta_compute, _fbeta_update +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.utilities import rank_zero_warn -class FBeta(_FBeta): +class FBeta(Metric): + r""" + Computes `F-score `_, specifically: + + .. math:: + F_\beta = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} + {(\beta^2 * \text{precision}) + \text{recall}} + + Where :math:`\beta` is some positive real factor. Works with binary, multiclass, and multilabel data. + Accepts probabilities from a model output or integer class values in prediction. + Works with multi-dimensional preds and target. + + Forward accepts + + - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes + - ``target`` (long tensor): ``(N, ...)`` + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument + to convert into integer labels. This is the case for binary and multi-label probabilities. + + If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + + Args: + num_classes: Number of classes in the dataset. + beta: Beta coefficient in the F measure. + threshold: + Threshold value for binary or multi-label probabilities. default: 0.5 + + average: + - ``'micro'`` computes metric globally + - ``'macro'`` computes metric for each class and uniformly averages them + - ``'weighted'`` computes metric for each class and does a weighted-average, + where each class is weighted by their support (accounts for class imbalance) + - ``'none'`` or ``None`` computes and returns the metric per class + + multilabel: If predictions are from multilabel classification. + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Raises: + ValueError: + If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"``, ``None``. + + Example: + + >>> from pytorch_lightning.metrics import FBeta + >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) + >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) + >>> f_beta = FBeta(num_classes=3, beta=0.5) + >>> f_beta(preds, target) + tensor(0.3333) + + """ - @deprecated_metrics(target=_FBeta) def __init__( self, num_classes: int, @@ -33,17 +90,103 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + + self.num_classes = num_classes + self.beta = beta + self.threshold = threshold + self.average = average + self.multilabel = multilabel + + allowed_average = ("micro", "macro", "weighted", "none", None) + if self.average not in allowed_average: + raise ValueError( + 'Argument `average` expected to be one of the following:' + f' {allowed_average} but got {self.average}' + ) + + self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") + self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") + self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values """ - This implementation refers to :class:`~torchmetrics.FBeta`. + true_positives, predicted_positives, actual_positives = _fbeta_update( + preds, target, self.num_classes, self.threshold, self.multilabel + ) + + self.true_positives += true_positives + self.predicted_positives += predicted_positives + self.actual_positives += actual_positives - .. deprecated:: - Use :class:`~torchmetrics.FBeta`. Will be removed in v1.5.0. + def compute(self) -> torch.Tensor: """ + Computes fbeta over state. + """ + return _fbeta_compute( + self.true_positives, self.predicted_positives, self.actual_positives, self.beta, self.average + ) + + +class F1(FBeta): + """ + Computes F1 metric. F1 metrics correspond to a harmonic mean of the + precision and recall scores. + + Works with binary, multiclass, and multilabel data. + Accepts logits from a model output or integer class values in prediction. + Works with multi-dimensional preds and target. + Forward accepts -class F1(_F1): + - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes + - ``target`` (long tensor): ``(N, ...)`` + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. + This is the case for binary and multi-label logits. + + If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + + Args: + num_classes: Number of classes in the dataset. + threshold: + Threshold value for binary or multi-label logits. default: 0.5 + + average: + - ``'micro'`` computes metric globally + - ``'macro'`` computes metric for each class and uniformly averages them + - ``'weighted'`` computes metric for each class and does a weighted-average, + where each class is weighted by their support (accounts for class imbalance) + - ``'none'`` or ``None`` computes and returns the metric per class + + multilabel: If predictions are from multilabel classification. + compute_on_step: + Forward only calls ``update()`` and returns None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example: + >>> from pytorch_lightning.metrics import F1 + >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) + >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) + >>> f1 = F1(num_classes=3) + >>> f1(preds, target) + tensor(0.3333) + """ - @deprecated_metrics(target=_F1) def __init__( self, num_classes: int, @@ -54,9 +197,16 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - """ - This implementation refers to :class:`~torchmetrics.F1`. + if multilabel is not False: + rank_zero_warn(f'The `multilabel={multilabel}` parameter is unused and will not have any effect.') - .. deprecated:: - Use :class:`~torchmetrics.F1`. Will be removed in v1.5.0. - """ + super().__init__( + num_classes=num_classes, + beta=1.0, + threshold=threshold, + average=average, + multilabel=multilabel, + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) diff --git a/pytorch_lightning/metrics/classification/hamming_distance.py b/pytorch_lightning/metrics/classification/hamming_distance.py index d66b0c2d9cfa8..62b4ae824a6d1 100644 --- a/pytorch_lightning/metrics/classification/hamming_distance.py +++ b/pytorch_lightning/metrics/classification/hamming_distance.py @@ -13,14 +13,61 @@ # limitations under the License. from typing import Any, Callable, Optional -from torchmetrics import HammingDistance as _HammingDistance +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.hamming_distance import _hamming_distance_compute, _hamming_distance_update +from pytorch_lightning.metrics.metric import Metric -class HammingDistance(_HammingDistance): +class HammingDistance(Metric): + r""" + Computes the average `Hamming distance `_ (also + known as Hamming loss) between targets and predictions: + + .. math:: + \text{Hamming distance} = \frac{1}{N \cdot L}\sum_i^N \sum_l^L 1(y_{il} \neq \hat{y_{il}}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that + tensor. + + This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it + treats each possible label separately - meaning that, for example, multi-class data is + treated as if it were multi-label. + + Accepts all input types listed in :ref:`extensions/metrics:input types`. + + Args: + threshold: + Threshold probability value for transforming probability predictions to binary + (0 or 1) predictions, in the case of binary or multi-label inputs. + compute_on_step: + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the all gather. + + Raises: + ValueError: + If ``threshold`` is not between ``0`` and ``1``. + + Example: + + >>> from pytorch_lightning.metrics import HammingDistance + >>> target = torch.tensor([[0, 1], [1, 1]]) + >>> preds = torch.tensor([[0, 1], [0, 1]]) + >>> hamming_distance = HammingDistance() + >>> hamming_distance(preds, target) + tensor(0.2500) + + """ - @deprecated_metrics(target=_HammingDistance) def __init__( self, threshold: float = 0.5, @@ -29,9 +76,36 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + if not 0 < threshold < 1: + raise ValueError("The `threshold` should lie in the (0,1) interval.") + self.threshold = threshold + + def update(self, preds: torch.Tensor, target: torch.Tensor): """ - This implementation refers to :class:`~torchmetrics.HammingDistance`. + Update state with predictions and targets. See :ref:`extensions/metrics:input types` for more information + on input types. - .. deprecated:: - Use :class:`~torchmetrics.HammingDistance`. Will be removed in v1.5.0. + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth labels + """ + correct, total = _hamming_distance_update(preds, target, self.threshold) + + self.correct += correct + self.total += total + + def compute(self) -> torch.Tensor: + """ + Computes hamming distance based on inputs passed in to ``update`` previously. """ + return _hamming_distance_compute(self.correct, self.total) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py new file mode 100644 index 0000000000000..ea6d5722b3041 --- /dev/null +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -0,0 +1,539 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple + +import numpy as np +import torch + +from pytorch_lightning.metrics.utils import select_topk, to_onehot +from pytorch_lightning.utilities import LightningEnum + + +class DataType(LightningEnum): + """ + Enum to represent data type + """ + + BINARY = "binary" + MULTILABEL = "multi-label" + MULTICLASS = "multi-class" + MULTIDIM_MULTICLASS = "multi-dim multi-class" + + +class AverageMethod(LightningEnum): + """ + Enum to represent average method + """ + + MICRO = "micro" + MACRO = "macro" + WEIGHTED = "weighted" + NONE = "none" + SAMPLES = "samples" + + +class MDMCAverageMethod(LightningEnum): + """ + Enum to represent multi-dim multi-class average method + """ + + GLOBAL = "global" + SAMPLEWISE = "samplewise" + + +def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool): + """ + Perform basic validation of inputs that does not require deducing any information + of the type of inputs. + """ + + if target.is_floating_point(): + raise ValueError("The `target` has to be an integer tensor.") + if target.min() < 0: + raise ValueError("The `target` has to be a non-negative tensor.") + + preds_float = preds.is_floating_point() + if not preds_float and preds.min() < 0: + raise ValueError("If `preds` are integers, they have to be non-negative.") + + if not preds.shape[0] == target.shape[0]: + raise ValueError("The `preds` and `target` should have the same first dimension.") + + if preds_float and (preds.min() < 0 or preds.max() > 1): + raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") + + if not 0 < threshold < 1: + raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") + + if is_multiclass is False and target.max() > 1: + raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") + + if is_multiclass is False and not preds_float and preds.max() > 1: + raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") + + +def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[str, int]: + """ + This checks that the shape and type of inputs are consistent with + each other and fall into one of the allowed input types (see the + documentation of docstring of ``_input_format_classification``). It does + not check for consistency of number of classes, other functions take + care of that. + + It returns the name of the case in which the inputs fall, and the implied + number of classes (from the ``C`` dim for multi-class data, or extra dim(s) for + multi-label data). + """ + + preds_float = preds.is_floating_point() + + if preds.ndim == target.ndim: + if preds.shape != target.shape: + raise ValueError( + "The `preds` and `target` should have the same shape,", + f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", + ) + if preds_float and target.max() > 1: + raise ValueError( + "If `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." + ) + + # Get the case + if preds.ndim == 1 and preds_float: + case = DataType.BINARY + elif preds.ndim == 1 and not preds_float: + case = DataType.MULTICLASS + elif preds.ndim > 1 and preds_float: + case = DataType.MULTILABEL + else: + case = DataType.MULTIDIM_MULTICLASS + + implied_classes = preds[0].numel() + + elif preds.ndim == target.ndim + 1: + if not preds_float: + raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "If `preds` have one dimension more than `target`, the shape of `preds` should be" + " (N, C, ...), and the shape of `target` should be (N, ...)." + ) + + implied_classes = preds.shape[1] + + if preds.ndim == 2: + case = DataType.MULTICLASS + else: + case = DataType.MULTIDIM_MULTICLASS + else: + raise ValueError( + "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" + " and `preds` should be (N, C, ...)." + ) + + return case, implied_classes + + +def _check_num_classes_binary(num_classes: int, is_multiclass: bool): + """ + This checks that the consistency of `num_classes` with the data + and `is_multiclass` param for binary data. + """ + + if num_classes > 2: + raise ValueError("Your data is binary, but `num_classes` is larger than 2.") + if num_classes == 2 and not is_multiclass: + raise ValueError( + "Your data is binary and `num_classes=2`, but `is_multiclass` is not True." + " Set it to True if you want to transform binary data to multi-class format." + ) + if num_classes == 1 and is_multiclass: + raise ValueError( + "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." + " Either set `is_multiclass=None`(default) or set `num_classes=2`" + " to transform binary data to multi-class format." + ) + + +def _check_num_classes_mc( + preds: torch.Tensor, target: torch.Tensor, num_classes: int, is_multiclass: bool, implied_classes: int +): + """ + This checks that the consistency of `num_classes` with the data + and `is_multiclass` param for (multi-dimensional) multi-class data. + """ + + if num_classes == 1 and is_multiclass is not False: + raise ValueError( + "You have set `num_classes=1`, but predictions are integers." + " If you want to convert (multi-dimensional) multi-class data with 2 classes" + " to binary/multi-label, set `is_multiclass=False`." + ) + if num_classes > 1: + if is_multiclass is False: + if implied_classes != num_classes: + raise ValueError( + "You have set `is_multiclass=False`, but the implied number of classes " + " (from shape of inputs) does not match `num_classes`. If you are trying to" + " transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`" + " should be either None or the product of the size of extra dimensions (...)." + " See Input Types in Metrics documentation." + ) + if num_classes <= target.max(): + raise ValueError("The highest label in `target` should be smaller than `num_classes`.") + if num_classes <= preds.max(): + raise ValueError("The highest label in `preds` should be smaller than `num_classes`.") + if preds.shape != target.shape and num_classes != implied_classes: + raise ValueError("The size of C dimension of `preds` does not match `num_classes`.") + + +def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes: int): + """ + This checks that the consistency of `num_classes` with the data + and `is_multiclass` param for multi-label data. + """ + + if is_multiclass and num_classes != 2: + raise ValueError( + "Your have set `is_multiclass=True`, but `num_classes` is not equal to 2." + " If you are trying to transform multi-label data to 2 class multi-dimensional" + " multi-class, you should set `num_classes` to either 2 or None." + ) + if not is_multiclass and num_classes != implied_classes: + raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") + + +def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool): + if case == DataType.BINARY: + raise ValueError("You can not use `top_k` parameter with binary data.") + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("The `top_k` has to be an integer larger than 0.") + if not preds_float: + raise ValueError("You have set `top_k`, but you do not have probability predictions.") + if is_multiclass is False: + raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.") + if case == DataType.MULTILABEL and is_multiclass: + raise ValueError( + "If you want to transform multi-label data to 2 class multi-dimensional" + "multi-class data using `is_multiclass=True`, you can not use `top_k`." + ) + if top_k >= implied_classes: + raise ValueError("The `top_k` has to be strictly smaller than the `C` dimension of `preds`.") + + +def _check_classification_inputs( + preds: torch.Tensor, + target: torch.Tensor, + threshold: float, + num_classes: Optional[int], + is_multiclass: bool, + top_k: Optional[int], +) -> str: + """Performs error checking on inputs for classification. + + This ensures that preds and target take one of the shape/type combinations that are + specified in ``_input_format_classification`` docstring. It also checks the cases of + over-rides with ``is_multiclass`` by checking (for multi-class and multi-dim multi-class + cases) that there are only up to 2 distinct labels. + + In case where preds are floats (probabilities), it is checked whether they are in [0,1] interval. + + When ``num_classes`` is given, it is checked that it is consitent with input cases (binary, + multi-label, ...), and that, if availible, the implied number of classes in the ``C`` + dimension is consistent with it (as well as that max label in target is smaller than it). + + When ``num_classes`` is not specified in these cases, consistency of the highest target + value against ``C`` dimension is checked for (multi-dimensional) multi-class cases. + + If ``top_k`` is set (not None) for inputs that do not have probability predictions (and + are not binary), an error is raised. Similarly if ``top_k`` is set to a number that + is higher than or equal to the ``C`` dimension of ``preds``, an error is raised. + + Preds and target tensors are expected to be squeezed already - all dimensions should be + greater than 1, except perhaps the first one (``N``). + + Args: + preds: Tensor with predictions (labels or probabilities) + target: Tensor with ground truth labels, always integers (labels) + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. + num_classes: + Number of classes. If not explicitly set, the number of classes will be infered + either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` + tensor, where applicable. + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. The default value (``None``) will be + interepreted as 1 for these inputs. If this parameter is set for multi-label inputs, + it will take precedence over threshold. + + Should be left unset (``None``) for inputs with label predictions. + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + + Return: + case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or + 'multi-dim multi-class' + """ + + # Baisc validation (that does not need case/type information) + _basic_input_validation(preds, target, threshold, is_multiclass) + + # Check that shape/types fall into one of the cases + case, implied_classes = _check_shape_and_type_consistency(preds, target) + + # For (multi-dim) multi-class case with prob preds, check that preds sum up to 1 + if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and preds.is_floating_point(): + if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all(): + raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.") + + # Check consistency with the `C` dimension in case of multi-class data + if preds.shape != target.shape: + if is_multiclass is False and implied_classes != 2: + raise ValueError( + "You have set `is_multiclass=False`, but have more than 2 classes in your data," + " based on the C dimension of `preds`." + ) + if target.max() >= implied_classes: + raise ValueError( + "The highest label in `target` should be smaller than the size of the `C` dimension of `preds`." + ) + + # Check that num_classes is consistent + if num_classes: + if case == DataType.BINARY: + _check_num_classes_binary(num_classes, is_multiclass) + elif case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS): + _check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes) + elif case.MULTILABEL: + _check_num_classes_ml(num_classes, is_multiclass, implied_classes) + + # Check that top_k is consistent + if top_k is not None: + _check_top_k(top_k, case, implied_classes, is_multiclass, preds.is_floating_point()) + + return case + + +def _input_format_classification( + preds: torch.Tensor, + target: torch.Tensor, + threshold: float = 0.5, + top_k: Optional[int] = None, + num_classes: Optional[int] = None, + is_multiclass: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor, str]: + """Convert preds and target tensors into common format. + + Preds and targets are supposed to fall into one of these categories (and are + validated to make sure this is the case): + + * Both preds and target are of shape ``(N,)``, and both are integers (multi-class) + * Both preds and target are of shape ``(N,)``, and target is binary, while preds + are a float (binary) + * preds are of shape ``(N, C)`` and are floats, and target is of shape ``(N,)`` and + is integer (multi-class) + * preds and target are of shape ``(N, ...)``, target is binary and preds is a float + (multi-label) + * preds are of shape ``(N, C, ...)`` and are floats, target is of shape ``(N, ...)`` + and is integer (multi-dimensional multi-class) + * preds and target are of shape ``(N, ...)`` both are integers (multi-dimensional + multi-class) + + To avoid ambiguities, all dimensions of size 1, except the first one, are squeezed out. + + The returned output tensors will be binary tensors of the same shape, either ``(N, C)`` + of ``(N, C, X)``, the details for each case are described below. The function also returns + a ``case`` string, which describes which of the above cases the inputs belonged to - regardless + of whether this was "overridden" by other settings (like ``is_multiclass``). + + In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed + into a binary tensor (elements become 1 if the probability is greater than or equal to + ``threshold`` or 0 otherwise). If ``is_multiclass=True``, then then both targets are preds + become ``(N, 2)`` tensors by a one-hot transformation; with the thresholding being applied to + preds first. + + In multi-class case, normally both preds and targets become ``(N, C)`` binary tensors; targets + by a one-hot transformation and preds by selecting ``top_k`` largest entries (if their original + shape was ``(N,C)``). However, if ``is_multiclass=False``, then targets and preds will be + returned as ``(N,1)`` tensor. + + In multi-label case, normally targets and preds are returned as ``(N, C)`` binary tensors, with + preds being binarized as in the binary case. Here the ``C`` dimension is obtained by flattening + all dimensions after the first one. However if ``is_multiclass=True``, then both are returned as + ``(N, 2, C)``, by an equivalent transformation as in the binary case. + + In multi-dimensional multi-class case, normally both target and preds are returned as + ``(N, C, X)`` tensors, with ``X`` resulting from flattening of all dimensions except ``N`` and + ``C``. The transformations performed here are equivalent to the multi-class case. However, if + ``is_multiclass=False`` (and there are up to two classes), then the data is returned as + ``(N, X)`` binary tensors (multi-label). + + Note that where a one-hot transformation needs to be performed and the number of classes + is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be + equal to ``num_classes``, if it is given, or the maximum label value in preds and + target. + + Args: + preds: Tensor with predictions (labels or probabilities) + target: Tensor with ground truth labels, always integers (labels) + threshold: + Threshold probability value for transforming probability predictions to binary + (0 or 1) predictions, in the case of binary or multi-label inputs. + num_classes: + Number of classes. If not explicitly set, the number of classes will be infered + either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` + tensor, where applicable. + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for (multi-dimensional) multi-class inputs with probability predictions. The + default value (``None``) will be interepreted as 1 for these inputs. + + Should be left unset (``None``) for all other types of inputs. + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + + Returns: + preds: binary tensor of shape ``(N, C)`` or ``(N, C, X)`` + target: binary tensor of shape ``(N, C)`` or ``(N, C, X)`` + case: The case the inputs fall in, one of ``'binary'``, ``'multi-class'``, ``'multi-label'`` or + ``'multi-dim multi-class'`` + """ + # Remove excess dimensions + if preds.shape[0] == 1: + preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0) + else: + preds, target = preds.squeeze(), target.squeeze() + + # Convert half precision tensors to full precision, as not all ops are supported + # for example, min() is not supported + if preds.dtype == torch.float16: + preds = preds.float() + + case = _check_classification_inputs( + preds, + target, + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) + + if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k: + preds = (preds >= threshold).int() + num_classes = num_classes if not is_multiclass else 2 + + if case == DataType.MULTILABEL and top_k: + preds = select_topk(preds, top_k) + + if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or is_multiclass: + if preds.is_floating_point(): + num_classes = preds.shape[1] + preds = select_topk(preds, top_k or 1) + else: + num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1 + preds = to_onehot(preds, max(2, num_classes)) + + target = to_onehot(target, max(2, num_classes)) + + if is_multiclass is False: + preds, target = preds[:, 1, ...], target[:, 1, ...] + + if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and is_multiclass is not False) or is_multiclass: + target = target.reshape(target.shape[0], target.shape[1], -1) + preds = preds.reshape(preds.shape[0], preds.shape[1], -1) + else: + target = target.reshape(target.shape[0], -1) + preds = preds.reshape(preds.shape[0], -1) + + # Some operatins above create an extra dimension for MC/binary case - this removes it + if preds.ndim > 2: + preds, target = preds.squeeze(-1), target.squeeze(-1) + + return preds.int(), target.int(), case + + +def _reduce_stat_scores( + numerator: torch.Tensor, + denominator: torch.Tensor, + weights: Optional[torch.Tensor], + average: str, + mdmc_average: Optional[str], + zero_division: int = 0, +) -> torch.Tensor: + """ + Reduces scores of type ``numerator/denominator`` or + ``weights * (numerator/denominator)``, if ``average='weighted'``. + + Args: + numerator: A tensor with numerator numbers. + denominator: A tensor with denominator numbers. If a denominator is + negative, the class will be ignored (if averaging), or its score + will be returned as ``nan`` (if ``average=None``). + If the denominator is zero, then ``zero_division`` score will be + used for those elements. + weights: + A tensor of weights to be used if ``average='weighted'``. + average: + The method to average the scores. Should be one of ``'micro'``, ``'macro'``, + ``'weighted'``, ``'none'``, ``None`` or ``'samples'``. The behavior + corresponds to `sklearn averaging methods `__. + mdmc_average: + The method to average the scores if inputs were multi-dimensional multi-class (MDMC). + Should be either ``'global'`` or ``'samplewise'``. If inputs were not + multi-dimensional multi-class, it should be ``None`` (default). + zero_division: + The value to use for the score if denominator equals zero. + """ + numerator, denominator = numerator.float(), denominator.float() + zero_div_mask = denominator == 0 + ignore_mask = denominator < 0 + + if weights is None: + weights = torch.ones_like(denominator) + else: + weights = weights.float() + + numerator = torch.where(zero_div_mask, torch.tensor(float(zero_division), device=numerator.device), numerator) + denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator) + weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights) + + if average not in (AverageMethod.MICRO, AverageMethod.NONE, None): + weights = weights / weights.sum(dim=-1, keepdim=True) + + scores = weights * (numerator / denominator) + + # This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted' + scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores) + + if mdmc_average == MDMCAverageMethod.SAMPLEWISE: + scores = scores.mean(dim=0) + ignore_mask = ignore_mask.sum(dim=0).bool() + + if average in (AverageMethod.NONE, None): + scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores) + else: + scores = scores.sum() + + return scores diff --git a/pytorch_lightning/metrics/classification/iou.py b/pytorch_lightning/metrics/classification/iou.py index f1d9d0945511a..a261b767a8190 100644 --- a/pytorch_lightning/metrics/classification/iou.py +++ b/pytorch_lightning/metrics/classification/iou.py @@ -13,14 +13,70 @@ # limitations under the License. from typing import Any, Optional -from torchmetrics import IoU as _IoU +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix +from pytorch_lightning.metrics.functional.iou import _iou_from_confmat -class IoU(_IoU): +class IoU(ConfusionMatrix): + r""" + Computes `Intersection over union, or Jaccard index calculation `_: + + .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} + + Where: :math:`A` and :math:`B` are both tensors of the same size, containing integer class values. + They may be subject to conversion from input data (see description below). Note that it is different from box IoU. + + Works with binary, multiclass and multi-label data. + Accepts probabilities from a model output or integer class values in prediction. + Works with multi-dimensional preds and target. + + Forward accepts + + - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes + - ``target`` (long tensor): ``(N, ...)`` + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument + to convert into integer labels. This is the case for binary and multi-label probabilities. + + If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + + Args: + num_classes: Number of classes in the dataset. + ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. Has no effect if given an int that is not in the + range [0, num_classes-1]. By default, no index is ignored, and all classes are used. + absent_score: score to use for an individual class, if no instances of the class index were present in + `pred` AND no instances of the class index were present in `target`. For example, if we have 3 classes, + [0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be assigned the `absent_score`. + threshold: + Threshold value for binary or multi-label probabilities. + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example: + >>> from pytorch_lightning.metrics import IoU + >>> target = torch.randint(0, 2, (10, 25, 25)) + >>> pred = torch.tensor(target) + >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] + >>> iou = IoU(num_classes=2) + >>> iou(pred, target) + tensor(0.9660) + + """ - @deprecated_metrics(target=_IoU) def __init__( self, num_classes: int, @@ -32,9 +88,20 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - """ - This implementation refers to :class:`~torchmetrics.IoU`. + super().__init__( + num_classes=num_classes, + normalize=None, + threshold=threshold, + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + self.reduction = reduction + self.ignore_index = ignore_index + self.absent_score = absent_score - .. deprecated:: - Use :class:`~torchmetrics.IoU`. Will be removed in v1.5.0. + def compute(self) -> torch.Tensor: + """ + Computes intersection over union (IoU) """ + return _iou_from_confmat(self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction) diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 7b95d21dae97c..11862769e62a8 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -13,15 +13,116 @@ # limitations under the License. from typing import Any, Callable, Optional -from torchmetrics import Precision as _Precision -from torchmetrics import Recall as _Recall +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.classification.stat_scores import StatScores +from pytorch_lightning.metrics.functional.precision_recall import _precision_compute, _recall_compute -class Precision(_Precision): +class Precision(StatScores): + r""" + Computes `Precision `_: + + .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} + + Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and + false positives respecitively. With the use of ``top_k`` parameter, this metric can + generalize to Precision@K. + + The reduction method (how the precision scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. + + Args: + num_classes: + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics accross classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics accross classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + multilabel: + .. warning :: This parameter is deprecated and has no effect. Will be removed in v1.4.0. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs + (see :ref:`extensions/metrics:input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. + + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. If this parameter is set for multi-label + inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, + this parameter defaults to 1. + + Should be left unset (``None``) for inputs with label predictions. + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + compute_on_step: + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step + process_group: + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather. + + Raises: + ValueError: + If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + + Example: + + >>> from pytorch_lightning.metrics import Precision + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> precision = Precision(average='macro', num_classes=3) + >>> precision(preds, target) + tensor(0.1667) + >>> precision = Precision(average='micro') + >>> precision(preds, target) + tensor(0.2500) + + """ - @deprecated_metrics(target=_Precision) def __init__( self, num_classes: Optional[int] = None, @@ -37,17 +138,146 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + super().__init__( + reduce="macro" if average in ["weighted", "none", None] else average, + mdmc_reduce=mdmc_average, + threshold=threshold, + top_k=top_k, + num_classes=num_classes, + is_multiclass=is_multiclass, + ignore_index=ignore_index, + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.average = average + + def compute(self) -> torch.Tensor: """ - This implementation refers to :class:`~torchmetrics.Precision`. + Computes the precision score based on inputs passed in to ``update`` previously. - .. deprecated:: - Use :class:`~torchmetrics.Precision`. Will be removed in v1.5.0. + Return: + The shape of the returned tensor depends on the ``average`` parameter + + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number + of classes """ + tp, fp, tn, fn = self._get_final_stats() + return _precision_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) + +class Recall(StatScores): + r""" + Computes `Recall `_: -class Recall(_Recall): + .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} + + Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and + false negatives respecitively. With the use of ``top_k`` parameter, this metric can + generalize to Recall@K. + + The reduction method (how the recall scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. + + Args: + num_classes: + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics accross classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics accross classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + multilabel: + .. warning :: This parameter is deprecated and has no effect. Will be removed in v1.4.0. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs + (see :ref:`extensions/metrics:input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. + + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. If this parameter is set for multi-label + inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, + this parameter defaults to 1. + + Should be left unset (``None``) for inputs with label predictions. + + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + compute_on_step: + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step + process_group: + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather. + + Raises: + ValueError: + If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + + Example: + + >>> from pytorch_lightning.metrics import Recall + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> recall = Recall(average='macro', num_classes=3) + >>> recall(preds, target) + tensor(0.3333) + >>> recall = Recall(average='micro') + >>> recall(preds, target) + tensor(0.2500) + + """ - @deprecated_metrics(target=_Recall) def __init__( self, num_classes: Optional[int] = None, @@ -63,9 +293,36 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + super().__init__( + reduce="macro" if average in ["weighted", "none", None] else average, + mdmc_reduce=mdmc_average, + threshold=threshold, + top_k=top_k, + num_classes=num_classes, + is_multiclass=is_multiclass, + ignore_index=ignore_index, + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.average = average + + def compute(self) -> torch.Tensor: """ - This implementation refers to :class:`~torchmetrics.Recall`. + Computes the recall score based on inputs passed in to ``update`` previously. + + Return: + The shape of the returned tensor depends on the ``average`` parameter - .. deprecated:: - Use :class:`~torchmetrics.Recall`. Will be removed in v1.5.0. + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number + of classes """ + tp, fp, tn, fn = self._get_final_stats() + return _recall_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) diff --git a/pytorch_lightning/metrics/classification/precision_recall_curve.py b/pytorch_lightning/metrics/classification/precision_recall_curve.py index 285cb2fb78ccc..5a02a99ed17fd 100644 --- a/pytorch_lightning/metrics/classification/precision_recall_curve.py +++ b/pytorch_lightning/metrics/classification/precision_recall_curve.py @@ -11,16 +11,80 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, List, Optional, Tuple, Union -from torchmetrics import PrecisionRecallCurve as _PrecisionRecallCurve +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.precision_recall_curve import ( + _precision_recall_curve_compute, + _precision_recall_curve_update, +) +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.utilities import rank_zero_warn -class PrecisionRecallCurve(_PrecisionRecallCurve): +class PrecisionRecallCurve(Metric): + """ + Computes precision-recall pairs for different thresholds. Works for both + binary and multiclass problems. In the case of multiclass, the values will + be calculated based on a one-vs-the-rest approach. + + Forward accepts + + - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor + with probabilities, where C is the number of classes. + + - ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels + + Args: + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translate to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range [0,num_classes-1] + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example (binary case): + + >>> from pytorch_lightning.metrics import PrecisionRecallCurve + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> pr_curve = PrecisionRecallCurve(pos_label=1) + >>> precision, recall, thresholds = pr_curve(pred, target) + >>> precision + tensor([0.6667, 0.5000, 0.0000, 1.0000]) + >>> recall + tensor([1.0000, 0.5000, 0.0000, 0.0000]) + >>> thresholds + tensor([1, 2, 3]) + + Example (multiclass case): + + >>> from pytorch_lightning.metrics import PrecisionRecallCurve + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> pr_curve = PrecisionRecallCurve(num_classes=5) + >>> precision, recall, thresholds = pr_curve(pred, target) + >>> precision # doctest: +NORMALIZE_WHITESPACE + [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), + tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] + >>> recall + [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] + >>> thresholds + [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] + + """ - @deprecated_metrics(target=_PrecisionRecallCurve) def __init__( self, num_classes: Optional[int] = None, @@ -29,9 +93,60 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + + self.num_classes = num_classes + self.pos_label = pos_label + + self.add_state("preds", default=[], dist_reduce_fx=None) + self.add_state("target", default=[], dist_reduce_fx=None) + + rank_zero_warn( + 'Metric `PrecisionRecallCurve` will save all targets and predictions in buffer.' + ' For large datasets this may lead to large memory footprint.' + ) + + def update(self, preds: torch.Tensor, target: torch.Tensor): """ - This implementation refers to :class:`~torchmetrics.PrecisionRecallCurve`. + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + preds, target, num_classes, pos_label = _precision_recall_curve_update( + preds, target, self.num_classes, self.pos_label + ) + self.preds.append(preds) + self.target.append(target) + self.num_classes = num_classes + self.pos_label = pos_label + + def compute( + self + ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], + List[torch.Tensor]]]: + """ + Compute the precision-recall curve + + Returns: + 3-element tuple containing - .. deprecated:: - Use :class:`~torchmetrics.PrecisionRecallCurve`. Will be removed in v1.5.0. + precision: + tensor where element i is the precision of predictions with + score >= thresholds[i] and the last element is 1. + If multiclass, this is a list of such tensors, one for each class. + recall: + tensor where element i is the recall of predictions with + score >= thresholds[i] and the last element is 0. + If multiclass, this is a list of such tensors, one for each class. + thresholds: + Thresholds used for computing precision/recall scores """ + preds = torch.cat(self.preds, dim=0) + target = torch.cat(self.target, dim=0) + return _precision_recall_curve_compute(preds, target, self.num_classes, self.pos_label) diff --git a/pytorch_lightning/metrics/classification/roc.py b/pytorch_lightning/metrics/classification/roc.py index 3f6cf50803c86..598646cde3861 100644 --- a/pytorch_lightning/metrics/classification/roc.py +++ b/pytorch_lightning/metrics/classification/roc.py @@ -11,16 +11,79 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, List, Optional, Tuple, Union -from torchmetrics import ROC as _ROC +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.roc import _roc_compute, _roc_update +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.utilities import rank_zero_warn -class ROC(_ROC): +class ROC(Metric): + """ + Computes the Receiver Operating Characteristic (ROC). Works for both + binary and multiclass problems. In the case of multiclass, the values will + be calculated based on a one-vs-the-rest approach. + + Forward accepts + + - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor + with probabilities, where C is the number of classes. + + - ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels + + Args: + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translate to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range [0,num_classes-1] + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example (binary case): + + >>> from pytorch_lightning.metrics import ROC + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> roc = ROC(pos_label=1) + >>> fpr, tpr, thresholds = roc(pred, target) + >>> fpr + tensor([0., 0., 0., 0., 1.]) + >>> tpr + tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) + >>> thresholds + tensor([4, 3, 2, 1, 0]) + + Example (multiclass case): + + >>> from pytorch_lightning.metrics import ROC + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05], + ... [0.05, 0.05, 0.05, 0.75]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> roc = ROC(num_classes=4) + >>> fpr, tpr, thresholds = roc(pred, target) + >>> fpr + [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] + >>> tpr + [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500])] + + """ - @deprecated_metrics(target=_ROC) def __init__( self, num_classes: Optional[int] = None, @@ -29,9 +92,56 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + + self.num_classes = num_classes + self.pos_label = pos_label + + self.add_state("preds", default=[], dist_reduce_fx=None) + self.add_state("target", default=[], dist_reduce_fx=None) + + rank_zero_warn( + 'Metric `ROC` will save all targets and predictions in buffer.' + ' For large datasets this may lead to large memory footprint.' + ) + + def update(self, preds: torch.Tensor, target: torch.Tensor): """ - This implementation refers to :class:`~torchmetrics.ROC`. + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + preds, target, num_classes, pos_label = _roc_update(preds, target, self.num_classes, self.pos_label) + self.preds.append(preds) + self.target.append(target) + self.num_classes = num_classes + self.pos_label = pos_label + + def compute( + self + ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], + List[torch.Tensor]]]: + """ + Compute the receiver operating characteristic + + Returns: + 3-element tuple containing - .. deprecated:: - Use :class:`~torchmetrics.ROC`. Will be removed in v1.5.0. + fpr: + tensor with false positive rates. + If multiclass, this is a list of such tensors, one for each class. + tpr: + tensor with true positive rates. + If multiclass, this is a list of such tensors, one for each class. + thresholds: + thresholds used for computing false- and true postive rates """ + preds = torch.cat(self.preds, dim=0) + target = torch.cat(self.target, dim=0) + return _roc_compute(preds, target, self.num_classes, self.pos_label) diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py index 1eed815d4b4cd..4ac47ea466ada 100644 --- a/pytorch_lightning/metrics/classification/stat_scores.py +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -11,16 +11,125 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Tuple -from torchmetrics import StatScores as _StatScores +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_compute, _stat_scores_update +from pytorch_lightning.metrics.metric import Metric -class StatScores(_StatScores): +class StatScores(Metric): + """Computes the number of true positives, false positives, true negatives, false negatives. + Related to `Type I and Type II errors `__ + and the `confusion matrix `__. + + The reduction method (how the statistics are aggregated) is controlled by the + ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the + multi-dimensional multi-class case. + + Accepts all inputs listed in :ref:`extensions/metrics:input types`. + + Args: + threshold: + Threshold probability value for transforming probability predictions to binary + (0 or 1) predictions, in the case of binary or multi-label inputs. + + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. If this parameter is set for multi-label + inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, + this parameter defaults to 1. + + Should be left unset (``None``) for inputs with label predictions. + + reduce: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Counts the statistics by summing over all [sample, class] + combinations (globally). Each statistic is represented by a single integer. + - ``'macro'``: Counts the statistics for each class separately (over all samples). + Each statistic is represented by a ``(C,)`` tensor. Requires ``num_classes`` + to be set. + - ``'samples'``: Counts the statistics for each sample separately (over all classes). + Each statistic is represented by a ``(N, )`` 1d tensor. + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_reduce``. + + num_classes: + Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + + ignore_index: + Specify a class (label) to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and + ``reduce='macro'``, the class statistics for the ignored class will all be returned + as ``-1``. + + mdmc_reduce: + Defines how the multi-dimensional multi-class inputs are handeled. Should be + one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class (see :ref:`extensions/metrics:input types` for the definition of input types). + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then the outputs are concatenated together. In each + sample the extra axes ``...`` are flattened to become the sub-sample axis, and + statistics for each sample are computed by treating the sub-sample axis as the + ``N`` axis for that sample. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs are + flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``reduce`` parameter applies as usual. + + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + compute_on_step: + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step + process_group: + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather. + + Raises: + ValueError: + If ``threshold`` is not a ``float`` between ``0`` and ``1``. + ValueError: + If ``reduce`` is none of ``"micro"``, ``"macro"`` or ``"samples"``. + ValueError: + If ``mdmc_reduce`` is none of ``None``, ``"samplewise"``, ``"global"``. + ValueError: + If ``reduce`` is set to ``"macro"`` and ``num_classes`` is not provided. + ValueError: + If ``num_classes`` is set + and ``ignore_index`` is not in the range ``0`` <= ``ignore_index`` < ``num_classes``. + + Example: + + >>> from pytorch_lightning.metrics.classification import StatScores + >>> preds = torch.tensor([1, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> stat_scores = StatScores(reduce='macro', num_classes=3) + >>> stat_scores(preds, target) + tensor([[0, 1, 2, 1, 1], + [1, 1, 1, 1, 2], + [1, 0, 3, 0, 1]]) + >>> stat_scores = StatScores(reduce='micro') + >>> stat_scores(preds, target) + tensor([2, 2, 6, 2, 4]) + + """ - @deprecated_metrics(target=_StatScores) def __init__( self, threshold: float = 0.5, @@ -35,9 +144,129 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.reduce = reduce + self.mdmc_reduce = mdmc_reduce + self.num_classes = num_classes + self.threshold = threshold + self.is_multiclass = is_multiclass + self.ignore_index = ignore_index + self.top_k = top_k + + if not 0 < threshold < 1: + raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") + + if reduce not in ["micro", "macro", "samples"]: + raise ValueError(f"The `reduce` {reduce} is not valid.") + + if mdmc_reduce not in [None, "samplewise", "global"]: + raise ValueError(f"The `mdmc_reduce` {mdmc_reduce} is not valid.") + + if reduce == "macro" and (not num_classes or num_classes < 1): + raise ValueError("When you set `reduce` as 'macro', you have to provide the number of classes.") + + if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + + if mdmc_reduce != "samplewise" and reduce != "samples": + if reduce == "micro": + zeros_shape = [] + elif reduce == "macro": + zeros_shape = (num_classes, ) + default, reduce_fn = lambda: torch.zeros(zeros_shape, dtype=torch.long), "sum" + else: + default, reduce_fn = lambda: [], None + + for s in ("tp", "fp", "tn", "fn"): + self.add_state(s, default=default(), dist_reduce_fx=reduce_fn) + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. See :ref:`extensions/metrics:input types` for more information + on input types. + + Args: + preds: Predictions from model (probabilities or labels) + target: Ground truth values + """ + + tp, fp, tn, fn = _stat_scores_update( + preds, + target, + reduce=self.reduce, + mdmc_reduce=self.mdmc_reduce, + threshold=self.threshold, + num_classes=self.num_classes, + top_k=self.top_k, + is_multiclass=self.is_multiclass, + ignore_index=self.ignore_index, + ) + + # Update states + if self.reduce != "samples" and self.mdmc_reduce != "samplewise": + self.tp += tp + self.fp += fp + self.tn += tn + self.fn += fn + else: + self.tp.append(tp) + self.fp.append(fp) + self.tn.append(tn) + self.fn.append(fn) + + def _get_final_stats(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Performs concatenation on the stat scores if neccesary, + before passing them to a compute function. """ - This implementation refers to :class:`~torchmetrics.StatScores`. - .. deprecated:: - Use :class:`~torchmetrics.StatScores`. Will be removed in v1.5.0. + if isinstance(self.tp, list): + tp = torch.cat(self.tp) + fp = torch.cat(self.fp) + tn = torch.cat(self.tn) + fn = torch.cat(self.fn) + else: + tp, fp, tn, fn = self.tp, self.fp, self.tn, self.fn + + return tp, fp, tn, fn + + def compute(self) -> torch.Tensor: + """ + Computes the stat scores based on inputs passed in to ``update`` previously. + + Return: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The + shape depends on the ``reduce`` and ``mdmc_reduce`` (in case of multi-dimensional + multi-class data) parameters: + + - If the data is not multi-dimensional multi-class, then + + - If ``reduce='micro'``, the shape will be ``(5, )`` + - If ``reduce='macro'``, the shape will be ``(C, 5)``, + where ``C`` stands for the number of classes + - If ``reduce='samples'``, the shape will be ``(N, 5)``, where ``N`` stands for + the number of samples + + - If the data is multi-dimensional multi-class and ``mdmc_reduce='global'``, then + + - If ``reduce='micro'``, the shape will be ``(5, )`` + - If ``reduce='macro'``, the shape will be ``(C, 5)`` + - If ``reduce='samples'``, the shape will be ``(N*X, 5)``, where ``X`` stands for + the product of sizes of all "extra" dimensions of the data (i.e. all dimensions + except for ``C`` and ``N``) + + - If the data is multi-dimensional multi-class and ``mdmc_reduce='samplewise'``, then + + - If ``reduce='micro'``, the shape will be ``(N, 5)`` + - If ``reduce='macro'``, the shape will be ``(N, C, 5)`` + - If ``reduce='samples'``, the shape will be ``(N, X, 5)`` + """ + tp, fp, tn, fn = self._get_final_stats() + return _stat_scores_compute(tp, fp, tn, fn) diff --git a/pytorch_lightning/metrics/compositional.py b/pytorch_lightning/metrics/compositional.py index 56bb1912e48e6..df98d16a3ef7e 100644 --- a/pytorch_lightning/metrics/compositional.py +++ b/pytorch_lightning/metrics/compositional.py @@ -1,28 +1,16 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from typing import Callable, Union import torch -from torchmetrics import Metric -from torchmetrics.metric import CompositionalMetric as _CompositionalMetric -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.metric import Metric -class CompositionalMetric(_CompositionalMetric): +class CompositionalMetric(Metric): + """Composition of two metrics with a specific operator + which will be executed upon metric's compute + + """ - @deprecated_metrics(target=_CompositionalMetric) def __init__( self, operator: Callable, @@ -30,6 +18,75 @@ def __init__( metric_b: Union[Metric, int, float, torch.Tensor, None], ): """ - .. deprecated:: - Use :class:`torchmetrics.metric.CompositionalMetric`. Will be removed in v1.5.0. + + Args: + operator: the operator taking in one (if metric_b is None) + or two arguments. Will be applied to outputs of metric_a.compute() + and (optionally if metric_b is not None) metric_b.compute() + metric_a: first metric whose compute() result is the first argument of operator + metric_b: second metric whose compute() result is the second argument of operator. + For operators taking in only one input, this should be None """ + super().__init__() + + self.op = operator + + if isinstance(metric_a, torch.Tensor): + self.register_buffer("metric_a", metric_a) + else: + self.metric_a = metric_a + + if isinstance(metric_b, torch.Tensor): + self.register_buffer("metric_b", metric_b) + else: + self.metric_b = metric_b + + def _sync_dist(self, dist_sync_fn=None): + # No syncing required here. syncing will be done in metric_a and metric_b + pass + + def update(self, *args, **kwargs): + if isinstance(self.metric_a, Metric): + self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs)) + + if isinstance(self.metric_b, Metric): + self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs)) + + def compute(self): + + # also some parsing for kwargs? + if isinstance(self.metric_a, Metric): + val_a = self.metric_a.compute() + else: + val_a = self.metric_a + + if isinstance(self.metric_b, Metric): + val_b = self.metric_b.compute() + else: + val_b = self.metric_b + + if val_b is None: + return self.op(val_a) + + return self.op(val_a, val_b) + + def reset(self): + if isinstance(self.metric_a, Metric): + self.metric_a.reset() + + if isinstance(self.metric_b, Metric): + self.metric_b.reset() + + def persistent(self, mode: bool = False): + if isinstance(self.metric_a, Metric): + self.metric_a.persistent(mode=mode) + if isinstance(self.metric_b, Metric): + self.metric_b.persistent(mode=mode) + + def __repr__(self): + repr_str = ( + self.__class__.__name__ + + f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)" + ) + + return repr_str diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 69fa9d75590e0..b51ce2e678996 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -11,15 +11,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Tuple import torch -from torchmetrics.functional import accuracy as _accuracy -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType + + +def _accuracy_update( + preds: torch.Tensor, target: torch.Tensor, threshold: float, top_k: Optional[int], subset_accuracy: bool +) -> Tuple[torch.Tensor, torch.Tensor]: + + preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k) + + if mode == DataType.MULTILABEL and top_k: + raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.") + + if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy): + correct = (preds == target).all(dim=1).sum() + total = torch.tensor(target.shape[0], device=target.device) + elif mode == DataType.MULTILABEL and not subset_accuracy: + correct = (preds == target).sum() + total = torch.tensor(target.numel(), device=target.device) + elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy): + correct = (preds * target).sum() + total = target.sum() + elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy: + sample_correct = (preds * target).sum(dim=(1, 2)) + correct = (sample_correct == target.shape[2]).sum() + total = torch.tensor(target.shape[0], device=target.device) + + return correct, total + + +def _accuracy_compute(correct: torch.Tensor, total: torch.Tensor) -> torch.Tensor: + return correct.float() / total -@deprecated_metrics(target=_accuracy) def accuracy( preds: torch.Tensor, target: torch.Tensor, @@ -27,7 +55,66 @@ def accuracy( top_k: Optional[int] = None, subset_accuracy: bool = False, ) -> torch.Tensor: + r"""Computes `Accuracy `_: + + .. math:: + \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. + + For multi-class and multi-dimensional multi-class data with probability predictions, the + parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the + top-K highest probability items are considered to find the correct label. + + For multi-label and multi-dimensional multi-class inputs, this metric computes the "global" + accuracy by default, which counts all labels or sub-samples separately. This can be + changed to subset accuracy (which requires all labels or sub-samples in the sample to + be correctly predicted) by setting ``subset_accuracy=True``. + + Accepts all input types listed in :ref:`extensions/metrics:input types`. + + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth labels + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. + top_k: + Number of highest probability predictions considered to find the correct label, relevant + only for (multi-dimensional) multi-class inputs with probability predictions. The + default value (``None``) will be interpreted as 1 for these inputs. + + Should be left at default (``None``) for all other types of inputs. + subset_accuracy: + Whether to compute subset accuracy for multi-label and multi-dimensional + multi-class inputs (has no effect for other input types). + + - For multi-label inputs, if the parameter is set to ``True``, then all labels for + each sample must be correctly predicted for the sample to count as correct. If it + is set to ``False``, then all labels are counted separately - this is equivalent to + flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). + + - For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all + sub-sample (on the extra axis) must be correct for the sample to be counted as correct. + If it is set to ``False``, then all sub-samples are counter separately - this is equivalent, + in the case of label predictions, to flattening the inputs beforehand (i.e. + ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter + still applies in both cases, if set. + + Example: + + >>> from pytorch_lightning.metrics.functional import accuracy + >>> target = torch.tensor([0, 1, 2, 3]) + >>> preds = torch.tensor([0, 2, 1, 3]) + >>> accuracy(preds, target) + tensor(0.5000) + + >>> target = torch.tensor([0, 1, 2]) + >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) + >>> accuracy(preds, target, top_k=2) + tensor(0.6667) """ - .. deprecated:: - Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0. - """ + + correct, total = _accuracy_update(preds, target, threshold, top_k, subset_accuracy) + return _accuracy_compute(correct, total) diff --git a/pytorch_lightning/metrics/functional/auc.py b/pytorch_lightning/metrics/functional/auc.py index 7cc6aa458d397..57ff9fe97fac2 100644 --- a/pytorch_lightning/metrics/functional/auc.py +++ b/pytorch_lightning/metrics/functional/auc.py @@ -11,15 +11,64 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple + import torch -from torchmetrics.functional import auc as _auc -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.utils import _stable_1d_sort + + +def _auc_update(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if x.ndim > 1 or y.ndim > 1: + raise ValueError( + f'Expected both `x` and `y` tensor to be 1d, but got' + f' tensors with dimention {x.ndim} and {y.ndim}' + ) + if x.numel() != y.numel(): + raise ValueError( + f'Expected the same number of elements in `x` and `y`' + f' tensor but received {x.numel()} and {y.numel()}' + ) + return x, y + + +def _auc_compute(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor: + if reorder: + x, x_idx = _stable_1d_sort(x) + y = y[x_idx] + + dx = x[1:] - x[:-1] + if (dx < 0).any(): + if (dx <= 0).all(): + direction = -1. + else: + raise ValueError( + "The `x` tensor is neither increasing or decreasing." + " Try setting the reorder argument to `True`." + ) + else: + direction = 1. + return direction * torch.trapz(y, x) -@deprecated_metrics(target=_auc) def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.auc`. Will be removed in v1.5.0. + Computes Area Under the Curve (AUC) using the trapezoidal rule + + Args: + x: x-coordinates + y: y-coordinates + reorder: if True, will reorder the arrays + + Return: + Tensor containing AUC score (float) + + Example: + >>> from pytorch_lightning.metrics.functional import auc + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> auc(x, y) + tensor(4.) """ + x, y = _auc_update(x, y) + return _auc_compute(x, y, reorder=reorder) diff --git a/pytorch_lightning/metrics/functional/auroc.py b/pytorch_lightning/metrics/functional/auroc.py index c49aa1a8fdc48..2a8b18d7c6b66 100644 --- a/pytorch_lightning/metrics/functional/auroc.py +++ b/pytorch_lightning/metrics/functional/auroc.py @@ -11,15 +11,129 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence +from distutils.version import LooseVersion +from typing import Optional, Sequence, Tuple import torch -from torchmetrics.functional import auroc as _auroc -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType +from pytorch_lightning.metrics.functional.auc import auc +from pytorch_lightning.metrics.functional.roc import roc +from pytorch_lightning.utilities import LightningEnum + + +class AverageMethods(LightningEnum): + """ Type of averages """ + MACRO = 'macro' + WEIGHTED = 'weighted' + NONE = None + + +def _auroc_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, str]: + # use _input_format_classification for validating the input and get the mode of data + _, _, mode = _input_format_classification(preds, target) + + if mode == 'multi class multi dim': + n_classes = preds.shape[1] + preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) + target = target.flatten() + if mode == 'multi-label' and preds.ndim > 2: + n_classes = preds.shape[1] + preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) + target = target.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) + + return preds, target, mode + + +def _auroc_compute( + preds: torch.Tensor, + target: torch.Tensor, + mode: str, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + average: Optional[str] = 'macro', + max_fpr: Optional[float] = None, + sample_weights: Optional[Sequence] = None, +) -> torch.Tensor: + # binary mode override num_classes + if mode == 'binary': + num_classes = 1 + + # check max_fpr parameter + if max_fpr is not None: + if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1): + raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}") + + if LooseVersion(torch.__version__) < LooseVersion('1.6.0'): + raise RuntimeError( + "`max_fpr` argument requires `torch.bucketize` which" + " is not available below PyTorch version 1.6" + ) + + # max_fpr parameter is only support for binary + if mode != 'binary': + raise ValueError( + f"Partial AUC computation not available in" + f" multilabel/multiclass setting, 'max_fpr' must be" + f" set to `None`, received `{max_fpr}`." + ) + + # calculate fpr, tpr + if mode == 'multi-label': + # for multilabel we iteratively evaluate roc in a binary fashion + output = [ + roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights) + for i in range(num_classes) + ] + fpr = [o[0] for o in output] + tpr = [o[1] for o in output] + else: + fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights) + + # calculate standard roc auc score + if max_fpr is None or max_fpr == 1: + if num_classes != 1: + # calculate auc scores per class + auc_scores = [auc(x, y) for x, y in zip(fpr, tpr)] + + # calculate average + if average == AverageMethods.NONE: + return auc_scores + elif average == AverageMethods.MACRO: + return torch.mean(torch.stack(auc_scores)) + elif average == AverageMethods.WEIGHTED: + if mode == DataType.MULTILABEL: + support = torch.sum(target, dim=0) + else: + support = torch.bincount(target.flatten(), minlength=num_classes) + return torch.sum(torch.stack(auc_scores) * support / support.sum()) + + allowed_average = [e.value for e in AverageMethods] + raise ValueError( + f"Argument `average` expected to be one of the following:" + f" {allowed_average} but got {average}" + ) + + return auc(fpr, tpr) + + max_fpr = torch.tensor(max_fpr, device=fpr.device) + # Add a single point at max_fpr and interpolate its tpr value + stop = torch.bucketize(max_fpr, fpr, out_int32=True, right=True) + weight = (max_fpr - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1]) + interp_tpr = torch.lerp(tpr[stop - 1], tpr[stop], weight) + tpr = torch.cat([tpr[:stop], interp_tpr.view(1)]) + fpr = torch.cat([fpr[:stop], max_fpr.view(1)]) + + # Compute partial AUC + partial_auc = auc(fpr, tpr) + + # McClish correction: standardize result to be 0.5 if non-discriminant + # and 1 if maximal + min_area = 0.5 * max_fpr**2 + max_area = max_fpr + return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area)) -@deprecated_metrics(target=_auroc) def auroc( preds: torch.Tensor, target: torch.Tensor, @@ -29,7 +143,47 @@ def auroc( max_fpr: Optional[float] = None, sample_weights: Optional[Sequence] = None, ) -> torch.Tensor: + """ Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC) + `_ + + Args: + preds: predictions from model (logits or probabilities) + target: Ground truth labels + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translate to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range [0,num_classes-1] + average: + - ``'macro'`` computes metric for each class and uniformly averages them + - ``'weighted'`` computes metric for each class and does a weighted-average, + where each class is weighted by their support (accounts for class imbalance) + - ``None`` computes and returns the metric per class + max_fpr: + If not ``None``, calculates standardized partial AUC over the + range [0, max_fpr]. Should be a float between 0 and 1. + sample_weight: sample weights for each data point + + Example (binary case): + + >>> from pytorch_lightning.metrics.functional import auroc + >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) + >>> target = torch.tensor([0, 0, 1, 1, 1]) + >>> auroc(preds, target, pos_label=1) + tensor(0.5000) + + Example (multiclass case): + + >>> from pytorch_lightning.metrics.functional import auroc + >>> preds = torch.tensor([[0.90, 0.05, 0.05], + ... [0.05, 0.90, 0.05], + ... [0.05, 0.05, 0.90], + ... [0.85, 0.05, 0.10], + ... [0.10, 0.10, 0.80]]) + >>> target = torch.tensor([0, 1, 1, 2, 2]) + >>> auroc(preds, target, num_classes=3) + tensor(0.7778) """ - .. deprecated:: - Use :func:`torchmetrics.functional.auroc`. Will be removed in v1.5.0. - """ + preds, target, mode = _auroc_update(preds, target) + return _auroc_compute(preds, target, mode, num_classes, pos_label, average, max_fpr, sample_weights) diff --git a/pytorch_lightning/metrics/functional/average_precision.py b/pytorch_lightning/metrics/functional/average_precision.py index 017b34739a0f4..2a82c4f38f20e 100644 --- a/pytorch_lightning/metrics/functional/average_precision.py +++ b/pytorch_lightning/metrics/functional/average_precision.py @@ -11,15 +11,45 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Tuple, Union import torch -from torchmetrics.functional import average_precision as _average_precision -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.precision_recall_curve import ( + _precision_recall_curve_compute, + _precision_recall_curve_update, +) + + +def _average_precision_update( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor, int, int]: + return _precision_recall_curve_update(preds, target, num_classes, pos_label) + + +def _average_precision_compute( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + pos_label: int, + sample_weights: Optional[Sequence] = None +) -> Union[List[torch.Tensor], torch.Tensor]: + precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label) + # Return the step function integral + # The following works because the last entry of precision is + # guaranteed to be 1, as returned by precision_recall_curve + if num_classes == 1: + return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) + + res = [] + for p, r in zip(precision, recall): + res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1])) + return res -@deprecated_metrics(target=_average_precision) def average_precision( preds: torch.Tensor, target: torch.Tensor, @@ -28,6 +58,42 @@ def average_precision( sample_weights: Optional[Sequence] = None, ) -> Union[List[torch.Tensor], torch.Tensor]: """ - .. deprecated:: - Use :func:`torchmetrics.functional.average_precision`. Will be removed in v1.5.0. + Computes the average precision score. + + Args: + preds: predictions from model (logits or probabilities) + target: ground truth values + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translate to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range [0,num_classes-1] + sample_weights: sample weights for each data point + + Returns: + tensor with average precision. If multiclass will return list + of such tensors, one for each class + + Example (binary case): + + >>> from pytorch_lightning.metrics.functional import average_precision + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> average_precision(pred, target, pos_label=1) + tensor(1.) + + Example (multiclass case): + + >>> from pytorch_lightning.metrics.functional import average_precision + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> average_precision(pred, target, num_classes=5) + [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] + """ + preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes, pos_label) + return _average_precision_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index be1fec196a346..e697ade9be16b 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -15,13 +15,12 @@ from typing import Callable, Optional, Sequence, Tuple import torch -from torchmetrics.utilities import class_reduce, reduce -from torchmetrics.utilities.data import get_num_classes, to_categorical from pytorch_lightning.metrics.functional.auc import auc as __auc from pytorch_lightning.metrics.functional.auroc import auroc as __auroc from pytorch_lightning.metrics.functional.iou import iou as __iou -from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.metrics.utils import class_reduce, get_num_classes, reduce, to_categorical +from pytorch_lightning.utilities import rank_zero_warn def stat_scores( @@ -31,8 +30,26 @@ def stat_scores( argmax_dim: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ - .. deprecated:: - Use :func:`torchmetrics.functional.stat_scores`. Will be removed in v1.4.0. + Calculates the number of true positive, false positive, true negative + and false negative for a specific class + + Args: + pred: prediction tensor + target: target tensor + class_index: class to calculate over + argmax_dim: if pred is a tensor of probabilities, this indicates the + axis the argmax transformation will be applied over + + Return: + True Positive, False Positive, True Negative, False Negative, Support + + Example: + + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([0, 2, 3]) + >>> tp, fp, tn, fn, sup = stat_scores(x, y, class_index=1) + >>> tp, fp, tn, fn, sup + (tensor(0), tensor(1), tensor(2), tensor(0), tensor(0)) """ if pred.ndim == target.ndim + 1: pred = to_categorical(pred, argmax_dim=argmax_dim) @@ -55,13 +72,17 @@ def stat_scores_multiple_classes( reduction: str = 'none', ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ - .. deprecated:: - Use :func:`torchmetrics.functional.stat_scores`. Will be removed in v1.4.0. + Calculates the number of true positive, false positive, true negative + and false negative for each class + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.stat_scores` + """ - rank_zero_deprecation( + + rank_zero_warn( "This `stat_scores_multiple_classes` was deprecated in v1.2.0 in favor of" " `from pytorch_lightning.metrics.functional import stat_scores`." - " It will be removed in v1.4.0" + " It will be removed in v1.4.0", DeprecationWarning ) if pred.ndim == target.ndim + 1: pred = to_categorical(pred, argmax_dim=argmax_dim) @@ -141,13 +162,42 @@ def precision_recall( return_state: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """ - .. deprecated:: - Use :func:`torchmetrics.functional.precision_recall`. Will be removed in v1.4.0. + Computes precision and recall for different thresholds + + .. warning :: Deprecated in favor of + :func:`~pytorch_lightning.metrics.functional.precision_recall`. + Will be removed in v1.4.0. + + Args: + pred: estimated probabilities + target: ground-truth labels + num_classes: number of classes + class_reduction: method to reduce metric score over labels + + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class + + return_support: returns the support for each class, need for fbeta/f1 calculations + return_state: returns a internal state that can be ddp reduced + before doing the final calculation + + Return: + Tensor with precision and recall + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 2, 2, 2]) + >>> precision_recall(x, y, class_reduction='macro') + (tensor(0.5000), tensor(0.3333)) + """ - rank_zero_deprecation( + rank_zero_warn( "This `precision_recall` was deprecated in v1.2.0 in favor of" " `from pytorch_lightning.metrcs.functional import precision_recall`." - " It will be removed in v1.4.0" + " It will be removed in v1.4.0", DeprecationWarning ) tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes) @@ -169,13 +219,37 @@ def precision( class_reduction: str = 'micro', ) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.precision`. Will be removed in v1.4.0. + Computes precision score. + + .. warning :: Deprecated in favor of + :func:`~pytorch_lightning.metrics.functional.recall`. Will be removed in v1.4.0. + + Args: + pred: estimated probabilities + target: ground-truth labels + num_classes: number of classes + class_reduction: method to reduce metric score over labels + + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class + + Return: + Tensor with precision. + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> precision(x, y) + tensor(0.7500) + """ - rank_zero_deprecation( + rank_zero_warn( "This `precision` was deprecated in v1.2.0 in favor of" " `from pytorch_lightning.metrics.functional import precision`." - " It will be removed in v1.4.0" + " It will be removed in v1.4.0", DeprecationWarning ) return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[0] @@ -189,13 +263,36 @@ def recall( class_reduction: str = 'micro', ) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.recall`. Will be removed in v1.4.0. + Computes recall score. + + .. warning :: Deprecated in favor of + :func:`~pytorch_lightning.metrics.functional.recall`. Will be removed in v1.4.0. + + Args: + pred: estimated probabilities + target: ground-truth labels + num_classes: number of classes + class_reduction: method to reduce metric score over labels + + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class + + Return: + Tensor with recall. + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> recall(x, y) + tensor(0.7500) """ - rank_zero_deprecation( + rank_zero_warn( "This `recall` was deprecated in v1.2.0 in favor of" " `from pytorch_lightning.metrics.functional import recall`." - " It will be removed in v1.4.0" + " It will be removed in v1.4.0", DeprecationWarning ) return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[1] @@ -207,19 +304,37 @@ def auc( y: torch.Tensor, ) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.auc`. Will be removed in v1.4.0. + Computes Area Under the Curve (AUC) using the trapezoidal rule + + .. warning :: Deprecated in favor of + :func:`~pytorch_lightning.metrics.functional.auc.auc`. Will be removed + in v1.4.0. + + Args: + x: x-coordinates + y: y-coordinates + + Return: + Tensor containing AUC score (float) + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> auc(x, y) + tensor(4.) """ - rank_zero_deprecation( + rank_zero_warn( "This `auc` was deprecated in v1.2.0 in favor of" " `pytorch_lightning.metrics.functional.auc import auc`." - " It will be removed in v1.4.0" + " It will be removed in v1.4.0", DeprecationWarning ) return __auc(x, y) # todo: remove in 1.4 -def _auc_decorator() -> Callable: +def auc_decorator() -> Callable: + rank_zero_warn("This `auc_decorator` was deprecated in v1.2.0." " It will be removed in v1.4.0", DeprecationWarning) def wrapper(func_to_decorate: Callable) -> Callable: @@ -235,7 +350,11 @@ def new_func(*args, **kwargs) -> torch.Tensor: # todo: remove in 1.4 -def _multiclass_auc_decorator() -> Callable: +def multiclass_auc_decorator() -> Callable: + rank_zero_warn( + "This `multiclass_auc_decorator` was deprecated in v1.2.0." + " It will be removed in v1.4.0", DeprecationWarning + ) def wrapper(func_to_decorate: Callable) -> Callable: @@ -262,12 +381,34 @@ def auroc( max_fpr: float = None, ) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.auroc`. Will be removed in v1.4.0. + Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores + + .. warning :: Deprecated in favor of + :func:`~pytorch_lightning.metrics.functional.auroc.auroc`. Will be removed + in v1.4.0. + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weights + pos_label: the label for the positive class + max_fpr: If not ``None``, calculates standardized partial AUC over the + range [0, max_fpr]. Should be a float between 0 and 1. + + Return: + Tensor containing ROCAUC score + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 1, 0]) + >>> auroc(x, y) + tensor(0.5000) """ - rank_zero_deprecation( - "This `auroc` was deprecated in v1.2.0 in favor of `pytorch_lightning.metrics.functional.auroc import auroc`." - " It will be removed in v1.4.0" + rank_zero_warn( + "This `auroc` was deprecated in v1.2.0 in favor of" + " `pytorch_lightning.metrics.functional.auroc import auroc`." + " It will be removed in v1.4.0", DeprecationWarning ) return __auroc( preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label, max_fpr=max_fpr, num_classes=1 @@ -282,15 +423,58 @@ def multiclass_auroc( num_classes: Optional[int] = None, ) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.auroc`. Will be removed in v1.4.0. + Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from multiclass + prediction scores + + .. warning :: Deprecated in favor of + :func:`~pytorch_lightning.metrics.functional.auroc.auroc`. Will be removed + in v1.4.0. + + Args: + pred: estimated probabilities, with shape [N, C] + target: ground-truth labels, with shape [N,] + sample_weight: sample weights + num_classes: number of classes (default: None, computes automatically from data) + + Return: + Tensor containing ROCAUC score + + Example: + + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> multiclass_auroc(pred, target, num_classes=4) + tensor(0.6667) """ - rank_zero_deprecation( + rank_zero_warn( "This `multiclass_auroc` was deprecated in v1.2.0 in favor of" " `pytorch_lightning.metrics.functional.auroc import auroc`." - " It will be removed in v1.4.0" + " It will be removed in v1.4.0", DeprecationWarning ) + if not torch.allclose(pred.sum(dim=1), torch.tensor(1.0)): + raise ValueError( + "Multiclass AUROC metric expects the target scores to be" + " probabilities, i.e. they should sum up to 1.0 over classes" + ) + + if torch.unique(target).size(0) != pred.size(1): + raise ValueError( + f"Number of classes found in in 'target' ({torch.unique(target).size(0)})" + f" does not equal the number of columns in 'pred' ({pred.size(1)})." + " Multiclass AUROC is not defined when all of the classes do not" + " occur in the target labels." + ) + + if num_classes is not None and num_classes != pred.size(1): + raise ValueError( + f"Number of classes deduced from 'pred' ({pred.size(1)}) does not equal" + f" the number of classes passed in 'num_classes' ({num_classes})." + ) + return __auroc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes) @@ -303,8 +487,34 @@ def dice_score( reduction: str = 'elementwise_mean', ) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.dice_score`. Will be removed in v1.4.0. + Compute dice score from prediction scores + + Args: + pred: estimated probabilities + target: ground-truth labels + bg: whether to also compute dice for the background + nan_score: score to return, if a NaN occurs during computation + no_fg_score: score to return, if no foreground pixel was found in target + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + Return: + Tensor containing dice score + + Example: + + >>> from pytorch_lightning.metrics.functional import dice_score + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> dice_score(pred, target) + tensor(0.3333) + """ num_classes = pred.shape[1] bg = (1 - int(bool(bg))) @@ -334,12 +544,47 @@ def iou( reduction: str = 'elementwise_mean', ) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.iou`. Will be removed in v1.4.0. + Intersection over union, or Jaccard index calculation. + + .. warning :: Deprecated in favor of + :func:`~pytorch_lightning.metrics.functional.iou.iou`. Will be removed in + v1.4.0. + + Args: + pred: Tensor containing integer predictions, with shape [N, d1, d2, ...] + target: Tensor containing integer targets, with shape [N, d1, d2, ...] + ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. Has no effect if given an int that is not in the + range [0, num_classes-1], where num_classes is either given or derived from pred and target. By default, no + index is ignored, and all classes are used. + absent_score: score to use for an individual class, if no instances of the class index were present in + `pred` AND no instances of the class index were present in `target`. For example, if we have 3 classes, + [0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be assigned the `absent_score`. Default is + 0.0. + num_classes: Optionally specify the number of classes + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + Return: + IoU score : Tensor containing single value if reduction is + 'elementwise_mean', or number of classes if reduction is 'none' + + Example: + + >>> target = torch.randint(0, 2, (10, 25, 25)) + >>> pred = torch.tensor(target) + >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] + >>> iou(pred, target) + tensor(0.9660) + """ - rank_zero_deprecation( - "This `iou` was deprecated in v1.2.0 in favor of `from pytorch_lightning.metrics.functional.iou import iou`." - " It will be removed in v1.4.0" + rank_zero_warn( + "This `iou` was deprecated in v1.2.0 in favor of" + " `from pytorch_lightning.metrics.functional.iou import iou`." + " It will be removed in v1.4.0", DeprecationWarning ) return __iou( pred=pred, diff --git a/pytorch_lightning/metrics/functional/confusion_matrix.py b/pytorch_lightning/metrics/functional/confusion_matrix.py index 038bd8b49b730..58947f2cb19ed 100644 --- a/pytorch_lightning/metrics/functional/confusion_matrix.py +++ b/pytorch_lightning/metrics/functional/confusion_matrix.py @@ -14,12 +14,44 @@ from typing import Optional import torch -from torchmetrics.functional import confusion_matrix as _confusion_matrix -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType +from pytorch_lightning.utilities import rank_zero_warn + + +def _confusion_matrix_update( + preds: torch.Tensor, target: torch.Tensor, num_classes: int, threshold: float = 0.5 +) -> torch.Tensor: + preds, target, mode = _input_format_classification(preds, target, threshold) + if mode not in (DataType.BINARY, DataType.MULTILABEL): + preds = preds.argmax(dim=1) + target = target.argmax(dim=1) + unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long) + bins = torch.bincount(unique_mapping, minlength=num_classes**2) + confmat = bins.reshape(num_classes, num_classes) + return confmat + + +def _confusion_matrix_compute(confmat: torch.Tensor, normalize: Optional[str] = None) -> torch.Tensor: + allowed_normalize = ('true', 'pred', 'all', 'none', None) + assert normalize in allowed_normalize, \ + f"Argument average needs to one of the following: {allowed_normalize}" + confmat = confmat.float() + if normalize is not None and normalize != 'none': + if normalize == 'true': + cm = confmat / confmat.sum(axis=1, keepdim=True) + elif normalize == 'pred': + cm = confmat / confmat.sum(axis=0, keepdim=True) + elif normalize == 'all': + cm = confmat / confmat.sum() + nan_elements = cm[torch.isnan(cm)].nelement() + if nan_elements != 0: + cm[torch.isnan(cm)] = 0 + rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.') + return cm + return confmat -@deprecated_metrics(target=_confusion_matrix) def confusion_matrix( preds: torch.Tensor, target: torch.Tensor, @@ -28,6 +60,38 @@ def confusion_matrix( threshold: float = 0.5 ) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.confusion_matrix`. Will be removed in v1.5.0. + Computes the confusion matrix. Works with binary, multiclass, and multilabel data. + Accepts probabilities from a model output or integer class values in prediction. + Works with multi-dimensional preds and target. + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument + to convert into integer labels. This is the case for binary and multi-label probabilities. + + If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + + Args: + preds: (float or long tensor), Either a ``(N, ...)`` tensor with labels or + ``(N, C, ...)`` where C is the number of classes, tensor with labels/probabilities + target: ``target`` (long tensor), tensor with shape ``(N, ...)`` with ground true labels + num_classes: Number of classes in the dataset. + normalize: Normalization mode for confusion matrix. Choose from + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + + threshold: + Threshold value for binary or multi-label probabilities. default: 0.5 + + Example: + + >>> from pytorch_lightning.metrics.functional import confusion_matrix + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> confusion_matrix(preds, target, num_classes=2) + tensor([[2., 0.], + [1., 1.]]) """ + confmat = _confusion_matrix_update(preds, target, num_classes, threshold) + return _confusion_matrix_compute(confmat, normalize) diff --git a/pytorch_lightning/metrics/functional/explained_variance.py b/pytorch_lightning/metrics/functional/explained_variance.py index 233a0851b8d56..617d800c754e3 100644 --- a/pytorch_lightning/metrics/functional/explained_variance.py +++ b/pytorch_lightning/metrics/functional/explained_variance.py @@ -11,21 +11,78 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Union +from typing import Sequence, Tuple, Union import torch -from torchmetrics.functional import explained_variance as _explained_variance -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.utils import _check_same_shape + + +def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + _check_same_shape(preds, target) + return preds, target + + +def _explained_variance_compute( + preds: torch.Tensor, + target: torch.Tensor, + multioutput: str = 'uniform_average', +) -> Union[torch.Tensor, Sequence[torch.Tensor]]: + diff_avg = torch.mean(target - preds, dim=0) + numerator = torch.mean((target - preds - diff_avg)**2, dim=0) + + target_avg = torch.mean(target, dim=0) + denominator = torch.mean((target - target_avg)**2, dim=0) + + # Take care of division by zero + nonzero_numerator = numerator != 0 + nonzero_denominator = denominator != 0 + valid_score = nonzero_numerator & nonzero_denominator + output_scores = torch.ones_like(diff_avg) + output_scores[valid_score] = 1.0 - (numerator[valid_score] / denominator[valid_score]) + output_scores[nonzero_numerator & ~nonzero_denominator] = 0. + + # Decide what to do in multioutput case + # Todo: allow user to pass in tensor with weights + if multioutput == 'raw_values': + return output_scores + if multioutput == 'uniform_average': + return torch.mean(output_scores) + if multioutput == 'variance_weighted': + denom_sum = torch.sum(denominator) + return torch.sum(denominator / denom_sum * output_scores) -@deprecated_metrics(target=_explained_variance) def explained_variance( preds: torch.Tensor, target: torch.Tensor, multioutput: str = 'uniform_average', ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: """ - .. deprecated:: - Use :func:`torchmetrics.functional.explained_variance`. Will be removed in v1.5.0. + Computes explained variance. + + Args: + preds: estimated labels + target: ground truth labels + multioutput: Defines aggregation in the case of multiple output scores. Can be one + of the following strings (default is `'uniform_average'`.): + + * `'raw_values'` returns full set of scores + * `'uniform_average'` scores are uniformly averaged + * `'variance_weighted'` scores are weighted by their individual variances + + Example: + + >>> from pytorch_lightning.metrics.functional import explained_variance + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> explained_variance(preds, target) + tensor(0.9572) + + >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) + >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) + >>> explained_variance(preds, target, multioutput='raw_values') + tensor([0.9677, 1.0000]) """ + preds, target = _explained_variance_update(preds, target) + return _explained_variance_compute(preds, target, multioutput) diff --git a/pytorch_lightning/metrics/functional/f_beta.py b/pytorch_lightning/metrics/functional/f_beta.py index f994c9a8a3271..debb6c8285fc9 100644 --- a/pytorch_lightning/metrics/functional/f_beta.py +++ b/pytorch_lightning/metrics/functional/f_beta.py @@ -11,14 +11,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple + import torch -from torchmetrics.functional import f1 as _f1 -from torchmetrics.functional import fbeta as _fbeta -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce + + +def _fbeta_update( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + threshold: float = 0.5, + multilabel: bool = False +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + preds, target = _input_format_classification_one_hot(num_classes, preds, target, threshold, multilabel) + true_positives = torch.sum(preds * target, dim=1) + predicted_positives = torch.sum(preds, dim=1) + actual_positives = torch.sum(target, dim=1) + return true_positives, predicted_positives, actual_positives + + +def _fbeta_compute( + true_positives: torch.Tensor, + predicted_positives: torch.Tensor, + actual_positives: torch.Tensor, + beta: float = 1.0, + average: str = "micro" +) -> torch.Tensor: + if average == "micro": + precision = true_positives.sum().float() / predicted_positives.sum() + recall = true_positives.sum().float() / actual_positives.sum() + else: + precision = true_positives.float() / predicted_positives + recall = true_positives.float() / actual_positives + + num = (1 + beta**2) * precision * recall + denom = beta**2 * precision + recall + return class_reduce(num, denom, weights=actual_positives, class_reduction=average) -@deprecated_metrics(target=_fbeta) def fbeta( preds: torch.Tensor, target: torch.Tensor, @@ -29,12 +61,49 @@ def fbeta( multilabel: bool = False ) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0. + Computes f_beta metric. + + Works with binary, multiclass, and multilabel data. + Accepts probabilities from a model output or integer class values in prediction. + Works with multi-dimensional preds and target. + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument + to convert into integer labels. This is the case for binary and multi-label probabilities. + + If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + + Args: + preds: predictions from model (probabilities, or labels) + target: ground truth labels + num_classes: Number of classes in the dataset. + beta: Beta coefficient in the F measure. + threshold: + Threshold value for binary or multi-label probabilities. default: 0.5 + + average: + - ``'micro'`` computes metric globally + - ``'macro'`` computes metric for each class and uniformly averages them + - ``'weighted'`` computes metric for each class and does a weighted-average, + where each class is weighted by their support (accounts for class imbalance) + - ``'none'`` or ``None`` computes and returns the metric per class + + multilabel: If predictions are from multilabel classification. + + Example: + + >>> from pytorch_lightning.metrics.functional import fbeta + >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) + >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) + >>> fbeta(preds, target, num_classes=3, beta=0.5) + tensor(0.3333) + """ + true_positives, predicted_positives, actual_positives = _fbeta_update( + preds, target, num_classes, threshold, multilabel + ) + return _fbeta_compute(true_positives, predicted_positives, actual_positives, beta, average) -@deprecated_metrics(target=_f1) def f1( preds: torch.Tensor, target: torch.Tensor, @@ -44,6 +113,39 @@ def f1( multilabel: bool = False ) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.f1`. Will be removed in v1.5.0. + Computes F1 metric. F1 metrics correspond to a equally weighted average of the + precision and recall scores. + + Works with binary, multiclass, and multilabel data. + Accepts probabilities from a model output or integer class values in prediction. + Works with multi-dimensional preds and target. + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument + to convert into integer labels. This is the case for binary and multi-label probabilities. + + If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + + Args: + preds: predictions from model (probabilities, or labels) + target: ground truth labels + num_classes: Number of classes in the dataset. + threshold: + Threshold value for binary or multi-label probabilities. default: 0.5 + + average: + - ``'micro'`` computes metric globally + - ``'macro'`` computes metric for each class and uniformly averages them + - ``'weighted'`` computes metric for each class and does a weighted-average, + where each class is weighted by their support (accounts for class imbalance) + - ``'none'`` or ``None`` computes and returns the metric per class + + multilabel: If predictions are from multilabel classification. + + Example: + >>> from pytorch_lightning.metrics.functional import f1 + >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) + >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) + >>> f1(preds, target, num_classes=3) + tensor(0.3333) """ + return fbeta(preds, target, num_classes, 1.0, threshold, average, multilabel) diff --git a/pytorch_lightning/metrics/functional/hamming_distance.py b/pytorch_lightning/metrics/functional/hamming_distance.py index 6a390e776f111..60409751fc9f0 100644 --- a/pytorch_lightning/metrics/functional/hamming_distance.py +++ b/pytorch_lightning/metrics/functional/hamming_distance.py @@ -11,15 +11,64 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple, Union + import torch -from torchmetrics.functional import hamming_distance as _hamming_distance -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.classification.helpers import _input_format_classification + + +def _hamming_distance_update( + preds: torch.Tensor, + target: torch.Tensor, + threshold: float = 0.5, +) -> Tuple[torch.Tensor, int]: + preds, target, _ = _input_format_classification(preds, target, threshold=threshold) + + correct = (preds == target).sum() + total = preds.numel() + + return correct, total + + +def _hamming_distance_compute(correct: torch.Tensor, total: Union[int, torch.Tensor]) -> torch.Tensor: + return 1 - correct.float() / total -@deprecated_metrics(target=_hamming_distance) def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: + r""" + Computes the average `Hamming distance `_ (also + known as Hamming loss) between targets and predictions: + + .. math:: + \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that + tensor. + + This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it + treats each possible label separately - meaning that, for example, multi-class data is + treated as if it were multi-label. + + Accepts all input types listed in :ref:`extensions/metrics:input types`. + + Args: + preds: Predictions from model + target: Ground truth + threshold: + Threshold probability value for transforming probability predictions to binary + (0 or 1) predictions, in the case of binary or multi-label inputs. + + Example: + + >>> from pytorch_lightning.metrics.functional import hamming_distance + >>> target = torch.tensor([[0, 1], [1, 1]]) + >>> preds = torch.tensor([[0, 1], [0, 1]]) + >>> hamming_distance(preds, target) + tensor(0.2500) + """ - .. deprecated:: - Use :func:`torchmetrics.functional.hamming_distance`. Will be removed in v1.5.0. - """ + + correct, total = _hamming_distance_update(preds, target, threshold) + return _hamming_distance_compute(correct, total) diff --git a/pytorch_lightning/metrics/functional/image_gradients.py b/pytorch_lightning/metrics/functional/image_gradients.py index e2151c5fc1d93..3fbed571e008e 100644 --- a/pytorch_lightning/metrics/functional/image_gradients.py +++ b/pytorch_lightning/metrics/functional/image_gradients.py @@ -14,14 +14,62 @@ from typing import Tuple import torch -from torchmetrics.functional import image_gradients as _image_gradients -from pytorch_lightning.metrics.utils import deprecated_metrics + +def _image_gradients_validate(img: torch.Tensor) -> torch.Tensor: + """ Validates whether img is a 4D torch Tensor """ + + if not isinstance(img, torch.Tensor): + raise TypeError(f"The `img` expects a value of type but got {type(img)}") + if img.ndim != 4: + raise RuntimeError(f"The `img` expects a 4D tensor but got {img.ndim}D tensor") + + +def _compute_image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ Computes image gradients (dy/dx) for a given image """ + + batch_size, channels, height, width = img.shape + + dy = img[..., 1:, :] - img[..., :-1, :] + dx = img[..., :, 1:] - img[..., :, :-1] + + shapey = [batch_size, channels, 1, width] + dy = torch.cat([dy, torch.zeros(shapey, device=img.device, dtype=img.dtype)], dim=2) + dy = dy.view(img.shape) + + shapex = [batch_size, channels, height, 1] + dx = torch.cat([dx, torch.zeros(shapex, device=img.device, dtype=img.dtype)], dim=3) + dx = dx.view(img.shape) + + return dy, dx -@deprecated_metrics(target=_image_gradients) def image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - .. deprecated:: - Use :func:`torchmetrics.functional.image_gradients`. Will be removed in v1.5.0. + Computes the `gradients `_ of a given image using finite difference + + Args: + img: An ``(N, C, H, W)`` input tensor where C is the number of image channels + + Return: + Tuple of (dy, dx) with each gradient of shape ``[N, C, H, W]`` + + Example: + >>> from pytorch_lightning.metrics.functional import image_gradients + >>> image = torch.arange(0, 1*1*5*5, dtype=torch.float32) + >>> image = torch.reshape(image, (1, 1, 5, 5)) + >>> dy, dx = image_gradients(image) + >>> dy[0, 0, :, :] + tensor([[5., 5., 5., 5., 5.], + [5., 5., 5., 5., 5.], + [5., 5., 5., 5., 5.], + [5., 5., 5., 5., 5.], + [0., 0., 0., 0., 0.]]) + + .. note:: The implementation follows the 1-step finite difference method as followed + by the TF implementation. The values are organized such that the gradient of + [I(x+1, y)-[I(x, y)]] are at the (x, y) location """ + _image_gradients_validate(img) + + return _compute_image_gradients(img) diff --git a/pytorch_lightning/metrics/functional/iou.py b/pytorch_lightning/metrics/functional/iou.py index 76f59854ad4bf..7b6851b5cebd0 100644 --- a/pytorch_lightning/metrics/functional/iou.py +++ b/pytorch_lightning/metrics/functional/iou.py @@ -14,12 +14,34 @@ from typing import Optional import torch -from torchmetrics.functional import iou as _iou -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_update +from pytorch_lightning.metrics.utils import get_num_classes, reduce + + +def _iou_from_confmat( + confmat: torch.Tensor, + num_classes: int, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + reduction: str = 'elementwise_mean', +): + intersection = torch.diag(confmat) + union = confmat.sum(0) + confmat.sum(1) - intersection + + # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. + scores = intersection.float() / union.float() + scores[union == 0] = absent_score + + # Remove the ignored class index from the scores. + if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes: + scores = torch.cat([ + scores[:ignore_index], + scores[ignore_index + 1:], + ]) + return reduce(scores, reduction=reduction) -@deprecated_metrics(target=_iou) def iou( pred: torch.Tensor, target: torch.Tensor, @@ -29,7 +51,60 @@ def iou( num_classes: Optional[int] = None, reduction: str = 'elementwise_mean', ) -> torch.Tensor: + r""" + Computes `Intersection over union, or Jaccard index calculation `_: + + .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} + + Where: :math:`A` and :math:`B` are both tensors of the same size, + containing integer class values. They may be subject to conversion from + input data (see description below). + + Note that it is different from box IoU. + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument + to convert into integer labels. This is the case for binary and multi-label probabilities. + + If pred has an extra dimension as in the case of multi-class scores we + perform an argmax on ``dim=1``. + + Args: + preds: tensor containing predictions from model (probabilities, or labels) with shape ``[N, d1, d2, ...]`` + target: tensor containing ground truth labels with shape ``[N, d1, d2, ...]`` + ignore_index: optional int specifying a target class to ignore. If given, + this class index does not contribute to the returned score, regardless + of reduction method. Has no effect if given an int that is not in the + range [0, num_classes-1], where num_classes is either given or derived + from pred and target. By default, no index is ignored, and all classes are used. + absent_score: score to use for an individual class, if no instances of + the class index were present in `pred` AND no instances of the class + index were present in `target`. For example, if we have 3 classes, + [0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be + assigned the `absent_score`. + threshold: + Threshold value for binary or multi-label probabilities. default: 0.5 + num_classes: + Optionally specify the number of classes + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + Return: + IoU score : Tensor containing single value if reduction is + 'elementwise_mean', or number of classes if reduction is 'none' + + Example: + + >>> from pytorch_lightning.metrics.functional import iou + >>> target = torch.randint(0, 2, (10, 25, 25)) + >>> pred = torch.tensor(target) + >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] + >>> iou(pred, target) + tensor(0.9660) """ - .. deprecated:: - Use :func:`torchmetrics.functional.iou`. Will be removed in v1.5.0. - """ + + num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) + confmat = _confusion_matrix_update(pred, target, num_classes, threshold) + return _iou_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction) diff --git a/pytorch_lightning/metrics/functional/mean_absolute_error.py b/pytorch_lightning/metrics/functional/mean_absolute_error.py index 219284d79d623..671368ba240f9 100644 --- a/pytorch_lightning/metrics/functional/mean_absolute_error.py +++ b/pytorch_lightning/metrics/functional/mean_absolute_error.py @@ -11,16 +11,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple import torch -from torchmetrics.functional import mean_absolute_error as _mean_absolute_error -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.utils import _check_same_shape + + +def _mean_absolute_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: + _check_same_shape(preds, target) + sum_abs_error = torch.sum(torch.abs(preds - target)) + n_obs = target.numel() + return sum_abs_error, n_obs + + +def _mean_absolute_error_compute(sum_abs_error: torch.Tensor, n_obs: int) -> torch.Tensor: + return sum_abs_error / n_obs -@deprecated_metrics(target=_mean_absolute_error) def mean_absolute_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.mean_absolute_error`. Will be removed in v1.5.0. + Computes mean absolute error + + Args: + pred: estimated labels + target: ground truth labels + + Return: + Tensor with MAE + + Example: + >>> from pytorch_lightning.metrics.functional import mean_absolute_error + >>> x = torch.tensor([0., 1, 2, 3]) + >>> y = torch.tensor([0., 1, 2, 2]) + >>> mean_absolute_error(x, y) + tensor(0.2500) """ + sum_abs_error, n_obs = _mean_absolute_error_update(preds, target) + return _mean_absolute_error_compute(sum_abs_error, n_obs) diff --git a/pytorch_lightning/metrics/functional/mean_relative_error.py b/pytorch_lightning/metrics/functional/mean_relative_error.py index 329fe040ebc7d..eedaea1a26a4f 100644 --- a/pytorch_lightning/metrics/functional/mean_relative_error.py +++ b/pytorch_lightning/metrics/functional/mean_relative_error.py @@ -11,16 +11,44 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple import torch -from torchmetrics.functional.regression.mean_relative_error import mean_relative_error as _mean_relative_error -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.utils import _check_same_shape + + +def _mean_relative_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: + _check_same_shape(preds, target) + target_nz = target.clone() + target_nz[target == 0] = 1 + sum_rltv_error = torch.sum(torch.abs((preds - target) / target_nz)) + n_obs = target.numel() + return sum_rltv_error, n_obs + + +def _mean_relative_error_compute(sum_rltv_error: torch.Tensor, n_obs: int) -> torch.Tensor: + return sum_rltv_error / n_obs -@deprecated_metrics(target=_mean_relative_error) def mean_relative_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.regression.mean_relative_error`. Will be removed in v1.5.0. + Computes mean relative error + + Args: + pred: estimated labels + target: ground truth labels + + Return: + Tensor with mean relative error + + Example: + + >>> x = torch.tensor([0., 1, 2, 3]) + >>> y = torch.tensor([0., 1, 2, 2]) + >>> mean_relative_error(x, y) + tensor(0.1250) + """ + sum_rltv_error, n_obs = _mean_relative_error_update(preds, target) + return _mean_relative_error_compute(sum_rltv_error, n_obs) diff --git a/pytorch_lightning/metrics/functional/mean_squared_error.py b/pytorch_lightning/metrics/functional/mean_squared_error.py index 5bbc0bb1c6a83..2cdd4ea679043 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_error.py @@ -11,16 +11,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple import torch -from torchmetrics.functional import mean_squared_error as _mean_squared_error -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.utils import _check_same_shape + + +def _mean_squared_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: + _check_same_shape(preds, target) + sum_squared_error = torch.sum(torch.pow(preds - target, 2)) + n_obs = target.numel() + return sum_squared_error, n_obs + + +def _mean_squared_error_compute(sum_squared_error: torch.Tensor, n_obs: int) -> torch.Tensor: + return sum_squared_error / n_obs -@deprecated_metrics(target=_mean_squared_error) def mean_squared_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.mean_squared_error`. Will be removed in v1.5.0. + Computes mean squared error + + Args: + preds: estimated labels + target: ground truth labels + + Return: + Tensor with MSE + + Example: + >>> from pytorch_lightning.metrics.functional import mean_squared_error + >>> x = torch.tensor([0., 1, 2, 3]) + >>> y = torch.tensor([0., 1, 2, 2]) + >>> mean_squared_error(x, y) + tensor(0.2500) """ + sum_squared_error, n_obs = _mean_squared_error_update(preds, target) + return _mean_squared_error_compute(sum_squared_error, n_obs) diff --git a/pytorch_lightning/metrics/functional/mean_squared_log_error.py b/pytorch_lightning/metrics/functional/mean_squared_log_error.py index 29786529381d5..45c255eb61d78 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_log_error.py @@ -11,16 +11,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple import torch -from torchmetrics.functional import mean_squared_log_error as _mean_squared_log_error -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.utils import _check_same_shape + + +def _mean_squared_log_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: + _check_same_shape(preds, target) + sum_squared_log_error = torch.sum(torch.pow(torch.log1p(preds) - torch.log1p(target), 2)) + n_obs = target.numel() + return sum_squared_log_error, n_obs + + +def _mean_squared_log_error_compute(sum_squared_log_error: torch.Tensor, n_obs: int) -> torch.Tensor: + return sum_squared_log_error / n_obs -@deprecated_metrics(target=_mean_squared_log_error) def mean_squared_log_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.mean_squared_log_error`. Will be removed in v1.5.0. + Computes mean squared log error + + Args: + preds: estimated labels + target: ground truth labels + + Return: + Tensor with RMSLE + + Example: + >>> from pytorch_lightning.metrics.functional import mean_squared_log_error + >>> x = torch.tensor([0., 1, 2, 3]) + >>> y = torch.tensor([0., 1, 2, 2]) + >>> mean_squared_log_error(x, y) + tensor(0.0207) """ + sum_squared_log_error, n_obs = _mean_squared_log_error_update(preds, target) + return _mean_squared_log_error_compute(sum_squared_log_error, n_obs) diff --git a/pytorch_lightning/metrics/functional/nlp.py b/pytorch_lightning/metrics/functional/nlp.py index c59d7cf2b8976..b1466c66112bc 100644 --- a/pytorch_lightning/metrics/functional/nlp.py +++ b/pytorch_lightning/metrics/functional/nlp.py @@ -16,15 +16,34 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from typing import Sequence +from collections import Counter +from typing import List, Sequence import torch -from torchmetrics.functional import bleu_score as _bleu_score -from pytorch_lightning.metrics.utils import deprecated_metrics + +def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: + """ + Counting how many times each word appears in a given text with ngram + + Args: + ngram_input_list: A list of translated text or reference texts + n_gram: gram value ranged 1 to 4 + + Return: + ngram_counter: a collections.Counter object of ngram + """ + + ngram_counter = Counter() + + for i in range(1, n_gram + 1): + for j in range(len(ngram_input_list) - i + 1): + ngram_key = tuple(ngram_input_list[j:(i + j)]) + ngram_counter[ngram_key] += 1 + + return ngram_counter -@deprecated_metrics(target=_bleu_score) def bleu_score( translate_corpus: Sequence[str], reference_corpus: Sequence[str], @@ -32,6 +51,64 @@ def bleu_score( smooth: bool = False ) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.bleu_score`. Will be removed in v1.5.0. + Calculate BLEU score of machine translated text with one or more references + + Args: + translate_corpus: An iterable of machine translated corpus + reference_corpus: An iterable of iterables of reference corpus + n_gram: Gram value ranged from 1 to 4 (Default 4) + smooth: Whether or not to apply smoothing – Lin et al. 2004 + + Return: + Tensor with BLEU Score + + Example: + >>> from pytorch_lightning.metrics.functional import bleu_score + >>> translate_corpus = ['the cat is on the mat'.split()] + >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] + >>> bleu_score(translate_corpus, reference_corpus) + tensor(0.7598) """ + + assert len(translate_corpus) == len(reference_corpus) + numerator = torch.zeros(n_gram) + denominator = torch.zeros(n_gram) + c = 0.0 + r = 0.0 + + for (translation, references) in zip(translate_corpus, reference_corpus): + c += len(translation) + ref_len_list = [len(ref) for ref in references] + ref_len_diff = [abs(len(translation) - x) for x in ref_len_list] + r += ref_len_list[ref_len_diff.index(min(ref_len_diff))] + translation_counter = _count_ngram(translation, n_gram) + reference_counter = Counter() + + for ref in references: + reference_counter |= _count_ngram(ref, n_gram) + + ngram_counter_clip = translation_counter & reference_counter + + for counter_clip in ngram_counter_clip: + numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] + + for counter in translation_counter: + denominator[len(counter) - 1] += translation_counter[counter] + + trans_len = torch.tensor(c) + ref_len = torch.tensor(r) + + if min(numerator) == 0.0: + return torch.tensor(0.0) + + if smooth: + precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram)) + else: + precision_scores = numerator / denominator + + log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores) + geometric_mean = torch.exp(torch.sum(log_precision_scores)) + brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len)) + bleu = brevity_penalty * geometric_mean + + return bleu diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index 7b6c8641b5829..6f5aafd79d109 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -14,14 +14,29 @@ from typing import Optional import torch -from torchmetrics.functional import precision as _precision -from torchmetrics.functional import precision_recall as _precision_recall -from torchmetrics.functional import recall as _recall -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.classification.helpers import _reduce_stat_scores +from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update +from pytorch_lightning.utilities import rank_zero_warn + + +def _precision_compute( + tp: torch.Tensor, + fp: torch.Tensor, + tn: torch.Tensor, + fn: torch.Tensor, + average: str, + mdmc_average: Optional[str], +) -> torch.Tensor: + return _reduce_stat_scores( + numerator=tp, + denominator=tp + fp, + weights=None if average != "weighted" else tp + fn, + average=average, + mdmc_average=mdmc_average, + ) -@deprecated_metrics(target=_precision) def precision( preds: torch.Tensor, target: torch.Tensor, @@ -32,14 +47,158 @@ def precision( threshold: float = 0.5, top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, + class_reduction: Optional[str] = None, ) -> torch.Tensor: + r""" + Computes `Precision `_: + + .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} + + Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and + false positives respecitively. With the use of ``top_k`` parameter, this metric can + generalize to Precision@K. + + The reduction method (how the precision scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. + + Args: + preds: Predictions from model (probabilities or labels) + target: Ground truth values + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics accross classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics accross classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs + (see :ref:`extensions/metrics:input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. + + num_classes: + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. If this parameter is set for multi-label + inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, + this parameter defaults to 1. + + Should be left unset (``None``) for inputs with label predictions. + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + class_reduction: + .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. + + Return: + The shape of the returned tensor depends on the ``average`` parameter + + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number + of classes + + Example: + + >>> from pytorch_lightning.metrics.functional import precision + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> precision(preds, target, average='macro', num_classes=3) + tensor(0.1667) + >>> precision(preds, target, average='micro') + tensor(0.2500) + """ - .. deprecated:: - Use :func:`torchmetrics.functional.precision`. Will be removed in v1.5.0. - """ + if class_reduction: + rank_zero_warn( + "This `class_reduction` parameter was deprecated in v1.2.0 in favor of" + " `reduce`. It will be removed in v1.4.0", + DeprecationWarning, + ) + average = class_reduction + + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + allowed_mdmc_average = [None, "samplewise", "global"] + if mdmc_average not in allowed_mdmc_average: + raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") + + if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): + raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") + + if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + + reduce = "macro" if average in ["weighted", "none", None] else average + tp, fp, tn, fn = _stat_scores_update( + preds, + target, + reduce=reduce, + mdmc_reduce=mdmc_average, + threshold=threshold, + num_classes=num_classes, + top_k=top_k, + is_multiclass=is_multiclass, + ignore_index=ignore_index, + ) + + return _precision_compute(tp, fp, tn, fn, average, mdmc_average) + + +def _recall_compute( + tp: torch.Tensor, + fp: torch.Tensor, + tn: torch.Tensor, + fn: torch.Tensor, + average: str, + mdmc_average: Optional[str], +) -> torch.Tensor: + return _reduce_stat_scores( + numerator=tp, + denominator=tp + fn, + weights=None if average != "weighted" else tp + fn, + average=average, + mdmc_average=mdmc_average, + ) -@deprecated_metrics(target=_recall) def recall( preds: torch.Tensor, target: torch.Tensor, @@ -50,14 +209,141 @@ def recall( threshold: float = 0.5, top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, + class_reduction: Optional[str] = None, ) -> torch.Tensor: + r""" + Computes `Recall `_: + + .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} + + Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and + false negatives respecitively. With the use of ``top_k`` parameter, this metric can + generalize to Recall@K. + + The reduction method (how the recall scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. + + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth values + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics accross classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics accross classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs + (see :ref:`extensions/metrics:input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. + + num_classes: + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. If this parameter is set for multi-label + inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, + this parameter defaults to 1. + + Should be left unset (``None``) for inputs with label predictions. + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + class_reduction: + .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. + + Return: + The shape of the returned tensor depends on the ``average`` parameter + + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number + of classes + + Example: + + >>> from pytorch_lightning.metrics.functional import recall + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> recall(preds, target, average='macro', num_classes=3) + tensor(0.3333) + >>> recall(preds, target, average='micro') + tensor(0.2500) + """ - .. deprecated:: - Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0. - """ + if class_reduction: + rank_zero_warn( + "This `class_reduction` parameter was deprecated in v1.2.0 in favor of" + " `reduce`. It will be removed in v1.4.0", + DeprecationWarning, + ) + average = class_reduction + + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + allowed_mdmc_average = [None, "samplewise", "global"] + if mdmc_average not in allowed_mdmc_average: + raise ValueError("The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") + + if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): + raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") + + if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + + reduce = "macro" if average in ["weighted", "none", None] else average + tp, fp, tn, fn = _stat_scores_update( + preds, + target, + reduce=reduce, + mdmc_reduce=mdmc_average, + threshold=threshold, + num_classes=num_classes, + top_k=top_k, + is_multiclass=is_multiclass, + ignore_index=ignore_index, + ) + + return _recall_compute(tp, fp, tn, fn, average, mdmc_average) -@deprecated_metrics(target=_precision_recall) def precision_recall( preds: torch.Tensor, target: torch.Tensor, @@ -68,8 +354,143 @@ def precision_recall( threshold: float = 0.5, top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, + class_reduction: Optional[str] = None, ) -> torch.Tensor: + r""" + Computes `Precision and Recall `_: + + .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} + + + .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} + + Where :math:`\text{TP}`m :math:`\text{FN}` and :math:`\text{FP}` represent the number + of true positives, false negatives and false positives respecitively. With the use of + ``top_k`` parameter, this metric can generalize to Recall@K and Precision@K. + + The reduction method (how the recall scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. + + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth values + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics accross classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics accross classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs + (see :ref:`extensions/metrics:input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. + + num_classes: + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. If this parameter is set for multi-label + inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, + this parameter defaults to 1. + + Should be left unset (``None``) for inputs with label predictions. + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + class_reduction: + .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. + + Return: + The function returns a tuple with two elements: precision and recall. Their shape + depends on the ``average`` parameter + + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, they are a single element tensor + - If ``average in ['none', None]``, they are a tensor of shape ``(C, )``, where ``C`` stands for + the number of classes + + Example: + + >>> from pytorch_lightning.metrics.functional import precision_recall + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> precision_recall(preds, target, average='macro', num_classes=3) + (tensor(0.1667), tensor(0.3333)) + >>> precision_recall(preds, target, average='micro') + (tensor(0.2500), tensor(0.2500)) + """ - .. deprecated:: - Use :func:`torchmetrics.functional.precision_recall`. Will be removed in v1.5.0. - """ + if class_reduction: + rank_zero_warn( + "This `class_reduction` parameter was deprecated in v1.2.0 in favor of" + " `reduce`. It will be removed in v1.4.0", + DeprecationWarning, + ) + average = class_reduction + + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + allowed_mdmc_average = [None, "samplewise", "global"] + if mdmc_average not in allowed_mdmc_average: + raise ValueError("The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") + + if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): + raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") + + if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + + reduce = "macro" if average in ["weighted", "none", None] else average + tp, fp, tn, fn = _stat_scores_update( + preds, + target, + reduce=reduce, + mdmc_reduce=mdmc_average, + threshold=threshold, + num_classes=num_classes, + top_k=top_k, + is_multiclass=is_multiclass, + ignore_index=ignore_index, + ) + + precision = _precision_compute(tp, fp, tn, fn, average, mdmc_average) + recall = _recall_compute(tp, fp, tn, fn, average, mdmc_average) + + return precision, recall diff --git a/pytorch_lightning/metrics/functional/precision_recall_curve.py b/pytorch_lightning/metrics/functional/precision_recall_curve.py index dc9863cbb47c4..fb442b020af88 100644 --- a/pytorch_lightning/metrics/functional/precision_recall_curve.py +++ b/pytorch_lightning/metrics/functional/precision_recall_curve.py @@ -14,12 +14,140 @@ from typing import List, Optional, Sequence, Tuple, Union import torch -from torchmetrics.functional import precision_recall_curve as _precision_recall_curve +import torch.nn.functional as F -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.utilities import rank_zero_warn + + +def _binary_clf_curve( + preds: torch.Tensor, + target: torch.Tensor, + sample_weights: Optional[Sequence] = None, + pos_label: int = 1., +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py + """ + if sample_weights is not None and not isinstance(sample_weights, torch.Tensor): + sample_weights = torch.tensor(sample_weights, device=preds.device, dtype=torch.float) + + # remove class dimension if necessary + if preds.ndim > target.ndim: + preds = preds[:, 0] + desc_score_indices = torch.argsort(preds, descending=True) + + preds = preds[desc_score_indices] + target = target[desc_score_indices] + + if sample_weights is not None: + weight = sample_weights[desc_score_indices] + else: + weight = 1. + + # pred typically has many tied values. Here we extract + # the indices associated with the distinct values. We also + # concatenate a value for the end of the curve. + distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0] + threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1) + target = (target == pos_label).to(torch.long) + tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] + + if sample_weights is not None: + # express fps as a cumsum to ensure fps is increasing even in + # the presence of floating point errors + fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs] + else: + fps = 1 + threshold_idxs - tps + + return fps, tps, preds[threshold_idxs] + + +def _precision_recall_curve_update( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor, int, int]: + if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): + raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") + # single class evaluation + if len(preds.shape) == len(target.shape): + num_classes = 1 + if pos_label is None: + rank_zero_warn('`pos_label` automatically set 1.') + pos_label = 1 + preds = preds.flatten() + target = target.flatten() + + # multi class evaluation + if len(preds.shape) == len(target.shape) + 1: + if pos_label is not None: + rank_zero_warn( + 'Argument `pos_label` should be `None` when running' + f' multiclass precision recall curve. Got {pos_label}' + ) + if num_classes != preds.shape[1]: + raise ValueError( + f'Argument `num_classes` was set to {num_classes} in' + f' metric `precision_recall_curve` but detected {preds.shape[1]}' + ' number of classes from predictions' + ) + preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1) + target = target.flatten() + + return preds, target, num_classes, pos_label + + +def _precision_recall_curve_compute( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + pos_label: int, + sample_weights: Optional[Sequence] = None, +) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], + List[torch.Tensor]]]: + + if num_classes == 1: + fps, tps, thresholds = _binary_clf_curve( + preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label + ) + + precision = tps / (tps + fps) + recall = tps / tps[-1] + + # stop when full recall attained + # and reverse the outputs so recall is decreasing + last_ind = torch.where(tps == tps[-1])[0][0] + sl = slice(0, last_ind.item() + 1) + + # need to call reversed explicitly, since including that to slice would + # introduce negative strides that are not yet supported in pytorch + precision = torch.cat([reversed(precision[sl]), torch.ones(1, dtype=precision.dtype, device=precision.device)]) + + recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)]) + + thresholds = reversed(thresholds[sl]).clone() + + return precision, recall, thresholds + + # Recursively call per class + precision, recall, thresholds = [], [], [] + for c in range(num_classes): + preds_c = preds[:, c] + res = precision_recall_curve( + preds=preds_c, + target=target, + num_classes=1, + pos_label=c, + sample_weights=sample_weights, + ) + precision.append(res[0]) + recall.append(res[1]) + thresholds.append(res[2]) + + return precision, recall, thresholds -@deprecated_metrics(target=_precision_recall_curve) def precision_recall_curve( preds: torch.Tensor, target: torch.Tensor, @@ -27,8 +155,64 @@ def precision_recall_curve( pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]], ]: + List[torch.Tensor]]]: """ - .. deprecated:: - Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0. + Computes precision-recall pairs for different thresholds. + + Args: + preds: predictions from model (probabilities) + target: ground truth labels + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translate to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range [0,num_classes-1] + sample_weights: sample weights for each data point + + Returns: + 3-element tuple containing + + precision: + tensor where element i is the precision of predictions with + score >= thresholds[i] and the last element is 1. + If multiclass, this is a list of such tensors, one for each class. + recall: + tensor where element i is the recall of predictions with + score >= thresholds[i] and the last element is 0. + If multiclass, this is a list of such tensors, one for each class. + thresholds: + Thresholds used for computing precision/recall scores + + Example (binary case): + + >>> from pytorch_lightning.metrics.functional import precision_recall_curve + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> precision, recall, thresholds = precision_recall_curve(pred, target, pos_label=1) + >>> precision + tensor([0.6667, 0.5000, 0.0000, 1.0000]) + >>> recall + tensor([1.0000, 0.5000, 0.0000, 0.0000]) + >>> thresholds + tensor([1, 2, 3]) + + Example (multiclass case): + + >>> from pytorch_lightning.metrics.functional import precision_recall_curve + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> precision, recall, thresholds = precision_recall_curve(pred, target, num_classes=5) + >>> precision # doctest: +NORMALIZE_WHITESPACE + [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), + tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] + >>> recall + [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] + >>> thresholds + [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] """ + preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label) + return _precision_recall_curve_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/pytorch_lightning/metrics/functional/psnr.py b/pytorch_lightning/metrics/functional/psnr.py index 51be9d47b91f9..bd513d4ca21dd 100644 --- a/pytorch_lightning/metrics/functional/psnr.py +++ b/pytorch_lightning/metrics/functional/psnr.py @@ -14,12 +14,46 @@ from typing import Optional, Tuple, Union import torch -from torchmetrics.functional import psnr as _psnr -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning import utilities +from pytorch_lightning.metrics import utils + + +def _psnr_compute( + sum_squared_error: torch.Tensor, + n_obs: torch.Tensor, + data_range: torch.Tensor, + base: float = 10.0, + reduction: str = 'elementwise_mean', +) -> torch.Tensor: + psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs) + psnr = psnr_base_e * (10 / torch.log(torch.tensor(base))) + return utils.reduce(psnr, reduction=reduction) + + +def _psnr_update(preds: torch.Tensor, + target: torch.Tensor, + dim: Optional[Union[int, Tuple[int, ...]]] = None) -> Tuple[torch.Tensor, torch.Tensor]: + if dim is None: + sum_squared_error = torch.sum(torch.pow(preds - target, 2)) + n_obs = torch.tensor(target.numel(), device=target.device) + return sum_squared_error, n_obs + + sum_squared_error = torch.sum(torch.pow(preds - target, 2), dim=dim) + + if isinstance(dim, int): + dim_list = [dim] + else: + dim_list = list(dim) + if not dim_list: + n_obs = torch.tensor(target.numel(), device=target.device) + else: + n_obs = torch.tensor(target.size(), device=target.device)[dim_list].prod() + n_obs = n_obs.expand_as(sum_squared_error) + + return sum_squared_error, n_obs -@deprecated_metrics(target=_psnr) def psnr( preds: torch.Tensor, target: torch.Tensor, @@ -29,6 +63,46 @@ def psnr( dim: Optional[Union[int, Tuple[int, ...]]] = None, ) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.psnr`. Will be removed in v1.5.0. + Computes the peak signal-to-noise ratio + + Args: + preds: estimated signal + target: groun truth signal + data_range: + the range of the data. If None, it is determined from the data (max - min). ``data_range`` must be given + when ``dim`` is not None. + base: a base of a logarithm to use (default: 10) + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + dim: + Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is + None meaning scores will be reduced across all dimensions. + Return: + Tensor with PSNR score + + Example: + >>> from pytorch_lightning.metrics.functional import psnr + >>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) + >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) + >>> psnr(pred, target) + tensor(2.5527) + """ + if dim is None and reduction != 'elementwise_mean': + utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') + + if data_range is None: + if dim is not None: + # Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to calculate + # `data_range` in the future. + raise ValueError("The `data_range` must be given when `dim` is not None.") + + data_range = target.max() - target.min() + else: + data_range = torch.tensor(float(data_range)) + sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim) + return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction) diff --git a/pytorch_lightning/metrics/functional/r2score.py b/pytorch_lightning/metrics/functional/r2score.py index fe4b541989358..ef8a20c806ee9 100644 --- a/pytorch_lightning/metrics/functional/r2score.py +++ b/pytorch_lightning/metrics/functional/r2score.py @@ -11,21 +11,121 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple import torch -from torchmetrics.functional import r2score as _r2score -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.utils import _check_same_shape +from pytorch_lightning.utilities import rank_zero_warn + + +def _r2score_update( + preds: torch.tensor, + target: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + _check_same_shape(preds, target) + if preds.ndim > 2: + raise ValueError( + 'Expected both prediction and target to be 1D or 2D tensors,' + f' but recevied tensors with dimension {preds.shape}' + ) + if len(preds) < 2: + raise ValueError('Needs atleast two samples to calculate r2 score.') + + sum_error = torch.sum(target, dim=0) + sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0) + residual = torch.sum(torch.pow(target - preds, 2.0), dim=0) + total = target.size(0) + + return sum_squared_error, sum_error, residual, total + + +def _r2score_compute( + sum_squared_error: torch.Tensor, + sum_error: torch.Tensor, + residual: torch.Tensor, + total: torch.Tensor, + adjusted: int = 0, + multioutput: str = "uniform_average" +) -> torch.Tensor: + mean_error = sum_error / total + diff = sum_squared_error - sum_error * mean_error + raw_scores = 1 - (residual / diff) + + if multioutput == "raw_values": + r2score = raw_scores + elif multioutput == "uniform_average": + r2score = torch.mean(raw_scores) + elif multioutput == "variance_weighted": + diff_sum = torch.sum(diff) + r2score = torch.sum(diff / diff_sum * raw_scores) + else: + raise ValueError( + 'Argument `multioutput` must be either `raw_values`,' + f' `uniform_average` or `variance_weighted`. Received {multioutput}.' + ) + + if adjusted < 0 or not isinstance(adjusted, int): + raise ValueError('`adjusted` parameter should be an integer larger or' ' equal to 0.') + + if adjusted != 0: + if adjusted > total - 1: + rank_zero_warn( + "More independent regressions than datapoints in" + " adjusted r2 score. Falls back to standard r2 score.", UserWarning + ) + elif adjusted == total - 1: + rank_zero_warn("Division by zero in adjusted r2 score. Falls back to" " standard r2 score.", UserWarning) + else: + r2score = 1 - (1 - r2score) * (total - 1) / (total - adjusted - 1) + return r2score -@deprecated_metrics(target=_r2score) def r2score( preds: torch.Tensor, target: torch.Tensor, adjusted: int = 0, multioutput: str = "uniform_average", ) -> torch.Tensor: + r""" + Computes r2 score also known as `coefficient of determination + `_: + + .. math:: R^2 = 1 - \frac{SS_res}{SS_tot} + + where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and + :math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate + adjusted r2 score given by + + .. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1} + + where the parameter :math:`k` (the number of independent regressors) should + be provided as the ``adjusted`` argument. + + Args: + preds: estimated labels + target: ground truth labels + adjusted: number of independent regressors for calculating adjusted r2 score. + Default 0 (standard r2 score). + multioutput: Defines aggregation in the case of multiple output scores. Can be one + of the following strings (default is ``'uniform_average'``.): + + * ``'raw_values'`` returns full set of scores + * ``'uniform_average'`` scores are uniformly averaged + * ``'variance_weighted'`` scores are weighted by their individual variances + + Example: + + >>> from pytorch_lightning.metrics.functional import r2score + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> r2score(preds, target) + tensor(0.9486) + + >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) + >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) + >>> r2score(preds, target, multioutput='raw_values') + tensor([0.9654, 0.9082]) """ - .. deprecated:: - Use :func:`torchmetrics.functional.r2score`. Will be removed in v1.5.0. - """ + sum_squared_error, sum_error, residual, total = _r2score_update(preds, target) + return _r2score_compute(sum_squared_error, sum_error, residual, total, adjusted, multioutput) diff --git a/pytorch_lightning/metrics/functional/roc.py b/pytorch_lightning/metrics/functional/roc.py index 928a0b40fca54..030c974365807 100644 --- a/pytorch_lightning/metrics/functional/roc.py +++ b/pytorch_lightning/metrics/functional/roc.py @@ -13,21 +13,135 @@ # limitations under the License. from typing import List, Optional, Sequence, Tuple, Union -from torch import Tensor -from torchmetrics.functional import roc as _roc +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.precision_recall_curve import ( + _binary_clf_curve, + _precision_recall_curve_update, +) + + +def _roc_update( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor, int, int]: + return _precision_recall_curve_update(preds, target, num_classes, pos_label) + + +def _roc_compute( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + pos_label: int, + sample_weights: Optional[Sequence] = None, +) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], + List[torch.Tensor]]]: + + if num_classes == 1: + fps, tps, thresholds = _binary_clf_curve( + preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label + ) + # Add an extra threshold position + # to make sure that the curve starts at (0, 0) + tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps]) + fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps]) + thresholds = torch.cat([thresholds[0][None] + 1, thresholds]) + + if fps[-1] <= 0: + raise ValueError("No negative samples in targets, false positive value should be meaningless") + fpr = fps / fps[-1] + + if tps[-1] <= 0: + raise ValueError("No positive samples in targets, true positive value should be meaningless") + tpr = tps / tps[-1] + + return fpr, tpr, thresholds + + # Recursively call per class + fpr, tpr, thresholds = [], [], [] + for c in range(num_classes): + preds_c = preds[:, c] + res = roc( + preds=preds_c, + target=target, + num_classes=1, + pos_label=c, + sample_weights=sample_weights, + ) + fpr.append(res[0]) + tpr.append(res[1]) + thresholds.append(res[2]) + + return fpr, tpr, thresholds -@deprecated_metrics(target=_roc) def roc( - preds: Tensor, - target: Tensor, + preds: torch.Tensor, + target: torch.Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: +) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], + List[torch.Tensor]]]: """ - .. deprecated:: - Use :func:`torchmetrics.functional.roc`. Will be removed in v1.5.0. + Computes the Receiver Operating Characteristic (ROC). + + Args: + preds: predictions from model (logits or probabilities) + target: ground truth values + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translate to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range [0,num_classes-1] + sample_weights: sample weights for each data point + + Returns: + 3-element tuple containing + + fpr: + tensor with false positive rates. + If multiclass, this is a list of such tensors, one for each class. + tpr: + tensor with true positive rates. + If multiclass, this is a list of such tensors, one for each class. + thresholds: + thresholds used for computing false- and true postive rates + + Example (binary case): + + >>> from pytorch_lightning.metrics.functional import roc + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> fpr, tpr, thresholds = roc(pred, target, pos_label=1) + >>> fpr + tensor([0., 0., 0., 0., 1.]) + >>> tpr + tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) + >>> thresholds + tensor([4, 3, 2, 1, 0]) + + Example (multiclass case): + + >>> from pytorch_lightning.metrics.functional import roc + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05], + ... [0.05, 0.05, 0.05, 0.75]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> fpr, tpr, thresholds = roc(pred, target, num_classes=4) + >>> fpr + [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] + >>> tpr + [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500])] """ + preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes, pos_label) + return _roc_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/pytorch_lightning/metrics/functional/self_supervised.py b/pytorch_lightning/metrics/functional/self_supervised.py index 65dec211e938a..ed00677bb32d3 100644 --- a/pytorch_lightning/metrics/functional/self_supervised.py +++ b/pytorch_lightning/metrics/functional/self_supervised.py @@ -12,12 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -from torchmetrics.functional import embedding_similarity as _embedding_similarity -from pytorch_lightning.metrics.utils import deprecated_metrics - -@deprecated_metrics(target=_embedding_similarity) def embedding_similarity( batch: torch.Tensor, similarity: str = 'cosine', @@ -25,6 +21,39 @@ def embedding_similarity( zero_diagonal: bool = True ) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.embedding_similarity`. Will be removed in v1.5.0. + Computes representation similarity + + Example: + >>> from pytorch_lightning.metrics.functional import embedding_similarity + >>> embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]]) + >>> embedding_similarity(embeddings) + tensor([[0.0000, 1.0000, 0.9759], + [1.0000, 0.0000, 0.9759], + [0.9759, 0.9759, 0.0000]]) + + Args: + batch: (batch, dim) + similarity: 'dot' or 'cosine' + reduction: 'none', 'sum', 'mean' (all along dim -1) + zero_diagonal: if True, the diagonals are set to zero + + Return: + A square matrix (batch, batch) with the similarity scores between all elements + If sum or mean are used, then returns (b, 1) with the reduced value for each row """ + if similarity == 'cosine': + norm = torch.norm(batch, p=2, dim=1) + batch = batch / norm.unsqueeze(1) + + sqr_mtx = batch.mm(batch.transpose(1, 0)) + + if zero_diagonal: + sqr_mtx = sqr_mtx.fill_diagonal_(0) + + if reduction == 'mean': + sqr_mtx = sqr_mtx.mean(dim=-1) + + if reduction == 'sum': + sqr_mtx = sqr_mtx.sum(dim=-1) + + return sqr_mtx diff --git a/pytorch_lightning/metrics/functional/ssim.py b/pytorch_lightning/metrics/functional/ssim.py index 31cff7fcfb9b4..459c1855f6fee 100644 --- a/pytorch_lightning/metrics/functional/ssim.py +++ b/pytorch_lightning/metrics/functional/ssim.py @@ -11,15 +11,107 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence +from typing import Optional, Sequence, Tuple import torch -from torchmetrics.functional import ssim as _ssim +from torch.nn import functional as F -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.utils import _check_same_shape, reduce + + +def _gaussian(kernel_size: int, sigma: int, dtype: torch.dtype, device: torch.device): + dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device) + gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2) + return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) + + +def _gaussian_kernel( + channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device +): + gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device) + gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device) + kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) + + return kernel.expand(channel, 1, kernel_size[0], kernel_size[1]) + + +def _ssim_update( + preds: torch.Tensor, + target: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + if preds.dtype != target.dtype: + raise TypeError( + "Expected `preds` and `target` to have the same data type." + f" Got pred: {preds.dtype} and target: {target.dtype}." + ) + _check_same_shape(preds, target) + if len(preds.shape) != 4: + raise ValueError( + "Expected `preds` and `target` to have BxCxHxW shape." + f" Got pred: {preds.shape} and target: {target.shape}." + ) + return preds, target + + +def _ssim_compute( + preds: torch.Tensor, + target: torch.Tensor, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: str = "elementwise_mean", + data_range: Optional[float] = None, + k1: float = 0.01, + k2: float = 0.03, +): + if len(kernel_size) != 2 or len(sigma) != 2: + raise ValueError( + "Expected `kernel_size` and `sigma` to have the length of two." + f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}." + ) + + if any(x % 2 == 0 or x <= 0 for x in kernel_size): + raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.") + + if any(y <= 0 for y in sigma): + raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.") + + if data_range is None: + data_range = max(preds.max() - preds.min(), target.max() - target.min()) + + c1 = pow(k1 * data_range, 2) + c2 = pow(k2 * data_range, 2) + device = preds.device + + channel = preds.size(1) + dtype = preds.dtype + kernel = _gaussian_kernel(channel, kernel_size, sigma, dtype, device) + pad_w = (kernel_size[0] - 1) // 2 + pad_h = (kernel_size[1] - 1) // 2 + + preds = F.pad(preds, (pad_w, pad_w, pad_h, pad_h), mode='reflect') + target = F.pad(target, (pad_w, pad_w, pad_h, pad_h), mode='reflect') + + input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) + outputs = F.conv2d(input_list, kernel, groups=channel) + output_list = [outputs[x * preds.size(0):(x + 1) * preds.size(0)] for x in range(len(outputs))] + + mu_pred_sq = output_list[0].pow(2) + mu_target_sq = output_list[1].pow(2) + mu_pred_target = output_list[0] * output_list[1] + + sigma_pred_sq = output_list[2] - mu_pred_sq + sigma_target_sq = output_list[3] - mu_target_sq + sigma_pred_target = output_list[4] - mu_pred_target + + upper = 2 * sigma_pred_target + c2 + lower = sigma_pred_sq + sigma_target_sq + c2 + + ssim_idx = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower) + ssim_idx = ssim_idx[..., pad_h:-pad_h, pad_w:-pad_w] + + return reduce(ssim_idx, reduction) -@deprecated_metrics(target=_ssim) def ssim( preds: torch.Tensor, target: torch.Tensor, @@ -31,6 +123,32 @@ def ssim( k2: float = 0.03, ) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.functional.ssim`. Will be removed in v1.5.0. + Computes Structual Similarity Index Measure + + Args: + preds: estimated image + target: ground truth image + kernel_size: size of the gaussian kernel (default: (11, 11)) + sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + data_range: Range of the image. If ``None``, it is determined from the image (max - min) + k1: Parameter of SSIM. Default: 0.01 + k2: Parameter of SSIM. Default: 0.03 + + Return: + Tensor with SSIM score + + Example: + >>> from pytorch_lightning.metrics.functional import ssim + >>> preds = torch.rand([16, 1, 16, 16]) + >>> target = preds * 0.75 + >>> ssim(preds, target) + tensor(0.9219) """ + preds, target = _ssim_update(preds, target) + return _ssim_compute(preds, target, kernel_size, sigma, reduction, data_range, k1, k2) diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py index 30c03da237fe6..44b4434f4dcf1 100644 --- a/pytorch_lightning/metrics/functional/stat_scores.py +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -11,15 +11,131 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Tuple import torch -from torchmetrics.functional import stat_scores as _stat_scores -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.classification.helpers import _input_format_classification + + +def _del_column(tensor: torch.Tensor, index: int): + """ Delete the column at index.""" + + return torch.cat([tensor[:, :index], tensor[:, (index + 1):]], 1) + + +def _stat_scores( + preds: torch.Tensor, + target: torch.Tensor, + reduce: str = "micro", +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate the number of tp, fp, tn, fn. + + Args: + preds: + An ``(N, C)`` or ``(N, C, X)`` tensor of predictions (0 or 1) + target: + An ``(N, C)`` or ``(N, C, X)`` tensor of true labels (0 or 1) + reduce: + One of ``'micro'``, ``'macro'``, ``'samples'`` + + Return: + Returns a list of 4 tensors; tp, fp, tn, fn. + The shape of the returned tensors depnds on the shape of the inputs + and the ``reduce`` parameter: + + If inputs are of the shape ``(N, C)``, then + - If ``reduce='micro'``, the returned tensors are 1 element tensors + - If ``reduce='macro'``, the returned tensors are ``(C,)`` tensors + - If ``reduce'samples'``, the returned tensors are ``(N,)`` tensors + + If inputs are of the shape ``(N, C, X)``, then + - If ``reduce='micro'``, the returned tensors are ``(N,)`` tensors + - If ``reduce='macro'``, the returned tensors are ``(N,C)`` tensors + - If ``reduce='samples'``, the returned tensors are ``(N,X)`` tensors + """ + if reduce == "micro": + dim = [0, 1] if preds.ndim == 2 else [1, 2] + elif reduce == "macro": + dim = 0 if preds.ndim == 2 else 2 + elif reduce == "samples": + dim = 1 + + true_pred, false_pred = target == preds, target != preds + pos_pred, neg_pred = preds == 1, preds == 0 + + tp = (true_pred * pos_pred).sum(dim=dim) + fp = (false_pred * pos_pred).sum(dim=dim) + + tn = (true_pred * neg_pred).sum(dim=dim) + fn = (false_pred * neg_pred).sum(dim=dim) + + return tp.long(), fp.long(), tn.long(), fn.long() + + +def _stat_scores_update( + preds: torch.Tensor, + target: torch.Tensor, + reduce: str = "micro", + mdmc_reduce: Optional[str] = None, + num_classes: Optional[int] = None, + top_k: Optional[int] = None, + threshold: float = 0.5, + is_multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + preds, target, _ = _input_format_classification( + preds, target, threshold=threshold, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k + ) + + if ignore_index is not None and not 0 <= ignore_index < preds.shape[1]: + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {preds.shape[0]} classes") + + if ignore_index is not None and preds.shape[1] == 1: + raise ValueError("You can not use `ignore_index` with binary data.") + + if preds.ndim == 3: + if not mdmc_reduce: + raise ValueError( + "When your inputs are multi-dimensional multi-class, you have to set the `mdmc_reduce` parameter" + ) + if mdmc_reduce == "global": + preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) + target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) + + # Delete what is in ignore_index, if applicable (and classes don't matter): + if ignore_index is not None and reduce != "macro": + preds = _del_column(preds, ignore_index) + target = _del_column(target, ignore_index) + + tp, fp, tn, fn = _stat_scores(preds, target, reduce=reduce) + + # Take care of ignore_index + if ignore_index is not None and reduce == "macro": + tp[..., ignore_index] = -1 + fp[..., ignore_index] = -1 + tn[..., ignore_index] = -1 + fn[..., ignore_index] = -1 + + return tp, fp, tn, fn + + +def _stat_scores_compute(tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, fn: torch.Tensor) -> torch.Tensor: + + outputs = [ + tp.unsqueeze(-1), + fp.unsqueeze(-1), + tn.unsqueeze(-1), + fn.unsqueeze(-1), + tp.unsqueeze(-1) + fn.unsqueeze(-1), # support + ] + outputs = torch.cat(outputs, -1) + outputs = torch.where(outputs < 0, torch.tensor(-1, device=outputs.device), outputs) + + return outputs -@deprecated_metrics(target=_stat_scores) def stat_scores( preds: torch.Tensor, target: torch.Tensor, @@ -31,7 +147,137 @@ def stat_scores( is_multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, ) -> torch.Tensor: + """Computes the number of true positives, false positives, true negatives, false negatives. + Related to `Type I and Type II errors `__ + and the `confusion matrix `__. + + The reduction method (how the statistics are aggregated) is controlled by the + ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. + + Args: + preds: Predictions from model (probabilities or labels) + target: Ground truth values + threshold: + Threshold probability value for transforming probability predictions to binary + (0 or 1) predictions, in the case of binary or multi-label inputs. + + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. If this parameter is set for multi-label + inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, + this parameter defaults to 1. + + Should be left unset (``None``) for inputs with label predictions. + + reduce: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Counts the statistics by summing over all [sample, class] + combinations (globally). Each statistic is represented by a single integer. + - ``'macro'``: Counts the statistics for each class separately (over all samples). + Each statistic is represented by a ``(C,)`` tensor. Requires ``num_classes`` + to be set. + - ``'samples'``: Counts the statistics for each sample separately (over all classes). + Each statistic is represented by a ``(N, )`` 1d tensor. + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_reduce``. + + num_classes: + Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + + ignore_index: + Specify a class (label) to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and + ``reduce='macro'``, the class statistics for the ignored class will all be returned + as ``-1``. + + mdmc_reduce: + Defines how the multi-dimensional multi-class inputs are handeled. Should be + one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class (see :ref:`extensions/metrics:input types` for the definition of input types). + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then the outputs are concatenated together. In each + sample the extra axes ``...`` are flattened to become the sub-sample axis, and + statistics for each sample are computed by treating the sub-sample axis as the + ``N`` axis for that sample. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs are + flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``reduce`` parameter applies as usual. + + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + Return: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The + shape depends on the ``reduce`` and ``mdmc_reduce`` (in case of multi-dimensional + multi-class data) parameters: + + - If the data is not multi-dimensional multi-class, then + + - If ``reduce='micro'``, the shape will be ``(5, )`` + - If ``reduce='macro'``, the shape will be ``(C, 5)``, + where ``C`` stands for the number of classes + - If ``reduce='samples'``, the shape will be ``(N, 5)``, where ``N`` stands for + the number of samples + + - If the data is multi-dimensional multi-class and ``mdmc_reduce='global'``, then + + - If ``reduce='micro'``, the shape will be ``(5, )`` + - If ``reduce='macro'``, the shape will be ``(C, 5)`` + - If ``reduce='samples'``, the shape will be ``(N*X, 5)``, where ``X`` stands for + the product of sizes of all "extra" dimensions of the data (i.e. all dimensions + except for ``C`` and ``N``) + + - If the data is multi-dimensional multi-class and ``mdmc_reduce='samplewise'``, then + + - If ``reduce='micro'``, the shape will be ``(N, 5)`` + - If ``reduce='macro'``, the shape will be ``(N, C, 5)`` + - If ``reduce='samples'``, the shape will be ``(N, X, 5)`` + + Example: + + >>> from pytorch_lightning.metrics.functional import stat_scores + >>> preds = torch.tensor([1, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> stat_scores(preds, target, reduce='macro', num_classes=3) + tensor([[0, 1, 2, 1, 1], + [1, 1, 1, 1, 2], + [1, 0, 3, 0, 1]]) + >>> stat_scores(preds, target, reduce='micro') + tensor([2, 2, 6, 2, 4]) """ - .. deprecated:: - Use :func:`torchmetrics.functional.stat_scores`. Will be removed in v1.5.0. - """ + + if reduce not in ["micro", "macro", "samples"]: + raise ValueError(f"The `reduce` {reduce} is not valid.") + + if mdmc_reduce not in [None, "samplewise", "global"]: + raise ValueError(f"The `mdmc_reduce` {mdmc_reduce} is not valid.") + + if reduce == "macro" and (not num_classes or num_classes < 1): + raise ValueError("When you set `reduce` as 'macro', you have to provide the number of classes.") + + if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + + tp, fp, tn, fn = _stat_scores_update( + preds, + target, + reduce=reduce, + mdmc_reduce=mdmc_reduce, + top_k=top_k, + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + ignore_index=ignore_index, + ) + return _stat_scores_compute(tp, fp, tn, fn) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index ee0fcdb8a92e1..3ff3039cb99b1 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -11,17 +11,54 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools +import inspect +from abc import ABC, abstractmethod +from collections.abc import Sequence +from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from torchmetrics import Metric as _Metric -from torchmetrics.collections import MetricCollection as _MetricCollection +import torch +from torch import nn -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.distributed import gather_all_tensors -class Metric(_Metric): +class Metric(nn.Module, ABC): + """ + Base class for all metrics present in the Metrics API. + + Implements ``add_state()``, ``forward()``, ``reset()`` and a few other things to + handle distributed synchronization and per-step metric computation. + + Override ``update()`` and ``compute()`` functions to implement your own metric. Use + ``add_state()`` to register metric state variables which keep track of state on each + call of ``update()`` and are synchronized across processes when ``compute()`` is called. + + Note: + Metric state variables can either be ``torch.Tensors`` or an empty list which can we used + to store `torch.Tensors``. + + Note: + Different metrics only override ``update()`` and not ``forward()``. A call to ``update()`` + is valid, but it won't return the metric value at the current step. A call to ``forward()`` + automatically calls ``update()`` and also returns the metric value at the current step. + + Args: + compute_on_step: + Forward only calls ``update()`` and returns None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + """ - @deprecated_metrics(target=_Metric) def __init__( self, compute_on_step: bool = True, @@ -29,17 +66,559 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - r""" - .. deprecated:: - Use :class:`torchmetrics.Metric`. Will be removed in v1.5.0. + super().__init__() + + self.dist_sync_on_step = dist_sync_on_step + self.compute_on_step = compute_on_step + self.process_group = process_group + self.dist_sync_fn = dist_sync_fn + self._to_sync = True + + self._update_signature = inspect.signature(self.update) + self.update = self._wrap_update(self.update) + self.compute = self._wrap_compute(self.compute) + self._computed = None + self._forward_cache = None + + # initialize state + self._defaults = {} + self._persistent = {} + self._reductions = {} + + def add_state( + self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = False + ): + """ + Adds metric state variable. Only used by subclasses. + + Args: + name: The name of the state variable. The variable will then be accessible at ``self.name``. + default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be + reset to this value when ``self.reset()`` is called. + dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode. + If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, + and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction + only makes sense if the state is a list, and not a tensor. The user can also pass a custom + function in this parameter. + persistent (Optional): whether the state will be saved as part of the modules ``state_dict``. + Default is ``False``. + + Note: + Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes. + However, there won't be any reduction function applied to the synchronized metric state. + + The metric states would be synced as follows + + - If the metric state is ``torch.Tensor``, the synced value will be a stacked ``torch.Tensor`` across + the process dimension if the metric state was a ``torch.Tensor``. The original ``torch.Tensor`` metric + state retains dimension and hence the synchronized output will be of shape ``(num_process, ...)``. + + - If the metric state is a ``list``, the synced value will be a ``list`` containing the + combined elements from all processes. + + Note: + When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow + the format discussed in the above note. + + Raises: + ValueError: + If ``default`` is not a ``tensor`` or an ``empty list``. + ValueError: + If ``dist_reduce_fx`` is not callable or one of ``"mean"``, ``"sum"``, ``"cat"``, ``None``. + """ + if ( + not isinstance(default, torch.Tensor) and not isinstance(default, list) # noqa: W503 + or (isinstance(default, list) and len(default) != 0) # noqa: W503 + ): + raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)") + + if dist_reduce_fx == "sum": + dist_reduce_fx = dim_zero_sum + elif dist_reduce_fx == "mean": + dist_reduce_fx = dim_zero_mean + elif dist_reduce_fx == "cat": + dist_reduce_fx = dim_zero_cat + elif dist_reduce_fx is not None and not isinstance(dist_reduce_fx, Callable): + raise ValueError("`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]") + + setattr(self, name, default) + + self._defaults[name] = deepcopy(default) + self._persistent[name] = persistent + self._reductions[name] = dist_reduce_fx + + @torch.jit.unused + def forward(self, *args, **kwargs): + """ + Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True. + """ + # add current step + with torch.no_grad(): + self.update(*args, **kwargs) + self._forward_cache = None + + if self.compute_on_step: + self._to_sync = self.dist_sync_on_step + + # save context before switch + cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} + + # call reset, update, compute, on single batch + self.reset() + self.update(*args, **kwargs) + self._forward_cache = self.compute() + + # restore context + for attr, val in cache.items(): + setattr(self, attr, val) + self._to_sync = True + self._computed = None + + return self._forward_cache + + def _sync_dist(self, dist_sync_fn=gather_all_tensors): + input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()} + output_dict = apply_to_collection( + input_dict, + torch.Tensor, + dist_sync_fn, + group=self.process_group, + ) + + for attr, reduction_fn in self._reductions.items(): + # pre-processing ops (stack or flatten for inputs) + if isinstance(output_dict[attr][0], torch.Tensor): + output_dict[attr] = torch.stack(output_dict[attr]) + elif isinstance(output_dict[attr][0], list): + output_dict[attr] = _flatten(output_dict[attr]) + + assert isinstance(reduction_fn, (Callable)) or reduction_fn is None + reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr] + setattr(self, attr, reduced) + + def _wrap_update(self, update): + + @functools.wraps(update) + def wrapped_func(*args, **kwargs): + self._computed = None + return update(*args, **kwargs) + + return wrapped_func + + def _wrap_compute(self, compute): + + @functools.wraps(compute) + def wrapped_func(*args, **kwargs): + # return cached value + if self._computed is not None: + return self._computed + + dist_sync_fn = self.dist_sync_fn + if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized(): + # User provided a bool, so we assume DDP if available + dist_sync_fn = gather_all_tensors + + synced = False + if self._to_sync and dist_sync_fn is not None: + # cache prior to syncing + cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} + + # sync + self._sync_dist(dist_sync_fn) + synced = True + + self._computed = compute(*args, **kwargs) + if synced: + # if we synced, restore to cache so that we can continue to accumulate un-synced state + for attr, val in cache.items(): + setattr(self, attr, val) + + return self._computed + + return wrapped_func + + @abstractmethod + def update(self) -> None: # pylint: disable=E0202 + """ + Override this method to update the state variables of your metric class. + """ + pass + + @abstractmethod + def compute(self): # pylint: disable=E0202 + """ + Override this method to compute the final metric value from state variables + synchronized across the distributed backend. + """ + pass + + def reset(self): + """ + This method automatically resets the metric state variables to their default value. + """ + for attr, default in self._defaults.items(): + current_val = getattr(self, attr) + if isinstance(default, torch.Tensor): + setattr(self, attr, deepcopy(default).to(current_val.device)) + else: + setattr(self, attr, deepcopy(default)) + + def clone(self): + """ Make a copy of the metric """ + return deepcopy(self) + + def __getstate__(self): + # ignore update and compute functions for pickling + return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]} + + def __setstate__(self, state): + # manually restore update and compute functions for pickling + self.__dict__.update(state) + self.update = self._wrap_update(self.update) + self.compute = self._wrap_compute(self.compute) + + def _apply(self, fn): + """Overwrite _apply function such that we can also move metric states + to the correct device when `.to`, `.cuda`, etc methods are called """ + self = super()._apply(fn) + # Also apply fn to metric states + for key in self._defaults.keys(): + current_val = getattr(self, key) + if isinstance(current_val, torch.Tensor): + setattr(self, key, fn(current_val)) + elif isinstance(current_val, Sequence): + setattr(self, key, [fn(cur_v) for cur_v in current_val]) + else: + raise TypeError( + "Expected metric state to be either a torch.Tensor" + f"or a list of torch.Tensor, but encountered {current_val}" + ) + return self + + def persistent(self, mode: bool = False): + """Method for post-init to change if metric states should be saved to + its state_dict + """ + for key in self._persistent.keys(): + self._persistent[key] = mode + + def state_dict(self, destination=None, prefix='', keep_vars=False): + destination = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + # Register metric states to be part of the state_dict + for key in self._defaults.keys(): + if self._persistent[key]: + current_val = getattr(self, key) + if not keep_vars: + if torch.is_tensor(current_val): + current_val = current_val.detach() + elif isinstance(current_val, list): + current_val = [cur_v.detach() if torch.is_tensor(cur_v) else cur_v for cur_v in current_val] + destination[prefix + key] = current_val + return destination + + def _filter_kwargs(self, **kwargs): + """ filter kwargs such that they match the update signature of the metric """ + + # filter all parameters based on update signature except those of + # type VAR_POSITIONAL (*args) and VAR_KEYWORD (**kwargs) + _params = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + filtered_kwargs = { + k: v + for k, v in kwargs.items() if k in self._update_signature.parameters.keys() + and self._update_signature.parameters[k].kind not in _params + } + + # if no kwargs filtered, return al kwargs as default + if not filtered_kwargs: + filtered_kwargs = kwargs + return filtered_kwargs + + def __hash__(self): + hash_vals = [self.__class__.__name__] + + for key in self._defaults.keys(): + val = getattr(self, key) + # Special case: allow list values, so long + # as their elements are hashable + if hasattr(val, '__iter__') and not isinstance(val, torch.Tensor): + hash_vals.extend(val) + else: + hash_vals.append(val) + + return hash(tuple(hash_vals)) + + def __add__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.add, self, other) + + def __and__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.bitwise_and, self, other) + + def __eq__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.eq, self, other) + + def __floordiv__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.floor_divide, self, other) + + def __ge__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.ge, self, other) + def __gt__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric -class MetricCollection(_MetricCollection): + return CompositionalMetric(torch.gt, self, other) + + def __le__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.le, self, other) + + def __lt__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.lt, self, other) + + def __matmul__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.matmul, self, other) + + def __mod__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.fmod, self, other) + + def __mul__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.mul, self, other) + + def __ne__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.ne, self, other) + + def __or__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.bitwise_or, self, other) + + def __pow__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.pow, self, other) + + def __radd__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.add, other, self) + + def __rand__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + # swap them since bitwise_and only supports that way and it's commutative + return CompositionalMetric(torch.bitwise_and, self, other) + + def __rfloordiv__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.floor_divide, other, self) + + def __rmatmul__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.matmul, other, self) + + def __rmod__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.fmod, other, self) + + def __rmul__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.mul, other, self) + + def __ror__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.bitwise_or, other, self) + + def __rpow__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.pow, other, self) + + def __rsub__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.sub, other, self) + + def __rtruediv__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.true_divide, other, self) + + def __rxor__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.bitwise_xor, other, self) + + def __sub__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.sub, self, other) + + def __truediv__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.true_divide, self, other) + + def __xor__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.bitwise_xor, self, other) + + def __abs__(self): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.abs, self, None) + + def __inv__(self): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.bitwise_not, self, None) + + def __invert__(self): + return self.__inv__() + + def __neg__(self): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(_neg, self, None) + + def __pos__(self): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.abs, self, None) + + +def _neg(tensor: torch.Tensor): + return -torch.abs(tensor) + + +class MetricCollection(nn.ModuleDict): + """ + MetricCollection class can be used to chain metrics that have the same + call pattern into one single class. + + Args: + metrics: One of the following + + * list or tuple: if metrics are passed in as a list, will use the + metrics class name as key for output dict. Therefore, two metrics + of the same class cannot be chained this way. + + * dict: if metrics are passed in as a dict, will use each key in the + dict as key for output dict. Use this format if you want to chain + together multiple of the same metric with different parameters. + + Raises: + ValueError: + If one of the elements of ``metrics`` is not an instance of ``pl.metrics.Metric``. + ValueError: + If two elements in ``metrics`` have the same ``name``. + ValueError: + If ``metrics`` is not a ``list``, ``tuple`` or a ``dict``. + + Example (input as list): + + >>> from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall + >>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) + >>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) + >>> metrics = MetricCollection([Accuracy(), + ... Precision(num_classes=3, average='macro'), + ... Recall(num_classes=3, average='macro')]) + >>> metrics(preds, target) + {'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)} + + Example (input as dict): + + >>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'), + ... 'macro_recall': Recall(num_classes=3, average='macro')}) + >>> metrics(preds, target) + {'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)} + + """ - @deprecated_metrics(target=_MetricCollection) def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]): + super().__init__() + if isinstance(metrics, dict): + # Check all values are metrics + for name, metric in metrics.items(): + if not isinstance(metric, Metric): + raise ValueError( + f"Value {metric} belonging to key {name}" + " is not an instance of `pl.metrics.Metric`" + ) + self[name] = metric + elif isinstance(metrics, (tuple, list)): + for metric in metrics: + if not isinstance(metric, Metric): + raise ValueError( + f"Input {metric} to `MetricCollection` is not a instance" + " of `pl.metrics.Metric`" + ) + name = metric.__class__.__name__ + if name in self: + raise ValueError(f"Encountered two metrics both named {name}") + self[name] = metric + else: + raise ValueError("Unknown input to MetricCollection.") + + def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202 + """ + Iteratively call forward for each metric. Positional arguments (args) will + be passed to every metric in the collection, while keyword arguments (kwargs) + will be filtered based on the signature of the individual metric. + """ + return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()} + + def update(self, *args, **kwargs): # pylint: disable=E0202 + """ + Iteratively call update for each metric. Positional arguments (args) will + be passed to every metric in the collection, while keyword arguments (kwargs) + will be filtered based on the signature of the individual metric. """ - .. deprecated:: - Use :class:`torchmetrics.MetricCollection`. Will be removed in v1.5.0. + for _, m in self.items(): + m_kwargs = m._filter_kwargs(**kwargs) + m.update(*args, **m_kwargs) + + def compute(self) -> Dict[str, Any]: + return {k: m.compute() for k, m in self.items()} + + def reset(self): + """ Iteratively call reset for each metric """ + for _, m in self.items(): + m.reset() + + def clone(self): + """ Make a copy of the metric collection """ + return deepcopy(self) + + def persistent(self, mode: bool = True): + """Method for post-init to change if metric states should be saved to + its state_dict """ + for _, m in self.items(): + m.persistent(mode) diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py index 0f94ae2fb3754..fc033fcd16759 100644 --- a/pytorch_lightning/metrics/regression/explained_variance.py +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -13,14 +13,72 @@ # limitations under the License. from typing import Any, Callable, Optional -from torchmetrics import ExplainedVariance as _ExplainedVariance +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.explained_variance import ( + _explained_variance_compute, + _explained_variance_update, +) +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.utilities import rank_zero_warn -class ExplainedVariance(_ExplainedVariance): +class ExplainedVariance(Metric): + r""" + Computes `explained variance + `_: + + .. math:: \text{ExplainedVariance} = 1 - \frac{\text{Var}(y - \hat{y})}{\text{Var}(y)} + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. + + Forward accepts + + - ``preds`` (float tensor): ``(N,)`` or ``(N, ...)`` (multioutput) + - ``target`` (long tensor): ``(N,)`` or ``(N, ...)`` (multioutput) + + In the case of multioutput, as default the variances will be uniformly + averaged over the additional dimensions. Please see argument `multioutput` + for changing this behavior. + + Args: + multioutput: + Defines aggregation in the case of multiple output scores. Can be one + of the following strings (default is `'uniform_average'`.): + + * `'raw_values'` returns full set of scores + * `'uniform_average'` scores are uniformly averaged + * `'variance_weighted'` scores are weighted by their individual variances + + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Raises: + ValueError: + If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``. + + Example: + + >>> from pytorch_lightning.metrics import ExplainedVariance + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> explained_variance = ExplainedVariance() + >>> explained_variance(preds, target) + tensor(0.9572) + + >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) + >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) + >>> explained_variance = ExplainedVariance(multioutput='raw_values') + >>> explained_variance(preds, target) + tensor([0.9677, 1.0000]) + """ - @deprecated_metrics(target=_ExplainedVariance) def __init__( self, multioutput: str = 'uniform_average', @@ -29,9 +87,43 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') + if multioutput not in allowed_multioutput: + raise ValueError( + f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}' + ) + self.multioutput = multioutput + self.add_state("y", default=[], dist_reduce_fx=None) + self.add_state("y_pred", default=[], dist_reduce_fx=None) + + rank_zero_warn( + 'Metric `ExplainedVariance` will save all targets and' + ' predictions in buffer. For large datasets this may lead' + ' to large memory footprint.' + ) + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values """ - This implementation refers to :class:`~torchmetrics.ExplainedVariance`. + preds, target = _explained_variance_update(preds, target) + self.y_pred.append(preds) + self.y.append(target) - .. deprecated:: - Use :class:`~torchmetrics.ExplainedVariance`. Will be removed in v1.5.0. + def compute(self): + """ + Computes explained variance over state. """ + preds = torch.cat(self.y_pred, dim=0) + target = torch.cat(self.y, dim=0) + return _explained_variance_compute(preds, target, self.multioutput) diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py index 57c7db420445b..ca184daf736b8 100644 --- a/pytorch_lightning/metrics/regression/mean_absolute_error.py +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -13,14 +13,42 @@ # limitations under the License. from typing import Any, Callable, Optional -from torchmetrics import MeanAbsoluteError as _MeanAbsoluteError +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.mean_absolute_error import ( + _mean_absolute_error_compute, + _mean_absolute_error_update, +) +from pytorch_lightning.metrics.metric import Metric -class MeanAbsoluteError(_MeanAbsoluteError): +class MeanAbsoluteError(Metric): + r""" + Computes `mean absolute error `_ (MAE): + + .. math:: \text{MAE} = \frac{1}{N}\sum_i^N | y_i - \hat{y_i} | + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. + + Args: + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example: + + >>> from pytorch_lightning.metrics import MeanAbsoluteError + >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> mean_absolute_error = MeanAbsoluteError() + >>> mean_absolute_error(preds, target) + tensor(0.5000) + """ - @deprecated_metrics(target=_MeanAbsoluteError) def __init__( self, compute_on_step: bool = True, @@ -28,9 +56,31 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values """ - This implementation refers to :class:`~torchmetrics.MeanAbsoluteError`. + sum_abs_error, n_obs = _mean_absolute_error_update(preds, target) - .. deprecated:: - Use :class:`~torchmetrics.MeanAbsoluteError`. Will be removed in v1.5.0. + self.sum_abs_error += sum_abs_error + self.total += n_obs + + def compute(self): + """ + Computes mean absolute error over state. """ + return _mean_absolute_error_compute(self.sum_abs_error, self.total) diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py index c8e9c151c99d9..09f275ded8638 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -13,14 +13,43 @@ # limitations under the License. from typing import Any, Callable, Optional -from torchmetrics import MeanSquaredError as _MeanSquaredError +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.mean_squared_error import ( + _mean_squared_error_compute, + _mean_squared_error_update, +) +from pytorch_lightning.metrics.metric import Metric -class MeanSquaredError(_MeanSquaredError): +class MeanSquaredError(Metric): + r""" + Computes `mean squared error `_ (MSE): + + .. math:: \text{MSE} = \frac{1}{N}\sum_i^N(y_i - \hat{y_i})^2 + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. + + Args: + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example: + + >>> from pytorch_lightning.metrics import MeanSquaredError + >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) + >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) + >>> mean_squared_error = MeanSquaredError() + >>> mean_squared_error(preds, target) + tensor(0.8750) + + """ - @deprecated_metrics(target=_MeanSquaredError) def __init__( self, compute_on_step: bool = True, @@ -28,9 +57,31 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values """ - This implementation refers to :class:`~torchmetrics.MeanSquaredError`. + sum_squared_error, n_obs = _mean_squared_error_update(preds, target) + + self.sum_squared_error += sum_squared_error + self.total += n_obs - .. deprecated:: - Use :class:`~torchmetrics.MeanSquaredError`. Will be removed in v1.5.0. + def compute(self): + """ + Computes mean squared error over state. """ + return _mean_squared_error_compute(self.sum_squared_error, self.total) diff --git a/pytorch_lightning/metrics/regression/mean_squared_log_error.py b/pytorch_lightning/metrics/regression/mean_squared_log_error.py index c8ee8a7069115..18105e687b0b1 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_log_error.py @@ -13,14 +13,45 @@ # limitations under the License. from typing import Any, Callable, Optional -from torchmetrics import MeanSquaredLogError as _MeanSquaredLogError +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.mean_squared_log_error import ( + _mean_squared_log_error_compute, + _mean_squared_log_error_update, +) +from pytorch_lightning.metrics.metric import Metric -class MeanSquaredLogError(_MeanSquaredLogError): +class MeanSquaredLogError(Metric): + r""" + Computes `mean squared logarithmic error + `_ + (MSLE): + + .. math:: \text{MSLE} = \frac{1}{N}\sum_i^N (\log_e(1 + y_i) - \log_e(1 + \hat{y_i}))^2 + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. + + Args: + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example: + + >>> from pytorch_lightning.metrics import MeanSquaredLogError + >>> target = torch.tensor([2.5, 5, 4, 8]) + >>> preds = torch.tensor([3, 5, 2.5, 7]) + >>> mean_squared_log_error = MeanSquaredLogError() + >>> mean_squared_log_error(preds, target) + tensor(0.0397) + + """ - @deprecated_metrics(target=_MeanSquaredLogError) def __init__( self, compute_on_step: bool = True, @@ -28,9 +59,31 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values """ - This implementation refers to :class:`~torchmetrics.MeanSquaredLogError`. + sum_squared_log_error, n_obs = _mean_squared_log_error_update(preds, target) + + self.sum_squared_log_error += sum_squared_log_error + self.total += n_obs - .. deprecated:: - Use :class:`~torchmetrics.MeanSquaredLogError`. Will be removed in v1.5.0. + def compute(self): + """ + Compute mean squared logarithmic error over state. """ + return _mean_squared_log_error_compute(self.sum_squared_log_error, self.total) diff --git a/pytorch_lightning/metrics/regression/psnr.py b/pytorch_lightning/metrics/regression/psnr.py index f972e9a8e2b5e..8a38bf515ebca 100644 --- a/pytorch_lightning/metrics/regression/psnr.py +++ b/pytorch_lightning/metrics/regression/psnr.py @@ -11,16 +11,61 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Sequence, Tuple, Union -from torchmetrics import PSNR as _PSNR +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning import utilities +from pytorch_lightning.metrics.functional.psnr import _psnr_compute, _psnr_update +from pytorch_lightning.metrics.metric import Metric -class PSNR(_PSNR): +class PSNR(Metric): + r""" + Computes `peak signal-to-noise ratio `_ (PSNR): + + .. math:: \text{PSNR}(I, J) = 10 * \log_{10} \left(\frac{\max(I)^2}{\text{MSE}(I, J)}\right) + + Where :math:`\text{MSE}` denotes the `mean-squared-error + `_ function. + + Args: + data_range: + the range of the data. If None, it is determined from the data (max - min). + The ``data_range`` must be given when ``dim`` is not None. + base: a base of a logarithm to use (default: 10) + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + dim: + Dimensions to reduce PSNR scores over, provided as either an integer or a list of integers. Default is + None meaning scores will be reduced across all dimensions and all batches. + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Raises: + ValueError: + If ``dim`` is not ``None`` and ``data_range`` is not given. + + Example: + + >>> from pytorch_lightning.metrics import PSNR + >>> psnr = PSNR() + >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) + >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) + >>> psnr(preds, target) + tensor(2.5527) + + """ - @deprecated_metrics(target=_PSNR) def __init__( self, data_range: Optional[float] = None, @@ -31,9 +76,71 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + + if dim is None and reduction != 'elementwise_mean': + utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') + + if dim is None: + self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + else: + self.add_state("sum_squared_error", default=[]) + self.add_state("total", default=[]) + + if data_range is None: + if dim is not None: + # Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to + # calculate `data_range` in the future. + raise ValueError("The `data_range` must be given when `dim` is not None.") + + self.data_range = None + self.add_state("min_target", default=torch.tensor(0.0), dist_reduce_fx=torch.min) + self.add_state("max_target", default=torch.tensor(0.0), dist_reduce_fx=torch.max) + else: + self.register_buffer("data_range", torch.tensor(float(data_range))) + self.base = base + self.reduction = reduction + self.dim = tuple(dim) if isinstance(dim, Sequence) else dim + + def update(self, preds: torch.Tensor, target: torch.Tensor): """ - This implementation refers to :class:`~torchmetrics.PSNR`. + Update state with predictions and targets. - .. deprecated:: - Use :class:`~torchmetrics.PSNR`. Will be removed in v1.5.0. + Args: + preds: Predictions from model + target: Ground truth values """ + sum_squared_error, n_obs = _psnr_update(preds, target, dim=self.dim) + if self.dim is None: + if self.data_range is None: + # keep track of min and max target values + self.min_target = min(target.min(), self.min_target) + self.max_target = max(target.max(), self.max_target) + + self.sum_squared_error += sum_squared_error + self.total += n_obs + else: + self.sum_squared_error.append(sum_squared_error) + self.total.append(n_obs) + + def compute(self): + """ + Compute peak signal-to-noise ratio over state. + """ + if self.data_range is not None: + data_range = self.data_range + else: + data_range = self.max_target - self.min_target + + if self.dim is None: + sum_squared_error = self.sum_squared_error + total = self.total + else: + sum_squared_error = torch.cat([values.flatten() for values in self.sum_squared_error]) + total = torch.cat([values.flatten() for values in self.total]) + return _psnr_compute(sum_squared_error, total, data_range, base=self.base, reduction=self.reduction) diff --git a/pytorch_lightning/metrics/regression/r2score.py b/pytorch_lightning/metrics/regression/r2score.py index ad5f7f3bd8d07..40d9d24711375 100644 --- a/pytorch_lightning/metrics/regression/r2score.py +++ b/pytorch_lightning/metrics/regression/r2score.py @@ -13,14 +13,81 @@ # limitations under the License. from typing import Any, Callable, Optional -from torchmetrics import R2Score as _R2Score +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.r2score import _r2score_compute, _r2score_update +from pytorch_lightning.metrics.metric import Metric -class R2Score(_R2Score): +class R2Score(Metric): + r""" + Computes r2 score also known as `coefficient of determination + `_: + + .. math:: R^2 = 1 - \frac{SS_res}{SS_tot} + + where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and + :math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate + adjusted r2 score given by + + .. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1} + + where the parameter :math:`k` (the number of independent regressors) should + be provided as the `adjusted` argument. + + Forward accepts + + - ``preds`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput) + - ``target`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput) + + In the case of multioutput, as default the variances will be uniformly + averaged over the additional dimensions. Please see argument `multioutput` + for changing this behavior. + + Args: + num_outputs: + Number of outputs in multioutput setting (default is 1) + adjusted: + number of independent regressors for calculating adjusted r2 score. + Default 0 (standard r2 score). + multioutput: + Defines aggregation in the case of multiple output scores. Can be one + of the following strings (default is ``'uniform_average'``.): + + * ``'raw_values'`` returns full set of scores + * ``'uniform_average'`` scores are uniformly averaged + * ``'variance_weighted'`` scores are weighted by their individual variances + + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Raises: + ValueError: + If ``adjusted`` parameter is not an integer larger or equal to 0. + ValueError: + If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``. + + Example: + + >>> from pytorch_lightning.metrics import R2Score + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> r2score = R2Score() + >>> r2score(preds, target) + tensor(0.9486) + + >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) + >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) + >>> r2score = R2Score(num_outputs=2, multioutput='raw_values') + >>> r2score(preds, target) + tensor([0.9654, 0.9082]) + """ - @deprecated_metrics(target=_R2Score) def __init__( self, num_outputs: int = 1, @@ -31,9 +98,50 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.num_outputs = num_outputs + + if adjusted < 0 or not isinstance(adjusted, int): + raise ValueError('`adjusted` parameter should be an integer larger or equal to 0.') + self.adjusted = adjusted + + allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') + if multioutput not in allowed_multioutput: + raise ValueError( + f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}' + ) + self.multioutput = multioutput + + self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") + self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") + self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values """ - This implementation refers to :class:`~torchmetrics.R2Score`. + sum_squared_error, sum_error, residual, total = _r2score_update(preds, target) - .. deprecated:: - Use :class:`~torchmetrics.R2Score`. Will be removed in v1.5.0. + self.sum_squared_error += sum_squared_error + self.sum_error += sum_error + self.residual += residual + self.total += total + + def compute(self) -> torch.Tensor: + """ + Computes r2 score over the metric states. """ + return _r2score_compute( + self.sum_squared_error, self.sum_error, self.residual, self.total, self.adjusted, self.multioutput + ) diff --git a/pytorch_lightning/metrics/regression/ssim.py b/pytorch_lightning/metrics/regression/ssim.py index cf5571f3e68f4..09b55fb2bb456 100644 --- a/pytorch_lightning/metrics/regression/ssim.py +++ b/pytorch_lightning/metrics/regression/ssim.py @@ -13,14 +13,43 @@ # limitations under the License. from typing import Any, Optional, Sequence -from torchmetrics import SSIM as _SSIM +import torch -from pytorch_lightning.metrics.utils import deprecated_metrics +from pytorch_lightning.metrics.functional.ssim import _ssim_compute, _ssim_update +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.utilities import rank_zero_warn -class SSIM(_SSIM): +class SSIM(Metric): + """ + Computes `Structual Similarity Index Measure + `_ (SSIM). + + Args: + kernel_size: size of the gaussian kernel (default: (11, 11)) + sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + data_range: Range of the image. If ``None``, it is determined from the image (max - min) + k1: Parameter of SSIM. Default: 0.01 + k2: Parameter of SSIM. Default: 0.03 + + Return: + Tensor with SSIM score + + Example: + >>> from pytorch_lightning.metrics import SSIM + >>> preds = torch.rand([16, 1, 16, 16]) + >>> target = preds * 0.75 + >>> ssim = SSIM() + >>> ssim(preds, target) + tensor(0.9219) + """ - @deprecated_metrics(target=_SSIM) def __init__( self, kernel_size: Sequence[int] = (11, 11), @@ -33,9 +62,44 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + rank_zero_warn( + 'Metric `SSIM` will save all targets and' + ' predictions in buffer. For large datasets this may lead' + ' to large memory footprint.' + ) + + self.add_state("y", default=[], dist_reduce_fx=None) + self.add_state("y_pred", default=[], dist_reduce_fx=None) + self.kernel_size = kernel_size + self.sigma = sigma + self.data_range = data_range + self.k1 = k1 + self.k2 = k2 + self.reduction = reduction + + def update(self, preds: torch.Tensor, target: torch.Tensor): """ - This implementation refers to :class:`~torchmetrics.SSIM`. + Update state with predictions and targets. - .. deprecated:: - Use :class:`~torchmetrics.SSIM`. Will be removed in v1.5.0. + Args: + preds: Predictions from model + target: Ground truth values + """ + preds, target = _ssim_update(preds, target) + self.y_pred.append(preds) + self.y.append(target) + + def compute(self): + """ + Computes explained variance over state. """ + preds = torch.cat(self.y_pred, dim=0) + target = torch.cat(self.y, dim=0) + return _ssim_compute( + preds, target, self.kernel_size, self.sigma, self.reduction, self.data_range, self.k1, self.k2 + ) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 4adc88a37ba21..cd0713fde0173 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -11,86 +11,293 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial -from typing import Optional +from typing import Optional, Tuple import torch -from deprecate import deprecated -from torchmetrics.utilities.data import dim_zero_cat as _dim_zero_cat -from torchmetrics.utilities.data import dim_zero_mean as _dim_zero_mean -from torchmetrics.utilities.data import dim_zero_sum as _dim_zero_sum -from torchmetrics.utilities.data import get_num_classes as _get_num_classes -from torchmetrics.utilities.data import select_topk as _select_topk -from torchmetrics.utilities.data import to_categorical as _to_categorical -from torchmetrics.utilities.data import to_onehot as _to_onehot -from torchmetrics.utilities.distributed import class_reduce as _class_reduce -from torchmetrics.utilities.distributed import reduce as _reduce -from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities import rank_zero_warn -deprecated_metrics = partial(deprecated, deprecated_in="1.3.0", remove_in="1.5.0", stream=rank_zero_deprecation) +METRIC_EPS = 1e-6 -@deprecated_metrics(target=_dim_zero_cat) def dim_zero_cat(x): - pass + x = x if isinstance(x, (list, tuple)) else [x] + return torch.cat(x, dim=0) -@deprecated_metrics(target=_dim_zero_sum) def dim_zero_sum(x): - pass + return torch.sum(x, dim=0) -@deprecated_metrics(target=_dim_zero_mean) def dim_zero_mean(x): - pass + return torch.mean(x, dim=0) -@deprecated_metrics(target=_to_onehot) -def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> torch.Tensor: +def _flatten(x): + return [item for sublist in x for item in sublist] + + +def _check_same_shape(pred: torch.Tensor, target: torch.Tensor): + """ Check that predictions and target have the same shape, else raise error """ + if pred.shape != target.shape: + raise RuntimeError("Predictions and targets are expected to have the same shape") + + +def _input_format_classification_one_hot( + num_classes: int, + preds: torch.Tensor, + target: torch.Tensor, + threshold: float = 0.5, + multilabel: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert preds and target tensors into one hot spare label tensors + + Args: + num_classes: number of classes + preds: either tensor with labels, tensor with probabilities/logits or + multilabel tensor + target: tensor with ground true labels + threshold: float used for thresholding multilabel input + multilabel: boolean flag indicating if input is multilabel + + Raises: + ValueError: + If ``preds`` and ``target`` don't have the same number of dimensions + or one additional dimension for ``preds``. + + Returns: + preds: one hot tensor of shape [num_classes, -1] with predicted labels + target: one hot tensors of shape [num_classes, -1] with true labels """ - .. deprecated:: - Use :func:`torchmetrics.utilities.data.to_onehot`. Will be removed in v1.5.0. + if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1): + raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") + + if preds.ndim == target.ndim + 1: + # multi class probabilites + preds = torch.argmax(preds, dim=1) + + if preds.ndim == target.ndim and preds.dtype in (torch.long, torch.int) and num_classes > 1 and not multilabel: + # multi-class + preds = to_onehot(preds, num_classes=num_classes) + target = to_onehot(target, num_classes=num_classes) + + elif preds.ndim == target.ndim and preds.is_floating_point(): + # binary or multilabel probablities + preds = (preds >= threshold).long() + + # transpose class as first dim and reshape + if preds.ndim > 1: + preds = preds.transpose(1, 0) + target = target.transpose(1, 0) + + return preds.reshape(num_classes, -1), target.reshape(num_classes, -1) + + +def to_onehot( + label_tensor: torch.Tensor, + num_classes: Optional[int] = None, +) -> torch.Tensor: """ + Converts a dense label tensor to one-hot format + + Args: + label_tensor: dense label tensor, with shape [N, d1, d2, ...] + num_classes: number of classes C + + Returns: + A sparse label tensor with shape [N, C, d1, d2, ...] + + Example: + + >>> from pytorch_lightning.metrics.utils import to_onehot + >>> x = torch.tensor([1, 2, 3]) + >>> to_onehot(x) + tensor([[0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + """ + if num_classes is None: + num_classes = int(label_tensor.max().detach().item() + 1) + + tensor_onehot = torch.zeros( + label_tensor.shape[0], + num_classes, + *label_tensor.shape[1:], + dtype=label_tensor.dtype, + device=label_tensor.device, + ) + index = label_tensor.long().unsqueeze(1).expand_as(tensor_onehot) + return tensor_onehot.scatter_(1, index, 1.0) -@deprecated_metrics(target=_select_topk) def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.utilities.data.select_topk`. Will be removed in v1.5.0. + Convert a probability tensor to binary by selecting top-k highest entries. + + Args: + prob_tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the + position defined by the ``dim`` argument + topk: number of highest entries to turn into 1s + dim: dimension on which to compare entries + + Returns: + A binary tensor of the same shape as the input tensor of type torch.int32 + + Example: + + >>> from pytorch_lightning.metrics.utils import select_topk + >>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) + >>> select_topk(x, topk=2) + tensor([[0, 1, 1], + [1, 1, 0]], dtype=torch.int32) """ + zeros = torch.zeros_like(prob_tensor) + topk_tensor = zeros.scatter(dim, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) + return topk_tensor.int() -@deprecated_metrics(target=_to_categorical) def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.utilities.data.to_categorical`. Will be removed in v1.5.0. + Converts a tensor of probabilities to a dense label tensor + + Args: + tensor: probabilities to get the categorical label [N, d1, d2, ...] + argmax_dim: dimension to apply + + Return: + A tensor with categorical labels [N, d2, ...] + + Example: + + >>> from pytorch_lightning.metrics.utils import to_categorical + >>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) + >>> to_categorical(x) + tensor([1, 0]) """ + return torch.argmax(tensor, dim=argmax_dim) -@deprecated_metrics(target=_get_num_classes) -def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None) -> int: +def get_num_classes( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, +) -> int: """ - .. deprecated:: - Use :func:`torchmetrics.utilities.data.get_num_classes`. Will be removed in v1.5.0. + Calculates the number of classes for a given prediction and target tensor. + + Args: + pred: predicted values + target: true labels + num_classes: number of classes if known + + Return: + An integer that represents the number of classes. """ + num_target_classes = int(target.max().detach().item() + 1) + num_pred_classes = int(pred.max().detach().item() + 1) + num_all_classes = max(num_target_classes, num_pred_classes) + + if num_classes is None: + num_classes = num_all_classes + elif num_classes != num_all_classes: + rank_zero_warn( + f"You have set {num_classes} number of classes which is" + f" different from predicted ({num_pred_classes}) and" + f" target ({num_target_classes}) number of classes", + RuntimeWarning, + ) + return num_classes -@deprecated_metrics(target=_reduce) def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.utilities.reduce`. Will be removed in v1.5.0. + Reduces a given tensor by a given reduction method + + Args: + to_reduce : the tensor, which shall be reduced + reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum') + + Return: + reduced Tensor + + Raise: + ValueError if an invalid reduction parameter was given """ + if reduction == "elementwise_mean": + return torch.mean(to_reduce) + if reduction == "none": + return to_reduce + if reduction == "sum": + return torch.sum(to_reduce) + raise ValueError("Reduction parameter unknown.") -@deprecated_metrics(target=_class_reduce) def class_reduce( num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none" ) -> torch.Tensor: """ - .. deprecated:: - Use :func:`torchmetrics.utilities.class_reduce`. Will be removed in v1.5.0. + Function used to reduce classification metrics of the form `num / denom * weights`. + For example for calculating standard accuracy the num would be number of + true positives per class, denom would be the support per class, and weights + would be a tensor of 1s + + Args: + num: numerator tensor + denom: denominator tensor + weights: weights for each class + class_reduction: reduction method for multiclass problems + + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'`` or ``None``: returns calculated metric per class + + Raises: + ValueError: + If ``class_reduction`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"`` or ``None``. + """ + valid_reduction = ("micro", "macro", "weighted", "none", None) + if class_reduction == "micro": + fraction = torch.sum(num) / torch.sum(denom) + else: + fraction = num / denom + + # We need to take care of instances where the denom can be 0 + # for some (or all) classes which will produce nans + fraction[fraction != fraction] = 0 + + if class_reduction == "micro": + return fraction + elif class_reduction == "macro": + return torch.mean(fraction) + elif class_reduction == "weighted": + return torch.sum(fraction * (weights.float() / torch.sum(weights))) + elif class_reduction == "none" or class_reduction is None: + return fraction + + raise ValueError( + f"Reduction parameter {class_reduction} unknown." + f" Choose between one of these: {valid_reduction}" + ) + + +def _stable_1d_sort(x: torch, N: int = 2049): + """ + Stable sort of 1d tensors. Pytorch defaults to a stable sorting algorithm + if number of elements are larger than 2048. This function pads the tensors, + makes the sort and returns the sorted array (with the padding removed) + See this discussion: https://discuss.pytorch.org/t/is-torch-sort-stable/20714 + + Raises: + ValueError: + If dim of ``x`` is greater than 1 since stable sort works with only 1d tensors. """ + if x.ndim > 1: + raise ValueError('Stable sort only works on 1d tensors') + n = x.numel() + if N - n > 0: + x_max = x.max() + x = torch.cat([x, (x_max + 1) * torch.ones(N - n, dtype=x.dtype, device=x.device)], 0) + x_sort = x.sort() + i = min(N, n) + return x_sort.values[:i], x_sort.indices[:i] diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index 0c1ac7b359fd0..1d6f4e93b5779 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -53,7 +53,7 @@ def forward(self, *inputs, **kwargs): elif trainer and (trainer.sanity_checking or trainer.validating): output = self.module.validation_step(*inputs, **kwargs) elif trainer and trainer.predicting: - output = self.module.predict_step(*inputs, **kwargs) + output = self.module.predict(*inputs, **kwargs) else: output = self.module(*inputs, **kwargs) diff --git a/pytorch_lightning/overrides/torch_distributed.py b/pytorch_lightning/overrides/torch_distributed.py deleted file mode 100644 index 67b64c046dc18..0000000000000 --- a/pytorch_lightning/overrides/torch_distributed.py +++ /dev/null @@ -1,94 +0,0 @@ -import logging -import pickle - -import torch - -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 - -log = logging.getLogger(__name__) - -if torch.distributed.is_available(): - from torch.distributed import Backend, broadcast, get_backend, get_rank, GroupMember - -# The code underneath is taken from PyTorch ``torch/distributed/distributed_c10d.py`` -# and enable broadcasting for PyTorch 1.6 and lower. - - -# https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L160 -def _rank_not_in_group(group): - """ - Helper that checks if the current process's rank is not in a given group. - """ - if group is None: - return False - return group == GroupMember.NON_GROUP_MEMBER - - -# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1164 -def _object_to_tensor(obj): - buffer = pickle.dumps(obj) - byte_storage = torch.ByteStorage.from_buffer(buffer) # type: ignore[attr-defined] - byte_tensor = torch.ByteTensor(byte_storage) - local_size = torch.LongTensor([byte_tensor.numel()]) - return byte_tensor, local_size - - -# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py -def _tensor_to_object(tensor, tensor_size): - buf = tensor.numpy().tobytes()[:tensor_size] - out = pickle.loads(buf) - return out - - -# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327 -def _broadcast_object_list(object_list, src=0, group=None): - if _rank_not_in_group(group): - return - - my_rank = get_rank() - # Serialize object_list elements to tensors on src rank. - if my_rank == src: - tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list]) - object_sizes_tensor = torch.cat(size_list) - else: - object_sizes_tensor = torch.LongTensor(len(object_list)) - - group_backend = get_backend(group) - is_nccl_backend = group_backend == Backend.NCCL - current_device = torch.device("cpu") - if is_nccl_backend: - # See note about using torch.cuda.current_device() here in docstring. - # We cannot simply use my_rank since rank == device is not necessarily - # true. - current_device = torch.device('cuda', torch.cuda.current_device()) - object_sizes_tensor = object_sizes_tensor.to(current_device) - object_sizes_tensor = object_sizes_tensor.to(current_device) - - # Broadcast object sizes - broadcast(object_sizes_tensor, src=src, group=group) - - # Concatenate and broadcast serialized object tensors - if my_rank == src: - object_tensor = torch.cat(tensor_list) - else: - object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item()) - - if is_nccl_backend: - object_tensor = object_tensor.to(current_device) - - broadcast(object_tensor, src=src, group=group) - - # Deserialize objects using their stored sizes. - offset = 0 - if my_rank != src: - for i, obj_size in enumerate(object_sizes_tensor): - obj_view = object_tensor[offset:offset + obj_size] - obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload] - offset += obj_size - object_list[i] = _tensor_to_object(obj_view, obj_size) - - -if _TORCH_GREATER_EQUAL_1_7 and torch.distributed.is_available(): - from torch.distributed.distributed_c10d import broadcast_object_list -else: - broadcast_object_list = _broadcast_object_list diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index a67235baa4767..dec672d025294 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,7 +1,6 @@ from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401 from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 -from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401 @@ -30,7 +29,6 @@ "DDPSpawnPlugin", "DeepSpeedPlugin", "DeepSpeedPrecisionPlugin", - "DoublePrecisionPlugin", "HorovodPlugin", "NativeMixedPrecisionPlugin", "PrecisionPlugin", diff --git a/pytorch_lightning/plugins/precision/__init__.py b/pytorch_lightning/plugins/precision/__init__.py index d32aac829a13d..fc60deffcbb77 100644 --- a/pytorch_lightning/plugins/precision/__init__.py +++ b/pytorch_lightning/plugins/precision/__init__.py @@ -1,6 +1,5 @@ from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 -from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index b600eca5e6bc2..75570e453ec1b 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -169,4 +169,5 @@ def pre_optimizer_step( pl_module.trainer.call_hook("on_after_backward") optimizer.step(**kwargs) + return False diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py deleted file mode 100644 index 4720f0f874fd0..0000000000000 --- a/pytorch_lightning/plugins/precision/double.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from functools import wraps -from typing import Any, Sequence, Tuple, TYPE_CHECKING, List - -import torch - -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin -from pytorch_lightning.utilities.apply_func import apply_to_collection - -if TYPE_CHECKING: - from torch.nn import Module - from torch.optim import Optimizer - - -class _DoublePrecisionPatch: - """Class to handle patching of methods in the ``LightningModule`` and subsequent teardown.""" - - def __init__(self, model: 'Module', method_name: str, old_method: Any) -> None: - self.model = model - self.method_name = method_name - self.old_method = old_method - - def teardown(self) -> None: - setattr(self.model, self.method_name, self.old_method) - - @staticmethod - def _to_double_precision(data: torch.Tensor) -> torch.Tensor: - if data.is_floating_point(): - return data.double() - return data - - @staticmethod - def _move_float_tensors_to_double(collection: Any) -> Any: - return apply_to_collection( - collection, torch.Tensor, function=_DoublePrecisionPatch._to_double_precision - ) - - @classmethod - def patch(cls, model: 'Module', method_name: str) -> '_DoublePrecisionPatch': - old_method = getattr(model, method_name) - - @wraps(old_method) - def new_method(*args: Any, **kwargs: Any) -> Any: - return old_method( - *_DoublePrecisionPatch._move_float_tensors_to_double(args), - **_DoublePrecisionPatch._move_float_tensors_to_double(kwargs) - ) - - setattr(model, method_name, new_method if callable(old_method) else old_method) - return cls(model, method_name, old_method) - - -class DoublePrecisionPlugin(PrecisionPlugin): - """Plugin for training with double (``torch.float64``) precision.""" - - precision: int = 64 - - def __init__(self) -> None: - self.patches: List[_DoublePrecisionPatch] = [] - - def connect( - self, - model: 'Module', - optimizers: Sequence['Optimizer'], - lr_schedulers: Sequence[Any], - ) -> Tuple['Module', Sequence['Optimizer'], Sequence[Any]]: - """Converts the model to double precision and wraps the `training_step`, `validation_step`, `test_step`, - `predict_step`, and `forward` methods to convert incoming floating point data to double. Does not alter - `optimizers` or `lr_schedulers`.""" - model = model.to(dtype=torch.float64) - if isinstance(model, LightningModule): - self.patches.append(_DoublePrecisionPatch.patch(model, 'training_step')) - self.patches.append(_DoublePrecisionPatch.patch(model, 'validation_step')) - self.patches.append(_DoublePrecisionPatch.patch(model, 'test_step')) - self.patches.append(_DoublePrecisionPatch.patch(model, 'predict_step')) - self.patches.append(_DoublePrecisionPatch.patch(model, 'forward')) - - return super().connect(model, optimizers, lr_schedulers) - - def post_dispatch(self) -> None: - while len(self.patches) > 0: - self.patches.pop().teardown() diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 3c83945c8a1b7..dc822680bcbda 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -103,21 +103,3 @@ def train_step_context(self) -> Generator[None, None, None]: """Enable autocast context""" with torch.cuda.amp.autocast(): yield - - @contextmanager - def val_step_context(self) -> Generator[None, None, None]: - """Enable autocast context""" - with torch.cuda.amp.autocast(): - yield - - @contextmanager - def test_step_context(self) -> Generator[None, None, None]: - """Enable autocast context""" - with torch.cuda.amp.autocast(): - yield - - @contextmanager - def predict_context(self) -> Generator[None, None, None]: - """Enable autocast context""" - with torch.cuda.amp.autocast(): - yield diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 7172d82391bd3..2b1579cf497c0 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -100,6 +100,7 @@ def post_optimizer_step(self, optimizer: 'Optimizer', optimizer_idx: int) -> Non def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: """Clips the gradients to a specific value""" + # TODO: separate TPU case from here if clip_val is None: return diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 58e26e7db32d8..f857ad50399cf 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -80,7 +80,9 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) return distributed_sampler_kwargs - def setup_environment(self): + def setup(self, model): + self._model = model + # start the other scripts if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": self._call_children_scripts() @@ -88,8 +90,6 @@ def setup_environment(self): # set the task idx self.task_idx = self.cluster_environment.local_rank() - self.setup_distributed() - def _call_children_scripts(self): # bookkeeping of spawned processes @@ -161,34 +161,6 @@ def _call_children_scripts(self): delay = np.random.uniform(1, 5, 1)[0] sleep(delay) - def setup_distributed(self): - # TODO: check if needed - seed = os.environ.get("PL_GLOBAL_SEED") - if seed is not None: - seed_everything(int(seed)) - - # determine which process we are and world size - self.set_world_ranks() - - # set warning rank - rank_zero_only.rank = self.global_rank - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - self.init_ddp_connection(self.global_rank, self.world_size) - - # on world_size=0 let everyone know training is starting - if self.is_global_zero and not torch.distributed.is_initialized(): - log.info("-" * 100) - log.info(f"distributed_backend={self.distributed_backend}") - log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") - log.info("-" * 100) - - # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device - def _check_can_spawn_children(self): if self._has_spawned_children: raise RuntimeError( @@ -207,7 +179,9 @@ def pre_configure_ddp(self): # Many models require setting this parameter to True, as there are corner cases # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. - self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( + "find_unused_parameters", True + ) # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( "find_unused_parameters", False @@ -241,6 +215,37 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) def pre_dispatch(self): + # TODO: check if needed + seed = os.environ.get("PL_GLOBAL_SEED") + if seed is not None: + seed_everything(int(seed)) + + # determine which process we are and world size + self.set_world_ranks() + + # set warning rank + rank_zero_only.rank = self.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + self.init_ddp_connection(self.global_rank, self.world_size) + + # TODO: we moved it to the trainer.fit after calling pre_dispatch + # ... need to double check that it is the correct place + # self.trainer.call_setup_hook(self.model) + + # on world_size=0 let everyone know training is starting + if self.is_global_zero and not torch.distributed.is_initialized(): + log.info("-" * 100) + log.info(f"distributed_backend={self.distributed_backend}") + log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") + log.info("-" * 100) + + # set the ranks and devices + self.dist.rank = self.global_rank + self.dist.device = self.root_device + if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) @@ -298,7 +303,7 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.model(*args, **kwargs) - def predict_step(self, *args, **kwargs): + def predict(self, *args, **kwargs): return self.model(*args, **kwargs) def post_training_step(self): diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 87d7fa5faecac..3636b2fb92fa2 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -77,6 +77,8 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs def setup(self, model): + self._model = model + os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) # pass in a state q @@ -170,7 +172,9 @@ def pre_configure_ddp(self): # Many models require setting this parameter to True, as there are corner cases # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. - self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( + "find_unused_parameters", True + ) # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( "find_unused_parameters", False @@ -282,7 +286,7 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.model(*args, **kwargs) - def predict_step(self, *args, **kwargs): + def predict(self, *args, **kwargs): return self.model(*args, **kwargs) def post_training_step(self): diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index b196044937414..b54155d60eae5 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -192,7 +192,17 @@ def _load_config(self, config): return config def pre_dispatch(self): + self.set_world_ranks() + self.init_ddp_connection(self.global_rank, self.world_size) + self.init_deepspeed() + + # set warning rank + rank_zero_only.rank = self.global_rank + + # set the ranks and devices + self.dist.rank = self.global_rank + self.dist.device = self.root_device self.barrier() def init_deepspeed(self): diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index a8e42e0fa747a..1d5398778c0df 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -71,8 +71,8 @@ def barrier(self, *args, **kwargs): def broadcast(self, obj: object, src: int = 0) -> object: return obj - def reduce_boolean_decision(self, decision: bool) -> bool: - return decision + def reduce_early_stopping_decision(self, should_stop: bool) -> bool: + return should_stop def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) @@ -83,7 +83,7 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.model(*args, **kwargs) - def predict_step(self, *args, **kwargs): + def predict(self, *args, **kwargs): return self.model(*args, **kwargs) def training_step_end(self, output): diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 8d0add27cbb29..2fe3906cb01d0 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -21,7 +21,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE -from pytorch_lightning.utilities.distributed import group, rank_zero_only, ReduceOp +from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp if _HOROVOD_AVAILABLE: import horovod.torch as hvd @@ -96,14 +96,14 @@ def start_training(self, trainer): stack.enter_context(optimizer.skip_synchronize()) # set up training routine - self._results = trainer.run_stage() + self._results = trainer.run_train() # Make sure all workers have finished training before returning to the user hvd.join() def start_evaluating(self, trainer): with ExitStack(): - self._results = trainer.run_stage() + self._results = trainer.run_evaluate() # Make sure all workers have finished training before returning to the user hvd.join() @@ -111,7 +111,7 @@ def start_evaluating(self, trainer): def start_predicting(self, trainer): with ExitStack(): # set up training routine - self._results = trainer.run_stage() + self._results = trainer.run_predict() # Make sure all workers have finished training before returning to the user hvd.join() @@ -159,13 +159,8 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ hvd.join() return hvd.allreduce(tensor, op=reduce_op) - def all_gather( - self, - result: Union[torch.Tensor], - group: Optional[Any] = group.WORLD, - sync_grads: bool = False - ) -> torch.Tensor: - if group is not None and group != group.WORLD: + def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None): + if group is not None: raise ValueError( "Horovod does not support allgather using a subcommunicator at this time. " "Unset `group`." diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index d9a8e70588c43..715c5332e231c 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import io import os from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, List, Optional +from typing import List, Optional import torch from torch.nn.parallel import DistributedDataParallel @@ -35,10 +36,9 @@ def __init__( ): super().__init__() self.parallel_devices = parallel_devices - self.cluster_environment = cluster_environment - self.global_rank = 0 self.world_size = 1 self.local_rank = 0 + self.cluster_environment = cluster_environment @property @abstractmethod @@ -53,6 +53,14 @@ def on_gpu(self): def lightning_module(self): return unwrap_lightning_module(self._model) + @abstractmethod + def setup(self, model): + raise NotImplementedError + + def connect(self, model, *args, **kwargs): + self.setup(model) + return self.model + @property def is_global_zero(self) -> bool: return self.global_rank == 0 @@ -62,15 +70,11 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=len(self.parallel_devices), rank=self.global_rank) return distributed_sampler_kwargs - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes """ - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) - - def reduce_boolean_decision(self, decision: bool) -> bool: - decision = torch.tensor(int(decision), device=self.lightning_module.device) - decision = self.reduce(decision, reduce_op=ReduceOp.SUM) - decision = bool(decision == self.world_size) - return decision + def reduce_early_stopping_decision(self, should_stop: bool) -> bool: + should_stop = torch.tensor(int(should_stop), device=self.lightning_module.device) + should_stop = self.reduce(should_stop, reduce_op=ReduceOp.SUM) + should_stop = bool(should_stop == self.world_size) + return should_stop @property def torch_distributed_backend(self): @@ -108,3 +112,13 @@ def block_backward_sync(self): yield None else: yield None + + def broadcast(self, obj: object, src: int) -> object: + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + data_tensor = torch.tensor(data).to(self.root_device, dtype=torch.float) + data = all_gather_ddp_if_available(data_tensor) + buffer = io.BytesIO(data.cpu().byte().numpy()) + obj = torch.load(buffer) + return obj diff --git a/pytorch_lightning/plugins/training_type/rpc.py b/pytorch_lightning/plugins/training_type/rpc.py index 3e0f57daef001..faf528d76b768 100644 --- a/pytorch_lightning/plugins/training_type/rpc.py +++ b/pytorch_lightning/plugins/training_type/rpc.py @@ -13,7 +13,7 @@ # limitations under the License. import os from contextlib import suppress -from typing import Callable, List, Optional +from typing import List, Optional, Callable import torch diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index ba26fc9f58ec5..336c16f0f1a03 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -13,7 +13,7 @@ # limitations under the License import logging import os -from typing import Callable, List, Optional +from typing import List, Optional, Callable import torch import torch.distributed as torch_distrib diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index d70779adf3ba1..d11ae87bed660 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Union import torch @@ -23,9 +23,6 @@ class SingleDevicePlugin(TrainingTypePlugin): def __init__(self, device: torch.device): super().__init__() self.device: torch.device = device - self.global_rank = 0 - self.local_rank = 0 - self.world_size = 1 @property def on_tpu(self) -> bool: @@ -50,10 +47,6 @@ def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> """ return tensor - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes """ - return tensor - @property def root_device(self) -> torch.device: return self.device @@ -64,7 +57,8 @@ def model_to_device(self) -> None: self._model.to(self.root_device) - def setup(self, model: torch.nn.Module) -> torch.nn.Module: + def connect(self, model: torch.nn.Module) -> torch.nn.Module: + self._model = model self.model_to_device() return self.model diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index b8d670ff16881..d3cbd0d6b5d79 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -39,8 +39,13 @@ def __init__(self, device: Union[torch.device, int]): def on_tpu(self) -> bool: return True + def connect(self, model: torch.nn.Module) -> torch.nn.Module: + self._model = model + self.model_to_device() + return self._model + def model_to_device(self) -> None: - self.model.to(self.root_device) + self._model.to(self.root_device) def pre_dispatch(self) -> None: if isinstance(self.device, int): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index a8706d54cb5c9..e05a7bc03ef5c 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -17,6 +17,7 @@ from typing import Any, Dict, Iterable, List, Optional, Union import torch +import torch.distributed as torch_distrib import torch.multiprocessing as mp from pytorch_lightning.core.lightning import LightningModule @@ -52,9 +53,10 @@ def __init__( self.tpu_local_core_rank = 0 self.start_method = None - def setup(self, model: torch.nn.Module) -> torch.nn.Module: + def connect(self, model: torch.nn.Module) -> torch.nn.Module: self.create_mp_queue() - return self.model + self._model = model + return self._model def create_mp_queue(self): self.start_method = 'fork' @@ -108,15 +110,13 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: # replace trainer save_checkpoint to use `xm.save` trainer.save_checkpoint = self.save_checkpoint - self.barrier("pre-run-stage") + self.barrier() results = trainer.run_stage() self.__save_end_of_training_weights(self.lightning_module) self.transfer_distrib_spawn_state_on_fit_end(results) - self.barrier("end-process") - def __save_end_of_training_weights(self, model: LightningModule) -> None: # when training ends on these platforms dump weights to get out of the main process if on_colab_kaggle(): @@ -127,11 +127,11 @@ def model_to_device(self) -> None: self._model.to(xm.xla_device()) def barrier(self, name: Optional[str] = None) -> None: - rendezvous(name) + if torch_distrib.is_initialized(): + rendezvous(f"pl.Trainer.{name}") def transfer_distrib_spawn_state_on_fit_end(self, results): - checkpoint_callback = self.lightning_module.trainer.checkpoint_callback - best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path if self.mp_queue is not None: rank_zero_warn("cleaning up ddp environment...") @@ -203,11 +203,12 @@ def save_spawn_weights(self, model: LightningModule) -> Optional[str]: model.trainer.save_checkpoint(path) return path - def reduce_decision(self, decision: bool) -> bool: - decision = torch.tensor(int(decision), device=self.device) - decision = self.reduce(decision, "sum") - decision = bool(decision == self.world_size) - return decision + def reduce_early_stopping_decision(self, should_stop: bool) -> bool: + should_stop = torch.tensor(int(should_stop), device=self.lightning_module.device) + stop = xm.mesh_reduce('stop_signal', should_stop, sum) + rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") + should_stop = int(stop.item()) == self.world_size + return should_stop def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): if not isinstance(output, torch.Tensor): @@ -295,8 +296,8 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.lightning_module.test_step(*args, **kwargs) - def predict_step(self, *args, **kwargs): - return self.lightning_module.predict_step(*args, **kwargs) + def predict(self, *args, **kwargs): + return self.lightning_module.predict(*args, **kwargs) def save_checkpoint(self, filepath, weights_only: bool = False): """Save model/training states as a checkpoint file through state-dump and file-write. diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 08dca63a7c925..7783f066dbc61 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -33,20 +33,11 @@ class TrainingTypePlugin(Plugin, ABC): def __init__(self) -> None: self._model = None self._results = None + self.global_rank = 0 + @abstractmethod def connect(self, model: 'Module') -> None: - """Called by the accelerator to connect the accelerator and the model with this plugin""" - self.model = model - - def setup_environment(self) -> None: - """ - Setup any processes or distributed connections. - This is called before the LightningModule/DataModule setup hook - which allows the user to access the accelerator environment before setup is complete. - """ - - def setup(self, model: 'Module') -> None: - """Called by the accelerator to finish setup.""" + """Called by the accelerator to connect it with this plugin""" @property @abstractmethod @@ -86,13 +77,9 @@ def barrier(self, name: Optional[str] = None) -> None: def broadcast(self, obj: object, src: int = 0) -> object: """Broadcasts an object to all processes""" - @abstractmethod - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes """ - - def reduce_boolean_decision(self, decision: bool) -> bool: - """Reduce the early stopping decision across all processes""" - return decision + def reduce_early_stopping_decision(self, should_stop: bool) -> bool: + """Reduce the early stopping decision across all possibly spawned processes""" + return should_stop def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run before precision plugin executes backward""" @@ -132,15 +119,15 @@ def rpc_enabled(self) -> bool: def start_training(self, trainer: 'Trainer') -> None: # double dispatch to initiate the training loop - self._results = trainer.run_stage() + self._results = trainer.run_train() def start_evaluating(self, trainer: 'Trainer') -> None: # double dispatch to initiate the test loop - self._results = trainer.run_stage() + self._results = trainer.run_evaluate() def start_predicting(self, trainer: 'Trainer') -> None: # double dispatch to initiate the predicting loop - self._results = trainer.run_stage() + self._results = trainer.run_predict() def training_step(self, *args, **kwargs): return self.lightning_module.training_step(*args, **kwargs) @@ -154,8 +141,8 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.lightning_module.test_step(*args, **kwargs) - def predict_step(self, *args, **kwargs): - return self.lightning_module.predict_step(*args, **kwargs) + def predict(self, *args, **kwargs): + return self.lightning_module.predict(*args, **kwargs) def training_step_end(self, output): return output @@ -182,13 +169,3 @@ def init_optimizers(self, trainer: "Trainer", model: LightningModule): def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): optimizer.step(closure=lambda_closure, **kwargs) - - @property - def setup_optimizers_in_pre_dispatch(self) -> bool: - """ - Override to delay setting optimizers and schedulers till after dispatch. - This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model. - However this may break certain precision plugins such as APEX which require optimizers to be set. - Returns: If True, delay setup optimizers till pre_dispatch, else call within setup. - """ - return False diff --git a/pytorch_lightning/plugins/training_type/utils.py b/pytorch_lightning/plugins/training_type/utils.py index eddb9077116dc..7380f871f59a5 100644 --- a/pytorch_lightning/plugins/training_type/utils.py +++ b/pytorch_lightning/plugins/training_type/utils.py @@ -1,16 +1,3 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import os diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py index 6ac6e16c18529..e09a5ea11a084 100644 --- a/pytorch_lightning/profiler/__init__.py +++ b/pytorch_lightning/profiler/__init__.py @@ -121,8 +121,7 @@ def custom_processing_step(self, data): Autograd includes a profiler that lets you inspect the cost of different operators inside your model - both on the CPU and GPU. -To read more about the PyTorch Profiler and all its options, -have a look at its `docs `__ +Find the Pytorch Profiler doc at [PyTorch Profiler](https://pytorch-lightning.readthedocs.io/en/stable/profiler.html) .. code-block:: python @@ -135,16 +134,16 @@ def custom_processing_step(self, data): This profiler works with PyTorch ``DistributedDataParallel``. -If ``filename`` is provided, each rank will save their profiled operation to their own file. The profiler -report can be quite long, so you setting a ``filename`` will save the report instead of logging it to the -output in your terminal. If no filename is given, it will be logged only on rank 0. +If ``output_filename`` is provided, each rank will save their profiled operation to their own file. -The profiler's results will be printed on the completion of ``{fit,validate,test,predict}``. -This profiler will record ``training_step_and_backward``, ``training_step``, ``backward``, -``validation_step``, ``test_step``, and ``predict_step`` by default. -The output below shows the profiling for the action ``training_step_and_backward``. -The user can provide ``PyTorchProfiler(record_functions={...})`` to extend the scope of profiled functions. +The profiler's results will be printed on the completion of a training `fit()`. This profiler +report can be quite long, so you can also specify an `output_filename` to save the report instead +of logging it to the output in your terminal. + +This profiler will record only for `training_step_and_backward`, `evaluation_step` and `test_step` functions by default. +The output below shows the profiling for the action `training_step_and_backward`. +The user can provide ``PyTorchProfiler(profiled_functions=[...])`` to extend the scope of profiled functions. .. note:: When using the PyTorch Profiler, wall clock time will not not be representative of the true wall clock time. This is due to forcing profiled operations to be measured synchronously, when many CUDA ops happen asynchronously. It is recommended to use this Profiler to find bottlenecks/breakdowns, however for end to end wall clock time use the `SimpleProfiler`. # noqa E501 @@ -185,13 +184,13 @@ def custom_processing_step(self, data): To visualize the profiled operation, you can either: -Use:: +* Use:: nvvp trace_name.prof -Or:: +* Use:: - python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))' + python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))' """ diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index bc9e3541dbaa8..d704ba83236c1 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -21,19 +21,31 @@ from abc import ABC, abstractmethod from collections import defaultdict from contextlib import contextmanager -from pathlib import Path -from typing import Any, Callable, Dict, Optional, TextIO, Tuple, Union +from typing import Optional, Union import numpy as np -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem log = logging.getLogger(__name__) -class AbstractProfiler(ABC): - """Specification of a profiler.""" +class BaseProfiler(ABC): + """ + If you wish to write a custom profiler, you should inhereit from this class. + """ + + def __init__(self, output_streams: Optional[Union[list, tuple]] = None): + """ + Args: + output_streams: callable + """ + if output_streams: + if not isinstance(output_streams, (list, tuple)): + output_streams = [output_streams] + else: + output_streams = [] + self.write_streams = output_streams @abstractmethod def start(self, action_name: str) -> None: @@ -43,48 +55,6 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: """Defines how to record the duration once an action is complete.""" - @abstractmethod - def summary(self) -> str: - """Create profiler summary in text format.""" - - @abstractmethod - def setup(self, **kwargs: Any) -> None: - """Execute arbitrary pre-profiling set-up steps as defined by subclass.""" - - @abstractmethod - def teardown(self, **kwargs: Any) -> None: - """Execute arbitrary post-profiling tear-down steps as defined by subclass.""" - - -class BaseProfiler(AbstractProfiler): - """ - If you wish to write a custom profiler, you should inherit from this class. - """ - - def __init__( - self, - dirpath: Optional[Union[str, Path]] = None, - filename: Optional[str] = None, - output_filename: Optional[str] = None, - ) -> None: - self.dirpath = dirpath - self.filename = filename - if output_filename is not None: - rank_zero_warn( - "`Profiler` signature has changed in v1.3. The `output_filename` parameter has been removed in" - " favor of `dirpath` and `filename`. Support for the old signature will be removed in v1.5", - DeprecationWarning - ) - filepath = Path(output_filename) - self.dirpath = filepath.parent - self.filename = filepath.stem - - self._output_file: Optional[TextIO] = None - self._write_stream: Optional[Callable] = None - self._local_rank: Optional[int] = None - self._log_dir: Optional[str] = None - self._stage: Optional[str] = None - @contextmanager def profile(self, action_name: str) -> None: """ @@ -116,92 +86,17 @@ def profile_iterable(self, iterable, action_name: str) -> None: self.stop(action_name) break - def _rank_zero_info(self, *args, **kwargs) -> None: - if self._local_rank in (None, 0): - log.info(*args, **kwargs) - - def _prepare_filename(self, extension: str = ".txt") -> str: - filename = "" - if self._stage is not None: - filename += f"{self._stage}-" - filename += str(self.filename) - if self._local_rank is not None: - filename += f"-{self._local_rank}" - filename += extension - return filename - - def _prepare_streams(self) -> None: - if self._write_stream is not None: - return - if self.filename: - filepath = os.path.join(self.dirpath, self._prepare_filename()) - fs = get_filesystem(filepath) - file = fs.open(filepath, "a") - self._output_file = file - self._write_stream = file.write - else: - self._write_stream = self._rank_zero_info - def describe(self) -> None: - """Logs a profile report after the conclusion of run.""" - # there are pickling issues with open file handles in Python 3.6 - # so to avoid them, we open and close the files within this function - # by calling `_prepare_streams` and `teardown` - self._prepare_streams() - self._write_stream(self.summary()) - if self._output_file is not None: - self._output_file.flush() - self.teardown(stage=self._stage) - - def _stats_to_str(self, stats: Dict[str, str]) -> str: - stage = f"{self._stage.upper()} " if self._stage is not None else "" - output = [stage + "Profiler Report"] - for action, value in stats.items(): - header = f"Profile stats for: {action}" - if self._local_rank is not None: - header += f" rank: {self._local_rank}" - output.append(header) - output.append(value) - return os.linesep.join(output) - - def setup( - self, - stage: Optional[str] = None, - local_rank: Optional[int] = None, - log_dir: Optional[str] = None, - ) -> None: - """Execute arbitrary pre-profiling set-up steps.""" - self._stage = stage - self._local_rank = local_rank - self._log_dir = log_dir - self.dirpath = self.dirpath or log_dir - - def teardown(self, stage: Optional[str] = None) -> None: - """ - Execute arbitrary post-profiling tear-down steps. - - Closes the currently open file and stream. - """ - self._write_stream = None - if self._output_file is not None: - self._output_file.close() - self._output_file = None # can't pickle TextIOWrapper - - def __del__(self) -> None: - self.teardown(stage=self._stage) - - def start(self, action_name: str) -> None: - raise NotImplementedError - - def stop(self, action_name: str) -> None: - raise NotImplementedError + """Logs a profile report after the conclusion of the training run.""" + for write in self.write_streams: + write(self.summary()) + @abstractmethod def summary(self) -> str: - raise NotImplementedError + """Create profiler summary in text format.""" - @property - def local_rank(self) -> int: - return 0 if self._local_rank is None else self._local_rank + def on_train_start(self, local_rank: Optional[int] = None): + self.local_rank = local_rank class PassThroughProfiler(BaseProfiler): @@ -210,6 +105,9 @@ class PassThroughProfiler(BaseProfiler): The Trainer uses this class by default. """ + def __init__(self): + super().__init__(output_streams=None) + def start(self, action_name: str) -> None: pass @@ -226,32 +124,30 @@ class SimpleProfiler(BaseProfiler): the mean duration of each action and the total time spent over the entire training run. """ - def __init__( - self, - dirpath: Optional[Union[str, Path]] = None, - filename: Optional[str] = None, - extended: bool = True, - output_filename: Optional[str] = None, - ) -> None: + def __init__(self, output_filename: Optional[str] = None, extended=True): """ Args: - dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the - ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) - will be used. - - filename: If present, filename where the profiler results will be saved instead of printing to stdout. - The ``.txt`` extension will be used automatically. + output_filename: optionally save profile results to file instead of printing + to std out when training is finished. Raises: ValueError: If you attempt to start an action which has already started, or if you attempt to stop recording an action which was never started. """ - super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) - self.current_actions: Dict[str, float] = {} + self.current_actions = {} self.recorded_durations = defaultdict(list) self.extended = extended + + self.output_fname = output_filename + self.output_file = None + if self.output_fname: + fs = get_filesystem(self.output_fname) + self.output_file = fs.open(self.output_fname, "w") + + streaming_out = [self.output_file.write] if self.output_file else [log.info] self.start_time = time.monotonic() + super().__init__(output_streams=streaming_out) def start(self, action_name: str) -> None: if action_name in self.current_actions: @@ -266,18 +162,14 @@ def stop(self, action_name: str) -> None: duration = end_time - start_time self.recorded_durations[action_name].append(duration) - def _make_report(self) -> Tuple[list, float]: + def make_report(self): total_duration = time.monotonic() - self.start_time report = [[a, d, 100. * np.sum(d) / total_duration] for a, d in self.recorded_durations.items()] report.sort(key=lambda x: x[2], reverse=True) return report, total_duration def summary(self) -> str: - sep = os.linesep - output_string = "" - if self._stage is not None: - output_string += f"{self._stage.upper()} " - output_string += f"Profiler Report{sep}" + output_string = "\n\nProfiler Report\n" if self.extended: @@ -285,16 +177,16 @@ def summary(self) -> str: max_key = np.max([len(k) for k in self.recorded_durations.keys()]) def log_row(action, mean, num_calls, total, per): - row = f"{sep}{action:<{max_key}s}\t| {mean:<15}\t|" + row = f"{os.linesep}{action:<{max_key}s}\t| {mean:<15}\t|" row += f"{num_calls:<15}\t| {total:<15}\t| {per:<15}\t|" return row output_string += log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %") output_string_len = len(output_string) - output_string += f"{sep}{'-' * output_string_len}" - report, total_duration = self._make_report() + output_string += f"{os.linesep}{'-' * output_string_len}" + report, total_duration = self.make_report() output_string += log_row("Total", "-", "_", f"{total_duration:.5}", "100 %") - output_string += f"{sep}{'-' * output_string_len}" + output_string += f"{os.linesep}{'-' * output_string_len}" for action, durations, duration_per in report: output_string += log_row( action, @@ -306,16 +198,27 @@ def log_row(action, mean, num_calls, total, per): else: def log_row(action, mean, total): - return f"{sep}{action:<20s}\t| {mean:<15}\t| {total:<15}" + return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}" output_string += log_row("Action", "Mean duration (s)", "Total time (s)") - output_string += f"{sep}{'-' * 65}" + output_string += f"{os.linesep}{'-' * 65}" for action, durations in self.recorded_durations.items(): output_string += log_row(action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}") - output_string += sep + output_string += os.linesep return output_string + def describe(self): + """Logs a profile report after the conclusion of the training run.""" + super().describe() + if self.output_file: + self.output_file.flush() + + def __del__(self): + """Close profiler's stream.""" + if self.output_file: + self.output_file.close() + class AdvancedProfiler(BaseProfiler): """ @@ -324,22 +227,11 @@ class AdvancedProfiler(BaseProfiler): verbose and you should only use this if you want very detailed reports. """ - def __init__( - self, - dirpath: Optional[Union[str, Path]] = None, - filename: Optional[str] = None, - line_count_restriction: float = 1.0, - output_filename: Optional[str] = None, - ) -> None: + def __init__(self, output_filename: Optional[str] = None, line_count_restriction: float = 1.0): """ Args: - dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the - ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) - will be used. - - filename: If present, filename where the profiler results will be saved instead of printing to stdout. - The ``.txt`` extension will be used automatically. - + output_filename: optionally save profile results to file instead of printing + to std out when training is finished. line_count_restriction: this can be used to limit the number of functions reported for each action. either an integer (to select a count of lines), or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines) @@ -348,10 +240,18 @@ def __init__( ValueError: If you attempt to stop recording an action which was never started. """ - super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) - self.profiled_actions: Dict[str, cProfile.Profile] = {} + self.profiled_actions = {} self.line_count_restriction = line_count_restriction + self.output_fname = output_filename + self.output_file = None + if self.output_fname: + fs = get_filesystem(self.output_fname) + self.output_file = fs.open(self.output_fname, "w") + + streaming_out = [self.output_file.write] if self.output_file else [log.info] + super().__init__(output_streams=streaming_out) + def start(self, action_name: str) -> None: if action_name not in self.profiled_actions: self.profiled_actions[action_name] = cProfile.Profile() @@ -360,7 +260,9 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: pr = self.profiled_actions.get(action_name) if pr is None: - raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.") + raise ValueError( # pragma: no-cover + f"Attempting to stop recording an action ({action_name}) which was never started." + ) pr.disable() def summary(self) -> str: @@ -370,16 +272,21 @@ def summary(self) -> str: ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative') ps.print_stats(self.line_count_restriction) recorded_stats[action_name] = s.getvalue() - return self._stats_to_str(recorded_stats) - def teardown(self, stage: Optional[str] = None) -> None: - super().teardown(stage=stage) - self.profiled_actions = {} + # log to standard out + output_string = f"{os.linesep}Profiler Report{os.linesep}" + for action, stats in recorded_stats.items(): + output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}" + + return output_string + + def describe(self): + """Logs a profile report after the conclusion of the training run.""" + super().describe() + if self.output_file: + self.output_file.flush() - def __reduce__(self): - # avoids `TypeError: cannot pickle 'cProfile.Profile' object` - return ( - self.__class__, - tuple(), - dict(dirpath=self.dirpath, filename=self.filename, line_count_restriction=self.line_count_restriction), - ) + def __del__(self): + """Close profiler's stream.""" + if self.output_file: + self.output_file.close() diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index fa2c2917f98a2..88a33a3d367f8 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -12,197 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. """Profiler to check if there are any bottlenecks in your code.""" + import inspect import logging import os -from functools import partial -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Type, TYPE_CHECKING, Union +from typing import List, Optional import torch -from torch import nn, Tensor -from torch.autograd.profiler import record_function from pytorch_lightning.profiler.profilers import BaseProfiler +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.distributed import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE - -if TYPE_CHECKING: - from torch.autograd.profiler import EventList - from torch.utils.hooks import RemovableHandle - - from pytorch_lightning.core.lightning import LightningModule - -if _KINETO_AVAILABLE: - from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler log = logging.getLogger(__name__) -_PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx] - - -class RegisterRecordFunction: - """ - While profiling autograd operations, this class will add labels for module names around the forward function. - - The Lightning PyTorch Profiler will activate this feature automatically. It can be deactivated as follows: - - Example:: - from pytorch_lightning.profilers import PyTorchProfiler - profiler = PyTorchProfiler(record_module_names=False) - Trainer(profiler=profiler) - - It can be used outside of Lightning as follows: - - Example:: - from pytorch_lightning import Trainer, seed_everything - with RegisterRecordFunction(model): - out = model(batch) - """ - - def __init__(self, model: nn.Module) -> None: - self._model = model - self._records: Dict[str, record_function] = {} - self._handles: Dict[str, List['RemovableHandle']] = {} - - def _start_recording_forward(self, _: nn.Module, input: Tensor, record_name: str) -> Tensor: - record = record_function(record_name) - record.__enter__() - self._records[record_name] = record - return input - - def _stop_recording_forward(self, _: nn.Module, __: Tensor, output: Tensor, record_name: str) -> Tensor: - self._records[record_name].__exit__(None, None, None) - return output - - def __enter__(self) -> None: - for module_name, module in self._model.named_modules(): - if module_name: - full_name = f"{type(module).__module__}.{type(module).__name__}" - record_name = f"{full_name}: {module_name}" - pre_forward_handle = module.register_forward_pre_hook( - partial(self._start_recording_forward, record_name=record_name) - ) - post_forward_handle = module.register_forward_hook( - partial(self._stop_recording_forward, record_name=record_name) - ) - - self._handles[module_name] = [pre_forward_handle, post_forward_handle] - - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - for handles in self._handles.values(): - for h in handles: - h.remove() - self._handles = {} - - -class ScheduleWrapper: - """ - This class is used to override the schedule logic from the profiler and perform - recording for both `training_step`, `validation_step`. - """ - - def __init__(self, schedule: Callable) -> None: - if not _KINETO_AVAILABLE: - raise ModuleNotFoundError("You are trying to use `ScheduleWrapper` which require kineto install.") - self._schedule = schedule - self.reset() - - def setup(self, start_action_name: str) -> None: - self._start_action_name = start_action_name - - def pre_step(self, current_action: str) -> None: - self._current_action = current_action - - def reset(self): - self._num_training_step_and_backward = 0 - self._num_validation_step = 0 - self._num_test_step = 0 - self._num_predict_step = 0 - self._training_step_and_backward_reached_end = False - self._validation_step_reached_end = False - self._test_step_reached_end = False - self._predict_step_reached_end = False - # used to stop profiler when `ProfilerAction.RECORD_AND_SAVE` is reached. - self._current_action: Optional[str] = None - self._start_action_name: Optional[str] = None - - @property - def num_step(self) -> int: - if self._current_action == "training_step_and_backward": - return self._num_training_step_and_backward - elif self._current_action == "validation_step": - return self._num_validation_step - elif self._current_action == "test_step": - return self._num_test_step - elif self._current_action == "predict_step": - return self._num_predict_step - else: - return 0 - - def _step(self) -> None: - if self._current_action == "training_step_and_backward": - self._num_training_step_and_backward += 1 - elif self._current_action == "validation_step": - if self._start_action_name == "on_fit_start": - if self._num_training_step_and_backward > 0: - self._num_validation_step += 1 - else: - self._num_validation_step += 1 - elif self._current_action == "test_step": - self._num_test_step += 1 - elif self._current_action == "predict_step": - self._num_predict_step += 1 - - @property - def has_finished(self) -> bool: - if self._current_action == "training_step_and_backward": - return self._training_step_and_backward_reached_end - elif self._current_action == "validation_step": - return self._validation_step_reached_end - elif self._current_action == "test_step": - return self._test_step_reached_end - elif self._current_action == "predict_step": - return self._predict_step_reached_end - return False - - def __call__(self, num_step: int) -> 'ProfilerAction': - # ignore the provided input. Keep internal state instead. - if self.has_finished: - return ProfilerAction.NONE - - self._step() - action = self._schedule(self.num_step) - if action == ProfilerAction.RECORD_AND_SAVE: - if self._current_action == "training_step_and_backward": - self._training_step_and_backward_reached_end = True - elif self._current_action == "validation_step": - self._validation_step_reached_end = True - elif self._current_action == "test_step": - self._test_step_reached_end = True - elif self._current_action == "predict_step": - self._predict_step_reached_end = True - return action - class PyTorchProfiler(BaseProfiler): - RECORD_FUNCTIONS = { - "training_step_and_backward", - "training_step", - "backward", - "validation_step", - "test_step", - "predict_step", - } - STEP_FUNCTIONS = { - "training_step_and_backward", - "validation_step", - "test_step", - "predict_step", - } - AVAILABLE_SORT_KEYS = { + PROFILED_FUNCTIONS = ("training_step_and_backward", "validation_step", "test_step") + AVAILABLE_SORT_KEYS = ( "cpu_time", "cuda_time", "cpu_time_total", @@ -212,43 +42,56 @@ class PyTorchProfiler(BaseProfiler): "self_cpu_memory_usage", "self_cuda_memory_usage", "count", - } - START_RECORD_FUNCTIONS = { - 'on_fit_start', - 'on_validation_start', - 'on_test_start', - 'on_predict_start', - } + ) def __init__( self, - dirpath: Optional[Union[str, Path]] = None, - filename: Optional[str] = None, + output_filename: Optional[str] = None, + enabled: bool = True, + use_cuda: bool = False, + record_shapes: bool = False, + profile_memory: bool = False, group_by_input_shapes: bool = False, + with_stack: bool = False, + use_kineto: bool = False, + use_cpu: bool = True, emit_nvtx: bool = False, - export_to_chrome: bool = True, + export_to_chrome: bool = False, + path_to_export_trace: str = None, row_limit: int = 20, sort_by_key: Optional[str] = None, - record_functions: Set[str] = None, - record_module_names: bool = True, profiled_functions: Optional[List] = None, - output_filename: Optional[str] = None, - **profiler_kwargs: Any, - ) -> None: + local_rank: Optional[int] = None, + ): """ This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of different operators inside your model - both on the CPU and GPU Args: - dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the - ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) - will be used. - filename: If present, filename where the profiler results will be saved instead of printing to stdout. - The ``.txt`` extension will be used automatically. + output_filename: optionally save profile results to file instead of printing + to std out when training is finished. When using ``ddp``, + each rank will stream the profiled operation to their own file + with the extension ``_{rank}.txt`` + + enabled: Setting this to False makes this context manager a no-op. + + use_cuda: Enables timing of CUDA events as well using the cudaEvent API. + Adds approximately 4us of overhead to each tensor operation. + + record_shapes: If shapes recording is set, information about input dimensions will be collected. + + profile_memory: Whether to report memory usage, default: True (Introduced in PyTorch 1.6.0) group_by_input_shapes: Include operator input shapes and group calls by shape. + with_stack: record source information (file and line number) for the ops (Introduced in PyTorch 1.7.0) + + use_kineto: experimental support for Kineto profiler (Introduced in PyTorch 1.8.0) + + use_cpu: use_kineto=True and can be used to lower the overhead + for GPU-only profiling (Introduced in PyTorch 1.8.0) + emit_nvtx: Context manager that makes every autograd operation emit an NVTX range Run:: @@ -259,254 +102,202 @@ def __init__( nvvp trace_name.prof torch.autograd.profiler.load_nvprof(path) - export_to_chrome: Whether to export the sequence of profiled operators for Chrome. + export_to_chrome: Wether to export the sequence of profiled operators for Chrome. It will generate a ``.json`` file which can be read by Chrome. - row_limit: Limit the number of rows in a table, ``-1`` is a special value that + path_to_export_trace: Directory path to export ``.json`` traces when using ``export_to_chrome=True``. + By default, it will be save where the file being is being run. + + row_limit: Limit the number of rows in a table, `0` is a special value that removes the limit completely. - sort_by_key: Attribute used to sort entries. By default - they are printed in the same order as they were registered. - Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``, - ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``, - ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``. + sort_by_key: Keys to sort out profiled table - record_functions: Set of profiled functions which will create a context manager on. + profiled_functions: list of profiled functions which will create a context manager on. Any other will be pass through. - record_module_names: Whether to add module names while recording autograd operation. - - profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version + local_rank: When running in distributed setting, local_rank is used for each process + to write to their own file if `output_fname` is provided. Raises: MisconfigurationException: - If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``. - If arg ``schedule`` is not a ``Callable``. - If arg ``schedule`` does not return a ``torch.profiler.ProfilerAction``. + If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``, or + if log file is not a ``.txt`` file. + ValueError: + If you attempt to stop recording an action which was never started. """ - super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) - - record_functions = self.__deprecation_check(profiled_functions, record_functions) - - self._group_by_input_shapes = group_by_input_shapes and profiler_kwargs.get("record_shapes", False) - self._emit_nvtx = emit_nvtx - self._export_to_chrome = export_to_chrome - self._row_limit = row_limit - self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total" - self._user_record_functions = record_functions - self._record_functions_start = self._user_record_functions | self.START_RECORD_FUNCTIONS - self._record_functions = self._user_record_functions | self.RECORD_FUNCTIONS - self._record_module_names = record_module_names - self._profiler_kwargs = profiler_kwargs - - self.profiler: Optional[_PROFILER] = None - self.function_events: Optional['EventList'] = None - self._lightning_module: Optional['LightningModule'] = None # set by ProfilerConnector - self._register: Optional[RegisterRecordFunction] = None - self._parent_profiler: Optional[_PROFILER] = None - self._recording_map: Dict[str, record_function] = {} - self._start_action_name: Optional[str] = None - self._schedule: Optional[ScheduleWrapper] = None - - if _KINETO_AVAILABLE: - self._init_kineto(profiler_kwargs) - - if self._sort_by_key not in self.AVAILABLE_SORT_KEYS: - raise MisconfigurationException( - f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " - ) - - def _init_kineto(self, profiler_kwargs: Any) -> None: - has_schedule = "schedule" in profiler_kwargs - self._has_on_trace_ready = "on_trace_ready" in profiler_kwargs - - schedule = profiler_kwargs.get("schedule", None) - if schedule is not None: - if not isinstance(schedule, Callable): - raise MisconfigurationException(f"Schedule should be a callable. Found: {schedule}") - action = schedule(0) - if not isinstance(action, ProfilerAction): - raise MisconfigurationException( - f"Schedule should return a `torch.profiler.ProfilerAction`. Found: {action}" - ) - schedule = schedule if has_schedule else self._default_schedule() - self._schedule = ScheduleWrapper(schedule) if schedule is not None else schedule - self._profiler_kwargs["schedule"] = self._schedule - - activities = profiler_kwargs.get("activities", None) - self._profiler_kwargs["activities"] = activities or self._default_activities() - self._export_to_flame_graph = profiler_kwargs.get("export_to_flame_graph", False) - self._metric = profiler_kwargs.get("metric", "self_cpu_time_total") - with_stack = profiler_kwargs.get("with_stack", False) or self._export_to_flame_graph - self._profiler_kwargs["with_stack"] = with_stack - - def __deprecation_check( - self, - profiled_functions: Optional[List[str]], - record_functions: Optional[Set[str]], - ) -> Set[str]: - if record_functions is None: - record_functions = set() - if profiled_functions is not None: + self.profiled_actions = {} + self.enabled = enabled + self.profiled_functions = profiled_functions or self.PROFILED_FUNCTIONS + self.use_cuda = use_cuda + self.record_shapes = record_shapes + self.profile_memory = profile_memory + self.sort_by_key = sort_by_key or ("cuda_time_total" if self.use_cuda else "cpu_time_total") + self.with_stack = with_stack + self.group_by_input_shapes = group_by_input_shapes and record_shapes + self.use_kineto = use_kineto + self.use_cpu = use_cpu + self.row_limit = row_limit + self.emit_nvtx = emit_nvtx + self.export_to_chrome = export_to_chrome + self.path_to_export_trace = path_to_export_trace + + if export_to_chrome and path_to_export_trace is None: rank_zero_warn( - "`PyTorchProfiler.profiled_functions` has been renamed to" - " `record_functions` in v1.3 and will be removed in v1.5", DeprecationWarning + "The exported trace would be save locally as `path_to_export_trace` is empty." + " Note: Each functions will generate its own traced file." ) - if not record_functions: - record_functions |= set(profiled_functions) - else: - raise MisconfigurationException( - "You set `PytorchProfiler.profiled_functions` and `PyTorchProfiler.record_functions`." - " Please use only the later." - ) - - return record_functions - - @staticmethod - def _default_schedule() -> Optional[callable]: - if _KINETO_AVAILABLE: - # Those schedule defaults allow the profiling overhead to be negligible over training time. - return torch.profiler.schedule(wait=1, warmup=1, active=3) - - def _default_activities(self) -> List['ProfilerActivity']: - activities = [] - if not _KINETO_AVAILABLE: - return activities - if self._profiler_kwargs.get("use_cpu", True): - activities.append(ProfilerActivity.CPU) - if self._profiler_kwargs.get("use_cuda", torch.cuda.is_available()): - activities.append(ProfilerActivity.CUDA) - return activities - def start(self, action_name: str) -> None: - if self.profiler is None and action_name in self._record_functions_start: - - # close profiler if it is already opened. might happen if 2 profilers - # are created and the first one did not call `describe` - try: - torch.autograd._disable_profiler() # noqa - except (AttributeError, RuntimeError): - pass - - if self._schedule is not None: - self._schedule.setup(action_name) + if self.sort_by_key not in self.AVAILABLE_SORT_KEYS: + raise MisconfigurationException( + f"Found sort_by_key: {sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " + ) - self._create_profilers() + self.profiled_actions = {} + self.context_names = {} + self.running_stack = [] + self.profiler = None - profiler = self.profiler.__enter__() - if profiler is not None: - self.profiler = profiler + self.output_fname = output_filename + self.output_file = None + if local_rank is not None: + self.on_train_start(local_rank=local_rank) + self.on_train_start = super().on_train_start - if self._parent_profiler is not None: - self._parent_profiler.__enter__() + def on_train_start(self, local_rank: Optional[str] = None): + self.local_rank = local_rank - if self._register is not None: - self._register.__enter__() + # when logging to `log.info`, only perform profiling on rank 0 + if local_rank != 0 and self.output_fname is None: + self.wrap_functions_into_rank_zero_only() - if ( - self.profiler is not None and action_name in self._record_functions - and action_name not in self._recording_map - ): - recording = record_function(action_name) - recording.__enter__() - self._recording_map[action_name] = recording + if self.output_fname: + if local_rank is not None: + if '.txt' not in self.output_fname: + raise MisconfigurationException("Log file should be .txt file.") - def stop(self, action_name: str) -> None: - if action_name in self._recording_map: - self._recording_map[action_name].__exit__(None, None, None) - del self._recording_map[action_name] + self.output_fname = self.output_fname.replace(".txt", f"_{self.local_rank}.txt") - if not _KINETO_AVAILABLE or self._emit_nvtx: - return + fs = get_filesystem(self.output_fname) + self.output_file = fs.open(self.output_fname, "w") - if self.profiler is not None and action_name in self.STEP_FUNCTIONS: - if self._schedule is not None: - self._schedule.pre_step(action_name) + streaming_out = [self.output_file.write] if self.output_file else [log.info] + super().__init__(output_streams=streaming_out) - def on_trace_ready(profiler): - if self.dirpath is not None: - if self._export_to_chrome: - handler = tensorboard_trace_handler(self.dirpath, self._prepare_filename(extension="")) - handler(profiler) + def wrap_functions_into_rank_zero_only(self): + self.start = rank_zero_only(self.start) + self.stop = rank_zero_only(self.stop) + self.summary = rank_zero_only(self.summary) + self.describe = rank_zero_only(self.describe) - if self._export_to_flame_graph: - path = os.path.join(self.dirpath, self._prepare_filename(extension=".stack")) - profiler.export_stacks(path, metric=self._metric) - else: - rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None") + def start(self, action_name: str) -> None: + if action_name not in self.profiled_functions: + return - if not self._has_on_trace_ready: - self.profiler.on_trace_ready = on_trace_ready + if len(self.running_stack) > 0: + self._stop(self.running_stack[-1]) + self.running_stack.append(action_name) - if self._schedule is not None: - self.profiler.step_num = self._schedule.num_step - self.profiler.step() + self.context_names[action_name] = "/".join(self.running_stack) - def summary(self) -> str: - if not self._profiler_kwargs.get("enabled", True) or self._emit_nvtx: - return "" + self._start(action_name) - self._delete_profilers() + def _start(self, action_name: str) -> None: + if self.emit_nvtx: + self._parent_profiler = self._create_profiler(action_name, torch.cuda.profiler.profile, enter=True) + self._create_profiler(action_name, torch.autograd.profiler.emit_nvtx) + else: + self._create_profiler(action_name, torch.autograd.profiler.profile) + + def _create_profiler(self, action_name, profiler, enter=True): + init_args = inspect.signature(profiler.__init__).parameters + profiler_args = {k: v for k, v in vars(self).items() if k in init_args} + pr = profiler(**profiler_args) + if enter: + out_pr = pr.__enter__() + if out_pr is not None: + pr = out_pr + self.profiler = pr + return self.profiler + + def _stop(self, action_name: str) -> None: + if self.profiler is None: + return - if not self.function_events: - return "" + self.profiler.__exit__(exc_type=None, exc_val=None, exc_tb=None) - if self._export_to_chrome and not _KINETO_AVAILABLE: - filename = f"{self.local_rank}_trace.json" - path_to_trace = (filename if self.dirpath is None else os.path.join(self.dirpath, filename)) - self.function_events.export_chrome_trace(path_to_trace) + if isinstance(self.profiler, torch.autograd.profiler.emit_nvtx): + # when running ``emit_nvtx``, PyTorch requires 2 context manager. + # The parent_profiler is being closed too. + self._parent_profiler.__exit__(None, None, None) + return - data = self.function_events.key_averages(group_by_input_shapes=self._group_by_input_shapes) - table = data.table(sort_by=self._sort_by_key, row_limit=self._row_limit) + function_events = self.profiler.function_events + self.profiler = None + for name in self.running_stack: + if name not in self.profiled_actions: + self.profiled_actions[name] = function_events + else: + self.profiled_actions[name] += function_events - recorded_stats = {"records": table} - return self._stats_to_str(recorded_stats) + def stop(self, action_name: str) -> None: + if action_name not in self.profiled_functions: + return - def _create_profilers(self) -> None: - if self._emit_nvtx: - self._parent_profiler = self._create_profiler(torch.cuda.profiler.profile) - self.profiler = self._create_profiler(torch.autograd.profiler.emit_nvtx) - else: - self._parent_profiler = None - self.profiler = self._create_profiler( - torch.profiler.profile if _KINETO_AVAILABLE else torch.autograd.profiler.profile + if len(self.running_stack) == 0 or self.running_stack[-1] != action_name: + raise ValueError( # pragma: no-cover + f"Attempting to stop recording an action ({action_name}) which was never started." ) - if self._record_module_names and self._lightning_module is not None: - self._register = RegisterRecordFunction(self._lightning_module) - - def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER: - init_parameters = inspect.signature(profiler.__init__).parameters - kwargs = {k: v for k, v in self._profiler_kwargs.items() if k in init_parameters} - return profiler(**kwargs) + self._stop(action_name) + self.running_stack.pop() + # restore running profiler + if len(self.running_stack) > 0: + self._start(self.running_stack[-1]) - def _cache_functions_events(self) -> None: - if self._emit_nvtx: - return - self.function_events = self.profiler.events() if _KINETO_AVAILABLE else self.profiler.function_events + def summary(self) -> str: + recorded_stats = {} + output_string = '' + local_rank = '0' if self.local_rank is None else self.local_rank - def _delete_profilers(self) -> None: - if self.profiler is not None: - self.profiler.__exit__(None, None, None) - self._cache_functions_events() - self.profiler = None + if not self.enabled: + return output_string - if self._schedule is not None: - self._schedule.reset() + for action_name, function_events in self.profiled_actions.items(): - if self._parent_profiler is not None: - self._parent_profiler.__exit__(None, None, None) - self._parent_profiler = None + # next line is a workaround for a pytorch issue (fixed on master, still present + # on 1.7). Without it the code fails with `AssertionError: There is already a CPU + # parent event for detach` + function_events.populate_cpu_children = lambda: None - if self._register is not None: - self._register.__exit__(None, None, None) - self._register = None + if self.export_to_chrome: + filename = f"{action_name}_{local_rank}_trace.json" + path_to_trace = filename if self.path_to_export_trace is None \ + else os.path.join(self.path_to_export_trace, filename) + function_events.export_chrome_trace(path_to_trace) - def teardown(self, stage: Optional[str] = None) -> None: - self._delete_profilers() + if self.emit_nvtx: + return output_string - for k in self._recording_map: - self.stop(k) - self._recording_map = {} - - super().teardown(stage=stage) + else: + data = function_events.key_averages(group_by_input_shapes=self.group_by_input_shapes) + table = data.table(sort_by=self.sort_by_key, row_limit=self.row_limit) + recorded_stats[action_name] = table + + # log to standard out + output_string = f"{os.linesep}Profiler Report{os.linesep}" + for action, stats in recorded_stats.items(): + output_string += (f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}") + + return output_string + + def describe(self): + """Logs a profile report after the conclusion of the training run.""" + super().describe() + if self.output_file: + self.output_file.flush() + + def __del__(self): + """Close profiler's stream.""" + if self.output_file: + self.output_file.close() diff --git a/pytorch_lightning/setup_tools.py b/pytorch_lightning/setup_tools.py index 3362ccb479895..f5aed2608635e 100644 --- a/pytorch_lightning/setup_tools.py +++ b/pytorch_lightning/setup_tools.py @@ -16,7 +16,7 @@ import re from typing import List -_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) +from pytorch_lightning import __homepage__, __version__, _PROJECT_ROOT def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comment_char: str = '#') -> List[str]: @@ -40,10 +40,10 @@ def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comme return reqs -def _load_readme_description(path_dir: str, homepage: str, version: str) -> str: +def _load_readme_description(path_dir: str, homepage: str = __homepage__, version: str = __version__) -> str: """Load readme as decribtion - >>> _load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> _load_readme_description(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE '
...' """ path_readme = os.path.join(path_dir, "README.md") diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 6d434e12a2e78..5aa9f1a44276b 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -15,15 +15,11 @@ from abc import ABC from copy import deepcopy from inspect import signature -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Type, Optional from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import rank_zero_deprecation -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.warnings import WarningCache - -warning_cache = WarningCache() +from pytorch_lightning.utilities import rank_zero_warn class TrainerCallbackHookMixin(ABC): @@ -83,12 +79,8 @@ def on_train_epoch_start(self): for callback in self.callbacks: callback.on_train_epoch_start(self, self.lightning_module) - def on_train_epoch_end(self, outputs: List[Any]): - """Called when the epoch ends. - - Args: - outputs: List of outputs on each ``train`` epoch - """ + def on_train_epoch_end(self, outputs): + """Called when the epoch ends.""" for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) @@ -97,52 +89,28 @@ def on_validation_epoch_start(self): for callback in self.callbacks: callback.on_validation_epoch_start(self, self.lightning_module) - def on_validation_epoch_end(self, outputs: List[Any]): - """Called when the epoch ends. - - Args: - outputs: List of outputs on each ``validation`` epoch - """ + def on_validation_epoch_end(self): + """Called when the epoch ends.""" for callback in self.callbacks: - if is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"): - callback.on_validation_epoch_end(self, self.lightning_module, outputs) - else: - warning_cache.warn( - "`Callback.on_validation_epoch_end` signature has changed in v1.3." - " `outputs` parameter has been added." - " Support for the old signature will be removed in v1.5", DeprecationWarning - ) - callback.on_validation_epoch_end(self, self.lightning_module) + callback.on_validation_epoch_end(self, self.lightning_module) def on_test_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: callback.on_test_epoch_start(self, self.lightning_module) - def on_test_epoch_end(self, outputs: List[Any]): - """Called when the epoch ends. - - Args: - outputs: List of outputs on each ``test`` epoch - """ + def on_test_epoch_end(self): + """Called when the epoch ends.""" for callback in self.callbacks: - if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"): - callback.on_test_epoch_end(self, self.lightning_module, outputs) - else: - warning_cache.warn( - "`Callback.on_test_epoch_end` signature has changed in v1.3." - " `outputs` parameter has been added." - " Support for the old signature will be removed in v1.5", DeprecationWarning - ) - callback.on_test_epoch_end(self, self.lightning_module) + callback.on_test_epoch_end(self, self.lightning_module) def on_epoch_start(self): - """Called when either of train/val/test epoch begins.""" + """Called when the epoch begins.""" for callback in self.callbacks: callback.on_epoch_start(self, self.lightning_module) def on_epoch_end(self): - """Called when either of train/val/test epoch ends.""" + """Called when the epoch ends.""" for callback in self.callbacks: callback.on_epoch_end(self, self.lightning_module) @@ -243,10 +211,10 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: callback_states = {} for callback in self.callbacks: if self.__is_old_signature(callback.on_save_checkpoint): - rank_zero_deprecation( + rank_zero_warn( "`Callback.on_save_checkpoint` signature has changed in v1.3." " A `checkpoint` parameter has been added." - " Support for the old signature will be removed in v1.5" + " Support for the old signature will be removed in v1.5", DeprecationWarning ) state = callback.on_save_checkpoint(self, self.lightning_module) # noqa: parameter-unfilled else: diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index a7ba2b1c40123..8c539b5ff478d 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -40,8 +40,7 @@ def verify_loop_configurations(self, model: LightningModule) -> None: self.__verify_eval_loop_configuration(model, 'val') elif self.trainer.state == TrainerState.TESTING: self.__verify_eval_loop_configuration(model, 'test') - elif self.trainer.state == TrainerState.PREDICTING: - self.__verify_predict_loop_configuration(model) + # TODO: add predict def __verify_train_loop_configuration(self, model): # ----------------------------------- @@ -100,9 +99,3 @@ def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) - rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop') if has_step and not has_loader: rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop') - - def __verify_predict_loop_configuration(self, model: LightningModule) -> None: - - has_predict_dataloader = is_overridden('predict_dataloader', model) - if not has_predict_dataloader: - raise MisconfigurationException('Dataloader not found for `Trainer.predict`') diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 30d2b48975a84..99d716f6b5a8c 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -32,7 +32,6 @@ DDPSpawnShardedPlugin, DeepSpeedPlugin, DeepSpeedPrecisionPlugin, - DoublePrecisionPlugin, HorovodPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin, @@ -274,10 +273,6 @@ def use_deepspeed(self) -> bool: @property def is_distributed(self) -> bool: - # Used for custom plugins. - # Custom plugins should implement is_distributed property. - if hasattr(self.training_type_plugin, 'is_distributed') and not self.on_tpu: - return self.training_type_plugin.is_distributed is_distributed = self.use_ddp or self.use_ddp2 or self.use_horovod if self.on_tpu: is_distributed |= self.training_type_plugin.is_distributed @@ -320,8 +315,7 @@ def select_precision_plugin(self) -> PrecisionPlugin: if self.precision == 32: return PrecisionPlugin() - elif self.precision == 64: - return DoublePrecisionPlugin() + elif self.precision == 16: if self.on_tpu: return TPUHalfPrecisionPlugin() @@ -360,7 +354,7 @@ def select_precision_plugin(self) -> PrecisionPlugin: log.info("Using APEX 16bit precision.") return ApexMixedPrecisionPlugin(self.amp_level) - raise NotImplementedError("We only support precisions 64, 32 and 16!") + raise NotImplementedError("We only support precisions 32 and 16!") def select_training_type_plugin(self) -> TrainingTypePlugin: if self.use_ddp2: @@ -432,11 +426,6 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None: training_type.num_nodes = self.num_nodes - # Automatically set sync_batchnorm if None. - # Useful for custom plugins. - if hasattr(training_type, 'sync_batchnorm') and getattr(training_type, 'sync_batchnorm') is None: - training_type.sync_batchnorm = self.sync_batchnorm - return training_type def select_accelerator(self) -> Accelerator: diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 5d2f141dc64a8..b3fc0b4eb7b29 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -150,10 +150,6 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = N self.trainer.datamodule = datamodule datamodule.trainer = self.trainer - # experimental feature for Flash - if hasattr(datamodule, "data_pipeline"): - model.data_pipeline = datamodule.data_pipeline - class _PatchDataLoader(object): r""" diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py index 1f1c41c6eb2f0..2e788c256af0d 100644 --- a/pytorch_lightning/trainer/connectors/env_vars_connector.py +++ b/pytorch_lightning/trainer/connectors/env_vars_connector.py @@ -18,25 +18,27 @@ from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables -def _defaults_from_env_vars(fn: Callable) -> Callable: +def overwrite_by_env_vars(fn: Callable) -> Callable: """ Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which input arguments should be moved automatically to the correct device. + """ @wraps(fn) - def insert_env_defaults(self, *args, **kwargs): - cls = self.__class__ # get the class + def overwrite_by_env_vars(self, *args, **kwargs): + # get the class + cls = self.__class__ if args: # inace any args passed move them to kwargs # parse only the argument names cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] # convert args to kwargs kwargs.update({k: v for k, v in zip(cls_arg_names, args)}) - env_variables = vars(parse_env_variables(cls)) # update the kwargs by env variables - kwargs = dict(list(env_variables.items()) + list(kwargs.items())) + # todo: maybe add a warning that some init args were overwritten by Env arguments + kwargs.update(vars(parse_env_variables(cls))) # all args were already moved to kwargs return fn(self, **kwargs) - return insert_env_defaults + return overwrite_by_env_vars diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 7759c8028d325..223216846758f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -13,11 +13,9 @@ # limitations under the License. from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple -from weakref import proxy import torch -import pytorch_lightning as pl from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import DistributedType, LightningEnum @@ -52,7 +50,7 @@ class HookResultStore: Those data structures enables us to reduce properly Result object when batch loop is finished. """ - def __init__(self, fx_name: str) -> None: + def __init__(self, fx_name): self._fx_name = fx_name self._internals = {} self._internals_reduced = {} @@ -106,7 +104,6 @@ def get_batch_log_metrics(self, *args, **kwargs): def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs) -> None: if not isinstance(opt_metric, Result): raise Exception("The provided opt_metric should be a Result Object. Something is wrong") - func = getattr(opt_metric, func_name) metrics_to_log = func(*args, add_dataloader_idx=self.has_several_dataloaders, **kwargs) results.append(metrics_to_log) @@ -225,8 +222,8 @@ class EpochResultStore: ``` """ - def __init__(self, trainer: 'pl.Trainer') -> None: - self.trainer = proxy(trainer) + def __init__(self, trainer) -> None: + self.trainer = trainer self.reset() def __getitem__(self, key: str) -> Any: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 0d0c3781c7724..15428c5d5c248 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -81,13 +81,16 @@ def cached_results(self) -> Union[EpochResultStore, None]: return self._cached_results.get(self.trainer._running_stage) # type: ignore def get_metrics(self, key: str) -> Dict: - metrics_holder: MetricsHolder = getattr(self, f"_{key}") - model = self.trainer.lightning_module - metrics_holder.convert(model.device if model is not None else None) + metrics_holder = getattr(self, f"_{key}", None) + model_ref = self.trainer.lightning_module + metrics_holder.convert( + self.trainer._device_type == DeviceType.TPU, + model_ref.device if model_ref is not None else model_ref, + ) return metrics_holder.metrics def set_metrics(self, key: str, val: Dict) -> None: - metrics_holder: MetricsHolder = getattr(self, f"_{key}") + metrics_holder = getattr(self, f"_{key}", None) metrics_holder.reset(val) def reset(self) -> None: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py index 1efbcc638674f..82f328a927485 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -12,52 +12,44 @@ # See the License for the specific language governing permissions and # limitations under the License. import numbers -from typing import Any, Dict, Optional, Union +from typing import Any import torch -from torchmetrics import Metric -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -_METRIC_TYPE = Union[Metric, torch.Tensor, int, float, Any] +from pytorch_lightning.metrics.metric import Metric class MetricsHolder: """ - This class acts as a dictionary holder. + This class acts as a dictonary holder. It holds metrics and implements conversion functions. Those functions will be triggered within LoggerConnector when the property is being requested from the user. """ - def __init__(self, to_float: bool = False) -> None: - self.metrics: Dict[str, _METRIC_TYPE] = {} + def __init__(self, to_float: bool = False): + self.metrics = {} self._to_float = to_float - def update(self, metrics: dict) -> None: + def update(self, metrics): self.metrics.update(metrics) - def pop(self, key: str, default: _METRIC_TYPE) -> _METRIC_TYPE: + def pop(self, key, default): return self.metrics.pop(key, default) - def reset(self, metrics: Dict[str, _METRIC_TYPE]) -> None: + def reset(self, metrics): self.metrics = metrics - def convert(self, device: Optional[torch.device]) -> None: + def convert(self, use_tpu: bool, device: torch.device): for key, value in self.metrics.items(): - if self._to_float: - if isinstance(value, torch.Tensor) and value.numel() != 1: - raise MisconfigurationException( - f"The metric `{key}` does not contain a single element" - f" thus it cannot be converted to float. Found `{value}`" - ) - converted = self._convert_to_float(value) - else: - converted = self._convert_to_tensor(value, device) - self.metrics[key] = converted - - @staticmethod - def _convert_to_float(current: _METRIC_TYPE) -> float: + self.metrics[key] = self._convert(value, use_tpu, device) + + def _convert(self, current: Any, use_tpu: bool, device: torch.device): + if self._to_float: + return self._convert_to_float(current, use_tpu, device) + return self._convert_to_tensor(current, use_tpu, device) + + def _convert_to_float(self, current, use_tpu: bool, device: torch.device): if isinstance(current, Metric): current = current.compute().detach() @@ -69,13 +61,16 @@ def _convert_to_float(current: _METRIC_TYPE) -> float: return current - @staticmethod - def _convert_to_tensor(current: _METRIC_TYPE, device: Optional[torch.device]) -> torch.Tensor: - if isinstance(current, Metric): - current = current.compute().detach() + def _convert_to_tensor(self, current: Any, use_tpu: bool, device: torch.device): + if current is not None: + if isinstance(current, Metric): + current = current.compute().detach() - elif isinstance(current, numbers.Number): - current = torch.tensor(current, device=device, dtype=torch.float) + elif isinstance(current, numbers.Number): + if device is None: + current = torch.tensor(current, dtype=torch.float) + else: + current = torch.tensor(current, device=device, dtype=torch.float) if isinstance(current, torch.Tensor) and current.device.type == "xla": current = current.cpu() diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py index fa1002d70a7ce..98d65c1285ff7 100644 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License + from typing import Union -from weakref import proxy from pytorch_lightning.profiler import ( AdvancedProfiler, @@ -54,8 +54,6 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, str]): ) self.trainer.profiler = profiler or PassThroughProfiler() - def setup(self) -> None: - trainer = self.trainer + def on_train_start(self, trainer): local_rank = trainer.local_rank if trainer.world_size > 1 else None - trainer.profiler._lightning_module = proxy(trainer.lightning_module) - trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir) + self.trainer.profiler.on_train_start(local_rank) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 59ec40c3df2e8..83505913d0186 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -16,7 +16,7 @@ import platform from abc import ABC from copy import deepcopy -from typing import Iterable, List, Tuple, Union +from typing import Callable, Iterable, List, Tuple, Union from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler @@ -41,7 +41,7 @@ class TrainerDataLoadingMixin(ABC): tpu_local_core_rank: int train_dataloader: DataLoader num_training_batches: Union[int, float] - val_check_batch: float + val_check_batch:... val_dataloaders: List[DataLoader] num_val_batches: List[Union[int, float]] test_dataloaders: List[DataLoader] @@ -191,7 +191,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: Args: model: The current `LightningModule` """ - self.train_dataloader = self.request_dataloader(model, "train") + self.train_dataloader = self.request_dataloader(model.train_dataloader) if self.overfit_batches > 0: if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler): @@ -271,7 +271,7 @@ def _reset_eval_dataloader( """ # always get the loaders first so we can count how many there are loader_name = f'{mode}_dataloader' - dataloaders = self.request_dataloader(model, mode) + dataloaders = self.request_dataloader(getattr(model, loader_name)) if not isinstance(dataloaders, list): dataloaders = [dataloaders] @@ -280,7 +280,7 @@ def _reset_eval_dataloader( # duplicate it the numb of times needed to match the train loaders if self.overfit_batches > 0: num_loaders = len(dataloaders) - train_dataloader = self.request_dataloader(model, 'train') + train_dataloader = self.request_dataloader(getattr(model, 'train_dataloader')) dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)] self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders) @@ -293,9 +293,9 @@ def _reset_eval_dataloader( if mode in modes and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler): # when overfitting, the dataloader should not have sampler - if self.overfit_batches > 0 and mode != 'predict': + if self.overfit_batches > 0: rank_zero_warn( - 'You requested to overfit but enabled val/test dataloader shuffling.' + 'You requested to overfit but enabled test/val dataloader shuffling.' ' We are turning it off for you.' ) dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset)) @@ -303,7 +303,7 @@ def _reset_eval_dataloader( else: rank_zero_warn( f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn' - ' this off for val/test/predict dataloaders.' + ' this off for validation and test dataloaders.' ) if any([dl is None for dl in dataloaders]): @@ -380,7 +380,7 @@ def reset_predict_dataloader(self, model) -> None: if has_loader: self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(model, 'predict') - def request_dataloader(self, model: LightningModule, stage: str) -> DataLoader: + def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: """Handles downloading data in the GPU or TPU case. Args: @@ -389,10 +389,9 @@ def request_dataloader(self, model: LightningModule, stage: str) -> DataLoader: Returns: The dataloader """ - if model.trainer is not None: - model.trainer.call_hook(f"on_{stage}_dataloader") - dataloader: DataLoader = getattr(model, f'{stage}_dataloader')() + dataloader = dataloader_fx() dataloader = self._flatten_dl_only(dataloader) + self.accelerator.barrier('get_dataloaders') return dataloader diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 32dbc8c4088a3..69d3887fc7718 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -14,7 +14,7 @@ from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector -from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_deprecation +from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn class DeprecatedDistDeviceAttributes: @@ -24,94 +24,96 @@ class DeprecatedDistDeviceAttributes: @property def on_cpu(self) -> bool: - rank_zero_deprecation("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self.accelerator_connector._device_type == DeviceType.CPU @on_cpu.setter def on_cpu(self, val: bool) -> None: - rank_zero_deprecation("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) if val: self.accelerator_connector._device_type = DeviceType.CPU @property def on_tpu(self) -> bool: - rank_zero_deprecation("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self.accelerator_connector._device_type == DeviceType.TPU @on_tpu.setter def on_tpu(self, val: bool) -> None: - rank_zero_deprecation("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) if val: self.accelerator_connector._device_type = DeviceType.TPU @property def use_tpu(self) -> bool: - rank_zero_deprecation("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self.on_tpu @use_tpu.setter def use_tpu(self, val: bool) -> None: - rank_zero_deprecation("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) self.on_tpu = val @property def on_gpu(self) -> bool: - rank_zero_deprecation("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self.accelerator_connector._device_type == DeviceType.GPU @on_gpu.setter def on_gpu(self, val: bool) -> None: - rank_zero_deprecation("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) if val: self.accelerator_connector._device_type = DeviceType.GPU @property def use_dp(self) -> bool: - rank_zero_deprecation("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self.accelerator_connector._distrib_type == DistributedType.DP @use_dp.setter def use_dp(self, val: bool) -> None: - rank_zero_deprecation("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) if val: self.accelerator_connector._distrib_type = DistributedType.DP @property def use_ddp(self) -> bool: - rank_zero_deprecation("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self.accelerator_connector._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) @use_ddp.setter def use_ddp(self, val: bool) -> None: - rank_zero_deprecation("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) if val: self.accelerator_connector._distrib_type = DistributedType.DDP @property def use_ddp2(self) -> bool: - rank_zero_deprecation("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self.accelerator_connector._distrib_type == DistributedType.DDP2 @use_ddp2.setter def use_ddp2(self, val: bool) -> None: - rank_zero_deprecation("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) if val: self.accelerator_connector._distrib_type = DistributedType.DDP2 @property def use_horovod(self) -> bool: - rank_zero_deprecation("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self.accelerator_connector._distrib_type == DistributedType.HOROVOD @use_horovod.setter def use_horovod(self, val: bool) -> None: - rank_zero_deprecation("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) if val: self.accelerator_connector._distrib_type = DistributedType.HOROVOD @property def use_single_gpu(self) -> bool: - rank_zero_deprecation("Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn( + "Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning + ) # todo, limiting to exclude DDP2 is not clear but it comes from connectors... return ( self.accelerator_connector._device_type and self.accelerator_connector._device_type == DeviceType.GPU @@ -120,7 +122,10 @@ def use_single_gpu(self) -> bool: @use_single_gpu.setter def use_single_gpu(self, val: bool) -> None: - rank_zero_deprecation("Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.") + rank_zero_warn( + "Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.", + DeprecationWarning, + ) if val: self.accelerator_connector._device_type = DeviceType.GPU @@ -133,22 +138,23 @@ class DeprecatedTrainerAttributes: @property def accelerator_backend(self) -> Accelerator: - rank_zero_deprecation( + rank_zero_warn( "The `Trainer.accelerator_backend` attribute is deprecated in favor of `Trainer.accelerator`" - " since 1.2 and will be removed in v1.4." + " since 1.2 and will be removed in v1.4.", DeprecationWarning ) return self.accelerator def get_model(self) -> LightningModule: - rank_zero_deprecation( + rank_zero_warn( "The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`" - " and will be removed in v1.4." + " and will be removed in v1.4.", DeprecationWarning ) return self.lightning_module @property def running_sanity_check(self) -> bool: - rank_zero_deprecation( - "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5." + rank_zero_warn( + "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking`" + " and will be removed in v1.5.", DeprecationWarning ) return self.sanity_checking diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index a87073428e725..91cfc2ec757d5 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -11,15 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import torch from pytorch_lightning.core.step_result import Result -from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.warnings import WarningCache @@ -100,10 +97,6 @@ def on_evaluation_end(self, *args, **kwargs): else: self.trainer.call_hook('on_validation_end', *args, **kwargs) - if self.trainer.state != TrainerState.FITTING: - # summarize profile results - self.trainer.profiler.describe() - def reload_evaluation_dataloaders(self): model = self.trainer.lightning_module if self.trainer.testing: @@ -125,8 +118,6 @@ def setup(self, model, max_batches, dataloaders): self._predictions = [[] for _ in range(self.num_dataloaders)] def on_evaluation_epoch_start(self, *args, **kwargs): - self.trainer.call_hook('on_epoch_start', *args, **kwargs) - if self.trainer.testing: self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) else: @@ -211,6 +202,9 @@ def __run_eval_epoch_end(self, num_dataloaders): # with a single dataloader don't pass an array outputs = self.outputs + # free memory + self.outputs = [] + eval_results = outputs if num_dataloaders == 1: eval_results = outputs[0] @@ -319,40 +313,13 @@ def store_predictions(self, output, batch_idx, dataloader_idx): def on_evaluation_epoch_end(self, *args, **kwargs): # call the callback hook - self.call_on_evaluation_epoch_end_hook() + if self.trainer.testing: + self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) self.trainer.call_hook('on_epoch_end') - def call_on_evaluation_epoch_end_hook(self): - outputs = self.outputs - - # free memory - self.outputs = [] - - model_ref = self.trainer.lightning_module - hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" - - self.trainer._reset_result_and_set_hook_fx_name(hook_name) - - with self.trainer.profiler.profile(hook_name): - - if hasattr(self.trainer, hook_name): - on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name) - on_evaluation_epoch_end_hook(outputs) - - if is_overridden(hook_name, model_ref): - model_hook_fx = getattr(model_ref, hook_name) - if is_param_in_hook_signature(model_hook_fx, "outputs"): - model_hook_fx(outputs) - else: - self.warning_cache.warn( - f"`ModelHooks.{hook_name}` signature has changed in v1.3. `outputs` parameter has been added." - " Support for the old signature will be removed in v1.5", DeprecationWarning - ) - model_hook_fx() - - self.trainer._cache_logged_metrics() - def log_evaluation_step_metrics(self, output, batch_idx): if self.trainer.sanity_checking: return diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index b33f41cb2ea48..4fe6960055ca9 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -28,24 +28,23 @@ def __init__(self, trainer): def on_trainer_init(self): self.trainer.num_predict_batches = [] - def get_predict_dataloaders(self): + def get_predict_dataloaders(self, max_batches): self.trainer.reset_predict_dataloader(self.trainer.lightning_module) dataloaders = self.trainer.predict_dataloaders - max_batches = self.trainer.num_predict_batches + if max_batches is None: + max_batches = self.trainer.num_predict_batches return dataloaders, max_batches - def should_skip_predict(self, max_batches): - return sum(max_batches) == 0 + def should_skip_predict(self, dataloaders, max_batches): + return dataloaders is None or not sum(max_batches) def on_predict_model_eval(self, *_, **__): model_ref = self.trainer.lightning_module model_ref.on_predict_model_eval() def setup(self, model, max_batches, dataloaders): - self.trainer.call_hook("on_predict_start") - # copy properties for forward overrides self.trainer.model_connector.copy_trainer_model_properties(model) @@ -67,7 +66,7 @@ def _get_num_dataloaders(self, dataloaders): length = len(dataloaders[0]) return length - def predict_step(self, batch, batch_idx, dataloader_idx): + def predict(self, batch, batch_idx, dataloader_idx): # configure args args = [batch, batch_idx] if self.num_dataloaders: @@ -76,7 +75,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx): model_ref = self.trainer.lightning_module model_ref._current_fx_name = "predict" - predictions = self.trainer.accelerator.predict_step(args) + predictions = self.trainer.accelerator.predict(args) if predictions is None: self.warning_cache.warn("predict returned None if it was on purpose, ignore this warning...") @@ -88,8 +87,6 @@ def predict_step(self, batch, batch_idx, dataloader_idx): return def on_predict_epoch_end(self): - self.trainer.profiler.describe() - self.trainer._progress_bar_callback.on_predict_end(self.trainer, self.trainer.lightning_module) results = self._predictions @@ -103,11 +100,3 @@ def _convert_to_numpy(v): return results[0] return results - - def on_predict_start(self): - # hook - self.trainer.call_hook("on_predict_start") - - def on_predict_end(self): - # hook - self.trainer.call_hook("on_predict_end") diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 315e3c60c0557..b5654b148afc6 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -491,16 +491,6 @@ def sanity_checking(self, val: bool) -> None: elif self.sanity_checking: self._running_stage = None - @property - def _setup_state(self) -> TrainerState: - # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" - return TrainerState.FITTING if self.state == TrainerState.TUNING else self.state - - @property - def _teardown_state(self) -> Optional[TrainerState]: - if self.state.running: - return self._setup_state - # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c692c3f1c113f..c3039d24aadc0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -38,7 +38,7 @@ from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector -from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars +from pytorch_lightning.trainer.connectors.env_vars_connector import overwrite_by_env_vars from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector @@ -57,7 +57,7 @@ from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import DeviceType, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -84,7 +84,7 @@ class Trainer( DeprecatedTrainerAttributes, ): - @_defaults_from_env_vars + @overwrite_by_env_vars def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, @@ -198,13 +198,11 @@ def __init__( gradient_clip_val: 0 means don't clip. - limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches) + limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches) - limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches) + limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches) - limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches) - - limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches) + limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches) logger: Logger (or iterable collection of loggers) for experiment tracking. @@ -223,12 +221,11 @@ def __init__( profiler: To profile individual steps during training and assist in identifying bottlenecks. - overfit_batches: Overfit a fraction of training data (float) or a set number of batches (int). + overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. - precision: Double precision (64), full precision (32) or half precision (16). Can be used on CPU, GPU or - TPUs. + precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs. max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000. @@ -321,10 +318,6 @@ def __init__( self.predict_loop = PredictLoop(self) # training state - if weights_summary is not None and weights_summary not in ModelSummary.MODES: - raise MisconfigurationException( - f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, but got {weights_summary}" - ) self.weights_summary = weights_summary self.shown_warnings = set() @@ -356,6 +349,7 @@ def __init__( max_steps, min_steps, num_sanity_val_steps, + weights_summary, ) self.evaluation_loop.on_trainer_init() @@ -432,10 +426,8 @@ def fit( # ---------------------------- # SET UP TRAINING # ---------------------------- + self.call_setup_hook(model) self.call_hook("on_before_accelerator_backend_setup", model) - self.accelerator.connect(model) - self.accelerator.setup_environment() - self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module # ---------------------------- @@ -449,15 +441,13 @@ def fit( | || {self.dispatch} || | || LIGHTNING - {self.accelerator.start_training} || - or {self.accelerator.start_evaluating} || - or {self.accelerator.start_predicting} || FLOW - | || - {self.run_stage} || + {self.accelerator.start_training} or || + {self.accelerator.start_evaluating} or || FLOW + {self.accelerator.start_predicting} || | || DIRECTION - {self.run_train} || - or {self.run_evaluation} || - or {self.run_predict} || + {self.run_train} or || + {self.run_evaluation} or || + {self.run_predict} || | || results \/ This is used to guide readers to the core loops: train, test, predict. @@ -501,7 +491,7 @@ def fit( return self.accelerator.results or 1 def pre_dispatch(self): - self.accelerator.pre_dispatch(self) + self.accelerator.pre_dispatch() # log hyper-parameters if self.logger is not None: @@ -511,7 +501,7 @@ def pre_dispatch(self): self.logger.save() def post_dispatch(self): - self.accelerator.post_dispatch(self) + self.accelerator.post_dispatch() self.accelerator.teardown() def dispatch(self): @@ -524,9 +514,6 @@ def dispatch(self): def run_stage(self): results = None - - self.profile_connector.setup() - if self.evaluating: results = self.run_evaluate() elif self.predicting: @@ -553,7 +540,10 @@ def _pre_training_routine(self): # print model summary if self.is_global_zero and self.weights_summary is not None and not self.testing: - ref_model.summarize(mode=self.weights_summary) + if self.weights_summary in ModelSummary.MODES: + ref_model.summarize(mode=self.weights_summary) + else: + raise MisconfigurationException("weights_summary can be None, " + ", ".join(ModelSummary.MODES)) # restore training and model before hpc is called self.checkpoint_connector.restore_weights() @@ -763,13 +753,11 @@ def run_evaluate(self): return eval_loop_results def run_predict(self): - self.predict_loop.on_predict_start() - # prepare dataloaders - dataloaders, max_batches = self.predict_loop.get_predict_dataloaders() + dataloaders, max_batches = self.predict_loop.get_predict_dataloaders(None) # check if we want to skip this evaluation - if self.predict_loop.should_skip_predict(max_batches): + if self.predict_loop.should_skip_predict(dataloaders, max_batches): return [] # ref model @@ -787,6 +775,7 @@ def run_predict(self): for dataloader_idx, dataloader in enumerate(dataloaders): dataloader = self.accelerator.process_dataloader(dataloader) dl_max_batches = self.predict_loop.max_batches[dataloader_idx] + for batch_idx, batch in enumerate(dataloader): if batch is None: continue @@ -796,15 +785,10 @@ def run_predict(self): break # lightning module methods - with self.profiler.profile("predict_step"): - self.predict_loop.predict_step(batch, batch_idx, dataloader_idx) + with self.profiler.profile("predict"): + self.predict_loop.predict(batch, batch_idx, dataloader_idx) results = self.predict_loop.on_predict_epoch_end() - self.predict_loop.on_predict_end() - - # re-enable grads - torch.set_grad_enabled(True) - return results def run_sanity_check(self, ref_model): @@ -938,7 +922,9 @@ def test( # If you supply a datamodule you can't supply test_dataloaders if test_dataloaders and datamodule: - raise MisconfigurationException('You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`') + raise MisconfigurationException( + 'You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`' + ) model_provided = model is not None model = model or self.lightning_module @@ -983,9 +969,7 @@ def __load_ckpt_weights( ' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`' ) - # only one process running at this point for TPUs, as spawn isn't triggered yet - if not self._device_type == DeviceType.TPU: - self.training_type_plugin.barrier() + self.training_type_plugin.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) @@ -1074,7 +1058,8 @@ def tune( def call_setup_hook(self, model: LightningModule) -> None: assert self.state.running, f"TrainerState: {self.state}" - state = self._setup_state + # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" + state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state if self.datamodule is not None: called = getattr(self.datamodule, f'has_setup_{state}') @@ -1085,14 +1070,11 @@ def call_setup_hook(self, model: LightningModule) -> None: model.setup(stage=state) def call_teardown_hook(self, model: LightningModule) -> None: - state = self._teardown_state - - if self.datamodule is not None: - called = getattr(self.datamodule, f'has_teardown_{state}') - if not called: - self.datamodule.teardown(stage=state) + if self.state.running: + state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + else: + state = None - self.profiler.teardown(stage=state) self.teardown(stage=state) model.teardown(stage=state) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 427ef8100af28..88b87afcb9358 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -14,12 +14,12 @@ from contextlib import contextmanager, suppress from copy import copy, deepcopy -from typing import Optional import numpy as np import torch from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins import ParallelPlugin @@ -36,7 +36,7 @@ class TrainLoop: - def __init__(self, trainer, multiple_trainloader_mode: str): + def __init__(self, trainer, multiple_trainloader_mode): self.trainer = trainer self.early_stopping_accumulator = None self.checkpoint_accumulator = None @@ -53,12 +53,13 @@ def __init__(self, trainer, multiple_trainloader_mode: str): def on_trainer_init( self, - max_epochs: Optional[int], - min_epochs: Optional[int], - max_steps: Optional[int], - min_steps: Optional[int], - num_sanity_val_steps: int, - ) -> None: + max_epochs, + min_epochs, + max_steps, + min_steps, + num_sanity_val_steps, + weights_summary, + ): self.trainer.global_step = 0 self.trainer.current_epoch = 0 self.trainer.should_stop = False @@ -81,6 +82,12 @@ def on_trainer_init( else: self.trainer.num_sanity_val_steps = num_sanity_val_steps + self.trainer.weights_summary = weights_summary + if weights_summary is not None and weights_summary not in ModelSummary.MODES: + raise MisconfigurationException( + f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, got {weights_summary}" + ) + @property def num_optimizers(self): num_optimizers = len(self.get_optimizers_iterable()) @@ -95,7 +102,10 @@ def on_train_start(self): # hook self.trainer.call_hook("on_train_start") - def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): + # provide rank to profiler + self.trainer.profile_connector.on_train_start(self.trainer) + + def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) @@ -130,7 +140,8 @@ def on_train_end(self): self.trainer.logger.finalize("success") # summarize profile results - self.trainer.profiler.describe() + if self.trainer.global_rank == 0: + self.trainer.profiler.describe() # give accelerators a chance to finish self.trainer.accelerator.on_train_end() @@ -177,7 +188,7 @@ def on_train_epoch_start(self, epoch): self.trainer.train_dataloader.sampler.set_epoch(epoch) # changing gradient according accumulation_scheduler - self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) + self.trainer.accumulation_scheduler.on_epoch_start(self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch self.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches) @@ -540,7 +551,7 @@ def run_training_epoch(self): self.increment_accumulated_grad_global_step() # epoch end hook - self.on_train_epoch_end(epoch_output) + self.run_on_epoch_end_hook(epoch_output) # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics( @@ -736,7 +747,7 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, # backward pass if result is not None: - with self.trainer.profiler.profile("backward"): + with self.trainer.profiler.profile("model_backward"): self.backward(result, optimizer, opt_idx) # hook - call this hook only @@ -782,7 +793,7 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None): # update lr self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics) - def on_train_epoch_end(self, epoch_output): + def run_on_epoch_end_hook(self, epoch_output): # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index b9fa9afe0e77e..78810141b1369 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -33,20 +33,13 @@ def on_trainer_init(self, auto_lr_find, auto_scale_batch_size): self.trainer.auto_lr_find = auto_lr_find self.trainer.auto_scale_batch_size = auto_scale_batch_size - def setup_trainer( - self, - model: LightningModule, - train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - datamodule: LightningDataModule = None, - ): - self.trainer.model_connector.copy_trainer_model_properties(model) + def tune(self, model, train_dataloader, val_dataloaders, datamodule): # setup data, etc... self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) + # hook self.trainer.data_connector.prepare_data(model) - def tune(self, model, train_dataloader, val_dataloaders, datamodule): # Run auto batch size scaling if self.trainer.auto_scale_batch_size: if isinstance(self.trainer.auto_scale_batch_size, bool): @@ -111,7 +104,6 @@ def scale_batch_size( or datamodule. """ - self.setup_trainer(model, **fit_kwargs) return scale_batch_size( self.trainer, model, @@ -136,7 +128,6 @@ def lr_find( datamodule: Optional[LightningDataModule] = None, update_attr: bool = False, ): - self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule) return lr_find( self.trainer, model, diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index f4617c23da383..3e2ee3e51efe1 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -17,7 +17,6 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 from pytorch_lightning.utilities.distributed import ( # noqa: F401 AllGatherGrad, - rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn, diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index 46d88184ee190..ee42ab3241ff6 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -67,7 +67,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp # Value has been passed as a flag => It is currently None, so we need to set it to True # We always set to True, regardless of the default value. # Users must pass False directly, but when passing nothing True is assumed. - # i.e. the only way to disable something that defaults to True is to use the long form: + # i.e. the only way to disable somthing that defaults to True is to use the long form: # "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None, # which then becomes True here. @@ -107,7 +107,7 @@ def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: - r"""Scans the class signature and returns argument names, types and default values. + r"""Scans the Trainer signature and returns argument names, types and default values. Returns: List with tuples of 3 values: @@ -119,11 +119,11 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: >>> args = get_init_arguments_and_types(Trainer) """ - cls_default_params = inspect.signature(cls).parameters + trainer_default_params = inspect.signature(cls).parameters name_type_default = [] - for arg in cls_default_params: - arg_type = cls_default_params[arg].annotation - arg_default = cls_default_params[arg].default + for arg in trainer_default_params: + arg_type = trainer_default_params[arg].annotation + arg_default = trainer_default_params[arg].default try: arg_types = tuple(arg_type.__args__) except AttributeError: @@ -242,6 +242,9 @@ def add_argparse_args( if arg == 'track_grad_norm': use_type = float + if arg_default is inspect._empty: + arg_default = None + parser.add_argument( f'--{arg}', dest=arg, @@ -288,7 +291,10 @@ def _gpus_allowed_type(x) -> Union[int, str]: def _gpus_arg_default(x) -> Union[int, str]: - return _gpus_allowed_type(x) + if ',' in x: + return str(x) + else: + return int(x) def _int_or_float_type(x) -> Union[int, float]: diff --git a/pytorch_lightning/utilities/argparse_utils.py b/pytorch_lightning/utilities/argparse_utils.py index 80db2429f7d2a..eb53579f948e8 100644 --- a/pytorch_lightning/utilities/argparse_utils.py +++ b/pytorch_lightning/utilities/argparse_utils.py @@ -1,5 +1,7 @@ -from pytorch_lightning.utilities import rank_zero_deprecation +from warnings import warn -rank_zero_deprecation("`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v1.4") +warn( + "`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v1.4", DeprecationWarning +) from pytorch_lightning.utilities.argparse import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index bf7a199fc08dc..e797c32bbf917 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -15,7 +15,7 @@ import logging import os import warnings -from functools import partial, wraps +from functools import wraps from typing import Any, Optional, Union import torch @@ -24,7 +24,6 @@ if torch.distributed.is_available(): from torch.distributed import group, ReduceOp - else: class ReduceOp: @@ -63,7 +62,6 @@ def _debug(*args, **kwargs): rank_zero_debug = rank_zero_only(_debug) rank_zero_info = rank_zero_only(_info) rank_zero_warn = rank_zero_only(_warn) -rank_zero_deprecation = partial(rank_zero_warn, category=DeprecationWarning) def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None): @@ -173,7 +171,7 @@ def backward(ctx, *grad_output): torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) - return grad_output[torch.distributed.get_rank()], None + return grad_output[torch.distributed.get_rank()] def all_gather_ddp_if_available( diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index baeac9be57218..41a13d6c678a0 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities""" -import importlib import operator import platform import sys @@ -20,7 +19,7 @@ from importlib.util import find_spec import torch -from pkg_resources import DistributionNotFound +from pkg_resources import DistributionNotFound, get_distribution def _module_available(module_path: str) -> bool: @@ -43,24 +42,11 @@ def _module_available(module_path: str) -> bool: def _compare_version(package: str, op, version) -> bool: - """ - Compare package version with some requirements - - >>> _compare_version("torch", operator.ge, "0.1") - True - """ try: - pkg = importlib.import_module(package) - except (ModuleNotFoundError, DistributionNotFound): - return False - try: - pkg_version = LooseVersion(pkg.__version__) - except AttributeError: + pkg_version = LooseVersion(get_distribution(package).version) + return op(pkg_version, LooseVersion(version)) + except DistributionNotFound: return False - if not (hasattr(pkg_version, "vstring") and hasattr(pkg_version, "version")): - # this is mock by sphinx, so it shall return True ro generate all summaries - return True - return op(pkg_version, LooseVersion(version)) _IS_WINDOWS = platform.system() == "Windows" @@ -68,9 +54,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0") _TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0") _TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") -_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0") -_KINETO_AVAILABLE = torch.profiler.kineto_available() if _TORCH_GREATER_EQUAL_1_8 else False _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available('pl_bolts') _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') diff --git a/pytorch_lightning/utilities/model_utils.py b/pytorch_lightning/utilities/model_utils.py index 728f73f4f0d32..7fd5b287f7ba3 100644 --- a/pytorch_lightning/utilities/model_utils.py +++ b/pytorch_lightning/utilities/model_utils.py @@ -1,7 +1,8 @@ -from pytorch_lightning.utilities import rank_zero_deprecation +from warnings import warn -rank_zero_deprecation( - "`model_utils` package has been renamed to `model_helpers` since v1.2 and will be removed in v1.4" +warn( + "`model_utils` package has been renamed to `model_helpers` since v1.2 and will be removed in v1.4", + DeprecationWarning ) from pytorch_lightning.utilities.model_helpers import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/signature_utils.py b/pytorch_lightning/utilities/signature_utils.py deleted file mode 100644 index 546d8e845ecb1..0000000000000 --- a/pytorch_lightning/utilities/signature_utils.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -from typing import Callable - - -def is_param_in_hook_signature(hook_fx: Callable, param: str) -> bool: - hook_params = list(inspect.signature(hook_fx).parameters) - if "args" in hook_params or param in hook_params: - return True - return False diff --git a/pytorch_lightning/utilities/warning_utils.py b/pytorch_lightning/utilities/warning_utils.py index 0668bababa609..c520086f62a81 100644 --- a/pytorch_lightning/utilities/warning_utils.py +++ b/pytorch_lightning/utilities/warning_utils.py @@ -1,5 +1,7 @@ -from pytorch_lightning.utilities import rank_zero_deprecation +from warnings import warn -rank_zero_deprecation("`warning_utils` package has been renamed to `warnings` since v1.2 and will be removed in v1.4") +warn( + "`warning_utils` package has been renamed to `warnings` since v1.2 and will be removed in v1.4", DeprecationWarning +) from pytorch_lightning.utilities.warnings import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py index f028222e3930b..aa0af1697ac51 100644 --- a/pytorch_lightning/utilities/xla_device_utils.py +++ b/pytorch_lightning/utilities/xla_device_utils.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.utilities import rank_zero_deprecation +from warnings import warn -rank_zero_deprecation( - "`xla_device_utils` package has been renamed to `xla_device` since v1.2 and will be removed in v1.4" +warn( + "`xla_device_utils` package has been renamed to `xla_device` since v1.2 and will be removed in v1.4", + DeprecationWarning ) from pytorch_lightning.utilities.xla_device import * # noqa: F403 E402 F401 diff --git a/requirements.txt b/requirements.txt index 4649983b79d78..bdfd6601ba4c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,5 +8,3 @@ PyYAML>=5.1, !=5.4.* # OmegaConf requirement >=5.1 tqdm>=4.41.0 fsspec[http]>=0.8.1 tensorboard>=2.2.0 -torchmetrics>=0.2.0 -pyDeprecate==0.1.1 \ No newline at end of file diff --git a/requirements/adjust_versions.py b/requirements/adjust_versions.py index d0dfbc59e2352..c1499cd4ea5ee 100644 --- a/requirements/adjust_versions.py +++ b/requirements/adjust_versions.py @@ -11,7 +11,6 @@ "1.7.0": dict(torchvision="0.8.1", torchtext="0.8"), "1.7.1": dict(torchvision="0.8.2", torchtext="0.8.1"), "1.8.0": dict(torchvision="0.9.0", torchtext="0.9"), - "1.8.1": dict(torchvision="0.9.0", torchtext="0.9"), } diff --git a/requirements/extra.txt b/requirements/extra.txt index 715916c4e36ac..85437327bce06 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -4,8 +4,7 @@ matplotlib>3.1 horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already installed omegaconf>=2.0.1 torchtext>=0.5 -# onnx>=1.7.0 +onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -# todo: when switch to standard package stream, drop `fairscale` from hard mocked docs libs https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip diff --git a/requirements/test.txt b/requirements/test.txt index 259cc2e2d6442..60c861cea9c50 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,11 +1,12 @@ -coverage>5.2.0 +coverage>=5.2 codecov>=2.1 pytest>=6.0 -#pytest-cov>2.10 -#pytest-xdist +pytest-cov>2.10 +pytest-xdist flake8>=3.6 check-manifest twine==3.2 +# scipy>=0.13.3 scikit-learn>=0.22.2 scikit-image>=0.17.2 isort>=5.6.4 diff --git a/setup.cfg b/setup.cfg index 6365482e32aa8..0e64df0530d82 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,14 +39,23 @@ exclude_lines = pass rank_zero_warn raise NotImplementedError + # TODO: figure out how to get codecov to pick up the test results on these backends # The actual coverage for each is 90%+ # *metrics (94%+) are temporarily removed from testing while tests speed up omit = - pytorch_lightning/cluster_environments/*.py + pytorch_lightning/accelerators/ddp_*.py + pytorch_lightning/accelerators/ddp2_*.py + pytorch_lightning/accelerators/dp_*.py + pytorch_lightning/accelerators/tpu_*.py pytorch_lightning/utilities/xla_device_utils.py pytorch_lightning/utilities/distributed.py pytorch_lightning/tuner/auto_gpu_select.py + # TODO: temporary, until accelerator refactor is finished + pytorch_lightning/accelerators/accelerator.py + pytorch_lightning/plugins/training_type/*.py + pytorch_lightning/plugins/precision/*.py + pytorch_lightning/plugins/base_plugin.py [flake8] @@ -64,8 +73,10 @@ verbose = 2 # https://pep8.readthedocs.io/en/latest/intro.html#error-codes format = pylint ignore = - E731 # Ignore "Do not assign a lambda expression, use a def" - W503 # Ignore "Line break occurred before a binary operator" + E731 # do not assign a lambda expression, use a def + W503 # line break before binary operator + # because of YAPF - till https://github.com/google/yapf/issues/897 is resolved + E231 # missing whitespace after ',', ';', or ':'; for black # setup.cfg or tox.ini diff --git a/setup.py b/setup.py index e53e24ebf0702..5d619d51977b2 100755 --- a/setup.py +++ b/setup.py @@ -16,22 +16,20 @@ import os # Always prefer setuptools over distutils -import sys - from setuptools import find_packages, setup try: - from pytorch_lightning import info, setup_tools + import builtins except ImportError: - # alternative https://stackoverflow.com/a/67692/4521646 - sys.path.append("pytorch_lightning") - import info - import setup_tools + import __builtin__ as builtins # https://packaging.python.org/guides/single-sourcing-package-version/ # http://blog.ionelmc.ro/2014/05/25/python-packaging/ -_PATH_ROOT = os.path.dirname(__file__) -_PATH_REQUIRE = os.path.join(_PATH_ROOT, 'requirements') +PATH_ROOT = os.path.dirname(__file__) +builtins.__LIGHTNING_SETUP__ = True + +import pytorch_lightning # noqa: E402 +from pytorch_lightning.setup_tools import _load_readme_description, _load_requirements # noqa: E402 # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras # Define package extras. These are only installed if you specify them. @@ -39,10 +37,10 @@ # From local copy of repo, use like `pip install ".[dev, docs]"` extras = { # 'docs': load_requirements(file_name='docs.txt'), - 'examples': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='examples.txt'), - 'loggers': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='loggers.txt'), - 'extra': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='extra.txt'), - 'test': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='test.txt') + 'examples': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='examples.txt'), + 'loggers': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='loggers.txt'), + 'extra': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='extra.txt'), + 'test': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='test.txt') } extras['dev'] = extras['extra'] + extras['loggers'] + extras['test'] extras['all'] = extras['dev'] + extras['examples'] # + extras['docs'] @@ -55,12 +53,6 @@ # filter cpu only packages extras[ex] = [pkg for pkg in extras[kw] if not any(pgpu.lower() in pkg.lower() for pgpu in PACKAGES_GPU_ONLY)] -long_description = setup_tools._load_readme_description( - _PATH_ROOT, - homepage=info.__homepage__, - version=info.__version__, -) - # https://packaging.python.org/discussions/install-requires-vs-requirements / # keep the meta-data here for simplicity in reading this file... it's not obvious # what happens and to non-engineers they won't know to look in init ... @@ -68,22 +60,22 @@ # engineer specific practices setup( name="pytorch-lightning", - version=info.__version__, - description=info.__docs__, - author=info.__author__, - author_email=info.__author_email__, - url=info.__homepage__, + version=pytorch_lightning.__version__, + description=pytorch_lightning.__docs__, + author=pytorch_lightning.__author__, + author_email=pytorch_lightning.__author_email__, + url=pytorch_lightning.__homepage__, download_url='https://github.com/PyTorchLightning/pytorch-lightning', - license=info.__license__, + license=pytorch_lightning.__license__, packages=find_packages(exclude=['tests', 'tests/*', 'benchmarks', 'legacy', 'legacy/*']), - long_description=long_description, + long_description=_load_readme_description(PATH_ROOT), long_description_content_type='text/markdown', include_package_data=True, zip_safe=False, keywords=['deep learning', 'pytorch', 'AI'], python_requires='>=3.6', setup_requires=[], - install_requires=setup_tools._load_requirements(_PATH_ROOT), + install_requires=_load_requirements(PATH_ROOT), extras_require=extras, project_urls={ "Bug Tracker": "https://github.com/PyTorchLightning/pytorch-lightning/issues", diff --git a/tests/__init__.py b/tests/__init__.py index fc634e6b73fec..433f183896dee 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -19,8 +19,8 @@ _TEST_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_TEST_ROOT) _TEMP_PATH = os.path.join(_PROJECT_ROOT, 'test_temp') -PATH_DATASETS = os.path.join(_PROJECT_ROOT, 'Datasets') -PATH_LEGACY = os.path.join(_PROJECT_ROOT, 'legacy') +DATASETS_PATH = os.path.join(_PROJECT_ROOT, 'Datasets') +LEGACY_PATH = os.path.join(_PROJECT_ROOT, 'legacy') # todo: this setting `PYTHONPATH` may not be used by other evns like Conda for import packages if _PROJECT_ROOT not in os.getenv('PYTHONPATH', ""): diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 79a17df074e35..e6139de5d3028 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -98,8 +98,7 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock): "SLURM_LOCALID": "10" } ) -@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) -def test_accelerator_choice_ddp_slurm(setup_distributed_mock): +def test_accelerator_choice_ddp_slurm(): class CB(Callback): @@ -137,8 +136,7 @@ def on_fit_start(self, trainer, pl_module): } ) @mock.patch('torch.cuda.device_count', return_value=2) -@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) -def test_accelerator_choice_ddp2_slurm(device_count_mock, setup_distributed_mock): +def test_accelerator_choice_ddp2_slurm(device_count_mock): class CB(Callback): @@ -167,8 +165,7 @@ def on_fit_start(self, trainer, pl_module): @RunIf(min_gpus=1) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"}) @mock.patch('torch.cuda.device_count', return_value=2) -@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) -def test_accelerator_choice_ddp_te(device_count_mock, setup_distributed_mock): +def test_accelerator_choice_ddp_te(device_count_mock): class CB(Callback): @@ -196,8 +193,7 @@ def on_fit_start(self, trainer, pl_module): @RunIf(min_gpus=1) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"}) @mock.patch('torch.cuda.device_count', return_value=2) -@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) -def test_accelerator_choice_ddp2_te(device_count_mock, setup_distributed_mock): +def test_accelerator_choice_ddp2_te(device_count_mock): class CB(Callback): @@ -228,8 +224,7 @@ def on_fit_start(self, trainer, pl_module): "NODE_RANK": "0", }) @mock.patch('torch.cuda.device_count', return_value=0) -@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) -def test_accelerator_choice_ddp_cpu_te(device_count_mock, setup_distributed_mock): +def test_accelerator_choice_ddp_cpu_te(device_count_mock): class CB(Callback): @@ -264,8 +259,7 @@ def on_fit_start(self, trainer, pl_module): } ) @mock.patch('torch.cuda.device_count', return_value=0) -@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) -def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock): +def test_accelerator_choice_ddp_cpu_slurm(device_count_mock): class CB(Callback): @@ -300,8 +294,7 @@ def on_fit_start(self, trainer, pl_module): } ) @mock.patch('torch.cuda.device_count', return_value=0) -@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) -def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock, setup_distributed_mock): +def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock): """ Test that we choose the custom cluster even when SLURM or TE flags are around """ @@ -311,9 +304,6 @@ class CustomCluster(LightningEnvironment): def master_address(self): return 'asdf' - def creates_children(self) -> bool: - return True - class CB(Callback): def on_fit_start(self, trainer, pl_module): @@ -346,8 +336,7 @@ def on_fit_start(self, trainer, pl_module): } ) @mock.patch('torch.cuda.device_count', return_value=0) -@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) -def test_custom_accelerator(device_count_mock, setup_distributed_mock): +def test_custom_accelerator(device_count_mock): class Accel(Accelerator): pass @@ -382,8 +371,7 @@ class TrainTypePlugin(SingleDevicePlugin): } ) @mock.patch('torch.cuda.device_count', return_value=0) -@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) -def test_dist_backend_accelerator_mapping(device_count_mock, setup_distributed_mock): +def test_dist_backend_accelerator_mapping(device_count_mock): class CB(Callback): diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index bd8636ba839f9..6962af7249d1b 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -8,13 +8,11 @@ from tests.helpers.runif import RunIf -@pytest.mark.parametrize( - "trainer_kwargs", ( - pytest.param(dict(gpus=1), marks=RunIf(min_gpus=1)), - pytest.param(dict(accelerator="dp", gpus=2), marks=RunIf(min_gpus=2)), - pytest.param(dict(accelerator="ddp_spawn", gpus=2), marks=RunIf(min_gpus=2)), - ) -) +@pytest.mark.parametrize("trainer_kwargs", ( + pytest.param({"gpus": 1}, marks=RunIf(min_gpus=1)), + pytest.param({"accelerator": "dp", "gpus": 2}, marks=RunIf(min_gpus=2)), + pytest.param({"accelerator": "ddp_spawn", "gpus": 2}, marks=RunIf(min_gpus=2)), +)) def test_evaluate(tmpdir, trainer_kwargs): tutils.set_random_master_port() diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 46379a9d10c14..81a5132e47356 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -3,12 +3,10 @@ import pytest import torch -from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.plugins import SingleDevicePlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel def test_unsupported_precision_plugins(): @@ -20,33 +18,3 @@ def test_unsupported_precision_plugins(): ) with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."): accelerator.setup(trainer=trainer, model=model) - - -@pytest.mark.parametrize("delay_dispatch", [True, False]) -def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch): - """ - Test when using a custom training type plugin that delays setup optimizers, - we do not call setup optimizers till ``pre_dispatch``. - """ - - class TestModel(BoringModel): - - def on_fit_start(self): - if delay_dispatch: - # Ensure we haven't setup optimizers if we've delayed dispatch - assert len(self.trainer.optimizers) == 0 - else: - assert len(self.trainer.optimizers) > 0 - - def on_fit_end(self): - assert len(self.trainer.optimizers) > 0 - - class CustomPlugin(SingleDevicePlugin): - - @property - def setup_optimizers_in_pre_dispatch(self) -> bool: - return delay_dispatch - - model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=CustomPlugin(device=torch.device("cpu"))) - trainer.fit(model) diff --git a/tests/accelerators/test_ddp.py b/tests/accelerators/test_ddp.py index 541110ac8846b..14e73d920af4b 100644 --- a/tests/accelerators/test_ddp.py +++ b/tests/accelerators/test_ddp.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Optional -from unittest import mock from unittest.mock import patch import pytest @@ -93,6 +91,7 @@ def test_torch_distributed_backend_env_variables(tmpdir): _environ = {"PL_TORCH_DISTRIBUTED_BACKEND": "undefined", "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2"} with patch.dict(os.environ, _environ), \ patch('torch.cuda.device_count', return_value=2): + with pytest.raises(ValueError, match="Invalid backend: 'undefined'"): model = BoringModel() trainer = Trainer( @@ -103,30 +102,3 @@ def test_torch_distributed_backend_env_variables(tmpdir): logger=False, ) trainer.fit(model) - - -@RunIf(skip_windows=True) -@mock.patch('torch.cuda.device_count', return_value=1) -@mock.patch('torch.cuda.is_available', return_value=True) -@mock.patch('torch.cuda.set_device') -@mock.patch.dict(os.environ, {'PL_TORCH_DISTRIBUTED_BACKEND': 'gloo'}, clear=True) -def test_ddp_torch_dist_is_available_in_setup(mock_set_device, mock_is_available, mock_device_count, tmpdir): - """ - Test to ensure torch distributed is available within the setup hook using ddp - """ - - class TestModel(BoringModel): - - def setup(self, stage: Optional[str] = None) -> None: - assert torch.distributed.is_initialized() - raise SystemExit() - - model = TestModel() - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - accelerator="ddp", - gpus=1, - ) - with pytest.raises(SystemExit): - trainer.fit(model) diff --git a/tests/base/model_template.py b/tests/base/model_template.py index 86578fef4c699..1ec2df7865caa 100644 --- a/tests/base/model_template.py +++ b/tests/base/model_template.py @@ -18,7 +18,6 @@ import torch.nn.functional as F from pytorch_lightning.core.lightning import LightningModule -from tests import PATH_DATASETS from tests.base.model_optimizers import ConfigureOptimizersPool from tests.base.model_test_dataloaders import TestDataloaderVariations from tests.base.model_test_epoch_ends import TestEpochEndVariations @@ -29,7 +28,7 @@ from tests.base.model_valid_dataloaders import ValDataloaderVariations from tests.base.model_valid_epoch_ends import ValidationEpochEndVariations from tests.base.model_valid_steps import ValidationStepVariations -from tests.helpers.datasets import TrialMNIST +from tests.helpers.datasets import PATH_DATASETS, TrialMNIST class EvalModelTemplate( diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py index df0eab31aac37..78926cc9a7dd4 100644 --- a/tests/callbacks/test_callback_hook_outputs.py +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -71,66 +71,3 @@ def on_train_epoch_end(self, outputs) -> None: results = trainer.fit(model) assert results - - -def test_on_val_epoch_end_outputs(tmpdir): - - class CB(Callback): - - def on_validation_epoch_end(self, trainer, pl_module, outputs): - if trainer.running_sanity_check: - assert len(outputs[0]) == trainer.num_sanity_val_batches[0] - else: - assert len(outputs[0]) == trainer.num_val_batches[0] - - model = BoringModel() - - trainer = Trainer( - callbacks=CB(), - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - weights_summary=None, - ) - - trainer.fit(model) - - -def test_on_test_epoch_end_outputs(tmpdir): - - class CB(Callback): - - def on_test_epoch_end(self, trainer, pl_module, outputs): - assert len(outputs[0]) == trainer.num_test_batches[0] - - model = BoringModel() - - trainer = Trainer( - callbacks=CB(), - default_root_dir=tmpdir, - weights_summary=None, - ) - - trainer.test(model) - - -def test_free_memory_on_eval_outputs(tmpdir): - - class CB(Callback): - - def on_epoch_end(self, trainer, pl_module): - assert len(trainer.evaluation_loop.outputs) == 0 - - model = BoringModel() - - trainer = Trainer( - callbacks=CB(), - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - weights_summary=None, - ) - - trainer.fit(model) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 713971629bdf4..626eb59dffb9c 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -46,18 +46,17 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), - call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'fit'), + call.on_before_accelerator_backend_setup(trainer, model), call.on_fit_start(trainer, model), call.on_pretrain_routine_start(trainer, model), call.on_pretrain_routine_end(trainer, model), call.on_sanity_check_start(trainer, model), call.on_validation_start(trainer, model), - call.on_epoch_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_epoch_end(trainer, model, ANY), + call.on_validation_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_sanity_check_end(trainer, model), @@ -85,11 +84,10 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_train_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_validation_start(trainer, model), - call.on_epoch_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_epoch_end(trainer, model, ANY), + call.on_validation_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC @@ -117,16 +115,15 @@ def test_trainer_callback_hook_system_test(tmpdir): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), - call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'test'), + call.on_before_accelerator_backend_setup(trainer, model), call.on_test_start(trainer, model), - call.on_epoch_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_test_batch_start(trainer, model, ANY, 1, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0), - call.on_test_epoch_end(trainer, model, ANY), + call.on_test_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_test_end(trainer, model), call.teardown(trainer, model, 'test'), @@ -151,16 +148,15 @@ def test_trainer_callback_hook_system_validate(tmpdir): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), - call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'validate'), + call.on_before_accelerator_backend_setup(trainer, model), call.on_validation_start(trainer, model), - call.on_epoch_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_validation_batch_start(trainer, model, ANY, 1, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0), - call.on_validation_epoch_end(trainer, model, ANY), + call.on_validation_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.teardown(trainer, model, 'validate'), diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 7926bc46dd290..397e471e8a4b8 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -19,7 +19,6 @@ from pytorch_lightning import callbacks, seed_everything, Trainer from tests.helpers import BoringModel -from tests.helpers.runif import RunIf @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @@ -103,42 +102,3 @@ def training_step(self, batch, batch_idx): # make sure types are correct assert save_mock.call_count == expected - - -@mock.patch('torch.save') -@RunIf(special=True, min_gpus=2) -@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [(1, 1, 1.0, 1), (2, 2, 0.3, 5)]) -def test_top_k_ddp(save_mock, tmpdir, k, epochs, val_check_interval, expected): - - class TestModel(BoringModel): - - def training_step(self, batch, batch_idx): - local_rank = int(os.getenv("LOCAL_RANK")) - self.log('my_loss', batch_idx * (1 + local_rank), on_epoch=True) - return super().training_step(batch, batch_idx) - - def training_epoch_end(self, outputs) -> None: - data = str(self.global_rank) - obj = [[data], (data, ), set(data)] - out = self.trainer.training_type_plugin.broadcast(obj) - assert obj == [[str(self.global_rank)], (str(self.global_rank), ), set(str(self.global_rank))] - assert out == [['0'], ('0', ), set('0')] - - model = TestModel() - trainer = Trainer( - callbacks=[callbacks.ModelCheckpoint(dirpath=tmpdir, monitor='my_loss_step', save_top_k=k, mode="max")], - default_root_dir=tmpdir, - max_epochs=epochs, - weights_summary=None, - val_check_interval=val_check_interval, - accelerator="ddp", - gpus=2, - limit_train_batches=64, - limit_val_batches=32, - ) - if os.getenv("LOCAL_RANK") == "0": - with pytest.raises(UserWarning, match="The value associated to the key my_loss_epoch: [15.5, 31.0]"): - trainer.fit(model) - assert save_mock.call_count == expected - else: - trainer.fit(model) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 325cc4925f4f4..a22e8b77e67a3 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -18,9 +18,9 @@ import pytest from pytorch_lightning import Trainer -from tests import PATH_LEGACY +from tests import LEGACY_PATH -LEGACY_CHECKPOINTS_PATH = os.path.join(PATH_LEGACY, 'checkpoints') +LEGACY_CHECKPOINTS_PATH = os.path.join(LEGACY_PATH, 'checkpoints') CHECKPOINT_EXTENSION = ".ckpt" @@ -56,7 +56,6 @@ "1.2.1", "1.2.2", "1.2.3", - "1.2.4", ] ) def test_resume_legacy_checkpoints(tmpdir, pl_version: str): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 75f25b90fa45f..e5583b9bbdf86 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -356,7 +356,7 @@ def on_train_start(self, trainer, pl_module): torch.save = Mock(wraps=torch.save) def on_save_checkpoint(self, trainer, pl_module, checkpoint): - # only rank 0 will call ``torch.save`` + # expect all ranks to run but only rank 0 will actually write the checkpoint file super().on_save_checkpoint(trainer, pl_module, checkpoint) self.on_save_checkpoint_count += 1 @@ -366,7 +366,8 @@ def on_train_end(self, trainer, pl_module): assert self.best_model_score assert self.on_save_checkpoint_count == self.expected_count if trainer.is_global_zero: - assert torch.save.call_count == self.expected_count + # twice the calls expected because ddp broadcast also uses torch.save + assert torch.save.call_count == self.expected_count * 2 else: assert torch.save.call_count == 0 diff --git a/tests/checkpointing/test_torch_saving.py b/tests/checkpointing/test_torch_saving.py index 8eabc4640046f..c8b1e96aeaf0a 100644 --- a/tests/checkpointing/test_torch_saving.py +++ b/tests/checkpointing/test_torch_saving.py @@ -47,7 +47,6 @@ def test_model_torch_save_ddp_cpu(tmpdir): max_epochs=num_epochs, accelerator="ddp_cpu", num_processes=2, - logger=False, ) temp_path = os.path.join(tmpdir, 'temp.pt') trainer.fit(model) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index c8808ec37326c..2118fec6c207b 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -128,10 +128,6 @@ def test_data_hooks_called(tmpdir): assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict - assert not dm.has_teardown_fit - assert not dm.has_teardown_test - assert not dm.has_teardown_validate - assert not dm.has_teardown_predict dm.prepare_data() assert dm.has_prepared_data @@ -139,10 +135,6 @@ def test_data_hooks_called(tmpdir): assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict - assert not dm.has_teardown_fit - assert not dm.has_teardown_test - assert not dm.has_teardown_validate - assert not dm.has_teardown_predict dm.setup() assert dm.has_prepared_data @@ -150,84 +142,49 @@ def test_data_hooks_called(tmpdir): assert dm.has_setup_test assert dm.has_setup_validate assert not dm.has_setup_predict - assert not dm.has_teardown_fit - assert not dm.has_teardown_test - assert not dm.has_teardown_validate - assert not dm.has_teardown_predict - - dm.teardown() - assert dm.has_prepared_data - assert dm.has_setup_fit - assert dm.has_setup_test - assert dm.has_setup_validate - assert not dm.has_setup_predict - assert dm.has_teardown_fit - assert dm.has_teardown_test - assert dm.has_teardown_validate - assert not dm.has_teardown_predict @pytest.mark.parametrize("use_kwarg", (False, True)) def test_data_hooks_called_verbose(tmpdir, use_kwarg): dm = BoringDataModule() + assert not dm.has_prepared_data + assert not dm.has_setup_fit + assert not dm.has_setup_test + dm.prepare_data() + assert dm.has_prepared_data assert not dm.has_setup_fit assert not dm.has_setup_test - assert not dm.has_setup_validate assert not dm.has_setup_predict - assert not dm.has_teardown_fit - assert not dm.has_teardown_test - assert not dm.has_teardown_validate - assert not dm.has_teardown_predict dm.setup(stage='fit') if use_kwarg else dm.setup('fit') + assert dm.has_prepared_data assert dm.has_setup_fit assert not dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='validate') if use_kwarg else dm.setup('validate') + assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='test') if use_kwarg else dm.setup('test') + assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_validate assert dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='predict') if use_kwarg else dm.setup('predict') + assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_validate assert dm.has_setup_test assert dm.has_setup_predict - dm.teardown(stage='fit') if use_kwarg else dm.teardown('fit') - assert dm.has_teardown_fit - assert not dm.has_teardown_validate - assert not dm.has_teardown_test - assert not dm.has_teardown_predict - - dm.teardown(stage='validate') if use_kwarg else dm.teardown('validate') - assert dm.has_teardown_fit - assert dm.has_teardown_validate - assert not dm.has_teardown_test - assert not dm.has_teardown_predict - - dm.teardown(stage='test') if use_kwarg else dm.teardown('test') - assert dm.has_teardown_fit - assert dm.has_teardown_validate - assert dm.has_teardown_test - assert not dm.has_teardown_predict - - dm.teardown(stage='predict') if use_kwarg else dm.teardown('predict') - assert dm.has_teardown_fit - assert dm.has_teardown_validate - assert dm.has_teardown_test - assert dm.has_teardown_predict - def test_dm_add_argparse_args(tmpdir): parser = ArgumentParser() diff --git a/tests/core/test_hooks.py b/tests/core/test_hooks.py deleted file mode 100644 index 191da0a1400c7..0000000000000 --- a/tests/core/test_hooks.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pytorch_lightning import Trainer -from tests.helpers.boring_model import BoringModel - - -def test_on_val_epoch_end_outputs(tmpdir): - - class TestModel(BoringModel): - - def on_validation_epoch_end(self, outputs): - if trainer.running_sanity_check: - assert len(outputs[0]) == trainer.num_sanity_val_batches[0] - else: - assert len(outputs[0]) == trainer.num_val_batches[0] - - model = TestModel() - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - weights_summary=None, - ) - - trainer.fit(model) - - -def test_on_test_epoch_end_outputs(tmpdir): - - class TestModel(BoringModel): - - def on_test_epoch_end(self, outputs): - assert len(outputs[0]) == trainer.num_test_batches[0] - - model = TestModel() - - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=2, - weights_summary=None, - ) - - trainer.test(model) diff --git a/tests/core/test_memory.py b/tests/core/test_memory.py index 3088743f71488..903154adf823d 100644 --- a/tests/core/test_memory.py +++ b/tests/core/test_memory.py @@ -88,19 +88,6 @@ def forward(self, x): return self.reduce(self.embed(x)) -class PartialScriptModel(LightningModule): - """ A model which contains scripted layers. """ - - def __init__(self): - super().__init__() - self.layer1 = torch.jit.script(nn.Linear(5, 3)) - self.layer2 = nn.Linear(3, 2) - self.example_input_array = torch.rand(2, 5) - - def forward(self, x): - return self.layer2(self.layer1(x)) - - def test_invalid_weights_summmary(): """ Test that invalid value for weights_summary raises an error. """ with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'): @@ -227,15 +214,6 @@ def test_summary_layer_types(mode): ] -@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) -def test_summary_with_scripted_modules(mode): - model = PartialScriptModel() - summary = model.summarize(mode=mode) - assert summary.layer_types == ["RecursiveScriptModule", "Linear"] - assert summary.in_sizes == [UNKNOWN_SIZE, [2, 3]] - assert summary.out_sizes == [UNKNOWN_SIZE, [2, 2]] - - @pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) @pytest.mark.parametrize(['example_input', 'expected_size'], [ pytest.param([], UNKNOWN_SIZE), @@ -287,7 +265,7 @@ def test_empty_model_size(mode): @RunIf(min_gpus=1, amp_native=True) -def test_model_size_precision(tmpdir): +def test_model_size_precision(monkeypatch, tmpdir): """ Test model size for half and full precision. """ model = PreCalculatedModel() diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f..9d31688d9bcc0 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -15,10 +15,10 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -from torchmetrics import Metric import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result +from pytorch_lightning.metrics import Metric from tests.helpers.runif import RunIf diff --git a/tests/deprecated_api/__init__.py b/tests/deprecated_api/__init__.py index ccfae3ec8dcf2..99e21d1ed6b22 100644 --- a/tests/deprecated_api/__init__.py +++ b/tests/deprecated_api/__init__.py @@ -13,27 +13,9 @@ # limitations under the License. """Test deprecated functionality which will be removed in vX.Y.Z""" import sys -from contextlib import contextmanager -from typing import Optional - -import pytest def _soft_unimport_module(str_module): # once the module is imported e.g with parsing with pytest it lives in memory if str_module in sys.modules: del sys.modules[str_module] - - -@contextmanager -def no_deprecated_call(match: Optional[str] = None): - with pytest.warns(None) as record: - yield - try: - w = record.pop(DeprecationWarning) - if match is not None and match not in str(w.message): - return - except AssertionError: - # no DeprecationWarning raised - return - raise AssertionError(f"`DeprecationWarning` was raised: {w}") diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index 99e1b31f6edad..39f5e0dca5075 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -130,6 +130,16 @@ def test_v1_4_0_deprecated_metrics(): with pytest.deprecated_call(match='will be removed in v1.4'): precision_recall(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3))) + # Testing deprecation of class_reduction arg in the *new* precision + from pytorch_lightning.metrics.functional import precision + with pytest.deprecated_call(match='will be removed in v1.4'): + precision(torch.randint(0, 2, (10, )), torch.randint(0, 2, (10, )), class_reduction='micro') + + # Testing deprecation of class_reduction arg in the *new* recall + from pytorch_lightning.metrics.functional import recall + with pytest.deprecated_call(match='will be removed in v1.4'): + recall(torch.randint(0, 2, (10, )), torch.randint(0, 2, (10, )), class_reduction='micro') + from pytorch_lightning.metrics.functional.classification import auc with pytest.deprecated_call(match='will be removed in v1.4'): auc(torch.rand(10, ).sort().values, torch.rand(10, )) @@ -142,6 +152,14 @@ def test_v1_4_0_deprecated_metrics(): with pytest.deprecated_call(match='will be removed in v1.4'): multiclass_auroc(torch.rand(20, 5).softmax(dim=-1), torch.randint(0, 5, (20, )), num_classes=5) + from pytorch_lightning.metrics.functional.classification import auc_decorator + with pytest.deprecated_call(match='will be removed in v1.4'): + auc_decorator() + + from pytorch_lightning.metrics.functional.classification import multiclass_auc_decorator + with pytest.deprecated_call(match='will be removed in v1.4'): + multiclass_auc_decorator() + class CustomDDPPlugin(DDPSpawnPlugin): diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index fc3fe3112e71e..e65ebbab254de 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -20,9 +20,6 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger -from pytorch_lightning.profiler import AdvancedProfiler, BaseProfiler, PyTorchProfiler, SimpleProfiler -from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache -from tests.deprecated_api import no_deprecated_call from tests.helpers import BoringModel from tests.helpers.utils import no_warning_call @@ -81,11 +78,6 @@ def on_save_checkpoint(self, *args): trainer.save_checkpoint(filepath) -def test_v1_5_0_legacy_profiler_argument(): - with pytest.deprecated_call(match="renamed to `record_functions` in v1.3"): - PyTorchProfiler(profiled_functions=[]) - - def test_v1_5_0_running_sanity_check(): trainer = Trainer() with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'): @@ -119,102 +111,3 @@ def test_v1_5_0_model_checkpoint_period(tmpdir): ModelCheckpoint(dirpath=tmpdir) with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): ModelCheckpoint(dirpath=tmpdir, period=1) - - -def test_v1_5_0_old_on_validation_epoch_end(tmpdir): - callback_warning_cache.clear() - - class OldSignature(Callback): - - def on_validation_epoch_end(self, trainer, pl_module): # noqa - ... - - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer.fit(model) - - class OldSignatureModel(BoringModel): - - def on_validation_epoch_end(self): # noqa - ... - - model = OldSignatureModel() - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer.fit(model) - - callback_warning_cache.clear() - - class NewSignature(Callback): - - def on_validation_epoch_end(self, trainer, pl_module, outputs): - ... - - trainer.callbacks = [NewSignature()] - with no_deprecated_call(match="`Callback.on_validation_epoch_end` signature has changed in v1.3."): - trainer.fit(model) - - class NewSignatureModel(BoringModel): - - def on_validation_epoch_end(self, outputs): - ... - - model = NewSignatureModel() - with no_deprecated_call(match="`ModelHooks.on_validation_epoch_end` signature has changed in v1.3."): - trainer.fit(model) - - -def test_v1_5_0_old_on_test_epoch_end(tmpdir): - callback_warning_cache.clear() - - class OldSignature(Callback): - - def on_test_epoch_end(self, trainer, pl_module): # noqa - ... - - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer.test(model) - - class OldSignatureModel(BoringModel): - - def on_test_epoch_end(self): # noqa - ... - - model = OldSignatureModel() - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer.test(model) - - callback_warning_cache.clear() - - class NewSignature(Callback): - - def on_test_epoch_end(self, trainer, pl_module, outputs): - ... - - trainer.callbacks = [NewSignature()] - with no_deprecated_call(match="`Callback.on_test_epoch_end` signature has changed in v1.3."): - trainer.test(model) - - class NewSignatureModel(BoringModel): - - def on_test_epoch_end(self, outputs): - ... - - model = NewSignatureModel() - with no_deprecated_call(match="`ModelHooks.on_test_epoch_end` signature has changed in v1.3."): - trainer.test(model) - - -@pytest.mark.parametrize("cls", (BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler)) -def test_v1_5_0_profiler_output_filename(tmpdir, cls): - filepath = str(tmpdir / "test.txt") - with pytest.deprecated_call(match="`output_filename` parameter has been removed"): - profiler = cls(output_filename=filepath) - assert profiler.dirpath == tmpdir - assert profiler.filename == "test" diff --git a/tests/helpers/advanced_models.py b/tests/helpers/advanced_models.py index 2b0146e1ee099..7ad678b3046fd 100644 --- a/tests/helpers/advanced_models.py +++ b/tests/helpers/advanced_models.py @@ -20,7 +20,6 @@ from torch.utils.data import DataLoader from pytorch_lightning.core.lightning import LightningModule -from tests import PATH_DATASETS from tests.helpers.datasets import AverageDataset, MNIST, TrialMNIST @@ -166,7 +165,7 @@ def configure_optimizers(self): return [opt_g, opt_d], [] def train_dataloader(self): - return DataLoader(TrialMNIST(root=PATH_DATASETS, train=True, download=True), batch_size=16) + return DataLoader(TrialMNIST(train=True, download=True), batch_size=16) class ParityModuleRNN(LightningModule): @@ -224,7 +223,6 @@ def configure_optimizers(self): def train_dataloader(self): return DataLoader(MNIST( - root=PATH_DATASETS, train=True, download=True, ), batch_size=128, num_workers=1) diff --git a/tests/helpers/datasets.py b/tests/helpers/datasets.py index 77035796ca3b1..e7bdad0f1538c 100644 --- a/tests/helpers/datasets.py +++ b/tests/helpers/datasets.py @@ -22,6 +22,11 @@ from torch import Tensor from torch.utils.data import Dataset +from tests import _PROJECT_ROOT + +#: local path to test datasets +PATH_DATASETS = os.path.join(_PROJECT_ROOT, 'Datasets') + class MNIST(Dataset): """ @@ -42,7 +47,7 @@ class MNIST(Dataset): downloaded again. Examples: - >>> dataset = MNIST(".", download=True) + >>> dataset = MNIST(download=True) >>> len(dataset) 60000 >>> torch.bincount(dataset.targets) @@ -60,7 +65,7 @@ class MNIST(Dataset): def __init__( self, - root: str, + root: str = PATH_DATASETS, train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, @@ -147,7 +152,7 @@ class TrialMNIST(MNIST): kwargs: Same as MNIST Examples: - >>> dataset = TrialMNIST(".", download=True) + >>> dataset = TrialMNIST(download=True) >>> len(dataset) 300 >>> sorted(set([d.item() for d in dataset.targets])) @@ -156,7 +161,7 @@ class TrialMNIST(MNIST): tensor([100, 100, 100]) """ - def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs): + def __init__(self, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs): # number of examples per class self.num_samples = num_samples # take just a subset of MNIST dataset @@ -164,7 +169,7 @@ def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence] self.cache_folder_name = f"digits-{'-'.join(str(d) for d in self.digits)}_nb-{self.num_samples}" - super().__init__(root, normalize=(0.5, 1.0), **kwargs) + super().__init__(normalize=(0.5, 1.0), **kwargs) @staticmethod def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor, num_samples: int, digits: Sequence): diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index 5483e33d9cddb..fe85fbaea9025 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -56,7 +56,6 @@ def __new__( *args, min_gpus: int = 0, min_torch: Optional[str] = None, - max_torch: Optional[str] = None, min_python: Optional[str] = None, quantization: bool = False, amp_apex: bool = False, @@ -77,7 +76,6 @@ def __new__( args: native pytest.mark.skipif arguments min_gpus: min number of gpus required to run test min_torch: minimum pytorch version to run test - max_torch: maximum pytorch version to run test min_python: minimum python version required to run test quantization: if `torch.quantization` package is required to run test amp_apex: NVIDIA Apex is installed @@ -104,11 +102,6 @@ def __new__( conditions.append(torch_version < LooseVersion(min_torch)) reasons.append(f"torch>={min_torch}") - if max_torch: - torch_version = LooseVersion(get_distribution("torch").version) - conditions.append(torch_version >= LooseVersion(max_torch)) - reasons.append(f"torch<{max_torch}") - if min_python: py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" conditions.append(py_version < LooseVersion(min_python)) diff --git a/tests/helpers/test_datasets.py b/tests/helpers/test_datasets.py index 8c866bdbab789..6319fdb562504 100644 --- a/tests/helpers/test_datasets.py +++ b/tests/helpers/test_datasets.py @@ -16,19 +16,12 @@ import cloudpickle import pytest -from tests import PATH_DATASETS from tests.helpers.datasets import AverageDataset, MNIST, TrialMNIST -@pytest.mark.parametrize( - 'dataset_cls,args', [ - (MNIST, dict(root=PATH_DATASETS)), - (TrialMNIST, dict(root=PATH_DATASETS)), - (AverageDataset, dict()), - ] -) -def test_pickling_dataset_mnist(tmpdir, dataset_cls, args): - mnist = dataset_cls(**args) +@pytest.mark.parametrize('dataset_cls', [MNIST, TrialMNIST, AverageDataset]) +def test_pickling_dataset_mnist(tmpdir, dataset_cls): + mnist = dataset_cls() mnist_pickled = pickle.dumps(mnist) pickle.loads(mnist_pickled) diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index cb461fe4ef387..af951369cc49d 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -202,34 +202,6 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir): logger.log_hyperparams(params) -@mock.patch('pytorch_lightning.loggers.mlflow.time') -@mock.patch('pytorch_lightning.loggers.mlflow.mlflow') -@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') -def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): - """ - Test that the logger calls methods on the mlflow experiment correctly. - """ - time.return_value = 1 - - logger = MLFlowLogger('test', save_dir=tmpdir, artifact_location='my_artifact_location') - logger._mlflow_client.get_experiment_by_name.return_value = None - - params = {'test': 'test_param'} - logger.log_hyperparams(params) - - logger.experiment.log_param.assert_called_once_with(logger.run_id, 'test', 'test_param') - - metrics = {'some_metric': 10} - logger.log_metrics(metrics) - - logger.experiment.log_metric.assert_called_once_with(logger.run_id, 'some_metric', 10, 1000, None) - - logger._mlflow_client.create_experiment.assert_called_once_with( - name='test', - artifact_location='my_artifact_location', - ) - - @mock.patch('pytorch_lightning.loggers.mlflow.mlflow') @mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') @pytest.mark.parametrize("step_idx", [10, None]) diff --git a/tests/metrics/classification/__init__.py b/tests/metrics/classification/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/metrics/classification/inputs.py b/tests/metrics/classification/inputs.py new file mode 100644 index 0000000000000..7f2ac450385fe --- /dev/null +++ b/tests/metrics/classification/inputs.py @@ -0,0 +1,66 @@ +from collections import namedtuple + +import torch + +from tests.metrics.utils import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES + +Input = namedtuple('Input', ["preds", "target"]) + +_input_binary_prob = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) +) + +_input_binary = Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) +) + +_input_multilabel_prob = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) +) + +_input_multilabel_multidim_prob = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)) +) + +_input_multilabel = Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) +) + +_input_multilabel_multidim = Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)) +) + +# Generate edge multilabel edge case, where nothing matches (scores are undefined) +__temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) +__temp_target = abs(__temp_preds - 1) + +_input_multilabel_no_match = Input(preds=__temp_preds, target=__temp_target) + +__mc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) +__mc_prob_preds = __mc_prob_preds / __mc_prob_preds.sum(dim=2, keepdim=True) + +_input_multiclass_prob = Input( + preds=__mc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) +) + +_input_multiclass = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) +) + +__mdmc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM) +__mdmc_prob_preds = __mdmc_prob_preds / __mdmc_prob_preds.sum(dim=2, keepdim=True) + +_input_multidim_multiclass_prob = Input( + preds=__mdmc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) +) + +_input_multidim_multiclass = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) +) diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py new file mode 100644 index 0000000000000..bed60aa88388f --- /dev/null +++ b/tests/metrics/classification/test_accuracy.py @@ -0,0 +1,175 @@ +from functools import partial + +import numpy as np +import pytest +import torch +from sklearn.metrics import accuracy_score as sk_accuracy + +from pytorch_lightning.metrics import Accuracy +from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType +from pytorch_lightning.metrics.functional import accuracy +from tests.metrics.classification.inputs import _input_binary, _input_binary_prob +from tests.metrics.classification.inputs import _input_multiclass as _input_mcls +from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc +from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from tests.metrics.classification.inputs import _input_multilabel as _input_mlb +from tests.metrics.classification.inputs import _input_multilabel_multidim as _input_mlmd +from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob +from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from tests.metrics.utils import MetricTester, THRESHOLD + +torch.manual_seed(42) + + +def _sk_accuracy(preds, target, subset_accuracy): + sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD) + sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() + + if mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy: + sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1)) + sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2]) + elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy: + return np.all(sk_preds == sk_target, axis=(1, 2)).mean() + elif mode == DataType.MULTILABEL and not subset_accuracy: + sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1) + + return sk_accuracy(y_true=sk_target, y_pred=sk_preds) + + +@pytest.mark.parametrize( + "preds, target, subset_accuracy", + [ + (_input_binary_prob.preds, _input_binary_prob.target, False), + (_input_binary.preds, _input_binary.target, False), + (_input_mlb_prob.preds, _input_mlb_prob.target, True), + (_input_mlb_prob.preds, _input_mlb_prob.target, False), + (_input_mlb.preds, _input_mlb.target, True), + (_input_mlb.preds, _input_mlb.target, False), + (_input_mcls_prob.preds, _input_mcls_prob.target, False), + (_input_mcls.preds, _input_mcls.target, False), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, False), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, True), + (_input_mdmc.preds, _input_mdmc.target, False), + (_input_mdmc.preds, _input_mdmc.target, True), + (_input_mlmd_prob.preds, _input_mlmd_prob.target, True), + (_input_mlmd_prob.preds, _input_mlmd_prob.target, False), + (_input_mlmd.preds, _input_mlmd.target, True), + (_input_mlmd.preds, _input_mlmd.target, False), + ], +) +class TestAccuracies(MetricTester): + + @pytest.mark.parametrize("ddp", [False, True]) + @pytest.mark.parametrize("dist_sync_on_step", [False, True]) + def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accuracy): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=Accuracy, + sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy), + dist_sync_on_step=dist_sync_on_step, + metric_args={ + "threshold": THRESHOLD, + "subset_accuracy": subset_accuracy + }, + ) + + def test_accuracy_fn(self, preds, target, subset_accuracy): + self.run_functional_metric_test( + preds, + target, + metric_functional=accuracy, + sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy), + metric_args={ + "threshold": THRESHOLD, + "subset_accuracy": subset_accuracy + }, + ) + + +_l1to4 = [0.1, 0.2, 0.3, 0.4] +_l1to4t3 = np.array([_l1to4, _l1to4, _l1to4]) +_l1to4t3_mcls = [_l1to4t3.T, _l1to4t3.T, _l1to4t3.T] + +# The preds in these examples always put highest probability on class 3, second highest on class 2, +# third highest on class 1, and lowest on class 0 +_topk_preds_mcls = torch.tensor([_l1to4t3, _l1to4t3]).float() +_topk_target_mcls = torch.tensor([[1, 2, 3], [2, 1, 0]]) + +# This is like for MC case, but one sample in each batch is sabotaged with 0 class prediction :) +_topk_preds_mdmc = torch.tensor([_l1to4t3_mcls, _l1to4t3_mcls]).float() +_topk_target_mdmc = torch.tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]]) + + +# Replace with a proper sk_metric test once sklearn 0.24 hits :) +@pytest.mark.parametrize( + "preds, target, exp_result, k, subset_accuracy", + [ + (_topk_preds_mcls, _topk_target_mcls, 1 / 6, 1, False), + (_topk_preds_mcls, _topk_target_mcls, 3 / 6, 2, False), + (_topk_preds_mcls, _topk_target_mcls, 5 / 6, 3, False), + (_topk_preds_mcls, _topk_target_mcls, 1 / 6, 1, True), + (_topk_preds_mcls, _topk_target_mcls, 3 / 6, 2, True), + (_topk_preds_mcls, _topk_target_mcls, 5 / 6, 3, True), + (_topk_preds_mdmc, _topk_target_mdmc, 1 / 6, 1, False), + (_topk_preds_mdmc, _topk_target_mdmc, 8 / 18, 2, False), + (_topk_preds_mdmc, _topk_target_mdmc, 13 / 18, 3, False), + (_topk_preds_mdmc, _topk_target_mdmc, 1 / 6, 1, True), + (_topk_preds_mdmc, _topk_target_mdmc, 2 / 6, 2, True), + (_topk_preds_mdmc, _topk_target_mdmc, 3 / 6, 3, True), + ], +) +def test_topk_accuracy(preds, target, exp_result, k, subset_accuracy): + topk = Accuracy(top_k=k, subset_accuracy=subset_accuracy) + + for batch in range(preds.shape[0]): + topk(preds[batch], target[batch]) + + assert topk.compute() == exp_result + + # Test functional + total_samples = target.shape[0] * target.shape[1] + + preds = preds.view(total_samples, 4, -1) + target = target.view(total_samples, -1) + + assert accuracy(preds, target, top_k=k, subset_accuracy=subset_accuracy) == exp_result + + +# Only MC and MDMC with probs input type should be accepted for top_k +@pytest.mark.parametrize( + "preds, target", + [ + (_input_binary_prob.preds, _input_binary_prob.target), + (_input_binary.preds, _input_binary.target), + (_input_mlb_prob.preds, _input_mlb_prob.target), + (_input_mlb.preds, _input_mlb.target), + (_input_mcls.preds, _input_mcls.target), + (_input_mdmc.preds, _input_mdmc.target), + (_input_mlmd_prob.preds, _input_mlmd_prob.target), + (_input_mlmd.preds, _input_mlmd.target), + ], +) +def test_topk_accuracy_wrong_input_types(preds, target): + topk = Accuracy(top_k=1) + + with pytest.raises(ValueError): + topk(preds[0], target[0]) + + with pytest.raises(ValueError): + accuracy(preds[0], target[0], top_k=1) + + +@pytest.mark.parametrize("top_k, threshold", [(0, 0.5), (None, 1.5)]) +def test_wrong_params(top_k, threshold): + preds, target = _input_mcls_prob.preds, _input_mcls_prob.target + + with pytest.raises(ValueError): + acc = Accuracy(threshold=threshold, top_k=top_k) + acc(preds, target) + acc.compute() + + with pytest.raises(ValueError): + accuracy(preds, target, threshold=threshold, top_k=top_k) diff --git a/tests/metrics/classification/test_auc.py b/tests/metrics/classification/test_auc.py new file mode 100644 index 0000000000000..e902151ecffce --- /dev/null +++ b/tests/metrics/classification/test_auc.py @@ -0,0 +1,64 @@ +from collections import namedtuple + +import numpy as np +import pytest +import torch +from sklearn.metrics import auc as _sk_auc + +from pytorch_lightning.metrics.classification.auc import AUC +from pytorch_lightning.metrics.functional.auc import auc +from tests.metrics.utils import MetricTester, NUM_BATCHES + +torch.manual_seed(42) + + +def sk_auc(x, y): + x = x.flatten() + y = y.flatten() + return _sk_auc(x, y) + + +Input = namedtuple('Input', ["x", "y"]) + +_examples = [] +# generate already ordered samples, sorted in both directions +for i in range(4): + x = np.random.randint(0, 5, (NUM_BATCHES * 8)) + y = np.random.randint(0, 5, (NUM_BATCHES * 8)) + idx = np.argsort(x, kind='stable') + x = x[idx] if i % 2 == 0 else x[idx[::-1]] + y = y[idx] if i % 2 == 0 else x[idx[::-1]] + x = x.reshape(NUM_BATCHES, 8) + y = y.reshape(NUM_BATCHES, 8) + _examples.append(Input(x=torch.tensor(x), y=torch.tensor(y))) + + +@pytest.mark.parametrize("x, y", _examples) +class TestAUC(MetricTester): + + @pytest.mark.parametrize("ddp", [False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_auc(self, x, y, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp=ddp, + preds=x, + target=y, + metric_class=AUC, + sk_metric=sk_auc, + dist_sync_on_step=dist_sync_on_step, + ) + + def test_auc_functional(self, x, y): + self.run_functional_metric_test(x, y, metric_functional=auc, sk_metric=sk_auc, metric_args={"reorder": False}) + + +@pytest.mark.parametrize(['x', 'y', 'expected'], [ + pytest.param([0, 1], [0, 1], 0.5), + pytest.param([1, 0], [0, 1], 0.5), + pytest.param([1, 0, 0], [0, 1, 1], 0.5), + pytest.param([0, 1], [1, 1], 1), + pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5), +]) +def test_auc(x, y, expected): + # Test Area Under Curve (AUC) computation + assert auc(torch.tensor(x), torch.tensor(y), reorder=True) == expected diff --git a/tests/metrics/classification/test_auroc.py b/tests/metrics/classification/test_auroc.py new file mode 100644 index 0000000000000..0affcb1010225 --- /dev/null +++ b/tests/metrics/classification/test_auroc.py @@ -0,0 +1,142 @@ +from distutils.version import LooseVersion +from functools import partial + +import pytest +import torch +from sklearn.metrics import roc_auc_score as sk_roc_auc_score + +from pytorch_lightning.metrics.classification.auroc import AUROC +from pytorch_lightning.metrics.functional.auroc import auroc +from tests.metrics.classification.inputs import _input_binary_prob +from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob +from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from tests.metrics.utils import MetricTester, NUM_CLASSES + +torch.manual_seed(42) + + +def _sk_auroc_binary_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + return sk_roc_auc_score(y_true=sk_target, y_score=sk_preds, average=average, max_fpr=max_fpr) + + +def _sk_auroc_multiclass_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'): + sk_preds = preds.reshape(-1, num_classes).numpy() + sk_target = target.view(-1).numpy() + return sk_roc_auc_score( + y_true=sk_target, + y_score=sk_preds, + average=average, + max_fpr=max_fpr, + multi_class=multi_class, + ) + + +def _sk_auroc_multidim_multiclass_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'): + sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() + sk_target = target.view(-1).numpy() + return sk_roc_auc_score( + y_true=sk_target, + y_score=sk_preds, + average=average, + max_fpr=max_fpr, + multi_class=multi_class, + ) + + +def _sk_auroc_multilabel_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'): + sk_preds = preds.reshape(-1, num_classes).numpy() + sk_target = target.reshape(-1, num_classes).numpy() + return sk_roc_auc_score( + y_true=sk_target, + y_score=sk_preds, + average=average, + max_fpr=max_fpr, + multi_class=multi_class, + ) + + +def _sk_auroc_multilabel_multidim_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'): + sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() + sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() + return sk_roc_auc_score( + y_true=sk_target, + y_score=sk_preds, + average=average, + max_fpr=max_fpr, + multi_class=multi_class, + ) + + +@pytest.mark.parametrize( + "preds, target, sk_metric, num_classes", + [(_input_binary_prob.preds, _input_binary_prob.target, _sk_auroc_binary_prob, 1), + (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_auroc_multiclass_prob, NUM_CLASSES), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_auroc_multidim_multiclass_prob, NUM_CLASSES), + (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_auroc_multilabel_prob, NUM_CLASSES), + (_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_auroc_multilabel_multidim_prob, NUM_CLASSES)] +) +@pytest.mark.parametrize("average", ['macro', 'weighted']) +@pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) +class TestAUROC(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, ddp, dist_sync_on_step): + # max_fpr different from None is not support in multi class + if max_fpr is not None and num_classes != 1: + pytest.skip('max_fpr parameter not support for multi class or multi label') + + # max_fpr only supported for torch v1.6 or higher + if max_fpr is not None and LooseVersion(torch.__version__) < LooseVersion('1.6.0'): + pytest.skip('requires torch v1.6 or higher to test max_fpr argument') + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=AUROC, + sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr), + dist_sync_on_step=dist_sync_on_step, + metric_args={ + "num_classes": num_classes, + "average": average, + "max_fpr": max_fpr + }, + ) + + def test_auroc_functional(self, preds, target, sk_metric, num_classes, average, max_fpr): + # max_fpr different from None is not support in multi class + if max_fpr is not None and num_classes != 1: + pytest.skip('max_fpr parameter not support for multi class or multi label') + + # max_fpr only supported for torch v1.6 or higher + if max_fpr is not None and LooseVersion(torch.__version__) < LooseVersion('1.6.0'): + pytest.skip('requires torch v1.6 or higher to test max_fpr argument') + + self.run_functional_metric_test( + preds, + target, + metric_functional=auroc, + sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr), + metric_args={ + "num_classes": num_classes, + "average": average, + "max_fpr": max_fpr + }, + ) + + +def test_error_on_different_mode(): + """ test that an error is raised if the user pass in data of + different modes (binary, multi-label, multi-class) + """ + metric = AUROC() + # pass in multi-class data + metric.update(torch.randn(10, 5).softmax(dim=-1), torch.randint(0, 5, (10, ))) + with pytest.raises(ValueError, match=r"The mode of data.* should be constant.*"): + # pass in multi-label data + metric.update(torch.rand(10, 5), torch.randint(0, 2, (10, 5))) diff --git a/tests/metrics/classification/test_average_precision.py b/tests/metrics/classification/test_average_precision.py new file mode 100644 index 0000000000000..7cab20883e970 --- /dev/null +++ b/tests/metrics/classification/test_average_precision.py @@ -0,0 +1,97 @@ +from functools import partial + +import numpy as np +import pytest +import torch +from sklearn.metrics import average_precision_score as sk_average_precision_score + +from pytorch_lightning.metrics.classification.average_precision import AveragePrecision +from pytorch_lightning.metrics.functional.average_precision import average_precision +from tests.metrics.classification.inputs import _input_binary_prob +from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from tests.metrics.utils import MetricTester, NUM_CLASSES + +torch.manual_seed(42) + + +def _sk_average_precision_score(y_true, probas_pred, num_classes=1): + if num_classes == 1: + return sk_average_precision_score(y_true, probas_pred) + + res = [] + for i in range(num_classes): + y_true_temp = np.zeros_like(y_true) + y_true_temp[y_true == i] = 1 + res.append(sk_average_precision_score(y_true_temp, probas_pred[:, i])) + return res + + +def _sk_avg_prec_binary_prob(preds, target, num_classes=1): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1): + sk_preds = preds.reshape(-1, num_classes).numpy() + sk_target = target.view(-1).numpy() + + return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1): + sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() + sk_target = target.view(-1).numpy() + return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +@pytest.mark.parametrize( + "preds, target, sk_metric, num_classes", [ + (_input_binary_prob.preds, _input_binary_prob.target, _sk_avg_prec_binary_prob, 1), + (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_avg_prec_multiclass_prob, NUM_CLASSES), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_avg_prec_multidim_multiclass_prob, NUM_CLASSES), + ] +) +class TestAveragePrecision(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_average_precision(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=AveragePrecision, + sk_metric=partial(sk_metric, num_classes=num_classes), + dist_sync_on_step=dist_sync_on_step, + metric_args={"num_classes": num_classes} + ) + + def test_average_precision_functional(self, preds, target, sk_metric, num_classes): + self.run_functional_metric_test( + preds, + target, + metric_functional=average_precision, + sk_metric=partial(sk_metric, num_classes=num_classes), + metric_args={"num_classes": num_classes}, + ) + + +@pytest.mark.parametrize( + ['scores', 'target', 'expected_score'], + [ + # Check the average_precision_score of a constant predictor is + # the TPR + # Generate a dataset with 25% of positives + # And a constant score + # The precision is then the fraction of positive whatever the recall + # is, as there is only one threshold: + pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25), + # With threshold 0.8 : 1 TP and 2 TN and one FN + pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75), + ] +) +def test_average_precision(scores, target, expected_score): + assert average_precision(scores, target) == expected_score diff --git a/tests/metrics/classification/test_confusion_matrix.py b/tests/metrics/classification/test_confusion_matrix.py new file mode 100644 index 0000000000000..5371044d6d4b0 --- /dev/null +++ b/tests/metrics/classification/test_confusion_matrix.py @@ -0,0 +1,128 @@ +from functools import partial + +import numpy as np +import pytest +import torch +from sklearn.metrics import confusion_matrix as sk_confusion_matrix + +from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix +from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix +from tests.metrics.classification.inputs import _input_binary, _input_binary_prob +from tests.metrics.classification.inputs import _input_multiclass as _input_mcls +from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc +from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from tests.metrics.classification.inputs import _input_multilabel as _input_mlb +from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD + +torch.manual_seed(42) + + +def _sk_cm_binary_prob(preds, target, normalize=None): + sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) + sk_target = target.view(-1).numpy() + + return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +def _sk_cm_binary(preds, target, normalize=None): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +def _sk_cm_multilabel_prob(preds, target, normalize=None): + sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) + sk_target = target.view(-1).numpy() + + return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +def _sk_cm_multilabel(preds, target, normalize=None): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +def _sk_cm_multiclass_prob(preds, target, normalize=None): + sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +def _sk_cm_multiclass(preds, target, normalize=None): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +def _sk_cm_multidim_multiclass_prob(preds, target, normalize=None): + sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +def _sk_cm_multidim_multiclass(preds, target, normalize=None): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +@pytest.mark.parametrize("normalize", ['true', 'pred', 'all', None]) +@pytest.mark.parametrize( + "preds, target, sk_metric, num_classes", + [(_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2), + (_input_binary.preds, _input_binary.target, _sk_cm_binary, 2), + (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, 2), + (_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, 2), + (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES), + (_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES), + (_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES)] +) +class TestConfusionMatrix(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_confusion_matrix(self, normalize, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=ConfusionMatrix, + sk_metric=partial(sk_metric, normalize=normalize), + dist_sync_on_step=dist_sync_on_step, + metric_args={ + "num_classes": num_classes, + "threshold": THRESHOLD, + "normalize": normalize + } + ) + + def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes): + self.run_functional_metric_test( + preds, + target, + metric_functional=confusion_matrix, + sk_metric=partial(sk_metric, normalize=normalize), + metric_args={ + "num_classes": num_classes, + "threshold": THRESHOLD, + "normalize": normalize + } + ) + + +def test_warning_on_nan(tmpdir): + preds = torch.randint(3, size=(20, )) + target = torch.randint(3, size=(20, )) + + with pytest.warns(UserWarning, match='.* nan values found in confusion matrix have been replaced with zeros.'): + confusion_matrix(preds, target, num_classes=5, normalize='true') diff --git a/tests/metrics/classification/test_f_beta.py b/tests/metrics/classification/test_f_beta.py new file mode 100644 index 0000000000000..b9458fb6c530c --- /dev/null +++ b/tests/metrics/classification/test_f_beta.py @@ -0,0 +1,153 @@ +from functools import partial + +import numpy as np +import pytest +import torch +from sklearn.metrics import fbeta_score + +from pytorch_lightning.metrics import F1, FBeta +from pytorch_lightning.metrics.functional import f1, fbeta +from tests.metrics.classification.inputs import _input_binary, _input_binary_prob +from tests.metrics.classification.inputs import _input_multiclass as _input_mcls +from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc +from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from tests.metrics.classification.inputs import _input_multilabel as _input_mlb +from tests.metrics.classification.inputs import _input_multilabel_no_match as _input_mlb_nomatch +from tests.metrics.classification.inputs import _input_multilabel_prob as _mlb_prob_inputs +from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD + +torch.manual_seed(42) + + +def _sk_fbeta_binary_prob(preds, target, average='micro', beta=1.0): + sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) + sk_target = target.view(-1).numpy() + + return fbeta_score(y_true=sk_target, y_pred=sk_preds, average='binary', beta=beta) + + +def _sk_fbeta_binary(preds, target, average='micro', beta=1.0): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return fbeta_score(y_true=sk_target, y_pred=sk_preds, average='binary', beta=beta) + + +def _sk_fbeta_multilabel_prob(preds, target, average='micro', beta=1.0): + sk_preds = (preds.view(-1, NUM_CLASSES).numpy() >= THRESHOLD).astype(np.uint8) + sk_target = target.view(-1, NUM_CLASSES).numpy() + + return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) + + +def _sk_fbeta_multilabel(preds, target, average='micro', beta=1.0): + sk_preds = preds.view(-1, NUM_CLASSES).numpy() + sk_target = target.view(-1, NUM_CLASSES).numpy() + + return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) + + +def _sk_fbeta_multiclass_prob(preds, target, average='micro', beta=1.0): + sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() + sk_target = target.view(-1).numpy() + + return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) + + +def _sk_fbeta_multiclass(preds, target, average='micro', beta=1.0): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) + + +def _sk_fbeta_multidim_multiclass_prob(preds, target, average='micro', beta=1.0): + sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() + sk_target = target.view(-1).numpy() + + return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) + + +def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.0): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) + + +@pytest.mark.parametrize( + "preds, target, sk_metric, num_classes, multilabel", + [ + (_input_binary_prob.preds, _input_binary_prob.target, _sk_fbeta_binary_prob, 1, False), + (_input_binary.preds, _input_binary.target, _sk_fbeta_binary, 1, False), + (_mlb_prob_inputs.preds, _mlb_prob_inputs.target, _sk_fbeta_multilabel_prob, NUM_CLASSES, True), + (_input_mlb.preds, _input_mlb.target, _sk_fbeta_multilabel, NUM_CLASSES, True), + (_input_mlb_nomatch.preds, _input_mlb_nomatch.target, _sk_fbeta_multilabel, NUM_CLASSES, True), + (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_fbeta_multiclass_prob, NUM_CLASSES, False), + (_input_mcls.preds, _input_mcls.target, _sk_fbeta_multiclass, NUM_CLASSES, False), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_fbeta_multidim_multiclass_prob, NUM_CLASSES, False), + (_input_mdmc.preds, _input_mdmc.target, _sk_fbeta_multidim_multiclass, NUM_CLASSES, False), + ], +) +@pytest.mark.parametrize("average", ['micro', 'macro', 'weighted', None]) +@pytest.mark.parametrize("beta", [0.5, 1.0, 2.0]) +class TestFBeta(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_fbeta(self, preds, target, sk_metric, num_classes, multilabel, average, beta, ddp, dist_sync_on_step): + metric_class = F1 if beta == 1.0 else partial(FBeta, beta=beta) + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=metric_class, + sk_metric=partial(sk_metric, average=average, beta=beta), + dist_sync_on_step=dist_sync_on_step, + metric_args={ + "num_classes": num_classes, + "average": average, + "multilabel": multilabel, + "threshold": THRESHOLD, + }, + check_dist_sync_on_step=False, + check_batch=False, + ) + + def test_fbeta_functional(self, preds, target, sk_metric, num_classes, multilabel, average, beta): + metric_functional = f1 if beta == 1.0 else partial(fbeta, beta=beta) + + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=metric_functional, + sk_metric=partial(sk_metric, average=average, beta=beta), + metric_args={ + "num_classes": num_classes, + "average": average, + "multilabel": multilabel, + "threshold": THRESHOLD + } + ) + + +@pytest.mark.parametrize(['pred', 'target', 'beta', 'exp_score'], [ + pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 0.5, [0.5, 0.5]), + pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 1, [0.5, 0.5]), + pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 2, [0.5, 0.5]), +]) +def test_fbeta_score(pred, target, beta, exp_score): + score = fbeta(torch.tensor(pred), torch.tensor(target), num_classes=1, beta=beta, average='none') + assert torch.allclose(score, torch.tensor(exp_score)) + + +@pytest.mark.parametrize(['pred', 'target', 'exp_score'], [ + pytest.param([0., 0., 0., 0.], [1., 1., 1., 1.], [0.0, 0.0]), + pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], [0.5, 0.5]), + pytest.param([1., 0., 1., 0.], [1., 0., 1., 0.], [1.0, 1.0]), +]) +def test_f1_score(pred, target, exp_score): + score = f1(torch.tensor(pred), torch.tensor(target), num_classes=1, average='none') + assert torch.allclose(score, torch.tensor(exp_score)) diff --git a/tests/metrics/classification/test_hamming_distance.py b/tests/metrics/classification/test_hamming_distance.py new file mode 100644 index 0000000000000..c57072c033c8c --- /dev/null +++ b/tests/metrics/classification/test_hamming_distance.py @@ -0,0 +1,80 @@ +import pytest +import torch +from sklearn.metrics import hamming_loss as sk_hamming_loss + +from pytorch_lightning.metrics import HammingDistance +from pytorch_lightning.metrics.classification.helpers import _input_format_classification +from pytorch_lightning.metrics.functional import hamming_distance +from tests.metrics.classification.inputs import _input_binary, _input_binary_prob +from tests.metrics.classification.inputs import _input_multiclass as _input_mcls +from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc +from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from tests.metrics.classification.inputs import _input_multilabel as _input_mlb +from tests.metrics.classification.inputs import _input_multilabel_multidim as _input_mlmd +from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob +from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from tests.metrics.utils import MetricTester, THRESHOLD + +torch.manual_seed(42) + + +def _sk_hamming_loss(preds, target): + sk_preds, sk_target, _ = _input_format_classification(preds, target, threshold=THRESHOLD) + sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() + sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0], -1), sk_target.reshape(sk_target.shape[0], -1) + + return sk_hamming_loss(y_true=sk_target, y_pred=sk_preds) + + +@pytest.mark.parametrize( + "preds, target", + [ + (_input_binary_prob.preds, _input_binary_prob.target), + (_input_binary.preds, _input_binary.target), + (_input_mlb_prob.preds, _input_mlb_prob.target), + (_input_mlb.preds, _input_mlb.target), + (_input_mcls_prob.preds, _input_mcls_prob.target), + (_input_mcls.preds, _input_mcls.target), + (_input_mdmc_prob.preds, _input_mdmc_prob.target), + (_input_mdmc.preds, _input_mdmc.target), + (_input_mlmd_prob.preds, _input_mlmd_prob.target), + (_input_mlmd.preds, _input_mlmd.target), + ], +) +class TestHammingDistance(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [False, True]) + def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=HammingDistance, + sk_metric=_sk_hamming_loss, + dist_sync_on_step=dist_sync_on_step, + metric_args={"threshold": THRESHOLD}, + ) + + def test_hamming_distance_fn(self, preds, target): + self.run_functional_metric_test( + preds, + target, + metric_functional=hamming_distance, + sk_metric=_sk_hamming_loss, + metric_args={"threshold": THRESHOLD}, + ) + + +@pytest.mark.parametrize("threshold", [1.5]) +def test_wrong_params(threshold): + preds, target = _input_mcls_prob.preds, _input_mcls_prob.target + + with pytest.raises(ValueError): + ham_dist = HammingDistance(threshold=threshold) + ham_dist(preds, target) + ham_dist.compute() + + with pytest.raises(ValueError): + hamming_distance(preds, target, threshold=threshold) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py new file mode 100644 index 0000000000000..a78d799b1a07d --- /dev/null +++ b/tests/metrics/classification/test_inputs.py @@ -0,0 +1,311 @@ +import pytest +import torch +from torch import rand, randint + +from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType +from pytorch_lightning.metrics.utils import select_topk, to_onehot +from tests.metrics.classification.inputs import _input_binary as _bin +from tests.metrics.classification.inputs import _input_binary_prob as _bin_prob +from tests.metrics.classification.inputs import _input_multiclass as _mc +from tests.metrics.classification.inputs import _input_multiclass_prob as _mc_prob +from tests.metrics.classification.inputs import _input_multidim_multiclass as _mdmc +from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _mdmc_prob +from tests.metrics.classification.inputs import _input_multilabel as _ml +from tests.metrics.classification.inputs import _input_multilabel_multidim as _mlmd +from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _mlmd_prob +from tests.metrics.classification.inputs import _input_multilabel_prob as _ml_prob +from tests.metrics.classification.inputs import Input +from tests.metrics.utils import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, THRESHOLD + +torch.manual_seed(42) + +# Some additional inputs to test on +_ml_prob_half = Input(_ml_prob.preds.half(), _ml_prob.target) + +_mc_prob_2cls_preds = rand(NUM_BATCHES, BATCH_SIZE, 2) +_mc_prob_2cls_preds /= _mc_prob_2cls_preds.sum(dim=2, keepdim=True) +_mc_prob_2cls = Input(_mc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))) + +_mdmc_prob_many_dims_preds = rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, EXTRA_DIM) +_mdmc_prob_many_dims_preds /= _mdmc_prob_many_dims_preds.sum(dim=2, keepdim=True) +_mdmc_prob_many_dims = Input( + _mdmc_prob_many_dims_preds, + randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)), +) + +_mdmc_prob_2cls_preds = rand(NUM_BATCHES, BATCH_SIZE, 2, EXTRA_DIM) +_mdmc_prob_2cls_preds /= _mdmc_prob_2cls_preds.sum(dim=2, keepdim=True) +_mdmc_prob_2cls = Input(_mdmc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM))) + +# Some utils +T = torch.Tensor + + +def _idn(x): + return x + + +def _usq(x): + return x.unsqueeze(-1) + + +def _thrs(x): + return x >= THRESHOLD + + +def _rshp1(x): + return x.reshape(x.shape[0], -1) + + +def _rshp2(x): + return x.reshape(x.shape[0], x.shape[1], -1) + + +def _onehot(x): + return to_onehot(x, NUM_CLASSES) + + +def _onehot2(x): + return to_onehot(x, 2) + + +def _top1(x): + return select_topk(x, 1) + + +def _top2(x): + return select_topk(x, 2) + + +# To avoid ugly black line wrapping +def _ml_preds_tr(x): + return _rshp1(_thrs(x)) + + +def _onehot_rshp1(x): + return _onehot(_rshp1(x)) + + +def _onehot2_rshp1(x): + return _onehot2(_rshp1(x)) + + +def _top1_rshp2(x): + return _top1(_rshp2(x)) + + +def _top2_rshp2(x): + return _top2(_rshp2(x)) + + +def _probs_to_mc_preds_tr(x): + return _onehot2(_thrs(x)) + + +def _mlmd_prob_to_mc_preds_tr(x): + return _onehot2(_rshp1(_thrs(x))) + + +######################## +# Test correct inputs +######################## + + +@pytest.mark.parametrize( + "inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target", + [ + ############################# + # Test usual expected cases + (_bin, None, False, None, "multi-class", _usq, _usq), + (_bin, 1, False, None, "multi-class", _usq, _usq), + (_bin_prob, None, None, None, "binary", lambda x: _usq(_thrs(x)), _usq), + (_ml_prob, None, None, None, "multi-label", _thrs, _idn), + (_ml, None, False, None, "multi-dim multi-class", _idn, _idn), + (_ml_prob, None, None, None, "multi-label", _ml_preds_tr, _rshp1), + (_ml_prob, None, None, 2, "multi-label", _top2, _rshp1), + (_mlmd, None, False, None, "multi-dim multi-class", _rshp1, _rshp1), + (_mc, NUM_CLASSES, None, None, "multi-class", _onehot, _onehot), + (_mc_prob, None, None, None, "multi-class", _top1, _onehot), + (_mc_prob, None, None, 2, "multi-class", _top2, _onehot), + (_mdmc, NUM_CLASSES, None, None, "multi-dim multi-class", _onehot, _onehot), + (_mdmc_prob, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot), + (_mdmc_prob, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot), + (_mdmc_prob_many_dims, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), + (_mdmc_prob_many_dims, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), + ########################### + # Test some special cases + # Make sure that half precision works, i.e. is converted to full precision + (_ml_prob_half, None, None, None, "multi-label", lambda x: _ml_preds_tr(x.float()), _rshp1), + # Binary as multiclass + (_bin, None, None, None, "multi-class", _onehot2, _onehot2), + # Binary probs as multiclass + (_bin_prob, None, True, None, "binary", _probs_to_mc_preds_tr, _onehot2), + # Multilabel as multiclass + (_ml, None, True, None, "multi-dim multi-class", _onehot2, _onehot2), + # Multilabel probs as multiclass + (_ml_prob, None, True, None, "multi-label", _probs_to_mc_preds_tr, _onehot2), + # Multidim multilabel as multiclass + (_mlmd, None, True, None, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1), + # Multidim multilabel probs as multiclass + (_mlmd_prob, None, True, None, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1), + # Multiclass prob with 2 classes as binary + (_mc_prob_2cls, None, False, None, "multi-class", lambda x: _top1(x)[:, [1]], _usq), + # Multi-dim multi-class with 2 classes as multi-label + (_mdmc_prob_2cls, None, False, None, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), + ], +) +def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): + + def __get_data_type_enum(str_exp_mode): + return next(DataType[n] for n in dir(DataType) if DataType[n] == str_exp_mode) + + for exp_mode in (exp_mode, __get_data_type_enum(exp_mode)): + preds_out, target_out, mode = _input_format_classification( + preds=inputs.preds[0], + target=inputs.target[0], + threshold=THRESHOLD, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) + + assert mode == exp_mode + assert torch.equal(preds_out, post_preds(inputs.preds[0]).int()) + assert torch.equal(target_out, post_target(inputs.target[0]).int()) + + # Test that things work when batch_size = 1 + preds_out, target_out, mode = _input_format_classification( + preds=inputs.preds[0][[0], ...], + target=inputs.target[0][[0], ...], + threshold=THRESHOLD, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) + + assert mode == exp_mode + assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int()) + assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int()) + + +# Test that threshold is correctly applied +def test_threshold(): + target = T([1, 1, 1]).int() + preds_probs = T([0.5 - 1e-5, 0.5, 0.5 + 1e-5]) + + preds_probs_out, _, _ = _input_format_classification(preds_probs, target, threshold=0.5) + + assert torch.equal(torch.tensor([0, 1, 1], dtype=torch.int), preds_probs_out.squeeze().int()) + + +######################################################################## +# Test incorrect inputs +######################################################################## + + +@pytest.mark.parametrize("threshold", [-0.5, 0.0, 1.0, 1.5]) +def test_incorrect_threshold(threshold): + preds, target = rand(size=(7, )), randint(high=2, size=(7, )) + with pytest.raises(ValueError): + _input_format_classification(preds, target, threshold=threshold) + + +@pytest.mark.parametrize( + "preds, target, num_classes, is_multiclass", + [ + # Target not integer + (randint(high=2, size=(7, )), randint(high=2, size=(7, )).float(), None, None), + # Target negative + (randint(high=2, size=(7, )), -randint(high=2, size=(7, )), None, None), + # Preds negative integers + (-randint(high=2, size=(7, )), randint(high=2, size=(7, )), None, None), + # Negative probabilities + (-rand(size=(7, )), randint(high=2, size=(7, )), None, None), + # is_multiclass=False and target > 1 + (rand(size=(7, )), randint(low=2, high=4, size=(7, )), None, False), + # is_multiclass=False and preds integers with > 1 + (randint(low=2, high=4, size=(7, )), randint(high=2, size=(7, )), None, False), + # Wrong batch size + (randint(high=2, size=(8, )), randint(high=2, size=(7, )), None, None), + # Completely wrong shape + (randint(high=2, size=(7, )), randint(high=2, size=(7, 4)), None, None), + # Same #dims, different shape + (randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), None, None), + # Same shape and preds floats, target not binary + (rand(size=(7, 3)), randint(low=2, high=4, size=(7, 3)), None, None), + # #dims in preds = 1 + #dims in target, C shape not second or last + (rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), None, None), + # #dims in preds = 1 + #dims in target, preds not float + (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), None, None), + # is_multiclass=False, with C dimension > 2 + (_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE, )), None, False), + # Probs of multiclass preds do not sum up to 1 + (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), None, None), + # Max target larger or equal to C dimension + (_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, )), None, None), + # C dimension not equal to num_classes + (_mc_prob.preds[0], _mc_prob.target[0], NUM_CLASSES + 1, None), + # Max target larger than num_classes (with #dim preds = 1 + #dims target) + (_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, NUM_CLASSES)), 4, None), + # Max target larger than num_classes (with #dim preds = #dims target) + (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 4, None), + # Max preds larger than num_classes (with #dim preds = #dims target) + (randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 4, None), + # Num_classes=1, but is_multiclass not false + (randint(high=2, size=(7, )), randint(high=2, size=(7, )), 1, None), + # is_multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes + (randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False), + # Multilabel input with implied class dimension != num_classes + (rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False), + # Multilabel input with is_multiclass=True, but num_classes != 2 (or None) + (rand(size=(7, 3)), randint(high=2, size=(7, 3)), 4, True), + # Binary input, num_classes > 2 + (rand(size=(7, )), randint(high=2, size=(7, )), 4, None), + # Binary input, num_classes == 2 and is_multiclass not True + (rand(size=(7, )), randint(high=2, size=(7, )), 2, None), + (rand(size=(7, )), randint(high=2, size=(7, )), 2, False), + # Binary input, num_classes == 1 and is_multiclass=True + (rand(size=(7, )), randint(high=2, size=(7, )), 1, True), + ], +) +def test_incorrect_inputs(preds, target, num_classes, is_multiclass): + with pytest.raises(ValueError): + _input_format_classification( + preds=preds, target=target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + ) + + +@pytest.mark.parametrize( + "preds, target, num_classes, is_multiclass, top_k", + [ + # Topk set with non (md)mc or ml prob data + (_bin.preds[0], _bin.target[0], None, None, 2), + (_bin_prob.preds[0], _bin_prob.target[0], None, None, 2), + (_mc.preds[0], _mc.target[0], None, None, 2), + (_ml.preds[0], _ml.target[0], None, None, 2), + (_mlmd.preds[0], _mlmd.target[0], None, None, 2), + (_mdmc.preds[0], _mdmc.target[0], None, None, 2), + # top_k = 0 + (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, None, 0), + # top_k = float + (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, None, 0.123), + # top_k =2 with 2 classes, is_multiclass=False + (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, False, 2), + # top_k = number of classes (C dimension) + (_mc_prob.preds[0], _mc_prob.target[0], None, None, NUM_CLASSES), + # is_multiclass = True for ml prob inputs, top_k set + (_ml_prob.preds[0], _ml_prob.target[0], None, True, 2), + # top_k = num_classes for ml prob inputs + (_ml_prob.preds[0], _ml_prob.target[0], None, True, NUM_CLASSES), + ], +) +def test_incorrect_inputs_topk(preds, target, num_classes, is_multiclass, top_k): + with pytest.raises(ValueError): + _input_format_classification( + preds=preds, + target=target, + threshold=THRESHOLD, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) diff --git a/tests/metrics/classification/test_iou.py b/tests/metrics/classification/test_iou.py new file mode 100644 index 0000000000000..6bb100f68165a --- /dev/null +++ b/tests/metrics/classification/test_iou.py @@ -0,0 +1,216 @@ +from functools import partial + +import numpy as np +import pytest +import torch +from sklearn.metrics import jaccard_score as sk_jaccard_score + +from pytorch_lightning.metrics.classification.iou import IoU +from pytorch_lightning.metrics.functional.iou import iou +from tests.metrics.classification.inputs import _input_binary, _input_binary_prob +from tests.metrics.classification.inputs import _input_multiclass as _input_mcls +from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc +from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from tests.metrics.classification.inputs import _input_multilabel as _input_mlb +from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD + + +def _sk_iou_binary_prob(preds, target, average=None): + sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) + sk_target = target.view(-1).numpy() + + return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + + +def _sk_iou_binary(preds, target, average=None): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + + +def _sk_iou_multilabel_prob(preds, target, average=None): + sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) + sk_target = target.view(-1).numpy() + + return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + + +def _sk_iou_multilabel(preds, target, average=None): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + + +def _sk_iou_multiclass_prob(preds, target, average=None): + sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + + +def _sk_iou_multiclass(preds, target, average=None): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + + +def _sk_iou_multidim_multiclass_prob(preds, target, average=None): + sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + + +def _sk_iou_multidim_multiclass(preds, target, average=None): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + + +@pytest.mark.parametrize("reduction", ['elementwise_mean', 'none']) +@pytest.mark.parametrize( + "preds, target, sk_metric, num_classes", + [(_input_binary_prob.preds, _input_binary_prob.target, _sk_iou_binary_prob, 2), + (_input_binary.preds, _input_binary.target, _sk_iou_binary, 2), + (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_iou_multilabel_prob, 2), + (_input_mlb.preds, _input_mlb.target, _sk_iou_multilabel, 2), + (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_iou_multiclass_prob, NUM_CLASSES), + (_input_mcls.preds, _input_mcls.target, _sk_iou_multiclass, NUM_CLASSES), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_iou_multidim_multiclass_prob, NUM_CLASSES), + (_input_mdmc.preds, _input_mdmc.target, _sk_iou_multidim_multiclass, NUM_CLASSES)] +) +class TestIoU(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_confusion_matrix(self, reduction, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + average = 'macro' if reduction == 'elementwise_mean' else None # convert tags + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=IoU, + sk_metric=partial(sk_metric, average=average), + dist_sync_on_step=dist_sync_on_step, + metric_args={ + "num_classes": num_classes, + "threshold": THRESHOLD, + "reduction": reduction + } + ) + + def test_confusion_matrix_functional(self, reduction, preds, target, sk_metric, num_classes): + average = 'macro' if reduction == 'elementwise_mean' else None # convert tags + self.run_functional_metric_test( + preds, + target, + metric_functional=iou, + sk_metric=partial(sk_metric, average=average), + metric_args={ + "num_classes": num_classes, + "threshold": THRESHOLD, + "reduction": reduction + } + ) + + +@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [ + pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])), + pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])), + pytest.param(False, 'none', 0, torch.Tensor([1, 1])), + pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])), + pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])), + pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])), +]) +def test_iou(half_ones, reduction, ignore_index, expected): + pred = (torch.arange(120) % 3).view(-1, 1) + target = (torch.arange(120) % 3).view(-1, 1) + if half_ones: + pred[:60] = 1 + iou_val = iou( + pred=pred, + target=target, + ignore_index=ignore_index, + reduction=reduction, + ) + assert torch.allclose(iou_val, expected, atol=1e-9) + + +# test `absent_score` +@pytest.mark.parametrize( + ['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'], + [ + # Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid + # scores the function can return ([0., 1.] range, inclusive). + # 2 classes, class 0 is correct everywhere, class 1 is absent. + pytest.param([0], [0], None, -1., 2, [1., -1.]), + pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]), + # absent_score not applied if only class 0 is present and it's the only class. + pytest.param([0], [0], None, -1., 1, [1.]), + # 2 classes, class 1 is correct everywhere, class 0 is absent. + pytest.param([1], [1], None, -1., 2, [-1., 1.]), + pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]), + # When 0 index ignored, class 0 does not get a score (not even the absent_score). + pytest.param([1], [1], 0, -1., 2, [1.0]), + # 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score. + pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]), + pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]), + # 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score. + pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]), + pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]), + # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class + # 2 is absent. + pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]), + # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class + # 2 is absent. + pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]), + # Sanity checks with absent_score of 1.0. + pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]), + pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]), + ] +) +def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected): + iou_val = iou( + pred=torch.tensor(pred), + target=torch.tensor(target), + ignore_index=ignore_index, + absent_score=absent_score, + num_classes=num_classes, + reduction='none', + ) + assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) + + +# example data taken from +# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py +@pytest.mark.parametrize( + ['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], + [ + # Ignoring an index outside of [0, num_classes-1] should have no effect. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]), + # Ignoring a valid index drops only that index from the result. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]), + # When reducing to mean or sum, the ignored index does not contribute to the output. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]), + ] +) +def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected): + iou_val = iou( + pred=torch.tensor(pred), + target=torch.tensor(target), + ignore_index=ignore_index, + num_classes=num_classes, + reduction=reduction, + ) + assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py new file mode 100644 index 0000000000000..a9bf39044174a --- /dev/null +++ b/tests/metrics/classification/test_precision_recall.py @@ -0,0 +1,347 @@ +from functools import partial +from typing import Callable, Optional + +import numpy as np +import pytest +import torch +from sklearn.metrics import precision_score, recall_score + +from pytorch_lightning.metrics import Metric, Precision, Recall +from pytorch_lightning.metrics.classification.helpers import _input_format_classification +from pytorch_lightning.metrics.functional import precision, precision_recall, recall +from tests.metrics.classification.inputs import _input_binary, _input_binary_prob +from tests.metrics.classification.inputs import _input_multiclass as _input_mcls +from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc +from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from tests.metrics.classification.inputs import _input_multilabel as _input_mlb +from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD + +torch.manual_seed(42) + + +def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None): + if average == "none": + average = None + if num_classes == 1: + average = "binary" + + labels = list(range(num_classes)) + try: + labels.remove(ignore_index) + except ValueError: + pass + + sk_preds, sk_target, _ = _input_format_classification( + preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + ) + sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() + + sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels) + + if len(labels) != num_classes and not average: + sk_scores = np.insert(sk_scores, ignore_index, np.nan) + + return sk_scores + + +def _sk_prec_recall_multidim_multiclass( + preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average +): + preds, target, _ = _input_format_classification( + preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + ) + + if mdmc_average == "global": + preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) + target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) + + return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, ignore_index) + elif mdmc_average == "samplewise": + scores = [] + + for i in range(preds.shape[0]): + pred_i = preds[i, ...].T + target_i = target[i, ...].T + scores_i = _sk_prec_recall(pred_i, target_i, sk_fn, num_classes, average, False, ignore_index) + + scores.append(np.expand_dims(scores_i, 0)) + + return np.concatenate(scores).mean(axis=0) + + +@pytest.mark.parametrize("metric, fn_metric", [(Precision, precision), (Recall, recall)]) +@pytest.mark.parametrize( + "average, mdmc_average, num_classes, ignore_index, match_str", + [ + ("wrong", None, None, None, "`average`"), + ("micro", "wrong", None, None, "`mdmc"), + ("macro", None, None, None, "number of classes"), + ("macro", None, 1, 0, "ignore_index"), + ], +) +def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index, match_str): + with pytest.raises(ValueError, match=match_str): + metric( + average=average, + mdmc_average=mdmc_average, + num_classes=num_classes, + ignore_index=ignore_index, + ) + + with pytest.raises(ValueError, match=match_str): + fn_metric( + _input_binary.preds[0], + _input_binary.target[0], + average=average, + mdmc_average=mdmc_average, + num_classes=num_classes, + ignore_index=ignore_index, + ) + + with pytest.raises(ValueError, match=match_str): + precision_recall( + _input_binary.preds[0], + _input_binary.target[0], + average=average, + mdmc_average=mdmc_average, + num_classes=num_classes, + ignore_index=ignore_index, + ) + + +@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) +def test_zero_division(metric_class, metric_fn): + """ Test that zero_division works correctly (currently should just set to 0). """ + + preds = torch.tensor([1, 2, 1, 1]) + target = torch.tensor([2, 1, 2, 1]) + + cl_metric = metric_class(average="none", num_classes=3) + cl_metric(preds, target) + + result_cl = cl_metric.compute() + result_fn = metric_fn(preds, target, average="none", num_classes=3) + + assert result_cl[0] == result_fn[0] == 0 + + +@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) +def test_no_support(metric_class, metric_fn): + """This tests a rare edge case, where there is only one class present + in target, and ignore_index is set to exactly that class - and the + average method is equal to 'weighted'. + + This would mean that the sum of weights equals zero, and would, without + taking care of this case, return NaN. However, the reduction function + should catch that and set the metric to equal the value of zero_division + in this case (zero_division is for now not configurable and equals 0). + """ + + preds = torch.tensor([1, 1, 0, 0]) + target = torch.tensor([0, 0, 0, 0]) + + cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=0) + cl_metric(preds, target) + + result_cl = cl_metric.compute() + result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=0) + + assert result_cl == result_fn == 0 + + +@pytest.mark.parametrize( + "metric_class, metric_fn, sk_fn", [(Recall, recall, recall_score), (Precision, precision, precision_score)] +) +@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) +@pytest.mark.parametrize("ignore_index", [None, 0]) +@pytest.mark.parametrize( + "preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper", + [ + (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_prec_recall), + (_input_binary.preds, _input_binary.target, 1, False, None, _sk_prec_recall), + (_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), + (_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_prec_recall), + (_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), + (_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_prec_recall), + (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_multidim_multiclass), + ( + _input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "global", + _sk_prec_recall_multidim_multiclass + ), + (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_multidim_multiclass), + ( + _input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "samplewise", + _sk_prec_recall_multidim_multiclass + ), + ], +) +class TestPrecisionRecall(MetricTester): + + @pytest.mark.parametrize("ddp", [False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_precision_recall_class( + self, + ddp: bool, + dist_sync_on_step: bool, + preds: torch.Tensor, + target: torch.Tensor, + sk_wrapper: Callable, + metric_class: Metric, + metric_fn: Callable, + sk_fn: Callable, + is_multiclass: Optional[bool], + num_classes: Optional[int], + average: str, + mdmc_average: Optional[str], + ignore_index: Optional[int], + ): + if num_classes == 1 and average != "micro": + pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") + + if ignore_index is not None and preds.ndim == 2: + pytest.skip("Skipping ignore_index test with binary inputs.") + + if average == "weighted" and ignore_index is not None and mdmc_average is not None: + pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=metric_class, + sk_metric=partial( + sk_wrapper, + sk_fn=sk_fn, + average=average, + num_classes=num_classes, + is_multiclass=is_multiclass, + ignore_index=ignore_index, + mdmc_average=mdmc_average, + ), + dist_sync_on_step=dist_sync_on_step, + metric_args={ + "num_classes": num_classes, + "average": average, + "threshold": THRESHOLD, + "is_multiclass": is_multiclass, + "ignore_index": ignore_index, + "mdmc_average": mdmc_average, + }, + check_dist_sync_on_step=True, + check_batch=True, + ) + + def test_precision_recall_fn( + self, + preds: torch.Tensor, + target: torch.Tensor, + sk_wrapper: Callable, + metric_class: Metric, + metric_fn: Callable, + sk_fn: Callable, + is_multiclass: Optional[bool], + num_classes: Optional[int], + average: str, + mdmc_average: Optional[str], + ignore_index: Optional[int], + ): + if num_classes == 1 and average != "micro": + pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") + + if ignore_index is not None and preds.ndim == 2: + pytest.skip("Skipping ignore_index test with binary inputs.") + + if average == "weighted" and ignore_index is not None and mdmc_average is not None: + pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + + self.run_functional_metric_test( + preds, + target, + metric_functional=metric_fn, + sk_metric=partial( + sk_wrapper, + sk_fn=sk_fn, + average=average, + num_classes=num_classes, + is_multiclass=is_multiclass, + ignore_index=ignore_index, + mdmc_average=mdmc_average, + ), + metric_args={ + "num_classes": num_classes, + "average": average, + "threshold": THRESHOLD, + "is_multiclass": is_multiclass, + "ignore_index": ignore_index, + "mdmc_average": mdmc_average, + }, + ) + + +@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) +def test_precision_recall_joint(average): + """A simple test of the joint precision_recall metric. + + No need to test this thorougly, as it is just a combination of precision and recall, + which are already tested thoroughly. + """ + + precision_result = precision( + _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES + ) + recall_result = recall( + _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES + ) + + prec_recall_result = precision_recall( + _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES + ) + + assert torch.equal(precision_result, prec_recall_result[0]) + assert torch.equal(recall_result, prec_recall_result[1]) + + +_mc_k_target = torch.tensor([0, 1, 2]) +_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) +_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) +_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) + + +@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) +@pytest.mark.parametrize( + "k, preds, target, average, expected_prec, expected_recall", + [ + (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)), + (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1 / 2), torch.tensor(1.0)), + (1, _ml_k_preds, _ml_k_target, "micro", torch.tensor(0.0), torch.tensor(0.0)), + (2, _ml_k_preds, _ml_k_target, "micro", torch.tensor(1 / 6), torch.tensor(1 / 3)), + ], +) +def test_top_k( + metric_class, + metric_fn, + k: int, + preds: torch.Tensor, + target: torch.Tensor, + average: str, + expected_prec: torch.Tensor, + expected_recall: torch.Tensor, +): + """A simple test to check that top_k works as expected. + + Just a sanity check, the tests in StatScores should already guarantee + the corectness of results. + """ + + class_metric = metric_class(top_k=k, average=average, num_classes=3) + class_metric.update(preds, target) + + if metric_class.__name__ == "Precision": + result = expected_prec + else: + result = expected_recall + + assert torch.equal(class_metric.compute(), result) + assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) diff --git a/tests/metrics/classification/test_precision_recall_curve.py b/tests/metrics/classification/test_precision_recall_curve.py new file mode 100644 index 0000000000000..6a60e1fd36fdd --- /dev/null +++ b/tests/metrics/classification/test_precision_recall_curve.py @@ -0,0 +1,97 @@ +from functools import partial + +import numpy as np +import pytest +import torch +from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve + +from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve +from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve +from tests.metrics.classification.inputs import _input_binary_prob +from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from tests.metrics.utils import MetricTester, NUM_CLASSES + +torch.manual_seed(42) + + +def _sk_precision_recall_curve(y_true, probas_pred, num_classes=1): + """ Adjusted comparison function that can also handles multiclass """ + if num_classes == 1: + return sk_precision_recall_curve(y_true, probas_pred) + + precision, recall, thresholds = [], [], [] + for i in range(num_classes): + y_true_temp = np.zeros_like(y_true) + y_true_temp[y_true == i] = 1 + res = sk_precision_recall_curve(y_true_temp, probas_pred[:, i]) + precision.append(res[0]) + recall.append(res[1]) + thresholds.append(res[2]) + return precision, recall, thresholds + + +def _sk_prec_rc_binary_prob(preds, target, num_classes=1): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +def _sk_prec_rc_multiclass_prob(preds, target, num_classes=1): + sk_preds = preds.reshape(-1, num_classes).numpy() + sk_target = target.view(-1).numpy() + + return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +def _sk_prec_rc_multidim_multiclass_prob(preds, target, num_classes=1): + sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() + sk_target = target.view(-1).numpy() + return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +@pytest.mark.parametrize( + "preds, target, sk_metric, num_classes", [ + (_input_binary_prob.preds, _input_binary_prob.target, _sk_prec_rc_binary_prob, 1), + (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_prec_rc_multiclass_prob, NUM_CLASSES), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_prec_rc_multidim_multiclass_prob, NUM_CLASSES), + ] +) +class TestPrecisionRecallCurve(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_precision_recall_curve(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=PrecisionRecallCurve, + sk_metric=partial(sk_metric, num_classes=num_classes), + dist_sync_on_step=dist_sync_on_step, + metric_args={"num_classes": num_classes} + ) + + def test_precision_recall_curve_functional(self, preds, target, sk_metric, num_classes): + self.run_functional_metric_test( + preds, + target, + metric_functional=precision_recall_curve, + sk_metric=partial(sk_metric, num_classes=num_classes), + metric_args={"num_classes": num_classes}, + ) + + +@pytest.mark.parametrize( + ['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], + [pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])] +) +def test_pr_curve(pred, target, expected_p, expected_r, expected_t): + p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target)) + assert p.size() == r.size() + assert p.size(0) == t.size(0) + 1 + + assert torch.allclose(p, torch.tensor(expected_p).to(p)) + assert torch.allclose(r, torch.tensor(expected_r).to(r)) + assert torch.allclose(t, torch.tensor(expected_t).to(t)) diff --git a/tests/metrics/classification/test_roc.py b/tests/metrics/classification/test_roc.py new file mode 100644 index 0000000000000..46a23322ca1c0 --- /dev/null +++ b/tests/metrics/classification/test_roc.py @@ -0,0 +1,99 @@ +from functools import partial + +import numpy as np +import pytest +import torch +from sklearn.metrics import roc_curve as sk_roc_curve + +from pytorch_lightning.metrics.classification.roc import ROC +from pytorch_lightning.metrics.functional.roc import roc +from tests.metrics.classification.inputs import _input_binary_prob +from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from tests.metrics.utils import MetricTester, NUM_CLASSES + +torch.manual_seed(42) + + +def _sk_roc_curve(y_true, probas_pred, num_classes=1): + """ Adjusted comparison function that can also handles multiclass """ + if num_classes == 1: + return sk_roc_curve(y_true, probas_pred, drop_intermediate=False) + + fpr, tpr, thresholds = [], [], [] + for i in range(num_classes): + y_true_temp = np.zeros_like(y_true) + y_true_temp[y_true == i] = 1 + res = sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False) + fpr.append(res[0]) + tpr.append(res[1]) + thresholds.append(res[2]) + return fpr, tpr, thresholds + + +def _sk_roc_binary_prob(preds, target, num_classes=1): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +def _sk_roc_multiclass_prob(preds, target, num_classes=1): + sk_preds = preds.reshape(-1, num_classes).numpy() + sk_target = target.view(-1).numpy() + + return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1): + sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() + sk_target = target.view(-1).numpy() + return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +@pytest.mark.parametrize( + "preds, target, sk_metric, num_classes", [ + (_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), + (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), + ] +) +class TestROC(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_roc(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=ROC, + sk_metric=partial(sk_metric, num_classes=num_classes), + dist_sync_on_step=dist_sync_on_step, + metric_args={"num_classes": num_classes} + ) + + def test_roc_functional(self, preds, target, sk_metric, num_classes): + self.run_functional_metric_test( + preds, + target, + metric_functional=roc, + sk_metric=partial(sk_metric, num_classes=num_classes), + metric_args={"num_classes": num_classes}, + ) + + +@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [ + pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]), + pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]), + pytest.param([1, 1], [1, 0], [0, 1], [0, 1]), + pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]), + pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]), +]) +def test_roc_curve(pred, target, expected_tpr, expected_fpr): + fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target)) + + assert fpr.shape == tpr.shape + assert fpr.size(0) == thresh.size(0) + assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr)) + assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr)) diff --git a/tests/metrics/classification/test_stat_scores.py b/tests/metrics/classification/test_stat_scores.py new file mode 100644 index 0000000000000..659765931c433 --- /dev/null +++ b/tests/metrics/classification/test_stat_scores.py @@ -0,0 +1,255 @@ +from functools import partial +from typing import Callable, Optional + +import numpy as np +import pytest +import torch +from sklearn.metrics import multilabel_confusion_matrix + +from pytorch_lightning.metrics import StatScores +from pytorch_lightning.metrics.classification.helpers import _input_format_classification +from pytorch_lightning.metrics.functional import stat_scores +from tests.metrics.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass +from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mccls_prob +from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc +from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from tests.metrics.classification.inputs import _input_multilabel as _input_mcls +from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD + +torch.manual_seed(42) + + +def _sk_stat_scores(preds, target, reduce, num_classes, is_multiclass, ignore_index, top_k, mdmc_reduce=None): + preds, target, _ = _input_format_classification( + preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k + ) + sk_preds, sk_target = preds.numpy(), target.numpy() + + if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1: + sk_preds = np.delete(sk_preds, ignore_index, 1) + sk_target = np.delete(sk_target, ignore_index, 1) + + if preds.shape[1] == 1 and reduce == "samples": + sk_target = sk_target.T + sk_preds = sk_preds.T + + sk_stats = multilabel_confusion_matrix( + sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1 + ) + + if preds.shape[1] == 1 and reduce != "samples": + sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]] + else: + sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]] + + if reduce == "micro": + sk_stats = sk_stats.sum(axis=0, keepdims=True) + + sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1) + + if reduce == "micro": + sk_stats = sk_stats[0] + + if reduce == "macro" and ignore_index is not None and preds.shape[1]: + sk_stats[ignore_index, :] = -1 + + return sk_stats + + +def _sk_stat_scores_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, is_multiclass, ignore_index, top_k): + preds, target, _ = _input_format_classification( + preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k + ) + + if mdmc_reduce == "global": + preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) + target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) + + return _sk_stat_scores(preds, target, reduce, None, False, ignore_index, top_k) + elif mdmc_reduce == "samplewise": + scores = [] + + for i in range(preds.shape[0]): + pred_i = preds[i, ...].T + target_i = target[i, ...].T + scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False, ignore_index, top_k) + + scores.append(np.expand_dims(scores_i, 0)) + + return np.concatenate(scores) + + +@pytest.mark.parametrize( + "reduce, mdmc_reduce, num_classes, inputs, ignore_index", + [ + ["unknown", None, None, _input_binary, None], + ["micro", "unknown", None, _input_binary, None], + ["macro", None, None, _input_binary, None], + ["micro", None, None, _input_mdmc_prob, None], + ["micro", None, None, _input_binary_prob, 0], + ["micro", None, None, _input_mccls_prob, NUM_CLASSES], + ["micro", None, NUM_CLASSES, _input_mccls_prob, NUM_CLASSES], + ], +) +def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index): + """Test a combination of parameters that are invalid and should raise an error. + + This includes invalid ``reduce`` and ``mdmc_reduce`` parameter values, not setting + ``num_classes`` when ``reduce='macro'`, not setting ``mdmc_reduce`` when inputs + are multi-dim multi-class``, setting ``ignore_index`` when inputs are binary, as well + as setting ``ignore_index`` to a value higher than the number of classes. + """ + with pytest.raises(ValueError): + stat_scores( + inputs.preds[0], inputs.target[0], reduce, mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index + ) + + with pytest.raises(ValueError): + sts = StatScores(reduce=reduce, mdmc_reduce=mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index) + sts(inputs.preds[0], inputs.target[0]) + + +def test_wrong_threshold(): + with pytest.raises(ValueError): + StatScores(threshold=1.5) + + +@pytest.mark.parametrize("ignore_index", [None, 0]) +@pytest.mark.parametrize("reduce", ["micro", "macro", "samples"]) +@pytest.mark.parametrize( + "preds, target, sk_fn, mdmc_reduce, num_classes, is_multiclass, top_k", + [ + (_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None), + (_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None), + (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None), + (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2), + (_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None), + (_input_mccls_prob.preds, _input_mccls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None), + (_input_mccls_prob.preds, _input_mccls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2), + (_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None), + (_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None), + ( + _input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, + None + ), + (_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None), + ], +) +class TestStatScores(MetricTester): + # DDP tests temporarily disabled due to hanging issues + @pytest.mark.parametrize("ddp", [False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_stat_scores_class( + self, + ddp: bool, + dist_sync_on_step: bool, + sk_fn: Callable, + preds: torch.Tensor, + target: torch.Tensor, + reduce: str, + mdmc_reduce: Optional[str], + num_classes: Optional[int], + is_multiclass: Optional[bool], + ignore_index: Optional[int], + top_k: Optional[int], + ): + if ignore_index is not None and preds.ndim == 2: + pytest.skip("Skipping ignore_index test with binary inputs.") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=StatScores, + sk_metric=partial( + sk_fn, + reduce=reduce, + mdmc_reduce=mdmc_reduce, + num_classes=num_classes, + is_multiclass=is_multiclass, + ignore_index=ignore_index, + top_k=top_k, + ), + dist_sync_on_step=dist_sync_on_step, + metric_args={ + "num_classes": num_classes, + "reduce": reduce, + "mdmc_reduce": mdmc_reduce, + "threshold": THRESHOLD, + "is_multiclass": is_multiclass, + "ignore_index": ignore_index, + "top_k": top_k, + }, + check_dist_sync_on_step=True, + check_batch=True, + ) + + def test_stat_scores_fn( + self, + sk_fn: Callable, + preds: torch.Tensor, + target: torch.Tensor, + reduce: str, + mdmc_reduce: Optional[str], + num_classes: Optional[int], + is_multiclass: Optional[bool], + ignore_index: Optional[int], + top_k: Optional[int], + ): + if ignore_index is not None and preds.ndim == 2: + pytest.skip("Skipping ignore_index test with binary inputs.") + + self.run_functional_metric_test( + preds, + target, + metric_functional=stat_scores, + sk_metric=partial( + sk_fn, + reduce=reduce, + mdmc_reduce=mdmc_reduce, + num_classes=num_classes, + is_multiclass=is_multiclass, + ignore_index=ignore_index, + top_k=top_k, + ), + metric_args={ + "num_classes": num_classes, + "reduce": reduce, + "mdmc_reduce": mdmc_reduce, + "threshold": THRESHOLD, + "is_multiclass": is_multiclass, + "ignore_index": ignore_index, + "top_k": top_k, + }, + ) + + +_mc_k_target = torch.tensor([0, 1, 2]) +_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) +_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) +_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) + + +@pytest.mark.parametrize( + "k, preds, target, reduce, expected", + [ + (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor([2, 1, 5, 1, 3])), + (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor([3, 3, 3, 0, 3])), + (1, _ml_k_preds, _ml_k_target, "micro", torch.tensor([0, 3, 3, 3, 3])), + (2, _ml_k_preds, _ml_k_target, "micro", torch.tensor([1, 5, 1, 2, 3])), + (1, _mc_k_preds, _mc_k_target, "macro", torch.tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])), + (2, _mc_k_preds, _mc_k_target, "macro", torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])), + (1, _ml_k_preds, _ml_k_target, "macro", torch.tensor([[0, 0, 0], [1, 0, 2], [1, 1, 1], [1, 2, 0], [1, 2, 0]])), + (2, _ml_k_preds, _ml_k_target, "macro", torch.tensor([[0, 1, 0], [2, 0, 3], [0, 1, 0], [1, 1, 0], [1, 2, 0]])), + ], +) +def test_top_k(k: int, preds: torch.Tensor, target: torch.Tensor, reduce: str, expected: torch.Tensor): + """ A simple test to check that top_k works as expected """ + + class_metric = StatScores(top_k=k, reduce=reduce, num_classes=3) + class_metric.update(preds, target) + + assert torch.equal(class_metric.compute(), expected.T) + assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T) diff --git a/tests/metrics/functional/__init__.py b/tests/metrics/functional/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py new file mode 100644 index 0000000000000..39622c4cd3550 --- /dev/null +++ b/tests/metrics/functional/test_classification.py @@ -0,0 +1,89 @@ +import pytest +import torch + +from pytorch_lightning import seed_everything +from pytorch_lightning.metrics.functional.classification import dice_score +from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve +from pytorch_lightning.metrics.utils import get_num_classes, to_categorical, to_onehot + + +def test_onehot(): + test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + expected = torch.stack([ + torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), + torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) + ]) + + assert test_tensor.shape == (2, 5) + assert expected.shape == (2, 10, 5) + + onehot_classes = to_onehot(test_tensor, num_classes=10) + onehot_no_classes = to_onehot(test_tensor) + + assert torch.allclose(onehot_classes, onehot_no_classes) + + assert onehot_classes.shape == expected.shape + assert onehot_no_classes.shape == expected.shape + + assert torch.allclose(expected.to(onehot_no_classes), onehot_no_classes) + assert torch.allclose(expected.to(onehot_classes), onehot_classes) + + +def test_to_categorical(): + test_tensor = torch.stack([ + torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), + torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) + ]).to(torch.float) + + expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + assert expected.shape == (2, 5) + assert test_tensor.shape == (2, 10, 5) + + result = to_categorical(test_tensor) + + assert result.shape == expected.shape + assert torch.allclose(result, expected.to(result.dtype)) + + +@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [ + pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10), + pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), + pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), +]) +def test_get_num_classes(pred, target, num_classes, expected_num_classes): + assert get_num_classes(pred, target, num_classes) == expected_num_classes + + +@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [ + pytest.param(1, 1., 42), + pytest.param(None, 1., 42), +]) +def test_binary_clf_curve(sample_weight, pos_label, exp_shape): + # TODO: move back the pred and target to test func arguments + # if you fix the array inside the function, you'd also have fix the shape, + # because when the array changes, you also have to fix the shape + seed_everything(0) + pred = torch.randint(low=51, high=99, size=(100, ), dtype=torch.float) / 100 + target = torch.tensor([0, 1] * 50, dtype=torch.int) + if sample_weight is not None: + sample_weight = torch.ones_like(pred) * sample_weight + + fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) + + assert isinstance(tps, torch.Tensor) + assert isinstance(fps, torch.Tensor) + assert isinstance(thresh, torch.Tensor) + assert tps.shape == (exp_shape, ) + assert fps.shape == (exp_shape, ) + assert thresh.shape == (exp_shape, ) + + +@pytest.mark.parametrize(['pred', 'target', 'expected'], [ + pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.), + pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.), + pytest.param([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3), + pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.), +]) +def test_dice_score(pred, target, expected): + score = dice_score(torch.tensor(pred), torch.tensor(target)) + assert score == expected diff --git a/tests/metrics/functional/test_image_gradients.py b/tests/metrics/functional/test_image_gradients.py new file mode 100644 index 0000000000000..2e406793b4370 --- /dev/null +++ b/tests/metrics/functional/test_image_gradients.py @@ -0,0 +1,109 @@ +import pytest +import torch + +from pytorch_lightning.metrics.functional.image_gradients import image_gradients + + +def test_invalid_input_img_type(): + """Test Whether the module successfully handles invalid input data type""" + invalid_dummy_input = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + + with pytest.raises(TypeError): + image_gradients(invalid_dummy_input) + + +def test_invalid_input_ndims(): + """ + Test whether the module successfully handles invalid number of dimensions + of input tensor + """ + + BATCH_SIZE = 1 + HEIGHT = 5 + WIDTH = 5 + CHANNELS = 1 + + image = torch.arange(0, BATCH_SIZE * HEIGHT * WIDTH * CHANNELS, dtype=torch.float32) + image = torch.reshape(image, (HEIGHT, WIDTH)) + + with pytest.raises(RuntimeError): + image_gradients(image) + + +def test_multi_batch_image_gradients(): + """Test whether the module correctly calculates gradients for known input + with non-unity batch size.Example input-output pair taken from TF's implementation of i + mage-gradients + """ + + BATCH_SIZE = 5 + HEIGHT = 5 + WIDTH = 5 + CHANNELS = 1 + + single_channel_img = torch.arange(0, 1 * HEIGHT * WIDTH * CHANNELS, dtype=torch.float32) + single_channel_img = torch.reshape(single_channel_img, (CHANNELS, HEIGHT, WIDTH)) + image = torch.stack([single_channel_img for _ in range(BATCH_SIZE)], dim=0) + + true_dy = [ + [5., 5., 5., 5., 5.], + [5., 5., 5., 5., 5.], + [5., 5., 5., 5., 5.], + [5., 5., 5., 5., 5.], + [0., 0., 0., 0., 0.], + ] + + true_dx = [ + [1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.], + ] + true_dy = torch.Tensor(true_dy) + true_dx = torch.Tensor(true_dx) + + dy, dx = image_gradients(image) + + for batch_id in range(BATCH_SIZE): + assert torch.allclose(dy[batch_id, 0, :, :], true_dy) + assert dy.shape == (BATCH_SIZE, 1, HEIGHT, WIDTH) + assert dx.shape == (BATCH_SIZE, 1, HEIGHT, WIDTH) + + +def test_image_gradients(): + """Test whether the module correctly calculates gradients for known input. + Example input-output pair taken from TF's implementation of image-gradients + """ + + BATCH_SIZE = 1 + HEIGHT = 5 + WIDTH = 5 + CHANNELS = 1 + + image = torch.arange(0, BATCH_SIZE * HEIGHT * WIDTH * CHANNELS, dtype=torch.float32) + image = torch.reshape(image, (BATCH_SIZE, CHANNELS, HEIGHT, WIDTH)) + + true_dy = [ + [5., 5., 5., 5., 5.], + [5., 5., 5., 5., 5.], + [5., 5., 5., 5., 5.], + [5., 5., 5., 5., 5.], + [0., 0., 0., 0., 0.], + ] + + true_dx = [ + [1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.], + ] + + true_dy = torch.Tensor(true_dy) + true_dx = torch.Tensor(true_dx) + + dy, dx = image_gradients(image) + + assert torch.allclose(dy, true_dy), "dy fails test" + assert torch.allclose(dx, true_dx), "dx fails tests" diff --git a/tests/metrics/functional/test_nlp.py b/tests/metrics/functional/test_nlp.py new file mode 100644 index 0000000000000..b8faadc16085f --- /dev/null +++ b/tests/metrics/functional/test_nlp.py @@ -0,0 +1,68 @@ +import pytest +import torch +from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction + +from pytorch_lightning.metrics.functional.nlp import bleu_score + +# example taken from +# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu +HYPOTHESIS1 = tuple( + "It is a guide to action which ensures that the military always obeys the commands of the party".split() +) +REFERENCE1 = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split()) +REFERENCE2 = tuple( + "It is a guiding principle which makes the military forces always being under the command of the Party".split() +) +REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split()) + +# example taken from +# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu +HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split() +HYP2 = "he read the book because he was interested in world history".split() + +REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split() +REF1B = "It is a guiding principle which makes the military force always being under the command of the Party".split() +REF1C = "It is the practical guide for the army always to heed the directions of the party".split() +REF2A = "he was interested in world history because he read the book".split() + +LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]] +HYPOTHESES = [HYP1, HYP2] + +# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction +smooth_func = SmoothingFunction().method2 + + +@pytest.mark.parametrize( + ["weights", "n_gram", "smooth_func", "smooth"], + [ + pytest.param([1], 1, None, False), + pytest.param([0.5, 0.5], 2, smooth_func, True), + pytest.param([0.333333, 0.333333, 0.333333], 3, None, False), + pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True), + ], +) +def test_bleu_score(weights, n_gram, smooth_func, smooth): + nltk_output = sentence_bleu( + [REFERENCE1, REFERENCE2, REFERENCE3], + HYPOTHESIS1, + weights=weights, + smoothing_function=smooth_func, + ) + pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth) + assert torch.allclose(pl_output, torch.tensor(nltk_output)) + + nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func) + pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth) + assert torch.allclose(pl_output, torch.tensor(nltk_output)) + + +def test_bleu_empty(): + hyp = [[]] + ref = [[[]]] + assert bleu_score(hyp, ref) == torch.tensor(0.0) + + +def test_no_4_gram(): + hyps = [["My", "full", "pytorch-lightning"]] + refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]] + assert bleu_score(hyps, refs) == torch.tensor(0.0) diff --git a/tests/metrics/functional/test_reduction.py b/tests/metrics/functional/test_reduction.py new file mode 100644 index 0000000000000..03a34f6c5a25b --- /dev/null +++ b/tests/metrics/functional/test_reduction.py @@ -0,0 +1,28 @@ +import pytest +import torch + +from pytorch_lightning.metrics.utils import class_reduce, reduce + + +def test_reduce(): + start_tensor = torch.rand(50, 40, 30) + + assert torch.allclose(reduce(start_tensor, 'elementwise_mean'), torch.mean(start_tensor)) + assert torch.allclose(reduce(start_tensor, 'sum'), torch.sum(start_tensor)) + assert torch.allclose(reduce(start_tensor, 'none'), start_tensor) + + with pytest.raises(ValueError): + reduce(start_tensor, 'error_reduction') + + +def test_class_reduce(): + num = torch.randint(1, 10, (100, )).float() + denom = torch.randint(10, 20, (100, )).float() + weights = torch.randint(1, 100, (100, )).float() + + assert torch.allclose(class_reduce(num, denom, weights, 'micro'), torch.sum(num) / torch.sum(denom)) + assert torch.allclose(class_reduce(num, denom, weights, 'macro'), torch.mean(num / denom)) + assert torch.allclose( + class_reduce(num, denom, weights, 'weighted'), torch.sum(num / denom * (weights / torch.sum(weights))) + ) + assert torch.allclose(class_reduce(num, denom, weights, 'none'), num / denom) diff --git a/tests/metrics/functional/test_self_supervised.py b/tests/metrics/functional/test_self_supervised.py new file mode 100644 index 0000000000000..fbabc5e93cffc --- /dev/null +++ b/tests/metrics/functional/test_self_supervised.py @@ -0,0 +1,32 @@ +import pytest +import torch +from sklearn.metrics import pairwise + +from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity + + +@pytest.mark.parametrize('similarity', ['cosine', 'dot']) +@pytest.mark.parametrize('reduction', ['none', 'mean', 'sum']) +def test_against_sklearn(similarity, reduction): + """Compare PL metrics to sklearn version.""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + batch = torch.randn(5, 10, device=device) # 100 samples in 10 dimensions + + pl_dist = embedding_similarity(batch, similarity=similarity, reduction=reduction, zero_diagonal=False) + + def sklearn_embedding_distance(batch, similarity, reduction): + + metric_func = {'cosine': pairwise.cosine_similarity, 'dot': pairwise.linear_kernel}[similarity] + + dist = metric_func(batch, batch) + if reduction == 'mean': + return dist.mean(axis=-1) + if reduction == 'sum': + return dist.sum(axis=-1) + return dist + + sk_dist = sklearn_embedding_distance(batch.cpu().detach().numpy(), similarity=similarity, reduction=reduction) + sk_dist = torch.tensor(sk_dist, dtype=torch.float, device=device) + + assert torch.allclose(sk_dist, pl_dist) diff --git a/tests/metrics/regression/__init__.py b/tests/metrics/regression/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/metrics/regression/test_explained_variance.py b/tests/metrics/regression/test_explained_variance.py new file mode 100644 index 0000000000000..adab562ac6055 --- /dev/null +++ b/tests/metrics/regression/test_explained_variance.py @@ -0,0 +1,77 @@ +from collections import namedtuple +from functools import partial + +import pytest +import torch +from sklearn.metrics import explained_variance_score + +from pytorch_lightning.metrics.functional import explained_variance +from pytorch_lightning.metrics.regression import ExplainedVariance +from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES + +torch.manual_seed(42) + +num_targets = 5 + +Input = namedtuple('Input', ["preds", "target"]) + +_single_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.rand(NUM_BATCHES, BATCH_SIZE), +) + +_multi_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), +) + + +def _single_target_sk_metric(preds, target, sk_fn=explained_variance_score): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + return sk_fn(sk_target, sk_preds) + + +def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score): + sk_preds = preds.view(-1, num_targets).numpy() + sk_target = target.view(-1, num_targets).numpy() + return sk_fn(sk_target, sk_preds) + + +@pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted']) +@pytest.mark.parametrize( + "preds, target, sk_metric", + [ + (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), + (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), + ], +) +class TestExplainedVariance(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_explained_variance(self, multioutput, preds, target, sk_metric, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + ExplainedVariance, + partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)), + dist_sync_on_step, + metric_args=dict(multioutput=multioutput), + ) + + def test_explained_variance_functional(self, multioutput, preds, target, sk_metric): + self.run_functional_metric_test( + preds, + target, + explained_variance, + partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)), + metric_args=dict(multioutput=multioutput), + ) + + +def test_error_on_different_shape(metric_class=ExplainedVariance): + metric = metric_class() + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/metrics/regression/test_mean_error.py b/tests/metrics/regression/test_mean_error.py new file mode 100644 index 0000000000000..041ce12f11164 --- /dev/null +++ b/tests/metrics/regression/test_mean_error.py @@ -0,0 +1,87 @@ +from collections import namedtuple +from functools import partial + +import pytest +import torch +from sklearn.metrics import mean_absolute_error as sk_mean_absolute_error +from sklearn.metrics import mean_squared_error as sk_mean_squared_error +from sklearn.metrics import mean_squared_log_error as sk_mean_squared_log_error + +from pytorch_lightning.metrics.functional import mean_absolute_error, mean_squared_error, mean_squared_log_error +from pytorch_lightning.metrics.regression import MeanAbsoluteError, MeanSquaredError, MeanSquaredLogError +from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES + +torch.manual_seed(42) + +num_targets = 5 + +Input = namedtuple('Input', ["preds", "target"]) + +_single_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.rand(NUM_BATCHES, BATCH_SIZE), +) + +_multi_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), +) + + +def _single_target_sk_metric(preds, target, sk_fn=mean_squared_error): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + return sk_fn(sk_preds, sk_target) + + +def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error): + sk_preds = preds.view(-1, num_targets).numpy() + sk_target = target.view(-1, num_targets).numpy() + return sk_fn(sk_preds, sk_target) + + +@pytest.mark.parametrize( + "preds, target, sk_metric", + [ + (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), + (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), + ], +) +@pytest.mark.parametrize( + "metric_class, metric_functional, sk_fn", + [ + (MeanSquaredError, mean_squared_error, sk_mean_squared_error), + (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error), + (MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error), + ], +) +class TestMeanError(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_mean_error_class( + self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, ddp, dist_sync_on_step + ): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=metric_class, + sk_metric=partial(sk_metric, sk_fn=sk_fn), + dist_sync_on_step=dist_sync_on_step, + ) + + def test_mean_error_functional(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=metric_functional, + sk_metric=partial(sk_metric, sk_fn=sk_fn), + ) + + +@pytest.mark.parametrize("metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError]) +def test_error_on_different_shape(metric_class): + metric = metric_class() + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/metrics/regression/test_psnr.py b/tests/metrics/regression/test_psnr.py new file mode 100644 index 0000000000000..eb07fffb9d55c --- /dev/null +++ b/tests/metrics/regression/test_psnr.py @@ -0,0 +1,133 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +from functools import partial + +import numpy as np +import pytest +import torch +from skimage.metrics import peak_signal_noise_ratio + +from pytorch_lightning.metrics.functional import psnr +from pytorch_lightning.metrics.regression import PSNR +from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES + +torch.manual_seed(42) + +Input = namedtuple('Input', ["preds", "target"]) + +_input_size = (NUM_BATCHES, BATCH_SIZE, 32, 32) +_inputs = [ + Input( + preds=torch.randint(n_cls_pred, _input_size, dtype=torch.float), + target=torch.randint(n_cls_target, _input_size, dtype=torch.float), + ) for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)] +] + + +def _to_sk_peak_signal_noise_ratio_inputs(value, dim): + value = value.numpy() + batches = value[None] if value.ndim == len(_input_size) - 1 else value + + if dim is None: + return [batches] + + num_dims = np.size(dim) + if not num_dims: + return batches + + inputs = [] + for batch in batches: + batch = np.moveaxis(batch, dim, np.arange(-num_dims, 0)) + psnr_input_shape = batch.shape[-num_dims:] + inputs.extend(batch.reshape(-1, *psnr_input_shape)) + return inputs + + +def _sk_psnr(preds, target, data_range, reduction, dim): + sk_preds_lists = _to_sk_peak_signal_noise_ratio_inputs(preds, dim=dim) + sk_target_lists = _to_sk_peak_signal_noise_ratio_inputs(target, dim=dim) + np_reduce_map = {"elementwise_mean": np.mean, "none": np.array, "sum": np.sum} + return np_reduce_map[reduction]([ + peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range) + for sk_target, sk_preds in zip(sk_target_lists, sk_preds_lists) + ]) + + +def _base_e_sk_psnr(preds, target, data_range, reduction, dim): + return _sk_psnr(preds, target, data_range, reduction, dim) * np.log(10) + + +@pytest.mark.parametrize( + "preds, target, data_range, reduction, dim", + [ + (_inputs[0].preds, _inputs[0].target, 10, "elementwise_mean", None), + (_inputs[1].preds, _inputs[1].target, 10, "elementwise_mean", None), + (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", None), + (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", 1), + (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", (1, 2)), + (_inputs[2].preds, _inputs[2].target, 5, "sum", (1, 2)), + ], +) +@pytest.mark.parametrize( + "base, sk_metric", + [ + (10.0, _sk_psnr), + (2.718281828459045, _base_e_sk_psnr), + ], +) +class TestPSNR(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_psnr(self, preds, target, data_range, base, reduction, dim, sk_metric, ddp, dist_sync_on_step): + _args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim} + self.run_class_metric_test( + ddp, + preds, + target, + PSNR, + partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim), + metric_args=_args, + dist_sync_on_step=dist_sync_on_step, + ) + + def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduction, dim): + _args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim} + self.run_functional_metric_test( + preds, + target, + psnr, + partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim), + metric_args=_args, + ) + + +@pytest.mark.parametrize("reduction", ["none", "sum"]) +def test_reduction_for_dim_none(reduction): + match = f"The `reduction={reduction}` will not have any effect when `dim` is None." + with pytest.warns(UserWarning, match=match): + PSNR(reduction=reduction, dim=None) + + with pytest.warns(UserWarning, match=match): + psnr(_inputs[0].preds, _inputs[0].target, reduction=reduction, dim=None) + + +def test_missing_data_range(): + with pytest.raises(ValueError): + PSNR(data_range=None, dim=0) + + with pytest.raises(ValueError): + psnr(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0) diff --git a/tests/metrics/regression/test_r2score.py b/tests/metrics/regression/test_r2score.py new file mode 100644 index 0000000000000..232b003e6116a --- /dev/null +++ b/tests/metrics/regression/test_r2score.py @@ -0,0 +1,114 @@ +from collections import namedtuple +from functools import partial + +import pytest +import torch +from sklearn.metrics import r2_score as sk_r2score + +from pytorch_lightning.metrics.functional import r2score +from pytorch_lightning.metrics.regression import R2Score +from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES + +torch.manual_seed(42) + +num_targets = 5 + +Input = namedtuple('Input', ["preds", "target"]) + +_single_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.rand(NUM_BATCHES, BATCH_SIZE), +) + +_multi_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), +) + + +def _single_target_sk_metric(preds, target, adjusted, multioutput): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) + if adjusted != 0: + r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) + return r2_score + + +def _multi_target_sk_metric(preds, target, adjusted, multioutput): + sk_preds = preds.view(-1, num_targets).numpy() + sk_target = target.view(-1, num_targets).numpy() + r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) + if adjusted != 0: + r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) + return r2_score + + +@pytest.mark.parametrize("adjusted", [0, 5, 10]) +@pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted']) +@pytest.mark.parametrize( + "preds, target, sk_metric, num_outputs", + [ + (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric, 1), + (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric, num_targets), + ], +) +class TestR2Score(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_r2(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + R2Score, + partial(sk_metric, adjusted=adjusted, multioutput=multioutput), + dist_sync_on_step, + metric_args=dict(adjusted=adjusted, multioutput=multioutput, num_outputs=num_outputs), + ) + + def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, num_outputs): + self.run_functional_metric_test( + preds, + target, + r2score, + partial(sk_metric, adjusted=adjusted, multioutput=multioutput), + metric_args=dict(adjusted=adjusted, multioutput=multioutput), + ) + + +def test_error_on_different_shape(metric_class=R2Score): + metric = metric_class() + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100, ), torch.randn(50, )) + + +def test_error_on_multidim_tensors(metric_class=R2Score): + metric = metric_class() + with pytest.raises( + ValueError, + match=r'Expected both prediction and target to be 1D or 2D tensors,' + r' but recevied tensors with dimension .' + ): + metric(torch.randn(10, 20, 5), torch.randn(10, 20, 5)) + + +def test_error_on_too_few_samples(metric_class=R2Score): + metric = metric_class() + with pytest.raises(ValueError, match='Needs atleast two samples to calculate r2 score.'): + metric(torch.randn(1, ), torch.randn(1, )) + + +def test_warning_on_too_large_adjusted(metric_class=R2Score): + metric = metric_class(adjusted=10) + + with pytest.warns( + UserWarning, + match="More independent regressions than datapoints in" + " adjusted r2 score. Falls back to standard r2 score." + ): + metric(torch.randn(10, ), torch.randn(10, )) + + with pytest.warns(UserWarning, match="Division by zero in adjusted r2 score. Falls back to" " standard r2 score."): + metric(torch.randn(11, ), torch.randn(11, )) diff --git a/tests/metrics/regression/test_ssim.py b/tests/metrics/regression/test_ssim.py new file mode 100644 index 0000000000000..f7e4b7a58e001 --- /dev/null +++ b/tests/metrics/regression/test_ssim.py @@ -0,0 +1,104 @@ +from collections import namedtuple +from functools import partial + +import pytest +import torch +from skimage.metrics import structural_similarity + +from pytorch_lightning.metrics.functional import ssim +from pytorch_lightning.metrics.regression import SSIM +from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES + +torch.manual_seed(42) + +Input = namedtuple('Input', ["preds", "target", "multichannel"]) + +_inputs = [] +for size, channel, coef, multichannel, dtype in [ + (12, 3, 0.9, True, torch.float), + (13, 1, 0.8, False, torch.float32), + (14, 1, 0.7, False, torch.double), + (15, 3, 0.6, True, torch.float64), +]: + preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) + _inputs.append(Input( + preds=preds, + target=preds * coef, + multichannel=multichannel, + )) + + +def _sk_metric(preds, target, data_range, multichannel): + c, h, w = preds.shape[-3:] + sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() + sk_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() + if not multichannel: + sk_preds = sk_preds[:, :, :, 0] + sk_target = sk_target[:, :, :, 0] + + return structural_similarity( + sk_target, + sk_preds, + data_range=data_range, + multichannel=multichannel, + gaussian_weights=True, + win_size=11, + sigma=1.5, + use_sample_covariance=False + ) + + +@pytest.mark.parametrize( + "preds, target, multichannel", + [(i.preds, i.target, i.multichannel) for i in _inputs], +) +class TestSSIM(MetricTester): + atol = 6e-5 + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_ssim(self, preds, target, multichannel, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + SSIM, + partial(_sk_metric, data_range=1.0, multichannel=multichannel), + metric_args={"data_range": 1.0}, + dist_sync_on_step=dist_sync_on_step, + ) + + def test_ssim_functional(self, preds, target, multichannel): + self.run_functional_metric_test( + preds, + target, + ssim, + partial(_sk_metric, data_range=1.0, multichannel=multichannel), + metric_args={"data_range": 1.0}, + ) + + +@pytest.mark.parametrize( + ['pred', 'target', 'kernel', 'sigma'], + [ + pytest.param([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape) + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma) + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma) + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma) + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input + ], +) +def test_ssim_invalid_inputs(pred, target, kernel, sigma): + pred_t = torch.rand(pred) + target_t = torch.rand(target, dtype=torch.float64) + with pytest.raises(TypeError): + ssim(pred_t, target_t) + + pred = torch.rand(pred) + target = torch.rand(target) + with pytest.raises(ValueError): + ssim(pred, target, kernel, sigma) diff --git a/tests/metrics/test_composition.py b/tests/metrics/test_composition.py new file mode 100644 index 0000000000000..7845e86f514ff --- /dev/null +++ b/tests/metrics/test_composition.py @@ -0,0 +1,510 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from operator import neg, pos + +import pytest +import torch + +from pytorch_lightning.metrics.compositional import CompositionalMetric +from pytorch_lightning.metrics.metric import Metric +from tests.helpers.runif import RunIf + + +class DummyMetric(Metric): + + def __init__(self, val_to_return): + super().__init__() + self._num_updates = 0 + self._val_to_return = val_to_return + + def update(self, *args, **kwargs) -> None: + self._num_updates += 1 + + def compute(self): + return torch.tensor(self._val_to_return) + + def reset(self): + self._num_updates = 0 + return super().reset() + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(4)), + (2, torch.tensor(4)), + (2.0, torch.tensor(4.0)), + (torch.tensor(2), torch.tensor(4)), + ], +) +def test_metrics_add(second_operand, expected_result): + first_metric = DummyMetric(2) + + final_add = first_metric + second_operand + final_radd = second_operand + first_metric + + assert isinstance(final_add, CompositionalMetric) + assert isinstance(final_radd, CompositionalMetric) + + assert torch.allclose(expected_result, final_add.compute()) + assert torch.allclose(expected_result, final_radd.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [(DummyMetric(3), torch.tensor(2)), (3, torch.tensor(2)), (3, torch.tensor(2)), (torch.tensor(3), torch.tensor(2))], +) +@RunIf(min_torch="1.5.0") +def test_metrics_and(second_operand, expected_result): + first_metric = DummyMetric(2) + + final_and = first_metric & second_operand + final_rand = second_operand & first_metric + + assert isinstance(final_and, CompositionalMetric) + assert isinstance(final_rand, CompositionalMetric) + + assert torch.allclose(expected_result, final_and.compute()) + assert torch.allclose(expected_result, final_rand.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(True)), + (2, torch.tensor(True)), + (2.0, torch.tensor(True)), + (torch.tensor(2), torch.tensor(True)), + ], +) +def test_metrics_eq(second_operand, expected_result): + first_metric = DummyMetric(2) + + final_eq = first_metric == second_operand + + assert isinstance(final_eq, CompositionalMetric) + + # can't use allclose for bool tensors + assert (expected_result == final_eq.compute()).all() + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(2)), + (2, torch.tensor(2)), + (2.0, torch.tensor(2.0)), + (torch.tensor(2), torch.tensor(2)), + ], +) +@RunIf(min_torch="1.5.0") +def test_metrics_floordiv(second_operand, expected_result): + first_metric = DummyMetric(5) + + final_floordiv = first_metric // second_operand + + assert isinstance(final_floordiv, CompositionalMetric) + + assert torch.allclose(expected_result, final_floordiv.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(True)), + (2, torch.tensor(True)), + (2.0, torch.tensor(True)), + (torch.tensor(2), torch.tensor(True)), + ], +) +def test_metrics_ge(second_operand, expected_result): + first_metric = DummyMetric(5) + + final_ge = first_metric >= second_operand + + assert isinstance(final_ge, CompositionalMetric) + + # can't use allclose for bool tensors + assert (expected_result == final_ge.compute()).all() + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(True)), + (2, torch.tensor(True)), + (2.0, torch.tensor(True)), + (torch.tensor(2), torch.tensor(True)), + ], +) +def test_metrics_gt(second_operand, expected_result): + first_metric = DummyMetric(5) + + final_gt = first_metric > second_operand + + assert isinstance(final_gt, CompositionalMetric) + + # can't use allclose for bool tensors + assert (expected_result == final_gt.compute()).all() + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(False)), + (2, torch.tensor(False)), + (2.0, torch.tensor(False)), + (torch.tensor(2), torch.tensor(False)), + ], +) +def test_metrics_le(second_operand, expected_result): + first_metric = DummyMetric(5) + + final_le = first_metric <= second_operand + + assert isinstance(final_le, CompositionalMetric) + + # can't use allclose for bool tensors + assert (expected_result == final_le.compute()).all() + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(False)), + (2, torch.tensor(False)), + (2.0, torch.tensor(False)), + (torch.tensor(2), torch.tensor(False)), + ], +) +def test_metrics_lt(second_operand, expected_result): + first_metric = DummyMetric(5) + + final_lt = first_metric < second_operand + + assert isinstance(final_lt, CompositionalMetric) + + # can't use allclose for bool tensors + assert (expected_result == final_lt.compute()).all() + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [(DummyMetric([2, 2, 2]), torch.tensor(12)), (torch.tensor([2, 2, 2]), torch.tensor(12))], +) +def test_metrics_matmul(second_operand, expected_result): + first_metric = DummyMetric([2, 2, 2]) + + final_matmul = first_metric @ second_operand + + assert isinstance(final_matmul, CompositionalMetric) + + assert torch.allclose(expected_result, final_matmul.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(1)), + (2, torch.tensor(1)), + (2.0, torch.tensor(1)), + (torch.tensor(2), torch.tensor(1)), + ], +) +def test_metrics_mod(second_operand, expected_result): + first_metric = DummyMetric(5) + + final_mod = first_metric % second_operand + + assert isinstance(final_mod, CompositionalMetric) + # prevent Runtime error for PT 1.8 - Long did not match Float + assert torch.allclose(expected_result.to(float), final_mod.compute().to(float)) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(4)), + (2, torch.tensor(4)), + (2.0, torch.tensor(4.0)), + (torch.tensor(2), torch.tensor(4)), + ], +) +def test_metrics_mul(second_operand, expected_result): + first_metric = DummyMetric(2) + + final_mul = first_metric * second_operand + final_rmul = second_operand * first_metric + + assert isinstance(final_mul, CompositionalMetric) + assert isinstance(final_rmul, CompositionalMetric) + + assert torch.allclose(expected_result, final_mul.compute()) + assert torch.allclose(expected_result, final_rmul.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(False)), + (2, torch.tensor(False)), + (2.0, torch.tensor(False)), + (torch.tensor(2), torch.tensor(False)), + ], +) +def test_metrics_ne(second_operand, expected_result): + first_metric = DummyMetric(2) + + final_ne = first_metric != second_operand + + assert isinstance(final_ne, CompositionalMetric) + + # can't use allclose for bool tensors + assert (expected_result == final_ne.compute()).all() + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [(DummyMetric([1, 0, 3]), torch.tensor([-1, -2, 3])), (torch.tensor([1, 0, 3]), torch.tensor([-1, -2, 3]))], +) +@RunIf(min_torch="1.5.0") +def test_metrics_or(second_operand, expected_result): + first_metric = DummyMetric([-1, -2, 3]) + + final_or = first_metric | second_operand + final_ror = second_operand | first_metric + + assert isinstance(final_or, CompositionalMetric) + assert isinstance(final_ror, CompositionalMetric) + + assert torch.allclose(expected_result, final_or.compute()) + assert torch.allclose(expected_result, final_ror.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + pytest.param(DummyMetric(2), torch.tensor(4)), + pytest.param(2, torch.tensor(4)), + pytest.param(2.0, torch.tensor(4.0), marks=RunIf(min_torch="1.6.0")), + pytest.param(torch.tensor(2), torch.tensor(4)), + ], +) +def test_metrics_pow(second_operand, expected_result): + first_metric = DummyMetric(2) + + final_pow = first_metric**second_operand + + assert isinstance(final_pow, CompositionalMetric) + + assert torch.allclose(expected_result, final_pow.compute()) + + +@pytest.mark.parametrize( + ["first_operand", "expected_result"], + [(5, torch.tensor(2)), (5.0, torch.tensor(2.0)), (torch.tensor(5), torch.tensor(2))], +) +@RunIf(min_torch="1.5.0") +def test_metrics_rfloordiv(first_operand, expected_result): + second_operand = DummyMetric(2) + + final_rfloordiv = first_operand // second_operand + + assert isinstance(final_rfloordiv, CompositionalMetric) + assert torch.allclose(expected_result, final_rfloordiv.compute()) + + +@pytest.mark.parametrize(["first_operand", "expected_result"], [(torch.tensor([2, 2, 2]), torch.tensor(12))]) +def test_metrics_rmatmul(first_operand, expected_result): + second_operand = DummyMetric([2, 2, 2]) + + final_rmatmul = first_operand @ second_operand + + assert isinstance(final_rmatmul, CompositionalMetric) + + assert torch.allclose(expected_result, final_rmatmul.compute()) + + +@pytest.mark.parametrize(["first_operand", "expected_result"], [(torch.tensor(2), torch.tensor(2))]) +def test_metrics_rmod(first_operand, expected_result): + second_operand = DummyMetric(5) + + final_rmod = first_operand % second_operand + + assert isinstance(final_rmod, CompositionalMetric) + + assert torch.allclose(expected_result, final_rmod.compute()) + + +@pytest.mark.parametrize( + "first_operand,expected_result", + [ + pytest.param(DummyMetric(2), torch.tensor(4)), + pytest.param(2, torch.tensor(4)), + pytest.param(2.0, torch.tensor(4.0), marks=RunIf(min_torch="1.6.0")), + ], +) +def test_metrics_rpow(first_operand, expected_result): + second_operand = DummyMetric(2) + + final_rpow = first_operand**second_operand + + assert isinstance(final_rpow, CompositionalMetric) + + assert torch.allclose(expected_result, final_rpow.compute()) + + +@pytest.mark.parametrize( + ["first_operand", "expected_result"], + [ + (DummyMetric(3), torch.tensor(1)), + (3, torch.tensor(1)), + (3.0, torch.tensor(1.0)), + (torch.tensor(3), torch.tensor(1)), + ], +) +def test_metrics_rsub(first_operand, expected_result): + second_operand = DummyMetric(2) + + final_rsub = first_operand - second_operand + + assert isinstance(final_rsub, CompositionalMetric) + + assert torch.allclose(expected_result, final_rsub.compute()) + + +@pytest.mark.parametrize( + ["first_operand", "expected_result"], + [ + (DummyMetric(6), torch.tensor(2.0)), + (6, torch.tensor(2.0)), + (6.0, torch.tensor(2.0)), + (torch.tensor(6), torch.tensor(2.0)), + ], +) +@RunIf(min_torch="1.5.0") +def test_metrics_rtruediv(first_operand, expected_result): + second_operand = DummyMetric(3) + + final_rtruediv = first_operand / second_operand + + assert isinstance(final_rtruediv, CompositionalMetric) + + assert torch.allclose(expected_result, final_rtruediv.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(1)), + (2, torch.tensor(1)), + (2.0, torch.tensor(1.0)), + (torch.tensor(2), torch.tensor(1)), + ], +) +def test_metrics_sub(second_operand, expected_result): + first_metric = DummyMetric(3) + + final_sub = first_metric - second_operand + + assert isinstance(final_sub, CompositionalMetric) + + assert torch.allclose(expected_result, final_sub.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(3), torch.tensor(2.0)), + (3, torch.tensor(2.0)), + (3.0, torch.tensor(2.0)), + (torch.tensor(3), torch.tensor(2.0)), + ], +) +@RunIf(min_torch="1.5.0") +def test_metrics_truediv(second_operand, expected_result): + first_metric = DummyMetric(6) + + final_truediv = first_metric / second_operand + + assert isinstance(final_truediv, CompositionalMetric) + + assert torch.allclose(expected_result, final_truediv.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [(DummyMetric([1, 0, 3]), torch.tensor([-2, -2, 0])), (torch.tensor([1, 0, 3]), torch.tensor([-2, -2, 0]))], +) +def test_metrics_xor(second_operand, expected_result): + first_metric = DummyMetric([-1, -2, 3]) + + final_xor = first_metric ^ second_operand + final_rxor = second_operand ^ first_metric + + assert isinstance(final_xor, CompositionalMetric) + assert isinstance(final_rxor, CompositionalMetric) + + assert torch.allclose(expected_result, final_xor.compute()) + assert torch.allclose(expected_result, final_rxor.compute()) + + +def test_metrics_abs(): + first_metric = DummyMetric(-1) + + final_abs = abs(first_metric) + + assert isinstance(final_abs, CompositionalMetric) + + assert torch.allclose(torch.tensor(1), final_abs.compute()) + + +def test_metrics_invert(): + first_metric = DummyMetric(1) + + final_inverse = ~first_metric + assert isinstance(final_inverse, CompositionalMetric) + assert torch.allclose(torch.tensor(-2), final_inverse.compute()) + + +def test_metrics_neg(): + first_metric = DummyMetric(1) + + final_neg = neg(first_metric) + assert isinstance(final_neg, CompositionalMetric) + assert torch.allclose(torch.tensor(-1), final_neg.compute()) + + +def test_metrics_pos(): + first_metric = DummyMetric(-1) + + final_pos = pos(first_metric) + assert isinstance(final_pos, CompositionalMetric) + assert torch.allclose(torch.tensor(1), final_pos.compute()) + + +def test_compositional_metrics_update(): + + compos = DummyMetric(5) + DummyMetric(4) + + assert isinstance(compos, CompositionalMetric) + compos.update() + compos.update() + compos.update() + + assert isinstance(compos.metric_a, DummyMetric) + assert isinstance(compos.metric_b, DummyMetric) + + assert compos.metric_a._num_updates == 3 + assert compos.metric_b._num_updates == 3 diff --git a/tests/metrics/test_ddp.py b/tests/metrics/test_ddp.py new file mode 100644 index 0000000000000..5120cce0a0425 --- /dev/null +++ b/tests/metrics/test_ddp.py @@ -0,0 +1,71 @@ +import pytest +import torch + +from pytorch_lightning.metrics import Metric +from tests.helpers.runif import RunIf +from tests.metrics.test_metric import Dummy +from tests.metrics.utils import setup_ddp + +torch.manual_seed(42) + + +def _test_ddp_sum(rank, worldsize): + setup_ddp(rank, worldsize) + dummy = Dummy() + dummy._reductions = {"foo": torch.sum} + dummy.foo = torch.tensor(1) + + dummy._sync_dist() + assert dummy.foo == worldsize + + +def _test_ddp_cat(rank, worldsize): + setup_ddp(rank, worldsize) + dummy = Dummy() + dummy._reductions = {"foo": torch.cat} + dummy.foo = [torch.tensor([1])] + dummy._sync_dist() + assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) + + +def _test_ddp_sum_cat(rank, worldsize): + setup_ddp(rank, worldsize) + dummy = Dummy() + dummy._reductions = {"foo": torch.cat, "bar": torch.sum} + dummy.foo = [torch.tensor([1])] + dummy.bar = torch.tensor(1) + dummy._sync_dist() + assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) + assert dummy.bar == worldsize + + +@RunIf(skip_windows=True) +@pytest.mark.parametrize("process", [_test_ddp_cat, _test_ddp_sum, _test_ddp_sum_cat]) +def test_ddp(process): + torch.multiprocessing.spawn(process, args=(2, ), nprocs=2) + + +def _test_non_contiguous_tensors(rank, worldsize): + setup_ddp(rank, worldsize) + + class DummyMetric(Metric): + + def __init__(self): + super().__init__() + self.add_state("x", default=[], dist_reduce_fx=None) + + def update(self, x): + self.x.append(x) + + def compute(self): + x = torch.cat(self.x, dim=0) + return x.sum() + + metric = DummyMetric() + metric.update(torch.randn(10, 5)[:, 0]) + + +@RunIf(skip_windows=True) +def test_non_contiguous_tensors(): + """ Test that gather_all operation works for non contiguous tensors """ + torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2, ), nprocs=2) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py new file mode 100644 index 0000000000000..ad7b4566dc012 --- /dev/null +++ b/tests/metrics/test_metric.py @@ -0,0 +1,395 @@ +import pickle +from collections import OrderedDict +from distutils.version import LooseVersion + +import cloudpickle +import numpy as np +import pytest +import torch +from torch import nn + +from pytorch_lightning.metrics.metric import Metric, MetricCollection +from tests.helpers.runif import RunIf + +torch.manual_seed(42) + + +class Dummy(Metric): + name = "Dummy" + + def __init__(self): + super().__init__() + self.add_state("x", torch.tensor(0.0), dist_reduce_fx=None) + + def update(self): + pass + + def compute(self): + pass + + +class DummyList(Metric): + name = "DummyList" + + def __init__(self): + super().__init__() + self.add_state("x", list(), dist_reduce_fx=None) + + def update(self): + pass + + def compute(self): + pass + + +def test_inherit(): + Dummy() + + +def test_add_state(): + a = Dummy() + + a.add_state("a", torch.tensor(0), "sum") + assert a._reductions["a"](torch.tensor([1, 1])) == 2 + + a.add_state("b", torch.tensor(0), "mean") + assert np.allclose(a._reductions["b"](torch.tensor([1.0, 2.0])).numpy(), 1.5) + + a.add_state("c", torch.tensor(0), "cat") + assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2, ) + + with pytest.raises(ValueError): + a.add_state("d1", torch.tensor(0), 'xyz') + + with pytest.raises(ValueError): + a.add_state("d2", torch.tensor(0), 42) + + with pytest.raises(ValueError): + a.add_state("d3", [torch.tensor(0)], 'sum') + + with pytest.raises(ValueError): + a.add_state("d4", 42, 'sum') + + def custom_fx(x): + return -1 + + a.add_state("e", torch.tensor(0), custom_fx) + assert a._reductions["e"](torch.tensor([1, 1])) == -1 + + +def test_add_state_persistent(): + a = Dummy() + + a.add_state("a", torch.tensor(0), "sum", persistent=True) + assert "a" in a.state_dict() + + a.add_state("b", torch.tensor(0), "sum", persistent=False) + + if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + assert "b" not in a.state_dict() + + +def test_reset(): + + class A(Dummy): + pass + + class B(DummyList): + pass + + a = A() + assert a.x == 0 + a.x = torch.tensor(5) + a.reset() + assert a.x == 0 + + b = B() + assert isinstance(b.x, list) and len(b.x) == 0 + b.x = torch.tensor(5) + b.reset() + assert isinstance(b.x, list) and len(b.x) == 0 + + +def test_update(): + + class A(Dummy): + + def update(self, x): + self.x += x + + a = A() + assert a.x == 0 + assert a._computed is None + a.update(1) + assert a._computed is None + assert a.x == 1 + a.update(2) + assert a.x == 3 + assert a._computed is None + + +def test_compute(): + + class A(Dummy): + + def update(self, x): + self.x += x + + def compute(self): + return self.x + + a = A() + assert 0 == a.compute() + assert 0 == a.x + a.update(1) + assert a._computed is None + assert a.compute() == 1 + assert a._computed == 1 + a.update(2) + assert a._computed is None + assert a.compute() == 3 + assert a._computed == 3 + + # called without update, should return cached value + a._computed = 5 + assert a.compute() == 5 + + +def test_hash(): + + class A(Dummy): + pass + + class B(DummyList): + pass + + a1 = A() + a2 = A() + assert hash(a1) != hash(a2) + + b1 = B() + b2 = B() + assert hash(b1) == hash(b2) + assert isinstance(b1.x, list) and len(b1.x) == 0 + b1.x.append(torch.tensor(5)) + assert isinstance(hash(b1), int) # <- check that nothing crashes + assert isinstance(b1.x, list) and len(b1.x) == 1 + b2.x.append(torch.tensor(5)) + # Sanity: + assert isinstance(b2.x, list) and len(b2.x) == 1 + # Now that they have tensor contents, they should have different hashes: + assert hash(b1) != hash(b2) + + +def test_forward(): + + class A(Dummy): + + def update(self, x): + self.x += x + + def compute(self): + return self.x + + a = A() + assert a(5) == 5 + assert a._forward_cache == 5 + + assert a(8) == 8 + assert a._forward_cache == 8 + + assert a.compute() == 13 + + +class DummyMetric1(Dummy): + + def update(self, x): + self.x += x + + def compute(self): + return self.x + + +class DummyMetric2(Dummy): + + def update(self, y): + self.x -= y + + def compute(self): + return self.x + + +def test_pickle(tmpdir): + # doesn't tests for DDP + a = DummyMetric1() + a.update(1) + + metric_pickled = pickle.dumps(a) + metric_loaded = pickle.loads(metric_pickled) + + assert metric_loaded.compute() == 1 + + metric_loaded.update(5) + assert metric_loaded.compute() == 6 + + metric_pickled = cloudpickle.dumps(a) + metric_loaded = cloudpickle.loads(metric_pickled) + + assert metric_loaded.compute() == 1 + + +def test_state_dict(tmpdir): + """ test that metric states can be removed and added to state dict """ + metric = Dummy() + assert metric.state_dict() == OrderedDict() + metric.persistent(True) + assert metric.state_dict() == OrderedDict(x=0) + metric.persistent(False) + assert metric.state_dict() == OrderedDict() + + +def test_child_metric_state_dict(): + """ test that child metric states will be added to parent state dict """ + + class TestModule(nn.Module): + + def __init__(self): + super().__init__() + self.metric = Dummy() + self.metric.add_state('a', torch.tensor(0), persistent=True) + self.metric.add_state('b', [], persistent=True) + self.metric.register_buffer('c', torch.tensor(0)) + + module = TestModule() + expected_state_dict = { + 'metric.a': torch.tensor(0), + 'metric.b': [], + 'metric.c': torch.tensor(0), + } + assert module.state_dict() == expected_state_dict + + +@RunIf(min_gpus=1) +def test_device_and_dtype_transfer(tmpdir): + metric = DummyMetric1() + assert metric.x.is_cuda is False + assert metric.x.dtype == torch.float32 + + metric = metric.to(device='cuda') + assert metric.x.is_cuda + + metric = metric.double() + assert metric.x.dtype == torch.float64 + + metric = metric.half() + assert metric.x.dtype == torch.float16 + + +def test_metric_collection(tmpdir): + m1 = DummyMetric1() + m2 = DummyMetric2() + + metric_collection = MetricCollection([m1, m2]) + + # Test correct dict structure + assert len(metric_collection) == 2 + assert metric_collection['DummyMetric1'] == m1 + assert metric_collection['DummyMetric2'] == m2 + + # Test correct initialization + for name, metric in metric_collection.items(): + assert metric.x == 0, f'Metric {name} not initialized correctly' + + # Test every metric gets updated + metric_collection.update(5) + for name, metric in metric_collection.items(): + assert metric.x.abs() == 5, f'Metric {name} not updated correctly' + + # Test compute on each metric + metric_collection.update(-5) + metric_vals = metric_collection.compute() + assert len(metric_vals) == 2 + for name, metric_val in metric_vals.items(): + assert metric_val == 0, f'Metric {name}.compute not called correctly' + + # Test that everything is reset + for name, metric in metric_collection.items(): + assert metric.x == 0, f'Metric {name} not reset correctly' + + # Test pickable + metric_pickled = pickle.dumps(metric_collection) + metric_loaded = pickle.loads(metric_pickled) + assert isinstance(metric_loaded, MetricCollection) + + +@RunIf(min_gpus=1) +def test_device_and_dtype_transfer_metriccollection(tmpdir): + m1 = DummyMetric1() + m2 = DummyMetric2() + + metric_collection = MetricCollection([m1, m2]) + for _, metric in metric_collection.items(): + assert metric.x.is_cuda is False + assert metric.x.dtype == torch.float32 + + metric_collection = metric_collection.to(device='cuda') + for _, metric in metric_collection.items(): + assert metric.x.is_cuda + + metric_collection = metric_collection.double() + for _, metric in metric_collection.items(): + assert metric.x.dtype == torch.float64 + + metric_collection = metric_collection.half() + for _, metric in metric_collection.items(): + assert metric.x.dtype == torch.float16 + + +def test_metric_collection_wrong_input(tmpdir): + """ Check that errors are raised on wrong input """ + m1 = DummyMetric1() + + # Not all input are metrics (list) + with pytest.raises(ValueError): + _ = MetricCollection([m1, 5]) + + # Not all input are metrics (dict) + with pytest.raises(ValueError): + _ = MetricCollection({'metric1': m1, 'metric2': 5}) + + # Same metric passed in multiple times + with pytest.raises(ValueError, match='Encountered two metrics both named *.'): + _ = MetricCollection([m1, m1]) + + # Not a list or dict passed in + with pytest.raises(ValueError, match='Unknown input to MetricCollection.'): + _ = MetricCollection(m1) + + +def test_metric_collection_args_kwargs(tmpdir): + """ Check that args and kwargs gets passed correctly in metric collection, + Checks both update and forward method + """ + m1 = DummyMetric1() + m2 = DummyMetric2() + + metric_collection = MetricCollection([m1, m2]) + + # args gets passed to all metrics + metric_collection.update(5) + assert metric_collection['DummyMetric1'].x == 5 + assert metric_collection['DummyMetric2'].x == -5 + metric_collection.reset() + _ = metric_collection(5) + assert metric_collection['DummyMetric1'].x == 5 + assert metric_collection['DummyMetric2'].x == -5 + metric_collection.reset() + + # kwargs gets only passed to metrics that it matches + metric_collection.update(x=10, y=20) + assert metric_collection['DummyMetric1'].x == 10 + assert metric_collection['DummyMetric2'].x == -20 + metric_collection.reset() + _ = metric_collection(x=10, y=20) + assert metric_collection['DummyMetric1'].x == 10 + assert metric_collection['DummyMetric2'].x == -20 diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index e52e39cb16488..895305fa9da7e 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -1,13 +1,11 @@ import torch -from torchmetrics import Metric as TMetric from pytorch_lightning import Trainer -from pytorch_lightning.metrics import Metric as PLMetric -from pytorch_lightning.metrics import MetricCollection +from pytorch_lightning.metrics import Metric, MetricCollection from tests.helpers.boring_model import BoringModel -class SumMetric(TMetric): +class SumMetric(Metric): def __init__(self): super().__init__() @@ -20,7 +18,7 @@ def compute(self): return self.x -class DiffMetric(PLMetric): +class DiffMetric(Metric): def __init__(self): super().__init__() diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py deleted file mode 100644 index d3703bf3691c9..0000000000000 --- a/tests/metrics/test_remove_1-5_metrics.py +++ /dev/null @@ -1,348 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Test deprecated functionality which will be removed in v1.5.0""" - -import pytest -import torch - -from pytorch_lightning.metrics import ( - Accuracy, - AUC, - AUROC, - AveragePrecision, - ConfusionMatrix, - ExplainedVariance, - F1, - FBeta, - HammingDistance, - IoU, - MeanAbsoluteError, - MeanSquaredError, - MeanSquaredLogError, - MetricCollection, - Precision, - PrecisionRecallCurve, - PSNR, - R2Score, - Recall, - ROC, - SSIM, - StatScores, -) -from pytorch_lightning.metrics.functional import ( - auc, - auroc, - average_precision, - bleu_score, - confusion_matrix, - embedding_similarity, - explained_variance, - f1, - fbeta, - hamming_distance, - iou, - mean_absolute_error, - mean_squared_error, - mean_squared_log_error, - precision, - precision_recall, - precision_recall_curve, - psnr, - r2score, - recall, - roc, - ssim, - stat_scores, -) -from pytorch_lightning.metrics.functional.accuracy import accuracy -from pytorch_lightning.metrics.functional.mean_relative_error import mean_relative_error -from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot - - -def test_v1_5_metrics_utils(): - x = torch.tensor([1, 2, 3]) - with pytest.deprecated_call(match="It will be removed in v1.5.0"): - assert torch.equal(to_onehot(x), torch.Tensor([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]).to(int)) - - with pytest.deprecated_call(match="It will be removed in v1.5.0"): - assert get_num_classes(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 0])) == 4 - - x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) - with pytest.deprecated_call(match="It will be removed in v1.5.0"): - assert torch.equal(select_topk(x, topk=2), torch.Tensor([[0, 1, 1], [1, 1, 0]]).to(torch.int32)) - - x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) - with pytest.deprecated_call(match="It will be removed in v1.5.0"): - assert torch.equal(to_categorical(x), torch.Tensor([1, 0]).to(int)) - - -def test_v1_5_metrics_collection(): - target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) - preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) - - MetricCollection.__init__._warned = False - with pytest.deprecated_call(match="It will be removed in v1.5.0."): - metrics = MetricCollection([Accuracy()]) - assert metrics(preds, target) == {'Accuracy': torch.tensor(0.1250)} - - -def test_v1_5_metric_accuracy(): - accuracy._warned = False - - preds = torch.tensor([0, 0, 1, 0, 1]) - target = torch.tensor([0, 0, 1, 1, 1]) - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert accuracy(preds, target) == torch.tensor(0.8) - - Accuracy.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - Accuracy() - - -def test_v1_5_metric_auc_auroc(): - AUC.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - AUC() - - ROC.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - ROC() - - AUROC.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - AUROC() - - x = torch.tensor([0, 1, 2, 3]) - y = torch.tensor([0, 1, 2, 2]) - auc._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert auc(x, y) == torch.tensor(4.) - - preds = torch.tensor([0, 1, 2, 3]) - target = torch.tensor([0, 1, 1, 1]) - roc._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - fpr, tpr, thrs = roc(preds, target, pos_label=1) - assert torch.equal(fpr, torch.tensor([0., 0., 0., 0., 1.])) - assert torch.allclose(tpr, torch.tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]), atol=1e-4) - assert torch.equal(thrs, torch.tensor([4, 3, 2, 1, 0])) - - preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) - target = torch.tensor([0, 0, 1, 1, 1]) - auroc._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert auroc(preds, target) == torch.tensor(0.5) - - -def test_v1_5_metric_precision_recall(): - AveragePrecision.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - AveragePrecision() - - Precision.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - Precision() - - Recall.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - Recall() - - PrecisionRecallCurve.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - PrecisionRecallCurve() - - pred = torch.tensor([0, 1, 2, 3]) - target = torch.tensor([0, 1, 1, 1]) - average_precision._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert average_precision(pred, target) == torch.tensor(1.) - - precision._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert precision(pred, target) == torch.tensor(0.5) - - recall._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert recall(pred, target) == torch.tensor(0.5) - - precision_recall._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - prec, rc = precision_recall(pred, target) - assert prec == torch.tensor(0.5) - assert rc == torch.tensor(0.5) - - precision_recall_curve._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - prec, rc, thrs = precision_recall_curve(pred, target) - assert torch.equal(prec, torch.tensor([1., 1., 1., 1.])) - assert torch.allclose(rc, torch.tensor([1., 0.6667, 0.3333, 0.]), atol=1e-4) - assert torch.equal(thrs, torch.tensor([1, 2, 3])) - - -def test_v1_5_metric_classif_mix(): - ConfusionMatrix.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - ConfusionMatrix(num_classes=1) - - FBeta.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - FBeta(num_classes=1) - - F1.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - F1(num_classes=1) - - HammingDistance.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - HammingDistance() - - StatScores.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - StatScores() - - target = torch.tensor([1, 1, 0, 0]) - preds = torch.tensor([0, 1, 0, 0]) - confusion_matrix._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert torch.equal(confusion_matrix(preds, target, num_classes=2), torch.tensor([[2., 0.], [1., 1.]])) - - target = torch.tensor([0, 1, 2, 0, 1, 2]) - preds = torch.tensor([0, 2, 1, 0, 0, 1]) - fbeta._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert torch.allclose(fbeta(preds, target, num_classes=3, beta=0.5), torch.tensor(0.3333), atol=1e-4) - - f1._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert torch.allclose(f1(preds, target, num_classes=3), torch.tensor(0.3333), atol=1e-4) - - target = torch.tensor([[0, 1], [1, 1]]) - preds = torch.tensor([[0, 1], [0, 1]]) - hamming_distance._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert hamming_distance(preds, target) == torch.tensor(0.25) - - preds = torch.tensor([1, 0, 2, 1]) - target = torch.tensor([1, 1, 2, 0]) - stat_scores._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert torch.equal(stat_scores(preds, target, reduce='micro'), torch.tensor([2, 2, 6, 2, 4])) - - -def test_v1_5_metric_detect(): - IoU.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - IoU(num_classes=1) - - target = torch.randint(0, 2, (10, 25, 25)) - preds = torch.tensor(target) - preds[2:5, 7:13, 9:15] = 1 - preds[2:5, 7:13, 9:15] - iou._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - res = iou(preds, target) - assert torch.allclose(res, torch.tensor(0.9660), atol=1e-4) - - -def test_v1_5_metric_regress(): - ExplainedVariance.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - ExplainedVariance() - - MeanAbsoluteError.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - MeanAbsoluteError() - - MeanSquaredError.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - MeanSquaredError() - - MeanSquaredLogError.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - MeanSquaredLogError() - - target = torch.tensor([3, -0.5, 2, 7]) - preds = torch.tensor([2.5, 0.0, 2, 8]) - explained_variance._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - res = explained_variance(preds, target) - assert torch.allclose(res, torch.tensor(0.9572), atol=1e-4) - - x = torch.tensor([0., 1, 2, 3]) - y = torch.tensor([0., 1, 2, 2]) - mean_absolute_error._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert mean_absolute_error(x, y) == 0.25 - - mean_relative_error._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert mean_relative_error(x, y) == 0.125 - - mean_squared_error._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert mean_squared_error(x, y) == 0.25 - - mean_squared_log_error._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - res = mean_squared_log_error(x, y) - assert torch.allclose(res, torch.tensor(0.0207), atol=1e-4) - - PSNR.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - PSNR() - - R2Score.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - R2Score() - - SSIM.__init__._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - SSIM() - - preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) - target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) - psnr._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - res = psnr(preds, target) - assert torch.allclose(res, torch.tensor(2.5527), atol=1e-4) - - target = torch.tensor([3, -0.5, 2, 7]) - preds = torch.tensor([2.5, 0.0, 2, 8]) - r2score._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - res = r2score(preds, target) - assert torch.allclose(res, torch.tensor(0.9486), atol=1e-4) - - preds = torch.rand([16, 1, 16, 16]) - target = preds * 0.75 - ssim._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - res = ssim(preds, target) - assert torch.allclose(res, torch.tensor(0.9219), atol=1e-4) - - -def test_v1_5_metric_others(): - translate_corpus = ['the cat is on the mat'.split()] - reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] - bleu_score._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - res = bleu_score(translate_corpus, reference_corpus) - assert torch.allclose(res, torch.tensor(0.7598), atol=1e-4) - - embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]]) - embedding_similarity._warned = False - with pytest.deprecated_call(match='It will be removed in v1.5.0'): - res = embedding_similarity(embeddings) - assert torch.allclose( - res, torch.tensor([[0.0000, 1.0000, 0.9759], [1.0000, 0.0000, 0.9759], [0.9759, 0.9759, 0.0000]]), atol=1e-4 - ) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index f1f17d0624936..4bd6608ce3fcf 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -8,7 +8,8 @@ import pytest import torch from torch.multiprocessing import Pool, set_start_method -from torchmetrics import Metric + +from pytorch_lightning.metrics import Metric try: set_start_method("spawn") diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 46ab64afccb03..d3868cfd979e6 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -21,8 +21,6 @@ import os import sys -import torch - # this is needed because Conda does not use `PYTHONPATH` env var while pip and virtualenv do PYTHONPATH = os.getenv('PYTHONPATH', '') if ':' in PYTHONPATH: @@ -54,13 +52,8 @@ def run_test_from_config(trainer_options, on_gpu, check_size=True): ckpt_path = trainer_options['weights_save_path'] trainer_options.update(callbacks=[ModelCheckpoint(dirpath=ckpt_path)]) - class TestModel(BoringModel): - - def training_epoch_end(self, outputs) -> None: - res = self.trainer.training_type_plugin.reduce(torch.tensor(1., device=self.device), reduce_op="sum") - assert res.sum() == self.trainer.training_type_plugin.world_size + model = BoringModel() - model = TestModel() trainer = Trainer(**trainer_options) trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 0b9d6776c1aaa..9853db342436b 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -17,43 +17,24 @@ import pytest import torch from torch import optim -from torch.utils.data import DataLoader import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel, RandomDataset +from tests.helpers import BoringModel from tests.helpers.runif import RunIf class AMPTestModel(BoringModel): - def _step(self, batch, batch_idx): - assert torch.is_autocast_enabled() - output = self(batch) - assert output.dtype == torch.float16 - loss = self.loss(batch, output) - return loss - def training_step(self, batch, batch_idx): - output = self._step(batch, batch_idx) - return {"loss": output} - - def validation_step(self, batch, batch_idx): - output = self._step(batch, batch_idx) - return {"x": output} - - def test_step(self, batch, batch_idx): - output = self._step(batch, batch_idx) - return {"y": output} - - def predict(self, batch, batch_idx, dataloader_idx=None): assert torch.is_autocast_enabled() output = self(batch) assert output.dtype == torch.float16 - return output + loss = self.loss(batch, output) + return {"loss": loss} @pytest.mark.skip(reason='dp + amp not supported currently') # TODO @@ -73,8 +54,6 @@ def test_amp_single_gpu_dp(tmpdir): model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) - trainer.test(model) - trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -94,8 +73,6 @@ def test_amp_single_gpu_ddp_spawn(tmpdir): model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) - trainer.test(model) - trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -135,8 +112,6 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir): model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) - trainer.test(model) - trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 1d55d4a5a63b7..0d1c7cf40a2bf 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect from unittest import mock from unittest.mock import PropertyMock @@ -19,7 +20,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.trainer.states import TrainerState -from tests.helpers import BoringModel, RandomDataset, BoringDataModule +from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -259,7 +260,7 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): def test_trainer_model_hook_system(tmpdir): - """Test the LightningModule hook system.""" + """Test the hooks system.""" class HookedModel(BoringModel): @@ -268,151 +269,149 @@ def __init__(self): self.called = [] def on_after_backward(self): - self.called.append("on_after_backward") + self.called.append(inspect.currentframe().f_code.co_name) super().on_after_backward() - def on_before_zero_grad(self, *args, **kwargs): - self.called.append("on_before_zero_grad") - super().on_before_zero_grad(*args, **kwargs) + def on_before_zero_grad(self, optimizer): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_before_zero_grad(optimizer) def on_epoch_start(self): - self.called.append("on_epoch_start") + self.called.append(inspect.currentframe().f_code.co_name) super().on_epoch_start() def on_epoch_end(self): - self.called.append("on_epoch_end") + self.called.append(inspect.currentframe().f_code.co_name) super().on_epoch_end() def on_fit_start(self): - self.called.append("on_fit_start") + self.called.append(inspect.currentframe().f_code.co_name) super().on_fit_start() def on_fit_end(self): - self.called.append("on_fit_end") + self.called.append(inspect.currentframe().f_code.co_name) super().on_fit_end() - def on_hpc_load(self, *args, **kwargs): - self.called.append("on_hpc_load") - super().on_hpc_load(*args, **kwargs) + def on_hpc_load(self, checkpoint): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_hpc_load(checkpoint) - def on_hpc_save(self, *args, **kwargs): - self.called.append("on_hpc_save") - super().on_hpc_save(*args, **kwargs) + def on_hpc_save(self, checkpoint): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_hpc_save(checkpoint) - def on_load_checkpoint(self, *args, **kwargs): - self.called.append("on_load_checkpoint") - super().on_load_checkpoint(*args, **kwargs) + def on_load_checkpoint(self, checkpoint): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_load_checkpoint(checkpoint) - def on_save_checkpoint(self, *args, **kwargs): - self.called.append("on_save_checkpoint") - super().on_save_checkpoint(*args, **kwargs) + def on_save_checkpoint(self, checkpoint): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_save_checkpoint(checkpoint) def on_pretrain_routine_start(self): - self.called.append("on_pretrain_routine_start") + self.called.append(inspect.currentframe().f_code.co_name) super().on_pretrain_routine_start() def on_pretrain_routine_end(self): - self.called.append("on_pretrain_routine_end") + self.called.append(inspect.currentframe().f_code.co_name) super().on_pretrain_routine_end() def on_train_start(self): - self.called.append("on_train_start") + self.called.append(inspect.currentframe().f_code.co_name) super().on_train_start() def on_train_end(self): - self.called.append("on_train_end") + self.called.append(inspect.currentframe().f_code.co_name) super().on_train_end() - def on_train_batch_start(self, *args, **kwargs): - self.called.append("on_train_batch_start") - super().on_train_batch_start(*args, **kwargs) + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_train_batch_start(batch, batch_idx, dataloader_idx) - def on_train_batch_end(self, *args, **kwargs): - self.called.append("on_train_batch_end") - super().on_train_batch_end(*args, **kwargs) + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) def on_train_epoch_start(self): - self.called.append("on_train_epoch_start") + self.called.append(inspect.currentframe().f_code.co_name) super().on_train_epoch_start() def on_train_epoch_end(self, outputs): - self.called.append("on_train_epoch_end") + self.called.append(inspect.currentframe().f_code.co_name) super().on_train_epoch_end(outputs) def on_validation_start(self): - self.called.append("on_validation_start") + self.called.append(inspect.currentframe().f_code.co_name) super().on_validation_start() def on_validation_end(self): - self.called.append("on_validation_end") + self.called.append(inspect.currentframe().f_code.co_name) super().on_validation_end() - def on_validation_batch_start(self, *args, **kwargs): - self.called.append("on_validation_batch_start") - super().on_validation_batch_start(*args, **kwargs) + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_validation_batch_start(batch, batch_idx, dataloader_idx) - def on_validation_batch_end(self, *args, **kwargs): - self.called.append("on_validation_batch_end") - super().on_validation_batch_end(*args, **kwargs) + def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) def on_validation_epoch_start(self): - self.called.append("on_validation_epoch_start") + self.called.append(inspect.currentframe().f_code.co_name) super().on_validation_epoch_start() - def on_validation_epoch_end(self, *args, **kwargs): - self.called.append("on_validation_epoch_end") - super().on_validation_epoch_end(*args, **kwargs) + def on_validation_epoch_end(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_validation_epoch_end() def on_test_start(self): - self.called.append("on_test_start") + self.called.append(inspect.currentframe().f_code.co_name) super().on_test_start() - def on_test_batch_start(self, *args, **kwargs): - self.called.append("on_test_batch_start") - super().on_test_batch_start(*args, **kwargs) + def on_test_batch_start(self, batch, batch_idx, dataloader_idx): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_test_batch_start(batch, batch_idx, dataloader_idx) - def on_test_batch_end(self, *args, **kwargs): - self.called.append("on_test_batch_end") - super().on_test_batch_end(*args, **kwargs) + def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_test_batch_end(outputs, batch, batch_idx, dataloader_idx) def on_test_epoch_start(self): - self.called.append("on_test_epoch_start") + self.called.append(inspect.currentframe().f_code.co_name) super().on_test_epoch_start() - def on_test_epoch_end(self, *args, **kwargs): - self.called.append("on_test_epoch_end") - super().on_test_epoch_end(*args, **kwargs) + def on_test_epoch_end(self): + self.called.append(inspect.currentframe().f_code.co_name) + super().on_test_epoch_end() def on_validation_model_eval(self): - self.called.append("on_validation_model_eval") + self.called.append(inspect.currentframe().f_code.co_name) super().on_validation_model_eval() def on_validation_model_train(self): - self.called.append("on_validation_model_train") + self.called.append(inspect.currentframe().f_code.co_name) super().on_validation_model_train() def on_test_model_eval(self): - self.called.append("on_test_model_eval") + self.called.append(inspect.currentframe().f_code.co_name) super().on_test_model_eval() def on_test_model_train(self): - self.called.append("on_test_model_train") + self.called.append(inspect.currentframe().f_code.co_name) super().on_test_model_train() def on_test_end(self): - self.called.append("on_test_end") + self.called.append(inspect.currentframe().f_code.co_name) super().on_test_end() - def setup(self, stage=None): - self.called.append(f"setup_{stage}") - super().setup(stage=stage) - def teardown(self, stage=None): - self.called.append(f"teardown_{stage}") + self.called.append(inspect.currentframe().f_code.co_name) super().teardown(stage) model = HookedModel() + assert model.called == [] + # fit model trainer = Trainer( default_root_dir=tmpdir, @@ -428,13 +427,11 @@ def teardown(self, stage=None): trainer.fit(model) expected = [ - 'setup_fit', 'on_fit_start', 'on_pretrain_routine_start', 'on_pretrain_routine_end', 'on_validation_model_eval', 'on_validation_start', - 'on_epoch_start', 'on_validation_epoch_start', 'on_validation_batch_start', 'on_validation_batch_end', @@ -457,7 +454,6 @@ def teardown(self, stage=None): 'on_epoch_end', 'on_validation_model_eval', 'on_validation_start', - 'on_epoch_start', 'on_validation_epoch_start', 'on_validation_batch_start', 'on_validation_batch_end', @@ -468,7 +464,7 @@ def teardown(self, stage=None): 'on_validation_model_train', 'on_train_end', 'on_fit_end', - 'teardown_fit', + 'teardown', ] assert model.called == expected @@ -476,10 +472,8 @@ def teardown(self, stage=None): trainer.validate(model, verbose=False) expected = [ - 'setup_validate', 'on_validation_model_eval', 'on_validation_start', - 'on_epoch_start', 'on_validation_epoch_start', 'on_validation_batch_start', 'on_validation_batch_end', @@ -487,18 +481,16 @@ def teardown(self, stage=None): 'on_epoch_end', 'on_validation_end', 'on_validation_model_train', - 'teardown_validate', + 'teardown', ] assert model.called == expected model = HookedModel() - trainer.test(model, verbose=False) + trainer.test(model, verbose=False) expected = [ - 'setup_test', 'on_test_model_eval', 'on_test_start', - 'on_epoch_start', 'on_test_epoch_start', 'on_test_batch_start', 'on_test_batch_end', @@ -506,119 +498,6 @@ def teardown(self, stage=None): 'on_epoch_end', 'on_test_end', 'on_test_model_train', - 'teardown_test', + 'teardown', ] assert model.called == expected - - -def test_trainer_datamodule_hook_system(tmpdir): - """Test the LightningDataModule hook system.""" - - class HookedDataModule(BoringDataModule): - def __init__(self): - super().__init__() - self.called = [] - - def prepare_data(self): - self.called.append("prepare_data") - super().prepare_data() - - def setup(self, stage=None): - self.called.append(f"setup_{stage}") - super().setup(stage=stage) - - def teardown(self, stage=None): - self.called.append(f"teardown_{stage}") - super().teardown(stage=stage) - - def train_dataloader(self): - self.called.append("train_dataloader") - return super().train_dataloader() - - def test_dataloader(self): - self.called.append("test_dataloader") - return super().test_dataloader() - - def val_dataloader(self): - self.called.append("val_dataloader") - return super().val_dataloader() - - def predict_dataloader(self): - self.called.append("predict_dataloader") - - def transfer_batch_to_device(self, *args, **kwargs): - self.called.append("transfer_batch_to_device") - return super().transfer_batch_to_device(*args, **kwargs) - - def on_before_batch_transfer(self, *args, **kwargs): - self.called.append("on_before_batch_transfer") - return super().on_before_batch_transfer(*args, **kwargs) - - def on_after_batch_transfer(self, *args, **kwargs): - self.called.append("on_after_batch_transfer") - return super().on_after_batch_transfer(*args, **kwargs) - - model = BoringModel() - dm = HookedDataModule() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=1, - limit_train_batches=2, - limit_test_batches=1, - progress_bar_refresh_rate=0, - weights_summary=None, - reload_dataloaders_every_epoch=True, - ) - trainer.fit(model, datamodule=dm) - - expected = [ - 'prepare_data', - 'setup_fit', - 'val_dataloader', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'train_dataloader', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'val_dataloader', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'teardown_fit' - ] - assert dm.called == expected - - dm = HookedDataModule() - trainer.validate(model, datamodule=dm, verbose=False) - - expected = [ - 'prepare_data', - 'setup_validate', - 'val_dataloader', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'teardown_validate' - ] - assert dm.called == expected - - dm = HookedDataModule() - trainer.test(model, datamodule=dm, verbose=False) - - expected = [ - 'prepare_data', - 'setup_test', - 'test_dataloader', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'teardown_test' - ] - assert dm.called == expected diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 3c8c9b0f36041..636979821b313 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -16,13 +16,11 @@ import shlex import subprocess import sys -from unittest.mock import patch import numpy as np import pytest import torch from sklearn.metrics import accuracy_score -from torch import optim import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils @@ -49,9 +47,6 @@ def _run_horovod(trainer_options, on_gpu=False): # for Horovod, we interpret `gpus` to be set per worker trainer_options.update(gpus=1 if on_gpu else None) tutils.reset_seed() - # todo: Find why coverage breaks CI. - # append = '-a' if '.coverage' in os.listdir(_PROJECT_ROOT) else '' # noqa E265 - # str(num_processes), sys.executable, '-m', 'coverage', 'run', '--source', 'pytorch_lightning', append, # noqa E265 cmdline = [ 'horovodrun', '-np', str(num_processes), sys.executable, TEST_SCRIPT, '--trainer-options', @@ -114,9 +109,7 @@ def test_horovod_multi_gpu(tmpdir): _run_horovod(trainer_options, on_gpu=True) -# https://discuss.pytorch.org/t/torch-cuda-amp-vs-nvidia-apex/74994 -# Check with (tgaddair) on Horovod issues if this feature is needed -@pytest.mark.skip(reason="Horovod currently doesn't work with Apex") # todo +@pytest.mark.skip(reason="Horovod has a problem with broadcast when using apex?") # todo @RunIf(min_gpus=2, skip_windows=True, amp_apex=True, horovod_nccl=True) def test_horovod_apex(tmpdir): """Test Horovod with multi-GPU support using apex amp.""" @@ -137,6 +130,7 @@ def test_horovod_apex(tmpdir): _run_horovod(trainer_options, on_gpu=True) +@pytest.mark.skip(reason="Skip till Horovod fixes integration with Native torch.cuda.amp") # todo @RunIf(min_gpus=2, skip_windows=True, amp_native=True, horovod_nccl=True) def test_horovod_amp(tmpdir): """Test Horovod with multi-GPU support using native amp.""" @@ -157,24 +151,6 @@ def test_horovod_amp(tmpdir): _run_horovod(trainer_options, on_gpu=True) -@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) -def test_horovod_gather(tmpdir): - """Test Horovod with multi-GPU support using native amp.""" - trainer_options = dict( - default_root_dir=str(tmpdir), - weights_save_path=str(tmpdir), - gradient_clip_val=1.0, - progress_bar_refresh_rate=0, - max_epochs=1, - limit_train_batches=0.4, - limit_val_batches=0.2, - gpus=2, - deterministic=True, - accelerator='horovod', - ) - _run_horovod(trainer_options, on_gpu=True) - - @RunIf(min_gpus=1, skip_windows=True, horovod_nccl=True) def test_horovod_transfer_batch_to_gpu(tmpdir): @@ -203,7 +179,7 @@ def validation_step(self, batch, *args, **kwargs): tpipes.run_model_test_without_loggers(trainer_options, model) -@RunIf(skip_windows=True, horovod=True) +@RunIf(skip_windows=True) def test_horovod_multi_optimizer(tmpdir): model = BasicGAN() @@ -235,7 +211,8 @@ def get_optimizer_params(optimizer): assert get_model_params(model.discriminator) == get_optimizer_params(trainer.optimizers[1]) -@pytest.mark.skipif(reason="CI agent.jobstatus=Succeeded: Permission denied") +# TODO: unclear Horovod failure... +@pytest.mark.skip(reason="unclear Horovod failure...") @RunIf(skip_windows=True, horovod=True) def test_result_reduce_horovod(tmpdir): """Make sure result logging works with Horovod. @@ -277,7 +254,6 @@ def training_epoch_end(self, outputs) -> None: max_epochs=1, log_every_n_steps=1, weights_summary=None, - logger=False ) trainer.fit(model) @@ -285,8 +261,9 @@ def training_epoch_end(self, outputs) -> None: horovod.run(hvd_test_fn, np=2) -@pytest.mark.skipif(reason="CI agent.jobstatus=Succeeded: Permission denied") -@RunIf(skip_windows=True, horovod=True, num_gpus=2) +# TODO: unclear Horovod failure... +@pytest.mark.skip(reason="unclear Horovod failure...") +@RunIf(skip_windows=True, horovod=True) def test_accuracy_metric_horovod(): num_batches = 10 batch_size = 16 @@ -301,7 +278,10 @@ def sk_metric(preds, target): target = torch.randint(high=2, size=(num_batches, batch_size)) def _compute_batch(): - trainer = Trainer(fast_dev_run=True, accelerator='horovod', logger=False) + trainer = Trainer( + fast_dev_run=True, + accelerator='horovod', + ) assert isinstance(trainer.accelerator, CPUAccelerator) # TODO: test that we selected the correct training_type_plugin based on horovod flags @@ -309,7 +289,7 @@ def _compute_batch(): metric = Accuracy( compute_on_step=True, dist_sync_on_step=True, - dist_sync_fn=trainer.training_type_plugin.all_gather, + dist_sync_fn=trainer.training_type_plugin.gather_all_tensors, threshold=threshold ) @@ -334,45 +314,33 @@ def _compute_batch(): horovod.run(_compute_batch, np=2) -@RunIf(skip_windows=True, horovod=True) -def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir): - - class TestModel(BoringModel): - - def training_step(self, batch, batch_idx, optimizer_idx): - return super().training_step(batch, batch_idx) - - def configure_optimizers(self): - optimizer1 = optim.Adam(self.parameters(), lr=0.1) - optimizer2 = optim.Adam(self.parameters(), lr=0.1) - lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1) - lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) - return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] - - model = TestModel() - model.training_epoch_end = None - - num_workers = 8 - init_lr = 0.1 * num_workers - - with patch('horovod.torch.size', return_value=8): - - # fit model - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=0.5, - limit_train_batches=0.2, - accelerator='horovod' - ) - results = trainer.fit(model) - assert results == 1 - - adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0] - adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0] - - # Called ones after end of epoch with gamma=0.1 - assert pytest.approx(init_lr * 0.1) == adjusted_lr1 - - # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times with gamma=0.1 - assert pytest.approx(init_lr * 0.1) == adjusted_lr2 +# @RunIf(skip_windows=True) +# def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir): +# model = BoringModel() +# model.configure_optimizers = model.configure_optimizers__multiple_schedulers +# +# num_workers = 8 +# init_lr = hparams.get('learning_rate') * num_workers +# +# with patch('pytorch_lightning.accelerators.legacy.horovod_backend.hvd.size') as mock_hvd_size: +# mock_hvd_size.return_value = 8 +# +# # fit model +# trainer = Trainer( +# default_root_dir=tmpdir, +# max_epochs=1, +# limit_val_batches=0.5, +# limit_train_batches=0.2, +# distributed_backend='horovod' +# ) +# results = trainer.fit(model) +# assert results == 1 +# +# adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0] +# adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0] +# +# # Called ones after end of epoch with gamma=0.1 +# assert pytest.approx(init_lr * 0.1) == adjusted_lr1 +# +# # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times with gamma=0.1 +# assert pytest.approx(init_lr * 0.1) == adjusted_lr2 diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index b2ed0db87d8d5..0c922c99149fa 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -355,44 +355,3 @@ def test_reduce(rank): assert result.item() == 8 xmp.spawn(test_reduce, nprocs=8, start_method='fork') - - -@RunIf(tpu=True) -@pl_multi_process_test -@pytest.mark.parametrize("clip_val", [10]) -@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_") -def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir): - """ - Ensure that clip gradients is only called if the value is greater than 0. - TODO: Fix (test fails with parametrize) - """ - tutils.reset_seed() - trainer_options = dict( - default_root_dir=tmpdir, - progress_bar_refresh_rate=0, - max_epochs=1, - tpu_cores=1, - precision=16, - limit_train_batches=4, - limit_val_batches=4, - gradient_clip_val=clip_val, - ) - model = BoringModel() - tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) - - if clip_val > 0: - mock_clip_grad_norm.assert_called() - else: - mock_clip_grad_norm.assert_not_called() - - -@RunIf(tpu=True) -@pl_multi_process_test -def test_if_test_works_with_checkpoint_false(tmpdir): - """Ensure that model trains properly when `checkpoint_callback` is set to False.""" - - # Train a model on TPU - model = BoringModel() - trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False) - trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index aaf47c82d5f08..3921e7ef33b8e 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -24,7 +24,7 @@ ("training", "training_step"), ("testing", "test_step"), ("validating", "validation_step"), - ("predicting", "predict_step"), + ("predicting", "predict"), ] ) def test_lightning_wrapper_module_methods(wrapper_class, stage): diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py deleted file mode 100644 index 872b49ef48635..0000000000000 --- a/tests/plugins/test_custom_plugin.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pytorch_lightning import Trainer -from pytorch_lightning.plugins import DDPPlugin -from tests.helpers import BoringModel -from tests.helpers.runif import RunIf - - -class CustomParallelPlugin(DDPPlugin): - - def __init__(self, **kwargs): - super().__init__(**kwargs) - # Set to None so it will be overwritten by the accelerator connector. - self.sync_batchnorm = None - - -@RunIf(skip_windows=True) -def test_sync_batchnorm_set(tmpdir): - """Tests if sync_batchnorm is automatically set for custom plugin.""" - model = BoringModel() - plugin = CustomParallelPlugin() - assert plugin.sync_batchnorm is None - trainer = Trainer( - max_epochs=1, - plugins=[plugin], - default_root_dir=tmpdir, - sync_batchnorm=True, - ) - trainer.fit(model) - assert plugin.sync_batchnorm is True diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index e6b15069f256a..cf5c23a824732 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -180,7 +180,7 @@ def test_deepspeed_defaults(tmpdir): assert isinstance(plugin.config["zero_optimization"], dict) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(deepspeed=True) def test_invalid_deepspeed_defaults_no_precision(tmpdir): """Test to ensure that using defaults, if precision is not set to 16, we throw an exception.""" model = BoringModel() @@ -195,7 +195,7 @@ def test_invalid_deepspeed_defaults_no_precision(tmpdir): trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True) def test_warn_deepspeed_override_backward(tmpdir): """Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning.""" @@ -216,7 +216,7 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True) def test_deepspeed_run_configure_optimizers(tmpdir): """Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), whilst using configure_optimizers for optimizers and schedulers.""" @@ -246,7 +246,7 @@ def on_train_start(self) -> None: _assert_save_model_is_equal(model, tmpdir, trainer) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True) def test_deepspeed_config(tmpdir, deepspeed_zero_config): """ Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers @@ -280,7 +280,7 @@ def on_train_start(self) -> None: _assert_save_model_is_equal(model, tmpdir, trainer) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True) def test_deepspeed_custom_precision_params(tmpdir): """Ensure if we modify the FP16 parameters via the DeepSpeedPlugin, the deepspeed config contains these changes.""" @@ -301,7 +301,7 @@ def on_train_start(self) -> None: trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True) def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_config): """Ensure if we use a config and turn off cpu_offload, that this is set to False within the config.""" diff --git a/tests/plugins/test_double_plugin.py b/tests/plugins/test_double_plugin.py deleted file mode 100644 index f089b1c23149e..0000000000000 --- a/tests/plugins/test_double_plugin.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest - -import torch -from torch.utils.data import DataLoader, Dataset - -from pytorch_lightning import Trainer -from tests.helpers.boring_model import BoringModel, RandomDataset - - -class RandomFloatIntDataset(Dataset): - - def __init__(self, size, length): - self.len = length - self.float_data = torch.randn(length, size) - self.int_data = torch.randint(10, (length, 1)) - - def __getitem__(self, index): - return self.float_data[index], self.int_data[index] - - def __len__(self): - return self.len - - -class DoublePrecisionBoringModel(BoringModel): - - def training_step(self, batch, batch_idx): - float_data, int_data = batch - assert float_data.dtype == torch.float64 - output = self(float_data) - loss = self.loss(batch, output) - return {"loss": loss} - - def validation_step(self, batch, batch_idx): - assert batch.dtype == torch.float64 - output = self(batch) - loss = self.loss(batch, output) - return {"x": loss} - - def test_step(self, batch, batch_idx): - assert batch.dtype == torch.float64 - output = self(batch) - loss = self.loss(batch, output) - return {"y": loss} - - def predict_step(self, batch, batch_idx, dataloader_idx=None): - assert batch.dtype == torch.float64 - return self(batch) - - def on_fit_start(self): - assert self.layer.weight.dtype == torch.float64 - - def on_after_backward(self): - assert self.layer.weight.grad.dtype == torch.float64 - - def train_dataloader(self): - dataset = RandomFloatIntDataset(32, 64) - assert dataset.float_data.dtype == torch.float32 # Don't start with double data - return DataLoader(dataset) - - def predict_dataloader(self): - return DataLoader(RandomDataset(32, 64)) - - -class DoublePrecisionBoringModelNoForward(BoringModel): - - def training_step(self, batch, batch_idx): - assert batch.dtype == torch.float64 - output = self.layer(batch) - assert output.dtype == torch.float64 - loss = self.loss(batch, output) - return {"loss": loss} - - def validation_step(self, batch, batch_idx): - assert batch.dtype == torch.float64 - output = self.layer(batch) - assert output.dtype == torch.float64 - loss = self.loss(batch, output) - return {"x": loss} - - def test_step(self, batch, batch_idx): - assert batch.dtype == torch.float64 - output = self.layer(batch) - assert output.dtype == torch.float64 - loss = self.loss(batch, output) - return {"y": loss} - - def predict_step(self, batch, batch_idx, dataloader_idx=None): - assert batch.dtype == torch.float64 - output = self.layer(batch) - assert output.dtype == torch.float64 - return output - - def predict_dataloader(self): - return DataLoader(RandomDataset(32, 64)) - - -@pytest.mark.parametrize( - 'boring_model', - (DoublePrecisionBoringModel, DoublePrecisionBoringModelNoForward) -) -def test_double_precision(tmpdir, boring_model): - model = boring_model() - original_training_step = model.training_step - - trainer = Trainer( - max_epochs=2, - default_root_dir=tmpdir, - fast_dev_run=2, - precision=64, - log_every_n_steps=1, - ) - trainer.fit(model) - trainer.test(model) - trainer.predict(model) - - assert model.training_step == original_training_step diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 655e12f046e04..a48f048160ee5 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -259,12 +259,10 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): @RunIf(skip_windows=True, special=True, fairscale=True) -@pytest.mark.parametrize( - "trainer_kwargs", ( - dict(num_processes=2), - pytest.param(dict(gpus=2), marks=RunIf(min_gpus=2)), - ) -) +@pytest.mark.parametrize("trainer_kwargs", ( + {'num_processes': 2}, + pytest.param({'gpus': 2}, marks=RunIf(min_gpus=2)) +)) def test_ddp_sharded_plugin_test_multigpu(tmpdir, trainer_kwargs): """ Test to ensure we can use validate and test without fit diff --git a/tests/special_tests.sh b/tests/special_tests.sh index c381b5e9feeb6..b2ef6dfdacbf3 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -14,15 +14,9 @@ # Running special tests set -e export PL_RUNNING_SPECIAL_TESTS=1 -DEFAULTS="-m coverage run --source pytorch_lightning --append -m pytest --verbose --capture=no" +DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp -python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_invalid_deepspeed_defaults_no_precision -python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_warn_deepspeed_override_backward -python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_run_configure_optimizers -python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_config -python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_custom_precision_params -python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_assert_config_zero_offload_disabled python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_multigpu python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual @@ -34,9 +28,8 @@ python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_ddp python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_dp python ${DEFAULTS} tests/trainer/logging_/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp python ${DEFAULTS} tests/callbacks/test_pruning.py::test_pruning_callback_ddp -python ${DEFAULTS} tests/test_profiler.py::test_pytorch_profiler_trainer_ddp +python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_ddp python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp python ${DEFAULTS} tests/trainer/test_data_loading.py::test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model -python ${DEFAULTS} tests/checkpointing/test_checkpoint_callback_frequency.py::test_top_k_ddp -nvprof --profile-from-start off -o trace_name.prof -- python ${DEFAULTS} tests/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx +nvprof --profile-from-start off -o trace_name.prof -- python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_nested_emit_nvtx diff --git a/tests/test_profiler.py b/tests/test_profiler.py index a6e33b3366f33..9b51ca7f7c6d2 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -13,22 +13,13 @@ # limitations under the License. import logging import os -import platform import time -from copy import deepcopy -from distutils.version import LooseVersion +from pathlib import Path import numpy as np import pytest -import torch -from pytorch_lightning import Callback, Trainer -from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler -from pytorch_lightning.profiler.pytorch import RegisterRecordFunction -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE -from tests.helpers import BoringModel -from tests.helpers.runif import RunIf +from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005 @@ -49,7 +40,14 @@ def _sleep_generator(durations): @pytest.fixture def simple_profiler(): - return SimpleProfiler() + profiler = SimpleProfiler() + return profiler + + +@pytest.fixture +def advanced_profiler(tmpdir): + profiler = AdvancedProfiler(output_filename=os.path.join(tmpdir, "profiler.txt")) + return profiler @pytest.mark.parametrize(["action", "expected"], [ @@ -95,6 +93,14 @@ def test_simple_profiler_overhead(simple_profiler, n_iter=5): assert all(durations < PROFILER_OVERHEAD_MAX_TOLERANCE) +def test_simple_profiler_describe(caplog, simple_profiler): + """Ensure the profiler won't fail when reporting the summary.""" + with caplog.at_level(logging.INFO): + simple_profiler.describe() + + assert "Profiler Report" in caplog.text + + def test_simple_profiler_value_errors(simple_profiler): """Ensure errors are raised where expected.""" @@ -110,77 +116,6 @@ def test_simple_profiler_value_errors(simple_profiler): simple_profiler.stop(action) -def test_simple_profiler_deepcopy(tmpdir): - simple_profiler = SimpleProfiler(dirpath=tmpdir, filename="test") - simple_profiler.describe() - assert deepcopy(simple_profiler) - - -def test_simple_profiler_log_dir(tmpdir): - """Ensure the profiler dirpath defaults to `trainer.log_dir` when not present""" - profiler = SimpleProfiler(filename="profiler") - assert profiler._log_dir is None - - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - profiler=profiler, - ) - trainer.fit(model) - - expected = tmpdir / "lightning_logs" / "version_0" - assert trainer.log_dir == expected - assert profiler._log_dir == trainer.log_dir - assert expected.join("fit-profiler.txt").exists() - - -@RunIf(skip_windows=True) -def test_simple_profiler_distributed_files(tmpdir): - """Ensure the proper files are saved in distributed""" - profiler = SimpleProfiler(dirpath=tmpdir, filename='profiler') - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=2, - accelerator="ddp_cpu", - num_processes=2, - profiler=profiler, - logger=False, - ) - trainer.fit(model) - trainer.validate(model) - trainer.test(model) - - actual = set(os.listdir(profiler.dirpath)) - expected = {f"{stage}-profiler-{rank}.txt" for stage in ("fit", "validate", "test") for rank in (0, 1)} - assert actual == expected - - for f in profiler.dirpath.listdir(): - assert f.read_text('utf-8') - - -def test_simple_profiler_logs(tmpdir, caplog, simple_profiler): - """Ensure that the number of printed logs is correct""" - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=2, - profiler=simple_profiler, - logger=False, - ) - with caplog.at_level(logging.INFO, logger="pytorch_lightning.profiler.profilers"): - trainer.fit(model) - trainer.test(model) - - assert caplog.text.count("Profiler Report") == 2 - - -@pytest.fixture -def advanced_profiler(tmpdir): - return AdvancedProfiler(dirpath=tmpdir, filename="profiler") - - @pytest.mark.parametrize(["action", "expected"], [ pytest.param("a", [3, 1]), pytest.param("b", [2]), @@ -239,8 +174,7 @@ def test_advanced_profiler_describe(tmpdir, advanced_profiler): pass # log to stdout and print to file advanced_profiler.describe() - path = advanced_profiler.dirpath / f"{advanced_profiler.filename}.txt" - data = path.read_text("utf-8") + data = Path(advanced_profiler.output_fname).read_text() assert len(data) > 0 @@ -253,259 +187,3 @@ def test_advanced_profiler_value_errors(advanced_profiler): advanced_profiler.start(action) advanced_profiler.stop(action) - - -def test_advanced_profiler_deepcopy(advanced_profiler): - advanced_profiler.describe() - assert deepcopy(advanced_profiler) - - -@pytest.fixture -def pytorch_profiler(tmpdir): - return PyTorchProfiler(dirpath=tmpdir, filename="profiler") - - -@RunIf(max_torch="1.8.1") -def test_pytorch_profiler_describe(pytorch_profiler): - """Ensure the profiler won't fail when reporting the summary.""" - with pytorch_profiler.profile("on_test_start"): - torch.tensor(0) - - # log to stdout and print to file - pytorch_profiler.describe() - path = pytorch_profiler.dirpath / f"{pytorch_profiler.filename}.txt" - data = path.read_text("utf-8") - assert len(data) > 0 - - -def test_pytorch_profiler_raises(pytorch_profiler): - """Ensure errors are raised where expected.""" - with pytest.raises(MisconfigurationException, match="profiled_functions` and `PyTorchProfiler.record"): - PyTorchProfiler(profiled_functions=["a"], record_functions=["b"]) - - -@RunIf(min_torch="1.6.0") -def test_advanced_profiler_cprofile_deepcopy(tmpdir): - """Checks for pickle issue reported in #6522""" - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - profiler="advanced", - stochastic_weight_avg=True, - ) - trainer.fit(model) - - -@RunIf(min_gpus=2, special=True) -def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler): - """Ensure that the profiler can be given to the training and default step are properly recorded. """ - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=5, - limit_val_batches=5, - profiler=pytorch_profiler, - accelerator="ddp", - gpus=2, - ) - trainer.fit(model) - - expected = {'validation_step'} - if not _KINETO_AVAILABLE: - expected |= {'training_step_and_backward', 'training_step', 'backward'} - for name in expected: - assert sum(e.name == name for e in pytorch_profiler.function_events), name - - files = set(os.listdir(pytorch_profiler.dirpath)) - expected = f"fit-profiler-{trainer.local_rank}.txt" - assert expected in files - - path = pytorch_profiler.dirpath / expected - assert path.read_text("utf-8") - - if _KINETO_AVAILABLE: - files = os.listdir(pytorch_profiler.dirpath) - files = [file for file in files if file.endswith('.json')] - assert len(files) == 2, files - local_rank = trainer.local_rank - assert any(f'training_step_{local_rank}' in f for f in files) - assert any(f'validation_step_{local_rank}' in f for f in files) - - -def test_pytorch_profiler_trainer_test(tmpdir): - """Ensure that the profiler can be given to the trainer and test step are properly recorded. """ - pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_test_batches=2, - profiler=pytorch_profiler, - ) - trainer.test(model) - - assert sum(e.name == 'test_step' for e in pytorch_profiler.function_events) - - path = pytorch_profiler.dirpath / f"test-{pytorch_profiler.filename}.txt" - assert path.read_text("utf-8") - - if _KINETO_AVAILABLE: - files = sorted([file for file in os.listdir(tmpdir) if file.endswith('.json')]) - assert any(f'test_step_{trainer.local_rank}' in f for f in files) - - -def test_pytorch_profiler_trainer_predict(tmpdir): - """Ensure that the profiler can be given to the trainer and predict function are properly recorded. """ - pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) - model = BoringModel() - model.predict_dataloader = model.train_dataloader - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_predict_batches=2, - profiler=pytorch_profiler, - ) - trainer.predict(model) - - assert sum(e.name == 'predict_step' for e in pytorch_profiler.function_events) - path = pytorch_profiler.dirpath / f"predict-{pytorch_profiler.filename}.txt" - assert path.read_text("utf-8") - - -def test_pytorch_profiler_trainer_validate(tmpdir): - """Ensure that the profiler can be given to the trainer and validate function are properly recorded. """ - pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=2, - profiler=pytorch_profiler, - ) - trainer.validate(model) - - assert sum(e.name == 'validation_step' for e in pytorch_profiler.function_events) - - path = pytorch_profiler.dirpath / f"validate-{pytorch_profiler.filename}.txt" - assert path.read_text("utf-8") - - -def test_pytorch_profiler_nested(tmpdir): - """Ensure that the profiler handles nested context""" - - pytorch_profiler = PyTorchProfiler( - record_functions={"a", "b", "c"}, use_cuda=False, dirpath=tmpdir, filename="profiler", schedule=None - ) - - with pytorch_profiler.profile("a"): - a = torch.ones(42) - with pytorch_profiler.profile("b"): - b = torch.zeros(42) - with pytorch_profiler.profile("c"): - _ = a + b - - pytorch_profiler.describe() - - events_name = {e.name for e in pytorch_profiler.function_events} - - if platform.system() == "Windows": - expected = {'a', 'add', 'b', 'c', 'profiler::_record_function_enter', 'profiler::_record_function_exit'} - else: - expected = { - 'signed char', 'add', 'profiler::_record_function_exit', 'bool', 'char', 'profiler::_record_function_enter' - } - - if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - expected = {'add', 'zeros', 'ones', 'zero_', 'b', 'fill_', 'c', 'a', 'empty'} - - if LooseVersion(torch.__version__) >= LooseVersion("1.7.0"): - expected = { - 'aten::zeros', 'aten::add', 'aten::zero_', 'c', 'b', 'a', 'aten::fill_', 'aten::empty', 'aten::ones' - } - - assert events_name == expected, (events_name, torch.__version__, platform.system()) - - -@RunIf(min_gpus=1, special=True) -def test_pytorch_profiler_nested_emit_nvtx(tmpdir): - """ - This test check emit_nvtx is correctly supported - """ - profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True) - - model = BoringModel() - trainer = Trainer( - fast_dev_run=True, - profiler=profiler, - gpus=1, - ) - trainer.fit(model) - - -@RunIf(min_torch="1.5.0") -def test_register_record_function(tmpdir): - - use_cuda = torch.cuda.is_available() - pytorch_profiler = PyTorchProfiler( - export_to_chrome=False, - record_functions={"a"}, - use_cuda=use_cuda, - dirpath=tmpdir, - filename="profiler", - schedule=None, - on_trace_ready=None, - ) - - class TestModel(BoringModel): - - def __init__(self): - super().__init__() - self.layer = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), torch.nn.Linear(1, 1)) - - model = TestModel() - input = torch.rand((1, 1)) - - if use_cuda: - model = model.cuda() - input = input.cuda() - - with pytorch_profiler.profile("a"): - with RegisterRecordFunction(model): - model(input) - - pytorch_profiler.describe() - event_names = [e.name for e in pytorch_profiler.function_events] - assert 'torch.nn.modules.container.Sequential: layer' in event_names - assert 'torch.nn.modules.linear.Linear: layer.0' in event_names - assert 'torch.nn.modules.activation.ReLU: layer.1' in event_names - assert 'torch.nn.modules.linear.Linear: layer.2' in event_names - - -@pytest.mark.parametrize("cls", (SimpleProfiler, AdvancedProfiler, PyTorchProfiler)) -def test_profiler_teardown(tmpdir, cls): - """ - This test checks if profiler teardown method is called when trainer is exiting. - """ - - class TestCallback(Callback): - - def on_fit_end(self, trainer, *args, **kwargs) -> None: - # describe sets it to None - assert trainer.profiler._output_file is None - - profiler = cls(dirpath=tmpdir, filename="profiler") - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()]) - trainer.fit(model) - - assert profiler._output_file is None - - -def test_pytorch_profiler_deepcopy(tmpdir): - pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profiler", schedule=None) - pytorch_profiler.start("on_train_start") - torch.tensor(1) - pytorch_profiler.describe() - assert deepcopy(pytorch_profiler) diff --git a/tests/trainer/flags/test_env_vars.py b/tests/trainer/flags/test_env_vars.py index 65b251a6633b5..ba76820d15ee8 100644 --- a/tests/trainer/flags/test_env_vars.py +++ b/tests/trainer/flags/test_env_vars.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from unittest import mock from pytorch_lightning import Trainer -def test_passing_no_env_variables(): +def test_passing_env_variables(tmpdir): """Testing overwriting trainer arguments """ trainer = Trainer() assert trainer.logger is not None @@ -26,29 +25,17 @@ def test_passing_no_env_variables(): assert trainer.logger is None assert trainer.max_steps == 42 - -@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "False", "PL_TRAINER_MAX_STEPS": "7"}) -def test_passing_env_variables_only(): - """Testing overwriting trainer arguments """ + os.environ['PL_TRAINER_LOGGER'] = 'False' + os.environ['PL_TRAINER_MAX_STEPS'] = '7' trainer = Trainer() assert trainer.logger is None assert trainer.max_steps == 7 - -@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "True", "PL_TRAINER_MAX_STEPS": "7"}) -def test_passing_env_variables_defaults(): - """Testing overwriting trainer arguments """ + os.environ['PL_TRAINER_LOGGER'] = 'True' trainer = Trainer(False, max_steps=42) - assert trainer.logger is None - assert trainer.max_steps == 42 - + assert trainer.logger is not None + assert trainer.max_steps == 7 -@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"}) -@mock.patch('torch.cuda.device_count', return_value=2) -@mock.patch('torch.cuda.is_available', return_value=True) -def test_passing_env_variables_gpus(cuda_available_mock, device_count_mock): - """Testing overwriting trainer arguments """ - trainer = Trainer() - assert trainer.gpus == 2 - trainer = Trainer(gpus=1) - assert trainer.gpus == 1 + # this has to be cleaned + del os.environ['PL_TRAINER_LOGGER'] + del os.environ['PL_TRAINER_MAX_STEPS'] diff --git a/tests/trainer/logging_/test_eval_loop_logging_1_0.py b/tests/trainer/logging_/test_eval_loop_logging_1_0.py index 674e2aeb6511b..72084454ba10d 100644 --- a/tests/trainer/logging_/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_/test_eval_loop_logging_1_0.py @@ -126,6 +126,7 @@ def validation_step_end(self, acc): def validation_epoch_end(self, outputs): self.log('g', torch.tensor(2, device=self.device), on_epoch=True) self.validation_epoch_end_called = True + assert len(self.trainer.evaluation_loop.outputs) == 0 def backward(self, loss, optimizer, optimizer_idx): return LightningModule.backward(self, loss, optimizer, optimizer_idx) @@ -495,15 +496,9 @@ def on_validation_start(self, trainer, pl_module): ) def on_epoch_start(self, trainer, pl_module): - if trainer.validating: - self.make_logging( - pl_module, - 'on_epoch_start', - 2, - on_steps=self.choices, - on_epochs=self.choices, - prob_bars=self.choices - ) + self.make_logging( + pl_module, 'on_epoch_start', 2, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + ) def on_validation_epoch_start(self, trainer, pl_module): self.make_logging( @@ -535,7 +530,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, self.count += 1 def on_epoch_end(self, trainer, pl_module): - if trainer.validating: + if not trainer.training: self.make_logging( pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices ) @@ -573,6 +568,7 @@ def validation_step(self, batch, batch_idx): callbacks=[test_callback], ) trainer.fit(model) + trainer.test() assert test_callback.funcs_called_count["on_epoch_start"] == 1 # assert test_callback.funcs_called_count["on_batch_start"] == 1 diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index d14ed71940328..3db0a8eaa065b 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -447,38 +447,13 @@ def is_float(value: Any) -> bool: "y": torch.tensor(2), "z": acc(preds, targets), }) - metric_holder.convert(device) + metric_holder.convert(False, device) metrics = metric_holder.metrics assert excepted_function(metrics["x"]) assert excepted_function(metrics["y"]) assert excepted_function(metrics["z"]) -def test_metric_holder_raises(tmpdir): - """Check that an error is raised when trying to convert non-scalar tensors""" - - class TestModel(BoringModel): - - def validation_step(self, batch, *args, **kwargs): - output = self(batch) - return {"test": output} - - def test_step(self, *args, **kwargs): - return self.validation_step(*args, **kwargs) - - model = TestModel() - model.validation_epoch_end = None - model.test_epoch_end = None - - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - - match = "The metric `test` does not contain a single element" - with pytest.raises(MisconfigurationException, match=match): - trainer.validate(model) - with pytest.raises(MisconfigurationException, match=match): - trainer.test(model) - - def test_logging_to_progress_bar_with_reserved_key(tmpdir): """ Test that logging a metric with a reserved name to the progress bar raises a warning. """ @@ -490,7 +465,10 @@ def training_step(self, *args, **kwargs): return output model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + ) with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"): trainer.fit(model) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index f13448187364c..34845c46b45eb 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -292,9 +292,7 @@ def test_init_optimizers_during_evaluation(tmpdir, fn): """ Test that optimizers is an empty list during evaluation """ - class TestModel(BoringModel): - def configure_optimizers(self): optimizer1 = torch.optim.Adam(self.parameters(), lr=0.1) optimizer2 = torch.optim.Adam(self.parameters(), lr=0.1) diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 5dc1ea5de4e8a..4dc5b5f34b50c 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -13,6 +13,7 @@ # limitations under the License. from pytorch_lightning import Trainer +from tests.accelerators import DDPLauncher from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -80,3 +81,25 @@ def test_get_model_gpu(tmpdir): gpus=1, ) trainer.fit(model) + + +@RunIf(min_gpus=1, skip_windows=True) +@DDPLauncher.run("--accelerator [accelerator]", max_epochs=["1"], accelerator=["ddp", "ddp_spawn"]) +def test_get_model_ddp_gpu(tmpdir, args=None): + """ + Tests that `trainer.lightning_module` extracts the model correctly when using GPU + ddp accelerators + """ + + model = TrainerGetModel() + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + gpus=1, + accelerator=args.accelerator + ) + trainer.fit(model) + return 1 diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index 9fccd9b36440a..59e10480a485e 100644 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -import torch -from pytorch_lightning import LightningDataModule, LightningModule, Trainer +from pytorch_lightning import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel, RandomDataset +from tests.helpers import BoringModel def test_wrong_train_setting(tmpdir): @@ -102,48 +101,3 @@ def test_val_loop_config(tmpdir): model = BoringModel() model.validation_step = None trainer.validate(model) - - -@pytest.mark.parametrize("datamodule", [False, True]) -def test_trainer_predict_verify_config(tmpdir, datamodule): - - class TestModel(LightningModule): - - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(32, 2) - - def forward(self, x): - return self.layer(x) - - class TestLightningDataModule(LightningDataModule): - - def __init__(self, dataloaders): - super().__init__() - self._dataloaders = dataloaders - - def test_dataloader(self): - return self._dataloaders - - def predict_dataloader(self): - return self._dataloaders - - dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] - - model = TestModel() - - trainer = Trainer(default_root_dir=tmpdir) - - if datamodule: - datamodule = TestLightningDataModule(dataloaders) - results = trainer.predict(model, datamodule=datamodule) - else: - results = trainer.predict(model, dataloaders=dataloaders) - - assert len(results) == 2 - assert results[0][0].shape == torch.Size([1, 2]) - - model.predict_dataloader = None - - with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"): - trainer.predict(model) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 505af173b7910..e4aea38fb7f37 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -137,7 +137,6 @@ def test_multiple_eval_dataloader(tmpdir, ckpt_path): """Verify multiple evaluation dataloaders.""" class MultipleTestDataloaderModel(EvalModelTemplate): - def test_dataloader(self): return [self.dataloader(train=False), self.dataloader(train=False)] @@ -1159,71 +1158,3 @@ def test_replace_sampler_with_multiprocessing_context(tmpdir): new_data_loader = trainer.replace_sampler(train, SequentialSampler(train.dataset)) assert (new_data_loader.multiprocessing_context == train.multiprocessing_context) - - -def test_request_dataloader(tmpdir): - """ - This test asserts dataloader can be modified and properly set to the trainer. - """ - - class DataLoaderWrapper: - - def __init__(self, loader): - self.loader = loader - self._iter = iter(self.loader) - - def __iter__(self): - self._iter = iter(self.loader) - return self._iter - - def __next__(self): - return next(self._iter) - - class DataLoaderFunc: - - def __init__(self, loader): - self.loader = loader - - def __call__(self): - return self.loader - - class TestModel(BoringModel): - - def __init__(self): - super().__init__() - self.on_train_dataloader_called = False - self.on_train_batch_start_called = False - self.on_val_dataloader_called = False - self.on_val_batch_start_called = False - - def on_train_dataloader(self) -> None: - loader = self.train_dataloader() - self.train_dataloader = DataLoaderFunc(DataLoaderWrapper(loader)) - self.on_train_dataloader_called = True - - def on_train_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None: - assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper) - self.on_train_batch_start_called = True - - def on_val_dataloader(self) -> None: - loader = self.val_dataloader() - self.val_dataloader = DataLoaderFunc(DataLoaderWrapper(loader)) - self.on_val_dataloader_called = True - - def on_validation_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None: - assert isinstance(self.trainer.val_dataloaders[0], DataLoaderWrapper) - self.on_val_batch_start_called = True - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - ) - model = TestModel() - trainer.fit(model) - trainer.test(model) - assert model.on_train_dataloader_called - assert model.on_train_batch_start_called - assert model.on_val_dataloader_called - assert model.on_val_batch_start_called diff --git a/tests/trainer/test_evaluation_loop.py b/tests/trainer/test_evaluation_loop.py deleted file mode 100644 index 3fe58afde7341..0000000000000 --- a/tests/trainer/test_evaluation_loop.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from unittest import mock - -from pytorch_lightning import Trainer -from tests.helpers.boring_model import BoringModel - - -@mock.patch("pytorch_lightning.trainer.evaluation_loop.EvaluationLoop.call_on_evaluation_epoch_end_hook") -def test_call_on_evaluation_epoch_end_hook(eval_epoch_end_mock, tmpdir): - """ - Tests that `call_on_evaluation_epoch_end_hook` is called - for `on_validation_epoch_end` and `on_test_epoch_end` hooks - """ - model = BoringModel() - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=2, - weights_summary=None, - ) - - trainer.fit(model) - # sanity + 2 epochs - assert eval_epoch_end_mock.call_count == 3 - - trainer.test() - # sanity + 2 epochs + called once for test - assert eval_epoch_end_mock.call_count == 4 diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index 44510eb16184d..e85c43361976d 100644 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -271,27 +271,3 @@ def test_lr_finder_fails_fast_on_bad_config(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_steps=2, auto_lr_find=True) with pytest.raises(MisconfigurationException, match='should have one of these fields'): trainer.tune(BoringModel()) - - -def test_lr_find_with_bs_scale(tmpdir): - """ Test that lr_find runs with batch_size_scaling """ - - class BoringModelTune(BoringModel): - - def __init__(self, learning_rate=0.1, batch_size=2): - super().__init__() - self.save_hyperparameters() - - model = BoringModelTune() - before_lr = model.hparams.learning_rate - - # logger file to get meta - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=3, - ) - bs = trainer.tuner.scale_batch_size(model) - lr = trainer.tuner.lr_find(model).suggestion() - - assert lr != before_lr - assert isinstance(bs, int) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4ca2f737f5106..5b06879b1f6d1 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -17,6 +17,7 @@ import sys from argparse import Namespace from copy import deepcopy +from distutils.version import LooseVersion from pathlib import Path from unittest.mock import ANY, call, patch @@ -42,6 +43,12 @@ from tests.helpers.runif import RunIf +@pytest.fixture +def pytorch_profiler(tmpdir): + profiler = PyTorchProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"), local_rank=0) + return profiler + + @pytest.mark.parametrize("url_ckpt", [True, False]) def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" @@ -596,9 +603,7 @@ def test_benchmark_option(tmpdir): @pytest.mark.parametrize("save_top_k", (-1, 0, 1, 2)) @pytest.mark.parametrize("fn", ("validate", "test")) def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn): - class TestModel(BoringModel): - def validation_step(self, batch, batch_idx): self.log("foo", -batch_idx) return super().validation_step(batch, batch_idx) @@ -1440,30 +1445,16 @@ def test_trainer_predict_no_return(tmpdir): class CustomBoringModel(BoringModel): - def predict_step(self, batch, batch_idx, dataloader_idx=None): + def predict(self, batch, batch_idx, dataloader_idx=None): if (batch_idx + 1) % 2 == 0: return - return super().predict_step(batch, batch_idx, dataloader_idx) + return super().predict(batch, batch_idx, dataloader_idx) with pytest.warns(UserWarning, match='predict returned None'): predict(tmpdir, None, None, 1, model=CustomBoringModel()) -def test_trainer_predict_grad(tmpdir): - - class CustomBoringModel(BoringModel): - - def predict_step(self, batch, batch_idx, dataloader_idx=None): - assert batch.expand_as(batch).grad_fn is None - return super().predict_step(batch, batch_idx, dataloader_idx) - - predict(tmpdir, None, None, 1, model=CustomBoringModel()) - - x = torch.zeros(1, requires_grad=True) - assert x.expand_as(x).grad_fn is not None - - @pytest.mark.parametrize('datamodule', [False, True]) def test_trainer_predict_cpu(tmpdir, datamodule): predict(tmpdir, None, None, 1, datamodule=datamodule) @@ -1495,6 +1486,124 @@ def test_trainer_predict_ddp_cpu(tmpdir): predict(tmpdir, "ddp_cpu", 0, 2) +def test_pytorch_profiler_describe(pytorch_profiler): + """Ensure the profiler won't fail when reporting the summary.""" + with pytorch_profiler.profile("test_step"): + pass + + # log to stdout and print to file + pytorch_profiler.describe() + data = Path(pytorch_profiler.output_fname).read_text() + assert len(data) > 0 + + +def test_pytorch_profiler_value_errors(pytorch_profiler): + """Ensure errors are raised where expected.""" + + action = "test_step" + with pytest.raises(ValueError): + pytorch_profiler.stop(action) + + pytorch_profiler.start(action) + pytorch_profiler.stop(action) + + +@RunIf(min_gpus=2, special=True) +@pytest.mark.parametrize("use_output_filename", [False, True]) +def test_pytorch_profiler_trainer_ddp(tmpdir, use_output_filename): + """Ensure that the profiler can be given to the training and default step are properly recorded. """ + + if use_output_filename: + output_filename = os.path.join(tmpdir, "profiler.txt") + else: + output_filename = None + + profiler = PyTorchProfiler(output_filename=output_filename) + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + profiler=profiler, + accelerator="ddp", + gpus=2, + ) + trainer.fit(model) + + enabled = use_output_filename or not use_output_filename and profiler.local_rank == 0 + + if enabled: + assert len(profiler.summary()) > 0 + assert set(profiler.profiled_actions.keys()) == {'training_step_and_backward', 'validation_step'} + else: + assert profiler.summary() is None + assert set(profiler.profiled_actions.keys()) == set() + + if use_output_filename: + profiler.describe() + data = Path(profiler.output_fname).read_text() + assert len(data) > 0 + + +def test_pytorch_profiler_nested(tmpdir): + """Ensure that the profiler handles nested context""" + + pytorch_profiler = PyTorchProfiler( + profiled_functions=["a", "b", "c"], use_cuda=False, output_filename=os.path.join(tmpdir, "profiler.txt") + ) + + with pytorch_profiler.profile("a"): + a = torch.ones(42) + with pytorch_profiler.profile("b"): + b = torch.zeros(42) + with pytorch_profiler.profile("c"): + _ = a + b + + pa = pytorch_profiler.profiled_actions + + # From PyTorch 1.8.0, less operation are being traced. + if LooseVersion(torch.__version__) >= LooseVersion("1.8.0"): + expected_ = { + 'a': ['ones', 'empty', 'fill_', 'zeros', 'empty', 'zero_', 'add'], + 'b': ['zeros', 'empty', 'zero_'], + 'c': ['add'], + } + # From PyTorch 1.6.0, more operation are being traced. + elif LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + expected_ = { + 'a': ['ones', 'empty', 'fill_', 'zeros', 'empty', 'zero_', 'fill_', 'add', 'empty'], + 'b': ['zeros', 'empty', 'zero_', 'fill_'], + 'c': ['add', 'empty'], + } + else: + expected_ = { + 'a': ['add'], + 'b': [], + 'c': ['add'], + } + + for n in ('a', 'b', 'c'): + pa[n] = [e.name for e in pa[n]] + if LooseVersion(torch.__version__) >= LooseVersion("1.7.1"): + pa[n] = [e.replace("aten::", "") for e in pa[n]] + assert pa[n] == expected_[n] + + +@RunIf(min_gpus=1, special=True) +def test_pytorch_profiler_nested_emit_nvtx(tmpdir): + """ + This test check emit_nvtx is correctly supported + """ + profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True) + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + profiler=profiler, + gpus=1, + ) + trainer.fit(model) + + @pytest.mark.parametrize( ["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"], [(0.2, 0, 0, 0, False), (0.5, 10, 2, 4, True)], @@ -1745,35 +1854,3 @@ def test_check_val_every_n_epoch_exception(tmpdir): max_epochs=1, check_val_every_n_epoch=1.2, ) - - -def test_trainer_attach_data_pipeline_to_model(tmpdir): - - class DataPipeline: - - pass - - class TestDataModule(LightningDataModule): - - data_pipeline = DataPipeline() - - def train_dataloader(self): - return DataLoader(RandomDataset(32, 64)) - - def val_dataloader(self): - return DataLoader(RandomDataset(32, 64)) - - def test_dataloader(self): - return DataLoader(RandomDataset(32, 64)) - - class TestCallback(Callback): - - def on_fit_start(self, trainer, pl_module: LightningModule) -> None: - """Called when fit begins""" - assert isinstance(pl_module.data_pipeline, DataPipeline) - - model = BoringModel() - dm = TestDataModule() - - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=[TestCallback()]) - trainer.fit(model, datamodule=dm) diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py deleted file mode 100644 index ad7fc57092f32..0000000000000 --- a/tests/tuner/test_scale_batch_size.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -from torch.utils.data import DataLoader - -from pytorch_lightning import Trainer -from pytorch_lightning.tuner.tuning import Tuner -from tests.helpers import BoringDataModule, BoringModel - - -class BatchSizeDataModule(BoringDataModule): - - def __init__(self, batch_size=None): - super().__init__() - if batch_size is not None: - self.batch_size = batch_size - - def train_dataloader(self): - return DataLoader(self.random_train, batch_size=getattr(self, "batch_size", 1)) - - -class BatchSizeModel(BoringModel): - - def __init__(self, batch_size=None): - super().__init__() - if batch_size is not None: - self.batch_size = batch_size - - -@pytest.mark.parametrize( - "model,datamodule", [ - (BatchSizeModel(2), None), - (BatchSizeModel(2), BatchSizeDataModule(2)), - (BatchSizeModel(2), BatchSizeDataModule(None)), - (BatchSizeModel(None), BatchSizeDataModule(2)), - ] -) -def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule): - """ Test the tuner method `Tuner.scale_batch_size` with a datamodule. """ - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=1, - limit_val_batches=0, - max_epochs=1, - ) - tuner = Tuner(trainer) - new_batch_size = tuner.scale_batch_size( - model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule - ) - assert new_batch_size == 16 - if hasattr(model, "batch_size"): - assert model.batch_size == 16 - if datamodule is not None and hasattr(datamodule, "batch_size"): - assert datamodule.batch_size == 16 diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index d67c9473bbb2e..259f9f4c09871 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -95,26 +95,3 @@ def training_epoch_end(self, outputs) -> None: trainer.fit(model) assert model.training_epoch_end_called - - -@RunIf(min_gpus=2, skip_windows=True, special=True) -def test_all_gather_sync_grads(tmpdir): - - class TestModel(BoringModel): - - training_step_called = False - - def training_step(self, batch, batch_idx): - self.training_step_called = True - tensor = torch.rand(2, 2, requires_grad=True, device=self.device) - gathered_tensor = self.all_gather(tensor, sync_grads=True) - assert gathered_tensor.shape == torch.Size([2, 2, 2]) - - loss = gathered_tensor.sum() - - return loss - - model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2) - trainer.fit(model) - assert model.training_step_called diff --git a/tests/utilities/test_argparse.py b/tests/utilities/test_argparse_utils.py similarity index 80% rename from tests/utilities/test_argparse.py rename to tests/utilities/test_argparse_utils.py index f13af4362364c..b2eac514941e6 100644 --- a/tests/utilities/test_argparse.py +++ b/tests/utilities/test_argparse_utils.py @@ -1,52 +1,17 @@ import io -from argparse import ArgumentParser, Namespace +from argparse import ArgumentParser from typing import List -from unittest.mock import MagicMock import pytest from pytorch_lightning import Trainer from pytorch_lightning.utilities.argparse import ( - _gpus_arg_default, - _int_or_float_type, add_argparse_args, - from_argparse_args, get_abbrev_qualified_cls_name, - parse_argparser, parse_args_from_docstring, ) -class ArgparseExample: - - def __init__(self, a: int = 0, b: str = '', c: bool = False): - self.a = a - self.b = b - self.c = c - - -def test_from_argparse_args(): - args = Namespace(a=1, b='test', c=True, d='not valid') - my_instance = from_argparse_args(ArgparseExample, args) - assert my_instance.a == 1 - assert my_instance.b == 'test' - assert my_instance.c - - parser = ArgumentParser() - mock_trainer = MagicMock() - _ = from_argparse_args(mock_trainer, parser) - mock_trainer.parse_argparser.assert_called_once_with(parser) - - -def test_parse_argparser(): - args = Namespace(a=1, b='test', c=None, d='not valid') - new_args = parse_argparser(ArgparseExample, args) - assert new_args.a == 1 - assert new_args.b == 'test' - assert new_args.c - assert new_args.d == 'not valid' - - def test_parse_args_from_docstring_normal(): args_help = parse_args_from_docstring( """Constrain image dataset @@ -203,13 +168,3 @@ def test_add_argparse_args_no_argument_group(): args = parser.parse_args(fake_argv) assert args.main_arg == "abc" assert args.my_parameter == 2 - - -def test_gpus_arg_default(): - assert _gpus_arg_default('1,2') == '1,2' - assert _gpus_arg_default('1') == 1 - - -def test_int_or_float_type(): - assert isinstance(_int_or_float_type('0.0'), float) - assert isinstance(_int_or_float_type('0'), int) From f545b94d638c646c9da1283ee002c79684ab0c5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 26 Mar 2021 12:38:36 +0100 Subject: [PATCH 12/25] Rm old fname check that is obsolete due to tempfile --- tests/loggers/test_mlflow.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index cb461fe4ef387..0c49a305583a7 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from pathlib import Path from unittest import mock from unittest.mock import MagicMock @@ -244,7 +243,6 @@ def test_mlflow_log_figure(client, mlflow, step_idx, figure_format, tmpdir): f = plotting.dummy_figure() logger.log_figure('dummy', f, step_idx, close=True) - fname_expect = logger.save_dir + f'/dummy_step_{step_idx}{figure_format}' artifact_expect = 'figure_dummy' mock_log.assert_called_once() From a1a3eedabc2903ec97b7e1222c769f3211194fb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 9 Apr 2021 10:36:17 +0200 Subject: [PATCH 13/25] Clear spec of supported loggers. --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a259bb8881e2..99ac98638975a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added unified API for figure logging `log_figure` in {insert all supported logger names here}. +- Added unified API for figure logging `log_figure` in Tensorboard, Comet, ML Flow, Neptune, Weights and Biases. ## [1.3.0] - 2021-MM-DD From 1559b78bdfd60f4b8f68048bdcb70854e935ee4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 9 Apr 2021 14:01:04 +0200 Subject: [PATCH 14/25] Style --- pytorch_lightning/loggers/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index bd7fbb7805194..9718f5391d90f 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -177,14 +177,15 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: """ Logs a matplotlib figure. + Args: name: name of the figure figure: plt figure handle step: step number at which the figure should be recorded close: close figure after logging """ - """Default is silent. - Not raising NotImplemented because one could have multiple logger where only some support log_figure.""" + # Default is silent and not NotImplementedError because we want to support LoggerCollection + # where some loggers might others might not have implemented this method. if close: plt.close(figure) From 653bd369bdb24ad6acc037a2fdb3147fd027c0f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 23 Apr 2021 10:32:26 +0200 Subject: [PATCH 15/25] Add matplotlib is available checks. Refactor close_figure all test. --- pytorch_lightning/loggers/base.py | 9 ++++++--- tests/helpers/plotting.py | 13 +++++++++++-- tests/loggers/test_all.py | 23 +++++++++++++++++++---- tests/loggers/test_comet.py | 5 +++++ tests/loggers/test_mlflow.py | 4 ++++ tests/loggers/test_neptune.py | 4 ++++ tests/loggers/test_tensorboard.py | 4 ++++ tests/loggers/test_wandb.py | 28 ++++++++++++++++++++++++++-- 8 files changed, 79 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 9718f5391d90f..927c7b48290fd 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -21,7 +21,6 @@ from functools import wraps from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union -import matplotlib.pyplot as plt import numpy as np import torch @@ -174,7 +173,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): """ pass - def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: + def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: """ Logs a matplotlib figure. @@ -184,6 +183,8 @@ def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, step: step number at which the figure should be recorded close: close figure after logging """ + import matplotlib.pyplot as plt + # Default is silent and not NotImplementedError because we want to support LoggerCollection # where some loggers might others might not have implemented this method. if close: @@ -393,7 +394,9 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> for logger in self._logger_iterable: logger.log_metrics(metrics, step) - def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: + def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: + import matplotlib.pyplot as plt + for logger in self._logger_iterable: # don't close in the individual loggers, but once at the end logger.log_figure(name, figure, step=step, close=False) diff --git a/tests/helpers/plotting.py b/tests/helpers/plotting.py index f3910240657c5..b7c5e88a68bb4 100644 --- a/tests/helpers/plotting.py +++ b/tests/helpers/plotting.py @@ -1,9 +1,18 @@ import numpy as np -from matplotlib import pyplot as plt +from pytorch_lightning.utilities import _module_available -def dummy_figure(): +_MATPLOTLIB_AVAILABLE = _module_available("matplotlib") +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + class plt: + figure = None + + +def dummy_figure() -> plt.figure: """Dummy figure to test logging capability of figures for loggers.""" + f = plt.figure() plt.plot(np.linspace(0., 1., 100), np.linspace(0., 10., 100) ** 2) diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index eb5f12f48324e..2028fbb68034d 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -33,6 +33,7 @@ ) from pytorch_lightning.loggers.base import DummyExperiment from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import _module_available from tests.helpers import BoringModel, plotting from tests.helpers.runif import RunIf from tests.loggers.test_comet import _patch_comet_atexit @@ -408,6 +409,9 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): logger.experiment.log.assert_called_once_with({'tmp-test': 1.0}, step=0) +@pytest.mark.skipif( + not _module_available("matplotlib"), + reason="close figure test requires matplotlib to be installed.") @pytest.mark.parametrize("close", [True, False]) @pytest.mark.parametrize("logger_class", [ CometLogger, @@ -415,12 +419,23 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): MLFlowLogger, NeptuneLogger, TensorBoardLogger, - WandbLogger, + # Wandb has its own close_figure test ]) -def test_logger_close_figure_all(logger_class, close, tmpdir): - f = plotting.dummy_figure() +def test_logger_close_figure_all(tmpdir, monkeypatch, logger_class, close): + _patch_comet_atexit(monkeypatch) + try: + _test_logger_close_figure(tmpdir, monkeypatch, logger_class, close) + except (ImportError, ModuleNotFoundError): + pytest.xfail(f"pickle test requires {logger_class.__class__} dependencies to be installed.") + - logger = _instantiate_logger(logger_class, save_idr=tmpdir) +def _test_logger_close_figure(tmpdir, monkeypatch, logger_class, close): + _patch_comet_atexit(monkeypatch) + + logger_args = _get_logger_args(logger_class, tmpdir) + logger = logger_class(**logger_args) + + f = plotting.dummy_figure() with mock.patch('matplotlib.pyplot.close') as plt_close: logger.log_figure('dummy', f, 0, close=close) diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index eee7eddea7ab7..f1862aef138ab 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -18,6 +18,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import CometLogger +from pytorch_lightning.utilities import _module_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, plotting @@ -224,12 +225,16 @@ def test_comet_epoch_logging(comet, comet_experiment, tmpdir, monkeypatch): logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123) +@pytest.mark.skipif( + not _module_available("matplotlib"), + reason="close figure test requires matplotlib to be installed.") @patch("pytorch_lightning.loggers.comet.CometExperiment") @patch('pytorch_lightning.loggers.comet.comet_ml') @pytest.mark.parametrize("step_idx", [10, None]) def test_comet_log_figure(comet, comet_experiment, tmpdir, monkeypatch, step_idx): _patch_comet_atexit(monkeypatch) + logger = CometLogger(project_name="test", save_dir=tmpdir) logger.log_figure('dummy', plotting.dummy_figure(), step_idx, close=True) # functional test diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 0c49a305583a7..f7a165fa59194 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -19,6 +19,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import _MLFLOW_AVAILABLE, MLFlowLogger +from pytorch_lightning.utilities import _module_available from tests.helpers import BoringModel, plotting @@ -229,6 +230,9 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): ) +@pytest.mark.skipif( + not _module_available("matplotlib"), + reason="close figure test requires matplotlib to be installed.") @mock.patch('pytorch_lightning.loggers.mlflow.mlflow') @mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') @pytest.mark.parametrize("step_idx", [10, None]) diff --git a/tests/loggers/test_neptune.py b/tests/loggers/test_neptune.py index 4de0d1c93d34e..de10e45ab903f 100644 --- a/tests/loggers/test_neptune.py +++ b/tests/loggers/test_neptune.py @@ -18,6 +18,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import NeptuneLogger +from pytorch_lightning.utilities import _module_available from tests.helpers import BoringModel, plotting @@ -127,6 +128,9 @@ def _run_training(logger): assert logger_open_after_fit._experiment.stop.call_count == 0 +@pytest.mark.skipif( + not _module_available("matplotlib"), + reason="close figure test requires matplotlib to be installed.") @patch('pytorch_lightning.loggers.neptune.neptune') @pytest.mark.parametrize("step_idx", [10, None]) def test_neptune_log_figure(neptune, step_idx): diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index b780b059b92f4..a586eae9d18c3 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -24,6 +24,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.utilities import _module_available from tests.helpers import BoringModel, plotting from tests.helpers.runif import RunIf @@ -131,6 +132,9 @@ def test_tensorboard_log_metrics(tmpdir, step_idx): logger.log_metrics(metrics, step_idx) +@pytest.mark.skipif( + not _module_available("matplotlib"), + reason="close figure test requires matplotlib to be installed.") @pytest.mark.parametrize("step_idx", [10, None]) def test_tensorboard_log_figure(tmpdir, step_idx): logger = TensorBoardLogger(tmpdir) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index d114846e9195f..c3e20e604243b 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -21,6 +21,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.utilities import _module_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, plotting @@ -189,10 +190,14 @@ def test_wandb_logger_offline_log_model(wandb, tmpdir): _ = WandbLogger(save_dir=str(tmpdir), offline=True, log_model=True) +@pytest.mark.skipif( + not _module_available("matplotlib"), + reason="close figure test requires matplotlib to be installed.") @mock.patch('pytorch_lightning.loggers.wandb.wandb') @pytest.mark.parametrize("step_idx", [10, None]) -def test_wandb_logger_log_figure(wandb, step_idx): - logger = WandbLogger(anonymous=True, offline=True) +def test_wandb_logger_log_figure(wandb, tmpdir, step_idx): + + logger = WandbLogger(save_dir=str(tmpdir), offline=True) logger.log_figure('dummy', plotting.dummy_figure(), step_idx, close=True) # functional test with mock.patch.object(logger.experiment, 'log') as mock_log: @@ -200,3 +205,22 @@ def test_wandb_logger_log_figure(wandb, step_idx): logger.log_figure('dummy', f, step_idx, close=True) mock_log.assert_called_once_with({'dummy': wandb.Image(f)}, step=step_idx) + + +@pytest.mark.skipif( + not _module_available("matplotlib"), + reason="close figure test requires matplotlib to be installed.") +@pytest.mark.parametrize("close", [True, False]) +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_close_figure(wandb, tmpdir, close): + f = plotting.dummy_figure() + + logger = WandbLogger(save_dir=str(tmpdir), offline=True) + + with mock.patch('matplotlib.pyplot.close') as plt_close: + logger.log_figure('dummy', f, 0, close=close) + + if close: + plt_close.assert_called_once_with(f) + else: + plt_close.assert_not_called() \ No newline at end of file From 855efd104d01bd54a2795039d91ae7122fd0f738 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 23 Apr 2021 11:23:20 +0200 Subject: [PATCH 16/25] Add mocked types when matplotlib is not available. Style. --- docs/source/extensions/logging.rst | 3 ++- pytorch_lightning/loggers/base.py | 12 ++++++++-- pytorch_lightning/loggers/comet.py | 11 +++++++-- pytorch_lightning/loggers/mlflow.py | 11 +++++++-- pytorch_lightning/loggers/neptune.py | 10 ++++++++- pytorch_lightning/loggers/tensorboard.py | 13 +++++++++-- pytorch_lightning/loggers/wandb.py | 10 ++++++++- pytorch_lightning/utilities/mock_types.py | 27 +++++++++++++++++++++++ tests/loggers/test_wandb.py | 2 +- 9 files changed, 87 insertions(+), 12 deletions(-) create mode 100644 pytorch_lightning/utilities/mock_types.py diff --git a/docs/source/extensions/logging.rst b/docs/source/extensions/logging.rst index 9ad17b5fd1821..239874d29b754 100644 --- a/docs/source/extensions/logging.rst +++ b/docs/source/extensions/logging.rst @@ -320,9 +320,10 @@ Logging Figures *************** When training a model, often it is very indicative to log figures, e.g. of the in- and output. -For standard ``matplotlib.pyplot`` figures, Lightning has a unified API that works with most of the implemented loggers. +For standard ``matplotlib.pyplot`` figures, Lightning has a unified API that works with Tensorboard, Comet, ML Flow, Neptune, Weights and Biases. .. code-block:: python + f = plt.figure() logger.log_figure(name='dummy_figure', figure=f, step=0, close=True) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 927c7b48290fd..ee77d0ea1cecf 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -25,7 +25,15 @@ import torch from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities import _module_available, rank_zero_only + +_MATPLOTLIB_AVAILABLE = _module_available("matplotlib") + +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib + plt = _matplotlib.pyplot def rank_zero_experiment(fn: Callable) -> Callable: @@ -173,7 +181,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): """ pass - def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: """ Logs a matplotlib figure. diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 1290d100b2241..ee343ee74682f 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -49,6 +49,14 @@ CometExperiment, CometExistingExperiment, CometOfflineExperiment = None, None, None API = None +_MATPLOTLIB_AVAILABLE = _module_available("matplotlib") + +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib + plt = _matplotlib.pyplot + class CometLogger(LightningLoggerBase): r""" @@ -253,8 +261,7 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch) @rank_zero_only - def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: - import matplotlib.pyplot as plt + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: self.experiment.log_figure(figure_name=name, figure=figure, step=step) if close: diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 69592ad1ff5ff..95ec81af015bf 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -38,6 +38,14 @@ _MLFLOW_AVAILABLE = False mlflow, MlflowClient, context = None, None, None +_MATPLOTLIB_AVAILABLE = _module_available("matplotlib") + +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib + plt = _matplotlib.pyplot + # before v1.1.0 if hasattr(context, 'resolve_tags'): @@ -210,8 +218,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step) @rank_zero_only - def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: - import matplotlib.pyplot as plt + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: with tempfile.NamedTemporaryFile(suffix=self._figure_file_extension) as tmp_file: figure.savefig(tmp_file) diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index 2169dd175e621..f0557569e1786 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -35,6 +35,14 @@ # needed for test mocks, these tests shall be updated neptune, Experiment = None, None +_MATPLOTLIB_AVAILABLE = _module_available("matplotlib") + +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib + plt = _matplotlib.pyplot + class NeptuneLogger(LightningLoggerBase): r""" @@ -264,7 +272,7 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti self.log_metric(key, val) @rank_zero_only - def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: import matplotlib.pyplot as plt description = f"step_{step}" if step is not None else None diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 8fe6cbfa9ef7d..1a420705a9c41 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -28,7 +28,8 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment -from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import \ + _OMEGACONF_AVAILABLE, _module_available, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem log = logging.getLogger(__name__) @@ -36,6 +37,14 @@ if _OMEGACONF_AVAILABLE: from omegaconf import Container, OmegaConf +_MATPLOTLIB_AVAILABLE = _module_available("matplotlib") + +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib + plt = _matplotlib.pyplot + class TensorBoardLogger(LightningLoggerBase): r""" @@ -210,7 +219,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> raise ValueError(m) from ex @rank_zero_only - def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: self.experiment.add_figure(tag=name, figure=figure, global_step=step, close=close) @rank_zero_only diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 25273bc92c736..1b5ccd2f22051 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -37,6 +37,14 @@ # needed for test mocks, these tests shall be updated wandb, Run = None, None +_MATPLOTLIB_AVAILABLE = _module_available("matplotlib") + +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib + plt = _matplotlib.pyplot + class WandbLogger(LightningLoggerBase): r""" @@ -200,7 +208,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> self.experiment.log(metrics) @rank_zero_only - def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: + def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: import matplotlib.pyplot as plt self.experiment.log({name: wandb.Image(figure)}, step=step) diff --git a/pytorch_lightning/utilities/mock_types.py b/pytorch_lightning/utilities/mock_types.py new file mode 100644 index 0000000000000..d0ba0c880e89e --- /dev/null +++ b/pytorch_lightning/utilities/mock_types.py @@ -0,0 +1,27 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Mock typing for non-required packages. They can be imported as normal. + +Examples: + >>> from mock_types import package + # simulate `import package.submodule as short` + >>> plt = __import__("mock_types").matplotlib.pyplot + +""" + +class matplotlib: + class pyplot: + figure = None + close = None diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index c3e20e604243b..6384a06202132 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -223,4 +223,4 @@ def test_close_figure(wandb, tmpdir, close): if close: plt_close.assert_called_once_with(f) else: - plt_close.assert_not_called() \ No newline at end of file + plt_close.assert_not_called() From 76bd3b2b63b79a14cf0cc2354cf6f014574bdfa4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 23 Apr 2021 15:29:15 +0200 Subject: [PATCH 17/25] PEP --- pytorch_lightning/utilities/mock_types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/utilities/mock_types.py b/pytorch_lightning/utilities/mock_types.py index d0ba0c880e89e..9ae2019be740a 100644 --- a/pytorch_lightning/utilities/mock_types.py +++ b/pytorch_lightning/utilities/mock_types.py @@ -21,6 +21,7 @@ """ + class matplotlib: class pyplot: figure = None From 87150dcd211f5f10612db39f1ee733715994774a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 7 May 2021 09:28:51 +0200 Subject: [PATCH 18/25] Rm deprecated example section --- pytorch_lightning/__about__.py | 2 +- pytorch_lightning/utilities/mock_types.py | 9 --------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/pytorch_lightning/__about__.py b/pytorch_lightning/__about__.py index 0ce2273febf00..05a81e9184399 100644 --- a/pytorch_lightning/__about__.py +++ b/pytorch_lightning/__about__.py @@ -1,7 +1,7 @@ import time _this_year = time.strftime("%Y") -__version__ = '1.3.0rc2' +__version__ = "20210429" __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0' diff --git a/pytorch_lightning/utilities/mock_types.py b/pytorch_lightning/utilities/mock_types.py index 9ae2019be740a..14f9be9631f61 100644 --- a/pytorch_lightning/utilities/mock_types.py +++ b/pytorch_lightning/utilities/mock_types.py @@ -11,15 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Mock typing for non-required packages. They can be imported as normal. - -Examples: - >>> from mock_types import package - # simulate `import package.submodule as short` - >>> plt = __import__("mock_types").matplotlib.pyplot - -""" class matplotlib: From a4175f26edf132ee2e4b84c4479fe2c9089c9200 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 21 May 2021 09:54:35 +0200 Subject: [PATCH 19/25] Rm unnecessary imports --- pytorch_lightning/loggers/base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index ee77d0ea1cecf..5cde5e29dbf4d 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -191,8 +191,6 @@ def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, step: step number at which the figure should be recorded close: close figure after logging """ - import matplotlib.pyplot as plt - # Default is silent and not NotImplementedError because we want to support LoggerCollection # where some loggers might others might not have implemented this method. if close: @@ -403,7 +401,6 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> logger.log_metrics(metrics, step) def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: - import matplotlib.pyplot as plt for logger in self._logger_iterable: # don't close in the individual loggers, but once at the end From 812d3c134c06c254f91c7a56d953c8669acc88c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 May 2021 07:57:33 +0000 Subject: [PATCH 20/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loggers/mlflow.py | 5 +---- pytorch_lightning/loggers/tensorboard.py | 3 +-- pytorch_lightning/utilities/mock_types.py | 1 + tests/helpers/plotting.py | 3 ++- tests/loggers/test_all.py | 23 +++++++++++++---------- tests/loggers/test_comet.py | 4 ++-- tests/loggers/test_mlflow.py | 4 ++-- tests/loggers/test_neptune.py | 7 +++---- tests/loggers/test_tensorboard.py | 4 ++-- tests/loggers/test_wandb.py | 8 ++++---- 10 files changed, 31 insertions(+), 31 deletions(-) diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index e81aa06904dcd..0ddfe7e8762fb 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -46,7 +46,6 @@ from pytorch_lightning.utilities.mock_types import matplotlib as _matplotlib plt = _matplotlib.pyplot - # before v1.1.0 if hasattr(context, 'resolve_tags'): from mlflow.tracking.context import resolve_tags @@ -226,9 +225,7 @@ def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, with tempfile.NamedTemporaryFile(suffix=self._figure_file_extension) as tmp_file: figure.savefig(tmp_file) self.experiment.log_artifact( - self.run_id, - tmp_file.name, - artifact_path=Path(self.save_dir) / ("figure_" + name) + self.run_id, tmp_file.name, artifact_path=Path(self.save_dir) / ("figure_" + name) ) if close: diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index aab15e5bdc814..85b935f8f6582 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -28,8 +28,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment -from pytorch_lightning.utilities import \ - _OMEGACONF_AVAILABLE, _module_available, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import _module_available, _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem log = logging.getLogger(__name__) diff --git a/pytorch_lightning/utilities/mock_types.py b/pytorch_lightning/utilities/mock_types.py index 14f9be9631f61..7b6e326c353a2 100644 --- a/pytorch_lightning/utilities/mock_types.py +++ b/pytorch_lightning/utilities/mock_types.py @@ -14,6 +14,7 @@ class matplotlib: + class pyplot: figure = None close = None diff --git a/tests/helpers/plotting.py b/tests/helpers/plotting.py index b7c5e88a68bb4..ef9e1b7559dca 100644 --- a/tests/helpers/plotting.py +++ b/tests/helpers/plotting.py @@ -6,6 +6,7 @@ if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt else: + class plt: figure = None @@ -14,6 +15,6 @@ def dummy_figure() -> plt.figure: """Dummy figure to test logging capability of figures for loggers.""" f = plt.figure() - plt.plot(np.linspace(0., 1., 100), np.linspace(0., 10., 100) ** 2) + plt.plot(np.linspace(0., 1., 100), np.linspace(0., 10., 100)**2) return f diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 970a7b2ab1f38..a9692485efd4c 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -409,17 +409,20 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): @pytest.mark.skipif( - not _module_available("matplotlib"), - reason="close figure test requires matplotlib to be installed.") + not _module_available("matplotlib"), reason="close figure test requires matplotlib to be installed." +) @pytest.mark.parametrize("close", [True, False]) -@pytest.mark.parametrize("logger_class", [ - CometLogger, - CSVLogger, - MLFlowLogger, - NeptuneLogger, - TensorBoardLogger, - # Wandb has its own close_figure test -]) +@pytest.mark.parametrize( + "logger_class", + [ + CometLogger, + CSVLogger, + MLFlowLogger, + NeptuneLogger, + TensorBoardLogger, + # Wandb has its own close_figure test + ] +) def test_logger_close_figure_all(tmpdir, monkeypatch, logger_class, close): _patch_comet_atexit(monkeypatch) try: diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index f1862aef138ab..79dee595005fe 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -226,8 +226,8 @@ def test_comet_epoch_logging(comet, comet_experiment, tmpdir, monkeypatch): @pytest.mark.skipif( - not _module_available("matplotlib"), - reason="close figure test requires matplotlib to be installed.") + not _module_available("matplotlib"), reason="close figure test requires matplotlib to be installed." +) @patch("pytorch_lightning.loggers.comet.CometExperiment") @patch('pytorch_lightning.loggers.comet.comet_ml') @pytest.mark.parametrize("step_idx", [10, None]) diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index f7a165fa59194..2152d94ea8d0a 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -231,8 +231,8 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): @pytest.mark.skipif( - not _module_available("matplotlib"), - reason="close figure test requires matplotlib to be installed.") + not _module_available("matplotlib"), reason="close figure test requires matplotlib to be installed." +) @mock.patch('pytorch_lightning.loggers.mlflow.mlflow') @mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') @pytest.mark.parametrize("step_idx", [10, None]) diff --git a/tests/loggers/test_neptune.py b/tests/loggers/test_neptune.py index 6fd514a2ac23b..5ef831d35a3bb 100644 --- a/tests/loggers/test_neptune.py +++ b/tests/loggers/test_neptune.py @@ -129,8 +129,8 @@ def _run_training(logger): @pytest.mark.skipif( - not _module_available("matplotlib"), - reason="close figure test requires matplotlib to be installed.") + not _module_available("matplotlib"), reason="close figure test requires matplotlib to be installed." +) @patch('pytorch_lightning.loggers.neptune.neptune') @pytest.mark.parametrize("step_idx", [10, None]) def test_neptune_log_figure(neptune, step_idx): @@ -141,5 +141,4 @@ def test_neptune_log_figure(neptune, step_idx): f = plotting.dummy_figure() logger.log_figure('dummy', f, step_idx, close=True) - mock_log.assert_called_once_with('dummy', f, - description=f"step_{step_idx}" if step_idx is not None else None) + mock_log.assert_called_once_with('dummy', f, description=f"step_{step_idx}" if step_idx is not None else None) diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index bd4662517eaf2..88990dfab27d8 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -178,8 +178,8 @@ def test_tensorboard_log_metrics(tmpdir, step_idx): @pytest.mark.skipif( - not _module_available("matplotlib"), - reason="close figure test requires matplotlib to be installed.") + not _module_available("matplotlib"), reason="close figure test requires matplotlib to be installed." +) @pytest.mark.parametrize("step_idx", [10, None]) def test_tensorboard_log_figure(tmpdir, step_idx): logger = TensorBoardLogger(tmpdir) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 5e3aba4010792..5f21e9d080640 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -203,8 +203,8 @@ def test_wandb_logger_offline_log_model(wandb, tmpdir): @pytest.mark.skipif( - not _module_available("matplotlib"), - reason="close figure test requires matplotlib to be installed.") + not _module_available("matplotlib"), reason="close figure test requires matplotlib to be installed." +) @mock.patch('pytorch_lightning.loggers.wandb.wandb') @pytest.mark.parametrize("step_idx", [10, None]) def test_wandb_logger_log_figure(wandb, tmpdir, step_idx): @@ -220,8 +220,8 @@ def test_wandb_logger_log_figure(wandb, tmpdir, step_idx): @pytest.mark.skipif( - not _module_available("matplotlib"), - reason="close figure test requires matplotlib to be installed.") + not _module_available("matplotlib"), reason="close figure test requires matplotlib to be installed." +) @pytest.mark.parametrize("close", [True, False]) @mock.patch('pytorch_lightning.loggers.wandb.wandb') def test_close_figure(wandb, tmpdir, close): From 00b59a1c11220d5fdd5167c2d4e3ac2e4a8e7b74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 21 May 2021 10:12:10 +0200 Subject: [PATCH 21/25] Bugfix in artifact path --- pytorch_lightning/loggers/mlflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index e81aa06904dcd..85bc1d1d796a1 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -228,7 +228,7 @@ def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, self.experiment.log_artifact( self.run_id, tmp_file.name, - artifact_path=Path(self.save_dir) / ("figure_" + name) + artifact_path=Path(self.save_dir) / f"figure_{name}{self._figure_file_extension}" ) if close: From 82c6da9fbecb6ec44d7eec723d086a550227ae3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas-Rapahel=20Mu=CC=88ller?= Date: Fri, 21 May 2021 12:22:47 +0200 Subject: [PATCH 22/25] Change log_figure saving logic for mlflow --- pytorch_lightning/loggers/mlflow.py | 16 ++++++++++++---- tests/loggers/test_mlflow.py | 7 ++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index fcc0469596150..036a018b7653e 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -222,12 +222,20 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> @rank_zero_only def log_figure(self, name: str, figure: plt.figure, step: Optional[int] = None, close: bool = True) -> None: - with tempfile.NamedTemporaryFile(suffix=self._figure_file_extension) as tmp_file: - figure.savefig(tmp_file) + if step is not None: + figure_fname = f"figure_{name}_step_{step}{self._figure_file_extension}" + else: + figure_fname = f"figure_{name}{self._figure_file_extension}" + + # create tmp directory and semantically named filebecause + # apparently one should not write to artifact location directly + # ToDo: Once its stable, use ml_flow.log_figure + with tempfile.TemporaryDirectory() as tmp_dir: + figure_path = Path(tmp_dir) / figure_fname + figure.savefig(figure_path) self.experiment.log_artifact( self.run_id, - tmp_file.name, - artifact_path=Path(self.save_dir) / f"figure_{name}{self._figure_file_extension}" + figure_path, ) if close: diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 2152d94ea8d0a..ac19f65ade197 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -247,7 +247,8 @@ def test_mlflow_log_figure(client, mlflow, step_idx, figure_format, tmpdir): f = plotting.dummy_figure() logger.log_figure('dummy', f, step_idx, close=True) - artifact_expect = 'figure_dummy' - mock_log.assert_called_once() - mock_log.call_args_list[0][1]['artifact_path'] == artifact_expect + if step_idx is not None: + assert mock_log.call_args_list[0][0][1].stem == 'figure_dummy_step_10' + else: + assert mock_log.call_args_list[0][0][1].stem == 'figure_dummy' From 126c0ec9de2a3dfdca525e6f85bd02154dff41f7 Mon Sep 17 00:00:00 2001 From: Lucas-Raphael Mueller Date: Fri, 18 Jun 2021 11:01:12 +0200 Subject: [PATCH 23/25] Change mlflow log_figure test to actual functional test instead of mocking --- tests/loggers/test_mlflow.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index d3543053a5044..03d03e865008d 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -261,16 +261,21 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): @pytest.mark.skipif( not _module_available("matplotlib"), reason="close figure test requires matplotlib to be installed." ) -@mock.patch('pytorch_lightning.loggers.mlflow.mlflow') -@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') @pytest.mark.parametrize("step_idx", [10, None]) @pytest.mark.parametrize("figure_format", ['.png', '.pdf']) -def test_mlflow_log_figure(client, mlflow, step_idx, figure_format, tmpdir): +def test_mlflow_log_figure(step_idx, figure_format, tmpdir): logger = MLFlowLogger('test', save_dir=tmpdir, figure_file_extension=figure_format) - logger.log_figure('dummy', plotting.dummy_figure(), step_idx, close=True) # functional test + logger.log_figure('dummy', plotting.dummy_figure(), step_idx, close=True) + + if step_idx is not None: + fname_expct = f'figure_dummy_step_{step_idx}{figure_format}' + else: + fname_expct = f'figure_dummy{figure_format}' + path_expct = tmpdir / logger.experiment_id / logger.run_id / 'artifacts' / fname_expct + assert path_expct.check(file=True) - # test whether figure is closed etc. + # tests arguments to log_artifact with mock.patch.object(logger.experiment, 'log_artifact') as mock_log: f = plotting.dummy_figure() logger.log_figure('dummy', f, step_idx, close=True) From f972410ac160ec3ef5e6632e8c63c03f3c977ae5 Mon Sep 17 00:00:00 2001 From: Lucas-Raphael Mueller Date: Fri, 18 Jun 2021 11:49:03 +0200 Subject: [PATCH 24/25] Move matplotlib available check to utilities.import --- pytorch_lightning/loggers/comet.py | 4 ++-- pytorch_lightning/loggers/mlflow.py | 3 +-- pytorch_lightning/loggers/neptune.py | 3 +-- pytorch_lightning/loggers/tensorboard.py | 5 ++--- pytorch_lightning/loggers/wandb.py | 4 +--- pytorch_lightning/utilities/imports.py | 1 + tests/loggers/test_all.py | 4 ++-- tests/loggers/test_comet.py | 4 ++-- tests/loggers/test_mlflow.py | 4 ++-- tests/loggers/test_neptune.py | 4 ++-- tests/loggers/test_tensorboard.py | 4 ++-- tests/loggers/test_wandb.py | 6 +++--- 12 files changed, 21 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index ee343ee74682f..06cf0398e1f24 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -28,6 +28,8 @@ from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE + log = logging.getLogger(__name__) _COMET_AVAILABLE = _module_available("comet_ml") @@ -49,8 +51,6 @@ CometExperiment, CometExistingExperiment, CometOfflineExperiment = None, None, None API = None -_MATPLOTLIB_AVAILABLE = _module_available("matplotlib") - if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt else: diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 81dc30e05617c..6930c2967fe5c 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -26,6 +26,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE log = logging.getLogger(__name__) LOCAL_FILE_URI_PREFIX = "file:" @@ -40,8 +41,6 @@ mlflow, MlflowClient, context = None, None, None MLFLOW_RUN_NAME = "mlflow.runName" -_MATPLOTLIB_AVAILABLE = _module_available("matplotlib") - if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt else: diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index 29a3224f1231c..a1b6e1fc9a212 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -24,6 +24,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE log = logging.getLogger(__name__) _NEPTUNE_AVAILABLE = _module_available("neptune") @@ -35,8 +36,6 @@ # needed for test mocks, these tests shall be updated neptune, Experiment = None, None -_MATPLOTLIB_AVAILABLE = _module_available("matplotlib") - if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt else: diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 9dada19d0adbc..2d1a52f2152fb 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -28,16 +28,15 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment -from pytorch_lightning.utilities import _module_available, _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE log = logging.getLogger(__name__) if _OMEGACONF_AVAILABLE: from omegaconf import Container, OmegaConf -_MATPLOTLIB_AVAILABLE = _module_available("matplotlib") - if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt else: diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 70ab8c3d16f7b..1d991ec9bc34c 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -28,7 +28,7 @@ from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _compare_version +from pytorch_lightning.utilities.imports import _compare_version, _MATPLOTLIB_AVAILABLE from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() @@ -43,8 +43,6 @@ # needed for test mocks, these tests shall be updated wandb, Run = None, None -_MATPLOTLIB_AVAILABLE = _module_available("matplotlib") - if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt else: diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 2a51b01404821..2313754321868 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -83,6 +83,7 @@ def _compare_version(package: str, op, version) -> bool: _HYDRA_AVAILABLE = _module_available("hydra") _HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental") _KINETO_AVAILABLE = _TORCH_GREATER_EQUAL_1_8_1 and torch.profiler.kineto_available() +_MATPLOTLIB_AVAILABLE = _module_available("matplotlib") _NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") _OMEGACONF_AVAILABLE = _module_available("omegaconf") _POPTORCH_AVAILABLE = _module_available('poptorch') diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 25d6a34cff10d..3819ca9c9cb77 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -32,7 +32,7 @@ WandbLogger, ) from pytorch_lightning.loggers.base import DummyExperiment -from pytorch_lightning.utilities import _module_available +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE from tests.helpers import BoringModel, plotting from tests.helpers.runif import RunIf from tests.loggers.test_comet import _patch_comet_atexit @@ -411,7 +411,7 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): @pytest.mark.skipif( - not _module_available("matplotlib"), reason="close figure test requires matplotlib to be installed." + not _MATPLOTLIB_AVAILABLE, reason="close figure test requires matplotlib to be installed." ) @pytest.mark.parametrize("close", [True, False]) @pytest.mark.parametrize( diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index 79dee595005fe..1ee2fd81f0287 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -18,8 +18,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import CometLogger -from pytorch_lightning.utilities import _module_available from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE from tests.helpers import BoringModel, plotting @@ -226,7 +226,7 @@ def test_comet_epoch_logging(comet, comet_experiment, tmpdir, monkeypatch): @pytest.mark.skipif( - not _module_available("matplotlib"), reason="close figure test requires matplotlib to be installed." + not _MATPLOTLIB_AVAILABLE, reason="close figure test requires matplotlib to be installed." ) @patch("pytorch_lightning.loggers.comet.CometExperiment") @patch('pytorch_lightning.loggers.comet.comet_ml') diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 03d03e865008d..2fec404641834 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -20,7 +20,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import _MLFLOW_AVAILABLE, MLFlowLogger from pytorch_lightning.loggers.mlflow import MLFLOW_RUN_NAME, resolve_tags -from pytorch_lightning.utilities import _module_available +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE from tests.helpers import BoringModel, plotting @@ -259,7 +259,7 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): @pytest.mark.skipif( - not _module_available("matplotlib"), reason="close figure test requires matplotlib to be installed." + not _MATPLOTLIB_AVAILABLE, reason="close figure test requires matplotlib to be installed." ) @pytest.mark.parametrize("step_idx", [10, None]) @pytest.mark.parametrize("figure_format", ['.png', '.pdf']) diff --git a/tests/loggers/test_neptune.py b/tests/loggers/test_neptune.py index 5ef831d35a3bb..ea83dd0eeab8c 100644 --- a/tests/loggers/test_neptune.py +++ b/tests/loggers/test_neptune.py @@ -18,7 +18,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import NeptuneLogger -from pytorch_lightning.utilities import _module_available +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE from tests.helpers import BoringModel, plotting @@ -129,7 +129,7 @@ def _run_training(logger): @pytest.mark.skipif( - not _module_available("matplotlib"), reason="close figure test requires matplotlib to be installed." + not _MATPLOTLIB_AVAILABLE, reason="close figure test requires matplotlib to be installed." ) @patch('pytorch_lightning.loggers.neptune.neptune') @pytest.mark.parametrize("step_idx", [10, None]) diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index 2ebd95ef1a388..4bb6fc4b7b64c 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -25,7 +25,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.utilities import _module_available +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE from tests.helpers import BoringModel, plotting from tests.helpers.runif import RunIf @@ -179,7 +179,7 @@ def test_tensorboard_log_metrics(tmpdir, step_idx): @pytest.mark.skipif( - not _module_available("matplotlib"), reason="close figure test requires matplotlib to be installed." + not _MATPLOTLIB_AVAILABLE, reason="close figure test requires matplotlib to be installed." ) @pytest.mark.parametrize("step_idx", [10, None]) def test_tensorboard_log_figure(tmpdir, step_idx): diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 3c28ca4fb236d..46ca8edfd343c 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -20,8 +20,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import WandbLogger -from pytorch_lightning.utilities import _module_available from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _MATPLOTLIB_AVAILABLE from tests.helpers import BoringModel, plotting @@ -258,7 +258,7 @@ def test_wandb_logger_offline_log_model(wandb, tmpdir): @pytest.mark.skipif( - not _module_available("matplotlib"), reason="close figure test requires matplotlib to be installed." + not _MATPLOTLIB_AVAILABLE, reason="close figure test requires matplotlib to be installed." ) @mock.patch('pytorch_lightning.loggers.wandb.wandb') @pytest.mark.parametrize("step_idx", [10, None]) @@ -275,7 +275,7 @@ def test_wandb_logger_log_figure(wandb, tmpdir, step_idx): @pytest.mark.skipif( - not _module_available("matplotlib"), reason="close figure test requires matplotlib to be installed." + not _MATPLOTLIB_AVAILABLE, reason="close figure test requires matplotlib to be installed." ) @pytest.mark.parametrize("close", [True, False]) @mock.patch('pytorch_lightning.loggers.wandb.wandb') From 568425848f308f3c0770e94e2d8d40842b619e61 Mon Sep 17 00:00:00 2001 From: Lucas-Raphael Mueller Date: Fri, 18 Jun 2021 13:17:01 +0200 Subject: [PATCH 25/25] Rm unnecessary line --- pytorch_lightning/loggers/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 2922758ab55dd..1ce31f840b69b 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -416,7 +416,6 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> logger.log_metrics(metrics, step) def log_figure(self, name: str, figure, step: Optional[int] = None, close: bool = True) -> None: - for logger in self._logger_iterable: # don't close in the individual loggers, but once at the end logger.log_figure(name, figure, step=step, close=False)