Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix segmentation.MeanIoU #2698

Merged
merged 8 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed wrong aggregation in `segmentation.MeanIoU` ([#2698](https://github.com/Lightning-AI/torchmetrics/pull/2698))


- Fixed handling zero division error in binary IoU (Jaccard index) calculation ([#2726](https://github.com/Lightning-AI/torchmetrics/pull/2726))


Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Aggregate and evaluate batch input directly.

Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch
statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding
statistics to the overall accumulating metric state. Input arguments are the exact same as corresponding
``update`` method. The returned output is the exact same as the output of ``compute``.

Args:
Expand Down Expand Up @@ -361,7 +361,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any:
def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:
"""Forward computation using single call to `update`.

This can be done when the global metric state is a sinple reduction of batch states. This can be unsafe for
This can be done when the global metric state is a simple reduction of batch states. This can be unsafe for
certain metric cases but is also the fastest way to both accumulate globally and compute locally.

"""
Expand Down Expand Up @@ -802,7 +802,7 @@ def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> Module:
"""Overwrite `_apply` function such that we can also move metric states to the correct device.

This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods
are called. Dtype conversion is garded and will only happen through the special `set_dtype` method.
are called. Dtype conversion is guarded and will only happen through the special `set_dtype` method.

Args:
fn: the function to apply
Expand Down Expand Up @@ -1166,15 +1166,15 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt
"""

def update(self, *args: Any, **kwargs: Any) -> None:
"""Redirect the call to the input which the conposition was formed from."""
"""Redirect the call to the input which the composition was formed from."""
if isinstance(self.metric_a, Metric):
self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs))

if isinstance(self.metric_b, Metric):
self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs))

def compute(self) -> Any:
"""Redirect the call to the input which the conposition was formed from."""
"""Redirect the call to the input which the composition was formed from."""
# also some parsing for kwargs?
val_a = self.metric_a.compute() if isinstance(self.metric_a, Metric) else self.metric_a
val_b = self.metric_b.compute() if isinstance(self.metric_b, Metric) else self.metric_b
Expand Down Expand Up @@ -1216,7 +1216,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
return self._forward_cache

def reset(self) -> None:
"""Redirect the call to the input which the conposition was formed from."""
"""Redirect the call to the input which the composition was formed from."""
if isinstance(self.metric_a, Metric):
self.metric_a.reset()

Expand Down
8 changes: 5 additions & 3 deletions src/torchmetrics/segmentation/mean_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def __init__(
self.input_format = input_format

num_classes = num_classes - 1 if not include_background else num_classes
self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="mean")
self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum")
self.add_state("num_batches", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update the state with the new data."""
Expand All @@ -119,10 +120,11 @@ def update(self, preds: Tensor, target: Tensor) -> None:
)
score = _mean_iou_compute(intersection, union, per_class=self.per_class)
self.score += score.mean(0) if self.per_class else score.mean()
self.num_batches += 1

def compute(self) -> Tensor:
"""Update the state with the new data."""
return self.score # / self.num_batches
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the final Mean Intersection over Union (mIoU)."""
return self.score / self.num_batches

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand Down
25 changes: 23 additions & 2 deletions tests/unittests/_helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,25 @@ def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: O
"""Recursively assert that two results are within a certain tolerance."""
# single output compare
if isinstance(tm_result, Tensor):
assert np.allclose(tm_result.detach().cpu().numpy(), ref_result, atol=atol, equal_nan=True)
assert np.allclose(
tm_result.detach().cpu().numpy() if isinstance(tm_result, Tensor) else tm_result,
ref_result.detach().cpu().numpy() if isinstance(ref_result, Tensor) else ref_result,
atol=atol,
equal_nan=True,
)
# multi output compare
elif isinstance(tm_result, Sequence):
for pl_res, ref_res in zip(tm_result, ref_result):
_assert_allclose(pl_res, ref_res, atol=atol)
elif isinstance(tm_result, Dict):
if key is None:
raise KeyError("Provide Key for Dict based metric results.")
assert np.allclose(tm_result[key].detach().cpu().numpy(), ref_result, atol=atol, equal_nan=True)
assert np.allclose(
tm_result[key].detach().cpu().numpy() if isinstance(tm_result[key], Tensor) else tm_result[key],
ref_result.detach().cpu().numpy() if isinstance(ref_result, Tensor) else ref_result,
atol=atol,
equal_nan=True,
)
else:
raise ValueError("Unknown format for comparison")

Expand Down Expand Up @@ -147,13 +157,24 @@ def _class_test(
# verify metrics work after being loaded from pickled state
pickled_metric = pickle.dumps(metric)
metric = pickle.loads(pickled_metric)
metric_clone = deepcopy(metric)

for i in range(rank, num_batches, world_size):
batch_kwargs_update = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()}

# compute batch stats and aggregate for global stats
batch_result = metric(preds[i], target[i], **batch_kwargs_update)

if rank == 0 and world_size == 1 and i == 0: # check only in non-ddp mode and first batch
# dummy check to make sure that forward/update works as expected
metric_clone.update(preds[i], target[i], **batch_kwargs_update)
update_result = metric_clone.compute()
if isinstance(batch_result, dict):
for key in batch_result:
_assert_allclose(batch_result, update_result[key], key=key)
else:
_assert_allclose(batch_result, update_result)

if metric.dist_sync_on_step and check_dist_sync_on_step and rank == 0:
if isinstance(preds, Tensor):
ddp_preds = torch.cat([preds[i + r] for r in range(world_size)]).cpu()
Expand Down
Loading