diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f6dc80500e..f175866e243 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1480](https://github.com/Lightning-AI/metrics/pull/1480), [#1490](https://github.com/Lightning-AI/metrics/pull/1490), [#1581](https://github.com/Lightning-AI/metrics/pull/1581), + [#1593](https://github.com/Lightning-AI/metrics/pull/1593), ) diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index fb5b33296d0..32ae8ebcf11 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -140,15 +140,24 @@ def plot( .. plot:: :scale: 75 - >>> from torch import randn, randint - >>> import torch.nn.functional as F - >>> # Example plotting a combined value across all classes + >>> # Example plotting a single + >>> import torch >>> from torchmetrics.classification import BinaryAUROC - >>> preds = F.softmax(randn(20, 2), dim=1) - >>> target = randint(2, (20,)) >>> metric = BinaryAUROC() - >>> metric.update(preds[:, 1], target) + >>> metric.update(torch.rand(20,), torch.randint(2, (20,))) >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.classification import BinaryAUROC + >>> metric = BinaryAUROC() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(20,), torch.randint(2, (20,)))) + >>> fig_, ax_ = metric.plot(values) """ return self._plot(val, ax) @@ -276,12 +285,24 @@ def plot( .. plot:: :scale: 75 - >>> from torch import randn, randint - >>> # Example plotting a combined value across all classes + >>> # Example plotting a single + >>> import torch >>> from torchmetrics.classification import MulticlassAUROC - >>> metric = MulticlassAUROC(num_classes=3, average="macro") - >>> metric.update(randn(20, 3), randint(3, (20,))) + >>> metric = MulticlassAUROC(num_classes=3) + >>> metric.update(torch.randn(20, 3), torch.randint(3,(20,))) >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.classification import MulticlassAUROC + >>> metric = MulticlassAUROC(num_classes=3) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.randn(20, 3), torch.randint(3, (20,)))) + >>> fig_, ax_ = metric.plot(values) """ return self._plot(val, ax) @@ -367,7 +388,7 @@ class MultilabelAUROC(MultilabelPrecisionRecallCurve): full_state_update: bool = False plot_lower_bound = 0.0 plot_upper_bound = 1.0 - plot_legend_name = "Class" + plot_legend_name = "Label" def __init__( self, @@ -411,19 +432,24 @@ def plot( .. plot:: :scale: 75 - >>> from torch import tensor + >>> # Example plotting a single + >>> import torch >>> from torchmetrics.classification import MultilabelAUROC - >>> preds = tensor([[0.75, 0.05, 0.35], - ... [0.45, 0.75, 0.05], - ... [0.05, 0.55, 0.75], - ... [0.05, 0.65, 0.05]]) - >>> target = tensor([[1, 0, 1], - ... [0, 0, 0], - ... [0, 1, 1], - ... [1, 1, 1]]) - >>> metric = MultilabelAUROC(num_labels=3, average="macro", thresholds=None) - >>> metric.update(preds, target) + >>> metric = MultilabelAUROC(num_labels=3) + >>> metric.update(torch.rand(20,3), torch.randint(2, (20,3))) >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.classification import MultilabelAUROC + >>> metric = MultilabelAUROC(num_labels=3) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(20,3), torch.randint(2, (20,3)))) + >>> fig_, ax_ = metric.plot(values) """ return self._plot(val, ax) diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index d1a0d80e011..432c78519dd 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -31,6 +31,15 @@ from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = [ + "BinaryAveragePrecision.plot", + "MulticlassAveragePrecision.plot", + "MultilabelAveragePrecision.plot", + ] class BinaryAveragePrecision(BinaryPrecisionRecallCurve): @@ -97,12 +106,56 @@ class BinaryAveragePrecision(BinaryPrecisionRecallCurve): is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False + plot_lower_bound = 0.0 + plot_upper_bound = 1.0 def compute(self) -> Tensor: """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _binary_average_precision_compute(state, self.thresholds) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single + >>> import torch + >>> from torchmetrics.classification import BinaryAveragePrecision + >>> metric = BinaryAveragePrecision() + >>> metric.update(torch.rand(20,), torch.randint(2, (20,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.classification import BinaryAveragePrecision + >>> metric = BinaryAveragePrecision() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(20,), torch.randint(2, (20,)))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) + class MulticlassAveragePrecision(MulticlassPrecisionRecallCurve): r"""Compute the average precision (AP) score for binary tasks. @@ -185,6 +238,9 @@ class MulticlassAveragePrecision(MulticlassPrecisionRecallCurve): is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False + plot_lower_bound = 0.0 + plot_upper_bound = 1.0 + plot_legend_name = "Class" def __init__( self, @@ -208,6 +264,47 @@ def compute(self) -> Tensor: state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat return _multiclass_average_precision_compute(state, self.num_classes, self.average, self.thresholds) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single + >>> import torch + >>> from torchmetrics.classification import MulticlassAveragePrecision + >>> metric = MulticlassAveragePrecision(num_classes=3) + >>> metric.update(torch.randn(20, 3), torch.randint(3,(20,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.classification import MulticlassAveragePrecision + >>> metric = MulticlassAveragePrecision(num_classes=3) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.randn(20, 3), torch.randint(3, (20,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) + class MultilabelAveragePrecision(MultilabelPrecisionRecallCurve): r"""Compute the average precision (AP) score for binary tasks. @@ -293,6 +390,9 @@ class MultilabelAveragePrecision(MultilabelPrecisionRecallCurve): is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False + plot_lower_bound = 0.0 + plot_upper_bound = 1.0 + plot_legend_name = "Label" def __init__( self, @@ -318,6 +418,47 @@ def compute(self) -> Tensor: state, self.num_labels, self.average, self.thresholds, self.ignore_index ) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single + >>> import torch + >>> from torchmetrics.classification import MultilabelAveragePrecision + >>> metric = MultilabelAveragePrecision(num_labels=3) + >>> metric.update(torch.rand(20,3), torch.randint(2, (20,3))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.classification import MultilabelAveragePrecision + >>> metric = MultilabelAveragePrecision(num_labels=3) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(20,3), torch.randint(2, (20,3)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) + class AveragePrecision: r"""Compute the average precision (AP) score. diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index 4956ae69567..a8c3ea5ca05 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -30,6 +30,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["BinaryCalibrationError.plot", "MulticlassCalibrationError.plot"] class BinaryCalibrationError(Metric): @@ -94,6 +99,8 @@ class BinaryCalibrationError(Metric): is_differentiable: bool = False higher_is_better: bool = False full_state_update: bool = False + plot_lower_bound = 0.0 + plot_upper_bound = 1.0 def __init__( self, @@ -130,6 +137,47 @@ def compute(self) -> Tensor: accuracies = dim_zero_cat(self.accuracies) return _ce_compute(confidences, accuracies, self.n_bins, norm=self.norm) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import rand, randint + >>> # Example plotting a single value + >>> from torchmetrics.classification import BinaryCalibrationError + >>> metric = BinaryCalibrationError(n_bins=2, norm='l1') + >>> metric.update(rand(10), randint(2,(10,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import rand, randint + >>> # Example plotting multiple values + >>> from torchmetrics.classification import BinaryCalibrationError + >>> metric = BinaryCalibrationError(n_bins=2, norm='l1') + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(rand(10), randint(2,(10,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) + class MulticlassCalibrationError(Metric): r"""`Top-label Calibration Error`_ for multiclass tasks. @@ -197,6 +245,8 @@ class MulticlassCalibrationError(Metric): is_differentiable: bool = False higher_is_better: bool = False full_state_update: bool = False + plot_lower_bound = 0.0 + plot_upper_bound = 1.0 def __init__( self, @@ -235,6 +285,47 @@ def compute(self) -> Tensor: accuracies = dim_zero_cat(self.accuracies) return _ce_compute(confidences, accuracies, self.n_bins, norm=self.norm) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn, randint + >>> # Example plotting a single value + >>> from torchmetrics.classification import MulticlassCalibrationError + >>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l1') + >>> metric.update(randn(20,3).softmax(dim=-1), randint(3, (20,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn, randint + >>> # Example plotting a multiple values + >>> from torchmetrics.classification import MulticlassCalibrationError + >>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l1') + >>> values = [] + >>> for _ in range(20): + ... values.append(metric(randn(20,3).softmax(dim=-1), randint(3, (20,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) + class CalibrationError: r"""`Top-label Calibration Error`_. diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index bd3599973d7..40e8e482ec1 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -24,6 +24,11 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["BinaryCohenKappa.plot", "MulticlassCohenKappa.plot"] class BinaryCohenKappa(BinaryConfusionMatrix): @@ -85,6 +90,8 @@ class labels. is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False + plot_lower_bound = 0.0 + plot_upper_bound = 1.0 def __init__( self, @@ -104,6 +111,47 @@ def compute(self) -> Tensor: """Compute metric.""" return _cohen_kappa_reduce(self.confmat, self.weights) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import rand, randint + >>> # Example plotting a single value + >>> from torchmetrics.classification import BinaryCohenKappa + >>> metric = BinaryCohenKappa() + >>> metric.update(rand(10), randint(2,(10,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import rand, randint + >>> # Example plotting multiple values + >>> from torchmetrics.classification import BinaryCohenKappa + >>> metric = BinaryCohenKappa() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(rand(10), randint(2,(10,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) + class MulticlassCohenKappa(MulticlassConfusionMatrix): r"""Calculate `Cohen's kappa score`_ that measures inter-annotator agreement for multiclass tasks. @@ -167,6 +215,8 @@ class labels. is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False + plot_lower_bound = 0.0 + plot_upper_bound = 1.0 def __init__( self, @@ -186,6 +236,47 @@ def compute(self) -> Tensor: """Compute metric.""" return _cohen_kappa_reduce(self.confmat, self.weights) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn, randint + >>> # Example plotting a single value + >>> from torchmetrics.classification import MulticlassCohenKappa + >>> metric = MulticlassCohenKappa(num_classes=3) + >>> metric.update(randn(20,3).softmax(dim=-1), randint(3, (20,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn, randint + >>> # Example plotting a multiple values + >>> from torchmetrics.classification import MulticlassCohenKappa + >>> metric = MulticlassCohenKappa(num_classes=3) + >>> values = [] + >>> for _ in range(20): + ... values.append(metric(randn(20,3).softmax(dim=-1), randint(3, (20,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) + class CohenKappa: r"""Calculate `Cohen's kappa score`_ that measures inter-annotator agreement. diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 659bd9dfdfa..030a3dcaf31 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -33,11 +33,18 @@ from torchmetrics.classification import ( BinaryAccuracy, BinaryAUROC, + BinaryAveragePrecision, + BinaryCalibrationError, + BinaryCohenKappa, BinaryConfusionMatrix, BinaryROC, MulticlassAccuracy, MulticlassAUROC, + MulticlassAveragePrecision, + MulticlassCalibrationError, + MulticlassCohenKappa, MulticlassConfusionMatrix, + MultilabelAveragePrecision, MultilabelConfusionMatrix, ) from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio @@ -57,6 +64,7 @@ _binary_randint_input = lambda: torch.randint(2, (10,)) _multiclass_randint_input = lambda: torch.randint(3, (10,)) _multiclass_randn_input = lambda: torch.randn(10, 3).softmax(dim=-1) +_multilabel_rand_input = lambda: torch.rand(10, 3) _multilabel_randint_input = lambda: torch.randint(2, (10, 3)) _audio_input = lambda: torch.randn(8000) _image_input = lambda: torch.rand([8, 3, 16, 16]) @@ -184,6 +192,38 @@ pytest.param(MeanMetric, _rand_input, None, id="mean metric"), pytest.param(MinMetric, _rand_input, None, id="min metric"), pytest.param(MaxMetric, _rand_input, None, id="min metric"), + pytest.param(BinaryAveragePrecision, _rand_input, _binary_randint_input, id="binary average precision"), + pytest.param( + partial(BinaryCalibrationError, n_bins=2, norm="l1"), + _rand_input, + _binary_randint_input, + id="binary calibration error", + ), + pytest.param(BinaryCohenKappa, _rand_input, _binary_randint_input, id="binary cohen kappa"), + pytest.param( + partial(MulticlassAveragePrecision, num_classes=3), + _multiclass_randn_input, + _multiclass_randint_input, + id="multiclass average precision", + ), + pytest.param( + partial(MulticlassCalibrationError, num_classes=3, n_bins=3, norm="l1"), + _multiclass_randn_input, + _multiclass_randint_input, + id="multiclass calibration error", + ), + pytest.param( + partial(MulticlassCohenKappa, num_classes=3), + _multiclass_randn_input, + _multiclass_randint_input, + id="multiclass cohen kappa", + ), + pytest.param( + partial(MultilabelAveragePrecision, num_labels=3), + _multilabel_rand_input, + _multilabel_randint_input, + id="multilabel average precision", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5])