Skip to content

Commit

Permalink
Cucim support for get_mask_edges and get_surface_distance (#7008)
Browse files Browse the repository at this point in the history
### Description
Add support for cucim for `get_mask_edges` and `get_surface_distance`.
This provides significant speedup in surface related metrics. Profiling
on my system gave 3-20x speedups depending on the input shape:
[---------------- (250, 250, 250) -----------------]
                               |    cpu    |   cuda 
1 threads: -----------------------------------------
      random()>0.2             |  26400.8  |  1306.3
      random()>0.5             |  26411.8  |  1399.1
      random()>0.8             |  29993.2  |  1009.5
      create_spherical_seg_3d  |    623.8  |    45.0

Times are in milliseconds (ms).

[--------------- (100, 100, 100) ----------------]
                               |   cpu    |   cuda
1 threads: ---------------------------------------
      random()>0.2             |  1332.5  |  140.2
      random()>0.5             |  1276.3  |  128.1
      random()>0.8             |  1179.2  |   89.1
      create_spherical_seg_3d  |   111.7  |   44.0

Times are in milliseconds (ms).

[---------------- (50, 50, 50) ----------------]
                               |   cpu   |  cuda
1 threads: -------------------------------------
      random()>0.2             |  154.5  |  47.4
      random()>0.5             |  166.7  |  39.3
      random()>0.8             |  165.0  |  38.0
      create_spherical_seg_3d  |   77.2  |  44.4

Times are in milliseconds (ms).

where create_spherical_seg_3d uses the same function from
test_hausdorff_distance, and binarizes random array using
`random(shape)>ratio`.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: John Zielke <john.zielke@snkeos.com>
  • Loading branch information
john-zielke-snkeos authored Oct 1, 2023
1 parent 21028ee commit f140e06
Show file tree
Hide file tree
Showing 9 changed files with 262 additions and 147 deletions.
27 changes: 14 additions & 13 deletions monai/losses/hausdorff_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
import warnings
from typing import Callable

import numpy as np
import torch
from torch.nn.modules.loss import _Loss

from monai.metrics.utils import distance_transform_edt
from monai.networks import one_hot
from monai.transforms.utils import distance_transform_edt
from monai.utils import LossReduction


Expand Down Expand Up @@ -95,7 +94,7 @@ def __init__(
self.batch = batch

@torch.no_grad()
def distance_field(self, img: np.ndarray) -> np.ndarray:
def distance_field(self, img: torch.Tensor) -> torch.Tensor:
"""Generate distance transform.
Args:
Expand All @@ -104,18 +103,20 @@ def distance_field(self, img: np.ndarray) -> np.ndarray:
Returns:
np.ndarray: Distance field.
"""
field = np.zeros_like(img)
field = torch.zeros_like(img)

for batch in range(len(img)):
fg_mask = img[batch] > 0.5
for batch_idx in range(len(img)):
fg_mask = img[batch_idx] > 0.5

if fg_mask.any():
# For cases where the mask is entirely background or entirely foreground
# the distance transform is not well defined for all 1s,
# which always would happen on either foreground or background, so skip
if fg_mask.any() and not fg_mask.all():
fg_dist: torch.Tensor = distance_transform_edt(fg_mask) # type: ignore
bg_mask = ~fg_mask
bg_dist: torch.Tensor = distance_transform_edt(bg_mask) # type: ignore

fg_dist = distance_transform_edt(fg_mask)
bg_dist = distance_transform_edt(bg_mask)

field[batch] = fg_dist + bg_dist
field[batch_idx] = fg_dist + bg_dist

return field

Expand Down Expand Up @@ -181,8 +182,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
for i in range(input.shape[1]):
ch_input = input[:, [i]]
ch_target = target[:, [i]]
pred_dt = torch.from_numpy(self.distance_field(ch_input.detach().cpu().numpy())).float()
target_dt = torch.from_numpy(self.distance_field(ch_target.detach().cpu().numpy())).float()
pred_dt = self.distance_field(ch_input.detach()).float()
target_dt = self.distance_field(ch_target.detach()).float()

pred_error = (ch_input - ch_target) ** 2
distance = pred_dt**self.alpha + target_dt**self.alpha
Expand Down
58 changes: 37 additions & 21 deletions monai/metrics/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from __future__ import annotations

import warnings
from collections.abc import Sequence
from typing import Any

Expand All @@ -20,12 +19,12 @@

from monai.metrics.utils import (
do_metric_reduction,
get_mask_edges,
get_edge_surface_distance,
get_surface_distance,
ignore_background,
prepare_spacing,
)
from monai.utils import MetricReduction, convert_data_type
from monai.utils import MetricReduction, convert_data_type, deprecated

from .metric import CumulativeIterationMetric

Expand Down Expand Up @@ -180,31 +179,46 @@ def compute_hausdorff_distance(
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")

batch_size, n_class = y_pred.shape[:2]
hd = np.empty((batch_size, n_class))
hd = torch.empty((batch_size, n_class), dtype=torch.float, device=y_pred.device)

img_dim = y_pred.ndim - 2
spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)

for b, c in np.ndindex(batch_size, n_class):
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c])
if not np.any(edges_gt):
warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.")
if not np.any(edges_pred):
warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.")

distance_1 = compute_percent_hausdorff_distance(
edges_pred, edges_gt, distance_metric, percentile, spacing_list[b]
_, distances, _ = get_edge_surface_distance(
y_pred[b, c],
y[b, c],
distance_metric=distance_metric,
spacing=spacing_list[b],
symetric=not directed,
class_index=c,
)
if directed:
hd[b, c] = distance_1
else:
distance_2 = compute_percent_hausdorff_distance(
edges_gt, edges_pred, distance_metric, percentile, spacing_list[b]
)
hd[b, c] = max(distance_1, distance_2)
return convert_data_type(hd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]
percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances]
max_distance = torch.max(torch.stack(percentile_distances))
hd[b, c] = max_distance
return hd


def _compute_percentile_hausdorff_distance(
surface_distance: torch.Tensor, percentile: float | None = None
) -> torch.Tensor:
"""
This function is used to compute the Hausdorff distance.
"""

# for both pred and gt do not have foreground
if surface_distance.shape == (0,):
return torch.tensor(torch.nan, dtype=torch.float, device=surface_distance.device)

if not percentile:
return surface_distance.max() # type: ignore[no-any-return]

if 0 <= percentile <= 100:
return torch.quantile(surface_distance, percentile / 100) # type: ignore[no-any-return]
raise ValueError(f"percentile should be a value between 0 and 100, get {percentile}.")


@deprecated(since="1.3.0", removed="1.5.0")
def compute_percent_hausdorff_distance(
edges_pred: np.ndarray,
edges_gt: np.ndarray,
Expand All @@ -216,7 +230,9 @@ def compute_percent_hausdorff_distance(
This function is used to compute the directed Hausdorff distance.
"""

surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing)
surface_distance: np.ndarray = get_surface_distance(
edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing
) # type: ignore

# for both pred and gt do not have foreground
if surface_distance.shape == (0,):
Expand Down
57 changes: 21 additions & 36 deletions monai/metrics/surface_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,14 @@

from __future__ import annotations

import warnings
from collections.abc import Sequence
from typing import Any

import numpy as np
import torch

from monai.metrics.utils import (
do_metric_reduction,
get_mask_edges,
get_surface_distance,
ignore_background,
prepare_spacing,
)
from monai.utils import MetricReduction, convert_data_type
from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing
from monai.utils import MetricReduction

from .metric import CumulativeIterationMetric

Expand Down Expand Up @@ -251,47 +244,39 @@ def compute_surface_dice(
if any(np.array(class_thresholds) < 0):
raise ValueError("All class thresholds need to be >= 0.")

nsd = np.empty((batch_size, n_class))
nsd = torch.empty((batch_size, n_class), device=y_pred.device, dtype=torch.float)

img_dim = y_pred.ndim - 2
spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)

for b, c in np.ndindex(batch_size, n_class):
(edges_pred, edges_gt), (distances_pred_gt, distances_gt_pred), areas = get_edge_surface_distance( # type: ignore
y_pred[b, c],
y[b, c],
distance_metric=distance_metric,
spacing=spacing_list[b],
use_subvoxels=use_subvoxels,
symetric=True,
class_index=c,
)
boundary_correct: int | torch.Tensor | float
boundary_complete: int | torch.Tensor | float
if not use_subvoxels:
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=True)
distances_pred_gt = get_surface_distance(
edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b]
)
distances_gt_pred = get_surface_distance(
edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b]
)

boundary_complete = len(distances_pred_gt) + len(distances_gt_pred)
boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum(
boundary_correct = torch.sum(distances_pred_gt <= class_thresholds[c]) + torch.sum(
distances_gt_pred <= class_thresholds[c]
)
else:
_spacing = spacing_list[b] if spacing_list[b] is not None else [1] * img_dim
areas_pred: np.ndarray
areas_gt: np.ndarray
edges_pred, edges_gt, areas_pred, areas_gt = get_mask_edges( # type: ignore
y_pred[b, c], y[b, c], crop=True, spacing=_spacing # type: ignore
)
dist_pred_to_gt = get_surface_distance(edges_pred, edges_gt, distance_metric, spacing=spacing_list[b])
dist_gt_to_pred = get_surface_distance(edges_gt, edges_pred, distance_metric, spacing=spacing_list[b])
areas_pred, areas_gt = areas # type: ignore
areas_gt, areas_pred = areas_gt[edges_gt], areas_pred[edges_pred]
boundary_complete = areas_gt.sum() + areas_pred.sum()
gt_true = areas_gt[dist_gt_to_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0
pred_true = areas_pred[dist_pred_to_gt <= class_thresholds[c]].sum() if len(areas_pred) > 0 else 0.0
boundary_complete = areas_gt.sum() + areas_pred.sum() # type: ignore
gt_true = areas_gt[distances_gt_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0
pred_true = areas_pred[distances_pred_gt <= class_thresholds[c]].sum() if len(areas_pred) > 0 else 0.0
boundary_correct = gt_true + pred_true
if not np.any(edges_gt):
warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.")
if not np.any(edges_pred):
warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.")
if boundary_complete == 0:
# the class is neither present in the prediction, nor in the reference segmentation
nsd[b, c] = np.nan
nsd[b, c] = torch.nan
else:
nsd[b, c] = boundary_correct / boundary_complete

return convert_data_type(nsd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]
return nsd
33 changes: 11 additions & 22 deletions monai/metrics/surface_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,13 @@

from __future__ import annotations

import warnings
from collections.abc import Sequence
from typing import Any

import numpy as np
import torch

from monai.metrics.utils import (
do_metric_reduction,
get_mask_edges,
get_surface_distance,
ignore_background,
prepare_spacing,
)
from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing
from monai.utils import MetricReduction, convert_data_type

from .metric import CumulativeIterationMetric
Expand Down Expand Up @@ -173,25 +166,21 @@ def compute_average_surface_distance(
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")

batch_size, n_class = y_pred.shape[:2]
asd = np.empty((batch_size, n_class))
asd = torch.empty((batch_size, n_class), dtype=torch.float32, device=y_pred.device)

img_dim = y_pred.ndim - 2
spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)

for b, c in np.ndindex(batch_size, n_class):
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c])
if not np.any(edges_gt):
warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.")
if not np.any(edges_pred):
warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.")
surface_distance = get_surface_distance(
edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b]
_, distances, _ = get_edge_surface_distance(
y_pred[b, c],
y[b, c],
distance_metric=distance_metric,
spacing=spacing_list[b],
symetric=symmetric,
class_index=c,
)
if symmetric:
surface_distance_2 = get_surface_distance(
edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b]
)
surface_distance = np.concatenate([surface_distance, surface_distance_2])
asd[b, c] = np.nan if surface_distance.shape == (0,) else surface_distance.mean()
surface_distance = torch.cat(distances)
asd[b, c] = torch.nan if surface_distance.shape == (0,) else surface_distance.mean()

return convert_data_type(asd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]
Loading

0 comments on commit f140e06

Please sign in to comment.