Skip to content

Commit

Permalink
feat: add ImageDataset.split (#846)
Browse files Browse the repository at this point in the history
Closes #831 

### Summary of Changes

feat: add `ImageDataset.split`

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
Co-authored-by: Lars Reimann <mail@larsreimann.com>
  • Loading branch information
3 people authored Jun 24, 2024
1 parent d33cb5d commit 3878751
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/safeds/data/image/containers/_image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def __contains__(self, item: object) -> bool:
Returns
-------
has_item:
Weather the given item is in this image list
Whether the given item is in this image list
"""
return isinstance(item, Image) and self.has_image(item)

Expand Down Expand Up @@ -524,7 +524,7 @@ def has_image(self, image: Image) -> bool:
Returns
-------
has_image:
Weather the given image is in this image list
Whether the given image is in this image list
"""

# ------------------------------------------------------------------------------------------------------------------
Expand Down
126 changes: 99 additions & 27 deletions src/safeds/data/labeled/containers/_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ImageDataset(Dataset[ImageList, Out_co]):
batch_size:
the batch size used for training
shuffle:
weather the data should be shuffled after each epoch of training
whether the data should be shuffled after each epoch of training
"""

def __init__(self, input_data: ImageList, output_data: Out_co, batch_size: int = 1, shuffle: bool = False) -> None:
Expand Down Expand Up @@ -108,13 +108,13 @@ def __iter__(self) -> ImageDataset:
return im_ds

def __next__(self) -> tuple[Tensor, Tensor]:
if self._next_batch_index * self._batch_size >= len(self._input):
if self._next_batch_index * self._batch_size >= len(self._shuffle_tensor_indices):
raise StopIteration
self._next_batch_index += 1
return self._get_batch(self._next_batch_index - 1)

def __len__(self) -> int:
return self._input.image_count
return len(self._shuffle_tensor_indices)

def __eq__(self, other: object) -> bool:
"""
Expand All @@ -138,6 +138,7 @@ def __eq__(self, other: object) -> bool:
and isinstance(other._output, type(self._output))
and (self._input == other._input)
and (self._output == other._output)
and (self._shuffle_tensor_indices.tolist() == other._shuffle_tensor_indices.tolist())
)

def __hash__(self) -> int:
Expand All @@ -149,7 +150,13 @@ def __hash__(self) -> int:
hash:
the hash value
"""
return _structural_hash(self._input, self._output, self._shuffle_after_epoch, self._batch_size)
return _structural_hash(
self._input,
self._output,
self._shuffle_after_epoch,
self._batch_size,
self._shuffle_tensor_indices.tolist(),
)

def __sizeof__(self) -> int:
"""
Expand Down Expand Up @@ -205,7 +212,7 @@ def get_input(self) -> ImageList:
input:
the input data of this dataset
"""
return self._sort_image_list_with_shuffle_tensor_indices(self._input)
return self._sort_image_list_with_shuffle_tensor_indices_reduce_if_necessary(self._input)

def get_output(self) -> Out_co:
"""
Expand All @@ -222,19 +229,25 @@ def get_output(self) -> Out_co:
elif isinstance(output, _ColumnAsTensor):
return output._to_column(self._shuffle_tensor_indices) # type: ignore[return-value]
else:
return self._sort_image_list_with_shuffle_tensor_indices(self._output) # type: ignore[return-value]
return self._sort_image_list_with_shuffle_tensor_indices_reduce_if_necessary(self._output) # type: ignore[return-value]

def _sort_image_list_with_shuffle_tensor_indices(self, image_list: _SingleSizeImageList) -> _SingleSizeImageList:
def _sort_image_list_with_shuffle_tensor_indices_reduce_if_necessary(
self,
image_list: _SingleSizeImageList,
) -> _SingleSizeImageList:
shuffled_image_list = _SingleSizeImageList()
shuffled_image_list._tensor = image_list._tensor
shuffled_image_list._indices_to_tensor_positions = {
index: self._shuffle_tensor_indices[tensor_position].item()
for index, tensor_position in image_list._indices_to_tensor_positions.items()
tensor_pos = [
image_list._indices_to_tensor_positions[shuffled_index]
for shuffled_index in sorted(self._shuffle_tensor_indices.tolist())
]
temp_pos = {
shuffled_index: new_index for new_index, shuffled_index in enumerate(self._shuffle_tensor_indices.tolist())
}
shuffled_image_list._tensor = image_list._tensor[tensor_pos]
shuffled_image_list._tensor_positions_to_indices = [
index
for index, _ in sorted(shuffled_image_list._indices_to_tensor_positions.items(), key=lambda item: item[1])
new_index for _, new_index in sorted(temp_pos.items(), key=lambda item: item[0])
]
shuffled_image_list._indices_to_tensor_positions = shuffled_image_list._calc_new_indices_to_tensor_positions()
return shuffled_image_list

def _get_batch(self, batch_number: int, batch_size: int | None = None) -> tuple[Tensor, Tensor]:
Expand All @@ -247,18 +260,18 @@ def _get_batch(self, batch_number: int, batch_size: int | None = None) -> tuple[

_check_bounds("batch_size", batch_size, lower_bound=_ClosedBound(1))

if batch_number < 0 or batch_size * batch_number >= len(self._input):
if batch_number < 0 or batch_size * batch_number >= len(self._shuffle_tensor_indices):
raise IndexOutOfBoundsError(batch_size * batch_number)
max_index = (
batch_size * (batch_number + 1) if batch_size * (batch_number + 1) < len(self._input) else len(self._input)
batch_size * (batch_number + 1)
if batch_size * (batch_number + 1) < len(self._shuffle_tensor_indices)
else len(self._shuffle_tensor_indices)
)
input_tensor = (
self._input._tensor[
self._shuffle_tensor_indices[
[
self._input._indices_to_tensor_positions[index]
for index in range(batch_size * batch_number, max_index)
]
[
self._input._indices_to_tensor_positions[index]
for index in self._shuffle_tensor_indices[batch_size * batch_number : max_index].tolist()
]
].to(torch.float32)
/ 255
Expand All @@ -267,11 +280,9 @@ def _get_batch(self, batch_number: int, batch_size: int | None = None) -> tuple[
if isinstance(self._output, _SingleSizeImageList):
output_tensor = (
self._output._tensor[
self._shuffle_tensor_indices[
[
self._output._indices_to_tensor_positions[index]
for index in range(batch_size * batch_number, max_index)
]
[
self._input._indices_to_tensor_positions[index]
for index in self._shuffle_tensor_indices[batch_size * batch_number : max_index].tolist()
]
].to(torch.float32)
/ 255
Expand All @@ -284,7 +295,7 @@ def shuffle(self) -> ImageDataset[Out_co]:
"""
Return a new `ImageDataset` with shuffled data.
The original dataset list is not modified.
The original dataset is not modified.
Returns
-------
Expand All @@ -296,10 +307,71 @@ def shuffle(self) -> ImageDataset[Out_co]:
_init_default_device()

im_dataset: ImageDataset[Out_co] = copy.copy(self)
im_dataset._shuffle_tensor_indices = torch.randperm(len(self))
im_dataset._shuffle_tensor_indices = self._shuffle_tensor_indices[
torch.randperm(len(self._shuffle_tensor_indices))
]
im_dataset._next_batch_index = 0
return im_dataset

def split(
self,
percentage_in_first: float,
*,
shuffle: bool = True,
) -> tuple[ImageDataset[Out_co], ImageDataset[Out_co]]:
"""
Create two image datasets by splitting the data of the current dataset.
The first dataset contains a percentage of the data specified by `percentage_in_first`, and the second dataset
contains the remaining data.
The original dataset is not modified.
By default, the data is shuffled before splitting. You can disable this by setting `shuffle` to False.
Parameters
----------
percentage_in_first:
The percentage of data to include in the first dataset. Must be between 0 and 1.
shuffle:
Whether to shuffle the data before splitting.
Returns
-------
first_dataset:
The first dataset.
second_dataset:
The second dataset.
Raises
------
OutOfBoundsError
If `percentage_in_first` is not between 0 and 1.
"""
import torch

_check_bounds(
"percentage_in_first",
percentage_in_first,
lower_bound=_ClosedBound(0),
upper_bound=_ClosedBound(1),
)

first_dataset: ImageDataset[Out_co] = copy.copy(self)
second_dataset: ImageDataset[Out_co] = copy.copy(self)

if shuffle:
shuffled_indices = torch.randperm(len(self._shuffle_tensor_indices))
else:
shuffled_indices = torch.arange(len(self._shuffle_tensor_indices))

first_dataset._shuffle_tensor_indices, second_dataset._shuffle_tensor_indices = shuffled_indices.split(
[
round(percentage_in_first * len(self)),
len(self) - round(percentage_in_first * len(self)),
],
)
return first_dataset, second_dataset


class _TableAsTensor:
def __init__(self, table: Table) -> None:
Expand Down
103 changes: 103 additions & 0 deletions tests/safeds/data/labeled/containers/test_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,109 @@ def test_get_batch_device(self, device: Device) -> None:
assert batch[1].device == _get_device()


@pytest.mark.parametrize("device", get_devices(), ids=get_devices_ids())
@pytest.mark.parametrize("shuffle", [True, False])
class TestSplit:

@pytest.mark.parametrize(
"output",
[
Column("images", images_all()[:4] + images_all()[5:]),
Table(
{
"0": [1, 0, 0, 0, 0, 0],
"1": [0, 1, 0, 0, 0, 0],
"2": [0, 0, 1, 0, 0, 0],
"3": [0, 0, 0, 1, 0, 0],
"4": [0, 0, 0, 0, 1, 0],
"5": [0, 0, 0, 0, 0, 1],
},
),
_EmptyImageList(),
],
ids=["Column", "Table", "ImageList"],
)
def test_should_split(self, device: Device, shuffle: bool, output: Column | Table | ImageList) -> None:
configure_test_with_device(device)
image_list = ImageList.from_files(resolve_resource_path(images_all())).remove_duplicate_images().resize(10, 10)
if isinstance(output, _EmptyImageList):
output = image_list
image_dataset = ImageDataset(image_list, output) # type: ignore[type-var]
image_dataset1, image_dataset2 = image_dataset.split(0.4, shuffle=shuffle)
offset = len(image_dataset1)
assert len(image_dataset1) == round(0.4 * len(image_dataset))
assert len(image_dataset2) == len(image_dataset) - offset
assert len(image_dataset1.get_input()) == round(0.4 * len(image_dataset))
assert len(image_dataset2.get_input()) == len(image_dataset) - offset
im1_output = image_dataset1.get_output()
im2_output = image_dataset2.get_output()
if isinstance(im1_output, Table):
assert im1_output.row_count == round(0.4 * len(image_dataset))
else:
assert len(im1_output) == round(0.4 * len(image_dataset))
if isinstance(im2_output, Table):
assert im2_output.row_count == len(image_dataset) - offset
else:
assert len(im2_output) == len(image_dataset) - offset

assert image_dataset != image_dataset1
assert image_dataset != image_dataset2
assert image_dataset1 != image_dataset2

for i, image in enumerate(image_dataset1.get_input().to_images()):
index = image_list.index(image)[0]
if not shuffle:
assert index == i
out = image_dataset1.get_output()
if isinstance(out, ImageList):
assert image_list.index(out.get_image(i))[0] == index
elif isinstance(out, Column) and isinstance(output, Column):
assert output.to_list().index(out.to_list()[i]) == index
elif isinstance(out, Table) and isinstance(output, Table):
assert output.get_column(str(index)).to_list()[index] == 1

for i, image in enumerate(image_dataset2.get_input().to_images()):
index = image_list.index(image)[0]
if not shuffle:
assert index == i + offset
out = image_dataset2.get_output()
if isinstance(out, ImageList):
assert image_list.index(out.get_image(i))[0] == index
elif isinstance(out, Column) and isinstance(output, Column):
assert output.to_list().index(out.to_list()[i]) == index
elif isinstance(out, Table) and isinstance(output, Table):
assert output.get_column(str(index)).to_list()[index] == 1

image_dataset._batch_size = len(image_dataset)
image_dataset1._batch_size = 1
image_dataset2._batch_size = 1
image_dataset_batch = next(iter(image_dataset))

for i, b in enumerate(iter(image_dataset1)):
assert b[0] in image_dataset_batch[0]
index = (b[0] == image_dataset_batch[0]).all(dim=[1, 2, 3]).nonzero()[0][0]
if not shuffle:
assert index == i
assert torch.all(torch.eq(b[0], image_dataset_batch[0][index]))
assert torch.all(torch.eq(b[1], image_dataset_batch[1][index]))

for i, b in enumerate(iter(image_dataset2)):
assert b[0] in image_dataset_batch[0]
index = (b[0] == image_dataset_batch[0]).all(dim=[1, 2, 3]).nonzero()[0][0]
if not shuffle:
assert index == i + offset
assert torch.all(torch.eq(b[0], image_dataset_batch[0][index]))
assert torch.all(torch.eq(b[1], image_dataset_batch[1][index]))

@pytest.mark.parametrize("percentage", [-1, -0.1, 1.1, 2])
def test_should_raise(self, device: Device, shuffle: bool, percentage: float) -> None:
configure_test_with_device(device)
image_list = ImageList.from_files(resolve_resource_path(images_all())).resize(10, 10)
image_dataset = ImageDataset(image_list, Column("images", images_all()))
with pytest.raises(OutOfBoundsError):
image_dataset.split(percentage, shuffle=shuffle)


@pytest.mark.parametrize("device", get_devices(), ids=get_devices_ids())
class TestTableAsTensor:
def test_should_raise_if_not_one_hot_encoded(self, device: Device) -> None:
Expand Down

0 comments on commit 3878751

Please sign in to comment.