Skip to content

Commit

Permalink
Add plotting 10/n (#1610)
Browse files Browse the repository at this point in the history
* more plotting
* changelgo
* imports

---------

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
3 people authored Mar 13, 2023
1 parent 980b32f commit 7a2e88a
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1593](https://github.com/Lightning-AI/metrics/pull/1593),
[#1600](https://github.com/Lightning-AI/metrics/pull/1600),
[#1605](https://github.com/Lightning-AI/metrics/pull/1605),
[#1610](https://github.com/Lightning-AI/metrics/pull/1610),
[#1609](https://github.com/Lightning-AI/metrics/pull/1609),
)

Expand Down
48 changes: 47 additions & 1 deletion src/torchmetrics/retrieval/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any, Optional, Sequence, Union

from torch import Tensor

from torchmetrics.functional.retrieval.precision import retrieval_precision
from torchmetrics.retrieval.base import RetrievalMetric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["RetrievalPrecision.plot"]


class RetrievalPrecision(RetrievalMetric):
Expand Down Expand Up @@ -103,3 +108,44 @@ def __init__(

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
return retrieval_precision(preds, target, top_k=self.top_k, adaptive_k=self.adaptive_k)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> import torch
>>> from torchmetrics.retrieval import RetrievalPrecision
>>> # Example plotting a single value
>>> metric = RetrievalPrecision()
>>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> import torch
>>> from torchmetrics.retrieval import RetrievalPrecision
>>> # Example plotting multiple values
>>> metric = RetrievalPrecision()
>>> values = []
>>> for _ in range(10):
... values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)
48 changes: 48 additions & 0 deletions src/torchmetrics/retrieval/r_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Sequence, Union

from torch import Tensor

from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
from torchmetrics.retrieval.base import RetrievalMetric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["RetrievalRPrecision.plot"]


class RetrievalRPrecision(RetrievalMetric):
Expand Down Expand Up @@ -74,3 +81,44 @@ class RetrievalRPrecision(RetrievalMetric):

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
return retrieval_r_precision(preds, target)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> import torch
>>> from torchmetrics.retrieval import RetrievalRPrecision
>>> # Example plotting a single value
>>> metric = RetrievalRPrecision()
>>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> import torch
>>> from torchmetrics.retrieval import RetrievalRPrecision
>>> # Example plotting multiple values
>>> metric = RetrievalRPrecision()
>>> values = []
>>> for _ in range(10):
... values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)
48 changes: 47 additions & 1 deletion src/torchmetrics/retrieval/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any, Optional, Sequence, Union

from torch import Tensor

from torchmetrics.functional.retrieval.recall import retrieval_recall
from torchmetrics.retrieval.base import RetrievalMetric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["RetrievalRecall.plot"]


class RetrievalRecall(RetrievalMetric):
Expand Down Expand Up @@ -95,3 +100,44 @@ def __init__(

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
return retrieval_recall(preds, target, top_k=self.top_k)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> import torch
>>> from torchmetrics.retrieval import RetrievalRecall
>>> # Example plotting a single value
>>> metric = RetrievalRecall()
>>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> import torch
>>> from torchmetrics.retrieval import RetrievalRecall
>>> # Example plotting multiple values
>>> metric = RetrievalRecall()
>>> values = []
>>> for _ in range(10):
... values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)
48 changes: 48 additions & 0 deletions src/torchmetrics/retrieval/reciprocal_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Sequence, Union

from torch import Tensor

from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank
from torchmetrics.retrieval.base import RetrievalMetric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["RetrievalMRR.plot"]


class RetrievalMRR(RetrievalMetric):
Expand Down Expand Up @@ -73,3 +80,44 @@ class RetrievalMRR(RetrievalMetric):

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
return retrieval_reciprocal_rank(preds, target)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> import torch
>>> from torchmetrics.retrieval import RetrievalMRR
>>> # Example plotting a single value
>>> metric = RetrievalMRR()
>>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> import torch
>>> from torchmetrics.retrieval import RetrievalMRR
>>> # Example plotting multiple values
>>> metric = RetrievalMRR()
>>> values = []
>>> for _ in range(10):
... values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)
30 changes: 30 additions & 0 deletions tests/unittests/utilities/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
MeanSquaredLogError,
MinkowskiDistance,
)
from torchmetrics.retrieval import RetrievalMRR, RetrievalPrecision, RetrievalRecall, RetrievalRPrecision

_rand_input = lambda: torch.rand(10)
_binary_randint_input = lambda: torch.randint(2, (10,))
Expand Down Expand Up @@ -371,6 +372,35 @@ def test_plot_methods_special_image_metrics(metric_class, preds, target, index_0
assert isinstance(ax, matplotlib.axes.Axes)


@pytest.mark.parametrize(
("metric_class", "preds", "target", "indexes"),
[
pytest.param(RetrievalMRR, _rand_input, _binary_randint_input, _binary_randint_input, id="retrieval mrr"),
pytest.param(RetrievalPrecision, _rand_input, _binary_randint_input, _binary_randint_input, id="retrieval mrr"),
pytest.param(
RetrievalRPrecision, _rand_input, _binary_randint_input, _binary_randint_input, id="retrieval mrr"
),
pytest.param(RetrievalRecall, _rand_input, _binary_randint_input, _binary_randint_input, id="retrieval mrr"),
],
)
@pytest.mark.parametrize("num_vals", [1, 2])
def test_plot_methods_retrieval(metric_class, preds, target, indexes, num_vals):
"""Test the plot method for retrieval metrics by themselves, since retrieval metrics requires an extra argument."""
metric = metric_class()

if num_vals == 1:
metric.update(preds(), target(), indexes=indexes())
fig, ax = metric.plot()
else:
vals = []
for _ in range(num_vals):
vals.append(metric(preds(), target(), indexes=indexes()))
fig, ax = metric.plot(vals)

assert isinstance(fig, plt.Figure)
assert isinstance(ax, matplotlib.axes.Axes)


@pytest.mark.parametrize(
("metric_class", "preds", "target", "labels"),
[
Expand Down

0 comments on commit 7a2e88a

Please sign in to comment.