From a9037a1b50839a9519d92cbb6e3478ef0a7b6da8 Mon Sep 17 00:00:00 2001 From: Bradley Lowekamp Date: Mon, 5 Jun 2023 09:46:41 -0400 Subject: [PATCH] WIP: initial prototype of zarr_extract function --- pytools/__init__.py | 4 +- pytools/workflow_functions.py | 123 +++++++++++++++++++++++++++++++++- 2 files changed, 124 insertions(+), 3 deletions(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index 994d677..d679b31 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -12,7 +12,7 @@ # limitations under the License. # -from .workflow_functions import visual_min_max +from .workflow_functions import visual_min_max, zarr_extract import logging @@ -27,4 +27,4 @@ pass -__all__ = ["__version__", "visual_min_max", "logger"] +__all__ = ["__version__", "visual_min_max", "zarr_extract", "logger"] diff --git a/pytools/workflow_functions.py b/pytools/workflow_functions.py index 746f532..34c4d46 100644 --- a/pytools/workflow_functions.py +++ b/pytools/workflow_functions.py @@ -13,7 +13,7 @@ # from pathlib import Path -from typing import Dict, Union +from typing import Dict, Union, Sequence import logging import math from SimpleITK.utilities.dask import from_sitk @@ -21,8 +21,11 @@ import SimpleITK as sitk import dask.array +import zarr + logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) def visual_min_max( @@ -93,3 +96,121 @@ def visual_min_max( } return output + + +def zarr_extract( + input_zarr: Union[Path, str], + target_size: Sequence[int], + size_factor: float = 1.5, + output_filename: Union[Path, str, None] = None, +) -> Union[sitk.Image, None]: + """Extracts a volume from an OME-NGFF pyramid structured ZARR array. + + + + :param input_zarr: The path to an OME-NGFF structured ZARR array. + :param target_size: The size of the subvolume to extract. [Z, Y, X] + :param size_factor: The size of the subvolume to extract will be increased by this factor so that the + extracted subvolume can have antialiasing applied to it. + :param output_filename: If not None then the extracted subvolume will be written to this file. + :return: The extracted subvolume as a SimpleITK image. + + """ + + input_zarr = Path(input_zarr) + + store = zarr.DirectoryStore(input_zarr) + group = zarr.group(store=store) + logger.debug(group.info) + + if "multiscales" not in group.attrs: + raise ValueError(f"Missing OME-NGFF multiscales meta data in zarr group: {input_zarr}") + + for image_meta in group.attrs["multiscales"]: + axes = image_meta["axes"] + + spacial_dim = [d for d, ax in enumerate(axes) if ax["type"].lower() == "space"] + + if len(target_size) != len(spacial_dim): + raise ValueError(f"target_size: {target_size} does not match the number of spacial dimensions: {spacial_dim}") + + _target_size = target_size.copy() + _target_size = [_target_size.pop(0) if ax["type"].lower() == "space" else 0 for ax in axes] + + logger.debug(f"_target_size: {_target_size}") + + path_to_size = {} + for dataset in image_meta["datasets"]: + level_path = dataset["path"] + + arr = zarr.Array(store, read_only=True, path=level_path) + path_to_size[level_path] = arr.shape + + logger.debug(f"path: {level_path} shape: {arr.shape}") + + # sort by total number of pixels, smallest first + path_to_size = dict(sorted(path_to_size.items(), key=lambda item: math.prod(item[1]))) + logger.debug(f"path_to_size: {path_to_size}") + + max_size_per_dim = next(iter(path_to_size.values())) + for path, size in path_to_size.items(): + max_size_per_dim = [max(s, m) for s, m in zip(size, max_size_per_dim)] + + _target_size = [min(t, m) for t, m in zip(_target_size, max_size_per_dim)] + + logger.debug(f"max_size_per_dim: {max_size_per_dim}") + + # search the pyramid for the first level that is larger than the target size with factor + for path, size in path_to_size.items(): + logger.debug(f"size: {size} _target_size: {_target_size} max_size_per_dim: {max_size_per_dim}") + + if all([s > size_factor * t for s, t, m in zip(size, _target_size, max_size_per_dim) if t > 0 and t != m]): + break + logger.debug(f"selected path: {path} size: {size}") + + arr = zarr.Array(store, read_only=True, path=path) + arr = arr.astype(arr.dtype.newbyteorder("=")) + logger.debug(arr.info) + + channel_dim = None + idx = [] + for d, ax in enumerate(axes): + ax_type = ax["type"].lower() + ax_name = ax["name"].lower() + if ax_type == "channel": + channel_dim = d + idx.append(slice(None)) + elif ax_type == "space": + if ax_name == "z": + idx.append((arr.shape[d] - 1) // 2) + else: + idx.append(slice(None)) + elif ax_type == "time": + idx.append((arr.shape[d] - 1) // 2) + + logger.debug(f"channel_dim: {channel_dim} idx: {idx}") + + img_list = [] + if channel_dim is not None: + for d in range(arr.shape[channel_dim]): + idx[channel_dim] = d + img_list.append(sitk.GetImageFromArray(arr[tuple(idx)])) + + # This should be able to be replaces with a PermuteAxesImageFilter and a ToVector, when available in SimpleITK + img = sitk.Compose(img_list) + elif len(axes) < 4: + img = sitk.GetImageFromArray(arr[tuple(idx)]) + else: + raise ValueError(f"Unsupported axes types: {[ax['type'] for ax in axes]}") + + logger.debug(img) + + logger.debug(f"resizing image of: {img.GetSize() } -> {target_size[::-1]}") + img = sitk.utilities.resize(img, target_size[::-1], interpolator=sitk.sitkLinear) + + if output_filename is not None: + output_filename = Path(output_filename) + logger.info(f"Writing image to: {output_filename}") + sitk.WriteImage(img, str(output_filename)) + + return img