Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
Implement better tile aggregation (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
AjeyPaiK authored Jul 15, 2024
1 parent 36397a4 commit fc84395
Show file tree
Hide file tree
Showing 3 changed files with 345 additions and 19 deletions.
56 changes: 38 additions & 18 deletions ahcore/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class FileImageReader(abc.ABC):
def __init__(self, filename: Path, stitching_mode: StitchingMode) -> None:
self._filename = filename
self._stitching_mode = stitching_mode

self.__empty_tile: GenericNumberArray | None = None

self._file: Optional[Any] = None
Expand Down Expand Up @@ -185,8 +184,8 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in
if self._file is None:
self._open_file()
assert self._file, "File is not open. Should not happen"
assert self._tile_size
assert self._tile_overlap
assert self._tile_size is not None, "self._tile_size should not be None"
assert self._tile_overlap is not None, "self._tile_overlap should not be None"

image_dataset = self._file["data"]
num_tiles = self._metadata["num_tiles"]
Expand All @@ -203,14 +202,19 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in
logger.error(f"Requested region is out of bounds: {location}, {self._size}")
raise ValueError("Requested region is out of bounds")

if self._stitching_mode not in StitchingMode:
raise ValueError(f"Stitching mode {self._stitching_mode} is not supported.")

start_row = y // self._stride[1]
end_row = min((y + h - 1) // self._stride[1] + 1, total_rows)
start_col = x // self._stride[0]
end_col = min((x + w - 1) // self._stride[0] + 1, total_cols)

if self._stitching_mode == StitchingMode.AVERAGE:
divisor_array = np.zeros((h, w), dtype=np.uint8)
stitched_image = np.zeros((self._num_channels, h, w), dtype=self._dtype)

if self._stitching_mode == StitchingMode.AVERAGE:
average_mask = np.zeros((h, w), dtype=self._dtype)

for i in range(start_row, end_row):
for j in range(start_col, end_col):
tile_idx = (i * total_cols) + j
Expand Down Expand Up @@ -245,27 +249,43 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in
stitched_image[:, img_start_y:img_end_y, img_start_x:img_end_x] = cropped_tile

elif self._stitching_mode == StitchingMode.AVERAGE:
raise NotImplementedError
tile_start_y = max(0, -start_y)
tile_end_y = img_end_y - img_start_y
tile_end_y = min(self._tile_size[1], h - start_y)
tile_start_x = max(0, -start_x)
tile_end_x = img_end_x - img_start_x
tile_end_x = min(self._tile_size[0], w - start_x)

# TODO: Replace this with crop_to_bbox
cropped_tile = tile[tile_start_y:tile_end_y, tile_start_x:tile_end_x]
stitched_image[img_start_y:img_end_y, img_start_x:img_end_x] += cropped_tile
divisor_array[img_start_y:img_end_y, img_start_x:img_end_x] += 1
else:
raise ValueError("Unsupported stitching mode")
average_mask[img_start_y:img_end_y, img_start_x:img_end_x] += 1
stitched_image[:, img_start_y:img_end_y, img_start_x:img_end_x] += tile[
:, tile_start_y:tile_end_y, tile_start_x:tile_end_x
]

if self._stitching_mode == StitchingMode.AVERAGE:
stitched_image = (stitched_image / divisor_array[..., np.newaxis]).astype(float)
elif self._stitching_mode == StitchingMode.MAXIMUM:
tile_start_y = max(0, -start_y)
tile_end_y = min(self._tile_size[1], h - start_y)
tile_start_x = max(0, -start_x)
tile_end_x = min(self._tile_size[0], w - start_x)

if self._precision != str(InferencePrecision.FP32):
# Always convert to float32.
if i == start_row and j == start_col:
# The first tile cannot be compared with anything. So, we just copy it.
stitched_image[:, img_start_y:img_end_y, img_start_x:img_end_x] = tile[
:, tile_start_y:tile_end_y, tile_start_x:tile_end_x
]

stitched_image[:, img_start_y:img_end_y, img_start_x:img_end_x] = np.maximum(
stitched_image[:, img_start_y:img_end_y, img_start_x:img_end_x],
tile[:, tile_start_y:tile_end_y, tile_start_x:tile_end_x],
)

# Adjust the precision and convert to float32 before averaging to avoid loss of precision.
if self._precision != str(InferencePrecision.UINT8) or self._stitching_mode == StitchingMode.AVERAGE:
stitched_image = stitched_image / self._multiplier
stitched_image = stitched_image.astype(np.float32)

if self._stitching_mode == StitchingMode.AVERAGE:
overlap_regions = average_mask > 0
# Perform division to average the accumulated pixel values
stitched_image[:, overlap_regions] = stitched_image[:, overlap_regions] / average_mask[overlap_regions]

return pyvips.Image.new_from_array(stitched_image.transpose(1, 2, 0))

@abc.abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"hydra-submitit-launcher>=1.2.0",
"hydra-optuna-sweeper>=1.3.0.dev0",
"hydra-colorlog>=1.2.0",
"dlup>=0.5.1",
"dlup>=0.5.3",
"kornia>=0.7.2",
"h5py>=3.8.0",
"monai[einops]>=1.3.0",
Expand Down
Loading

0 comments on commit fc84395

Please sign in to comment.