From 59147edca8bc2846c661cf1a8cb3d4652626106a Mon Sep 17 00:00:00 2001 From: Su YR Date: Sat, 17 Feb 2024 03:17:56 +0800 Subject: [PATCH 1/5] fix: MetricCollection did not copy inner state of metric in ClasswiseWrapper when computing groups metrics Issue Link: https://github.com/Lightning-AI/torchmetrics/issues/2389 --- src/torchmetrics/wrappers/classwise.py | 21 +++++++++++++ tests/unittests/bases/test_collections.py | 37 ++++++++++++++++++++++- 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 3c8d6621bc2..559e04152d7 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import typing from typing import Any, Dict, List, Optional, Sequence, Union from torch import Tensor @@ -20,6 +21,9 @@ from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE from torchmetrics.wrappers.abstract import WrapperMetric +if typing.TYPE_CHECKING: + from torch.nn import Module + if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["ClasswiseWrapper.plot"] @@ -209,3 +213,20 @@ def plot( """ return self._plot(val, ax) + + def __getattr__(self, name: str) -> Union[Tensor, "Module"]: + """Get attribute from classwise wrapper.""" + # return state from self.metric + if name in ["tp", "fp", "fn", "tn"]: + return getattr(self.metric, name) + + return super().__getattr__(name) + + def __setattr__(self, name: str, value: Any) -> None: + """Set attribute to classwise wrapper.""" + super().__setattr__(name, value) + if name == "metric": + self._defaults = self.metric._defaults + if hasattr(self, "metric") and name in ["tp", "fp", "fn", "tn", "_update_count", "_computed"]: + # update ``_update_count`` and ``_computed`` of internal metric to prevent warning. + setattr(self.metric, name, value) diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 9e4ac4a5897..3c34cb965ad 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -17,7 +17,7 @@ import pytest import torch -from torchmetrics import Metric, MetricCollection +from torchmetrics import ClasswiseWrapper, Metric, MetricCollection from torchmetrics.classification import ( BinaryAccuracy, MulticlassAccuracy, @@ -540,6 +540,41 @@ def test_compute_group_define_by_user(): assert m.compute() +def test_classwise_wrapper_compute_group(): + """Check that user can provide compute groups.""" + classwise_accuracy = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="accuracy") + classwise_recall = ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), prefix="recall") + classwise_precision = ClasswiseWrapper(MulticlassPrecision(num_classes=3, average=None), prefix="precision") + + m = MetricCollection( + { + "accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="accuracy"), + "recall": ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), prefix="recall"), + "precision": ClasswiseWrapper(MulticlassPrecision(num_classes=3, average=None), prefix="precision"), + }, + compute_groups=[["accuracy", "recall", "precision"]], + ) + + # Check that we are not going to check the groups in the first update + assert m._groups_checked + assert m.compute_groups == {0: ["accuracy", "recall", "precision"]} + + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + + expected = { + **classwise_accuracy(preds, target), + **classwise_recall(preds, target), + **classwise_precision(preds, target), + } + + m.update(preds, target) + res = m.compute() + + for key in expected: + assert torch.allclose(res[key], expected[key]) + + def test_compute_on_different_dtype(): """Check that extraction of compute groups are robust towards difference in dtype.""" m = MetricCollection([ From d5e19c41de35bfe1dc389fa4c18f2526cb31501c Mon Sep 17 00:00:00 2001 From: Su YR Date: Tue, 20 Feb 2024 09:43:06 +0800 Subject: [PATCH 2/5] fix: set _persistent and _reductions be same as internal metric --- src/torchmetrics/wrappers/classwise.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 559e04152d7..698d0f51848 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -227,6 +227,8 @@ def __setattr__(self, name: str, value: Any) -> None: super().__setattr__(name, value) if name == "metric": self._defaults = self.metric._defaults + self._persistent = self.metric._persistent + self._reductions = self.metric._reductions if hasattr(self, "metric") and name in ["tp", "fp", "fn", "tn", "_update_count", "_computed"]: # update ``_update_count`` and ``_computed`` of internal metric to prevent warning. setattr(self.metric, name, value) From ca85a31ddc9402d927a5d7e57f24657f4c5cc8ed Mon Sep 17 00:00:00 2001 From: Su YR Date: Tue, 20 Feb 2024 09:50:43 +0800 Subject: [PATCH 3/5] test: check metric state_dict wrapped in `ClasswiseWrapper` --- tests/unittests/bases/test_collections.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 3c34cb965ad..16c95fc879a 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -574,6 +574,9 @@ def test_classwise_wrapper_compute_group(): for key in expected: assert torch.allclose(res[key], expected[key]) + # check metric state_dict + m.state_dict() + def test_compute_on_different_dtype(): """Check that extraction of compute groups are robust towards difference in dtype.""" From 08b4530a4d3a10550331eeed8eb8fc45f6d1d353 Mon Sep 17 00:00:00 2001 From: Su YR Date: Tue, 5 Mar 2024 18:27:55 +0800 Subject: [PATCH 4/5] refactor: make __getattr__ and __setattr__ of ClasswiseWrapper more general --- src/torchmetrics/wrappers/classwise.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 698d0f51848..0920118c919 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -216,19 +216,19 @@ def plot( def __getattr__(self, name: str) -> Union[Tensor, "Module"]: """Get attribute from classwise wrapper.""" - # return state from self.metric - if name in ["tp", "fp", "fn", "tn"]: - return getattr(self.metric, name) + if name == "metric" or (name in self.__dict__ and name not in self.metric.__dict__): + # we need this to prevent from infinite getattribute loop. + return super().__getattr__(name) - return super().__getattr__(name) + return getattr(self.metric, name) def __setattr__(self, name: str, value: Any) -> None: """Set attribute to classwise wrapper.""" - super().__setattr__(name, value) - if name == "metric": - self._defaults = self.metric._defaults - self._persistent = self.metric._persistent - self._reductions = self.metric._reductions - if hasattr(self, "metric") and name in ["tp", "fp", "fn", "tn", "_update_count", "_computed"]: - # update ``_update_count`` and ``_computed`` of internal metric to prevent warning. + if hasattr(self, "metric") and name in self.metric._defaults: setattr(self.metric, name, value) + else: + super().__setattr__(name, value) + if name == "metric": + self._defaults = self.metric._defaults + self._persistent = self.metric._persistent + self._reductions = self.metric._reductions From e14c59c38ca69b98b1bb293dd1cd6912c40d19e3 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 5 Mar 2024 14:52:20 +0100 Subject: [PATCH 5/5] chlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f2aa69ba29..f970dbc9df2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- +- Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424)) ### Deprecated