Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cucim support for get_mask_edges and get_surface_distance #7008

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions monai/losses/hausdorff_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
import warnings
from typing import Callable

import numpy as np
import torch
from torch.nn.modules.loss import _Loss

from monai.metrics.utils import distance_transform_edt
from monai.networks import one_hot
from monai.transforms.utils import distance_transform_edt
from monai.utils import LossReduction


Expand Down Expand Up @@ -95,7 +94,7 @@ def __init__(
self.batch = batch

@torch.no_grad()
def distance_field(self, img: np.ndarray) -> np.ndarray:
def distance_field(self, img: torch.Tensor) -> torch.Tensor:
"""Generate distance transform.

Args:
Expand All @@ -104,18 +103,20 @@ def distance_field(self, img: np.ndarray) -> np.ndarray:
Returns:
np.ndarray: Distance field.
"""
field = np.zeros_like(img)
field = torch.zeros_like(img)

for batch in range(len(img)):
fg_mask = img[batch] > 0.5
for batch_idx in range(len(img)):
fg_mask = img[batch_idx] > 0.5

if fg_mask.any():
# For cases where the mask is entirely background or entirely foreground
# the distance transform is not well defined for all 1s,
# which always would happen on either foreground or background, so skip
if fg_mask.any() and not fg_mask.all():
fg_dist: torch.Tensor = distance_transform_edt(fg_mask) # type: ignore
bg_mask = ~fg_mask
bg_dist: torch.Tensor = distance_transform_edt(bg_mask) # type: ignore

fg_dist = distance_transform_edt(fg_mask)
bg_dist = distance_transform_edt(bg_mask)

field[batch] = fg_dist + bg_dist
field[batch_idx] = fg_dist + bg_dist

return field

Expand Down Expand Up @@ -181,8 +182,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
for i in range(input.shape[1]):
ch_input = input[:, [i]]
ch_target = target[:, [i]]
pred_dt = torch.from_numpy(self.distance_field(ch_input.detach().cpu().numpy())).float()
target_dt = torch.from_numpy(self.distance_field(ch_target.detach().cpu().numpy())).float()
pred_dt = self.distance_field(ch_input.detach()).float()
target_dt = self.distance_field(ch_target.detach()).float()

pred_error = (ch_input - ch_target) ** 2
distance = pred_dt**self.alpha + target_dt**self.alpha
Expand Down
58 changes: 37 additions & 21 deletions monai/metrics/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from __future__ import annotations

import warnings
from collections.abc import Sequence
from typing import Any

Expand All @@ -20,12 +19,12 @@

from monai.metrics.utils import (
do_metric_reduction,
get_mask_edges,
get_edge_surface_distance,
get_surface_distance,
ignore_background,
prepare_spacing,
)
from monai.utils import MetricReduction, convert_data_type
from monai.utils import MetricReduction, convert_data_type, deprecated

from .metric import CumulativeIterationMetric

Expand Down Expand Up @@ -180,31 +179,46 @@ def compute_hausdorff_distance(
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")

batch_size, n_class = y_pred.shape[:2]
hd = np.empty((batch_size, n_class))
hd = torch.empty((batch_size, n_class), dtype=torch.float, device=y_pred.device)

img_dim = y_pred.ndim - 2
spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)

for b, c in np.ndindex(batch_size, n_class):
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c])
if not np.any(edges_gt):
warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.")
if not np.any(edges_pred):
warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.")

distance_1 = compute_percent_hausdorff_distance(
edges_pred, edges_gt, distance_metric, percentile, spacing_list[b]
_, distances, _ = get_edge_surface_distance(
y_pred[b, c],
y[b, c],
distance_metric=distance_metric,
spacing=spacing_list[b],
symetric=not directed,
class_index=c,
)
if directed:
hd[b, c] = distance_1
else:
distance_2 = compute_percent_hausdorff_distance(
edges_gt, edges_pred, distance_metric, percentile, spacing_list[b]
)
hd[b, c] = max(distance_1, distance_2)
return convert_data_type(hd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]
percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances]
max_distance = torch.max(torch.stack(percentile_distances))
hd[b, c] = max_distance
return hd


def _compute_percentile_hausdorff_distance(
surface_distance: torch.Tensor, percentile: float | None = None
) -> torch.Tensor:
"""
This function is used to compute the Hausdorff distance.
"""

# for both pred and gt do not have foreground
if surface_distance.shape == (0,):
return torch.tensor(torch.nan, dtype=torch.float, device=surface_distance.device)

if not percentile:
return surface_distance.max() # type: ignore[no-any-return]

if 0 <= percentile <= 100:
return torch.quantile(surface_distance, percentile / 100) # type: ignore[no-any-return]
raise ValueError(f"percentile should be a value between 0 and 100, get {percentile}.")


@deprecated(since="1.3.0", removed="1.5.0")
def compute_percent_hausdorff_distance(
edges_pred: np.ndarray,
edges_gt: np.ndarray,
Expand All @@ -216,7 +230,9 @@ def compute_percent_hausdorff_distance(
This function is used to compute the directed Hausdorff distance.
"""

surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing)
surface_distance: np.ndarray = get_surface_distance(
edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing
) # type: ignore

# for both pred and gt do not have foreground
if surface_distance.shape == (0,):
Expand Down
57 changes: 21 additions & 36 deletions monai/metrics/surface_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,14 @@

from __future__ import annotations

import warnings
from collections.abc import Sequence
from typing import Any

import numpy as np
import torch

from monai.metrics.utils import (
do_metric_reduction,
get_mask_edges,
get_surface_distance,
ignore_background,
prepare_spacing,
)
from monai.utils import MetricReduction, convert_data_type
from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing
from monai.utils import MetricReduction

from .metric import CumulativeIterationMetric

Expand Down Expand Up @@ -251,47 +244,39 @@ def compute_surface_dice(
if any(np.array(class_thresholds) < 0):
raise ValueError("All class thresholds need to be >= 0.")

nsd = np.empty((batch_size, n_class))
nsd = torch.empty((batch_size, n_class), device=y_pred.device, dtype=torch.float)

img_dim = y_pred.ndim - 2
spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)

for b, c in np.ndindex(batch_size, n_class):
(edges_pred, edges_gt), (distances_pred_gt, distances_gt_pred), areas = get_edge_surface_distance( # type: ignore
y_pred[b, c],
y[b, c],
distance_metric=distance_metric,
spacing=spacing_list[b],
use_subvoxels=use_subvoxels,
symetric=True,
class_index=c,
)
boundary_correct: int | torch.Tensor | float
boundary_complete: int | torch.Tensor | float
if not use_subvoxels:
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=True)
distances_pred_gt = get_surface_distance(
edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b]
)
distances_gt_pred = get_surface_distance(
edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b]
)

boundary_complete = len(distances_pred_gt) + len(distances_gt_pred)
boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum(
boundary_correct = torch.sum(distances_pred_gt <= class_thresholds[c]) + torch.sum(
distances_gt_pred <= class_thresholds[c]
)
else:
_spacing = spacing_list[b] if spacing_list[b] is not None else [1] * img_dim
areas_pred: np.ndarray
areas_gt: np.ndarray
edges_pred, edges_gt, areas_pred, areas_gt = get_mask_edges( # type: ignore
y_pred[b, c], y[b, c], crop=True, spacing=_spacing # type: ignore
)
dist_pred_to_gt = get_surface_distance(edges_pred, edges_gt, distance_metric, spacing=spacing_list[b])
dist_gt_to_pred = get_surface_distance(edges_gt, edges_pred, distance_metric, spacing=spacing_list[b])
areas_pred, areas_gt = areas # type: ignore
areas_gt, areas_pred = areas_gt[edges_gt], areas_pred[edges_pred]
boundary_complete = areas_gt.sum() + areas_pred.sum()
gt_true = areas_gt[dist_gt_to_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0
pred_true = areas_pred[dist_pred_to_gt <= class_thresholds[c]].sum() if len(areas_pred) > 0 else 0.0
boundary_complete = areas_gt.sum() + areas_pred.sum() # type: ignore
gt_true = areas_gt[distances_gt_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0
pred_true = areas_pred[distances_pred_gt <= class_thresholds[c]].sum() if len(areas_pred) > 0 else 0.0
boundary_correct = gt_true + pred_true
if not np.any(edges_gt):
warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.")
if not np.any(edges_pred):
warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.")
if boundary_complete == 0:
# the class is neither present in the prediction, nor in the reference segmentation
nsd[b, c] = np.nan
nsd[b, c] = torch.nan
else:
nsd[b, c] = boundary_correct / boundary_complete

return convert_data_type(nsd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]
return nsd
33 changes: 11 additions & 22 deletions monai/metrics/surface_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,13 @@

from __future__ import annotations

import warnings
from collections.abc import Sequence
from typing import Any

import numpy as np
import torch

from monai.metrics.utils import (
do_metric_reduction,
get_mask_edges,
get_surface_distance,
ignore_background,
prepare_spacing,
)
from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing
from monai.utils import MetricReduction, convert_data_type

from .metric import CumulativeIterationMetric
Expand Down Expand Up @@ -173,25 +166,21 @@ def compute_average_surface_distance(
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")

batch_size, n_class = y_pred.shape[:2]
asd = np.empty((batch_size, n_class))
asd = torch.empty((batch_size, n_class), dtype=torch.float32, device=y_pred.device)

img_dim = y_pred.ndim - 2
spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)

for b, c in np.ndindex(batch_size, n_class):
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c])
if not np.any(edges_gt):
warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.")
if not np.any(edges_pred):
warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.")
surface_distance = get_surface_distance(
edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b]
_, distances, _ = get_edge_surface_distance(
y_pred[b, c],
y[b, c],
distance_metric=distance_metric,
spacing=spacing_list[b],
symetric=symmetric,
class_index=c,
)
if symmetric:
surface_distance_2 = get_surface_distance(
edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b]
)
surface_distance = np.concatenate([surface_distance, surface_distance_2])
asd[b, c] = np.nan if surface_distance.shape == (0,) else surface_distance.mean()
surface_distance = torch.cat(distances)
asd[b, c] = torch.nan if surface_distance.shape == (0,) else surface_distance.mean()

return convert_data_type(asd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]
Loading