From 0cd5fb989845d1628e89d6e55e7c7d9e26046048 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sun, 14 Apr 2024 14:35:38 +0200 Subject: [PATCH 1/2] suppress warnings when needed --- src/torchmetrics/detection/mean_ap.py | 5 +++-- tests/unittests/detection/test_map.py | 12 +++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 9f37208212f..b36e4ed89a6 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -827,8 +827,9 @@ def _get_safe_item_values( rle = self.mask_utils.encode(np.asfortranarray(i)) masks.append((tuple(rle["size"]), rle["counts"])) output[1] = tuple(masks) # type: ignore[call-overload] - if (output[0] is not None and len(output[0]) > self.max_detection_thresholds[-1]) or ( - output[1] is not None and len(output[1]) > self.max_detection_thresholds[-1] + if warn and ( + (output[0] is not None and len(output[0]) > self.max_detection_thresholds[-1]) + or (output[1] is not None and len(output[1]) > self.max_detection_thresholds[-1]) ): _warning_on_too_many_detections(self.max_detection_thresholds[-1]) return output # type: ignore[return-value] diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index f0dcdff52f2..64af4aab4ab 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -712,7 +712,8 @@ def test_for_box_format(self, box_format, iou_val_expected, map_val_expected, ba assert round(float(result["ious"][(0, 0)]), 3) == iou_val_expected @pytest.mark.parametrize("iou_type", ["bbox", "segm"]) - def test_warning_on_many_detections(self, iou_type, backend): + @pytest.mark.parametrize("warn_on_many_detections", [False, True]) + def test_warning_on_many_detections(self, iou_type, warn_on_many_detections, backend, recwarn): """Test that a warning is raised when there are many detections.""" if iou_type == "bbox": preds = [ @@ -727,8 +728,13 @@ def test_warning_on_many_detections(self, iou_type, backend): preds, targets = _generate_random_segm_input("cpu", 1, 101, 10, False) metric = MeanAveragePrecision(iou_type=iou_type, backend=backend) - with pytest.warns(UserWarning, match="Encountered more than 100 detections in a single image.*"): - metric.update(preds, targets) + metric.warn_on_many_detections = warn_on_many_detections + + if warn_on_many_detections: + with pytest.warns(UserWarning, match="Encountered more than 100 detections in a single image.*"): + metric.update(preds, targets) + else: + assert len(recwarn) == 0 @pytest.mark.parametrize( ("preds", "target", "expected_iou_len", "iou_keys", "precision_shape", "recall_shape", "scores_shape"), From 83d61f207e58bf5f4eadefba37dcd6d6fc172a4f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sun, 14 Apr 2024 14:37:45 +0200 Subject: [PATCH 2/2] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 40ed1b62c79..820f242e754 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462)) +- Fixed warnings being suppressed in `MeanAveragePrecision` when requested ([#2501](https://github.com/Lightning-AI/torchmetrics/pull/2501)) + + ## [1.3.2] - 2024-03-18 ### Fixed