Skip to content

Commit

Permalink
Merge pull request #51 from brainglobe/add_save_any
Browse files Browse the repository at this point in the history
Add save_any function to image_io
  • Loading branch information
alessandrofelder authored Mar 15, 2024
2 parents 8dd06d2 + d9161b6 commit d336fc5
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 112 deletions.
36 changes: 36 additions & 0 deletions brainglobe_utils/image_io/save.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from pathlib import Path

import nrrd
import numpy as np
Expand Down Expand Up @@ -104,3 +105,38 @@ def to_nrrd(img_volume, dest_path):
"""
dest_path = str(dest_path)
nrrd.write(dest_path, img_volume)


def save_any(img_volume, dest_path):
"""
Save the image volume (numpy array) to the given file path, using the save
function matching its file extension.
Parameters
----------
img_volume : np.ndarray
The image to be saved.
dest_path : str or pathlib.Path
The file path to save the image to.
Supports directories (will save a sequence of tiffs), .tif, .tiff,
.nrrd and .nii.
"""
dest_path = Path(dest_path)

if dest_path.is_dir():
to_tiffs(img_volume, str(dest_path / "image"))

elif dest_path.suffix == ".tif" or dest_path.suffix == ".tiff":
to_tiff(img_volume, str(dest_path))

elif dest_path.suffix == ".nrrd":
to_nrrd(img_volume, str(dest_path))

elif dest_path.suffix == ".nii":
to_nii(img_volume, str(dest_path))

else:
raise NotImplementedError(
f"Could not guess data type for path {dest_path}"
)
194 changes: 82 additions & 112 deletions tests/tests/test_image_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,90 +29,53 @@ def image_array(request, array_2d, array_3d):


@pytest.fixture()
def shuffled_txt_path(tmp_path, array_3d):
def txt_path(tmp_path, array_3d):
"""
Return the path to a text file containing the paths of a series of 2D tiffs
in a random order
in order
"""
txt_path = tmp_path / "imgs_file.txt"
write_tiff_sequence_with_txt_file(txt_path, array_3d)

# Shuffle paths in the text file into a random order
with open(txt_path, "r+") as f:
tiff_paths = f.read().splitlines()
random.Random(4).shuffle(tiff_paths)
f.seek(0)
f.writelines(line + "\n" for line in tiff_paths)
f.truncate()

return txt_path


def write_tiff_sequence_with_txt_file(txt_path, image_array):
"""
Write an image array to a series of tiffs, and write a text file
containing all the tiff file paths in order (one per line).
The tiff sequence will be saved to a sub-folder inside the same folder
as the text file.
Parameters
----------
txt_path : pathlib.Path
Filepath of text file to create
image_array : np.ndarray
Image to write as sequence of tiffs
"""
directory = txt_path.parent

# Write tiff sequence to sub-folder
sub_dir = directory / "sub"
sub_dir.mkdir()
save.to_tiffs(image_array, str(sub_dir / "image_array"))
save.to_tiffs(array_3d, str(sub_dir / "image"))

# Write txt file containing all tiff file paths (one per line)
tiff_paths = sorted(sub_dir.iterdir())
txt_path.write_text(
"\n".join([str(sub_dir / fname) for fname in tiff_paths])
)

return txt_path

def save_any(file_path, image_array):
"""
Save image_array to given file path, using the save function matching
its file extension

Parameters
----------
file_path : pathlib.Path
File path of image to save
image_array : np.ndarray
Image to save
@pytest.fixture()
def shuffled_txt_path(txt_path):
"""
if file_path.is_dir():
save.to_tiffs(image_array, str(file_path / "image_array"))

elif file_path.suffix == ".txt":
write_tiff_sequence_with_txt_file(file_path, image_array)

elif file_path.suffix == ".tif" or file_path.suffix == ".tiff":
save.to_tiff(image_array, str(file_path))

elif file_path.suffix == ".nrrd":
save.to_nrrd(image_array, str(file_path))
Return the path to a text file containing the paths of a series of 2D tiffs
in a random order
"""
# Shuffle paths in the text file into a random order
with open(txt_path, "r+") as f:
tiff_paths = f.read().splitlines()
random.Random(4).shuffle(tiff_paths)
f.seek(0)
f.writelines(line + "\n" for line in tiff_paths)
f.truncate()

elif file_path.suffix == ".nii":
save.to_nii(image_array, str(file_path), scale=(1, 1, 1))
return txt_path


def test_tiff_io(tmp_path, image_array):
"""
Test that a 2D/3D tiff can be written and read correctly
"""
dest_path = tmp_path / "image_array.tiff"
save_any(dest_path, image_array)
save.save_any(image_array, dest_path)
reloaded = load.load_any(str(dest_path))

reloaded = load.load_img_stack(str(dest_path), 1, 1, 1)
assert (reloaded == image_array).all()


Expand All @@ -127,11 +90,14 @@ def test_3d_tiff_scaling(
Test that a 3D tiff is scaled correctly on loading
"""
dest_path = tmp_path / "image_array.tiff"
save_any(dest_path, array_3d)

reloaded = load.load_img_stack(
str(dest_path), x_scaling_factor, y_scaling_factor, z_scaling_factor
save.save_any(array_3d, dest_path)
reloaded = load.load_any(
str(dest_path),
x_scaling_factor=x_scaling_factor,
y_scaling_factor=y_scaling_factor,
z_scaling_factor=z_scaling_factor,
)

assert reloaded.shape[0] == array_3d.shape[0] * z_scaling_factor
assert reloaded.shape[1] == array_3d.shape[1] * y_scaling_factor
assert reloaded.shape[2] == array_3d.shape[2] * x_scaling_factor
Expand All @@ -149,10 +115,10 @@ def test_tiff_sequence_io(tmp_path, array_3d, load_parallel):
Test that a 3D image can be written and read correctly as a sequence
of 2D tiffs (with or without parallel loading)
"""
save_any(tmp_path, array_3d)
reloaded_array = load.load_from_folder(
str(tmp_path), 1, 1, 1, load_parallel=load_parallel
)
save.save_any(array_3d, tmp_path)
assert len(list(tmp_path.glob("*.tif"))) == array_3d.shape[0]

reloaded_array = load.load_any(str(tmp_path), load_parallel=load_parallel)
assert (reloaded_array == array_3d).all()


Expand All @@ -166,29 +132,19 @@ def test_tiff_sequence_scaling(
"""
Test that a tiff sequence is scaled correctly on loading
"""
save_any(tmp_path, array_3d)
reloaded_array = load.load_from_folder(
str(tmp_path), x_scaling_factor, y_scaling_factor, z_scaling_factor
save.save_any(array_3d, tmp_path)
reloaded_array = load.load_any(
str(tmp_path),
x_scaling_factor=x_scaling_factor,
y_scaling_factor=y_scaling_factor,
z_scaling_factor=z_scaling_factor,
)

assert reloaded_array.shape[0] == array_3d.shape[0] * z_scaling_factor
assert reloaded_array.shape[1] == array_3d.shape[1] * y_scaling_factor
assert reloaded_array.shape[2] == array_3d.shape[2] * x_scaling_factor


def test_load_img_sequence_from_txt(tmp_path, array_3d):
"""
Test that a tiff sequence can be loaded from a text file containing an
ordered list of the tiff file paths (one per line)
"""
img_sequence_file = tmp_path / "imgs_file.txt"
save_any(img_sequence_file, array_3d)

# Load image from paths in text file
reloaded_array = load.load_img_sequence(str(img_sequence_file), 1, 1, 1)
assert (reloaded_array == array_3d).all()


@pytest.mark.parametrize(
"sort",
[True, False],
Expand All @@ -209,54 +165,58 @@ def test_sort_img_sequence_from_txt(shuffled_txt_path, array_3d, sort):

def test_nii_io(tmp_path, array_3d):
"""
Test that a 3D image can be written and read correctly as nii
Test that a 3D image can be written and read correctly as nii with scale
(keeping it as a nifty object with no numpy conversion on loading)
"""
nii_path = tmp_path / "test_array.nii"
save_any(nii_path, array_3d)
assert (load.load_nii(str(nii_path)).get_fdata() == array_3d).all()
nii_path = str(tmp_path / "test_array.nii")
scale = (5, 5, 5)
save.to_nii(array_3d, nii_path, scale=scale)
reloaded = load.load_nii(nii_path)

assert (reloaded.get_fdata() == array_3d).all()
assert reloaded.header.get_zooms() == scale


def test_nii_read_to_numpy(tmp_path, array_3d):
"""
Test that conversion of loaded nii image to an in-memory numpy array works
"""
nii_path = tmp_path / "test_array.nii"
save_any(nii_path, array_3d)
save.save_any(array_3d, nii_path)
reloaded_array = load.load_any(str(nii_path), as_numpy=True)

reloaded_array = load.load_nii(str(nii_path), as_array=True, as_numpy=True)
assert (reloaded_array == array_3d).all()


def test_nrrd_io(tmp_path, array_3d):
"""
Test that a 3D image can be written and read correctly as nrrd
"""
nrrd_path = tmp_path / "test_array.nrrd"
save_any(nrrd_path, array_3d)
assert (load.load_nrrd(str(nrrd_path)) == array_3d).all()


@pytest.mark.parametrize(
"file_name",
[
"test_array.tiff",
"test_array.tif",
"test_array.txt",
"test_array.nrrd",
"test_array.nii",
pytest.param("", id="dir of tiffs"),
],
)
def test_load_any(tmp_path, array_3d, file_name):
def test_save_and_load_any(tmp_path, array_3d, file_name):
"""
Test that load_any can read all required image file types
Test that save_any/load_any can write/read all required image
file types.
"""
src_path = tmp_path / file_name
save_any(src_path, array_3d)
save.save_any(array_3d, src_path)

assert (load.load_any(str(src_path)) == array_3d).all()


def test_load_any_txt(txt_path, array_3d):
"""
Test that load_any can read a tiff sequence from a text file containing an
ordered list of the tiff file paths (one per line)
"""
assert (load.load_any(str(txt_path)) == array_3d).all()


def test_load_any_error(tmp_path):
"""
Test that load_any throws an error for an unknown file extension
Expand All @@ -265,6 +225,14 @@ def test_load_any_error(tmp_path):
load.load_any(str(tmp_path / "test.unknown"))


def test_save_any_error(tmp_path, array_3d):
"""
Test that save_any throws an error for an unknown file extension
"""
with pytest.raises(NotImplementedError):
save.save_any(array_3d, str(tmp_path / "test.unknown"))


def test_scale_z(array_3d):
"""
Test that a 3D image can be scaled in z by float and integer values
Expand All @@ -273,22 +241,24 @@ def test_scale_z(array_3d):
assert utils.scale_z(array_3d, 2).shape[0] == array_3d.shape[0] * 2


@pytest.mark.parametrize(
"file_name",
[
"test_array.txt",
pytest.param("", id="dir of tiffs"),
],
)
def test_image_size(tmp_path, array_3d, file_name):
def test_image_size_dir(tmp_path, array_3d):
"""
Test that image size can be detected from a directory of 2D tiffs, or
a text file containing the paths of a sequence of 2D tiffs
Test that image size can be detected from a directory of 2D tiffs
"""
file_path = tmp_path / file_name
save_any(file_path, array_3d)
save.save_any(array_3d, tmp_path)

image_shape = load.get_size_image_from_file_paths(str(file_path))
image_shape = load.get_size_image_from_file_paths(str(tmp_path))
assert image_shape["x"] == array_3d.shape[2]
assert image_shape["y"] == array_3d.shape[1]
assert image_shape["z"] == array_3d.shape[0]


def test_image_size_txt(txt_path, array_3d):
"""
Test that image size can be detected from a text file containing the paths
of a sequence of 2D tiffs
"""
image_shape = load.get_size_image_from_file_paths(str(txt_path))
assert image_shape["x"] == array_3d.shape[2]
assert image_shape["y"] == array_3d.shape[1]
assert image_shape["z"] == array_3d.shape[0]

0 comments on commit d336fc5

Please sign in to comment.