From f535fe5db0e023227479c2c271236e3674e7649b Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 28 Jun 2023 08:53:39 +0200 Subject: [PATCH 1/2] fixes --- .../classification/matthews_corrcoef.py | 6 +++--- .../classification/test_matthews_corrcoef.py | 18 ++++++++++++++++-- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index 3e2f3c63149..62bb480c9ba 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -44,10 +44,10 @@ def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: if confmat.numel() == 4: # binary case tn, fp, fn, tp = confmat.reshape(-1) - if tp != 0 and tn == 0 and fp == 0 and fn == 0: + if tp + tn != 0 and fp + fn == 0: return torch.tensor(1.0, dtype=confmat.dtype, device=confmat.device) - if tp == 0 and tn != 0 and fp == 0 and fn == 0: + if tp + tn == 0 and fp + fn != 0: return torch.tensor(-1.0, dtype=confmat.dtype, device=confmat.device) tk = confmat.sum(dim=-1).float() @@ -71,7 +71,7 @@ def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: eps = torch.tensor(torch.finfo(torch.float32).eps, dtype=torch.float32, device=confmat.device) numerator = torch.sqrt(eps) * (a - b) - denom = torch.sqrt(2 * (a + b) * (a + eps) * (b + eps)) + denom = (tp + fp + eps) * (tp + fn + eps) * (tn + fp + eps) * (tn + fn + eps) elif denom == 0: return torch.tensor(0, dtype=confmat.dtype, device=confmat.device) return numerator / torch.sqrt(denom) diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index 79acf7480e5..795fc503479 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -321,7 +321,7 @@ def test_zero_case_in_multiclass(): @pytest.mark.parametrize( ("metric_fn", "preds", "target", "expected"), [ - (binary_matthews_corrcoef, torch.zeros(10), torch.zeros(10), -1.0), + (binary_matthews_corrcoef, torch.zeros(10), torch.zeros(10), 1.0), (binary_matthews_corrcoef, torch.ones(10), torch.ones(10), 1.0), ( binary_matthews_corrcoef, @@ -329,11 +329,13 @@ def test_zero_case_in_multiclass(): torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]), 0.0, ), + (binary_matthews_corrcoef, torch.zeros(10), torch.ones(10), -1.0), + (binary_matthews_corrcoef, torch.ones(10), torch.zeros(10), -1.0), ( partial(multilabel_matthews_corrcoef, num_labels=NUM_CLASSES), torch.zeros(10, NUM_CLASSES).long(), torch.zeros(10, NUM_CLASSES).long(), - -1.0, + 1.0, ), ( partial(multilabel_matthews_corrcoef, num_labels=NUM_CLASSES), @@ -341,6 +343,18 @@ def test_zero_case_in_multiclass(): torch.ones(10, NUM_CLASSES).long(), 1.0, ), + ( + partial(multilabel_matthews_corrcoef, num_labels=NUM_CLASSES), + torch.zeros(10, NUM_CLASSES).long(), + torch.ones(10, NUM_CLASSES).long(), + -1.0, + ), + ( + partial(multilabel_matthews_corrcoef, num_labels=NUM_CLASSES), + torch.ones(10, NUM_CLASSES).long(), + torch.zeros(10, NUM_CLASSES).long(), + -1.0, + ), ], ) def test_corner_cases(metric_fn, preds, target, expected): From 7f84e286c9e5f3d92ecdfb197430b81f75c0d3d6 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 28 Jun 2023 08:56:53 +0200 Subject: [PATCH 2/2] changelog --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e355c900ad..390a83cc0f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -212,7 +212,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed several bugs in `SpectralDistortionIndex` metric ([#1808](https://github.com/Lightning-AI/torchmetrics/pull/1808)) -- Fixed bug for corner cases in `MatthewsCorrCoef` ([#1812](https://github.com/Lightning-AI/torchmetrics/pull/1812)) +- Fixed bug for corner cases in `MatthewsCorrCoef` ( + [#1812](https://github.com/Lightning-AI/torchmetrics/pull/1812), + [#1863](https://github.com/Lightning-AI/torchmetrics/pull/1863) +) - Fixed support for half precision in `PearsonCorrCoef` ([#1819](https://github.com/Lightning-AI/torchmetrics/pull/1819))