Skip to content

Commit

Permalink
Bugfix/topk accuracy (#2423)
Browse files Browse the repository at this point in the history
* add tests
* fix implementation
* tests
* changelog

(cherry picked from commit 1351009)
  • Loading branch information
SkafteNicki authored and Borda committed Mar 18, 2024
1 parent fc9fa74 commit d30d06d
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed dtype being changed by deepspeed for certain regression metrics ([#2379](https://github.com/Lightning-AI/torchmetrics/pull/2379))


- Fixed bug when `top_k>1` and `average="macro"` for classification metrics ([#2423](https://github.com/Lightning-AI/torchmetrics/pull/2423))


## [1.3.1] - 2024-02-12

### Fixed
Expand Down
4 changes: 3 additions & 1 deletion src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ class MulticlassAccuracy(MulticlassStatScores):
def compute(self) -> Tensor:
"""Compute accuracy based on inputs passed in to ``update`` previously."""
tp, fp, tn, fn = self._final_state()
return _accuracy_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average)
return _accuracy_reduce(
tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average
"precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k
)

def plot(
Expand Down Expand Up @@ -702,7 +702,7 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average
"recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k
)

def plot(
Expand Down
6 changes: 4 additions & 2 deletions src/torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _accuracy_reduce(
average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]],
multidim_average: Literal["global", "samplewise"] = "global",
multilabel: bool = False,
top_k: int = 1,
) -> Tensor:
"""Reduce classification statistics into accuracy score.
Expand All @@ -66,6 +67,7 @@ def _accuracy_reduce(
- ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
multilabel: If input is multilabel or not
top_k: value for top-k accuracy, else 1
Returns:
Accuracy score
Expand All @@ -83,7 +85,7 @@ def _accuracy_reduce(
return _safe_divide(tp, tp + fn)

score = _safe_divide(tp + tn, tp + tn + fp + fn) if multilabel else _safe_divide(tp, tp + fn)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k)


def binary_accuracy(
Expand Down Expand Up @@ -266,7 +268,7 @@ def multiclass_accuracy(
tp, fp, tn, fn = _multiclass_stat_scores_update(
preds, target, num_classes, top_k, average, multidim_average, ignore_index
)
return _accuracy_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average)
return _accuracy_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, top_k=top_k)


def multilabel_accuracy(
Expand Down
11 changes: 8 additions & 3 deletions src/torchmetrics/functional/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def _precision_recall_reduce(
average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]],
multidim_average: Literal["global", "samplewise"] = "global",
multilabel: bool = False,
top_k: int = 1,
) -> Tensor:
different_stat = fp if stat == "precision" else fn # this is what differs between the two scores
if average == "binary":
Expand All @@ -54,7 +55,7 @@ def _precision_recall_reduce(
return _safe_divide(tp, tp + different_stat)

score = _safe_divide(tp, tp + different_stat)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k=top_k)


def binary_precision(
Expand Down Expand Up @@ -235,7 +236,9 @@ def multiclass_precision(
tp, fp, tn, fn = _multiclass_stat_scores_update(
preds, target, num_classes, top_k, average, multidim_average, ignore_index
)
return _precision_recall_reduce("precision", tp, fp, tn, fn, average=average, multidim_average=multidim_average)
return _precision_recall_reduce(
"precision", tp, fp, tn, fn, average=average, multidim_average=multidim_average, top_k=top_k
)


def multilabel_precision(
Expand Down Expand Up @@ -519,7 +522,9 @@ def multiclass_recall(
tp, fp, tn, fn = _multiclass_stat_scores_update(
preds, target, num_classes, top_k, average, multidim_average, ignore_index
)
return _precision_recall_reduce("recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average)
return _precision_recall_reduce(
"recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average, top_k=top_k
)


def multilabel_recall(
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _safe_divide(num: Tensor, denom: Tensor) -> Tensor:


def _adjust_weights_safe_divide(
score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor
score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor, top_k: int = 1
) -> Tensor:
if average is None or average == "none":
return score
Expand All @@ -65,7 +65,7 @@ def _adjust_weights_safe_divide(
else:
weights = torch.ones_like(score)
if not multilabel:
weights[tp + fp + fn == 0] = 0.0
weights[tp + fp + fn == 0 if top_k == 1 else tp + fn == 0] = 0.0
return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)


Expand Down
5 changes: 5 additions & 0 deletions tests/unittests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,17 @@ def test_multiclass_accuracy_half_gpu(self, inputs, dtype):
_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])

_mc_k_targets2 = torch.tensor([0, 0, 2])
_mc_k_preds2 = torch.tensor([[0.9, 0.1, 0.0], [0.9, 0.1, 0.0], [0.9, 0.1, 0.0]])


@pytest.mark.parametrize(
("k", "preds", "target", "average", "expected"),
[
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3)),
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(3 / 3)),
(1, _mc_k_preds2, _mc_k_targets2, "macro", torch.tensor(1 / 2)),
(2, _mc_k_preds2, _mc_k_targets2, "macro", torch.tensor(1 / 2)),
],
)
def test_top_k(k, preds, target, average, expected):
Expand Down
5 changes: 5 additions & 0 deletions tests/unittests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@ def test_multiclass_precision_recall_half_gpu(self, inputs, module, functional,
_mc_k_target = tensor([0, 1, 2])
_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])

_mc_k_targets2 = torch.tensor([0, 0, 2])
_mc_k_preds2 = torch.tensor([[0.9, 0.1, 0.0], [0.9, 0.1, 0.0], [0.9, 0.1, 0.0]])


@pytest.mark.parametrize(
("metric_class", "metric_fn"), [(MulticlassPrecision, multiclass_precision), (MulticlassRecall, multiclass_recall)]
Expand All @@ -340,6 +343,8 @@ def test_multiclass_precision_recall_half_gpu(self, inputs, module, functional,
[
(1, _mc_k_preds, _mc_k_target, "micro", tensor(2 / 3), tensor(2 / 3)),
(2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2), tensor(1.0)),
(1, _mc_k_preds2, _mc_k_targets2, "macro", tensor(1 / 3), tensor(1 / 2)),
(2, _mc_k_preds2, _mc_k_targets2, "macro", tensor(1 / 3), tensor(1 / 2)),
],
)
def test_top_k(
Expand Down

0 comments on commit d30d06d

Please sign in to comment.