From 4ca1c0dffe78a0038185ea1122f21c58e8980638 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 5 Apr 2021 14:03:04 +0200 Subject: [PATCH 01/21] init files --- tests/regression/test_pearson.py | 0 .../functional/regression/__init__.py | 1 + torchmetrics/functional/regression/pearson.py | 86 +++++++++++++++++++ torchmetrics/regression/__init__.py | 1 + torchmetrics/regression/pearson.py | 13 +++ 5 files changed, 101 insertions(+) create mode 100644 tests/regression/test_pearson.py create mode 100644 torchmetrics/functional/regression/pearson.py create mode 100644 torchmetrics/regression/pearson.py diff --git a/tests/regression/test_pearson.py b/tests/regression/test_pearson.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/torchmetrics/functional/regression/__init__.py b/torchmetrics/functional/regression/__init__.py index 63c2aabb1e2..9684b5642b4 100644 --- a/torchmetrics/functional/regression/__init__.py +++ b/torchmetrics/functional/regression/__init__.py @@ -15,6 +15,7 @@ from torchmetrics.functional.regression.mean_absolute_error import mean_absolute_error # noqa: F401 from torchmetrics.functional.regression.mean_squared_error import mean_squared_error # noqa: F401 from torchmetrics.functional.regression.mean_squared_log_error import mean_squared_log_error # noqa: F401 +from torchmetrics.functional.regression.pearson import pearson_corrcoef # noqa: F401 from torchmetrics.functional.regression.psnr import psnr # noqa: F401 from torchmetrics.functional.regression.r2score import r2score # noqa: F401 from torchmetrics.functional.regression.ssim import ssim # noqa: F401 diff --git a/torchmetrics/functional/regression/pearson.py b/torchmetrics/functional/regression/pearson.py new file mode 100644 index 00000000000..a1829c6443c --- /dev/null +++ b/torchmetrics/functional/regression/pearson.py @@ -0,0 +1,86 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Sequence, Tuple, Union, Optional + +import torch +from torch import Tensor + +from torchmetrics.utilities.checks import _check_same_shape + + +def _update_mean(old_mean: torch.Tensor, old_nobs: torch.Tensor, data: torch.Tensor) -> torch.Tensor: + """ Update a mean estimate given new data + Args: + old_mean: current mean estimate + old_nobs: number of observation until now + data: data used for updating the estimate + Returns: + new_mean: updated mean estimate + """ + data_size = data.shape[0] + return (old_mean * old_nobs + data.mean(dim=0) * data_size) / (old_nobs + data_size) + + +def _update_cov(old_cov: torch.Tensor, old_mean: torch.Tensor, new_mean: torch.Tensor, data: torch.Tensor): + """ Update a covariance estimate given new data + Args: + old_cov: current covariance estimate + old_mean: current mean estimate + new_mean: updated mean estimate + data: data used for updating the estimate + Returns: + new_mean: updated covariance estimate + """ + return old_cov + (data - new_mean).T @ (data - old_mean) + + +def _pearson_corrcoef_update( + preds: Tensor, + target: Tensor, + old_mean: Optional[Tensor] = None, + old_cov: Optional[Tensor] = None, + old_nobs: Optional[Tensor] = None +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + # Data checking + _check_same_shape(preds, target) + preds = preds.squeeze() + target = target.squeeze() + if preds.ndim > 1 or target.ndim > 1: + raise ValueError('Expected both predictions and target to be 1 dimensional tensors. Please flatten.') + data = torch.stack([preds, target], dim=1) + + if old_mean is None: + old_mean = 0 + if old_cov is None: + old_cov = 0 + if old_nobs is None: + old_nobs = 0 + + new_mean = _update_mean(old_mean, old_nobs, data) + new_cov = _update_cov(old_cov, old_mean, new_mean, data) + new_size = old_nobs + preds.numel() + + return new_mean, new_cov, new_size + + +def _pearson_corrcoef_compute(c: torch.Tensor): + x_var = c[0,0] + y_var = c[1,1] + cov = c[1,1] + return cov / (x_var * y_var) + + +def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor: + _, c, _ = _pearson_corrcoef_update(preds, target) + return _pearson_corrcoef_compute(c) diff --git a/torchmetrics/regression/__init__.py b/torchmetrics/regression/__init__.py index bf2da61095e..5c405e10ff3 100644 --- a/torchmetrics/regression/__init__.py +++ b/torchmetrics/regression/__init__.py @@ -15,6 +15,7 @@ from torchmetrics.regression.mean_absolute_error import MeanAbsoluteError # noqa: F401 from torchmetrics.regression.mean_squared_error import MeanSquaredError # noqa: F401 from torchmetrics.regression.mean_squared_log_error import MeanSquaredLogError # noqa: F401 +from torchmetrics.regression.pearson import PearsonCorrcoef # noqa: F401 from torchmetrics.regression.psnr import PSNR # noqa: F401 from torchmetrics.regression.r2score import R2Score # noqa: F401 from torchmetrics.regression.ssim import SSIM # noqa: F401 diff --git a/torchmetrics/regression/pearson.py b/torchmetrics/regression/pearson.py new file mode 100644 index 00000000000..e708df58645 --- /dev/null +++ b/torchmetrics/regression/pearson.py @@ -0,0 +1,13 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file From d25bcf6d669ba71af459dd2b00ec7daf0a5307a0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 5 Apr 2021 15:04:19 +0200 Subject: [PATCH 02/21] rest --- docs/source/references/functional.rst | 7 + docs/source/references/modules.rst | 7 + tests/regression/test_pearson.py | 120 ++++++++++++++++++ torchmetrics/__init__.py | 1 + torchmetrics/functional/__init__.py | 1 + torchmetrics/functional/regression/pearson.py | 28 +++- torchmetrics/regression/pearson.py | 84 +++++++++++- 7 files changed, 241 insertions(+), 7 deletions(-) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 5f450eca1df..8c56cf1287c 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -196,6 +196,13 @@ mean_squared_log_error [func] :noindex: +pearson_corrcoef [func] +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.pearson_corrcoef + :noindex: + + psnr [func] ~~~~~~~~~~~ diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index f6f6d32af86..dc20c58dec0 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -250,6 +250,13 @@ MeanSquaredLogError :noindex: +PearsonCorrcoef +~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.PearsonCorrcoef + :noindex: + + PSNR ~~~~ diff --git a/tests/regression/test_pearson.py b/tests/regression/test_pearson.py index e69de29bb2d..ce46deae1c7 100644 --- a/tests/regression/test_pearson.py +++ b/tests/regression/test_pearson.py @@ -0,0 +1,120 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import pickle +from collections import namedtuple +from functools import partial +import torch +from tests.helpers import seed_all +from scipy.stats import pearsonr + +from torchmetrics.regression.pearson import PearsonCorrcoef +from torchmetrics.functional.regression.pearson import pearson_corrcoef, _update_cov, _update_mean +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +seed_all(42) + + +def test_update_functions(tmpdir): + """ Test that updating the estimates are equal to estimating them on all data """ + data = torch.randn(100, 2) + batch1, batch2 = data.chunk(2) + + def _mean_cov(data): + mean = data.mean(0) + diff = data - mean + cov = diff.T @ diff + return mean, cov + + mean_update, cov_update, size_update = torch.zeros(2), torch.zeros(2,2), torch.zeros(1) + for batch in [batch1, batch2]: + new_mean = _update_mean(mean_update, size_update, batch) + new_cov = _update_cov(cov_update, mean_update, new_mean, batch) + + assert not torch.allclose(new_mean, mean_update), "mean estimate did not update" + assert not torch.allclose(new_cov, cov_update), "covariance estimate did not update" + + size_update += batch.shape[0] + mean_update = new_mean + cov_update = new_cov + + mean, cov = _mean_cov(data) + + assert torch.allclose(mean, mean_update), "updated mean does not correspond to mean of all data" + assert torch.allclose(cov, cov_update), "updated covariance does not correspond to covariance of all data" + + +Input = namedtuple('Input', ["preds", "target"]) + +_single_target_inputs1 = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.rand(NUM_BATCHES, BATCH_SIZE), +) + +_single_target_inputs2 = Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE), + target=torch.randn(NUM_BATCHES, BATCH_SIZE), +) + + +def _sk_metric(preds, target): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + return pearsonr(sk_target, sk_preds)[0] + + +@pytest.mark.parametrize("preds, target", + [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs2.preds, _single_target_inputs2.target), + ], +) +class TestPearsonCorrcoef(MetricTester): + atol=1e-4 + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_explained_variance(self, preds, target, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + PearsonCorrcoef, + _sk_metric, + dist_sync_on_step, + ) + + def test_pearson_corrcoef_functional(self, preds, target): + self.run_functional_metric_test( + preds, + target, + pearson_corrcoef, + _sk_metric + ) + + # Pearson half + cpu does not work due to missing support in torch.sqrt + @pytest.mark.xfail(reason="PSNR metric does not support cpu + half precision") + def test_pearson_corrcoef_half_cpu(self, preds, target): + self.run_precision_test_cpu(preds, target, PearsonCorrcoef, pearson_corrcoef) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') + def test_pearson_corrcoef_half_gpu(self, preds, target): + self.run_precision_test_gpu(preds, target, PearsonCorrcoef, pearson_corrcoef) + + +def test_error_on_different_shape(): + metric = PearsonCorrcoef() + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100, ), torch.randn(50, )) + + with pytest.raises(ValueError, match='Expected both predictions and target to be 1 dimensional tensors.'): + metric(torch.randn(100, 2), torch.randn(100, 2)) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 7e21b7886df..8af309c46d6 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -41,6 +41,7 @@ from torchmetrics.collections import MetricCollection # noqa: F401 E402 from torchmetrics.metric import Metric # noqa: F401 E402 from torchmetrics.regression import ( # noqa: F401 E402 + PearsonCorrcoef, PSNR, SSIM, ExplainedVariance, diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 18d90b6d9ce..428ddd54d97 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -34,6 +34,7 @@ from torchmetrics.functional.regression.mean_relative_error import mean_relative_error # noqa: F401 from torchmetrics.functional.regression.mean_squared_error import mean_squared_error # noqa: F401 from torchmetrics.functional.regression.mean_squared_log_error import mean_squared_log_error # noqa: F401 +from torchmetrics.functional.regression.pearson import pearson_corrcoef # noqa: F401 from torchmetrics.functional.regression.psnr import psnr # noqa: F401 from torchmetrics.functional.regression.r2score import r2score # noqa: F401 from torchmetrics.functional.regression.ssim import ssim # noqa: F401 diff --git a/torchmetrics/functional/regression/pearson.py b/torchmetrics/functional/regression/pearson.py index a1829c6443c..d6061ded778 100644 --- a/torchmetrics/functional/regression/pearson.py +++ b/torchmetrics/functional/regression/pearson.py @@ -57,7 +57,7 @@ def _pearson_corrcoef_update( preds = preds.squeeze() target = target.squeeze() if preds.ndim > 1 or target.ndim > 1: - raise ValueError('Expected both predictions and target to be 1 dimensional tensors. Please flatten.') + raise ValueError('Expected both predictions and target to be 1 dimensional tensors.') data = torch.stack([preds, target], dim=1) if old_mean is None: @@ -74,13 +74,29 @@ def _pearson_corrcoef_update( return new_mean, new_cov, new_size -def _pearson_corrcoef_compute(c: torch.Tensor): +def _pearson_corrcoef_compute(c: Tensor, nobs: Tensor): + c /= (nobs-1) x_var = c[0,0] y_var = c[1,1] - cov = c[1,1] - return cov / (x_var * y_var) + cov = c[0,1] + corrcoef = cov / (x_var * y_var).sqrt() + return torch.clip(corrcoef, -1.0, 1.0) def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor: - _, c, _ = _pearson_corrcoef_update(preds, target) - return _pearson_corrcoef_compute(c) + """ + Computes pearson correlation coefficient. + + Args: + preds: estimated scores + target: ground truth scores + + Example: + >>> from torchmetrics.functional import pearson_corrcoef + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> pearson_corrcoef(preds, target) + tensor(0.9849) + """ + _, c, nobs = _pearson_corrcoef_update(preds, target) + return _pearson_corrcoef_compute(c, nobs) diff --git a/torchmetrics/regression/pearson.py b/torchmetrics/regression/pearson.py index e708df58645..aed0a460502 100644 --- a/torchmetrics/regression/pearson.py +++ b/torchmetrics/regression/pearson.py @@ -10,4 +10,86 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. +from typing import Any, Callable, Optional + +import torch +from torch import Tensor, tensor + +from torchmetrics.functional.regression.pearson import ( + _pearson_corrcoef_compute, + _pearson_corrcoef_update +) +from torchmetrics.metric import Metric + + +class PearsonCorrcoef(Metric): + r""" + Computes `pearson correlation coefficient + `_: + + .. math:: \text{P_corr}(x,y) = \frac{cov(x,y)}{\sigma_x \times \sigma_y} + + Where :math:`y` is a tensor of target values, and :math:`x` is a + tensor of predictions. + + Forward accepts + + - ``preds`` (float tensor): ``(N,)`` + - ``target``(float tensor): ``(N,)`` + + Args: + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather + + Example: + >>> from torchmetrics import PearsonCorrcoef + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> pearson = PearsonCorrcoef() + >>> pearson(preds, target) + tensor(0.9849) + + """ + def __init__( + self, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + self.add_state("cov", default=torch.zeros(2,2), dist_reduce_fx="sum") + self.add_state("mean", default=torch.zeros(2), dist_reduce_fx="sum") + self.add_state("n_obs", default=torch.zeros(1), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + self.mean, self.cov, self.n_obs = _pearson_corrcoef_update( + preds, target, self.mean, self.cov, self.n_obs + ) + + def compute(self): + """ + Computes pearson correlation coefficient over state. + """ + return _pearson_corrcoef_compute(self.cov, self.n_obs) From ff5326e73743c3f88d61536afdcc155f92074d0c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 5 Apr 2021 15:10:11 +0200 Subject: [PATCH 03/21] pep8 --- tests/regression/test_pearson.py | 35 +++++++++---------- torchmetrics/functional/regression/pearson.py | 28 +++++++-------- torchmetrics/regression/pearson.py | 11 +++--- 3 files changed, 35 insertions(+), 39 deletions(-) diff --git a/tests/regression/test_pearson.py b/tests/regression/test_pearson.py index ce46deae1c7..d59b88773d1 100644 --- a/tests/regression/test_pearson.py +++ b/tests/regression/test_pearson.py @@ -11,17 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pytest -import pickle from collections import namedtuple -from functools import partial + +import pytest import torch -from tests.helpers import seed_all from scipy.stats import pearsonr -from torchmetrics.regression.pearson import PearsonCorrcoef -from torchmetrics.functional.regression.pearson import pearson_corrcoef, _update_cov, _update_mean +from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from torchmetrics.functional.regression.pearson import _update_cov, _update_mean, pearson_corrcoef +from torchmetrics.regression.pearson import PearsonCorrcoef + seed_all(42) @@ -36,9 +36,9 @@ def _mean_cov(data): cov = diff.T @ diff return mean, cov - mean_update, cov_update, size_update = torch.zeros(2), torch.zeros(2,2), torch.zeros(1) + mean_update, cov_update, size_update = torch.zeros(2), torch.zeros(2, 2), torch.zeros(1) for batch in [batch1, batch2]: - new_mean = _update_mean(mean_update, size_update, batch) + new_mean = _update_mean(mean_update, size_update, batch) new_cov = _update_cov(cov_update, mean_update, new_mean, batch) assert not torch.allclose(new_mean, mean_update), "mean estimate did not update" @@ -52,7 +52,7 @@ def _mean_cov(data): assert torch.allclose(mean, mean_update), "updated mean does not correspond to mean of all data" assert torch.allclose(cov, cov_update), "updated covariance does not correspond to covariance of all data" - + Input = namedtuple('Input', ["preds", "target"]) @@ -73,14 +73,13 @@ def _sk_metric(preds, target): return pearsonr(sk_target, sk_preds)[0] -@pytest.mark.parametrize("preds, target", - [ - (_single_target_inputs1.preds, _single_target_inputs1.target), - (_single_target_inputs2.preds, _single_target_inputs2.target), - ], -) +@pytest.mark.parametrize("preds, target", [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs2.preds, _single_target_inputs2.target), +]) class TestPearsonCorrcoef(MetricTester): - atol=1e-4 + atol = 1e-4 + @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_explained_variance(self, preds, target, ddp, dist_sync_on_step): @@ -89,7 +88,7 @@ def test_explained_variance(self, preds, target, ddp, dist_sync_on_step): preds, target, PearsonCorrcoef, - _sk_metric, + _sk_metric, dist_sync_on_step, ) @@ -115,6 +114,6 @@ def test_error_on_different_shape(): metric = PearsonCorrcoef() with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): metric(torch.randn(100, ), torch.randn(50, )) - + with pytest.raises(ValueError, match='Expected both predictions and target to be 1 dimensional tensors.'): metric(torch.randn(100, 2), torch.randn(100, 2)) diff --git a/torchmetrics/functional/regression/pearson.py b/torchmetrics/functional/regression/pearson.py index d6061ded778..86dd79a0987 100644 --- a/torchmetrics/functional/regression/pearson.py +++ b/torchmetrics/functional/regression/pearson.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Tuple, Union, Optional +from typing import Optional, Tuple import torch from torch import Tensor @@ -46,11 +46,11 @@ def _update_cov(old_cov: torch.Tensor, old_mean: torch.Tensor, new_mean: torch.T def _pearson_corrcoef_update( - preds: Tensor, - target: Tensor, - old_mean: Optional[Tensor] = None, - old_cov: Optional[Tensor] = None, - old_nobs: Optional[Tensor] = None + preds: Tensor, + target: Tensor, + old_mean: Optional[Tensor] = None, + old_cov: Optional[Tensor] = None, + old_nobs: Optional[Tensor] = None ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: # Data checking _check_same_shape(preds, target) @@ -66,19 +66,19 @@ def _pearson_corrcoef_update( old_cov = 0 if old_nobs is None: old_nobs = 0 - + new_mean = _update_mean(old_mean, old_nobs, data) new_cov = _update_cov(old_cov, old_mean, new_mean, data) new_size = old_nobs + preds.numel() - + return new_mean, new_cov, new_size - - + + def _pearson_corrcoef_compute(c: Tensor, nobs: Tensor): - c /= (nobs-1) - x_var = c[0,0] - y_var = c[1,1] - cov = c[0,1] + c /= (nobs - 1) + x_var = c[0, 0] + y_var = c[1, 1] + cov = c[0, 1] corrcoef = cov / (x_var * y_var).sqrt() return torch.clip(corrcoef, -1.0, 1.0) diff --git a/torchmetrics/regression/pearson.py b/torchmetrics/regression/pearson.py index aed0a460502..8598c86117a 100644 --- a/torchmetrics/regression/pearson.py +++ b/torchmetrics/regression/pearson.py @@ -14,12 +14,9 @@ from typing import Any, Callable, Optional import torch -from torch import Tensor, tensor +from torch import Tensor -from torchmetrics.functional.regression.pearson import ( - _pearson_corrcoef_compute, - _pearson_corrcoef_update -) +from torchmetrics.functional.regression.pearson import _pearson_corrcoef_compute, _pearson_corrcoef_update from torchmetrics.metric import Metric @@ -35,7 +32,7 @@ class PearsonCorrcoef(Metric): Forward accepts - - ``preds`` (float tensor): ``(N,)`` + - ``preds`` (float tensor): ``(N,)`` - ``target``(float tensor): ``(N,)`` Args: @@ -72,7 +69,7 @@ def __init__( process_group=process_group, dist_sync_fn=dist_sync_fn, ) - self.add_state("cov", default=torch.zeros(2,2), dist_reduce_fx="sum") + self.add_state("cov", default=torch.zeros(2, 2), dist_reduce_fx="sum") self.add_state("mean", default=torch.zeros(2), dist_reduce_fx="sum") self.add_state("n_obs", default=torch.zeros(1), dist_reduce_fx="sum") From 0ae07326d4a559b979e05e4d71f90fea458fd548 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 5 Apr 2021 15:14:20 +0200 Subject: [PATCH 04/21] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ab569b5666..9483904dced 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142)) +- Added `PearsonCorrcoef` metric ([#157](https://github.com/PyTorchLightning/metrics/pull/157)) + + ### Changed - Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68)) From ff36a0ea1a02fe93c21eeeb3b141ed2f20f9887b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 5 Apr 2021 16:03:06 +0200 Subject: [PATCH 05/21] clamp --- torchmetrics/functional/regression/pearson.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/regression/pearson.py b/torchmetrics/functional/regression/pearson.py index 86dd79a0987..7ce53b3a4fe 100644 --- a/torchmetrics/functional/regression/pearson.py +++ b/torchmetrics/functional/regression/pearson.py @@ -80,7 +80,7 @@ def _pearson_corrcoef_compute(c: Tensor, nobs: Tensor): y_var = c[1, 1] cov = c[0, 1] corrcoef = cov / (x_var * y_var).sqrt() - return torch.clip(corrcoef, -1.0, 1.0) + return torch.clamp(corrcoef, -1.0, 1.0) def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor: From ca7cb99a24463c13c360665b9d2c147c944b2dbf Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 6 Apr 2021 09:54:16 +0200 Subject: [PATCH 06/21] suggestions --- tests/regression/test_pearson.py | 24 +++++++++---------- torchmetrics/functional/regression/pearson.py | 4 ++++ torchmetrics/regression/pearson.py | 2 +- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/regression/test_pearson.py b/tests/regression/test_pearson.py index d59b88773d1..0e0642355dd 100644 --- a/tests/regression/test_pearson.py +++ b/tests/regression/test_pearson.py @@ -82,26 +82,26 @@ class TestPearsonCorrcoef(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_explained_variance(self, preds, target, ddp, dist_sync_on_step): + def test_pearson_corrcoef(self, preds, target, ddp, dist_sync_on_step): self.run_class_metric_test( - ddp, - preds, - target, - PearsonCorrcoef, - _sk_metric, - dist_sync_on_step, + ddp=ddp, + preds=preds, + target=target, + metric_class=PearsonCorrcoef, + sk_metric=_sk_metric, + dist_sync_on_step=dist_sync_on_step, ) def test_pearson_corrcoef_functional(self, preds, target): self.run_functional_metric_test( - preds, - target, - pearson_corrcoef, - _sk_metric + preds=preds, + target=target, + metric_functional=pearson_corrcoef, + sk_metric=_sk_metric ) # Pearson half + cpu does not work due to missing support in torch.sqrt - @pytest.mark.xfail(reason="PSNR metric does not support cpu + half precision") + @pytest.mark.xfail(reason="PearsonCorrcoef metric does not support cpu + half precision") def test_pearson_corrcoef_half_cpu(self, preds, target): self.run_precision_test_cpu(preds, target, PearsonCorrcoef, pearson_corrcoef) diff --git a/torchmetrics/functional/regression/pearson.py b/torchmetrics/functional/regression/pearson.py index 7ce53b3a4fe..d87b994ceb7 100644 --- a/torchmetrics/functional/regression/pearson.py +++ b/torchmetrics/functional/regression/pearson.py @@ -52,6 +52,9 @@ def _pearson_corrcoef_update( old_cov: Optional[Tensor] = None, old_nobs: Optional[Tensor] = None ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """ updates current estimates of the mean, cov and n_obs with new data for calculating + pearsons correlation + """ # Data checking _check_same_shape(preds, target) preds = preds.squeeze() @@ -75,6 +78,7 @@ def _pearson_corrcoef_update( def _pearson_corrcoef_compute(c: Tensor, nobs: Tensor): + """ computes the final pearson correlation based on covariance matrix and number of observatiosn """ c /= (nobs - 1) x_var = c[0, 0] y_var = c[1, 1] diff --git a/torchmetrics/regression/pearson.py b/torchmetrics/regression/pearson.py index 8598c86117a..9275869dbc5 100644 --- a/torchmetrics/regression/pearson.py +++ b/torchmetrics/regression/pearson.py @@ -61,7 +61,7 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Callable = None, + dist_sync_fn: Optional[Callable] = None, ): super().__init__( compute_on_step=compute_on_step, From 1188d0d82ac895ede50d90381730b69689468f66 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 6 Apr 2021 10:46:18 +0200 Subject: [PATCH 07/21] rename --- tests/regression/test_pearson.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/regression/test_pearson.py b/tests/regression/test_pearson.py index 0e0642355dd..1f05d4f958e 100644 --- a/tests/regression/test_pearson.py +++ b/tests/regression/test_pearson.py @@ -67,7 +67,7 @@ def _mean_cov(data): ) -def _sk_metric(preds, target): +def _sk_pearsonr(preds, target): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() return pearsonr(sk_target, sk_preds)[0] From df47c7fea5833d8e39d7e750fee2daf7bfc9ae31 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 6 Apr 2021 10:47:31 +0200 Subject: [PATCH 08/21] format --- torchmetrics/functional/classification/f_beta.py | 2 +- torchmetrics/functional/regression/pearson.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmetrics/functional/classification/f_beta.py b/torchmetrics/functional/classification/f_beta.py index 70e0687e2da..e477cc51c4b 100644 --- a/torchmetrics/functional/classification/f_beta.py +++ b/torchmetrics/functional/classification/f_beta.py @@ -21,7 +21,7 @@ from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod -def _safe_divide(num: torch.Tensor, denom: torch.Tensor): +def _safe_divide(num: Tensor, denom: Tensor): """ prevent zero division """ denom[denom == 0.] = 1 return num / denom diff --git a/torchmetrics/functional/regression/pearson.py b/torchmetrics/functional/regression/pearson.py index d87b994ceb7..ac11514c75f 100644 --- a/torchmetrics/functional/regression/pearson.py +++ b/torchmetrics/functional/regression/pearson.py @@ -19,7 +19,7 @@ from torchmetrics.utilities.checks import _check_same_shape -def _update_mean(old_mean: torch.Tensor, old_nobs: torch.Tensor, data: torch.Tensor) -> torch.Tensor: +def _update_mean(old_mean: Tensor, old_nobs: Tensor, data: Tensor) -> Tensor: """ Update a mean estimate given new data Args: old_mean: current mean estimate @@ -32,7 +32,7 @@ def _update_mean(old_mean: torch.Tensor, old_nobs: torch.Tensor, data: torch.Ten return (old_mean * old_nobs + data.mean(dim=0) * data_size) / (old_nobs + data_size) -def _update_cov(old_cov: torch.Tensor, old_mean: torch.Tensor, new_mean: torch.Tensor, data: torch.Tensor): +def _update_cov(old_cov: Tensor, old_mean: Tensor, new_mean: Tensor, data: Tensor): """ Update a covariance estimate given new data Args: old_cov: current covariance estimate From 6b0795c6780065e247946ffa4f4fd186da4819b8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 6 Apr 2021 11:04:01 +0200 Subject: [PATCH 09/21] _sk_pearsonr --- tests/regression/test_pearson.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/regression/test_pearson.py b/tests/regression/test_pearson.py index 1f05d4f958e..4152965381e 100644 --- a/tests/regression/test_pearson.py +++ b/tests/regression/test_pearson.py @@ -88,7 +88,7 @@ def test_pearson_corrcoef(self, preds, target, ddp, dist_sync_on_step): preds=preds, target=target, metric_class=PearsonCorrcoef, - sk_metric=_sk_metric, + sk_metric=_sk_pearsonr, dist_sync_on_step=dist_sync_on_step, ) @@ -97,7 +97,7 @@ def test_pearson_corrcoef_functional(self, preds, target): preds=preds, target=target, metric_functional=pearson_corrcoef, - sk_metric=_sk_metric + sk_metric=_sk_pearsonr ) # Pearson half + cpu does not work due to missing support in torch.sqrt From e939d278c2c6c85cba3368fa4e2a3e68738223c7 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 6 Apr 2021 11:13:17 +0200 Subject: [PATCH 10/21] inline --- torchmetrics/functional/regression/pearson.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchmetrics/functional/regression/pearson.py b/torchmetrics/functional/regression/pearson.py index ac11514c75f..25636388231 100644 --- a/torchmetrics/functional/regression/pearson.py +++ b/torchmetrics/functional/regression/pearson.py @@ -63,12 +63,9 @@ def _pearson_corrcoef_update( raise ValueError('Expected both predictions and target to be 1 dimensional tensors.') data = torch.stack([preds, target], dim=1) - if old_mean is None: - old_mean = 0 - if old_cov is None: - old_cov = 0 - if old_nobs is None: - old_nobs = 0 + old_mean = 0 if old_mean is None else old_mean + old_cov = 0 if old_cov is None else old_cov + old_nobs = 0 if old_nobs is None else old_nobs new_mean = _update_mean(old_mean, old_nobs, data) new_cov = _update_cov(old_cov, old_mean, new_mean, data) From 24198009ab1b2f8b424717c7d378db61e0a43adf Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 6 Apr 2021 19:15:11 +0200 Subject: [PATCH 11/21] fix sync --- torchmetrics/regression/pearson.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torchmetrics/regression/pearson.py b/torchmetrics/regression/pearson.py index 9275869dbc5..264f92318ac 100644 --- a/torchmetrics/regression/pearson.py +++ b/torchmetrics/regression/pearson.py @@ -43,9 +43,6 @@ class PearsonCorrcoef(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When ``None``, DDP - will be used to perform the allgather Example: >>> from torchmetrics import PearsonCorrcoef @@ -61,13 +58,11 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Optional[Callable] = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, - dist_sync_fn=dist_sync_fn, ) self.add_state("cov", default=torch.zeros(2, 2), dist_reduce_fx="sum") self.add_state("mean", default=torch.zeros(2), dist_reduce_fx="sum") From 7fa6530c24dc04102d72aef51b3e5f94d29bda1b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 13 Apr 2021 11:31:54 +0200 Subject: [PATCH 12/21] fix tests --- tests/regression/test_pearson.py | 33 +----------- torchmetrics/__init__.py | 2 +- torchmetrics/functional/regression/pearson.py | 54 ++++--------------- torchmetrics/regression/pearson.py | 24 ++++++--- 4 files changed, 29 insertions(+), 84 deletions(-) diff --git a/tests/regression/test_pearson.py b/tests/regression/test_pearson.py index 4152965381e..ec66fd459e9 100644 --- a/tests/regression/test_pearson.py +++ b/tests/regression/test_pearson.py @@ -19,41 +19,12 @@ from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.functional.regression.pearson import _update_cov, _update_mean, pearson_corrcoef +from torchmetrics.functional.regression.pearson import pearson_corrcoef from torchmetrics.regression.pearson import PearsonCorrcoef seed_all(42) -def test_update_functions(tmpdir): - """ Test that updating the estimates are equal to estimating them on all data """ - data = torch.randn(100, 2) - batch1, batch2 = data.chunk(2) - - def _mean_cov(data): - mean = data.mean(0) - diff = data - mean - cov = diff.T @ diff - return mean, cov - - mean_update, cov_update, size_update = torch.zeros(2), torch.zeros(2, 2), torch.zeros(1) - for batch in [batch1, batch2]: - new_mean = _update_mean(mean_update, size_update, batch) - new_cov = _update_cov(cov_update, mean_update, new_mean, batch) - - assert not torch.allclose(new_mean, mean_update), "mean estimate did not update" - assert not torch.allclose(new_cov, cov_update), "covariance estimate did not update" - - size_update += batch.shape[0] - mean_update = new_mean - cov_update = new_cov - - mean, cov = _mean_cov(data) - - assert torch.allclose(mean, mean_update), "updated mean does not correspond to mean of all data" - assert torch.allclose(cov, cov_update), "updated covariance does not correspond to covariance of all data" - - Input = namedtuple('Input', ["preds", "target"]) _single_target_inputs1 = Input( @@ -78,8 +49,6 @@ def _sk_pearsonr(preds, target): (_single_target_inputs2.preds, _single_target_inputs2.target), ]) class TestPearsonCorrcoef(MetricTester): - atol = 1e-4 - @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_pearson_corrcoef(self, preds, target, ddp, dist_sync_on_step): diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 9cab9e0cf2b..e115f00af18 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -41,13 +41,13 @@ from torchmetrics.collections import MetricCollection # noqa: F401 E402 from torchmetrics.metric import Metric # noqa: F401 E402 from torchmetrics.regression import ( # noqa: F401 E402 - PearsonCorrcoef, PSNR, SSIM, ExplainedVariance, MeanAbsoluteError, MeanSquaredError, MeanSquaredLogError, + PearsonCorrcoef, R2Score, ) from torchmetrics.retrieval import ( # noqa: F401 E402 diff --git a/torchmetrics/functional/regression/pearson.py b/torchmetrics/functional/regression/pearson.py index 25636388231..d0bd78c948a 100644 --- a/torchmetrics/functional/regression/pearson.py +++ b/torchmetrics/functional/regression/pearson.py @@ -19,32 +19,6 @@ from torchmetrics.utilities.checks import _check_same_shape -def _update_mean(old_mean: Tensor, old_nobs: Tensor, data: Tensor) -> Tensor: - """ Update a mean estimate given new data - Args: - old_mean: current mean estimate - old_nobs: number of observation until now - data: data used for updating the estimate - Returns: - new_mean: updated mean estimate - """ - data_size = data.shape[0] - return (old_mean * old_nobs + data.mean(dim=0) * data_size) / (old_nobs + data_size) - - -def _update_cov(old_cov: Tensor, old_mean: Tensor, new_mean: Tensor, data: Tensor): - """ Update a covariance estimate given new data - Args: - old_cov: current covariance estimate - old_mean: current mean estimate - new_mean: updated mean estimate - data: data used for updating the estimate - Returns: - new_mean: updated covariance estimate - """ - return old_cov + (data - new_mean).T @ (data - old_mean) - - def _pearson_corrcoef_update( preds: Tensor, target: Tensor, @@ -61,26 +35,20 @@ def _pearson_corrcoef_update( target = target.squeeze() if preds.ndim > 1 or target.ndim > 1: raise ValueError('Expected both predictions and target to be 1 dimensional tensors.') - data = torch.stack([preds, target], dim=1) - old_mean = 0 if old_mean is None else old_mean - old_cov = 0 if old_cov is None else old_cov - old_nobs = 0 if old_nobs is None else old_nobs + return preds, target - new_mean = _update_mean(old_mean, old_nobs, data) - new_cov = _update_cov(old_cov, old_mean, new_mean, data) - new_size = old_nobs + preds.numel() - return new_mean, new_cov, new_size +def _pearson_corrcoef_compute(preds: Tensor, target: Tensor) -> Tensor: + """ computes the final pearson correlation based on covariance matrix and number of observatiosn """ + preds_diff = preds - preds.mean() + target_diff = target - target.mean() + cov = (preds_diff * target_diff).mean() + preds_std = torch.sqrt((preds_diff * preds_diff).mean()) + target_std = torch.sqrt((target_diff * target_diff).mean()) -def _pearson_corrcoef_compute(c: Tensor, nobs: Tensor): - """ computes the final pearson correlation based on covariance matrix and number of observatiosn """ - c /= (nobs - 1) - x_var = c[0, 0] - y_var = c[1, 1] - cov = c[0, 1] - corrcoef = cov / (x_var * y_var).sqrt() + corrcoef = cov / (preds_std * target_std) return torch.clamp(corrcoef, -1.0, 1.0) @@ -99,5 +67,5 @@ def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor: >>> pearson_corrcoef(preds, target) tensor(0.9849) """ - _, c, nobs = _pearson_corrcoef_update(preds, target) - return _pearson_corrcoef_compute(c, nobs) + preds, target = _pearson_corrcoef_update(preds, target) + return _pearson_corrcoef_compute(preds, target) diff --git a/torchmetrics/regression/pearson.py b/torchmetrics/regression/pearson.py index 264f92318ac..6f497c6340f 100644 --- a/torchmetrics/regression/pearson.py +++ b/torchmetrics/regression/pearson.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Optional import torch from torch import Tensor from torchmetrics.functional.regression.pearson import _pearson_corrcoef_compute, _pearson_corrcoef_update from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn class PearsonCorrcoef(Metric): @@ -64,9 +65,14 @@ def __init__( dist_sync_on_step=dist_sync_on_step, process_group=process_group, ) - self.add_state("cov", default=torch.zeros(2, 2), dist_reduce_fx="sum") - self.add_state("mean", default=torch.zeros(2), dist_reduce_fx="sum") - self.add_state("n_obs", default=torch.zeros(1), dist_reduce_fx="sum") + + rank_zero_warn( + 'Metric `PearsonCorrcoef` will save all targets and predictions in buffer.' + ' For large datasets this may lead to large memory footprint.' + ) + + self.add_state("preds", default=[], dist_reduce_fx=None) + self.add_state("target", default=[], dist_reduce_fx=None) def update(self, preds: Tensor, target: Tensor): """ @@ -76,12 +82,14 @@ def update(self, preds: Tensor, target: Tensor): preds: Predictions from model target: Ground truth values """ - self.mean, self.cov, self.n_obs = _pearson_corrcoef_update( - preds, target, self.mean, self.cov, self.n_obs - ) + preds, target = _pearson_corrcoef_update(preds, target) + self.preds.append(preds) + self.target.append(target) def compute(self): """ Computes pearson correlation coefficient over state. """ - return _pearson_corrcoef_compute(self.cov, self.n_obs) + preds = torch.cat(self.preds, dim=0) + target = torch.cat(self.target, dim=0) + return _pearson_corrcoef_compute(preds, target) From 61d2d900264fbd4023a85d0c4350d1ad666df132 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 13 Apr 2021 12:24:10 +0200 Subject: [PATCH 13/21] fix docs --- torchmetrics/regression/pearson.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchmetrics/regression/pearson.py b/torchmetrics/regression/pearson.py index 6f497c6340f..15efce09330 100644 --- a/torchmetrics/regression/pearson.py +++ b/torchmetrics/regression/pearson.py @@ -26,7 +26,8 @@ class PearsonCorrcoef(Metric): Computes `pearson correlation coefficient `_: - .. math:: \text{P_corr}(x,y) = \frac{cov(x,y)}{\sigma_x \times \sigma_y} + .. math:: + P_{corr}(x,y) = \frac{cov(x,y)}{\sigma_x \sigma_y} Where :math:`y` is a tensor of target values, and :math:`x` is a tensor of predictions. From eab5638cae77aa3ea2f8cdf6d1c2f36cc5fd6c3c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 13 Apr 2021 13:43:22 +0200 Subject: [PATCH 14/21] Apply suggestions from code review --- torchmetrics/functional/regression/pearson.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/regression/pearson.py b/torchmetrics/functional/regression/pearson.py index d0bd78c948a..8911d27bf92 100644 --- a/torchmetrics/functional/regression/pearson.py +++ b/torchmetrics/functional/regression/pearson.py @@ -39,7 +39,7 @@ def _pearson_corrcoef_update( return preds, target -def _pearson_corrcoef_compute(preds: Tensor, target: Tensor) -> Tensor: +def _pearson_corrcoef_compute(preds: Tensor, target: Tensor, eps: float=1e-6) -> Tensor: """ computes the final pearson correlation based on covariance matrix and number of observatiosn """ preds_diff = preds - preds.mean() target_diff = target - target.mean() @@ -48,7 +48,7 @@ def _pearson_corrcoef_compute(preds: Tensor, target: Tensor) -> Tensor: preds_std = torch.sqrt((preds_diff * preds_diff).mean()) target_std = torch.sqrt((target_diff * target_diff).mean()) - corrcoef = cov / (preds_std * target_std) + corrcoef = cov / (preds_std * target_std + eps) return torch.clamp(corrcoef, -1.0, 1.0) From dba602bac8b3439701496d789a1b48ca1da92823 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 13 Apr 2021 13:44:02 +0200 Subject: [PATCH 15/21] Update torchmetrics/functional/regression/pearson.py --- torchmetrics/functional/regression/pearson.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/regression/pearson.py b/torchmetrics/functional/regression/pearson.py index 8911d27bf92..dabb203569e 100644 --- a/torchmetrics/functional/regression/pearson.py +++ b/torchmetrics/functional/regression/pearson.py @@ -39,7 +39,7 @@ def _pearson_corrcoef_update( return preds, target -def _pearson_corrcoef_compute(preds: Tensor, target: Tensor, eps: float=1e-6) -> Tensor: +def _pearson_corrcoef_compute(preds: Tensor, target: Tensor, eps: float = 1e-6) -> Tensor: """ computes the final pearson correlation based on covariance matrix and number of observatiosn """ preds_diff = preds - preds.mean() target_diff = target - target.mean() From d41e551616ca092c5313cae749e2483a3662fbf0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 13 Apr 2021 14:26:19 +0200 Subject: [PATCH 16/21] atol --- tests/regression/test_pearson.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/regression/test_pearson.py b/tests/regression/test_pearson.py index ec66fd459e9..b2630d62959 100644 --- a/tests/regression/test_pearson.py +++ b/tests/regression/test_pearson.py @@ -49,6 +49,8 @@ def _sk_pearsonr(preds, target): (_single_target_inputs2.preds, _single_target_inputs2.target), ]) class TestPearsonCorrcoef(MetricTester): + atol = 1e-2 + @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_pearson_corrcoef(self, preds, target, ddp, dist_sync_on_step): From 402ac9f82b1df9802cd48698e11e98a5cf75e335 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 13 Apr 2021 14:31:14 +0200 Subject: [PATCH 17/21] update --- torchmetrics/functional/regression/pearson.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchmetrics/functional/regression/pearson.py b/torchmetrics/functional/regression/pearson.py index dabb203569e..5c22ddad198 100644 --- a/torchmetrics/functional/regression/pearson.py +++ b/torchmetrics/functional/regression/pearson.py @@ -47,8 +47,13 @@ def _pearson_corrcoef_compute(preds: Tensor, target: Tensor, eps: float = 1e-6) cov = (preds_diff * target_diff).mean() preds_std = torch.sqrt((preds_diff * preds_diff).mean()) target_std = torch.sqrt((target_diff * target_diff).mean()) + + denom = preds_std * target_std + # prevent division by zero + if denom == 0: + denom += eps - corrcoef = cov / (preds_std * target_std + eps) + corrcoef = cov / denom return torch.clamp(corrcoef, -1.0, 1.0) From 0db3f4ce7a4e77b0400ad73a8682d3bb6aebbd4a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 13 Apr 2021 14:35:57 +0200 Subject: [PATCH 18/21] pep8 --- tests/regression/test_pearson.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/regression/test_pearson.py b/tests/regression/test_pearson.py index b2630d62959..d01952ee38c 100644 --- a/tests/regression/test_pearson.py +++ b/tests/regression/test_pearson.py @@ -50,7 +50,7 @@ def _sk_pearsonr(preds, target): ]) class TestPearsonCorrcoef(MetricTester): atol = 1e-2 - + @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_pearson_corrcoef(self, preds, target, ddp, dist_sync_on_step): From 40ab5fac9c8644e6207bee1c58f55e9f731590c0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 13 Apr 2021 14:38:42 +0200 Subject: [PATCH 19/21] pep8 --- torchmetrics/functional/regression/pearson.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/regression/pearson.py b/torchmetrics/functional/regression/pearson.py index 5c22ddad198..5382c88bb55 100644 --- a/torchmetrics/functional/regression/pearson.py +++ b/torchmetrics/functional/regression/pearson.py @@ -47,7 +47,7 @@ def _pearson_corrcoef_compute(preds: Tensor, target: Tensor, eps: float = 1e-6) cov = (preds_diff * target_diff).mean() preds_std = torch.sqrt((preds_diff * preds_diff).mean()) target_std = torch.sqrt((target_diff * target_diff).mean()) - + denom = preds_std * target_std # prevent division by zero if denom == 0: From 662b22ce9ea2fefdfe6acbfad7dd704983760748 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 13 Apr 2021 16:06:57 +0200 Subject: [PATCH 20/21] chlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f924191b54..0721ac183f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added other metrics: * Added `CohenKappa` ([#69](https://github.com/PyTorchLightning/metrics/pull/69)) * Added `MatthewsCorrcoef` ([#98](https://github.com/PyTorchLightning/metrics/pull/98)) + * Added `PearsonCorrcoef` metric ([#157](https://github.com/PyTorchLightning/metrics/pull/157)) * Added `Hinge` ([#120](https://github.com/PyTorchLightning/metrics/pull/120)) - Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110)) - Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114)) @@ -28,7 +29,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ) - Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70)) - Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142)) -- Added `PearsonCorrcoef` metric ([#157](https://github.com/PyTorchLightning/metrics/pull/157)) ### Changed From bead3ed8f73515818c5b3580b1a3afb559d0c7aa Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 13 Apr 2021 16:07:15 +0200 Subject: [PATCH 21/21] . --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0721ac183f0..7d73cba87cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added other metrics: * Added `CohenKappa` ([#69](https://github.com/PyTorchLightning/metrics/pull/69)) * Added `MatthewsCorrcoef` ([#98](https://github.com/PyTorchLightning/metrics/pull/98)) - * Added `PearsonCorrcoef` metric ([#157](https://github.com/PyTorchLightning/metrics/pull/157)) + * Added `PearsonCorrcoef` ([#157](https://github.com/PyTorchLightning/metrics/pull/157)) * Added `Hinge` ([#120](https://github.com/PyTorchLightning/metrics/pull/120)) - Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110)) - Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114))