Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MetricCollection with repeated compute calls #2211

Merged
merged 9 commits into from
Nov 28, 2023
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug in `Metric._reduce_states(...)` when using `dist_sync_fn="cat"` ([#2226](https://github.com/Lightning-AI/torchmetrics/pull/2226))


- Fixed bug in `MetricCollection` when using compute groups and `compute` is called more than once ([#2211](https://github.com/Lightning-AI/torchmetrics/pull/2211))

Borda marked this conversation as resolved.
Show resolved Hide resolved
## [1.2.0] - 2023-09-22

### Added
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
# Determine if we just should set a reference or a full copy
setattr(mi, state, deepcopy(m0_state) if copy else m0_state)
mi._update_count = deepcopy(m0._update_count) if copy else m0._update_count
mi._computed = deepcopy(m0._computed) if copy else m0._computed
self._state_is_copy = copy

def compute(self) -> Dict[str, Any]:
Expand Down
8 changes: 5 additions & 3 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ class TestComputeGroups:
("prefix_", "_postfix"),
],
)
def test_check_compute_groups_correctness(self, metrics, expected, preds, target, prefix, postfix):
@pytest.mark.parametrize("with_reset", [True, False])
def test_check_compute_groups_correctness(self, metrics, expected, preds, target, prefix, postfix, with_reset):
"""Check that compute groups are formed after initialization and that metrics are correctly computed."""
if isinstance(metrics, MetricCollection):
prefix, postfix = None, None # disable for nested collections
Expand Down Expand Up @@ -445,8 +446,9 @@ def test_check_compute_groups_correctness(self, metrics, expected, preds, target
for key in res_cg:
assert torch.allclose(res_cg[key], res_without_cg[key])

m.reset()
m2.reset()
if with_reset:
m.reset()
m2.reset()

@pytest.mark.parametrize("method", ["items", "values", "keys"])
def test_check_compute_groups_items_and_values(self, metrics, expected, preds, target, method):
Expand Down
Loading