From 1a31510620bfc21acd210144b0d0480d96710106 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 11 Oct 2021 14:59:58 +0100 Subject: [PATCH 1/4] update --- flash/core/data/datasets.py | 41 ++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/flash/core/data/datasets.py b/flash/core/data/datasets.py index d2adc39d14..4629cef5b9 100644 --- a/flash/core/data/datasets.py +++ b/flash/core/data/datasets.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Generic, Iterable, Optional, Sequence, Type, TypeVar +from abc import abstractmethod +from typing import Any, Callable, Generic, Iterable, Mapping, Optional, Type, TypeVar from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -29,15 +30,19 @@ class BaseDataset(Generic[DATA_TYPE], Properties): transform_registry: Optional[FlashRegistry] = None - def load_data(self, data: Any) -> Any: - return data + @abstractmethod + def load_data(self, data: Any) -> DATA_TYPE: + """The `load_data` hook should return either a Mapping or an Iterable. + + Override to add your dataset logic creation logic. + """ + @abstractmethod def load_sample(self, data: Any) -> Any: - return data + """The `load_sample` hook contains the logic to load a single sample.""" def __init__(self, running_stage: RunningStage) -> None: super().__init__() - self.running_stage = running_stage def pass_args_to_load_data( @@ -80,7 +85,7 @@ def _call_load_sample(self, sample: Any) -> Any: def from_data( cls, *load_data_args, - running_stage: RunningStage = None, + running_stage: Optional[RunningStage] = None, **dataset_kwargs: Any, ) -> "BaseDataset": if not running_stage: @@ -95,20 +100,26 @@ def from_data( def resolve_functions(self): raise NotImplementedError - _load_data = load_data - _load_sample = load_sample + _load_data = None + _load_sample = None -class FlashDataset(BaseDataset[Sequence], Dataset): +class FlashDataset(BaseDataset[Mapping], Dataset): """The ``FlashDataset`` is a ``BaseDataset`` and a :class:`~torch.utils.data.Dataset`. - The `data` argument must be a ``Sequence`` (it must have a length). + The `data` argument must be a ``Mapping`` (it must have a length). """ - def load_data(self, data: Any) -> Any: + def load_data(self, data: Any) -> Mapping: + """By default, the `load_data` perform an identity operation Override to add your own logic to load the + data.""" return data def load_sample(self, data: Any) -> Any: + """By default, the `load_sample` perform an identity operation. + + Override to add your own logic to load a single sample. + """ return data def resolve_functions(self): @@ -128,10 +139,16 @@ class FlashIterableDataset(BaseDataset[Iterable], IterableDataset): The `data` argument must be an ``Iterable``. """ - def load_data(self, data: Any) -> Any: + def load_data(self, data: Any) -> Iterable: + """By default, the `load_data` perform an identity operation Override to add your own logic to load the + data.""" return data def load_sample(self, data: Any) -> Any: + """By default, the `load_sample` perform an identity operation. + + Override to add your own logic to load a single sample. + """ return data def resolve_functions(self): From 54182d2a7cd0e28574bc826167d731141e4f3d0b Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 11 Oct 2021 15:01:47 +0100 Subject: [PATCH 2/4] update --- flash/core/data/datasets.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/flash/core/data/datasets.py b/flash/core/data/datasets.py index 4629cef5b9..6417680d4c 100644 --- a/flash/core/data/datasets.py +++ b/flash/core/data/datasets.py @@ -111,8 +111,10 @@ class FlashDataset(BaseDataset[Mapping], Dataset): """ def load_data(self, data: Any) -> Mapping: - """By default, the `load_data` perform an identity operation Override to add your own logic to load the - data.""" + """By default, the `load_data` perform an identity operation. + + Override to add your own logic to load the data. + """ return data def load_sample(self, data: Any) -> Any: @@ -140,8 +142,10 @@ class FlashIterableDataset(BaseDataset[Iterable], IterableDataset): """ def load_data(self, data: Any) -> Iterable: - """By default, the `load_data` perform an identity operation Override to add your own logic to load the - data.""" + """By default, the `load_data` perform an identity operation. + + Override to add your own logic to load the data. + """ return data def load_sample(self, data: Any) -> Any: From b7c4297dc7cc6892a59d2d0b4f17c26af87e8d3d Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 11 Oct 2021 15:07:22 +0100 Subject: [PATCH 3/4] update on comments --- flash/core/data/datasets.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flash/core/data/datasets.py b/flash/core/data/datasets.py index 6417680d4c..052a0d7e9e 100644 --- a/flash/core/data/datasets.py +++ b/flash/core/data/datasets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import abstractmethod -from typing import Any, Callable, Generic, Iterable, Mapping, Optional, Type, TypeVar +from typing import Any, Callable, Iterable, Mapping, Optional, Type, TypeVar, Union from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -24,14 +24,14 @@ DATA_TYPE = TypeVar("DATA_TYPE") -class BaseDataset(Generic[DATA_TYPE], Properties): +class BaseDataset(Properties): DATASET_KEY = "dataset" transform_registry: Optional[FlashRegistry] = None @abstractmethod - def load_data(self, data: Any) -> DATA_TYPE: + def load_data(self, data: Any) -> Union[Iterable, Mapping]: """The `load_data` hook should return either a Mapping or an Iterable. Override to add your dataset logic creation logic. @@ -104,7 +104,7 @@ def resolve_functions(self): _load_sample = None -class FlashDataset(BaseDataset[Mapping], Dataset): +class FlashDataset(Dataset, BaseDataset): """The ``FlashDataset`` is a ``BaseDataset`` and a :class:`~torch.utils.data.Dataset`. The `data` argument must be a ``Mapping`` (it must have a length). @@ -135,7 +135,7 @@ def __len__(self) -> int: return len(self.data) -class FlashIterableDataset(BaseDataset[Iterable], IterableDataset): +class FlashIterableDataset(IterableDataset, BaseDataset): """The ``IterableAutoDataset`` is a ``BaseDataset`` and a :class:`~torch.utils.data.IterableDataset`. The `data` argument must be an ``Iterable``. From d0b9086b45d3c2311f5b8e2d8594ef77f15eb5f6 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 11 Oct 2021 15:11:09 +0100 Subject: [PATCH 4/4] update --- flash/core/data/datasets.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flash/core/data/datasets.py b/flash/core/data/datasets.py index 052a0d7e9e..a1a1ef6b5b 100644 --- a/flash/core/data/datasets.py +++ b/flash/core/data/datasets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import abstractmethod -from typing import Any, Callable, Iterable, Mapping, Optional, Type, TypeVar, Union +from typing import Any, Callable, Iterable, Mapping, Optional, Type, Union from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -21,8 +21,6 @@ from flash.core.data.properties import Properties from flash.core.registry import FlashRegistry -DATA_TYPE = TypeVar("DATA_TYPE") - class BaseDataset(Properties): @@ -65,7 +63,7 @@ def running_stage(self, running_stage: RunningStage) -> None: def _resolve_functions(self, func_name: str, cls: Type["BaseDataset"]) -> None: from flash.core.data.data_pipeline import DataPipeline # noqa F811 - function: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( + function: Callable[[Any, Optional[Any]], Any] = getattr( self, DataPipeline._resolve_function_hierarchy( func_name, @@ -100,6 +98,8 @@ def from_data( def resolve_functions(self): raise NotImplementedError + # Set to None as they are dymically resolved when the dataset is made stage aware + # c.f running_stage is set in `__init__` function. _load_data = None _load_sample = None