diff --git a/src/flash/image/segmentation/input.py b/src/flash/image/segmentation/input.py index 4662e1c986..02068c96be 100644 --- a/src/flash/image/segmentation/input.py +++ b/src/flash/image/segmentation/input.py @@ -94,7 +94,7 @@ def load_data( def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: if DataKeys.TARGET in sample: - sample[DataKeys.TARGET] = np.array(load_image(sample[DataKeys.TARGET])).transpose((2, 0, 1))[:, :, 0] + sample[DataKeys.TARGET] = np.array(load_image(sample[DataKeys.TARGET])).transpose((2, 0, 1))[0, :, :] return super().load_sample(sample) diff --git a/tests/image/semantic_segm/test_data.py b/tests/image/semantic_segm/test_data.py index 0c07e99fb0..aae5129c90 100644 --- a/tests/image/semantic_segm/test_data.py +++ b/tests/image/semantic_segm/test_data.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Dict, List, Tuple +from typing import Callable, Dict, List, Tuple import numpy as np import pytest @@ -8,6 +8,8 @@ from flash import Trainer from flash.core.data.io.input import DataKeys +from flash.core.data.io.input_transform import InputTransform +from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, _MATPLOTLIB_AVAILABLE, @@ -43,12 +45,22 @@ def _rand_labels(size: Tuple[int, int], num_classes: int): return Image.fromarray(data.astype(np.uint8)) -def create_random_data(image_files: List[str], label_files: List[str], size: Tuple[int, int], num_classes: int): +def create_random_data( + image_files: List[str], label_files: List[str], size: Tuple[int, int], num_classes: int +) -> Tuple[List[Image.Image], List[Image.Image]]: + imgs = [] for img_file in image_files: - _rand_image(size).save(img_file) + img = _rand_image(size) + img.save(img_file) + imgs.append(img) + labels = [] for label_file in label_files: - _rand_labels(size, num_classes).save(label_file) + label = _rand_labels(size, num_classes) + label.save(label_file) + labels.append(label) + + return imgs, labels @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @@ -58,6 +70,56 @@ def test_smoke(): dm = SemanticSegmentationData(batch_size=1) assert dm is not None + @staticmethod + @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") + def test_identity(tmpdir): + class IdentityTransform(InputTransform): + def per_sample_transform(self) -> Callable: + return ApplyToKeys( + DataKeys.INPUT, + np.array, + ) + + def per_batch_transform(self) -> Callable: + return lambda x: x + + tmp_dir = Path(tmpdir) + + # create random dummy data + + os.makedirs(str(tmp_dir / "images")) + os.makedirs(str(tmp_dir / "targets")) + + images = [str(tmp_dir / "images" / "img1.png")] + + targets = [str(tmp_dir / "targets" / "img1.png")] + + num_classes: int = 2 + img_size: Tuple[int, int] = (128, 128) + images_data, targets_data = create_random_data(images, targets, img_size, num_classes) + + # instantiate the data module + + dm = SemanticSegmentationData.from_files( + test_files=images, + test_targets=targets, + batch_size=1, + num_workers=0, + num_classes=num_classes, + transform=IdentityTransform(), + ) + + assert dm is not None + assert dm.test_dataloader() is not None + + # check test data + data = next(iter(dm.test_dataloader())) + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] + assert imgs.shape == (1, 128, 128, 3) + assert labels.shape == (1, 128, 128) + assert torch.allclose(imgs, torch.from_numpy(np.array(images_data[0]))) + assert torch.allclose(labels, torch.from_numpy(np.array(targets_data[0]))[:, :, 0]) + @staticmethod def test_from_folders(tmpdir): tmp_dir = Path(tmpdir)