From e67474d5f51176fa8db579c1bb4c3c104cc047ad Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 7 Apr 2021 21:51:24 +0200 Subject: [PATCH 01/10] fix --- tests/bases/test_metric.py | 28 ++++++++++++++++++++++++++++ torchmetrics/metric.py | 14 +++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index cb4eb553f50..a6bcb8dabb1 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -259,3 +259,31 @@ def test_device_and_dtype_transfer(tmpdir): metric = metric.half() assert metric.x.dtype == torch.float16 + + +def test_warning_on_compute_before_update(): + metric = DummyMetricSum() + + # make sure everything is fine with forward + with pytest.warns(None) as record: + val = metric(1) + assert not record + + metric.reset() + + with pytest.warns(UserWarning, match=r'The ``compute`` method of metric .*'): + val = metric.compute() + assert val == 0.0 + + # after update things should be fine + metric.update(2.0) + with pytest.warns(None) as record: + val = metric.compute() + assert not record + assert val == 2.0 + + + + + + \ No newline at end of file diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 4ea08e36025..91cef6623cc 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -21,7 +21,7 @@ import torch from torch import Tensor, nn -from torchmetrics.utilities import apply_to_collection +from torchmetrics.utilities import apply_to_collection, rank_zero_warn from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum from torchmetrics.utilities.distributed import gather_all_tensors @@ -79,6 +79,7 @@ def __init__( self.compute = self._wrap_compute(self.compute) self._computed = None self._forward_cache = None + self._update_called = False # initialize state self._defaults = {} @@ -198,6 +199,7 @@ def _wrap_update(self, update): @functools.wraps(update) def wrapped_func(*args, **kwargs): self._computed = None + self._update_called = True return update(*args, **kwargs) return wrapped_func @@ -206,6 +208,15 @@ def _wrap_compute(self, compute): @functools.wraps(compute) def wrapped_func(*args, **kwargs): + if not self._update_called: + rank_zero_warn( + f"The ``compute`` method of metric {self.__class__.__name__}" + " was called before the ``update`` method which may lead to errors," + " as metric states have not yet been updated. Will return 0.0 instead.", + UserWarning + ) + return 0.0 + # return cached value if self._computed is not None: return self._computed @@ -255,6 +266,7 @@ def reset(self): This method automatically resets the metric state variables to their default value. """ self._computed = None + self._update_called = False for attr, default in self._defaults.items(): current_val = getattr(self, attr) From 3bd5c2396bbbd15ae6ccd56694fa1bbb724bbd72 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 7 Apr 2021 21:55:22 +0200 Subject: [PATCH 02/10] remove space --- tests/bases/test_metric.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index a6bcb8dabb1..c22326c0cf4 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -281,9 +281,3 @@ def test_warning_on_compute_before_update(): val = metric.compute() assert not record assert val == 2.0 - - - - - - \ No newline at end of file From c4edba5b1c3b8cf408d1e834b5494421f6fc1901 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 7 Apr 2021 21:57:41 +0200 Subject: [PATCH 03/10] pep8 --- tests/bases/test_metric.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index c22326c0cf4..1e679e20ce0 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -263,18 +263,18 @@ def test_device_and_dtype_transfer(tmpdir): def test_warning_on_compute_before_update(): metric = DummyMetricSum() - + # make sure everything is fine with forward with pytest.warns(None) as record: val = metric(1) assert not record - + metric.reset() - + with pytest.warns(UserWarning, match=r'The ``compute`` method of metric .*'): val = metric.compute() assert val == 0.0 - + # after update things should be fine metric.update(2.0) with pytest.warns(None) as record: From 8f8d33eedb8525b5ebd10f510308f863f2425a4f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 7 Apr 2021 22:41:56 +0200 Subject: [PATCH 04/10] fix tests --- tests/bases/test_composition.py | 43 ++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/tests/bases/test_composition.py b/tests/bases/test_composition.py index 98749e783d9..cf17a0bb36c 100644 --- a/tests/bases/test_composition.py +++ b/tests/bases/test_composition.py @@ -27,6 +27,7 @@ def __init__(self, val_to_return): super().__init__() self._num_updates = 0 self._val_to_return = val_to_return + self._update_called = True def update(self, *args, **kwargs) -> None: self._num_updates += 1 @@ -57,6 +58,9 @@ def test_metrics_add(second_operand, expected_result): assert isinstance(final_add, CompositionalMetric) assert isinstance(final_radd, CompositionalMetric) + final_add.update() + final_radd.update() + assert torch.allclose(expected_result, final_add.compute()) assert torch.allclose(expected_result, final_radd.compute()) @@ -75,6 +79,8 @@ def test_metrics_and(second_operand, expected_result): assert isinstance(final_and, CompositionalMetric) assert isinstance(final_rand, CompositionalMetric) + final_and.update() + final_rand.update() assert torch.allclose(expected_result, final_and.compute()) assert torch.allclose(expected_result, final_rand.compute()) @@ -95,6 +101,7 @@ def test_metrics_eq(second_operand, expected_result): assert isinstance(final_eq, CompositionalMetric) + final_eq.update() # can't use allclose for bool tensors assert (expected_result == final_eq.compute()).all() @@ -116,6 +123,7 @@ def test_metrics_floordiv(second_operand, expected_result): assert isinstance(final_floordiv, CompositionalMetric) + final_floordiv.update() assert torch.allclose(expected_result, final_floordiv.compute()) @@ -135,6 +143,7 @@ def test_metrics_ge(second_operand, expected_result): assert isinstance(final_ge, CompositionalMetric) + final_ge.update() # can't use allclose for bool tensors assert (expected_result == final_ge.compute()).all() @@ -155,6 +164,7 @@ def test_metrics_gt(second_operand, expected_result): assert isinstance(final_gt, CompositionalMetric) + final_gt.update() # can't use allclose for bool tensors assert (expected_result == final_gt.compute()).all() @@ -175,6 +185,7 @@ def test_metrics_le(second_operand, expected_result): assert isinstance(final_le, CompositionalMetric) + final_le.update() # can't use allclose for bool tensors assert (expected_result == final_le.compute()).all() @@ -195,6 +206,7 @@ def test_metrics_lt(second_operand, expected_result): assert isinstance(final_lt, CompositionalMetric) + final_lt.update() # can't use allclose for bool tensors assert (expected_result == final_lt.compute()).all() @@ -210,6 +222,7 @@ def test_metrics_matmul(second_operand, expected_result): assert isinstance(final_matmul, CompositionalMetric) + final_matmul.update() assert torch.allclose(expected_result, final_matmul.compute()) @@ -228,6 +241,8 @@ def test_metrics_mod(second_operand, expected_result): final_mod = first_metric % second_operand assert isinstance(final_mod, CompositionalMetric) + + final_mod.update() # prevent Runtime error for PT 1.8 - Long did not match Float assert torch.allclose(expected_result.to(float), final_mod.compute().to(float)) @@ -250,6 +265,8 @@ def test_metrics_mul(second_operand, expected_result): assert isinstance(final_mul, CompositionalMetric) assert isinstance(final_rmul, CompositionalMetric) + final_mul.update() + final_rmul.update() assert torch.allclose(expected_result, final_mul.compute()) assert torch.allclose(expected_result, final_rmul.compute()) @@ -270,6 +287,7 @@ def test_metrics_ne(second_operand, expected_result): assert isinstance(final_ne, CompositionalMetric) + final_ne.update() # can't use allclose for bool tensors assert (expected_result == final_ne.compute()).all() @@ -288,6 +306,8 @@ def test_metrics_or(second_operand, expected_result): assert isinstance(final_or, CompositionalMetric) assert isinstance(final_ror, CompositionalMetric) + final_or.update() + final_ror.update() assert torch.allclose(expected_result, final_or.compute()) assert torch.allclose(expected_result, final_ror.compute()) @@ -308,6 +328,7 @@ def test_metrics_pow(second_operand, expected_result): assert isinstance(final_pow, CompositionalMetric) + final_pow.update() assert torch.allclose(expected_result, final_pow.compute()) @@ -322,6 +343,8 @@ def test_metrics_rfloordiv(first_operand, expected_result): final_rfloordiv = first_operand // second_operand assert isinstance(final_rfloordiv, CompositionalMetric) + + final_rfloordiv.update() assert torch.allclose(expected_result, final_rfloordiv.compute()) @@ -336,6 +359,7 @@ def test_metrics_rmatmul(first_operand, expected_result): assert isinstance(final_rmatmul, CompositionalMetric) + final_rmatmul.update() assert torch.allclose(expected_result, final_rmatmul.compute()) @@ -350,6 +374,7 @@ def test_metrics_rmod(first_operand, expected_result): assert isinstance(final_rmod, CompositionalMetric) + final_rmod.update() assert torch.allclose(expected_result, final_rmod.compute()) @@ -367,7 +392,7 @@ def test_metrics_rpow(first_operand, expected_result): final_rpow = first_operand**second_operand assert isinstance(final_rpow, CompositionalMetric) - + final_rpow.update() assert torch.allclose(expected_result, final_rpow.compute()) @@ -386,7 +411,7 @@ def test_metrics_rsub(first_operand, expected_result): final_rsub = first_operand - second_operand assert isinstance(final_rsub, CompositionalMetric) - + final_rsub.update() assert torch.allclose(expected_result, final_rsub.compute()) @@ -406,7 +431,7 @@ def test_metrics_rtruediv(first_operand, expected_result): final_rtruediv = first_operand / second_operand assert isinstance(final_rtruediv, CompositionalMetric) - + final_rtruediv.update() assert torch.allclose(expected_result, final_rtruediv.compute()) @@ -425,7 +450,7 @@ def test_metrics_sub(second_operand, expected_result): final_sub = first_metric - second_operand assert isinstance(final_sub, CompositionalMetric) - + final_sub.update() assert torch.allclose(expected_result, final_sub.compute()) @@ -445,7 +470,7 @@ def test_metrics_truediv(second_operand, expected_result): final_truediv = first_metric / second_operand assert isinstance(final_truediv, CompositionalMetric) - + final_truediv.update() assert torch.allclose(expected_result, final_truediv.compute()) @@ -463,6 +488,8 @@ def test_metrics_xor(second_operand, expected_result): assert isinstance(final_xor, CompositionalMetric) assert isinstance(final_rxor, CompositionalMetric) + final_xor.update() + final_rxor.update() assert torch.allclose(expected_result, final_xor.compute()) assert torch.allclose(expected_result, final_rxor.compute()) @@ -473,7 +500,7 @@ def test_metrics_abs(): final_abs = abs(first_metric) assert isinstance(final_abs, CompositionalMetric) - + final_abs.update() assert torch.allclose(tensor(1), final_abs.compute()) @@ -482,6 +509,7 @@ def test_metrics_invert(): final_inverse = ~first_metric assert isinstance(final_inverse, CompositionalMetric) + final_inverse.update() assert torch.allclose(tensor(-2), final_inverse.compute()) @@ -490,6 +518,7 @@ def test_metrics_neg(): final_neg = neg(first_metric) assert isinstance(final_neg, CompositionalMetric) + final_neg.update() assert torch.allclose(tensor(-1), final_neg.compute()) @@ -498,6 +527,7 @@ def test_metrics_pos(): final_pos = pos(first_metric) assert isinstance(final_pos, CompositionalMetric) + final_pos.update() assert torch.allclose(tensor(1), final_pos.compute()) @@ -510,6 +540,7 @@ def test_metrics_getitem(value, idx, expected_result): final_getitem = first_metric[idx] assert isinstance(final_getitem, CompositionalMetric) + final_getitem.update() assert torch.allclose(expected_result, final_getitem.compute()) From 9f58d95e682c8c1dbe94c7413c81fc6f0e42c48d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 7 Apr 2021 22:43:37 +0200 Subject: [PATCH 05/10] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6200571f0ac..94fcaa3eae8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68)) - Changed behaviour of `confusionmatrix` for multilabel data to better match `multilabel_confusion_matrix` from sklearn ([#134](https://github.com/PyTorchLightning/metrics/pull/134)) - Updated FBeta arguments ([#111](https://github.com/PyTorchLightning/metrics/pull/111)) +- Calling `compute` before `update` will now return 0 and give an warning ([#164](https://github.com/PyTorchLightning/metrics/pull/164)) ### Deprecated From 2ae96dc2a9010906ae6b163c9b047c527d40765f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 8 Apr 2021 00:35:31 +0200 Subject: [PATCH 06/10] test --- tests/bases/test_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index 1e679e20ce0..6dcb632cb7c 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -273,7 +273,7 @@ def test_warning_on_compute_before_update(): with pytest.warns(UserWarning, match=r'The ``compute`` method of metric .*'): val = metric.compute() - assert val == 0.0 + assert val == 0.0 # after update things should be fine metric.update(2.0) From 15e7297bd6e240e733efa6f1c9e82099a5262a94 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Apr 2021 21:48:40 +0000 Subject: [PATCH 07/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/bases/test_metric.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index db041ac6d8a..6ae59125320 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -299,4 +299,3 @@ def test_warning_on_compute_before_update(): def test_metric_scripts(): torch.jit.script(DummyMetric()) torch.jit.script(DummyMetricSum()) - From 88cd07fc0383ecfa6091b8ac40ffb1d8e1f3c443 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 28 Apr 2021 23:49:37 +0200 Subject: [PATCH 08/10] Update CHANGELOG.md --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e6fa2834fdb..32139adbd37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Calling `compute` before `update` will now return 0 and give an warning ([#164](https://github.com/PyTorchLightning/metrics/pull/164)) + ### Deprecated @@ -73,7 +75,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68)) - Changed behaviour of `confusionmatrix` for multilabel data to better match `multilabel_confusion_matrix` from sklearn ([#134](https://github.com/PyTorchLightning/metrics/pull/134)) - Updated FBeta arguments ([#111](https://github.com/PyTorchLightning/metrics/pull/111)) -- Calling `compute` before `update` will now return 0 and give an warning ([#164](https://github.com/PyTorchLightning/metrics/pull/164)) - Changed `reset` method to use `detach.clone()` instead of `deepcopy` when resetting to default ([#163](https://github.com/PyTorchLightning/metrics/pull/163)) - Metrics passed as dict to `MetricCollection` will now always be in deterministic order ([#173](https://github.com/PyTorchLightning/metrics/pull/173)) - Allowed `MetricCollection` pass metrics as arguments ([#176](https://github.com/PyTorchLightning/metrics/pull/176)) From ed706687d5850007ee01568c98f5b8eda4071fc9 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 28 Apr 2021 23:49:44 +0200 Subject: [PATCH 09/10] Create CHANGELOG.md From 4726553409866aef6445e70cdcca406dc679c06f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 29 Apr 2021 14:25:11 +0200 Subject: [PATCH 10/10] update based on discussion --- CHANGELOG.md | 2 +- torchmetrics/metric.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 32139adbd37..46e89df0feb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- Calling `compute` before `update` will now return 0 and give an warning ([#164](https://github.com/PyTorchLightning/metrics/pull/164)) +- Calling `compute` before `update` will now give an warning ([#164](https://github.com/PyTorchLightning/metrics/pull/164)) ### Deprecated diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index d5f8eed7aa1..7b5ed89e9dc 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -217,10 +217,9 @@ def wrapped_func(*args, **kwargs): rank_zero_warn( f"The ``compute`` method of metric {self.__class__.__name__}" " was called before the ``update`` method which may lead to errors," - " as metric states have not yet been updated. Will return 0.0 instead.", + " as metric states have not yet been updated.", UserWarning ) - return 0.0 # return cached value if self._computed is not None: