Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement distance_transform_edt and the DistanceTransformEDT transform #6981

Merged
merged 26 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
48a2873
Implement distance_transform_edt and the DistanceTransformEDT transform
matt3o Sep 13, 2023
cbac352
Fix code style
matt3o Sep 13, 2023
150d59b
Add test_distance_transform_edt to min_tests.py
matt3o Sep 13, 2023
2349f89
Add DistanceTransformEDTd
matt3o Sep 13, 2023
e3b3846
Update docs
matt3o Sep 13, 2023
62cbdce
Fix test
matt3o Sep 14, 2023
f32841d
Add typing
matt3o Sep 14, 2023
9a88bae
Fix typing for sampling argument
matt3o Sep 14, 2023
c4ed6c5
Fix typing return value
matt3o Sep 14, 2023
f76957e
Merge branch 'dev' into gpu_edt_transform
wyli Sep 14, 2023
6eccc4a
Update docs to match the code
matt3o Sep 15, 2023
4e26689
CuPy now allows 4D channel-wise input
matt3o Sep 19, 2023
2e97953
Merge branch 'dev' into gpu_edt_transform
matt3o Sep 19, 2023
7943ffd
fixes format
wyli Sep 19, 2023
101cc62
Update monai/transforms/post/array.py
matt3o Sep 23, 2023
24814c4
Apply suggestions from code review
matt3o Sep 23, 2023
e98b10a
Add test for 4D input
matt3o Sep 20, 2023
6d19ae4
Remove force_scipy flag
matt3o Sep 24, 2023
1f60c62
Remove force_scipy flag
matt3o Sep 24, 2023
3cccb54
Rework distance_transform_edt to include more parameters
matt3o Sep 26, 2023
582efb4
Code styling
matt3o Sep 26, 2023
2bb17e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2023
aa31e31
DCO Remediation Commit for Matthias Hadlich <matthiashadlich@posteo.de>
matt3o Sep 26, 2023
94db8ea
Final fixes
matt3o Sep 27, 2023
d0075c5
Fix typing
matt3o Sep 27, 2023
d4da2af
Merge branch 'dev' into gpu_edt_transform
wyli Sep 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,12 @@ Post-processing
:members:
:special-members: __call__

`DistanceTransformEDT`
"""""""""""""""""""""""""""""""
.. autoclass:: DistanceTransformEDT
:members:
:special-members: __call__

`RemoveSmallObjects`
""""""""""""""""""""
.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RemoveSmallObjects.png
Expand Down Expand Up @@ -1640,6 +1646,12 @@ Post-processing (Dict)
:members:
:special-members: __call__

`DistanceTransformEDTd`
""""""""""""""""""""""""""""""""
.. autoclass:: DistanceTransformEDTd
:members:
:special-members: __call__

`RemoveSmallObjectsd`
"""""""""""""""""""""
.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RemoveSmallObjectsd.png
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@
from .post.array import (
Activations,
AsDiscrete,
DistanceTransformEDT,
FillHoles,
Invert,
KeepLargestConnectedComponent,
Expand All @@ -295,6 +296,9 @@
AsDiscreteD,
AsDiscreted,
AsDiscreteDict,
DistanceTransformEDTd,
DistanceTransformEDTD,
DistanceTransformEDTDict,
Ensembled,
EnsembleD,
EnsembleDict,
Expand Down
29 changes: 29 additions & 0 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from monai.transforms.utility.array import ToTensor
from monai.transforms.utils import (
convert_applied_interp_mode,
distance_transform_edt,
fill_holes,
get_largest_connected_component_mask,
get_unique_labels,
Expand All @@ -53,6 +54,7 @@
"SobelGradients",
"VoteEnsemble",
"Invert",
"DistanceTransformEDT",
]


Expand Down Expand Up @@ -936,3 +938,30 @@ def __call__(self, image: NdarrayOrTensor) -> torch.Tensor:
grads = convert_to_dst_type(grads.squeeze(0), image_tensor)[0]

return grads


class DistanceTransformEDT(Transform):
"""
Applies the Euclidean distance transform on the input.

Either GPU based with CuPy / cuCIM or CPU based with scipy.ndimage.
Choice only depends on cuCIM being available.
Note that the calculations can deviate, for details look into the cuCIM about distance_transform_edt().
"""

backend = [TransformBackends.NUMPY, TransformBackends.CUPY]

def __init__(self, sampling: None | float | list[float] = None, force_scipy: bool = False) -> None:
super().__init__()
self.force_scipy = force_scipy
self.sampling = sampling

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
img: shape must 2D or 3D for cupy, otherwise no restrictions
wyli marked this conversation as resolved.
Show resolved Hide resolved

Returns:
An array with the same shape and data type as img
"""
return distance_transform_edt(img=img, sampling=self.sampling, force_scipy=self.force_scipy)
44 changes: 44 additions & 0 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from monai.transforms.post.array import (
Activations,
AsDiscrete,
DistanceTransformEDT,
FillHoles,
KeepLargestConnectedComponent,
LabelFilter,
Expand Down Expand Up @@ -91,6 +92,9 @@
"VoteEnsembleD",
"VoteEnsembleDict",
"VoteEnsembled",
"DistanceTransformEDTd",
"DistanceTransformEDTD",
"DistanceTransformEDTDict",
]

DEFAULT_POST_FIX = PostFix.meta()
Expand Down Expand Up @@ -855,6 +859,45 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class DistanceTransformEDTd(MapTransform):
"""
Applies the Euclidean distance transform on the input.

Either GPU based with CuPy / cuCIM or CPU based with scipy.ndimage.
Choice only depends on cuCIM being available.
Note that the calculations can deviate, for details look into the cuCIM about distance_transform_edt().

Args:
keys: keys of the corresponding items to model output.
allow_missing_keys: don't raise exception if key is missing.
sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank;
if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.
force_scipy: Force the CPU based scipy implementation of the euclidean distance transform

"""

backend = DistanceTransformEDT.backend

def __init__(
self,
keys: KeysCollection,
allow_missing_keys: bool = False,
sampling: None | float | list[float] = None,
force_scipy: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.force_scipy = force_scipy
self.sampling = sampling
self.distance_transform = DistanceTransformEDT(sampling=self.sampling, force_scipy=self.force_scipy)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.distance_transform(img=d[key])

return d


ActivationsD = ActivationsDict = Activationsd
AsDiscreteD = AsDiscreteDict = AsDiscreted
FillHolesD = FillHolesDict = FillHolesd
Expand All @@ -869,3 +912,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
VoteEnsembleD = VoteEnsembleDict = VoteEnsembled
EnsembleD = EnsembleDict = Ensembled
SobelGradientsD = SobelGradientsDict = SobelGradientsd
DistanceTransformEDTD = DistanceTransformEDTDict = DistanceTransformEDTd
46 changes: 43 additions & 3 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,17 @@
pytorch_after,
)
from monai.utils.enums import TransformBackends
from monai.utils.type_conversion import convert_data_type, convert_to_cupy, convert_to_dst_type, convert_to_tensor
from monai.utils.type_conversion import (
convert_data_type,
convert_to_cupy,
convert_to_dst_type,
convert_to_numpy,
convert_to_tensor,
)

measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version)
morphology, has_morphology = optional_import("skimage.morphology")
ndimage, _ = optional_import("scipy.ndimage")
ndimage, has_ndimage = optional_import("scipy.ndimage")
cp, has_cp = optional_import("cupy")
cp_ndarray, _ = optional_import("cupy", name="ndarray")
exposure, has_skimage = optional_import("skimage.exposure")
Expand Down Expand Up @@ -124,6 +130,7 @@
"reset_ops_id",
"resolves_modes",
"has_status_keys",
"distance_transform_edt",
]


Expand Down Expand Up @@ -2012,7 +2019,7 @@ def has_status_keys(data: torch.Tensor, status_key: Any, default_message: str =

Status keys are defined in :class:`TraceStatusKeys<monai.utils.enums.TraceStatusKeys>`.

This function also accepts:
This fun ction also accepts:
matt3o marked this conversation as resolved.
Show resolved Hide resolved

* dictionaries of tensors
* lists or tuples of tensors
Expand Down Expand Up @@ -2051,5 +2058,38 @@ def has_status_keys(data: torch.Tensor, status_key: Any, default_message: str =
return True, None


def distance_transform_edt(
img: NdarrayOrTensor, sampling: None | float | list[float] = None, force_scipy: bool = False
) -> NdarrayOrTensor:
"""
Euclidean distance transform, either GPU based with CuPy / cuCIM
or CPU based with scipy.ndimage.
Choice only depends on cuCIM being available.

Args:
img: Input image on which the distance transform shall be run
sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank;
matt3o marked this conversation as resolved.
Show resolved Hide resolved
if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.
force_scipy: Force the CPU based scipy implementation of the euclidean distance transform
"""
distance_transform_edt, has_cucim = optional_import(
"cucim.core.operations.morphology", name="distance_transform_edt"
)

if has_cp and has_cucim and not force_scipy:
wyli marked this conversation as resolved.
Show resolved Hide resolved
img_ = convert_to_cupy(img)
# Only accepts 2D and 3D input as of 09-2023
# TODO Add check and switch to scipy then?
distance = distance_transform_edt(img_, sampling=sampling)
else:
if not has_ndimage:
raise RuntimeError("scipy.ndimage required if cupy is not available")
img_ = convert_to_numpy(img)
distance = ndimage.distance_transform_edt(img_, sampling=sampling)

out = convert_to_dst_type(distance, dst=img, dtype=distance.dtype)[0]
return out


if __name__ == "__main__":
print_transform_backends()
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def run_testsuit():
"test_deepgrow_transforms",
"test_detect_envelope",
"test_dints_network",
"test_distance_transform_edt",
"test_efficientnet",
"test_ensemble_evaluator",
"test_ensure_channel_first",
Expand Down
121 changes: 121 additions & 0 deletions tests/test_distance_transform_edt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import DistanceTransformEDT, DistanceTransformEDTd
from tests.utils import HAS_CUPY, assert_allclose, optional_import, skip_if_no_cuda

momorphology, has_cucim = optional_import("cucim.core.operations.morphology")
ndimage, has_ndimage = optional_import("scipy.ndimage")
cp, _ = optional_import("cupy")

TEST_CASES = [
[
wyli marked this conversation as resolved.
Show resolved Hide resolved
np.array(
([0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 0]), dtype=np.float32
),
np.array(
[
[0.0, 1.0, 1.4142, 2.2361, 3.0],
[0.0, 0.0, 1.0, 2.0, 2.0],
[0.0, 1.0, 1.4142, 1.4142, 1.0],
[0.0, 1.0, 1.4142, 1.0, 0.0],
[0.0, 1.0, 1.0, 0.0, 0.0],
]
),
]
]

SAMPLING_TEST_CASES = [
[
2,
np.array(
([0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 0]), dtype=np.float32
),
np.array(
[
[0.0, 2.0, 2.828427, 4.472136, 6.0],
[0.0, 0.0, 2.0, 4.0, 4.0],
[0.0, 2.0, 2.828427, 2.828427, 2.0],
[0.0, 2.0, 2.828427, 2.0, 0.0],
[0.0, 2.0, 2.0, 0.0, 0.0],
]
),
]
]

RAISES_TEST_CASES = (
[
np.array(
[[[[0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 0]]]], dtype=np.float32
)
],
)


class TestDistanceTransformEDT(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_scipy_transform(self, input, expected_output):
transform = DistanceTransformEDT(force_scipy=True)
output = transform(input)
assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False)

@parameterized.expand(TEST_CASES)
def test_scipy_transformd(self, input, expected_output):
transform = DistanceTransformEDTd(keys=("to_transform",), force_scipy=True)
data = {"to_transform": input}
data_ = transform(data)
output = data_["to_transform"]
assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False)

@parameterized.expand(TEST_CASES)
@skip_if_no_cuda
@unittest.skipUnless(HAS_CUPY, "CuPy is required.")
@unittest.skipUnless(momorphology, "cuCIM transforms are required.")
def test_cucim_transform(self, input, expected_output):
transform = DistanceTransformEDT()
output = transform(input)
assert_allclose(cp.asnumpy(output), cp.asnumpy(expected_output), atol=1e-4, rtol=1e-4, type_test=False)

@parameterized.expand(SAMPLING_TEST_CASES)
def test_sampling(self, sampling, input, expected_output):
transform = DistanceTransformEDT(force_scipy=True, sampling=sampling)
output = transform(input)
assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False)

@parameterized.expand(SAMPLING_TEST_CASES)
@skip_if_no_cuda
@unittest.skipUnless(HAS_CUPY, "CuPy is required.")
@unittest.skipUnless(momorphology, "cuCIM transforms are required.")
def test_cucim_sampling(self, sampling, input, expected_output):
transform = DistanceTransformEDT(sampling=sampling)
output = transform(input)
assert_allclose(cp.asnumpy(output), cp.asnumpy(expected_output), atol=1e-4, rtol=1e-4, type_test=False)

# @skip_if_no_cuda
# @unittest.skipUnless(HAS_CUPY, "CuPy is required.")
# @unittest.skipUnless(momorphology, "cuCIM transforms are required.")
# @parameterized.expand(RAISES_TEST_CASES)
# def test_cucim_raises(self, raises):
# """Currently only 2D and 3D images are supported by CuPy. This test checks for the according error message"""
# transform = DistanceTransformEDT()
# with self.assertRaises(NotImplementedError):
# output = transform(raises)


if __name__ == "__main__":
unittest.main()
Loading