Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add device in HoVerNetNuclearTypePostProcessing and HoVerNetInstanceMapPostProcessing #6333

Merged
merged 3 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions monai/apps/pathology/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique
from monai.utils import TransformBackends, convert_to_numpy, optional_import
from monai.utils.misc import ensure_tuple_rep
from monai.utils.type_conversion import convert_to_dst_type
from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor

label, _ = optional_import("scipy.ndimage.measurements", name="label")
disk, _ = optional_import("skimage.morphology", name="disk")
Expand Down Expand Up @@ -671,6 +671,7 @@ class HoVerNetInstanceMapPostProcessing(Transform):
min_num_points: minimum number of points to be considered as a contour. Defaults to 3.
contour_level: an optional value for `skimage.measure.find_contours` to find contours in the array.
If not provided, the level is set to `(max(image) + min(image)) / 2`.
device: target device to put the output Tensor data.
"""

def __init__(
Expand All @@ -686,9 +687,10 @@ def __init__(
watershed_connectivity: int | None = 1,
min_num_points: int = 3,
contour_level: float | None = None,
device: str | torch.device | None = None,
) -> None:
super().__init__()

self.device = device
self.generate_watershed_mask = GenerateWatershedMask(
activation=activation, threshold=mask_threshold, min_object_size=min_object_size
)
Expand Down Expand Up @@ -742,7 +744,7 @@ def __call__( # type: ignore
"centroid": instance_centroid,
"contour": instance_contour,
}

instance_map = convert_to_tensor(instance_map, device=self.device)
return instance_info, instance_map


Expand All @@ -758,13 +760,19 @@ class HoVerNetNuclearTypePostProcessing(Transform):
threshold: an optional float value to threshold to binarize probability map.
If not provided, defaults to 0.5 when activation is not "softmax", otherwise None.
return_type_map: whether to calculate and return pixel-level type map.
device: target device to put the output Tensor data.

"""

def __init__(
self, activation: str | Callable = "softmax", threshold: float | None = None, return_type_map: bool = True
self,
activation: str | Callable = "softmax",
threshold: float | None = None,
return_type_map: bool = True,
device: str | torch.device | None = None,
) -> None:
super().__init__()
self.device = device
self.return_type_map = return_type_map
self.generate_instance_type = GenerateInstanceType()

Expand Down Expand Up @@ -824,5 +832,6 @@ def __call__( # type: ignore
# update instance type map
if type_map is not None:
type_map[instance_map == inst_id] = instance_type
type_map = convert_to_tensor(type_map, device=self.device)

return instance_info, type_map
9 changes: 7 additions & 2 deletions monai/apps/pathology/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from collections.abc import Callable, Hashable, Mapping

import numpy as np
import torch

from monai.apps.pathology.transforms.post.array import (
GenerateDistanceMap,
Expand Down Expand Up @@ -488,6 +489,7 @@ class HoVerNetInstanceMapPostProcessingd(Transform):
min_num_points: minimum number of points to be considered as a contour. Defaults to 3.
contour_level: an optional value for `skimage.measure.find_contours` to find contours in the array.
If not provided, the level is set to `(max(image) + min(image)) / 2`.
device: target device to put the output Tensor data.
"""

def __init__(
Expand All @@ -507,6 +509,7 @@ def __init__(
watershed_connectivity: int | None = 1,
min_num_points: int = 3,
contour_level: float | None = None,
device: str | torch.device | None = None,
) -> None:
super().__init__()
self.instance_map_post_process = HoVerNetInstanceMapPostProcessing(
Expand All @@ -521,6 +524,7 @@ def __init__(
watershed_connectivity=watershed_connectivity,
min_num_points=min_num_points,
contour_level=contour_level,
device=device,
)
self.nuclear_prediction_key = nuclear_prediction_key
self.hover_map_key = hover_map_key
Expand Down Expand Up @@ -553,7 +557,7 @@ class HoVerNetNuclearTypePostProcessingd(Transform):
Defaults to `"instance_info"`.
instance_map_key: the key where instance map is stored. Defaults to `"instance_map"`.
type_map_key: the output key where type map is written. Defaults to `"type_map"`.

device: target device to put the output Tensor data.

"""

Expand All @@ -566,10 +570,11 @@ def __init__(
activation: str | Callable = "softmax",
threshold: float | None = None,
return_type_map: bool = True,
device: str | torch.device | None = None,
) -> None:
super().__init__()
self.type_post_process = HoVerNetNuclearTypePostProcessing(
activation=activation, threshold=threshold, return_type_map=return_type_map
activation=activation, threshold=threshold, return_type_map=return_type_map, device=device
)
self.type_prediction_key = type_prediction_key
self.instance_info_key = instance_info_key
Expand Down