Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
fix channel dim selection on segmentation target (#1509)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
4 people authored May 11, 2023
1 parent e10fca5 commit 8a0a962
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/flash/image/segmentation/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def load_data(

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = np.array(load_image(sample[DataKeys.TARGET])).transpose((2, 0, 1))[:, :, 0]
sample[DataKeys.TARGET] = np.array(load_image(sample[DataKeys.TARGET])).transpose((2, 0, 1))[0, :, :]
return super().load_sample(sample)


Expand Down
70 changes: 66 additions & 4 deletions tests/image/semantic_segm/test_data.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Callable, Dict, List, Tuple

import numpy as np
import pytest
import torch

from flash import Trainer
from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.transforms import ApplyToKeys
from flash.core.utilities.imports import (
_FIFTYONE_AVAILABLE,
_MATPLOTLIB_AVAILABLE,
Expand Down Expand Up @@ -43,12 +45,22 @@ def _rand_labels(size: Tuple[int, int], num_classes: int):
return Image.fromarray(data.astype(np.uint8))


def create_random_data(image_files: List[str], label_files: List[str], size: Tuple[int, int], num_classes: int):
def create_random_data(
image_files: List[str], label_files: List[str], size: Tuple[int, int], num_classes: int
) -> Tuple[List[Image.Image], List[Image.Image]]:
imgs = []
for img_file in image_files:
_rand_image(size).save(img_file)
img = _rand_image(size)
img.save(img_file)
imgs.append(img)

labels = []
for label_file in label_files:
_rand_labels(size, num_classes).save(label_file)
label = _rand_labels(size, num_classes)
label.save(label_file)
labels.append(label)

return imgs, labels


@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.")
Expand All @@ -58,6 +70,56 @@ def test_smoke():
dm = SemanticSegmentationData(batch_size=1)
assert dm is not None

@staticmethod
@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.")
def test_identity(tmpdir):
class IdentityTransform(InputTransform):
def per_sample_transform(self) -> Callable:
return ApplyToKeys(
DataKeys.INPUT,
np.array,
)

def per_batch_transform(self) -> Callable:
return lambda x: x

tmp_dir = Path(tmpdir)

# create random dummy data

os.makedirs(str(tmp_dir / "images"))
os.makedirs(str(tmp_dir / "targets"))

images = [str(tmp_dir / "images" / "img1.png")]

targets = [str(tmp_dir / "targets" / "img1.png")]

num_classes: int = 2
img_size: Tuple[int, int] = (128, 128)
images_data, targets_data = create_random_data(images, targets, img_size, num_classes)

# instantiate the data module

dm = SemanticSegmentationData.from_files(
test_files=images,
test_targets=targets,
batch_size=1,
num_workers=0,
num_classes=num_classes,
transform=IdentityTransform(),
)

assert dm is not None
assert dm.test_dataloader() is not None

# check test data
data = next(iter(dm.test_dataloader()))
imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET]
assert imgs.shape == (1, 128, 128, 3)
assert labels.shape == (1, 128, 128)
assert torch.allclose(imgs, torch.from_numpy(np.array(images_data[0])))
assert torch.allclose(labels, torch.from_numpy(np.array(targets_data[0]))[:, :, 0])

@staticmethod
def test_from_folders(tmpdir):
tmp_dir = Path(tmpdir)
Expand Down

0 comments on commit 8a0a962

Please sign in to comment.