Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add FlashDataset update #853

Merged
merged 4 commits into from
Oct 11, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions flash/core/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Generic, Iterable, Optional, Sequence, Type, TypeVar
from abc import abstractmethod
from typing import Any, Callable, Iterable, Mapping, Optional, Type, Union

from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -20,24 +21,26 @@
from flash.core.data.properties import Properties
from flash.core.registry import FlashRegistry

DATA_TYPE = TypeVar("DATA_TYPE")


class BaseDataset(Generic[DATA_TYPE], Properties):
class BaseDataset(Properties):

DATASET_KEY = "dataset"

transform_registry: Optional[FlashRegistry] = None

def load_data(self, data: Any) -> Any:
return data
@abstractmethod
def load_data(self, data: Any) -> Union[Iterable, Mapping]:
"""The `load_data` hook should return either a Mapping or an Iterable.

Override to add your dataset logic creation logic.
"""

@abstractmethod
def load_sample(self, data: Any) -> Any:
return data
"""The `load_sample` hook contains the logic to load a single sample."""

def __init__(self, running_stage: RunningStage) -> None:
super().__init__()

self.running_stage = running_stage

def pass_args_to_load_data(
Expand All @@ -60,7 +63,7 @@ def running_stage(self, running_stage: RunningStage) -> None:
def _resolve_functions(self, func_name: str, cls: Type["BaseDataset"]) -> None:
from flash.core.data.data_pipeline import DataPipeline # noqa F811

function: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr(
function: Callable[[Any, Optional[Any]], Any] = getattr(
self,
DataPipeline._resolve_function_hierarchy(
func_name,
Expand All @@ -80,7 +83,7 @@ def _call_load_sample(self, sample: Any) -> Any:
def from_data(
cls,
*load_data_args,
running_stage: RunningStage = None,
running_stage: Optional[RunningStage] = None,
**dataset_kwargs: Any,
) -> "BaseDataset":
if not running_stage:
Expand All @@ -95,20 +98,30 @@ def from_data(
def resolve_functions(self):
raise NotImplementedError

_load_data = load_data
_load_sample = load_sample
# Set to None as they are dymically resolved when the dataset is made stage aware
# c.f running_stage is set in `__init__` function.
_load_data = None
_load_sample = None
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved


class FlashDataset(BaseDataset[Sequence], Dataset):
class FlashDataset(Dataset, BaseDataset):
"""The ``FlashDataset`` is a ``BaseDataset`` and a :class:`~torch.utils.data.Dataset`.

The `data` argument must be a ``Sequence`` (it must have a length).
The `data` argument must be a ``Mapping`` (it must have a length).
"""

def load_data(self, data: Any) -> Any:
def load_data(self, data: Any) -> Mapping:
"""By default, the `load_data` perform an identity operation.

Override to add your own logic to load the data.
"""
return data

def load_sample(self, data: Any) -> Any:
"""By default, the `load_sample` perform an identity operation.

Override to add your own logic to load a single sample.
"""
return data

def resolve_functions(self):
Expand All @@ -122,16 +135,24 @@ def __len__(self) -> int:
return len(self.data)


class FlashIterableDataset(BaseDataset[Iterable], IterableDataset):
class FlashIterableDataset(IterableDataset, BaseDataset):
"""The ``IterableAutoDataset`` is a ``BaseDataset`` and a :class:`~torch.utils.data.IterableDataset`.

The `data` argument must be an ``Iterable``.
"""

def load_data(self, data: Any) -> Any:
def load_data(self, data: Any) -> Iterable:
"""By default, the `load_data` perform an identity operation.

Override to add your own logic to load the data.
"""
return data

def load_sample(self, data: Any) -> Any:
"""By default, the `load_sample` perform an identity operation.

Override to add your own logic to load a single sample.
"""
return data

def resolve_functions(self):
Expand Down