Skip to content

Commit

Permalink
Raise error in MeanAveragePrecision if too little/many detection ar…
Browse files Browse the repository at this point in the history
…e provided for `pycocotools` backend (#2219)
  • Loading branch information
SkafteNicki authored Nov 28, 2023
1 parent f45945e commit 56c0fdf
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed numerical stability issue in `UniversalImageQualityIndex` metric ([#2222](https://github.com/Lightning-AI/torchmetrics/pull/2222))


- Fixed incompatibility for `MeanAveragePrecision` with `pycocotools` backend when too little `max_detection_thresholds` are provided ([#2219](https://github.com/Lightning-AI/torchmetrics/pull/2219))


- Fixed support for half precision in Perplexity metric ([#2235](https://github.com/Lightning-AI/torchmetrics/pull/2235))


Expand Down
8 changes: 7 additions & 1 deletion src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ class MeanAveragePrecision(Metric):
with step ``0.01``. Else provide a list of floats.
max_detection_thresholds:
Thresholds on max detections per image. If set to `None` will use thresholds ``[1, 10, 100]``.
Else, please provide a list of ints.
Else, please provide a list of ints. If the `pycocotools` backend is used then the list needs to have
length 3. If this is a problem, shift to `faster_coco_eval` which supports more detection thresholds.
class_metrics:
Option to enable per-class metrics for mAP and mAR_100. Has a performance impact that scales linearly with
the number of classes in the dataset.
Expand Down Expand Up @@ -409,6 +410,11 @@ def __init__(
f"Expected argument `max_detection_thresholds` to either be `None` or a list of ints"
f" but got {max_detection_thresholds}"
)
if max_detection_thresholds is not None and backend == "pycocotools" and len(max_detection_thresholds) != 3:
raise ValueError(
"When using `pycocotools` backend the number of max detection thresholds should be 3 else"
f" it will not work correctly with the backend. Got value {len(max_detection_thresholds)}."
)
max_det_thr, _ = torch.sort(torch.tensor(max_detection_thresholds or [1, 10, 100], dtype=torch.int))
self.max_detection_thresholds = max_det_thr.tolist()

Expand Down
31 changes: 31 additions & 0 deletions tests/unittests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,3 +860,34 @@ def test_many_detection_thresholds(self, backend):
assert round(res["map"].item(), 5) != 0.6
else:
assert round(res["map"].item(), 5) == 0.6

@pytest.mark.parametrize("max_detection_thresholds", [[1, 10], [1, 10, 50, 100]])
def test_with_more_and_less_detection_thresholds(self, max_detection_thresholds, backend):
"""Test how metric is working when list of max detection thresholds is not 3.
This is a known limitation of the pycocotools where values are hardcoded to expect at least 3 elements
https://github.com/ppwwyyxx/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py#L461
"""
preds = [
{
"boxes": torch.tensor([[258.0, 41.0, 606.0, 285.0]]),
"scores": torch.tensor([0.536]),
"labels": torch.tensor([0]),
}
]
target = [
{
"boxes": torch.tensor([[214.0, 41.0, 562.0, 285.0]]),
"labels": torch.tensor([0]),
}
]

if backend == "pycocotools":
with pytest.raises(
ValueError, match="When using `pycocotools` backend the number of max detection thresholds should.*"
):
metric = MeanAveragePrecision(max_detection_thresholds=max_detection_thresholds, backend=backend)
else:
metric = MeanAveragePrecision(max_detection_thresholds=max_detection_thresholds, backend=backend)
metric(preds, target)

0 comments on commit 56c0fdf

Please sign in to comment.