Skip to content

Commit

Permalink
feat: added MNIST, Fashion-MNIST and KMNIST datasets (#164)
Browse files Browse the repository at this point in the history
Closes #161 
Closes #162 
Closes #163 

### Summary of Changes

feat: added MNIST, Fashion-MNIST and KMNIST datasets
build: bump safe-ds to ^0.24.0

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
Co-authored-by: Lars Reimann <mail@larsreimann.com>
  • Loading branch information
3 people authored Jul 12, 2024
1 parent 90de957 commit 97ae47a
Show file tree
Hide file tree
Showing 8 changed files with 366 additions and 4 deletions.
4 changes: 1 addition & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ packages = [

[tool.poetry.dependencies]
python = "^3.11,<3.13"
safe-ds = ">=0.17,<0.27"
safe-ds = ">=0.24,<0.27"

[tool.poetry.group.dev.dependencies]
pytest = ">=7.2.1,<9.0.0"
Expand Down
5 changes: 5 additions & 0 deletions src/safeds_datasets/image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Image datasets."""

from ._mnist import load_fashion_mnist, load_kmnist, load_mnist

__all__ = ["load_fashion_mnist", "load_kmnist", "load_mnist"]
5 changes: 5 additions & 0 deletions src/safeds_datasets/image/_mnist/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""MNIST like Datasets."""

from ._mnist import load_fashion_mnist, load_kmnist, load_mnist

__all__ = ["load_fashion_mnist", "load_kmnist", "load_mnist"]
256 changes: 256 additions & 0 deletions src/safeds_datasets/image/_mnist/_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import gzip
import os
import struct
import sys
import urllib.request
from array import array
from pathlib import Path
from typing import TYPE_CHECKING
from urllib.error import HTTPError

import torch
from safeds._config import _init_default_device
from safeds.data.image.containers._single_size_image_list import _SingleSizeImageList
from safeds.data.labeled.containers import ImageDataset
from safeds.data.tabular.containers import Column

if TYPE_CHECKING:
from safeds.data.image.containers import ImageList

_mnist_links: list[str] = ["http://yann.lecun.com/exdb/mnist/", "https://ossci-datasets.s3.amazonaws.com/mnist/"]
_mnist_files: dict[str, str] = {
"train-images-idx3": "train-images-idx3-ubyte.gz",
"train-labels-idx1": "train-labels-idx1-ubyte.gz",
"test-images-idx3": "t10k-images-idx3-ubyte.gz",
"test-labels-idx1": "t10k-labels-idx1-ubyte.gz",
}
_mnist_labels: dict[int, str] = {0: "0", 1: "1", 2: "2", 3: "3", 4: "4", 5: "5", 6: "6", 7: "7", 8: "8", 9: "9"}
_mnist_folder: str = "mnist"

_fashion_mnist_links: list[str] = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
_fashion_mnist_files: dict[str, str] = _mnist_files
_fashion_mnist_labels: dict[int, str] = {
0: "T-shirt/top",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle boot",
}
_fashion_mnist_folder: str = "fashion-mnist"

_kuzushiji_mnist_links: list[str] = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
_kuzushiji_mnist_files: dict[str, str] = _mnist_files
_kuzushiji_mnist_labels: dict[int, str] = {
0: "\u304a",
1: "\u304d",
2: "\u3059",
3: "\u3064",
4: "\u306a",
5: "\u306f",
6: "\u307e",
7: "\u3084",
8: "\u308c",
9: "\u3092",
}
_kuzushiji_mnist_folder: str = "kmnist"


def load_mnist(path: str | Path, download: bool = True) -> tuple[ImageDataset[Column], ImageDataset[Column]]:
"""
Load the `MNIST <http://yann.lecun.com/exdb/mnist/>`_ datasets.
Parameters
----------
path:
the path were the files are stored or will be downloaded to
download:
whether the files should be downloaded to the given path
Returns
-------
train_dataset, test_dataset:
The train and test datasets.
Raises
------
FileNotFoundError
if a file of the dataset cannot be found
"""
path = Path(path) / _mnist_folder
path.mkdir(parents=True, exist_ok=True)
path_files = os.listdir(path)
missing_files = []
for file_path in _mnist_files.values():
if file_path not in path_files:
missing_files.append(file_path)
if len(missing_files) > 0:
if download:
_download_mnist_like(
path,
{name: f_path for name, f_path in _mnist_files.items() if f_path in missing_files},
_mnist_links,
)
else:
raise FileNotFoundError(f"Could not find files {[str(path / file) for file in missing_files]}")
return _load_mnist_like(path, _mnist_files, _mnist_labels)


def load_fashion_mnist(path: str | Path, download: bool = True) -> tuple[ImageDataset[Column], ImageDataset[Column]]:
"""
Load the `Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ datasets.
Parameters
----------
path:
the path were the files are stored or will be downloaded to
download:
whether the files should be downloaded to the given path
Returns
-------
train_dataset, test_dataset:
The train and test datasets.
Raises
------
FileNotFoundError
if a file of the dataset cannot be found
"""
path = Path(path) / _fashion_mnist_folder
path.mkdir(parents=True, exist_ok=True)
path_files = os.listdir(path)
missing_files = []
for file_path in _fashion_mnist_files.values():
if file_path not in path_files:
missing_files.append(file_path)
if len(missing_files) > 0:
if download:
_download_mnist_like(
path,
{name: f_path for name, f_path in _fashion_mnist_files.items() if f_path in missing_files},
_fashion_mnist_links,
)
else:
raise FileNotFoundError(f"Could not find files {[str(path / file) for file in missing_files]}")
return _load_mnist_like(path, _fashion_mnist_files, _fashion_mnist_labels)


def load_kmnist(path: str | Path, download: bool = True) -> tuple[ImageDataset[Column], ImageDataset[Column]]:
"""
Load the `Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ datasets.
Parameters
----------
path:
the path were the files are stored or will be downloaded to
download:
whether the files should be downloaded to the given path
Returns
-------
train_dataset, test_dataset:
The train and test datasets.
Raises
------
FileNotFoundError
if a file of the dataset cannot be found
"""
path = Path(path) / _kuzushiji_mnist_folder
path.mkdir(parents=True, exist_ok=True)
path_files = os.listdir(path)
missing_files = []
for file_path in _kuzushiji_mnist_files.values():
if file_path not in path_files:
missing_files.append(file_path)
if len(missing_files) > 0:
if download:
_download_mnist_like(
path,
{name: f_path for name, f_path in _kuzushiji_mnist_files.items() if f_path in missing_files},
_kuzushiji_mnist_links,
)
else:
raise FileNotFoundError(f"Could not find files {[str(path / file) for file in missing_files]}")
return _load_mnist_like(path, _kuzushiji_mnist_files, _kuzushiji_mnist_labels)


def _load_mnist_like(
path: str | Path,
files: dict[str, str],
labels: dict[int, str],
) -> tuple[ImageDataset[Column], ImageDataset[Column]]:
_init_default_device()

path = Path(path)
test_labels: Column | None = None
train_labels: Column | None = None
test_image_list: ImageList | None = None
train_image_list: ImageList | None = None
for file_name, file_path in files.items():
if "idx1" in file_name:
with gzip.open(path / file_path, mode="rb") as label_file:
magic, size = struct.unpack(">II", label_file.read(8))
if magic != 2049:
raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2049.") # pragma: no cover
if "train" in file_name:
train_labels = Column(
file_name,
[labels[label_index] for label_index in array("B", label_file.read())],
)
else:
test_labels = Column(
file_name,
[labels[label_index] for label_index in array("B", label_file.read())],
)
else:
with gzip.open(path / file_path, mode="rb") as image_file:
magic, size, rows, cols = struct.unpack(">IIII", image_file.read(16))
if magic != 2051:
raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2051.") # pragma: no cover
image_data = array("B", image_file.read())
image_tensor = torch.empty(size, 1, rows, cols, dtype=torch.uint8)
for i in range(size):
image_tensor[i, 0] = torch.frombuffer(
image_data[i * rows * cols : (i + 1) * rows * cols],
dtype=torch.uint8,
).reshape(rows, cols)
image_list = _SingleSizeImageList()
image_list._tensor = image_tensor
image_list._tensor_positions_to_indices = list(range(size))
image_list._indices_to_tensor_positions = image_list._calc_new_indices_to_tensor_positions()
if "train" in file_name:
train_image_list = image_list
else:
test_image_list = image_list
if train_image_list is None or test_image_list is None or train_labels is None or test_labels is None:
raise ValueError # pragma: no cover
return ImageDataset[Column](train_image_list, train_labels, 32, shuffle=True), ImageDataset[Column](
test_image_list,
test_labels,
32,
)


def _download_mnist_like(path: str | Path, files: dict[str, str], links: list[str]) -> None:
path = Path(path)
for file_name, file_path in files.items():
for link in links:
try:
print(f"Trying to download file {file_name} via {link + file_path}") # noqa: T201
urllib.request.urlretrieve(link + file_path, path / file_path, reporthook=_report_download_progress)
print() # noqa: T201
break
except HTTPError as e:
print(f"An error occurred while downloading: {e}") # noqa: T201 # pragma: no cover


def _report_download_progress(current_packages: int, package_size: int, file_size: int) -> None:
percentage = min(((current_packages * package_size) / file_size) * 100, 100)
sys.stdout.write(f"\rDownloading... {percentage:.1f}%")
sys.stdout.flush()
Empty file.
Empty file.
98 changes: 98 additions & 0 deletions tests/safeds_datasets/image/_mnist/test_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
import tempfile
from pathlib import Path

import pytest
import torch
from safeds.data.labeled.containers import ImageDataset
from safeds_datasets.image import _mnist, load_fashion_mnist, load_kmnist, load_mnist


class TestMNIST:

def test_should_download_and_return_mnist(self) -> None:
with tempfile.TemporaryDirectory() as tmpdirname:
train, test = load_mnist(tmpdirname, download=True)
files = os.listdir(Path(tmpdirname) / _mnist._mnist._mnist_folder)
for mnist_file in _mnist._mnist._mnist_files.values():
assert mnist_file in files
assert isinstance(train, ImageDataset)
assert isinstance(test, ImageDataset)
assert len(train) == 60_000
assert len(test) == 10_000
assert (
train.get_input()._as_single_size_image_list()._tensor.dtype
== test.get_input()._as_single_size_image_list()._tensor.dtype
== torch.uint8
)
train_output = train.get_output()
test_output = test.get_output()
assert (
set(train_output.get_distinct_values())
== set(test_output.get_distinct_values())
== set(_mnist._mnist._mnist_labels.values())
)

def test_should_raise_if_file_not_found(self) -> None:
with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(FileNotFoundError):
load_mnist(tmpdirname, download=False)


class TestFashionMNIST:

def test_should_download_and_return_mnist(self) -> None:
with tempfile.TemporaryDirectory() as tmpdirname:
train, test = load_fashion_mnist(tmpdirname, download=True)
files = os.listdir(Path(tmpdirname) / _mnist._mnist._fashion_mnist_folder)
for mnist_file in _mnist._mnist._fashion_mnist_files.values():
assert mnist_file in files
assert isinstance(train, ImageDataset)
assert isinstance(test, ImageDataset)
assert len(train) == 60_000
assert len(test) == 10_000
assert (
train.get_input()._as_single_size_image_list()._tensor.dtype
== test.get_input()._as_single_size_image_list()._tensor.dtype
== torch.uint8
)
train_output = train.get_output()
test_output = test.get_output()
assert (
set(train_output.get_distinct_values())
== set(test_output.get_distinct_values())
== set(_mnist._mnist._fashion_mnist_labels.values())
)

def test_should_raise_if_file_not_found(self) -> None:
with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(FileNotFoundError):
load_fashion_mnist(tmpdirname, download=False)


class TestKMNIST:

def test_should_download_and_return_mnist(self) -> None:
with tempfile.TemporaryDirectory() as tmpdirname:
train, test = load_kmnist(tmpdirname, download=True)
files = os.listdir(Path(tmpdirname) / _mnist._mnist._kuzushiji_mnist_folder)
for mnist_file in _mnist._mnist._kuzushiji_mnist_files.values():
assert mnist_file in files
assert isinstance(train, ImageDataset)
assert isinstance(test, ImageDataset)
assert len(train) == 60_000
assert len(test) == 10_000
assert (
train.get_input()._as_single_size_image_list()._tensor.dtype
== test.get_input()._as_single_size_image_list()._tensor.dtype
== torch.uint8
)
train_output = train.get_output()
test_output = test.get_output()
assert (
set(train_output.get_distinct_values())
== set(test_output.get_distinct_values())
== set(_mnist._mnist._kuzushiji_mnist_labels.values())
)

def test_should_raise_if_file_not_found(self) -> None:
with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(FileNotFoundError):
load_kmnist(tmpdirname, download=False)

0 comments on commit 97ae47a

Please sign in to comment.