diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 5289dc101c..248ff24bec 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -31,7 +31,7 @@ from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique from monai.utils import TransformBackends, convert_to_numpy, optional_import from monai.utils.misc import ensure_tuple_rep -from monai.utils.type_conversion import convert_to_dst_type +from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor label, _ = optional_import("scipy.ndimage.measurements", name="label") disk, _ = optional_import("skimage.morphology", name="disk") @@ -671,6 +671,7 @@ class HoVerNetInstanceMapPostProcessing(Transform): min_num_points: minimum number of points to be considered as a contour. Defaults to 3. contour_level: an optional value for `skimage.measure.find_contours` to find contours in the array. If not provided, the level is set to `(max(image) + min(image)) / 2`. + device: target device to put the output Tensor data. """ def __init__( @@ -686,9 +687,10 @@ def __init__( watershed_connectivity: int | None = 1, min_num_points: int = 3, contour_level: float | None = None, + device: str | torch.device | None = None, ) -> None: super().__init__() - + self.device = device self.generate_watershed_mask = GenerateWatershedMask( activation=activation, threshold=mask_threshold, min_object_size=min_object_size ) @@ -742,7 +744,7 @@ def __call__( # type: ignore "centroid": instance_centroid, "contour": instance_contour, } - + instance_map = convert_to_tensor(instance_map, device=self.device) return instance_info, instance_map @@ -758,13 +760,19 @@ class HoVerNetNuclearTypePostProcessing(Transform): threshold: an optional float value to threshold to binarize probability map. If not provided, defaults to 0.5 when activation is not "softmax", otherwise None. return_type_map: whether to calculate and return pixel-level type map. + device: target device to put the output Tensor data. """ def __init__( - self, activation: str | Callable = "softmax", threshold: float | None = None, return_type_map: bool = True + self, + activation: str | Callable = "softmax", + threshold: float | None = None, + return_type_map: bool = True, + device: str | torch.device | None = None, ) -> None: super().__init__() + self.device = device self.return_type_map = return_type_map self.generate_instance_type = GenerateInstanceType() @@ -824,5 +832,6 @@ def __call__( # type: ignore # update instance type map if type_map is not None: type_map[instance_map == inst_id] = instance_type + type_map = convert_to_tensor(type_map, device=self.device) return instance_info, type_map diff --git a/monai/apps/pathology/transforms/post/dictionary.py b/monai/apps/pathology/transforms/post/dictionary.py index ef6de1b596..a95bdfd48f 100644 --- a/monai/apps/pathology/transforms/post/dictionary.py +++ b/monai/apps/pathology/transforms/post/dictionary.py @@ -14,6 +14,7 @@ from collections.abc import Callable, Hashable, Mapping import numpy as np +import torch from monai.apps.pathology.transforms.post.array import ( GenerateDistanceMap, @@ -488,6 +489,7 @@ class HoVerNetInstanceMapPostProcessingd(Transform): min_num_points: minimum number of points to be considered as a contour. Defaults to 3. contour_level: an optional value for `skimage.measure.find_contours` to find contours in the array. If not provided, the level is set to `(max(image) + min(image)) / 2`. + device: target device to put the output Tensor data. """ def __init__( @@ -507,6 +509,7 @@ def __init__( watershed_connectivity: int | None = 1, min_num_points: int = 3, contour_level: float | None = None, + device: str | torch.device | None = None, ) -> None: super().__init__() self.instance_map_post_process = HoVerNetInstanceMapPostProcessing( @@ -521,6 +524,7 @@ def __init__( watershed_connectivity=watershed_connectivity, min_num_points=min_num_points, contour_level=contour_level, + device=device, ) self.nuclear_prediction_key = nuclear_prediction_key self.hover_map_key = hover_map_key @@ -553,7 +557,7 @@ class HoVerNetNuclearTypePostProcessingd(Transform): Defaults to `"instance_info"`. instance_map_key: the key where instance map is stored. Defaults to `"instance_map"`. type_map_key: the output key where type map is written. Defaults to `"type_map"`. - + device: target device to put the output Tensor data. """ @@ -566,10 +570,11 @@ def __init__( activation: str | Callable = "softmax", threshold: float | None = None, return_type_map: bool = True, + device: str | torch.device | None = None, ) -> None: super().__init__() self.type_post_process = HoVerNetNuclearTypePostProcessing( - activation=activation, threshold=threshold, return_type_map=return_type_map + activation=activation, threshold=threshold, return_type_map=return_type_map, device=device ) self.type_prediction_key = type_prediction_key self.instance_info_key = instance_info_key