Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Workaround for lack of zsd support in czifile #1142

Merged
merged 9 commits into from
Jul 15, 2024
74 changes: 72 additions & 2 deletions package/PartSegImage/image_reader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import inspect
import os.path
import typing
from abc import abstractmethod
from contextlib import suppress
from importlib.metadata import version
from io import BytesIO
from pathlib import Path
from threading import Lock

import imagecodecs
import numpy as np
import tifffile
from czifile.czifile import CziFile
from czifile.czifile import DECOMPRESS, CziFile
from defusedxml import ElementTree
from oiffile import OifFile
from packaging.version import parse as parse_version

from PartSegImage.image import Image

Expand All @@ -21,6 +25,72 @@
from xml.etree.ElementTree import Element # nosec


class ZSTD1Header(typing.NamedTuple):
"""
ZSTD1 header structure
based on:
https://github.com/ZEISS/libczi/blob/4a60e22200cbf0c8ff2a59f69a81ef1b2b89bf4f/Src/libCZI/decoder_zstd.cpp#L19
"""

header_size: int
hiLoByteUnpackPreprocessing: bool


def parse_zstd1_header(data, size):
"""
Parse ZSTD header

https://github.com/ZEISS/libczi/blob/4a60e22200cbf0c8ff2a59f69a81ef1b2b89bf4f/Src/libCZI/decoder_zstd.cpp#L84
"""
if size < 1:
return ZSTD1Header(0, False)

if data[0] == 1:
return ZSTD1Header(1, False)

if data[0] == 3 and size < 3:
return ZSTD1Header(0, False)

if data[1] == 1:
return ZSTD1Header(3, bool(data[2] & 1))

return ZSTD1Header(0, False)


def _get_dtype():
return inspect.currentframe().f_back.f_back.f_locals["de"].dtype


def decode_zstd1(data):
"""
Decode ZSTD1 data
"""
header = parse_zstd1_header(data, len(data))
dtype = _get_dtype()
if header.hiLoByteUnpackPreprocessing:
array_ = np.fromstring(imagecodecs.zstd_decode(data[header.header_size :]), np.uint8)
half_size = array_.size // 2
array = np.empty(half_size, np.uint16)
array[:] = array_[:half_size] + (array_[half_size:].astype(np.uint16) << 8)
array = array.view(dtype)
else:
array = np.fromstring(imagecodecs.zstd_decode(data[header.header_size :]), dtype)
return array
Czaki marked this conversation as resolved.
Show resolved Hide resolved


def decode_zstd0(data):
"""
Decode ZSTD0 data
"""
dtype = _get_dtype()
return np.fromstring(imagecodecs.zstd_decode(data), dtype)
Czaki marked this conversation as resolved.
Show resolved Hide resolved


if parse_version(version("czifile")) == parse_version("2019.7.2"):
DECOMPRESS[5] = decode_zstd0
DECOMPRESS[6] = decode_zstd1


def _empty(_, __):
"""Empty function for callback"""

Expand Down Expand Up @@ -270,7 +340,7 @@ class CziImageReader(BaseImageReaderBuffer):

def read(self, image_path: typing.Union[str, BytesIO, Path], mask_path=None, ext=None) -> Image:
image_file = CziFile(image_path)
image_data = image_file.asarray()
image_data = image_file.asarray(max_workers=1)
image_data = self.update_array_shape(image_data, image_file.axes)
metadata = image_file.metadata(False)
with suppress(KeyError):
Expand Down
9 changes: 9 additions & 0 deletions package/tests/test_PartSegImage/test_image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ def test_czi_file_read(self, data_test_dir):

assert np.all(np.isclose(image.spacing, (7.752248561753867e-08,) * 2))

def test_czi_file_read_compressed(self, data_test_dir):
image = CziImageReader.read_image(os.path.join(data_test_dir, "test_czi_compressed.czi"))
assert image.channels == 4
assert image.layers == 1

assert image.file_path == os.path.join(data_test_dir, "test_czi_compressed.czi")

assert np.all(np.isclose(image.spacing, (7.752248561753867e-08,) * 2))

def test_czi_file_read_buffer(self, data_test_dir):
with open(os.path.join(data_test_dir, "test_czi.czi"), "rb") as f_p:
buffer = BytesIO(f_p.read())
Expand Down
Loading