diff --git a/CHANGELOG.md b/CHANGELOG.md index d3eafaf8be5..ef9ed3c896f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed wrong aggregation in `segmentation.MeanIoU` ([#2698](https://github.com/Lightning-AI/torchmetrics/pull/2698)) + + - Fixed handling zero division error in binary IoU (Jaccard index) calculation ([#2726](https://github.com/Lightning-AI/torchmetrics/pull/2726)) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index ef9055ed560..940e393c6d1 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -284,7 +284,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: """Aggregate and evaluate batch input directly. Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch - statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding + statistics to the overall accumulating metric state. Input arguments are the exact same as corresponding ``update`` method. The returned output is the exact same as the output of ``compute``. Args: @@ -361,7 +361,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any: def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: """Forward computation using single call to `update`. - This can be done when the global metric state is a sinple reduction of batch states. This can be unsafe for + This can be done when the global metric state is a simple reduction of batch states. This can be unsafe for certain metric cases but is also the fastest way to both accumulate globally and compute locally. """ @@ -802,7 +802,7 @@ def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> Module: """Overwrite `_apply` function such that we can also move metric states to the correct device. This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods - are called. Dtype conversion is garded and will only happen through the special `set_dtype` method. + are called. Dtype conversion is guarded and will only happen through the special `set_dtype` method. Args: fn: the function to apply @@ -1166,7 +1166,7 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt """ def update(self, *args: Any, **kwargs: Any) -> None: - """Redirect the call to the input which the conposition was formed from.""" + """Redirect the call to the input which the composition was formed from.""" if isinstance(self.metric_a, Metric): self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs)) @@ -1174,7 +1174,7 @@ def update(self, *args: Any, **kwargs: Any) -> None: self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs)) def compute(self) -> Any: - """Redirect the call to the input which the conposition was formed from.""" + """Redirect the call to the input which the composition was formed from.""" # also some parsing for kwargs? val_a = self.metric_a.compute() if isinstance(self.metric_a, Metric) else self.metric_a val_b = self.metric_b.compute() if isinstance(self.metric_b, Metric) else self.metric_b @@ -1216,7 +1216,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: return self._forward_cache def reset(self) -> None: - """Redirect the call to the input which the conposition was formed from.""" + """Redirect the call to the input which the composition was formed from.""" if isinstance(self.metric_a, Metric): self.metric_a.reset() diff --git a/src/torchmetrics/segmentation/mean_iou.py b/src/torchmetrics/segmentation/mean_iou.py index c298254585a..0fe831f5231 100644 --- a/src/torchmetrics/segmentation/mean_iou.py +++ b/src/torchmetrics/segmentation/mean_iou.py @@ -110,7 +110,8 @@ def __init__( self.input_format = input_format num_classes = num_classes - 1 if not include_background else num_classes - self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="mean") + self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum") + self.add_state("num_batches", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: """Update the state with the new data.""" @@ -119,10 +120,11 @@ def update(self, preds: Tensor, target: Tensor) -> None: ) score = _mean_iou_compute(intersection, union, per_class=self.per_class) self.score += score.mean(0) if self.per_class else score.mean() + self.num_batches += 1 def compute(self) -> Tensor: - """Update the state with the new data.""" - return self.score # / self.num_batches + """Compute the final Mean Intersection over Union (mIoU).""" + return self.score / self.num_batches def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/tests/unittests/_helpers/testers.py b/tests/unittests/_helpers/testers.py index deb4c12324e..c5a69077f3c 100644 --- a/tests/unittests/_helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -32,7 +32,12 @@ def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: O """Recursively assert that two results are within a certain tolerance.""" # single output compare if isinstance(tm_result, Tensor): - assert np.allclose(tm_result.detach().cpu().numpy(), ref_result, atol=atol, equal_nan=True) + assert np.allclose( + tm_result.detach().cpu().numpy() if isinstance(tm_result, Tensor) else tm_result, + ref_result.detach().cpu().numpy() if isinstance(ref_result, Tensor) else ref_result, + atol=atol, + equal_nan=True, + ) # multi output compare elif isinstance(tm_result, Sequence): for pl_res, ref_res in zip(tm_result, ref_result): @@ -40,7 +45,12 @@ def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: O elif isinstance(tm_result, Dict): if key is None: raise KeyError("Provide Key for Dict based metric results.") - assert np.allclose(tm_result[key].detach().cpu().numpy(), ref_result, atol=atol, equal_nan=True) + assert np.allclose( + tm_result[key].detach().cpu().numpy() if isinstance(tm_result[key], Tensor) else tm_result[key], + ref_result.detach().cpu().numpy() if isinstance(ref_result, Tensor) else ref_result, + atol=atol, + equal_nan=True, + ) else: raise ValueError("Unknown format for comparison") @@ -147,6 +157,7 @@ def _class_test( # verify metrics work after being loaded from pickled state pickled_metric = pickle.dumps(metric) metric = pickle.loads(pickled_metric) + metric_clone = deepcopy(metric) for i in range(rank, num_batches, world_size): batch_kwargs_update = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} @@ -154,6 +165,16 @@ def _class_test( # compute batch stats and aggregate for global stats batch_result = metric(preds[i], target[i], **batch_kwargs_update) + if rank == 0 and world_size == 1 and i == 0: # check only in non-ddp mode and first batch + # dummy check to make sure that forward/update works as expected + metric_clone.update(preds[i], target[i], **batch_kwargs_update) + update_result = metric_clone.compute() + if isinstance(batch_result, dict): + for key in batch_result: + _assert_allclose(batch_result, update_result[key], key=key) + else: + _assert_allclose(batch_result, update_result) + if metric.dist_sync_on_step and check_dist_sync_on_step and rank == 0: if isinstance(preds, Tensor): ddp_preds = torch.cat([preds[i + r] for r in range(world_size)]).cpu()