Skip to content

Commit

Permalink
Fix precision-recall curve based computations for float target (#1642)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Mar 28, 2023
1 parent 7015b94 commit 6756b74
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed support in `MetricTracker` for `MultioutputWrapper` and nested structures ([#1608](https://github.com/Lightning-AI/metrics/pull/1608))


- Fix precision-recall curve based computations for float target ([#1642](https://github.com/Lightning-AI/metrics/pull/1642))


## [0.11.4] - 2023-03-10

### Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _binary_precision_recall_curve_update_vectorized(
"""
len_t = len(thresholds)
preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0)).long() # num_samples x num_thresholds
unique_mapping = preds_t + 2 * target.unsqueeze(-1) + 4 * torch.arange(len_t, device=target.device)
unique_mapping = preds_t + 2 * target.long().unsqueeze(-1) + 4 * torch.arange(len_t, device=target.device)
bins = _bincount(unique_mapping.flatten(), minlength=4 * len_t)
return bins.reshape(len_t, 2, 2)

Expand Down Expand Up @@ -469,7 +469,7 @@ def _multiclass_precision_recall_curve_update_vectorized(
len_t = len(thresholds)
preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0).unsqueeze(0)).long()
target_t = torch.nn.functional.one_hot(target, num_classes=num_classes)
unique_mapping = preds_t + 2 * target_t.unsqueeze(-1)
unique_mapping = preds_t + 2 * target_t.long().unsqueeze(-1)
unique_mapping += 4 * torch.arange(num_classes, device=preds.device).unsqueeze(0).unsqueeze(-1)
unique_mapping += 4 * num_classes * torch.arange(len_t, device=preds.device)
bins = _bincount(unique_mapping.flatten(), minlength=4 * num_classes * len_t)
Expand Down Expand Up @@ -714,7 +714,7 @@ def _multilabel_precision_recall_curve_update(
len_t = len(thresholds)
# num_samples x num_labels x num_thresholds
preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0).unsqueeze(0)).long()
unique_mapping = preds_t + 2 * target.unsqueeze(-1)
unique_mapping = preds_t + 2 * target.long().unsqueeze(-1)
unique_mapping += 4 * torch.arange(num_labels, device=preds.device).unsqueeze(0).unsqueeze(-1)
unique_mapping += 4 * num_labels * torch.arange(len_t, device=preds.device)
unique_mapping = unique_mapping[unique_mapping >= 0]
Expand Down

0 comments on commit 6756b74

Please sign in to comment.