Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GridPatch with both count and threshold filtering #6055

Merged
merged 20 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 56 additions & 31 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
optional_import,
)
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends, WSIPatchKeys
from monai.utils.enums import GridPatchSort, PatchKeys, PytorchPadMode, TraceKeys, TransformBackends
from monai.utils.misc import ImageMetaKey as Key
from monai.utils.module import look_up_option
from monai.utils.type_conversion import convert_data_type, get_equivalent_dtype, get_torch_dtype_from_string
Expand Down Expand Up @@ -3139,14 +3139,18 @@ class GridPatch(Transform, MultiSampleTrait):
Args:
patch_size: size of patches to generate slices for, 0 or None selects whole dimension
offset: offset of starting position in the array, default is 0 for each dimension.
num_patches: number of patches to return. Defaults to None, which returns all the available patches.
If the required patches are more than the available patches, padding will be applied.
num_patches: number of patches (or maximum number of patches) to return.
If the requested number of patches is greater than the number of available patches,
padding will be applied to provide exactly `num_patches` patches unless `threshold` is set.
Defaults to None, which returns all the available patches.
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`),
lowest values (`"min"`), or in their default order (`None`). Default to None.
threshold: a value to keep only the patches whose sum of intensities are less than the threshold.
Defaults to no filtering.
threshold_first: whether to apply threshold filtering before limiting the number of patches to `num_patches`.
drbeh marked this conversation as resolved.
Show resolved Hide resolved
Defaults to True.
pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``.
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.

Expand All @@ -3165,6 +3169,7 @@ def __init__(
overlap: Sequence[float] | float = 0.0,
sort_fn: str | None = None,
threshold: float | None = None,
threshold_first: bool = True,
pad_mode: str = PytorchPadMode.CONSTANT,
**pad_kwargs,
):
Expand All @@ -3176,27 +3181,30 @@ def __init__(
self.num_patches = num_patches
self.sort_fn = sort_fn.lower() if sort_fn else None
self.threshold = threshold
self.threshold_first = threshold_first

def filter_threshold(self, image_np: np.ndarray, locations: np.ndarray):
"""
Filter the patches and their locations according to a threshold
Filter the patches and their locations according to a threshold.

Args:
image_np: a numpy.ndarray representing a stack of patches
locations: a numpy.ndarray representing the stack of location of each patch
image_np: a numpy.ndarray representing a stack of patches.
locations: a numpy.ndarray representing the stack of location of each patch.

Returns:
tuple[numpy.ndarray, numpy.ndarray]: tuple of filtered patches and locations.
"""
if self.threshold is not None:
n_dims = len(image_np.shape)
idx = np.argwhere(image_np.sum(axis=tuple(range(1, n_dims))) < self.threshold).reshape(-1)
image_np = image_np[idx]
locations = locations[idx]
return image_np, locations
n_dims = len(image_np.shape)
idx = np.argwhere(image_np.sum(axis=tuple(range(1, n_dims))) < self.threshold).reshape(-1)
return image_np[idx], locations[idx]

def filter_count(self, image_np: np.ndarray, locations: np.ndarray):
"""
Sort the patches based on the sum of their intensity, and just keep `self.num_patches` of them.

Args:
image_np: a numpy.ndarray representing a stack of patches
locations: a numpy.ndarray representing the stack of location of each patch
image_np: a numpy.ndarray representing a stack of patches.
locations: a numpy.ndarray representing the stack of location of each patch.
"""
if self.sort_fn is None:
image_np = image_np[: self.num_patches]
Expand All @@ -3214,7 +3222,17 @@ def filter_count(self, image_np: np.ndarray, locations: np.ndarray):
locations = locations[idx]
return image_np, locations

def __call__(self, array: NdarrayOrTensor):
def __call__(self, array: NdarrayOrTensor) -> MetaTensor:
"""
Extract the patches (sweeping the entire image in a row-major sliding-window manner with possible overlaps).

Args:
array: a input image as `numpy.ndarray` or `torch.Tensor`

Return:
MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension),
with defined `PatchKeys.LOCATION` and `PatchKeys.COUNT` metadata.
"""
# create the patch iterator which sweeps the image row-by-row
array_np, *_ = convert_data_type(array, np.ndarray)
patch_iterator = iter_patch(
Expand All @@ -3228,29 +3246,36 @@ def __call__(self, array: NdarrayOrTensor):
)
patches = list(zip(*patch_iterator))
patched_image = np.array(patches[0])
locations = np.array(patches[1])[:, 1:, 0] # only keep the starting location
del patches[0]
locations = np.array(patches[0])[:, 1:, 0] # only keep the starting location
del patches[0]

# Filter patches
if self.num_patches:
patched_image, locations = self.filter_count(patched_image, locations)
drbeh marked this conversation as resolved.
Show resolved Hide resolved
elif self.threshold:
# Apply threshold filter before filtering by count (if threshold_first is set)
if self.threshold_first and self.threshold is not None:
patched_image, locations = self.filter_threshold(patched_image, locations)

# Pad the patch list to have the requested number of patches
if self.num_patches:
padding = self.num_patches - len(patched_image)
if padding > 0:
patched_image = np.pad(
patched_image,
[[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size),
constant_values=self.pad_kwargs.get("constant_values", 0),
)
locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0)
# Limit number of patches
patched_image, locations = self.filter_count(patched_image, locations)
# Pad the patch list to have the requested number of patches
if self.threshold is None:
padding = self.num_patches - len(patched_image)
if padding > 0:
patched_image = np.pad(
drbeh marked this conversation as resolved.
Show resolved Hide resolved
patched_image,
[[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size),
constant_values=self.pad_kwargs.get("constant_values", 0),
)
locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0)

# Apply threshold filter after filtering by count (if threshold_first is not set)
if not self.threshold_first and self.threshold is not None:
patched_image, locations = self.filter_threshold(patched_image, locations)

# Convert to MetaTensor
metadata = array.meta if isinstance(array, MetaTensor) else MetaTensor.get_default_meta()
metadata[WSIPatchKeys.LOCATION] = locations.T
metadata[WSIPatchKeys.COUNT] = len(locations)
metadata[PatchKeys.LOCATION] = locations.T
metadata[PatchKeys.COUNT] = len(locations)
metadata["spatial_shape"] = np.tile(np.array(self.patch_size), (len(locations), 1)).T
output = MetaTensor(x=patched_image, meta=metadata)
output.is_batch = True
Expand Down
21 changes: 18 additions & 3 deletions tests/test_grid_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,26 @@
A,
[np.zeros((3, 3, 3)), np.pad(A[:, :1, 1:4], ((0, 0), (2, 0), (0, 0)), mode="constant")],
]
# Only threshold filtering
TEST_CASE_13 = [{"patch_size": (2, 2), "threshold": 50.0}, A, [A11]]
TEST_CASE_14 = [{"patch_size": (2, 2), "threshold": 150.0}, A, [A11, A12, A21]]
# threshold filtering with num_patches more than available patches (no effect)
TEST_CASE_15 = [{"patch_size": (2, 2), "num_patches": 3, "threshold": 50.0}, A, [A11]]
# threshold filtering with num_patches less than available patches (count filtering)
TEST_CASE_16 = [{"patch_size": (2, 2), "num_patches": 2, "threshold": 150.0}, A, [A11, A12]]
# threshold filtering before count filtering
TEST_CASE_17 = [{"patch_size": (2, 2), "num_patches": 2, "threshold": -50.0, "threshold_first": True}, -A, [-A12, -A21]]
# threshold filtering after count filtering (causes desirable or undesirable data reduction)
TEST_CASE_18 = [{"patch_size": (2, 2), "num_patches": 2, "threshold": -50.0, "threshold_first": False}, -A, [-A12]]

TEST_CASE_MEAT_0 = [
TEST_CASE_META_0 = [
{"patch_size": (2, 2)},
A,
[A11, A12, A21, A22],
[{"location": [0, 0]}, {"location": [0, 2]}, {"location": [2, 0]}, {"location": [2, 2]}],
]

TEST_CASE_MEAT_1 = [
TEST_CASE_META_1 = [
{"patch_size": (2, 2)},
MetaTensor(x=A, meta={"path": "path/to/file"}),
[A11, A12, A21, A22],
Expand Down Expand Up @@ -84,6 +94,11 @@
TEST_CASES.append([p, *TEST_CASE_11])
TEST_CASES.append([p, *TEST_CASE_12])
TEST_CASES.append([p, *TEST_CASE_13])
TEST_CASES.append([p, *TEST_CASE_14])
TEST_CASES.append([p, *TEST_CASE_15])
TEST_CASES.append([p, *TEST_CASE_16])
TEST_CASES.append([p, *TEST_CASE_17])
TEST_CASES.append([p, *TEST_CASE_18])


class TestGridPatch(unittest.TestCase):
Expand All @@ -96,7 +111,7 @@ def test_grid_patch(self, in_type, input_parameters, image, expected):
for output_patch, expected_patch in zip(output, expected):
assert_allclose(output_patch, expected_patch, type_test=False)

@parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1])
@parameterized.expand([TEST_CASE_META_0, TEST_CASE_META_1])
@SkipIfBeforePyTorchVersion((1, 9, 1))
def test_grid_patch_meta(self, input_parameters, image, expected, expected_meta):
set_track_meta(True)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_grid_patchd.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,25 @@
{"image": A},
[np.zeros((3, 3, 3)), np.pad(A[:, :1, 1:4], ((0, 0), (2, 0), (0, 0)), mode="constant")],
]
# Only threshold filtering
TEST_CASE_13 = [{"patch_size": (2, 2), "threshold": 50.0}, {"image": A}, [A11]]
TEST_CASE_14 = [{"patch_size": (2, 2), "threshold": 150.0}, {"image": A}, [A11, A12, A21]]
# threshold filtering with num_patches more than available patches (no effect)
TEST_CASE_15 = [{"patch_size": (2, 2), "threshold": 50.0, "num_patches": 3}, {"image": A}, [A11]]
# threshold filtering with num_patches less than available patches (count filtering)
TEST_CASE_16 = [{"patch_size": (2, 2), "threshold": 150.0, "num_patches": 2}, {"image": A}, [A11, A12]]
# threshold filtering before count filtering
TEST_CASE_17 = [
{"patch_size": (2, 2), "num_patches": 2, "threshold": -50.0, "threshold_first": True},
{"image": -A},
[-A12, -A21],
]
# threshold filtering after count filtering (causes desirable or undesirable data reduction)
TEST_CASE_18 = [
{"patch_size": (2, 2), "num_patches": 2, "threshold": -50.0, "threshold_first": False},
{"image": -A},
[-A12],
]

TEST_SINGLE = []
for p in TEST_NDARRAYS:
Expand All @@ -64,6 +82,11 @@
TEST_SINGLE.append([p, *TEST_CASE_11])
TEST_SINGLE.append([p, *TEST_CASE_12])
TEST_SINGLE.append([p, *TEST_CASE_13])
TEST_SINGLE.append([p, *TEST_CASE_14])
TEST_SINGLE.append([p, *TEST_CASE_15])
TEST_SINGLE.append([p, *TEST_CASE_16])
TEST_SINGLE.append([p, *TEST_CASE_17])
TEST_SINGLE.append([p, *TEST_CASE_18])


class TestGridPatchd(unittest.TestCase):
Expand Down