Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pearson correlation coefficient #157

Merged
merged 26 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
)
- Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70))
- Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142))
- Added `PearsonCorrcoef` metric ([#157](https://github.com/PyTorchLightning/metrics/pull/157))


### Changed

Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,13 @@ mean_squared_log_error [func]
:noindex:


pearson_corrcoef [func]
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.pearson_corrcoef
:noindex:


psnr [func]
~~~~~~~~~~~

Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,13 @@ MeanSquaredLogError
:noindex:


PearsonCorrcoef
~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.PearsonCorrcoef
:noindex:


PSNR
~~~~

Expand Down
88 changes: 88 additions & 0 deletions tests/regression/test_pearson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple

import pytest
import torch
from scipy.stats import pearsonr

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional.regression.pearson import pearson_corrcoef
from torchmetrics.regression.pearson import PearsonCorrcoef

seed_all(42)


Input = namedtuple('Input', ["preds", "target"])

_single_target_inputs1 = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
)

_single_target_inputs2 = Input(
preds=torch.randn(NUM_BATCHES, BATCH_SIZE),
target=torch.randn(NUM_BATCHES, BATCH_SIZE),
)


def _sk_pearsonr(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return pearsonr(sk_target, sk_preds)[0]


@pytest.mark.parametrize("preds, target", [
(_single_target_inputs1.preds, _single_target_inputs1.target),
(_single_target_inputs2.preds, _single_target_inputs2.target),
])
class TestPearsonCorrcoef(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_pearson_corrcoef(self, preds, target, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=PearsonCorrcoef,
sk_metric=_sk_pearsonr,
dist_sync_on_step=dist_sync_on_step,
)

def test_pearson_corrcoef_functional(self, preds, target):
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=pearson_corrcoef,
sk_metric=_sk_pearsonr
)

# Pearson half + cpu does not work due to missing support in torch.sqrt
@pytest.mark.xfail(reason="PearsonCorrcoef metric does not support cpu + half precision")
def test_pearson_corrcoef_half_cpu(self, preds, target):
self.run_precision_test_cpu(preds, target, PearsonCorrcoef, pearson_corrcoef)

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_pearson_corrcoef_half_gpu(self, preds, target):
self.run_precision_test_gpu(preds, target, PearsonCorrcoef, pearson_corrcoef)


def test_error_on_different_shape():
metric = PearsonCorrcoef()
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
metric(torch.randn(100, ), torch.randn(50, ))

with pytest.raises(ValueError, match='Expected both predictions and target to be 1 dimensional tensors.'):
metric(torch.randn(100, 2), torch.randn(100, 2))
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
MeanAbsoluteError,
MeanSquaredError,
MeanSquaredLogError,
PearsonCorrcoef,
R2Score,
)
from torchmetrics.retrieval import ( # noqa: F401 E402
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torchmetrics.functional.regression.mean_relative_error import mean_relative_error # noqa: F401
from torchmetrics.functional.regression.mean_squared_error import mean_squared_error # noqa: F401
from torchmetrics.functional.regression.mean_squared_log_error import mean_squared_log_error # noqa: F401
from torchmetrics.functional.regression.pearson import pearson_corrcoef # noqa: F401
from torchmetrics.functional.regression.psnr import psnr # noqa: F401
from torchmetrics.functional.regression.r2score import r2score # noqa: F401
from torchmetrics.functional.regression.ssim import ssim # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod


def _safe_divide(num: torch.Tensor, denom: torch.Tensor):
def _safe_divide(num: Tensor, denom: Tensor):
""" prevent zero division """
denom[denom == 0.] = 1
return num / denom
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchmetrics.functional.regression.mean_absolute_error import mean_absolute_error # noqa: F401
from torchmetrics.functional.regression.mean_squared_error import mean_squared_error # noqa: F401
from torchmetrics.functional.regression.mean_squared_log_error import mean_squared_log_error # noqa: F401
from torchmetrics.functional.regression.pearson import pearson_corrcoef # noqa: F401
from torchmetrics.functional.regression.psnr import psnr # noqa: F401
from torchmetrics.functional.regression.r2score import r2score # noqa: F401
from torchmetrics.functional.regression.ssim import ssim # noqa: F401
71 changes: 71 additions & 0 deletions torchmetrics/functional/regression/pearson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple

import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape


def _pearson_corrcoef_update(
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
preds: Tensor,
target: Tensor,
old_mean: Optional[Tensor] = None,
old_cov: Optional[Tensor] = None,
old_nobs: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
""" updates current estimates of the mean, cov and n_obs with new data for calculating
pearsons correlation
"""
# Data checking
_check_same_shape(preds, target)
preds = preds.squeeze()
target = target.squeeze()
if preds.ndim > 1 or target.ndim > 1:
raise ValueError('Expected both predictions and target to be 1 dimensional tensors.')

return preds, target


def _pearson_corrcoef_compute(preds: Tensor, target: Tensor, eps: float=1e-6) -> Tensor:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
""" computes the final pearson correlation based on covariance matrix and number of observatiosn """
preds_diff = preds - preds.mean()
target_diff = target - target.mean()

cov = (preds_diff * target_diff).mean()
preds_std = torch.sqrt((preds_diff * preds_diff).mean())
target_std = torch.sqrt((target_diff * target_diff).mean())

corrcoef = cov / (preds_std * target_std + eps)
return torch.clamp(corrcoef, -1.0, 1.0)


def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor:
"""
Computes pearson correlation coefficient.

Args:
preds: estimated scores
target: ground truth scores

Example:
>>> from torchmetrics.functional import pearson_corrcoef
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> pearson_corrcoef(preds, target)
tensor(0.9849)
"""
preds, target = _pearson_corrcoef_update(preds, target)
return _pearson_corrcoef_compute(preds, target)
1 change: 1 addition & 0 deletions torchmetrics/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchmetrics.regression.mean_absolute_error import MeanAbsoluteError # noqa: F401
from torchmetrics.regression.mean_squared_error import MeanSquaredError # noqa: F401
from torchmetrics.regression.mean_squared_log_error import MeanSquaredLogError # noqa: F401
from torchmetrics.regression.pearson import PearsonCorrcoef # noqa: F401
from torchmetrics.regression.psnr import PSNR # noqa: F401
from torchmetrics.regression.r2score import R2Score # noqa: F401
from torchmetrics.regression.ssim import SSIM # noqa: F401
96 changes: 96 additions & 0 deletions torchmetrics/regression/pearson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional

import torch
from torch import Tensor

from torchmetrics.functional.regression.pearson import _pearson_corrcoef_compute, _pearson_corrcoef_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn


class PearsonCorrcoef(Metric):
r"""
Computes `pearson correlation coefficient
<https://en.wikipedia.org/wiki/Pearson_correlation_coefficient>`_:

.. math::
P_{corr}(x,y) = \frac{cov(x,y)}{\sigma_x \sigma_y}

Where :math:`y` is a tensor of target values, and :math:`x` is a
tensor of predictions.

Forward accepts

- ``preds`` (float tensor): ``(N,)``
- ``target``(float tensor): ``(N,)``

Args:
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example:
>>> from torchmetrics import PearsonCorrcoef
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> pearson = PearsonCorrcoef()
>>> pearson(preds, target)
tensor(0.9849)

"""
def __init__(
self,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)

rank_zero_warn(
'Metric `PearsonCorrcoef` will save all targets and predictions in buffer.'
' For large datasets this may lead to large memory footprint.'
)

self.add_state("preds", default=[], dist_reduce_fx=None)
self.add_state("target", default=[], dist_reduce_fx=None)

def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.

Args:
preds: Predictions from model
target: Ground truth values
"""
preds, target = _pearson_corrcoef_update(preds, target)
self.preds.append(preds)
self.target.append(target)

def compute(self):
"""
Computes pearson correlation coefficient over state.
"""
preds = torch.cat(self.preds, dim=0)
target = torch.cat(self.target, dim=0)
return _pearson_corrcoef_compute(preds, target)