Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix get_ranges for empty channels in Image #1136

Merged
merged 6 commits into from
Jul 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 48 additions & 43 deletions package/PartSegImage/image.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import re
import typing
import warnings
from collections.abc import Iterable
from contextlib import suppress
from typing import Union

import numpy as np

Expand Down Expand Up @@ -33,8 +34,8 @@ def minimal_dtype(val: int):

def reduce_array(
array: np.ndarray,
components: typing.Optional[typing.Collection[int]] = None,
max_val: typing.Optional[int] = None,
components: typing.Collection[int] | None = None,
max_val: int | None = None,
dtype=None,
) -> np.ndarray:
"""
Expand Down Expand Up @@ -102,12 +103,12 @@ def __init__(
data: _IMAGE_DATA,
image_spacing: Spacing,
file_path=None,
mask: typing.Union[None, np.ndarray] = None,
mask: None | np.ndarray = None,
default_coloring=None,
ranges=None,
channel_names=None,
axes_order: typing.Optional[str] = None,
shift: typing.Optional[Spacing] = None,
axes_order: str | None = None,
shift: Spacing | None = None,
name: str = "",
):
# TODO add time distance to image spacing
Expand Down Expand Up @@ -144,18 +145,26 @@ def __init__(
self.default_coloring = [np.array(x) for x in default_coloring]

self._channel_names = self._prepare_channel_names(channel_names, self.channels)

self.ranges = self._adjust_ranges(ranges, self._channel_arrays)
self._mask_array = self._fit_mask(mask, data, axes_order)

@staticmethod
def _adjust_ranges(
ranges: list[tuple[float, float]] | None, channel_arrays: list[np.ndarray]
) -> list[tuple[float, float]]:
if ranges is None:
self.ranges = list(
zip((np.min(c) for c in self._channel_arrays), (np.max(c) for c in self._channel_arrays))
)
else:
self.ranges = ranges
self._mask_array = self._prepare_mask(mask, data, axes_order)
if self._mask_array is not None:
self._mask_array = self.fit_mask_to_image(self._mask_array)
ranges = list(zip((np.min(c) for c in channel_arrays), (np.max(c) for c in channel_arrays)))
return [(min_val, max_val) if (min_val != max_val) else (min_val, min_val + 1) for (min_val, max_val) in ranges]

def _fit_mask(self, mask, data, axes_order):
mask_array = self._prepare_mask(mask, data, axes_order)
if mask_array is not None:
mask_array = self.fit_mask_to_image(mask_array)
return mask_array

@classmethod
def _prepare_mask(cls, mask, data, axes_order) -> typing.Optional[np.ndarray]:
def _prepare_mask(cls, mask, data, axes_order) -> np.ndarray | None:
if mask is None:
return None

Expand All @@ -170,7 +179,7 @@ def _prepare_mask(cls, mask, data, axes_order) -> typing.Optional[np.ndarray]:
return cls.reorder_axes(mask, axes_order.replace("C", ""))

@staticmethod
def _prepare_channel_names(channel_names, channels_num) -> typing.List[str]:
def _prepare_channel_names(channel_names, channels_num) -> list[str]:
default_channel_names = [f"channel {i + 1}" for i in range(channels_num)]
if isinstance(channel_names, str):
channel_names = [channel_names]
Expand All @@ -182,9 +191,7 @@ def _prepare_channel_names(channel_names, channels_num) -> typing.List[str]:
return channel_names_list[:channels_num]

@classmethod
def _split_data_on_channels(
cls, data: typing.Union[np.ndarray, typing.List[np.ndarray]], axes_order: str
) -> typing.List[np.ndarray]:
def _split_data_on_channels(cls, data: np.ndarray | list[np.ndarray], axes_order: str) -> list[np.ndarray]:
if isinstance(data, list) and not axes_order.startswith("C"): # pragma: no cover
raise ValueError("When passing data as list of numpy arrays then Channel must be first axis.")
if "C" not in axes_order:
Expand All @@ -199,7 +206,7 @@ def _split_data_on_channels(

if not isinstance(data, np.ndarray):
raise TypeError("If `data` is list of arrays then `axes_order` must start with `C`") # pragma: no cover
pos: typing.List[typing.Union[slice, int]] = [slice(None) for _ in range(data.ndim)]
pos: list[slice | int] = [slice(None) for _ in range(data.ndim)]
c_pos = axes_order.index("C")
res = []
for i in range(data.shape[c_pos]):
Expand All @@ -208,9 +215,7 @@ def _split_data_on_channels(
return res

@staticmethod
def _merge_channel_names(
base_channel_names: typing.List[str], new_channel_names: typing.List[str]
) -> typing.List[str]:
def _merge_channel_names(base_channel_names: list[str], new_channel_names: list[str]) -> list[str]:
base_channel_names = base_channel_names[:]
reg = re.compile(r"channel \d+")
for name in new_channel_names:
Expand All @@ -228,7 +233,7 @@ def _merge_channel_names(
base_channel_names.append(new_name)
return base_channel_names

def merge(self, image: "Image", axis: str) -> "Image":
def merge(self, image: Image, axis: str) -> Image:
"""
Produce new image merging image data along given axis. All metadata
are obtained from self.
Expand Down Expand Up @@ -256,7 +261,7 @@ def merge(self, image: "Image", axis: str) -> "Image":
return self.substitute(data=data, ranges=self.ranges + image.ranges, channel_names=channel_names)

@property
def channel_names(self) -> typing.List[str]:
def channel_names(self) -> list[str]:
return self._channel_names[:]

@property
Expand Down Expand Up @@ -333,7 +338,7 @@ def substitute(
default_coloring=None,
ranges=None,
channel_names=None,
) -> "Image":
) -> Image:
"""Create copy of image with substitution of not None elements"""
data = self._channel_arrays if data is None else data
image_spacing = self._image_spacing if image_spacing is None else image_spacing
Expand All @@ -353,7 +358,7 @@ def substitute(
axes_order=self.axis_order,
)

def set_mask(self, mask: typing.Optional[np.ndarray], axes: typing.Optional[str] = None):
def set_mask(self, mask: np.ndarray | None, axes: str | None = None):
"""
Set mask for image, check if it has proper shape.

Expand All @@ -374,7 +379,7 @@ def get_data(self) -> np.ndarray:
return self._channel_arrays[0]

@property
def mask(self) -> typing.Optional[np.ndarray]:
def mask(self) -> np.ndarray | None:
return self._mask_array[:] if self._mask_array is not None else None

@staticmethod
Expand Down Expand Up @@ -430,7 +435,7 @@ def get_image_for_save(self) -> np.ndarray:
)
return self._reorder_axes(self._channel_arrays[0], self.axis_order, "TZCYX")

def get_mask_for_save(self) -> typing.Optional[np.ndarray]:
def get_mask_for_save(self) -> np.ndarray | None:
"""
:return: if image has mask then return mask with axes in proper order
"""
Expand Down Expand Up @@ -469,7 +474,7 @@ def times(self) -> int:
return self._channel_arrays[0].shape[self.time_pos]

@property
def plane_shape(self) -> typing.Tuple[int, int]:
def plane_shape(self) -> tuple[int, int]:
"""y,x size of image"""
return self._channel_arrays[0].shape[self.y_pos], self._channel_arrays[0].shape[self.x_pos]

Expand All @@ -487,15 +492,15 @@ def swap_time_and_stack(self):
return self.substitute(data=self._image_data_normalize(image_array_list))

@classmethod
def get_axis_positions(cls) -> typing.Dict[str, int]:
def get_axis_positions(cls) -> dict[str, int]:
"""
:return: dict with mapping axis to its position
:rtype: dict
"""
return {letter: i for i, letter in enumerate(cls.axis_order)}

@classmethod
def get_array_axis_positions(cls) -> typing.Dict[str, int]:
def get_array_axis_positions(cls) -> dict[str, int]:
"""
:return: dict with mapping axis to its position for array fitted to image
:rtype: dict
Expand All @@ -510,7 +515,7 @@ def get_data_by_axis(self, **kwargs) -> np.ndarray:
:return:
:rtype:
"""
slices: typing.List[typing.Union[int, slice]] = [slice(None) for _ in range(len(self.array_axis_order))]
slices: list[int | slice] = [slice(None) for _ in range(len(self.array_axis_order))]
axis_pos = self.get_array_axis_positions()
if "c" in kwargs:
kwargs["C"] = kwargs.pop("c")
Expand All @@ -533,7 +538,7 @@ def get_data_by_axis(self, **kwargs) -> np.ndarray:
return self._channel_arrays[channel][slices_t]
return np.stack([x[slices_t] for x in self._channel_arrays[channel]], axis=axis_order.index("C"))

def clip_array(self, array: np.ndarray, **kwargs: typing.Union[int, slice]) -> np.ndarray:
def clip_array(self, array: np.ndarray, **kwargs: int | slice) -> np.ndarray:
"""
Clip array by axis. Axis is selected by single letter from :py:attr:`axis_order`

Expand All @@ -542,14 +547,14 @@ def clip_array(self, array: np.ndarray, **kwargs: typing.Union[int, slice]) -> n
:return: clipped array
"""
array = self.fit_array_to_image(array)
slices: typing.List[typing.Union[int, slice]] = [slice(None) for _ in range(len(self.array_axis_order))]
slices: list[int | slice] = [slice(None) for _ in range(len(self.array_axis_order))]
axis_pos = self.get_array_axis_positions()
for name in kwargs:
if (n := name.upper()) in axis_pos:
slices[axis_pos[n]] = kwargs[name]
return array[tuple(slices)]

def get_channel(self, num: Union[int, str, Channel]) -> np.ndarray:
def get_channel(self, num: int | str | Channel) -> np.ndarray:
"""
Alias for :py:func:`get_sub_data` with argument ``c=num``

Expand Down Expand Up @@ -611,7 +616,7 @@ def set_spacing(self, value: Spacing):
self._image_spacing = tuple(value)

@staticmethod
def _frame_array(array: typing.Optional[np.ndarray], index_to_add: typing.List[int], frame=FRAME_THICKNESS):
def _frame_array(array: np.ndarray | None, index_to_add: list[int], frame=FRAME_THICKNESS):
if array is None: # pragma: no cover
return array
result_shape = list(array.shape)
Expand All @@ -626,7 +631,7 @@ def _frame_array(array: typing.Optional[np.ndarray], index_to_add: typing.List[i
return data

@staticmethod
def calc_index_to_frame(array_axis: str, important_axis: str) -> typing.List[int]:
def calc_index_to_frame(array_axis: str, important_axis: str) -> list[int]:
"""
calculate in which axis frame should be added

Expand All @@ -650,15 +655,15 @@ def _frame_cut_area(self, cut_area: typing.Iterable[slice], frame: int):

def _cut_image_slices(
self, cut_area: typing.Iterable[slice], frame: int
) -> typing.Tuple[typing.List[np.ndarray], typing.Optional[np.ndarray]]:
) -> tuple[list[np.ndarray], np.ndarray | None]:
new_mask = None
cut_area = self._frame_cut_area(cut_area, frame)
new_image = [x[tuple(cut_area)] for x in self._channel_arrays]
if self._mask_array is not None:
new_mask = self._mask_array[tuple(cut_area)]
return new_image, new_mask

def _roi_to_slices(self, roi: np.ndarray) -> typing.List[slice]:
def _roi_to_slices(self, roi: np.ndarray) -> list[slice]:
cut_area = self.fit_array_to_image(roi)
points = np.nonzero(cut_area)
lower_bound = np.min(points, axis=1)
Expand Down Expand Up @@ -689,11 +694,11 @@ def _cut_with_roi(self, cut_area: np.ndarray, replace_mask: bool, frame: int):

def cut_image(
self,
cut_area: typing.Union[np.ndarray, typing.Iterable[slice]],
cut_area: np.ndarray | typing.Iterable[slice],
replace_mask=False,
frame: int = FRAME_THICKNESS,
zero_out_cut_area: bool = True,
) -> "Image":
) -> Image:
"""
Create new image base on mask or list of slices
:param bool replace_mask: if cut area is represented by mask array,
Expand Down Expand Up @@ -771,7 +776,7 @@ def get_um_shift(self) -> Spacing:
"""image spacing in micrometers"""
return tuple(float(x * 10**6) for x in self.shift)

def get_ranges(self) -> typing.List[typing.Tuple[float, float]]:
def get_ranges(self) -> list[tuple[float, float]]:
"""image brightness ranges for each channel"""
return self.ranges[:]

Expand Down
Loading