Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[3/N] Data sources - docs (#272)
Browse files Browse the repository at this point in the history
* Some docs

* Update docs

* Add docs for auto datasets

* Add docs

* Updates

* Updates

* Uopdates

* Finish docs for data_module

* Updates

* Update preprocess docstring

* Docs for transforms

* Add docstrings

* Updates

* Updates to custom_task

* Updates

* Updates

* Update custom_task.rst

* Update custom_task.rst

* Update custom_task.rst

* Fixes

* Updates

* Fix notebook

* Updates
  • Loading branch information
ethanwharris authored May 11, 2021
1 parent e24aa62 commit fb6402b
Show file tree
Hide file tree
Showing 19 changed files with 1,083 additions and 411 deletions.
265 changes: 161 additions & 104 deletions docs/source/custom_task.rst

Large diffs are not rendered by default.

81 changes: 35 additions & 46 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,31 +185,6 @@ Example::
# Set ``preprocess_cls`` with your custom ``preprocess``.
preprocess_cls = ImageClassificationPreprocess

@classmethod
def from_folders(
cls,
train_folder: Optional[str],
val_folder: Optional[str],
test_folder: Optional[str],
predict_folder: Optional[str],
preprocess: Optional[Preprocess] = None,
**kwargs
):

# Set a custom ``Preprocess`` if none was provided
preprocess = preprocess or cls.preprocess_cls()

# {stage}_load_data_input will be given to your
# ``Preprocess`` ``{stage}_load_data`` function.
return cls.from_load_data_inputs(
train_load_data_input=train_folder,
val_load_data_input=val_folder,
test_load_data_input=test_folder,
predict_load_data_input=predict_folder,
preprocess=preprocess, # DON'T FORGET TO PASS THE CREATED PREPROCESS
**kwargs,
)


3. The Preprocess
__________________
Expand All @@ -218,9 +193,12 @@ Finally, implement your custom ``ImageClassificationPreprocess``.

Example::

from typing import Any, Callable, Dict, Optional, Tuple, Union
import os
import numpy as np
from flash.data.data_source import DefaultDataSources
from flash.data.process import Preprocess
from flash.vision.data import ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource
from PIL import Image
import torchvision.transforms as T
from torch import Tensor
Expand All @@ -231,29 +209,32 @@ Example::

to_tensor = T.ToTensor()

def load_data(self, folder: str, dataset: AutoDataset) -> Iterable:
# The AutoDataset is optional but can be useful to save some metadata.

# metadata contains the image path and its corresponding label with the following structure:
# [(image_path_1, label_1), ... (image_path_n, label_n)].
metadata = make_dataset(folder)

# for the train ``AutoDataset``, we want to store the ``num_classes``.
if self.training:
dataset.num_classes = len(np.unique([m[1] for m in metadata]))

return metadata
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
):
super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.PATHS: ImagePathsDataSource(),
DefaultDataSources.NUMPY: ImageNumpyDataSource(),
DefaultDataSources.TENSOR: ImageTensorDataSource(),
},
default_data_source=DefaultDataSources.PATHS,
)

def predict_load_data(self, predict_folder: str) -> Iterable:
# This returns [image_path_1, ... image_path_m].
return os.listdir(folder)
def get_state_dict(self) -> Dict[str, Any]:
return {**self.transforms}

def load_sample(self, sample: Union[str, Tuple[str, int]]) -> Tuple[Image, int]
if self.predicting:
return Image.open(image_path)
else:
image_path, label = sample
return Image.open(image_path), label
@classmethod
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
return cls(**state_dict)

def to_tensor_transform(
self,
Expand Down Expand Up @@ -285,6 +266,14 @@ __________
.. autoclass:: flash.data.data_source.DataSource
:members:

.. autoclass:: flash.data.data_source.DefaultDataSources
:members:
:undoc-members:

.. autoclass:: flash.data.data_source.DefaultDataKeys
:members:
:undoc-members:


----------

Expand Down
21 changes: 12 additions & 9 deletions flash/data/auto_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,20 @@


class BaseAutoDataset(Generic[DATA_TYPE]):

DATASET_KEY = "dataset"
"""This class is used to encapsulate a Preprocess Object ``load_data`` and ``load_sample`` functions. ``load_data``
will be called within the ``__init__`` function of the AutoDataset if ``running_stage`` is provided and
``load_sample`` within ``__getitem__``.
"""The ``BaseAutoDataset`` class wraps the output of a call to :meth:`~flash.data.data_source.DataSource.load_data`
and a :class:`~fash.data.data_source.DataSource` and provides the ``_call_load_sample`` method to call
:meth:`~flash.data.data_source.DataSource.load_sample` with the correct
:class:`~flash.data.utils.CurrentRunningStageFuncContext` for the current ``running_stage``. Inheriting classes are
responsible for extracting samples from ``data`` to be given to ``_call_load_sample``.
Args:
data: The output of a call to :meth:`~flash.data.data_source.load_data`.
data: The output of a call to :meth:`~flash.data.data_source.DataSource.load_data`.
data_source: The :class:`~flash.data.data_source.DataSource` which has the ``load_sample`` method.
running_stage: The current running stage.
"""

DATASET_KEY = "dataset"

def __init__(
self,
data: DATA_TYPE,
Expand Down Expand Up @@ -93,6 +92,8 @@ def _call_load_sample(self, sample: Any) -> Any:


class AutoDataset(BaseAutoDataset[Sequence], Dataset):
"""The ``AutoDataset`` is a ``BaseAutoDataset`` and a :class:`~torch.utils.data.Dataset`. The `data` argument
must be a ``Sequence`` (it must have a length)."""

def __getitem__(self, index: int) -> Any:
return self._call_load_sample(self.data[index])
Expand All @@ -102,6 +103,8 @@ def __len__(self) -> int:


class IterableAutoDataset(BaseAutoDataset[Iterable], IterableDataset):
"""The ``IterableAutoDataset`` is a ``BaseAutoDataset`` and a :class:`~torch.utils.data.IterableDataset`. The `data`
argument must be an ``Iterable``."""

def __iter__(self):
self.data_iter = iter(self.data)
Expand Down
32 changes: 21 additions & 11 deletions flash/data/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,16 @@ class BaseDataFetcher(FlashCallback):
from flash.data.callback import BaseDataFetcher
from flash.data.data_module import DataModule
from flash.data.data_source import DataSource
from flash.data.process import Preprocess
class CustomPreprocess(Preprocess):
def __init__(**kwargs):
super().__init__(
data_sources = {"inputs": DataSource()},
**kwargs,
)
class PrintData(BaseDataFetcher):
Expand All @@ -90,6 +99,8 @@ def print(self):
class CustomDataModule(DataModule):
preprocess_cls = CustomPreprocess
@staticmethod
def configure_data_fetcher():
return PrintData()
Expand All @@ -100,17 +111,16 @@ def from_inputs(
train_data: Any,
val_data: Any,
test_data: Any,
predict_data: Any) -> "CustomDataModule":
preprocess = CustomPreprocess()
return cls.from_load_data_inputs(
train_load_data_input=train_data,
val_load_data_input=val_data,
test_load_data_input=test_data,
predict_load_data_input=predict_data,
preprocess=preprocess,
batch_size=5)
predict_data: Any,
) -> "CustomDataModule":
return cls.from_data_source(
"inputs",
train_data=train_data,
val_data=val_data,
test_data=test_data,
predict_data=predict_data,
batch_size=5,
)
dm = CustomDataModule.from_inputs(range(5), range(5), range(5), range(5))
data_fetcher = dm.data_fetcher
Expand Down
Loading

0 comments on commit fb6402b

Please sign in to comment.