Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Plotting of metrics #1328

Merged
merged 75 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
3641b2c
plot code
SkafteNicki Aug 29, 2022
b1688dd
Merge branch 'master' into feature/plot
SkafteNicki Oct 8, 2022
edff40b
curve plot
SkafteNicki Oct 10, 2022
ff1926c
Merge branch 'master' into feature/plot
SkafteNicki Nov 8, 2022
b241b90
plot of single value
SkafteNicki Nov 10, 2022
275af7f
Merge branch 'master' into feature/plot
SkafteNicki Nov 10, 2022
7e753d8
Merge branch 'master' into feature/plot
SkafteNicki Nov 10, 2022
eb545e6
changelog
SkafteNicki Nov 10, 2022
f68190a
enable doc plotting
SkafteNicki Nov 10, 2022
be8de34
update plot functions
SkafteNicki Nov 10, 2022
e65d08f
enable doc plotting
SkafteNicki Nov 10, 2022
45c0527
example
SkafteNicki Nov 10, 2022
0aecc5c
try something
SkafteNicki Nov 11, 2022
3688891
Merge branch 'master' into feature/plot
Borda Nov 11, 2022
64e738a
move requirement
SkafteNicki Nov 13, 2022
fa80944
requirement shenanigans
SkafteNicki Nov 13, 2022
0ab579d
fix
SkafteNicki Nov 13, 2022
1b4673b
Merge branch 'master' into feature/plot
SkafteNicki Nov 14, 2022
1c01d03
try fixing doctesting
SkafteNicki Nov 14, 2022
0d472de
Apply suggestions from code review
SkafteNicki Nov 15, 2022
81f5e0e
some docstring
SkafteNicki Nov 15, 2022
98b9c3e
Merge branch 'master' into feature/plot
SkafteNicki Nov 16, 2022
42ed083
add to base class + docstring
SkafteNicki Nov 16, 2022
39aaf89
suggestions
SkafteNicki Nov 16, 2022
c501a7f
fx testing
SkafteNicki Nov 16, 2022
07528c3
testing
SkafteNicki Nov 16, 2022
d4632cb
some fixes
SkafteNicki Nov 16, 2022
c07da9e
fix source inclusion
SkafteNicki Nov 16, 2022
cf17a69
Merge branch 'master' into feature/plot
SkafteNicki Nov 16, 2022
b3fecaf
Merge branch 'master' into feature/plot
Borda Nov 17, 2022
33e6501
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
4a9e2e9
Apply suggestions from code review
Borda Nov 17, 2022
0e49be3
visual
Borda Nov 17, 2022
a602002
requirements
SkafteNicki Nov 18, 2022
4b18038
cython
Borda Nov 18, 2022
f727a13
Merge branch 'master' into feature/plot
SkafteNicki Nov 18, 2022
9f22a48
Merge branch 'master' into feature/plot
SkafteNicki Nov 18, 2022
a99530c
Merge branch 'master' into feature/plot
mergify[bot] Nov 19, 2022
79b3f6d
Merge branch 'master' into feature/plot
mergify[bot] Nov 21, 2022
622fbfc
Merge branch 'master' into feature/plot
SkafteNicki Nov 22, 2022
c5b9845
update with ax arg
SkafteNicki Nov 22, 2022
895c7ff
change requirement
SkafteNicki Nov 22, 2022
b27f436
Merge branch 'master' into feature/plot
mergify[bot] Nov 22, 2022
54cced9
Merge branch 'master' into feature/plot
mergify[bot] Nov 22, 2022
28fc65a
Merge branch 'master' into feature/plot
mergify[bot] Nov 22, 2022
ad5e443
Merge branch 'master' into feature/plot
Borda Nov 30, 2022
f674dda
Merge branch 'master' into feature/plot
mergify[bot] Dec 2, 2022
a35edbf
Merge branch 'master' into feature/plot
mergify[bot] Dec 3, 2022
8c82865
Merge branch 'master' into feature/plot
mergify[bot] Dec 4, 2022
72fb033
Merge branch 'master' into feature/plot
mergify[bot] Dec 7, 2022
8fe409b
Merge branch 'master' into feature/plot
mergify[bot] Dec 7, 2022
6c4a4cb
Merge branch 'master' into feature/plot
mergify[bot] Dec 7, 2022
70a964b
Merge branch 'master' into feature/plot
mergify[bot] Dec 7, 2022
42eb260
Merge branch 'master' into feature/plot
mergify[bot] Dec 9, 2022
b6b71c0
Merge branch 'master' into feature/plot
mergify[bot] Dec 9, 2022
1f77bbf
Merge branch 'master' into feature/plot
mergify[bot] Dec 13, 2022
a011d5f
Merge branch 'master' into feature/plot
mergify[bot] Dec 14, 2022
591b0b7
Merge branch 'master' into feature/plot
mergify[bot] Dec 14, 2022
79b8b65
Merge branch 'master' into feature/plot
mergify[bot] Dec 14, 2022
e5fce3e
Merge branch 'master' into feature/plot
mergify[bot] Dec 14, 2022
b52201e
Merge branch 'master' into feature/plot
mergify[bot] Dec 14, 2022
115cbb9
Merge branch 'master' into feature/plot
mergify[bot] Dec 14, 2022
71f60cb
Merge branch 'master' into feature/plot
mergify[bot] Dec 15, 2022
c7c5030
Merge branch 'master' into feature/plot
mergify[bot] Dec 16, 2022
7666186
Merge branch 'master' into feature/plot
mergify[bot] Dec 16, 2022
f319929
ci: maxfail=25
Borda Dec 17, 2022
d06b0e9
Merge branch 'master' into feature/plot
Borda Dec 17, 2022
dc3fcb2
Merge branch 'master' into feature/plot
Borda Dec 22, 2022
ac5777f
oldest fixed
Borda Dec 22, 2022
d1c17e0
fix docstring
Borda Dec 22, 2022
ac3749a
plt 3.2.0
Borda Dec 22, 2022
b5d47f2
typing
Borda Dec 22, 2022
c7fd476
mypy
Borda Dec 22, 2022
b232bef
""
Borda Dec 22, 2022
d8caeeb
example
Borda Dec 23, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `normalize` argument to `Inception`, `FID`, `KID` metrics ([#1246](https://github.com/Lightning-AI/metrics/pull/1246))


- Added support for plotting of metrics through `.plot()` method ([#1328](https://github.com/Lightning-AI/metrics/pull/1328))


### Changed

- Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259))
Expand All @@ -52,7 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
-


## [0.10.3] - 2022-11-16
Expand Down
4 changes: 4 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,12 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
"sphinx_paramlinks",
"sphinx.ext.githubpages",
"pt_lightning_sphinx_theme.extensions.lightning",
"matplotlib.sphinxext.plot_directive",
]

# Set that source code from plotting is always included
plot_include_source = True

# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]

Expand Down
96 changes: 96 additions & 0 deletions examples/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse

import matplotlib.pyplot as plt
import torch


def accuracy_example():
from torchmetrics.classification import MulticlassAccuracy

p = lambda: torch.randn(20, 5)
t = lambda: torch.randint(5, (20,))

# plot single value
metric = MulticlassAccuracy(num_classes=5)
metric.update(p(), t())
fig, ax = metric.plot()

# plot a value per class
metric = MulticlassAccuracy(num_classes=5, average=None)
metric.update(p(), t())
fig, ax = metric.plot()

# plot two values as an series
metric = MulticlassAccuracy(num_classes=5)
val1 = metric(p(), t())
val2 = metric(p(), t())
fig, ax = metric.plot([val1, val2])

# plot a series of values per class
metric = MulticlassAccuracy(num_classes=5, average=None)
val1 = metric(p(), t())
val2 = metric(p(), t())
fig, ax = metric.plot([val1, val2])
return fig, ax


def mean_squared_error_example():
from torchmetrics.regression import MeanSquaredError

p = lambda: torch.randn(20)
t = lambda: torch.randn(20)

# single val
metric = MeanSquaredError()
metric.update(p(), t())
fig, ax = metric.plot()

# multiple values
metric = MeanSquaredError()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)
return fig, ax


def confusion_matrix_example():
from torchmetrics.classification import MulticlassConfusionMatrix

p = lambda: torch.randn(20, 5)
t = lambda: torch.randint(5, (20,))

# plot single value
metric = MulticlassConfusionMatrix(num_classes=5)
metric.update(p(), t())
fig, ax = metric.plot()
return fig, ax


if __name__ == "__main__":
list_of_choices = ["accuracy", "mean_squared_error", "confusion_matrix"]
parser = argparse.ArgumentParser(description="Example script for plotting metrics.")
parser.add_argument("metric", choices=list_of_choices)
args = parser.parse_args()

if args.metric == "accuracy":
fig, ax = accuracy_example()

if args.metric == "mean_squared_error":
fig, ax = mean_squared_error_example()

if args.metric == "confusion_matrix":
fig, ax = confusion_matrix_example()

plt.show()
1 change: 1 addition & 0 deletions requirements/integrate.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pytorch-lightning>=1.5
matplotlib>=3.0.0
Borda marked this conversation as resolved.
Show resolved Hide resolved
100 changes: 99 additions & 1 deletion src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor, tensor
Expand All @@ -28,8 +28,13 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import AverageMethod, DataType
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _PLOT_OUT_TYPE, plot_single_or_multi_val
from torchmetrics.utilities.prints import rank_zero_warn

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["BinaryAccuracy.plot", "MulticlassAccuracy.plot"]

from torchmetrics.classification.stat_scores import ( # isort:skip
StatScores,
BinaryStatScores,
Expand Down Expand Up @@ -107,12 +112,58 @@ class BinaryAccuracy(BinaryStatScores):
is_differentiable = False
higher_is_better = True
full_state_update: bool = False
plot_options: dict = {"lower_bound": 0.0, "upper_bound": 1.0}

def compute(self) -> Tensor:
"""Computes accuracy based on inputs passed in to ``update`` previously."""
tp, fp, tn, fn = self._final_state()
return _accuracy_reduce(tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average)

def plot(self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.

Returns:
fig: Figure object
ax: Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

Examples:

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.classification import BinaryAccuracy
>>> metric = BinaryAccuracy()
>>> metric.update(torch.rand(10), torch.randint(2,(10,)))
>>> metric.plot()

.. plot::
Borda marked this conversation as resolved.
Show resolved Hide resolved
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.classification import BinaryAccuracy
>>> metric = BinaryAccuracy()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(10), torch.randint(2,(10,))))
>>> metric.plot(values)
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__
)
return fig, ax


class MulticlassAccuracy(MulticlassStatScores):
r"""Computes `Accuracy`_ for multiclass tasks:
Expand Down Expand Up @@ -213,12 +264,58 @@ class MulticlassAccuracy(MulticlassStatScores):
is_differentiable = False
higher_is_better = True
full_state_update: bool = False
plot_options = {"lower_bound": 0.0, "upper_bound": 1.0, "legend_name": "Class"}

def compute(self) -> Tensor:
"""Computes accuracy based on inputs passed in to ``update`` previously."""
tp, fp, tn, fn = self._final_state()
return _accuracy_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average)

def plot(self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.

Returns:
fig: Figure object
ax: Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

Examples:

.. plot::
:scale: 75

>>> # Example plotting a single value per class
>>> import torch
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = MulticlassAccuracy(num_classes=3, average=None)
>>> metric.update(torch.randint(3, (20,)), torch.randint(3, (20,)))
>>> metric.plot()

.. plot::
:scale: 75

>>> # Example plotting a multiple values per class
>>> import torch
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = MulticlassAccuracy(num_classes=3, average=None)
>>> values = []
>>> for _ in range(20):
... values.append(metric(torch.randint(3, (20,)), torch.randint(3, (20,))))
>>> metric.plot(values)
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__
)
return fig, ax


class MultilabelAccuracy(MultilabelStatScores):
r"""Computes `Accuracy`_ for multilabel tasks:
Expand Down Expand Up @@ -317,6 +414,7 @@ class MultilabelAccuracy(MultilabelStatScores):
is_differentiable = False
higher_is_better = True
full_state_update: bool = False
plot_options: dict = {"lower_bound": 0.0, "upper_bound": 1.0, "legend_name": "Label"}

def compute(self) -> Tensor:
"""Computes accuracy based on inputs passed in to ``update`` previously."""
Expand Down
35 changes: 35 additions & 0 deletions src/torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,13 @@
_multilabel_confusion_matrix_update,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _PLOT_OUT_TYPE, plot_confusion_matrix
from torchmetrics.utilities.prints import rank_zero_warn

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["MulticlassConfusionMatrix.plot"]


class BinaryConfusionMatrix(Metric):
r"""Computes the `confusion matrix`_ for binary tasks.
Expand Down Expand Up @@ -220,6 +225,36 @@ def compute(self) -> Tensor:
"""
return _multiclass_confusion_matrix_compute(self.confmat, self.normalize)

def plot(self, val: Optional[Tensor] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.

Returns:
fig: Figure object
ax: Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> import torch
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> metric = MulticlassConfusionMatrix(num_classes=5)
>>> metric.update(torch.randint(5, (20,)), torch.randint(5, (20,)))
>>> metric.plot()
"""
val = val or self.compute()
if not isinstance(val, Tensor):
raise TypeError(f"Expected val to be a single tensor but got {val}")
fig, ax = plot_confusion_matrix(val)
return fig, ax


class MultilabelConfusionMatrix(Metric):
r"""Computes the `confusion matrix`_ for multilabel tasks.
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,10 @@ def compute(self) -> Any:
"""Override this method to compute the final metric value from state variables synchronized across the
distributed backend."""

def plot(self, *_: Any, **__: Any) -> None:
"""Override this method plot the metric value."""
raise NotImplementedError

def reset(self) -> None:
"""This method automatically resets the metric state variables to their default value."""
self._update_count = 0
Expand Down
Loading