Skip to content

Commit

Permalink
WIP: initial prototype of zarr_extract function
Browse files Browse the repository at this point in the history
  • Loading branch information
blowekamp committed Jun 5, 2023
1 parent 92205d5 commit a9037a1
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 3 deletions.
4 changes: 2 additions & 2 deletions pytools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.
#

from .workflow_functions import visual_min_max
from .workflow_functions import visual_min_max, zarr_extract

import logging

Expand All @@ -27,4 +27,4 @@
pass


__all__ = ["__version__", "visual_min_max", "logger"]
__all__ = ["__version__", "visual_min_max", "zarr_extract", "logger"]
123 changes: 122 additions & 1 deletion pytools/workflow_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@
#

from pathlib import Path
from typing import Dict, Union
from typing import Dict, Union, Sequence
import logging
import math
from SimpleITK.utilities.dask import from_sitk
from pytools.utils.histogram import DaskHistogramHelper, ZARRHistogramHelper, histogram_robust_stats, weighted_quantile
import SimpleITK as sitk
import dask.array

import zarr


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def visual_min_max(
Expand Down Expand Up @@ -93,3 +96,121 @@ def visual_min_max(
}

return output


def zarr_extract(
input_zarr: Union[Path, str],
target_size: Sequence[int],
size_factor: float = 1.5,
output_filename: Union[Path, str, None] = None,
) -> Union[sitk.Image, None]:
"""Extracts a volume from an OME-NGFF pyramid structured ZARR array.
:param input_zarr: The path to an OME-NGFF structured ZARR array.
:param target_size: The size of the subvolume to extract. [Z, Y, X]
:param size_factor: The size of the subvolume to extract will be increased by this factor so that the
extracted subvolume can have antialiasing applied to it.
:param output_filename: If not None then the extracted subvolume will be written to this file.
:return: The extracted subvolume as a SimpleITK image.
"""

input_zarr = Path(input_zarr)

store = zarr.DirectoryStore(input_zarr)
group = zarr.group(store=store)
logger.debug(group.info)

if "multiscales" not in group.attrs:
raise ValueError(f"Missing OME-NGFF multiscales meta data in zarr group: {input_zarr}")

for image_meta in group.attrs["multiscales"]:
axes = image_meta["axes"]

spacial_dim = [d for d, ax in enumerate(axes) if ax["type"].lower() == "space"]

if len(target_size) != len(spacial_dim):
raise ValueError(f"target_size: {target_size} does not match the number of spacial dimensions: {spacial_dim}")

_target_size = target_size.copy()
_target_size = [_target_size.pop(0) if ax["type"].lower() == "space" else 0 for ax in axes]

logger.debug(f"_target_size: {_target_size}")

path_to_size = {}
for dataset in image_meta["datasets"]:
level_path = dataset["path"]

arr = zarr.Array(store, read_only=True, path=level_path)
path_to_size[level_path] = arr.shape

logger.debug(f"path: {level_path} shape: {arr.shape}")

# sort by total number of pixels, smallest first
path_to_size = dict(sorted(path_to_size.items(), key=lambda item: math.prod(item[1])))
logger.debug(f"path_to_size: {path_to_size}")

max_size_per_dim = next(iter(path_to_size.values()))
for path, size in path_to_size.items():
max_size_per_dim = [max(s, m) for s, m in zip(size, max_size_per_dim)]

_target_size = [min(t, m) for t, m in zip(_target_size, max_size_per_dim)]

logger.debug(f"max_size_per_dim: {max_size_per_dim}")

# search the pyramid for the first level that is larger than the target size with factor
for path, size in path_to_size.items():
logger.debug(f"size: {size} _target_size: {_target_size} max_size_per_dim: {max_size_per_dim}")

if all([s > size_factor * t for s, t, m in zip(size, _target_size, max_size_per_dim) if t > 0 and t != m]):
break
logger.debug(f"selected path: {path} size: {size}")

arr = zarr.Array(store, read_only=True, path=path)
arr = arr.astype(arr.dtype.newbyteorder("="))
logger.debug(arr.info)

channel_dim = None
idx = []
for d, ax in enumerate(axes):
ax_type = ax["type"].lower()
ax_name = ax["name"].lower()
if ax_type == "channel":
channel_dim = d
idx.append(slice(None))
elif ax_type == "space":
if ax_name == "z":
idx.append((arr.shape[d] - 1) // 2)
else:
idx.append(slice(None))
elif ax_type == "time":
idx.append((arr.shape[d] - 1) // 2)

logger.debug(f"channel_dim: {channel_dim} idx: {idx}")

img_list = []
if channel_dim is not None:
for d in range(arr.shape[channel_dim]):
idx[channel_dim] = d
img_list.append(sitk.GetImageFromArray(arr[tuple(idx)]))

# This should be able to be replaces with a PermuteAxesImageFilter and a ToVector, when available in SimpleITK
img = sitk.Compose(img_list)
elif len(axes) < 4:
img = sitk.GetImageFromArray(arr[tuple(idx)])
else:
raise ValueError(f"Unsupported axes types: {[ax['type'] for ax in axes]}")

logger.debug(img)

logger.debug(f"resizing image of: {img.GetSize() } -> {target_size[::-1]}")
img = sitk.utilities.resize(img, target_size[::-1], interpolator=sitk.sitkLinear)

if output_filename is not None:
output_filename = Path(output_filename)
logger.info(f"Writing image to: {output_filename}")
sitk.WriteImage(img, str(output_filename))

return img

0 comments on commit a9037a1

Please sign in to comment.