Skip to content

Commit

Permalink
Support return dice for each class in DiceMetric (Project-MONAI#7163)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#7162
Fixes Project-MONAI#7164

### Description
Add `return_with_label`, if True or a list, will return the metrics with
the corresponding label name, only works when reduction="mean_batch".

https://github.com/pytorch/ignite/blob/47b95d087a0f8713a9d24bcfe3a539b08101ba7a/ignite/metrics/metric.py#L424

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <yunl@nvidia.com>
  • Loading branch information
KumoLiu authored Oct 25, 2023
1 parent cc20c9b commit 85243f5
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def generate_apidocs(*args):
{"name": "Twitter", "url": "https://twitter.com/projectmonai", "icon": "fab fa-twitter-square"},
],
"collapse_navigation": True,
"navigation_with_keys": True,
"navigation_depth": 1,
"show_toc_level": 1,
"footer_start": ["copyright"],
Expand Down
12 changes: 11 additions & 1 deletion monai/handlers/mean_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
num_classes: int | None = None,
output_transform: Callable = lambda x: x,
save_details: bool = True,
return_with_label: bool | list[str] = False,
) -> None:
"""
Expand All @@ -50,9 +51,18 @@ def __init__(
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
save_details: whether to save metric computation details per image, for example: mean dice of every image.
default to True, will save to `engine.state.metric_details` dict with the metric name as key.
return_with_label: whether to return the metrics with label, only works when reduction is "mean_batch".
If `True`, use "label_{index}" as the key corresponding to C channels; if 'include_background' is True,
the index begins at "0", otherwise at "1". It can also take a list of label names.
The outcome will then be returned as a dictionary.
See also:
:py:meth:`monai.metrics.meandice.compute_dice`
"""
metric_fn = DiceMetric(include_background=include_background, reduction=reduction, num_classes=num_classes)
metric_fn = DiceMetric(
include_background=include_background,
reduction=reduction,
num_classes=num_classes,
return_with_label=return_with_label,
)
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)
16 changes: 16 additions & 0 deletions monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class DiceMetric(CumulativeIterationMetric):
num_classes: number of input channels (always including the background). When this is None,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
return_with_label: whether to return the metrics with label, only works when reduction is "mean_batch".
If `True`, use "label_{index}" as the key corresponding to C channels; if 'include_background' is True,
the index begins at "0", otherwise at "1". It can also take a list of label names.
The outcome will then be returned as a dictionary.
"""

Expand All @@ -60,13 +64,15 @@ def __init__(
get_not_nans: bool = False,
ignore_empty: bool = True,
num_classes: int | None = None,
return_with_label: bool | list[str] = False,
) -> None:
super().__init__()
self.include_background = include_background
self.reduction = reduction
self.get_not_nans = get_not_nans
self.ignore_empty = ignore_empty
self.num_classes = num_classes
self.return_with_label = return_with_label
self.dice_helper = DiceHelper(
include_background=self.include_background,
reduction=MetricReduction.NONE,
Expand Down Expand Up @@ -112,6 +118,16 @@ def aggregate(

# do metric reduction
f, not_nans = do_metric_reduction(data, reduction or self.reduction)
if self.reduction == MetricReduction.MEAN_BATCH and self.return_with_label:
_f = {}
if isinstance(self.return_with_label, bool):
for i, v in enumerate(f):
_label_key = f"label_{i+1}" if not self.include_background else f"label_{i}"
_f[_label_key] = round(v.item(), 4)
else:
for key, v in zip(self.return_with_label, f):
_f[key] = round(v.item(), 4)
f = _f
return (f, not_nans) if self.get_not_nans else f


Expand Down
74 changes: 72 additions & 2 deletions tests/test_compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,71 @@
[[0.0000, 0.0000], [0.0000, 0.0000]],
]

# test return_with_label
TEST_CASE_13 = [
{
"include_background": True,
"reduction": "mean_batch",
"get_not_nans": True,
"return_with_label": ["bg", "fg0", "fg1"],
},
{
"y_pred": torch.tensor(
[
[[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],
[[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],
]
),
"y": torch.tensor(
[
[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],
]
),
},
{"bg": 0.6786, "fg0": 0.4000, "fg1": 0.6667},
]

# test return_with_label, include_background
TEST_CASE_14 = [
{"include_background": True, "reduction": "mean_batch", "get_not_nans": True, "return_with_label": True},
{
"y_pred": torch.tensor(
[
[[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],
[[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],
]
),
"y": torch.tensor(
[
[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],
]
),
},
{"label_0": 0.6786, "label_1": 0.4000, "label_2": 0.6667},
]

# test return_with_label, not include_background
TEST_CASE_15 = [
{"include_background": False, "reduction": "mean_batch", "get_not_nans": True, "return_with_label": True},
{
"y_pred": torch.tensor(
[
[[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],
[[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],
]
),
"y": torch.tensor(
[
[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],
]
),
},
{"label_1": 0.4000, "label_2": 0.6667},
]


class TestComputeMeanDice(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12])
Expand Down Expand Up @@ -223,12 +288,17 @@ def test_value_class(self, input_data, expected_value):
result = dice_metric.aggregate(reduction="none")
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

@parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8])
@parameterized.expand(
[TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15]
)
def test_nans_class(self, params, input_data, expected_value):
dice_metric = DiceMetric(**params)
dice_metric(**input_data)
result, _ = dice_metric.aggregate()
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
if isinstance(result, dict):
self.assertEqual(result, expected_value)
else:
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)


if __name__ == "__main__":
Expand Down

0 comments on commit 85243f5

Please sign in to comment.