Skip to content

Commit

Permalink
Refactor and API update
Browse files Browse the repository at this point in the history
  • Loading branch information
sameeul committed Oct 11, 2024
1 parent d91ae24 commit 8dd64ac
Showing 1 changed file with 59 additions and 60 deletions.
119 changes: 59 additions & 60 deletions src/python/argolid/pyramid_compositor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,28 +95,28 @@ def get_zarr_write_spec(

class PyramidCompositor:
"""
A class for compositing well images into a pyramid structure.
A class for composing a group of pyramid images into an assembled pyramid structure.
"""

def __init__(
self, well_pyramid_loc: str, out_dir: str, pyramid_file_name: str
self, input_pyramids_loc: str, out_dir: str, output_pyramid_name: str
) -> None:
"""
Initializes the PyramidCompositor object.
Args:
well_pyramid_loc (str): The location of the well pyramids.
input_pyramids_loc (str): The location of the input pyramid images.
out_dir (str): The output directory for the composed zarr pyramid file.
pyramid_file_name (str): The name of the zarr pyramid file.
output_pyramid_name (str): The name of the zarr pyramid file.
"""
self._well_pyramid_loc: str = well_pyramid_loc
self._tile_cache: set = set()
self._pyramid_file_name: str = f"{out_dir}/{pyramid_file_name}"
self._ome_metadata_file: str = f"{out_dir}/{pyramid_file_name}/METADATA.ome.xml"
self._well_map: dict = None
self._input_pyramids_loc: str = input_pyramids_loc
self._chunk_cache: set = set()
self._output_pyramid_name: str = f"{out_dir}/{output_pyramid_name}"
self._ome_metadata_file: str = f"{out_dir}/{output_pyramid_name}/METADATA.ome.xml"
self._composition_map: dict = None
self._plate_image_shapes: dict = {}
self._zarr_arrays: dict = {}
self._well_image_shapes: dict = {}
self._unit_image_shapes: dict = {}
self._pyramid_levels: int = None
self._image_dtype: np.dtype = None
self._num_channels: int = None
Expand Down Expand Up @@ -163,12 +163,12 @@ def _create_zattr_file(self) -> None:
multiscale_metadata.append({"path": str(key)})
attr_dict["datasets"] = multiscale_metadata
attr_dict["version"] = "0.1"
attr_dict["name"] = self._pyramid_file_name
attr_dict["name"] = self._output_pyramid_name
attr_dict["metadata"] = {"method": "mean"}

final_attr_dict = {"multiscales": [attr_dict]}

with open(f"{self._pyramid_file_name}/data.zarr/0/.zattrs", "w") as json_file:
with open(f"{self._output_pyramid_name}/data.zarr/0/.zattrs", "w") as json_file:
json.dump(final_attr_dict, json_file)

def _create_zgroup_file(self) -> None:
Expand All @@ -177,10 +177,10 @@ def _create_zgroup_file(self) -> None:
"""
zgroup_dict = {"zarr_format": 2}

with open(f"{self._pyramid_file_name}/data.zarr/0/.zgroup", "w") as json_file:
with open(f"{self._output_pyramid_name}/data.zarr/0/.zgroup", "w") as json_file:
json.dump(zgroup_dict, json_file)

with open(f"{self._pyramid_file_name}/data.zarr/.zgroup", "w") as json_file:
with open(f"{self._output_pyramid_name}/data.zarr/.zgroup", "w") as json_file:
json.dump(zgroup_dict, json_file)

def _create_auxilary_files(self) -> None:
Expand All @@ -191,7 +191,7 @@ def _create_auxilary_files(self) -> None:
self._create_zattr_file()
self._create_zgroup_file()

def _write_tile_data(
def _write_zarr_chunk(
self, level: int, channel: int, y_index: int, x_index: int
) -> None:
"""
Expand Down Expand Up @@ -220,31 +220,31 @@ def _write_tile_data(
(assembled_height, assembled_width), dtype=self._image_dtype
)

# find what well images are needed
# find what input images are needed

# are we at the begining of the well image?
well_image_height = self._well_image_shapes[level][0]
well_image_width = self._well_image_shapes[level][1]
# are we at the begining of the input image?
unit_image_height = self._unit_image_shapes[level][0]
unit_image_width = self._unit_image_shapes[level][1]

row_start_pos = y_range[0]
while row_start_pos < y_range[1]:
# row and col are well map coordinates
row = row_start_pos // well_image_height
# row and col are unit map coordinates
row = row_start_pos // unit_image_height
# local_y* and local_x* are coordintes in the chunk
local_y_start = row_start_pos - y_range[0]
tile_y_start = row_start_pos - row * well_image_height
tile_y_dim = min((row + 1) * well_image_height - row_start_pos, y_range[1] - row_start_pos)
tile_y_start = row_start_pos - row * unit_image_height
tile_y_dim = min((row + 1) * unit_image_height - row_start_pos, y_range[1] - row_start_pos)
tile_y_end = tile_y_start + tile_y_dim
col_start_pos = x_range[0]
while col_start_pos < x_range[1]:
col = col_start_pos // well_image_width
col = col_start_pos // unit_image_width
local_x_start = col_start_pos - x_range[0]
tile_x_start = col_start_pos - col * well_image_width
tile_x_dim = min((col + 1) * well_image_width - col_start_pos, x_range[1] - col_start_pos)
tile_x_start = col_start_pos - col * unit_image_width
tile_x_dim = min((col + 1) * unit_image_width - col_start_pos, x_range[1] - col_start_pos)
tile_x_end = tile_x_start + tile_x_dim
# read well zarr file
well_file_name = self._well_map.get((col, row, channel))
zarr_file_loc = Path(well_file_name) / "data.zarr/0/"
# read input zarr file
input_file_name = self._composition_map.get((col, row, channel))
zarr_file_loc = Path(input_file_name) / "data.zarr/0/"
zarr_array_loc = zarr_file_loc / str(level)
zarr_file = ts.open(get_zarr_read_spec(str(zarr_array_loc))).result()

Expand All @@ -262,22 +262,21 @@ def _write_tile_data(
row_start_pos += tile_y_end - tile_y_start

zarr_array = self._zarr_arrays[level]
print(assembled_image.sum())
zarr_array[
0, channel, 0, y_range[0] : y_range[1], x_range[0] : x_range[1]
].write(assembled_image).result()

def set_well_map(self, well_map: dict) -> None:
def set_composition(self, composition_map: dict) -> None:
"""
Sets the well map for the pyramid.
Sets the composition for the pyramid.
Args:
well_map (dict): A dictionary mapping well coordinates to file paths.
composition_map (dict): A dictionary mapping composition images to file paths.
"""
self._well_map = well_map
self._well_image_shapes = {}
for coord in well_map:
file = well_map[coord]
self._composition_map = composition_map
self._unit_image_shapes = {}
for coord in composition_map:
file = composition_map[coord]
zarr_file_loc = Path(file) / "data.zarr/0/"
attr_file_loc = Path(file) / "data.zarr/0/.zattrs"
if attr_file_loc.exists():
Expand All @@ -291,7 +290,7 @@ def set_well_map(self, well_map: dict) -> None:
zarr_file = ts.open(
get_zarr_read_spec(str(zarr_array_loc))
).result()
self._well_image_shapes[int(res_key)] = (
self._unit_image_shapes[int(res_key)] = (
zarr_file.shape[-2],
zarr_file.shape[-1],
)
Expand All @@ -304,7 +303,7 @@ def set_well_map(self, well_map: dict) -> None:
num_cols = 0
num_channels = 0

for coord in well_map:
for coord in composition_map:
num_rows = max(num_rows, coord[1])
num_cols = max(num_cols, coord[0])
num_channels = max(num_channels, coord[2])
Expand All @@ -317,29 +316,29 @@ def set_well_map(self, well_map: dict) -> None:

self._plate_image_shapes = {}
self._zarr_arrays = {}
self._tile_cache = set()
for l in self._well_image_shapes:
self._chunk_cache = set()
for l in self._unit_image_shapes:
level = int(l)
self._plate_image_shapes[level] = (
1,
num_channels,
1,
num_rows * self._well_image_shapes[level][0],
num_cols * self._well_image_shapes[level][1],
num_rows * self._unit_image_shapes[level][0],
num_cols * self._unit_image_shapes[level][1],
)
num_row_tiles = math.ceil(
1.0 * num_rows * self._well_image_shapes[level][0] / CHUNK_SIZE
1.0 * num_rows * self._unit_image_shapes[level][0] / CHUNK_SIZE
)
num_col_tiles = math.ceil(
1.0 * num_cols * self._well_image_shapes[level][1] / CHUNK_SIZE
1.0 * num_cols * self._unit_image_shapes[level][1] / CHUNK_SIZE
)
if num_row_tiles == 0:
num_row_tiles = 1
if num_col_tiles == 0:
num_col_tiles == 1
self._zarr_arrays[level] = ts.open(
get_zarr_write_spec(
f"{self._pyramid_file_name}/data.zarr/0/{level}",
f"{self._output_pyramid_name}/data.zarr/0/{level}",
CHUNK_SIZE,
self._plate_image_shapes[level],
np.dtype(self._image_dtype).str,
Expand All @@ -352,21 +351,21 @@ def reset_composition(self) -> None:
"""
Resets the pyramid composition by removing the pyramid file and clearing internal data structures.
"""
shutil.rmtree(self._pyramid_file_name)
self._well_map = None
shutil.rmtree(self._output_pyramid_name)
self._composition_map = None
self._plate_image_shapes = None
self._tile_cache = None
self._chunk_cache = None
self._plate_image_shapes = {}
self._zarr_arrays = {}

def get_tile_data(
def get_zarr_chunk(
self, level: int, channel: int, y_index: int, x_index: int
) -> None:
"""
Retrieves tile data from the pyramid at the specified level, channel, y_index, and x_index.
Retrieves zarr chunk data from the pyramid at the specified level, channel, y_index, and x_index.
This method checks if the requested tile data is already in the cache. If not, it calls
the `_write_tile_data` method to generate the tile data and add it to the cache.
This method checks if the requested zarr chunk is already in the cache. If not, it calls
the `_write_zarr_chunk` method to generate the chunk and add it to the cache.
Args:
level (int): The level of the pyramid.
Expand All @@ -375,14 +374,14 @@ def get_tile_data(
x_index (int): The x-index of the tile.
Raises:
ValueError: If the well map is not set, the requested level does not exist,
ValueError: If the composition map is not set, the requested level does not exist,
the requested channel does not exist, or the requested y_index or x_index
is out of bounds.
"""
if self._well_map is None:
raise ValueError("No well map is set. Unable to generate pyramid")
if self._composition_map is None:
raise ValueError("No composition map is set. Unable to generate pyramid")

if level not in self._well_image_shapes:
if level not in self._unit_image_shapes:
raise ValueError(f"Requested level ({level}) does not exist")

if channel >= self._num_channels:
Expand All @@ -394,9 +393,9 @@ def get_tile_data(
if x_index > (self._plate_image_shapes[level][4] // CHUNK_SIZE):
raise ValueError(f"Requested y index ({x_index}) does not exist")

if (level, channel, y_index, x_index) in self._tile_cache:
if (level, channel, y_index, x_index) in self._chunk_cache:
return
else:
self._write_tile_data(level, channel, y_index, x_index)
self._tile_cache.add((level, channel, y_index, x_index))
self._write_zarr_chunk(level, channel, y_index, x_index)
self._chunk_cache.add((level, channel, y_index, x_index))
return

0 comments on commit 8dd64ac

Please sign in to comment.