Skip to content

Commit

Permalink
Fix prefix postfix in nested collection (#1773)
Browse files Browse the repository at this point in the history
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
SkafteNicki and mergify[bot] authored May 12, 2023
1 parent 3ea6062 commit 17c0e9f
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed states not being correctly synced and device transfered in `MeanAveragePrecision` for `iou_type="segm"` ([#1763](https://github.com/Lightning-AI/torchmetrics/pull/1763))


- Fixed use of `prefix` and `postfix` in nested `MetricCollection` ([#1773](https://github.com/Lightning-AI/torchmetrics/pull/1773))

## [0.11.4] - 2023-03-10

### Fixed
Expand Down
56 changes: 48 additions & 8 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
import torch
from torch import Tensor
from torch.nn import Module, ModuleDict
from typing_extensions import Literal

from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import _flatten_dict, allclose
from torchmetrics.utilities.data import allclose
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val

Expand Down Expand Up @@ -176,9 +177,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
res = {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items(keep_base=True, copy_state=False)}
res = _flatten_dict(res)
return {self._set_name(k): v for k, v in res.items()}
return self._compute_and_reduce("forward", *args, **kwargs)

def update(self, *args: Any, **kwargs: Any) -> None:
"""Call update for each metric sequentially.
Expand Down Expand Up @@ -288,9 +287,43 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None:

def compute(self) -> Dict[str, Any]:
"""Compute the result for each metric in the collection."""
res = {k: m.compute() for k, m in self.items(keep_base=True, copy_state=False)}
res = _flatten_dict(res)
return {self._set_name(k): v for k, v in res.items()}
return self._compute_and_reduce("compute")

def _compute_and_reduce(
self, method_name: Literal["compute", "forward"], *args: Any, **kwargs: Any
) -> Dict[str, Any]:
"""Compute result from collection and reduce into a single dictionary.
Args:
method_name: The method to call on each metric in the collection.
Should be either `compute` or `forward`.
args: Positional arguments to pass to each metric (if method_name is `forward`)
kwargs: Keyword arguments to pass to each metric (if method_name is `forward`)
Raises:
ValueError:
If method_name is not `compute` or `forward`.
"""
result = {}
for k, m in self.items(keep_base=True, copy_state=False):
if method_name == "compute":
res = m.compute()
elif method_name == "forward":
res = m(*args, **m._filter_kwargs(**kwargs))
else:
raise ValueError("method_name should be either 'compute' or 'forward', but got {method_name}")

if isinstance(res, dict):
for key, v in res.items():
if hasattr(m, "prefix") and m.prefix is not None:
key = f"{m.prefix}{key}"
if hasattr(m, "postfix") and m.postfix is not None:
key = f"{key}{m.postfix}"
result[key] = v
else:
result[k] = res
return {self._set_name(k): v for k, v in result.items()}

def reset(self) -> None:
"""Call reset for each metric sequentially."""
Expand Down Expand Up @@ -358,6 +391,8 @@ def add_metrics(
self[name] = metric
else:
for k, v in metric.items(keep_base=False):
v.postfix = metric.postfix
v.prefix = metric.prefix
self[f"{name}_{k}"] = v
elif isinstance(metrics, Sequence):
for metric in metrics:
Expand All @@ -373,9 +408,14 @@ def add_metrics(
self[name] = metric
else:
for k, v in metric.items(keep_base=False):
v.postfix = metric.postfix
v.prefix = metric.prefix
self[k] = v
else:
raise ValueError("Unknown input to MetricCollection.")
raise ValueError(
"Unknown input to MetricCollection. Expected, `Metric`, `MetricCollection` or `dict`/`sequence` of the"
f" previous, but got {metrics}"
)

self._groups_checked = False
if self._enable_compute_groups:
Expand Down
12 changes: 11 additions & 1 deletion tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from torchmetrics.utilities.checks import _allclose_recursive
from unittests.helpers import seed_all
from unittests.helpers.testers import DummyMetricDiff, DummyMetricSum
from unittests.helpers.testers import DummyMetricDiff, DummyMetricMultiOutputDict, DummyMetricSum

seed_all(42)

Expand Down Expand Up @@ -618,3 +618,13 @@ def test_nested_collections(input_collections):
assert "valmetrics/macro_MulticlassPrecision" in val
assert "valmetrics/micro_MulticlassAccuracy" in val
assert "valmetrics/micro_MulticlassPrecision" in val


def test_double_nested_collections():
"""Test that double nested collections gets flattened to a single collection."""
collection1 = MetricCollection([DummyMetricMultiOutputDict()], prefix="prefix1_", postfix="_postfix1")
collection2 = MetricCollection([collection1], prefix="prefix2_", postfix="_postfix2")
x = torch.randn(10).sum()
val = collection2(x)
assert "prefix2_prefix1_output1_postfix1_postfix2" in val
assert "prefix2_prefix1_output2_postfix1_postfix2" in val
8 changes: 8 additions & 0 deletions tests/unittests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,14 @@ def compute(self):
return [self.x, self.x]


class DummyMetricMultiOutputDict(DummyMetricSum):
"""DummyMetricMultiOutput for testing core components."""

def compute(self):
"""Compute value."""
return {"output1": self.x, "output2": self.x}


def inject_ignore_index(x: Tensor, ignore_index: int) -> Tensor:
"""Injecting the ignored index value into a tensor randomly."""
if any(x.flatten() == ignore_index): # ignore index is a class label
Expand Down

0 comments on commit 17c0e9f

Please sign in to comment.