Skip to content

Commit

Permalink
Torch-based mAP (#632)
Browse files Browse the repository at this point in the history
* First draft
* Remove double score
* Calculate num_class only
* Support empty predictions
* Remove pycocotools from tests
* Fix annotation id evals to false if zero
* Improve method descriptions
* Fix returning metrics if parameters are changed
* Apply suggestions from code review

Co-authored-by: Tobias Kupek <tkupek@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>

(cherry picked from commit 2494e68)
  • Loading branch information
twsl authored and Borda committed Dec 5, 2021
1 parent 09acc67 commit 2c08715
Show file tree
Hide file tree
Showing 3 changed files with 503 additions and 244 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Migrate MAP metrics from pycocotools to PyTorch ([#632](https://github.com/PyTorchLightning/metrics/pull/632))


- Use `torch.topk` instead of `torch.argsort` in retrieval precision for speedup ([#627](https://github.com/PyTorchLightning/metrics/pull/627))


### Deprecated


Expand Down
41 changes: 14 additions & 27 deletions tests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@

from tests.helpers.testers import MetricTester
from torchmetrics.detection.map import MAP
from torchmetrics.utilities.imports import (
_PYCOCOTOOLS_AVAILABLE,
_TORCHVISION_AVAILABLE,
_TORCHVISION_GREATER_EQUAL_0_8,
)
from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8

Input = namedtuple("Input", ["preds", "target"])

Expand Down Expand Up @@ -59,7 +55,7 @@
), # coco image id 74
dict(
boxes=torch.Tensor([[0.00, 2.87, 601.00, 421.52]]),
scores=torch.Tensor([0.699, 0.423]),
scores=torch.Tensor([0.699]),
labels=torch.IntTensor([5]),
), # coco image id 133
],
Expand Down Expand Up @@ -164,10 +160,10 @@ def _compare_fn(preds, target) -> dict:
}


_pytest_condition = not (_PYCOCOTOOLS_AVAILABLE and _TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8)
_pytest_condition = not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8)


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
class TestMAP(MetricTester):
"""Test the MAP metric for object detection predictions.
Expand All @@ -194,7 +190,7 @@ def test_map(self, ddp):


# noinspection PyTypeChecker
@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_error_on_wrong_init():
"""Test class raises the expected errors."""
MAP() # no error
Expand All @@ -203,20 +199,11 @@ def test_error_on_wrong_init():
MAP(class_metrics=0)


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_empty_preds():
"""Test empty predictions."""
metric = MAP()

metric.update(
[
dict(boxes=torch.Tensor([[]]), scores=torch.Tensor([]), labels=torch.IntTensor([])),
],
[
dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=torch.IntTensor([4])),
],
)

metric.update(
[
dict(boxes=torch.Tensor([]), scores=torch.Tensor([]), labels=torch.IntTensor([])),
Expand All @@ -235,17 +222,17 @@ def test_empty_metric():
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_error_on_wrong_input():
"""Test class input validation."""
metric = MAP()

metric.update([], []) # no error

with pytest.raises(ValueError, match="Expected argument `preds` to be of type List"):
with pytest.raises(ValueError, match="Expected argument `preds` to be of type Sequence"):
metric.update(torch.Tensor(), []) # type: ignore

with pytest.raises(ValueError, match="Expected argument `target` to be of type List"):
with pytest.raises(ValueError, match="Expected argument `target` to be of type Sequence"):
metric.update([], torch.Tensor()) # type: ignore

with pytest.raises(ValueError, match="Expected argument `preds` and `target` to have the same length"):
Expand Down Expand Up @@ -281,31 +268,31 @@ def test_error_on_wrong_input():
[dict(boxes=torch.IntTensor())],
)

with pytest.raises(ValueError, match="Expected all boxes in `preds` to be of type torch.Tensor"):
with pytest.raises(ValueError, match="Expected all boxes in `preds` to be of type Tensor"):
metric.update(
[dict(boxes=[], scores=torch.Tensor(), labels=torch.IntTensor())],
[dict(boxes=torch.Tensor(), labels=torch.IntTensor())],
)

with pytest.raises(ValueError, match="Expected all scores in `preds` to be of type torch.Tensor"):
with pytest.raises(ValueError, match="Expected all scores in `preds` to be of type Tensor"):
metric.update(
[dict(boxes=torch.Tensor(), scores=[], labels=torch.IntTensor())],
[dict(boxes=torch.Tensor(), labels=torch.IntTensor())],
)

with pytest.raises(ValueError, match="Expected all labels in `preds` to be of type torch.Tensor"):
with pytest.raises(ValueError, match="Expected all labels in `preds` to be of type Tensor"):
metric.update(
[dict(boxes=torch.Tensor(), scores=torch.Tensor(), labels=[])],
[dict(boxes=torch.Tensor(), labels=torch.IntTensor())],
)

with pytest.raises(ValueError, match="Expected all boxes in `target` to be of type torch.Tensor"):
with pytest.raises(ValueError, match="Expected all boxes in `target` to be of type Tensor"):
metric.update(
[dict(boxes=torch.Tensor(), scores=torch.Tensor(), labels=torch.IntTensor())],
[dict(boxes=[], labels=torch.IntTensor())],
)

with pytest.raises(ValueError, match="Expected all labels in `target` to be of type torch.Tensor"):
with pytest.raises(ValueError, match="Expected all labels in `target` to be of type Tensor"):
metric.update(
[dict(boxes=torch.Tensor(), scores=torch.Tensor(), labels=torch.IntTensor())],
[dict(boxes=torch.Tensor(), labels=[])],
Expand Down
Loading

0 comments on commit 2c08715

Please sign in to comment.