Skip to content

Commit

Permalink
Add plotting 7/n (#1593)
Browse files Browse the repository at this point in the history
* more plotting

* changelog
  • Loading branch information
SkafteNicki authored Mar 6, 2023
1 parent 163cb61 commit fdececb
Show file tree
Hide file tree
Showing 6 changed files with 415 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1480](https://github.com/Lightning-AI/metrics/pull/1480),
[#1490](https://github.com/Lightning-AI/metrics/pull/1490),
[#1581](https://github.com/Lightning-AI/metrics/pull/1581),
[#1593](https://github.com/Lightning-AI/metrics/pull/1593),
)


Expand Down
70 changes: 48 additions & 22 deletions src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,24 @@ def plot(
.. plot::
:scale: 75
>>> from torch import randn, randint
>>> import torch.nn.functional as F
>>> # Example plotting a combined value across all classes
>>> # Example plotting a single
>>> import torch
>>> from torchmetrics.classification import BinaryAUROC
>>> preds = F.softmax(randn(20, 2), dim=1)
>>> target = randint(2, (20,))
>>> metric = BinaryAUROC()
>>> metric.update(preds[:, 1], target)
>>> metric.update(torch.rand(20,), torch.randint(2, (20,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.classification import BinaryAUROC
>>> metric = BinaryAUROC()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(20,), torch.randint(2, (20,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)

Expand Down Expand Up @@ -276,12 +285,24 @@ def plot(
.. plot::
:scale: 75
>>> from torch import randn, randint
>>> # Example plotting a combined value across all classes
>>> # Example plotting a single
>>> import torch
>>> from torchmetrics.classification import MulticlassAUROC
>>> metric = MulticlassAUROC(num_classes=3, average="macro")
>>> metric.update(randn(20, 3), randint(3, (20,)))
>>> metric = MulticlassAUROC(num_classes=3)
>>> metric.update(torch.randn(20, 3), torch.randint(3,(20,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.classification import MulticlassAUROC
>>> metric = MulticlassAUROC(num_classes=3)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.randn(20, 3), torch.randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)

Expand Down Expand Up @@ -367,7 +388,7 @@ class MultilabelAUROC(MultilabelPrecisionRecallCurve):
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0
plot_legend_name = "Class"
plot_legend_name = "Label"

def __init__(
self,
Expand Down Expand Up @@ -411,19 +432,24 @@ def plot(
.. plot::
:scale: 75
>>> from torch import tensor
>>> # Example plotting a single
>>> import torch
>>> from torchmetrics.classification import MultilabelAUROC
>>> preds = tensor([[0.75, 0.05, 0.35],
... [0.45, 0.75, 0.05],
... [0.05, 0.55, 0.75],
... [0.05, 0.65, 0.05]])
>>> target = tensor([[1, 0, 1],
... [0, 0, 0],
... [0, 1, 1],
... [1, 1, 1]])
>>> metric = MultilabelAUROC(num_labels=3, average="macro", thresholds=None)
>>> metric.update(preds, target)
>>> metric = MultilabelAUROC(num_labels=3)
>>> metric.update(torch.rand(20,3), torch.randint(2, (20,3)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.classification import MultilabelAUROC
>>> metric = MultilabelAUROC(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(20,3), torch.randint(2, (20,3))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)

Expand Down
143 changes: 142 additions & 1 deletion src/torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Sequence, Union

from torch import Tensor
from typing_extensions import Literal
Expand All @@ -31,6 +31,15 @@
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = [
"BinaryAveragePrecision.plot",
"MulticlassAveragePrecision.plot",
"MultilabelAveragePrecision.plot",
]


class BinaryAveragePrecision(BinaryPrecisionRecallCurve):
Expand Down Expand Up @@ -97,12 +106,56 @@ class BinaryAveragePrecision(BinaryPrecisionRecallCurve):
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0

def compute(self) -> Tensor:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _binary_average_precision_compute(state, self.thresholds)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single
>>> import torch
>>> from torchmetrics.classification import BinaryAveragePrecision
>>> metric = BinaryAveragePrecision()
>>> metric.update(torch.rand(20,), torch.randint(2, (20,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.classification import BinaryAveragePrecision
>>> metric = BinaryAveragePrecision()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(20,), torch.randint(2, (20,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class MulticlassAveragePrecision(MulticlassPrecisionRecallCurve):
r"""Compute the average precision (AP) score for binary tasks.
Expand Down Expand Up @@ -185,6 +238,9 @@ class MulticlassAveragePrecision(MulticlassPrecisionRecallCurve):
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0
plot_legend_name = "Class"

def __init__(
self,
Expand All @@ -208,6 +264,47 @@ def compute(self) -> Tensor:
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _multiclass_average_precision_compute(state, self.num_classes, self.average, self.thresholds)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single
>>> import torch
>>> from torchmetrics.classification import MulticlassAveragePrecision
>>> metric = MulticlassAveragePrecision(num_classes=3)
>>> metric.update(torch.randn(20, 3), torch.randint(3,(20,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.classification import MulticlassAveragePrecision
>>> metric = MulticlassAveragePrecision(num_classes=3)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.randn(20, 3), torch.randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class MultilabelAveragePrecision(MultilabelPrecisionRecallCurve):
r"""Compute the average precision (AP) score for binary tasks.
Expand Down Expand Up @@ -293,6 +390,9 @@ class MultilabelAveragePrecision(MultilabelPrecisionRecallCurve):
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_lower_bound = 0.0
plot_upper_bound = 1.0
plot_legend_name = "Label"

def __init__(
self,
Expand All @@ -318,6 +418,47 @@ def compute(self) -> Tensor:
state, self.num_labels, self.average, self.thresholds, self.ignore_index
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single
>>> import torch
>>> from torchmetrics.classification import MultilabelAveragePrecision
>>> metric = MultilabelAveragePrecision(num_labels=3)
>>> metric.update(torch.rand(20,3), torch.randint(2, (20,3)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.classification import MultilabelAveragePrecision
>>> metric = MultilabelAveragePrecision(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(20,3), torch.randint(2, (20,3))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class AveragePrecision:
r"""Compute the average precision (AP) score.
Expand Down
Loading

0 comments on commit fdececb

Please sign in to comment.