diff --git a/CHANGELOG.md b/CHANGELOG.md index b74da5e12c..b733661ba5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Changed `Preprocess` to `InputTransform` ([#951](https://github.com/PyTorchLightning/lightning-flash/pull/951)) + - Changed classes named `*Serializer` and properties / variables named `serializer` to be `*Output` and `output` respectively ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927)) - Changed `Postprocess` to `OutputTransform` ([#942](https://github.com/PyTorchLightning/lightning-flash/pull/942)) diff --git a/README.md b/README.md index 743a88c13b..5056694562 100644 --- a/README.md +++ b/README.md @@ -227,7 +227,7 @@ model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr Flash includes some simple augmentations for each task by default, however, you will often want to override these and control your own augmentation recipe. To this end, Flash supports custom transformations backed by our powerful data pipeline. -The transform requires to be passed as a dictionary of transforms where the keys are the [hook's name](https://lightning-flash.readthedocs.io/en/latest/api/generated/flash.core.data.process.Preprocess.html?highlight=Preprocess). +The transform requires to be passed as a dictionary of transforms where the keys are the [hook's name](https://lightning-flash.readthedocs.io/en/latest/api/generated/flash.core.data.io.input_transform.InputTransform.html?highlight=InputTransform). This enable transforms to be applied per sample or per batch either on or off device. It is important to note that data are being processed as a dictionary for all tasks (typically containing `input`, `target`, and `metadata`), Therefore, you can use [`ApplyToKeys`](https://lightning-flash.readthedocs.io/en/latest/api/generated/flash.core.data.transforms.ApplyToKeys.html#flash.core.data.transforms.ApplyToKeys) utility to apply the transform to a specific key. diff --git a/docs/extensions/autodatasources.py b/docs/extensions/autodatasources.py index e2e00b4017..b1bc85a1ff 100644 --- a/docs/extensions/autodatasources.py +++ b/docs/extensions/autodatasources.py @@ -53,17 +53,17 @@ def run(self): data_module = getattr(importlib.import_module(data_module_path), data_module_name) - class PatchedPreprocess(data_module.preprocess_cls): + class PatchedInputTransform(data_module.input_transform_cls): """TODO: This is a hack to prevent default transforms form being created""" @staticmethod def _resolve_transforms(_): return None - preprocess = PatchedPreprocess() + input_transform = PatchedInputTransform() data_sources = { - data_source: preprocess.data_source_of_name(data_source) - for data_source in preprocess.available_data_sources() + data_source: input_transform.data_source_of_name(data_source) + for data_source in input_transform.available_data_sources() } ENVIRONMENT.get_template("base.rst") diff --git a/docs/source/api/audio.rst b/docs/source/api/audio.rst index a5aec6cacc..16cdf31d88 100644 --- a/docs/source/api/audio.rst +++ b/docs/source/api/audio.rst @@ -18,7 +18,7 @@ ______________ :template: classtemplate.rst ~classification.data.AudioClassificationData - ~classification.data.AudioClassificationPreprocess + ~classification.data.AudioClassificationInputTransform Speech Recognition __________________ @@ -31,7 +31,7 @@ __________________ ~speech_recognition.data.SpeechRecognitionData ~speech_recognition.model.SpeechRecognition - speech_recognition.data.SpeechRecognitionPreprocess + speech_recognition.data.SpeechRecognitionInputTransform speech_recognition.data.SpeechRecognitionBackboneState speech_recognition.data.SpeechRecognitionOutputTransform speech_recognition.data.SpeechRecognitionCSVDataSource diff --git a/docs/source/api/data.rst b/docs/source/api/data.rst index 08a59a4d30..0d56e52cdd 100644 --- a/docs/source/api/data.rst +++ b/docs/source/api/data.rst @@ -118,12 +118,13 @@ _______________________ :nosignatures: :template: classtemplate.rst - ~flash.core.data.process.BasePreprocess - ~flash.core.data.process.DefaultPreprocess + ~flash.core.data.io.input_transform.BaseInputTransform + ~flash.core.data.io.input_transform.DefaultInputTransform ~flash.core.data.process.DeserializerMapping ~flash.core.data.process.Deserializer ~flash.core.data.io.output_transform.OutputTransform - ~flash.core.data.process.Preprocess + ~flash.core.data.io.input_transform.InputTransform + ~flash.core.data.process.Serializer flash.core.data.properties __________________________ diff --git a/docs/source/api/flash.rst b/docs/source/api/flash.rst index f61471eeef..c83d6fb5f4 100644 --- a/docs/source/api/flash.rst +++ b/docs/source/api/flash.rst @@ -10,8 +10,8 @@ flash ~flash.core.data.data_source.DataSource ~flash.core.data.data_module.DataModule ~flash.core.data.callback.FlashCallback - ~flash.core.data.process.Preprocess ~flash.core.data.io.output_transform.OutputTransform ~flash.core.data.io.output.Output + ~flash.core.data.io.input_transform.InputTransform ~flash.core.model.Task ~flash.core.trainer.Trainer diff --git a/docs/source/api/graph.rst b/docs/source/api/graph.rst index bf94475ab2..74becb27df 100644 --- a/docs/source/api/graph.rst +++ b/docs/source/api/graph.rst @@ -20,7 +20,7 @@ ______________ ~classification.model.GraphClassifier ~classification.data.GraphClassificationData - classification.data.GraphClassificationPreprocess + classification.data.GraphClassificationInputTransform flash.graph.data ________________ diff --git a/docs/source/api/image.rst b/docs/source/api/image.rst index 3cedc69058..7d1280a3f3 100644 --- a/docs/source/api/image.rst +++ b/docs/source/api/image.rst @@ -19,7 +19,7 @@ ______________ ~classification.model.ImageClassifier ~classification.data.ImageClassificationData - ~classification.data.ImageClassificationPreprocess + ~classification.data.ImageClassificationInputTransform classification.data.MatplotlibVisualization @@ -44,8 +44,8 @@ ________________ detection.data.FiftyOneParser detection.data.ObjectDetectionFiftyOneDataSource - detection.data.ObjectDetectionPreprocess detection.output.FiftyOneDetectionLabels + detection.data.ObjectDetectionInputTransform Keypoint Detection __________________ @@ -58,7 +58,7 @@ __________________ ~keypoint_detection.model.KeypointDetector ~keypoint_detection.data.KeypointDetectionData - keypoint_detection.data.KeypointDetectionPreprocess + keypoint_detection.data.KeypointDetectionInputTransform Instance Segmentation _____________________ @@ -71,7 +71,7 @@ _____________________ ~instance_segmentation.model.InstanceSegmentation ~instance_segmentation.data.InstanceSegmentationData - instance_segmentation.data.InstanceSegmentationPreprocess + instance_segmentation.data.InstanceSegmentationInputTransform Embedding _________ @@ -93,7 +93,7 @@ ____________ ~segmentation.model.SemanticSegmentation ~segmentation.data.SemanticSegmentationData - ~segmentation.data.SemanticSegmentationPreprocess + ~segmentation.data.SemanticSegmentationInputTransform segmentation.data.SegmentationMatplotlibVisualization segmentation.data.SemanticSegmentationNumpyDataSource @@ -123,7 +123,7 @@ ______________ ~style_transfer.model.StyleTransfer ~style_transfer.data.StyleTransferData - ~style_transfer.data.StyleTransferPreprocess + ~style_transfer.data.StyleTransferInputTransform .. autosummary:: :toctree: generated/ diff --git a/docs/source/api/pointcloud.rst b/docs/source/api/pointcloud.rst index d3c7b94797..dc4b777423 100644 --- a/docs/source/api/pointcloud.rst +++ b/docs/source/api/pointcloud.rst @@ -20,7 +20,7 @@ ____________ ~segmentation.model.PointCloudSegmentation ~segmentation.data.PointCloudSegmentationData - segmentation.data.PointCloudSegmentationPreprocess + segmentation.data.PointCloudSegmentationInputTransform segmentation.data.PointCloudSegmentationFoldersDataSource segmentation.data.PointCloudSegmentationDatasetDataSource @@ -35,6 +35,6 @@ ________________ ~detection.model.PointCloudObjectDetector ~detection.data.PointCloudObjectDetectorData - detection.data.PointCloudObjectDetectorPreprocess + detection.data.PointCloudObjectDetectorInputTransform detection.data.PointCloudObjectDetectorFoldersDataSource detection.data.PointCloudObjectDetectorDatasetDataSource diff --git a/docs/source/api/tabular.rst b/docs/source/api/tabular.rst index 1b8b8add8b..a258890495 100644 --- a/docs/source/api/tabular.rst +++ b/docs/source/api/tabular.rst @@ -42,7 +42,7 @@ ___________ ~forecasting.model.TabularForecaster ~forecasting.data.TabularForecastingData - forecasting.data.TabularForecastingPreprocess + forecasting.data.TabularForecastingInputTransform forecasting.data.TabularForecastingDataFrameDataSource forecasting.data.TimeSeriesDataSetParametersState @@ -58,5 +58,5 @@ __________________ ~data.TabularDataFrameDataSource ~data.TabularCSVDataSource ~data.TabularDeserializer - ~data.TabularPreprocess ~data.TabularOutputTransform + ~data.TabularInputTransform diff --git a/docs/source/api/text.rst b/docs/source/api/text.rst index 50750dad05..d692994aa8 100644 --- a/docs/source/api/text.rst +++ b/docs/source/api/text.rst @@ -21,7 +21,7 @@ ______________ ~classification.data.TextClassificationData classification.data.TextClassificationOutputTransform - classification.data.TextClassificationPreprocess + classification.data.TextClassificationInputTransform classification.data.TextDeserializer classification.data.TextDataSource classification.data.TextCSVDataSource @@ -49,7 +49,7 @@ __________________ question_answering.data.QuestionAnsweringFileDataSource question_answering.data.QuestionAnsweringJSONDataSource question_answering.data.QuestionAnsweringOutputTransform - question_answering.data.QuestionAnsweringPreprocess + question_answering.data.QuestionAnsweringInputTransform question_answering.data.SQuADDataSource @@ -64,7 +64,7 @@ _____________ ~seq2seq.summarization.model.SummarizationTask ~seq2seq.summarization.data.SummarizationData - seq2seq.summarization.data.SummarizationPreprocess + seq2seq.summarization.data.SummarizationInputTransform Translation ___________ @@ -77,7 +77,7 @@ ___________ ~seq2seq.translation.model.TranslationTask ~seq2seq.translation.data.TranslationData - seq2seq.translation.data.TranslationPreprocess + seq2seq.translation.data.TranslationInputTransform General Seq2Seq _______________ @@ -97,5 +97,5 @@ _______________ seq2seq.core.data.Seq2SeqFileDataSource seq2seq.core.data.Seq2SeqJSONDataSource seq2seq.core.data.Seq2SeqOutputTransform - seq2seq.core.data.Seq2SeqPreprocess + seq2seq.core.data.Seq2SeqInputTransform seq2seq.core.data.Seq2SeqSentencesDataSource diff --git a/docs/source/api/video.rst b/docs/source/api/video.rst index ade63234ca..f9825041f0 100644 --- a/docs/source/api/video.rst +++ b/docs/source/api/video.rst @@ -23,5 +23,5 @@ ______________ classification.data.BaseVideoClassification classification.data.VideoClassificationFiftyOneDataSource classification.data.VideoClassificationPathsDataSource - classification.data.VideoClassificationPreprocess + classification.data.VideoClassificationInputTransform classification.model.VideoClassifierFinetuning diff --git a/docs/source/general/data.rst b/docs/source/general/data.rst index ff2ca87fb8..5dc0e71387 100644 --- a/docs/source/general/data.rst +++ b/docs/source/general/data.rst @@ -26,14 +26,14 @@ Here are common terms you need to be familiar with: * - :class:`~flash.core.data.data_module.DataModule` - The :class:`~flash.core.data.data_module.DataModule` contains the datasets, transforms and dataloaders. * - :class:`~flash.core.data.data_pipeline.DataPipeline` - - The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.core.data.Deserializer`, :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.io.output_transform.OutputTransform`, and :class:`~flash.core.data.io.output.Output` objects. + - The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.core.data.Deserializer`, :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.io.input_transform.InputTransform`, :class:`~flash.core.data.io.output_transform.OutputTransform`, and :class:`~flash.core.data.io.output.Output` objects. * - :class:`~flash.core.data.data_source.DataSource` - The :class:`~flash.core.data.data_source.DataSource` provides :meth:`~flash.core.data.data_source.DataSource.load_data` and :meth:`~flash.core.data.data_source.DataSource.load_sample` hooks for creating data sets from metadata (such as folder names). - * - :class:`~flash.core.data.process.Preprocess` - - The :class:`~flash.core.data.process.Preprocess` provides a simple hook-based API to encapsulate your pre-processing logic. - These hooks (such as :meth:`~flash.core.data.process.Preprocess.pre_tensor_transform`) enable transformations to be applied to your data at every point along the pipeline (including on the device). + * - :class:`~flash.core.data.io.input_transform.InputTransform` + - The :class:`~flash.core.data.io.input_transform.InputTransform` provides a simple hook-based API to encapsulate your pre-processing logic. + These hooks (such as :meth:`~flash.core.data.io.input_transform.InputTransform.pre_tensor_transform`) enable transformations to be applied to your data at every point along the pipeline (including on the device). The :class:`~flash.core.data.data_pipeline.DataPipeline` contains a system to call the right hooks when needed. - The :class:`~flash.core.data.process.Preprocess` hooks can be either overridden directly or provided as a dictionary of transforms (mapping hook name to callable transform). + The :class:`~flash.core.data.io.input_transform.InputTransform` hooks can be either overridden directly or provided as a dictionary of transforms (mapping hook name to callable transform). * - :class:`~flash.core.data.io.output_transform.OutputTransform` - The :class:`~flash.core.data.io.output_transform.OutputTransform` provides a simple hook-based API to encapsulate your post-processing logic. The :class:`~flash.core.data.io.output_transform.OutputTransform` hooks cover from model outputs to predictions export. @@ -58,7 +58,8 @@ However, after model training, it requires a lot of engineering overhead to make Usually, extra processing logic should be added to bridge the gap between training data and raw data. The :class:`~flash.core.data.data_source.DataSource` class can be used to generate data sets from multiple sources (e.g. folders, numpy, etc.), that can then all be transformed in the same way. -The :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.io.output_transform.OutputTransform` classes can be used to manage the preprocessing and postprocessing transforms. + +The :class:`~flash.core.data.io.input_transform.InputTransform` and :class:`~flash.core.data.io.output_transform.OutputTransform` classes can be used to manage the input and output transforms. The :class:`~flash.core.data.io.output.Output` class provides the logic for converting :class:`~flash.core.data.io.output_transform.OutputTransform` outputs to the desired predict format (e.g. classes, labels, probabilities, etc.). By providing a series of hooks that can be overridden with custom data processing logic (or just targeted with transforms), @@ -72,10 +73,10 @@ Here are the primary advantages: To change the processing behavior only on specific stages for a given hook, -you can prefix each of the :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.io.output_transform.OutputTransform` +you can prefix each of the :class:`~flash.core.data.io.input_transform.InputTransform` and :class:`~flash.core.data.io.output_transform.OutputTransform` hooks by adding ``train``, ``val``, ``test`` or ``predict``. -Check out :class:`~flash.core.data.process.Preprocess` for some examples. +Check out :class:`~flash.core.data.io.input_transform.InputTransform` for some examples. ************************************* How to customize existing DataModules @@ -93,7 +94,7 @@ Any Flash :class:`~flash.core.data.data_module.DataModule` can be created direct The :class:`~flash.core.data.data_module.DataModule` provides additional ``classmethod`` helpers (``from_*``) for loading data from various sources. -In each ``from_*`` method, the :class:`~flash.core.data.data_module.DataModule` internally retrieves the correct :class:`~flash.core.data.data_source.DataSource` to use from the :class:`~flash.core.data.process.Preprocess`. +In each ``from_*`` method, the :class:`~flash.core.data.data_module.DataModule` internally retrieves the correct :class:`~flash.core.data.data_source.DataSource` to use from the :class:`~flash.core.data.io.input_transform.InputTransform`. Flash :class:`~flash.core.data.auto_dataset.AutoDataset` instances are created from the :class:`~flash.core.data.data_source.DataSource` for train, val, test, and predict. The :class:`~flash.core.data.data_module.DataModule` populates the ``DataLoader`` for each stage with the corresponding :class:`~flash.core.data.auto_dataset.AutoDataset`. @@ -101,8 +102,8 @@ The :class:`~flash.core.data.data_module.DataModule` populates the ``DataLoader` Customize preprocessing of DataModules ************************************** -The :class:`~flash.core.data.process.Preprocess` contains the processing logic related to a given task. -Each :class:`~flash.core.data.process.Preprocess` provides some default transforms through the :meth:`~flash.core.data.process.Preprocess.default_transforms` method. +The :class:`~flash.core.data.io.input_transform.InputTransform` contains the processing logic related to a given task. +Each :class:`~flash.core.data.io.input_transform.InputTransform` provides some default transforms through the :meth:`~flash.core.data.io.input_transform.InputTransform.default_transforms` method. Users can easily override these by providing their own transforms to the :class:`~flash.core.data.data_module.DataModule`. Here's an example: @@ -127,10 +128,10 @@ Alternatively, the user may directly override the hooks for their needs like thi .. code-block:: python from typing import Any, Dict - from flash.image import ImageClassificationData, ImageClassifier, ImageClassificationPreprocess + from flash.image import ImageClassificationData, ImageClassifier, ImageClassificationInputTransform - class CustomImageClassificationPreprocess(ImageClassificationPreprocess): + class CustomImageClassificationInputTransform(ImageClassificationInputTransform): def to_tensor_transform(sample: Dict[str, Any]) -> Dict[str, Any]: sample["input"] = my_to_tensor_transform(sample["input"]) return sample @@ -140,15 +141,15 @@ Alternatively, the user may directly override the hooks for their needs like thi train_folder="data/hymenoptera_data/train/", val_folder="data/hymenoptera_data/val/", test_folder="data/hymenoptera_data/test/", - preprocess=CustomImageClassificationPreprocess(), + input_transform=CustomImageClassificationInputTransform(), ) -***************************************** -Create your own Preprocess and DataModule -***************************************** +********************************************* +Create your own InputTransform and DataModule +********************************************* -The example below shows a very simple ``ImageClassificationPreprocess`` with a single ``ImageClassificationFoldersDataSource`` and an ``ImageClassificationDataModule``. +The example below shows a very simple ``ImageClassificationInputTransform`` with a single ``ImageClassificationFoldersDataSource`` and an ``ImageClassificationDataModule``. 1. User-Facing API design _________________________ @@ -226,20 +227,20 @@ Here's the full ``ImageClassificationFoldersDataSource``: .. note:: We return samples as dictionaries using the :class:`~flash.core.data.data_source.DefaultDataKeys` by convention. This is the recommended (although not required) way to represent data in Flash. -3. The Preprocess -__________________ +3. The InputTransform +_____________________ -Next, implement your custom ``ImageClassificationPreprocess`` with some default transforms and a reference to the data source: +Next, implement your custom ``ImageClassificationInputTransform`` with some default transforms and a reference to the data source: .. code-block:: python from typing import Any, Callable, Dict, Optional from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources - from flash.core.data.process import Preprocess + from flash.core.data.io.input_transform import InputTransform import torchvision.transforms.functional as T - # Subclass `Preprocess` - class ImageClassificationPreprocess(Preprocess): + # Subclass `InputTransform` + class ImageClassificationInputTransform(InputTransform): def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -272,8 +273,8 @@ Next, implement your custom ``ImageClassificationPreprocess`` with some default _________________ Finally, let's implement the ``ImageClassificationDataModule``. -We get the ``from_folders`` classmethod for free as we've registered a ``DefaultDataSources.FOLDERS`` data source in our ``ImageClassificationPreprocess``. -All we need to do is attach our :class:`~flash.core.data.process.Preprocess` class like this: +We get the ``from_folders`` classmethod for free as we've registered a ``DefaultDataSources.FOLDERS`` data source in our ``ImageClassificationInputTransform``. +All we need to do is attach our :class:`~flash.core.data.io.input_transform.InputTransform` class like this: .. code-block:: python @@ -282,8 +283,8 @@ All we need to do is attach our :class:`~flash.core.data.process.Preprocess` cla class ImageClassificationDataModule(DataModule): - # Set `preprocess_cls` with your custom `Preprocess`. - preprocess_cls = ImageClassificationPreprocess + # Set `input_transform_cls` with your custom `InputTransform`. + input_transform_cls = ImageClassificationInputTransform ****************************** @@ -319,27 +320,27 @@ Here is the :class:`~flash.core.data.auto_dataset.AutoDataset` pseudo-code. def __len__(self): return len(self.data) -Preprocess -__________ +InputTransform +______________ .. note:: - The :meth:`~flash.core.data.process.Preprocess.pre_tensor_transform`, - :meth:`~flash.core.data.process.Preprocess.to_tensor_transform`, - :meth:`~flash.core.data.process.Preprocess.post_tensor_transform`, - :meth:`~flash.core.data.process.Preprocess.collate`, - :meth:`~flash.core.data.process.Preprocess.per_batch_transform` are injected as the + The :meth:`~flash.core.data.io.input_transform.InputTransform.pre_tensor_transform`, + :meth:`~flash.core.data.io.input_transform.InputTransform.to_tensor_transform`, + :meth:`~flash.core.data.io.input_transform.InputTransform.post_tensor_transform`, + :meth:`~flash.core.data.io.input_transform.InputTransform.collate`, + :meth:`~flash.core.data.io.input_transform.InputTransform.per_batch_transform` are injected as the :paramref:`torch.utils.data.DataLoader.collate_fn` function of the DataLoader. -Here is the pseudo code using the preprocess hooks name. +Here is the pseudo code using the input transform hooks name. Flash takes care of calling the right hooks for each stage. Example:: - # This will be wrapped into a :class:`~flash.core.data.batch._Preprocessor`. + # This will be wrapped into a :class:`~flash.core.data.io.input_transform.flash.core.data.io.input_transform._InputTransformProcessor`. def collate_fn(samples: Sequence[Any]) -> Any: - # This will be wrapped into a :class:`~flash.core.data.batch._Sequential` + # This will be wrapped into a :class:`~flash.core.data.io.input_transform._InputTransformSequential` for sample in samples: sample = pre_tensor_transform(sample) sample = to_tensor_transform(sample) @@ -347,7 +348,7 @@ Example:: samples = type(samples)(samples) - # if :func:`flash.core.data.process.Preprocess.per_sample_transform_on_device` hook is overridden, + # if :func:`flash.core.data.io.input_transform.InputTransform.per_sample_transform_on_device` hook is overridden, # those functions below will be no-ops samples = collate(samples) @@ -361,12 +362,12 @@ Example:: The ``per_sample_transform_on_device``, ``collate``, ``per_batch_transform_on_device`` are injected after the ``LightningModule`` ``transfer_batch_to_device`` hook. -Here is the pseudo code using the preprocess hooks name. +Here is the pseudo code using the input transform hooks name. Flash takes care of calling the right hooks for each stage. Example:: - # This will be wrapped into a :class:`~flash.core.data.batch._Preprocessor` + # This will be wrapped into a :class:`~flash.core.data.io.input_transform._InputTransformProcessor` def collate_fn(samples: Sequence[Any]) -> Any: # if ``per_batch_transform`` hook is overridden, those functions below will be no-ops diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index 93e2983a4e..fc10f2b713 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -103,7 +103,7 @@ Custom Transformations ********************** Flash automatically applies some default image transformations and augmentations, but you may wish to customize these for your own use case. -The base :class:`~flash.core.data.process.Preprocess` defines 7 hooks for different stages in the data loading pipeline. +The base :class:`~flash.core.data.io.input_transform.InputTransform` defines 7 hooks for different stages in the data loading pipeline. To apply image augmentations you can directly import the ``default_transforms`` from ``flash.image.classification.transforms`` and then merge your custom image transformations with them using the :func:`~flash.core.data.transforms.merge_transforms` helper function. Here's an example where we load the default transforms and merge with custom `torchvision` transformations. We use the `post_tensor_transform` hook to apply the transformations after the image has been converted to a `torch.Tensor`. diff --git a/docs/source/reference/object_detection.rst b/docs/source/reference/object_detection.rst index dd70e4bfed..1203b582c7 100644 --- a/docs/source/reference/object_detection.rst +++ b/docs/source/reference/object_detection.rst @@ -81,7 +81,7 @@ Custom Transformations ********************** Flash automatically applies some default image / mask transformations and augmentations, but you may wish to customize these for your own use case. -The base :class:`~flash.core.data.process.Preprocess` defines 7 hooks for different stages in the data loading pipeline. +The base :class:`~flash.core.data.io.input_transform.InputTransform` defines 7 hooks for different stages in the data loading pipeline. For object-detection tasks, you can leverage the transformations from `Albumentations `__ with the :class:`~flash.core.integrations.icevision.transforms.IceVisionTransformAdapter`. .. code-block:: python diff --git a/docs/source/template/data.rst b/docs/source/template/data.rst index 1cd7fcd10c..d5eb6b03e6 100644 --- a/docs/source/template/data.rst +++ b/docs/source/template/data.rst @@ -8,7 +8,7 @@ The first step to contributing a task is to implement the classes we need to loa Inside `data.py `_ you should implement: #. some :class:`~flash.core.data.data_source.DataSource` classes *(optional)* -#. a :class:`~flash.core.data.process.Preprocess` +#. a :class:`~flash.core.data.io.input_transform.InputTransform` #. a :class:`~flash.core.data.data_module.DataModule` #. a :class:`~flash.core.data.base_viz.BaseVisualization` *(optional)* #. a :class:`~flash.core.data.io.output_transform.OutputTransform` *(optional)* @@ -74,7 +74,7 @@ DataSource vs Dataset ~~~~~~~~~~~~~~~~~~~~~ A :class:`~flash.core.data.data_source.DataSource` is not the same as a :class:`torch.utils.data.Dataset`. -When a ``from_*`` method is called on your :class:`~flash.core.data.data_module.DataModule`, it gets the :class:`~flash.core.data.data_source.DataSource` to use from the :class:`~flash.core.data.process.Preprocess`. +When a ``from_*`` method is called on your :class:`~flash.core.data.data_module.DataModule`, it gets the :class:`~flash.core.data.data_source.DataSource` to use from the :class:`~flash.core.data.io.input_transform.InputTransform`. A :class:`~torch.utils.data.Dataset` is then created from the :class:`~flash.core.data.data_source.DataSource` for each stage (`train`, `val`, `test`, `predict`) using the provided metadata (e.g. folder name, numpy array etc.). The output of the :meth:`~flash.core.data.data_source.DataSource.load_data` can just be a :class:`torch.utils.data.Dataset` instance. @@ -87,14 +87,14 @@ Here's how it looks (from `video/classification.data.py `_ which creates some default transforms given the desired image size: .. literalinclude:: ../../../flash/image/classification/transforms.py :language: python :pyobject: default_transforms -Here's how we create our transforms in the :class:`~flash.image.classification.data.ImageClassificationPreprocess`: +Here's how we create our transforms in the :class:`~flash.image.classification.data.ImageClassificationInputTransform`: .. literalinclude:: ../../../flash/image/classification/data.py :language: python - :pyobject: ImageClassificationPreprocess.default_transforms + :pyobject: ImageClassificationInputTransform.default_transforms Add outputs to your Task ======================== diff --git a/flash/__init__.py b/flash/__init__.py index 77687c1e41..5b5414f9f3 100644 --- a/flash/__init__.py +++ b/flash/__init__.py @@ -23,10 +23,10 @@ from flash.core.data.data_module import DataModule # noqa: E402 from flash.core.data.data_source import DataSource from flash.core.data.datasets import FlashDataset, FlashIterableDataset - from flash.core.data.input_transform import InputTransform + from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output import Output from flash.core.data.io.output_transform import OutputTransform - from flash.core.data.process import Preprocess, Serializer + from flash.core.data.process import Serializer from flash.core.model import Task # noqa: E402 from flash.core.trainer import Trainer # noqa: E402 @@ -49,7 +49,6 @@ "InputTransform", "Output", "OutputTransform", - "Preprocess", "Serializer", "Task", "Trainer", diff --git a/flash/audio/__init__.py b/flash/audio/__init__.py index b90bc6d06e..9aff9a8b4a 100644 --- a/flash/audio/__init__.py +++ b/flash/audio/__init__.py @@ -1,2 +1,2 @@ -from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess # noqa: F401 +from flash.audio.classification import AudioClassificationData, AudioClassificationInputTransform # noqa: F401 from flash.audio.speech_recognition import SpeechRecognition, SpeechRecognitionData # noqa: F401 diff --git a/flash/audio/classification/__init__.py b/flash/audio/classification/__init__.py index 476a303d49..13f8fb612b 100644 --- a/flash/audio/classification/__init__.py +++ b/flash/audio/classification/__init__.py @@ -1 +1 @@ -from flash.audio.classification.data import AudioClassificationData, AudioClassificationPreprocess # noqa: F401 +from flash.audio.classification.data import AudioClassificationData, AudioClassificationInputTransform # noqa: F401 diff --git a/flash/audio/classification/cli.py b/flash/audio/classification/cli.py index c69b1e540c..ac84ffc21d 100644 --- a/flash/audio/classification/cli.py +++ b/flash/audio/classification/cli.py @@ -23,7 +23,7 @@ def from_urban8k( batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> AudioClassificationData: """Downloads and loads the Urban 8k sounds images data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip", "./data") @@ -32,7 +32,7 @@ def from_urban8k( val_folder="data/urban8k_images/val", batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py index 9c5bd805c1..4d6edb02bd 100644 --- a/flash/audio/classification/data.py +++ b/flash/audio/classification/data.py @@ -24,7 +24,8 @@ NumpyDataSource, PathsDataSource, ) -from flash.core.data.process import Deserializer, Preprocess +from flash.core.data.io.input_transform import InputTransform +from flash.core.data.process import Deserializer from flash.core.data.utils import image_default_loader from flash.image.classification.data import ImageClassificationData from flash.image.data import ImageDeserializer, IMG_EXTENSIONS, NP_EXTENSIONS @@ -61,7 +62,7 @@ def __init__(self): super().__init__(spectrogram_loader) -class AudioClassificationPreprocess(Preprocess): +class AudioClassificationInputTransform(InputTransform): def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -116,4 +117,4 @@ def train_default_transforms(self) -> Optional[Dict[str, Callable]]: class AudioClassificationData(ImageClassificationData): """Data module for audio classification.""" - preprocess_cls = AudioClassificationPreprocess + input_transform_cls = AudioClassificationInputTransform diff --git a/flash/audio/speech_recognition/cli.py b/flash/audio/speech_recognition/cli.py index f8a7ad26dd..a74a930d25 100644 --- a/flash/audio/speech_recognition/cli.py +++ b/flash/audio/speech_recognition/cli.py @@ -23,7 +23,7 @@ def from_timit( val_split: float = 0.1, batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> SpeechRecognitionData: """Downloads and loads the timit data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data") @@ -35,7 +35,7 @@ def from_timit( val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/audio/speech_recognition/data.py b/flash/audio/speech_recognition/data.py index c50a117fb6..6f1fbb48bb 100644 --- a/flash/audio/speech_recognition/data.py +++ b/flash/audio/speech_recognition/data.py @@ -30,8 +30,9 @@ DefaultDataSources, PathsDataSource, ) +from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.process import Deserializer, Preprocess +from flash.core.data.process import Deserializer from flash.core.data.properties import ProcessState from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires @@ -92,7 +93,7 @@ def load_data( self, data: Tuple[str, Union[str, List[str]], Union[str, List[str]]], dataset: Optional[Any] = None, - ) -> Union[Sequence[Mapping[str, Any]]]: + ) -> Sequence[Mapping[str, Any]]: if self.filetype == "json": file, input_key, target_key, field = data else: @@ -134,7 +135,7 @@ def __init__(self, sampling_rate: int): self.sampling_rate = sampling_rate - def load_data(self, data: Dataset, dataset: Optional[Any] = None) -> Union[Sequence[Mapping[str, Any]]]: + def load_data(self, data: Dataset, dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: if isinstance(data, HFDataset): data = list(zip(data["file"], data["text"])) return super().load_data(data, dataset) @@ -155,7 +156,7 @@ def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any: return self._load_sample(sample, self.sampling_rate) -class SpeechRecognitionPreprocess(Preprocess): +class SpeechRecognitionInputTransform(InputTransform): @requires("audio") def __init__( self, @@ -237,5 +238,5 @@ def __setstate__(self, state): class SpeechRecognitionData(DataModule): """Data Module for text classification tasks.""" - preprocess_cls = SpeechRecognitionPreprocess + input_transform_cls = SpeechRecognitionInputTransform output_transform_cls = SpeechRecognitionOutputTransform diff --git a/flash/core/data/base_viz.py b/flash/core/data/base_viz.py index f46a5d558a..513714a8db 100644 --- a/flash/core/data/base_viz.py +++ b/flash/core/data/base_viz.py @@ -22,10 +22,10 @@ class BaseVisualization(BaseDataFetcher): - """This Base Class is used to create visualization tool on top of :class:`~flash.core.data.process.Preprocess` - hooks. + """This Base Class is used to create visualization tool on top of + :class:`~flash.core.data.io.input_transform.InputTransform` hooks. - Override any of the ``show_{preprocess_hook_name}`` to receive the associated data and visualize them. + Override any of the ``show_{_hook_name}`` to receive the associated data and visualize them. Example:: @@ -102,7 +102,7 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage): .. note:: - As the :class:`~flash.core.data.process.Preprocess` hooks are injected within + As the :class:`~flash.core.data.io.input_transform.InputTransform` hooks are injected within the threaded workers of the DataLoader, the data won't be accessible when using ``num_workers > 0``. """ @@ -123,25 +123,25 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage, func_names_li getattr(self, hook_name)(batch[func_name], running_stage) def show_load_sample(self, samples: List[Any], running_stage: RunningStage): - """Override to visualize preprocess ``load_sample`` output data.""" + """Override to visualize ``load_sample`` output data.""" def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - """Override to visualize preprocess ``pre_tensor_transform`` output data.""" + """Override to visualize ``pre_tensor_transform`` output data.""" def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - """Override to visualize preprocess ``to_tensor_transform`` output data.""" + """Override to visualize ``to_tensor_transform`` output data.""" def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - """Override to visualize preprocess ``post_tensor_transform`` output data.""" + """Override to visualize ``post_tensor_transform`` output data.""" def show_collate(self, batch: List[Any], running_stage: RunningStage) -> None: - """Override to visualize preprocess ``collate`` output data.""" + """Override to visualize ``collate`` output data.""" def show_per_batch_transform(self, batch: List[Any], running_stage: RunningStage) -> None: - """Override to visualize preprocess ``per_batch_transform`` output data.""" + """Override to visualize ``per_batch_transform`` output data.""" def show_per_sample_transform_on_device(self, samples: List[Any], running_stage: RunningStage) -> None: - """Override to visualize preprocess ``per_sample_transform_on_device`` output data.""" + """Override to visualize ``per_sample_transform_on_device`` output data.""" def show_per_batch_transform_on_device(self, batch: List[Any], running_stage: RunningStage) -> None: - """Override to visualize preprocess ``per_batch_transform_on_device`` output data.""" + """Override to visualize ``per_batch_transform_on_device`` output data.""" diff --git a/flash/core/data/batch.py b/flash/core/data/batch.py index c704fc9999..2f77f2b56d 100644 --- a/flash/core/data/batch.py +++ b/flash/core/data/batch.py @@ -11,113 +11,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Sequence, TYPE_CHECKING import torch -from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor from flash.core.data.callback import ControlFlow -from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.utils import ( - _contains_any_tensor, - convert_to_modules, - CurrentFuncContext, - CurrentRunningStageContext, -) +from flash.core.data.utils import convert_to_modules, CurrentFuncContext, CurrentRunningStageContext from flash.core.utilities.stages import RunningStage if TYPE_CHECKING: - from flash.core.data.process import Deserializer, Preprocess - - -class _Sequential(torch.nn.Module): - """This class is used to chain 3 functions together for the _Preprocessor ``per_sample_transform`` function. - - 1. ``pre_tensor_transform`` - 2. ``to_tensor_transform`` - 3. ``post_tensor_transform`` - """ - - def __init__( - self, - preprocess: "Preprocess", - pre_tensor_transform: Optional[Callable], - to_tensor_transform: Optional[Callable], - post_tensor_transform: Callable, - stage: RunningStage, - assert_contains_tensor: bool = False, - ): - super().__init__() - self.preprocess = preprocess - self.callback = ControlFlow(self.preprocess.callbacks) - self.pre_tensor_transform = convert_to_modules(pre_tensor_transform) - self.to_tensor_transform = convert_to_modules(to_tensor_transform) - self.post_tensor_transform = convert_to_modules(post_tensor_transform) - self.stage = stage - self.assert_contains_tensor = assert_contains_tensor - - self._current_stage_context = CurrentRunningStageContext(stage, preprocess, reset=False) - self._pre_tensor_transform_context = CurrentFuncContext("pre_tensor_transform", preprocess) - self._to_tensor_transform_context = CurrentFuncContext("to_tensor_transform", preprocess) - self._post_tensor_transform_context = CurrentFuncContext("post_tensor_transform", preprocess) - - def forward(self, sample: Any) -> Any: - self.callback.on_load_sample(sample, self.stage) - - with self._current_stage_context: - if self.pre_tensor_transform is not None: - with self._pre_tensor_transform_context: - sample = self.pre_tensor_transform(sample) - self.callback.on_pre_tensor_transform(sample, self.stage) - - if self.to_tensor_transform is not None: - with self._to_tensor_transform_context: - sample = self.to_tensor_transform(sample) - self.callback.on_to_tensor_transform(sample, self.stage) - - if self.assert_contains_tensor: - if not _contains_any_tensor(sample): - raise MisconfigurationException( - "When ``to_tensor_transform`` is overriden, " - "``DataPipeline`` expects the outputs to be ``tensors``" - ) - - with self._post_tensor_transform_context: - sample = self.post_tensor_transform(sample) - self.callback.on_post_tensor_transform(sample, self.stage) - - return sample - - def __str__(self) -> str: - return ( - f"{self.__class__.__name__}:\n" - f"\t(pre_tensor_transform): {str(self.pre_tensor_transform)}\n" - f"\t(to_tensor_transform): {str(self.to_tensor_transform)}\n" - f"\t(post_tensor_transform): {str(self.post_tensor_transform)}\n" - f"\t(assert_contains_tensor): {str(self.assert_contains_tensor)}\n" - f"\t(stage): {str(self.stage)}" - ) + from flash.core.data.io.input_transform import InputTransform + from flash.core.data.process import Deserializer class _DeserializeProcessor(torch.nn.Module): def __init__( self, deserializer: "Deserializer", - preprocess: "Preprocess", + input_transform: "InputTransform", pre_tensor_transform: Callable, to_tensor_transform: Callable, ): super().__init__() - self.preprocess = preprocess - self.callback = ControlFlow(self.preprocess.callbacks) + self.input_transform = input_transform + self.callback = ControlFlow(self.input_transform.callbacks) self.deserializer = convert_to_modules(deserializer) self.pre_tensor_transform = convert_to_modules(pre_tensor_transform) self.to_tensor_transform = convert_to_modules(to_tensor_transform) - self._current_stage_context = CurrentRunningStageContext(RunningStage.PREDICTING, preprocess, reset=False) - self._pre_tensor_transform_context = CurrentFuncContext("pre_tensor_transform", preprocess) - self._to_tensor_transform_context = CurrentFuncContext("to_tensor_transform", preprocess) + self._current_stage_context = CurrentRunningStageContext(RunningStage.PREDICTING, input_transform, reset=False) + self._pre_tensor_transform_context = CurrentFuncContext("pre_tensor_transform", input_transform) + self._to_tensor_transform_context = CurrentFuncContext("to_tensor_transform", input_transform) def forward(self, sample: str): @@ -135,115 +60,6 @@ def forward(self, sample: str): return sample -class _Preprocessor(torch.nn.Module): - """ - This class is used to encapsultate the following functions of a Preprocess Object: - Inside a worker: - per_sample_transform: Function to transform an individual sample - Inside a worker, it is actually make of 3 functions: - * pre_tensor_transform - * to_tensor_transform - * post_tensor_transform - collate: Function to merge sample into a batch - per_batch_transform: Function to transform an individual batch - * per_batch_transform - - Inside main process: - per_sample_transform: Function to transform an individual sample - * per_sample_transform_on_device - collate: Function to merge sample into a batch - per_batch_transform: Function to transform an individual batch - * per_batch_transform_on_device - """ - - def __init__( - self, - preprocess: "Preprocess", - collate_fn: Callable, - per_sample_transform: Union[Callable, _Sequential], - per_batch_transform: Callable, - stage: RunningStage, - apply_per_sample_transform: bool = True, - on_device: bool = False, - ): - super().__init__() - self.preprocess = preprocess - self.callback = ControlFlow(self.preprocess.callbacks) - self.collate_fn = convert_to_modules(collate_fn) - self.per_sample_transform = convert_to_modules(per_sample_transform) - self.per_batch_transform = convert_to_modules(per_batch_transform) - self.apply_per_sample_transform = apply_per_sample_transform - self.stage = stage - self.on_device = on_device - - extension = f"{'_on_device' if self.on_device else ''}" - self._current_stage_context = CurrentRunningStageContext(stage, preprocess) - self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform{extension}", preprocess) - self._collate_context = CurrentFuncContext("collate", preprocess) - self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess) - - @staticmethod - def _extract_metadata( - samples: List[Dict[str, Any]], - ) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]: - metadata = [s.pop(DefaultDataKeys.METADATA, None) if isinstance(s, Mapping) else None for s in samples] - return samples, metadata if any(m is not None for m in metadata) else None - - def forward(self, samples: Sequence[Any]) -> Any: - # we create a new dict to prevent from potential memory leaks - # assuming that the dictionary samples are stored in between and - # potentially modified before the transforms are applied. - if isinstance(samples, dict): - samples = dict(samples.items()) - - with self._current_stage_context: - - if self.apply_per_sample_transform: - with self._per_sample_transform_context: - _samples = [] - - if isinstance(samples, Mapping): - samples = [samples] - - for sample in samples: - sample = self.per_sample_transform(sample) - if self.on_device: - self.callback.on_per_sample_transform_on_device(sample, self.stage) - _samples.append(sample) - - samples = type(_samples)(_samples) - - with self._collate_context: - samples, metadata = self._extract_metadata(samples) - try: - samples = self.collate_fn(samples, metadata) - except TypeError: - samples = self.collate_fn(samples) - if metadata and isinstance(samples, dict): - samples[DefaultDataKeys.METADATA] = metadata - self.callback.on_collate(samples, self.stage) - - with self._per_batch_transform_context: - samples = self.per_batch_transform(samples) - if self.on_device: - self.callback.on_per_batch_transform_on_device(samples, self.stage) - else: - self.callback.on_per_batch_transform(samples, self.stage) - return samples - - def __str__(self) -> str: - # todo: define repr function which would take object and string attributes to be shown - return ( - "_Preprocessor:\n" - f"\t(per_sample_transform): {str(self.per_sample_transform)}\n" - f"\t(collate_fn): {str(self.collate_fn)}\n" - f"\t(per_batch_transform): {str(self.per_batch_transform)}\n" - f"\t(apply_per_sample_transform): {str(self.apply_per_sample_transform)}\n" - f"\t(on_device): {str(self.on_device)}\n" - f"\t(stage): {str(self.stage)}" - ) - - def default_uncollate(batch: Any): """ This function is used to uncollate a batch into samples. diff --git a/flash/core/data/callback.py b/flash/core/data/callback.py index ea310bc1fd..3669914639 100644 --- a/flash/core/data/callback.py +++ b/flash/core/data/callback.py @@ -81,7 +81,7 @@ def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningSta class BaseDataFetcher(FlashCallback): - """This class is used to profile :class:`~flash.core.data.process.Preprocess` hook outputs. + """This class is used to profile :class:`~flash.core.data.io.input_transform.InputTransform` hook outputs. By default, the callback won't profile the data being processed as it may lead to ``OOMError``. @@ -90,9 +90,9 @@ class BaseDataFetcher(FlashCallback): from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource - from flash.core.data.process import Preprocess + from flash.core.data.io.input_transform import InputTransform - class CustomPreprocess(Preprocess): + class CustomInputTransform(InputTransform): def __init__(**kwargs): super().__init__( @@ -107,7 +107,7 @@ def print(self): class CustomDataModule(DataModule): - preprocess_cls = CustomPreprocess + input_transform_cls = CustomInputTransform @staticmethod def configure_data_fetcher(): @@ -167,7 +167,7 @@ def from_inputs( def __init__(self, enabled: bool = False): self.enabled = enabled - self._preprocess = None + self._input_transform = None self.reset() def _store(self, data: Any, fn_name: str, running_stage: RunningStage) -> None: @@ -207,9 +207,9 @@ def enable(self): yield self.enabled = False - def attach_to_preprocess(self, preprocess: "flash.core.data.process.Preprocess") -> None: - preprocess.add_callbacks([self]) - self._preprocess = preprocess + def attach_to_input_transform(self, input_transform: "flash.core.data.io.input_transform.InputTransform") -> None: + input_transform.add_callbacks([self]) + self._input_transform = input_transform def reset(self): self.batches = {k: {} for k in _STAGES_PREFIX.values()} diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 3d89d4bdef..8ece38d667 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -39,8 +39,9 @@ from flash.core.data.auto_dataset import BaseAutoDataset, IterableAutoDataset from flash.core.data.base_viz import BaseVisualization from flash.core.data.callback import BaseDataFetcher -from flash.core.data.data_pipeline import DataPipeline, DefaultPreprocess, Preprocess +from flash.core.data.data_pipeline import DataPipeline from flash.core.data.data_source import DataSource, DefaultDataSources +from flash.core.data.io.input_transform import DefaultInputTransform, InputTransform from flash.core.data.io.output_transform import OutputTransform from flash.core.data.splits import SplitDataset from flash.core.data.utils import _STAGES_PREFIX @@ -55,7 +56,7 @@ class DataModule(pl.LightningDataModule): """A basic DataModule class for all Flash tasks. This class includes references to a - :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.process.Preprocess`, + :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.io.input_transform.InputTransform`, :class:`~flash.core.data.io.output_transform.OutputTransform`, and a :class:`~flash.core.data.callback.BaseDataFetcher`. @@ -65,14 +66,14 @@ class DataModule(pl.LightningDataModule): test_dataset: Dataset to test model performance. Defaults to None. predict_dataset: Dataset for predicting. Defaults to None. data_source: The :class:`~flash.core.data.data_source.DataSource` that was used to create the datasets. - preprocess: The :class:`~flash.core.data.process.Preprocess` to use when constructing the + input_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` to use when constructing the :class:`~flash.core.data.data_pipeline.DataPipeline`. If ``None``, a - :class:`~flash.core.data.process.DefaultPreprocess` will be used. + :class:`~flash.core.data.io.input_transform.DefaultInputTransform` will be used. output_transform: The :class:`~flash.core.data.io.output_transform.OutputTransform` to use when constructing the :class:`~flash.core.data.data_pipeline.DataPipeline`. If ``None``, a plain :class:`~flash.core.data.io.output_transform.OutputTransform` will be used. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to attach to the - :class:`~flash.core.data.process.Preprocess`. If ``None``, the output from + :class:`~flash.core.data.io.input_transform.InputTransform`. If ``None``, the output from :meth:`~flash.core.data.data_module.DataModule.configure_data_fetcher` will be used. val_split: An optional float which gives the relative amount of the training dataset to use for the validation dataset. @@ -84,7 +85,7 @@ class DataModule(pl.LightningDataModule): Will be passed to the DataLoader for the training dataset. Defaults to None. """ - preprocess_cls = DefaultPreprocess + input_transform_cls = DefaultInputTransform output_transform_cls = OutputTransform def __init__( @@ -94,7 +95,7 @@ def __init__( test_dataset: Optional[Dataset] = None, predict_dataset: Optional[Dataset] = None, data_source: Optional[DataSource] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, output_transform: Optional[OutputTransform] = None, data_fetcher: Optional[BaseDataFetcher] = None, val_split: Optional[float] = None, @@ -109,13 +110,13 @@ def __init__( batch_size = 16 self._data_source: DataSource = data_source - self._preprocess: Optional[Preprocess] = preprocess + self._input_tranform: Optional[InputTransform] = input_transform self._output_transform: Optional[OutputTransform] = output_transform self._viz: Optional[BaseVisualization] = None self._data_fetcher: Optional[BaseDataFetcher] = data_fetcher or self.configure_data_fetcher() - # TODO: Preprocess can change - self.data_fetcher.attach_to_preprocess(self.preprocess) + # TODO: InputTransform can change + self.data_fetcher.attach_to_input_transform(self.input_transform) self._train_ds = train_dataset self._val_ds = val_dataset @@ -280,7 +281,7 @@ def set_running_stages(self): def _resolve_collate_fn(self, dataset: Dataset, running_stage: RunningStage) -> Optional[Callable]: if isinstance(dataset, (BaseAutoDataset, SplitDataset)): - return self.data_pipeline.worker_preprocessor(running_stage) + return self.data_pipeline.worker_input_transform_processor(running_stage) def _train_dataloader(self) -> DataLoader: """Configure the train dataloader of the datamodule.""" @@ -430,9 +431,9 @@ def data_source(self) -> Optional[DataSource]: return self._data_source @property - def preprocess(self) -> Preprocess: - """Property that returns the preprocessing class used on input data.""" - return self._preprocess or self.preprocess_cls() + def input_transform(self) -> InputTransform: + """Property that returns the input transform class used on input data.""" + return self._input_tranform or self.input_transform_cls() @property def output_transform(self) -> OutputTransform: @@ -442,9 +443,9 @@ def output_transform(self) -> OutputTransform: @property def data_pipeline(self) -> DataPipeline: - """Property that returns the full data pipeline including the data source, preprocessing and + """Property that returns the full data pipeline including the data source, input transform and postprocessing.""" - return DataPipeline(self.data_source, self.preprocess, self.output_transform) + return DataPipeline(self.data_source, self.input_transform, self.output_transform) def available_data_sources(self) -> Sequence[str]: """Get the list of available data source names for use with this @@ -453,7 +454,7 @@ def available_data_sources(self) -> Sequence[str]: Returns: The list of data source names. """ - return self.preprocess.available_data_sources() + return self.input_transform.available_data_sources() @staticmethod def _split_train_val( @@ -503,18 +504,18 @@ def from_data_source( test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given inputs to :meth:`~flash.core.data.data_source.DataSource.load_data` (``train_data``, ``val_data``, ``test_data``, ``predict_data``). The data source will be resolved from the instantiated - :class:`~flash.core.data.process.Preprocess` - using :meth:`~flash.core.data.process.Preprocess.data_source_of_name`. + :class:`~flash.core.data.io.input_transform.InputTransform` + using :meth:`~flash.core.data.io.input_transform.InputTransform.data_source_of_name`. Args: data_source: The name of the data source to use for the @@ -528,24 +529,24 @@ def from_data_source( predict_data: The input to :meth:`~flash.core.data.data_source.DataSource.load_data` to use when creating the predict dataset. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be + input_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -561,15 +562,15 @@ def from_data_source( ) """ - preprocess = preprocess or cls.preprocess_cls( + input_transform = input_transform or cls.input_transform_cls( train_transform, val_transform, test_transform, predict_transform, - **preprocess_kwargs, + **input_transform_kwargs, ) - data_source = preprocess.data_source_of_name(data_source) + data_source = input_transform.data_source_of_name(data_source) train_dataset, val_dataset, test_dataset, predict_dataset = data_source.to_datasets( train_data, @@ -584,7 +585,7 @@ def from_data_source( test_dataset, predict_dataset, data_source=data_source, - preprocess=preprocess, + input_transform=input_transform, data_fetcher=data_fetcher, val_split=val_split, batch_size=batch_size, @@ -604,17 +605,17 @@ def from_folders( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given folders using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS` - from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: train_folder: The folder containing the train data. @@ -622,24 +623,24 @@ def from_folders( test_folder: The folder containing the test data. predict_folder: The folder containing the predict data. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -655,12 +656,12 @@ def from_folders( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -678,17 +679,17 @@ def from_files( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given sequences of files using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.FILES` from the passed or constructed - :class:`~flash.core.data.process.Preprocess`. + :class:`~flash.core.data.io.input_transform.InputTransform`. Args: train_files: A sequence of files to use as the train inputs. @@ -699,24 +700,24 @@ def from_files( test_targets: A sequence of targets (one per test file) to use as the test targets. predict_files: A sequence of files to use when predicting. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -732,12 +733,12 @@ def from_files( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -755,17 +756,17 @@ def from_tensors( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given tensors using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.TENSOR` - from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: train_data: A tensor or collection of tensors to use as the train inputs. @@ -776,24 +777,24 @@ def from_tensors( test_targets: A sequence of targets (one per test input) to use as the test targets. predict_data: A tensor or collection of tensors to use when predicting. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -819,12 +820,12 @@ def from_tensors( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -842,17 +843,17 @@ def from_numpy( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given numpy array using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` - from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: train_data: A numpy array to use as the train inputs. @@ -863,24 +864,24 @@ def from_numpy( test_targets: A sequence of targets (one per test input) to use as the test targets. predict_data: A numpy array to use when predicting. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -906,12 +907,12 @@ def from_numpy( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -928,18 +929,18 @@ def from_json( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, field: Optional[str] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given JSON files using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.JSON` - from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: input_fields: The field or fields in the JSON objects to use for the input. @@ -949,25 +950,25 @@ def from_json( test_file: The JSON file containing the testing data. predict_file: The JSON file containing the data to use when predicting. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. field: To specify the field that holds the data in the JSON file. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -1016,12 +1017,12 @@ def from_json( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -1038,17 +1039,17 @@ def from_csv( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given CSV files using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.CSV` - from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: input_fields: The field or fields (columns) in the CSV file to use for the input. @@ -1058,24 +1059,24 @@ def from_csv( test_file: The CSV file containing the testing data. predict_file: The CSV file containing the data to use when predicting. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -1102,12 +1103,12 @@ def from_csv( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -1122,17 +1123,17 @@ def from_datasets( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given datasets using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.DATASETS` - from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: train_dataset: Dataset used during training. @@ -1140,24 +1141,24 @@ def from_datasets( test_dataset: Dataset used during testing. predict_dataset: Dataset used during predicting. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -1182,12 +1183,12 @@ def from_datasets( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -1203,17 +1204,17 @@ def from_fiftyone( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given FiftyOne Datasets using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.FIFTYONE` - from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: train_dataset: The ``fiftyone.core.collections.SampleCollection`` containing the train data. @@ -1221,23 +1222,23 @@ def from_fiftyone( test_dataset: The ``fiftyone.core.collections.SampleCollection`` containing the test data. predict_dataset: The ``fiftyone.core.collections.SampleCollection`` containing the predict data. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -1266,11 +1267,11 @@ def from_fiftyone( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -1291,17 +1292,17 @@ def from_labelstudio( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given export file and data directory using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS` - from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: export_json: path to label studio export file @@ -1317,23 +1318,23 @@ def from_labelstudio( test_data_folder: path to label studio data folder for test data predict_data_folder: path to label studio data folder for predict data train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -1350,7 +1351,7 @@ def from_labelstudio( "data_folder": data_folder, "export_json": export_json, "split": val_split, - "multi_label": preprocess_kwargs.get("multi_label", False), + "multi_label": input_transform_kwargs.get("multi_label", False), } train_data = None val_data = None @@ -1360,25 +1361,25 @@ def from_labelstudio( train_data = { "data_folder": train_data_folder or data_folder, "export_json": train_export_json, - "multi_label": preprocess_kwargs.get("multi_label", False), + "multi_label": input_transform_kwargs.get("multi_label", False), } if (val_data_folder or data_folder) and val_export_json: val_data = { "data_folder": val_data_folder or data_folder, "export_json": val_export_json, - "multi_label": preprocess_kwargs.get("multi_label", False), + "multi_label": input_transform_kwargs.get("multi_label", False), } if (test_data_folder or data_folder) and test_export_json: test_data = { "data_folder": test_data_folder or data_folder, "export_json": test_export_json, - "multi_label": preprocess_kwargs.get("multi_label", False), + "multi_label": input_transform_kwargs.get("multi_label", False), } if (predict_data_folder or data_folder) and predict_export_json: predict_data = { "data_folder": predict_data_folder or data_folder, "export_json": predict_export_json, - "multi_label": preprocess_kwargs.get("multi_label", False), + "multi_label": input_transform_kwargs.get("multi_label", False), } return cls.from_data_source( DefaultDataSources.LABELSTUDIO, @@ -1391,9 +1392,9 @@ def from_labelstudio( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index 7f10fbeab0..b671f1347a 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -24,13 +24,19 @@ import flash from flash.core.data.auto_dataset import IterableAutoDataset -from flash.core.data.batch import _DeserializeProcessor, _Preprocessor, _Sequential +from flash.core.data.batch import _DeserializeProcessor from flash.core.data.data_source import DataSource +from flash.core.data.io.input_transform import ( + _InputTransformProcessor, + _InputTransformSequential, + DefaultInputTransform, + InputTransform, +) from flash.core.data.io.output import _OutputProcessor, Output from flash.core.data.io.output_transform import _OutputTransformProcessor, OutputTransform -from flash.core.data.process import DefaultPreprocess, Deserializer, Preprocess +from flash.core.data.process import Deserializer from flash.core.data.properties import ProcessState -from flash.core.data.utils import _OUTPUT_TRANSFORM_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX +from flash.core.data.utils import _INPUT_TRANSFORM_FUNCS, _OUTPUT_TRANSFORM_FUNCS, _STAGES_PREFIX from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_3, _PL_GREATER_EQUAL_1_5_0 from flash.core.utilities.stages import _RUNNING_STAGE_MAPPING, RunningStage @@ -77,37 +83,38 @@ def __str__(self) -> str: class DataPipeline: """ DataPipeline holds the engineering logic to connect - :class:`~flash.core.data.process.Preprocess` and/or :class:`~flash.core.data.io.output_transform.OutputTransform` + :class:`~flash.core.data.io.input_transform.InputTransform` and/or + :class:`~flash.core.data.io.output_transform.OutputTransform` objects to the ``DataModule``, Flash ``Task`` and ``Trainer``. """ - PREPROCESS_FUNCS: Set[str] = _PREPROCESS_FUNCS + INPUT_TRANSFORM_FUNCS: Set[str] = _INPUT_TRANSFORM_FUNCS OUTPUT_TRANSFORM_FUNCS: Set[str] = _OUTPUT_TRANSFORM_FUNCS def __init__( self, data_source: Optional[DataSource] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, output_transform: Optional[OutputTransform] = None, deserializer: Optional[Deserializer] = None, output: Optional[Output] = None, ) -> None: self.data_source = data_source - self._preprocess_pipeline = preprocess or DefaultPreprocess() + self._input_transform_pipeline = input_transform or DefaultInputTransform() self._output_transform = output_transform or OutputTransform() self._output = output or Output() self._deserializer = deserializer or Deserializer() self._running_stage = None def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) -> DataPipelineState: - """Creates the :class:`.DataPipelineState` and gives the reference to the: :class:`.Preprocess`, + """Creates the :class:`.DataPipelineState` and gives the reference to the: :class:`.InputTransform`, :class:`.OutputTransform`, and :class:`.Output`. Once this has been called, any attempt to add new state will give a warning.""" data_pipeline_state = data_pipeline_state or DataPipelineState() if self.data_source is not None: self.data_source.attach_data_pipeline_state(data_pipeline_state) - self._preprocess_pipeline.attach_data_pipeline_state(data_pipeline_state) + self._input_transform_pipeline.attach_data_pipeline_state(data_pipeline_state) self._output_transform.attach_data_pipeline_state(data_pipeline_state) self._output.attach_data_pipeline_state(data_pipeline_state) return data_pipeline_state @@ -155,15 +162,17 @@ def _identity(samples: Sequence[Any]) -> Sequence[Any]: return samples def deserialize_processor(self) -> _DeserializeProcessor: - return self._create_collate_preprocessors(RunningStage.PREDICTING)[0] + return self._create_collate_input_transform_processors(RunningStage.PREDICTING)[0] - def worker_preprocessor( + def worker_input_transform_processor( self, running_stage: RunningStage, collate_fn: Optional[Callable] = None, is_serving: bool = False - ) -> _Preprocessor: - return self._create_collate_preprocessors(running_stage, collate_fn=collate_fn, is_serving=is_serving)[1] + ) -> _InputTransformProcessor: + return self._create_collate_input_transform_processors( + running_stage, collate_fn=collate_fn, is_serving=is_serving + )[1] - def device_preprocessor(self, running_stage: RunningStage) -> _Preprocessor: - return self._create_collate_preprocessors(running_stage)[2] + def device_input_transform_processor(self, running_stage: RunningStage) -> _InputTransformProcessor: + return self._create_collate_input_transform_processors(running_stage)[2] def output_transform_processor(self, running_stage: RunningStage, is_serving=False) -> _OutputTransformProcessor: return self._create_output_transform_processor(running_stage, is_serving=is_serving) @@ -176,7 +185,7 @@ def _resolve_function_hierarchy( cls, function_name, process_obj, stage: RunningStage, object_type: Optional[Type] = None ) -> str: if object_type is None: - object_type = Preprocess + object_type = InputTransform prefixes = [] @@ -202,37 +211,38 @@ def _make_collates(self, on_device: bool, collate: Callable) -> Tuple[Callable, return self._identity, collate return collate, self._identity - def _create_collate_preprocessors( + def _create_collate_input_transform_processors( self, stage: RunningStage, collate_fn: Optional[Callable] = None, is_serving: bool = False, - ) -> Tuple[_DeserializeProcessor, _Preprocessor, _Preprocessor]: + ) -> Tuple[_DeserializeProcessor, _InputTransformProcessor, _InputTransformProcessor]: original_collate_fn = collate_fn - preprocess: Preprocess = self._preprocess_pipeline + input_transform: InputTransform = self._input_transform_pipeline prefix: str = _STAGES_PREFIX[stage] if collate_fn is not None: - preprocess._default_collate = collate_fn + input_transform._default_collate = collate_fn func_names: Dict[str, str] = { - k: self._resolve_function_hierarchy(k, preprocess, stage, Preprocess) for k in self.PREPROCESS_FUNCS + k: self._resolve_function_hierarchy(k, input_transform, stage, InputTransform) + for k in self.INPUT_TRANSFORM_FUNCS } - collate_fn: Callable = getattr(preprocess, func_names["collate"]) + collate_fn: Callable = getattr(input_transform, func_names["collate"]) per_batch_transform_overriden: bool = self._is_overriden_recursive( - "per_batch_transform", preprocess, Preprocess, prefix=prefix + "per_batch_transform", input_transform, InputTransform, prefix=prefix ) per_sample_transform_on_device_overriden: bool = self._is_overriden_recursive( - "per_sample_transform_on_device", preprocess, Preprocess, prefix=prefix + "per_sample_transform_on_device", input_transform, InputTransform, prefix=prefix ) collate_in_worker_from_transform: Optional[bool] = getattr( - preprocess, f"_{prefix}_collate_in_worker_from_transform", None + input_transform, f"_{prefix}_collate_in_worker_from_transform", None ) is_per_overriden = per_batch_transform_overriden and per_sample_transform_on_device_overriden @@ -250,53 +260,55 @@ def _create_collate_preprocessors( ) worker_collate_fn = ( - worker_collate_fn.collate_fn if isinstance(worker_collate_fn, _Preprocessor) else worker_collate_fn + worker_collate_fn.collate_fn + if isinstance(worker_collate_fn, _InputTransformProcessor) + else worker_collate_fn ) assert_contains_tensor = self._is_overriden_recursive( - "to_tensor_transform", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] + "to_tensor_transform", input_transform, InputTransform, prefix=_STAGES_PREFIX[stage] ) deserialize_processor = _DeserializeProcessor( self._deserializer, - preprocess, - getattr(preprocess, func_names["pre_tensor_transform"]), - getattr(preprocess, func_names["to_tensor_transform"]), + input_transform, + getattr(input_transform, func_names["pre_tensor_transform"]), + getattr(input_transform, func_names["to_tensor_transform"]), ) - worker_preprocessor = _Preprocessor( - preprocess, + worker_input_transform_processor = _InputTransformProcessor( + input_transform, worker_collate_fn, - _Sequential( - preprocess, - None if is_serving else getattr(preprocess, func_names["pre_tensor_transform"]), - None if is_serving else getattr(preprocess, func_names["to_tensor_transform"]), - getattr(preprocess, func_names["post_tensor_transform"]), + _InputTransformSequential( + input_transform, + None if is_serving else getattr(input_transform, func_names["pre_tensor_transform"]), + None if is_serving else getattr(input_transform, func_names["to_tensor_transform"]), + getattr(input_transform, func_names["post_tensor_transform"]), stage, assert_contains_tensor=assert_contains_tensor, ), - getattr(preprocess, func_names["per_batch_transform"]), + getattr(input_transform, func_names["per_batch_transform"]), stage, ) - worker_preprocessor._original_collate_fn = original_collate_fn - device_preprocessor = _Preprocessor( - preprocess, + worker_input_transform_processor._original_collate_fn = original_collate_fn + device_input_transform_processor = _InputTransformProcessor( + input_transform, device_collate_fn, - getattr(preprocess, func_names["per_sample_transform_on_device"]), - getattr(preprocess, func_names["per_batch_transform_on_device"]), + getattr(input_transform, func_names["per_sample_transform_on_device"]), + getattr(input_transform, func_names["per_batch_transform_on_device"]), stage, apply_per_sample_transform=device_collate_fn != self._identity, on_device=True, ) - return deserialize_processor, worker_preprocessor, device_preprocessor + return deserialize_processor, worker_input_transform_processor, device_input_transform_processor @staticmethod def _model_transfer_to_device_wrapper( - func: Callable, preprocessor: _Preprocessor, model: "Task", stage: RunningStage + func: Callable, input_transform: _InputTransformProcessor, model: "Task", stage: RunningStage ) -> Callable: if not isinstance(func, _StageOrchestrator): func = _StageOrchestrator(func, model) - func.register_additional_stage(stage, preprocessor) + func.register_additional_stage(stage, input_transform) return func @@ -368,7 +380,7 @@ def _set_loader(model: "Task", loader_name: str, new_loader: DataLoader) -> None setattr(curr_attr, final_name, new_loader) setattr(model, final_name, new_loader) - def _attach_preprocess_to_model( + def _attach_input_transform_to_model( self, model: "Task", stage: Optional[RunningStage] = None, @@ -409,7 +421,7 @@ def _attach_preprocess_to_model( if isinstance(loader, DataLoader): dl_args = {k: v for k, v in vars(loader).items() if not k.startswith("_")} - _, dl_args["collate_fn"], device_collate_fn = self._create_collate_preprocessors( + _, dl_args["collate_fn"], device_collate_fn = self._create_collate_input_transform_processors( stage=stage, collate_fn=dl_args["collate_fn"], is_serving=is_serving ) @@ -474,18 +486,18 @@ def _attach_to_model( is_serving: bool = False, ): # not necessary to detach. preprocessing and postprocessing for stage will be overwritten. - self._attach_preprocess_to_model(model, stage) + self._attach_input_transform_to_model(model, stage) if not stage or stage == RunningStage.PREDICTING: self._attach_output_transform_to_model(model, RunningStage.PREDICTING, is_serving=is_serving) def _detach_from_model(self, model: "Task", stage: Optional[RunningStage] = None): - self._detach_preprocessing_from_model(model, stage) + self._detach_input_transform_from_model(model, stage) if not stage or stage == RunningStage.PREDICTING: self._detach_output_transform_from_model(model) - def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[RunningStage] = None): + def _detach_input_transform_from_model(self, model: "Task", stage: Optional[RunningStage] = None): if not stage: stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] elif isinstance(stage, RunningStage): @@ -530,7 +542,7 @@ def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[Runnin if default_collate: dl_args["collate_fn"] = default_collate - if isinstance(dl_args["collate_fn"], _Preprocessor): + if isinstance(dl_args["collate_fn"], _InputTransformProcessor): dl_args["collate_fn"] = dl_args["collate_fn"]._original_collate_fn if isinstance(dl_args["dataset"], (IterableAutoDataset, IterableDataset)): @@ -559,7 +571,7 @@ def _detach_output_transform_from_model(model: "Task"): def __str__(self) -> str: data_source: DataSource = self.data_source - preprocess: Preprocess = self._preprocess_pipeline + input_transform: InputTransform = self._input_transform_pipeline output_transform: OutputTransform = self._output_transform output: Output = self._output deserializer: Deserializer = self._deserializer @@ -567,7 +579,7 @@ def __str__(self) -> str: f"{self.__class__.__name__}(" f"data_source={str(data_source)}, " f"deserializer={deserializer}, " - f"preprocess={preprocess}, " + f"input_transform={input_transform}, " f"output_transform={output_transform}, " f"output={output})" ) diff --git a/flash/core/data/data_source.py b/flash/core/data/data_source.py index cf279dcbba..bd875e81ac 100644 --- a/flash/core/data/data_source.py +++ b/flash/core/data/data_source.py @@ -266,7 +266,7 @@ def load_sample(sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any Returns: The loaded sample as a mapping with string keys (e.g. "input", "target") that can be processed by the - :meth:`~flash.core.data.process.Preprocess.pre_tensor_transform`. + :meth:`~flash.core.data.io.input_transform.InputTransform.pre_tensor_transform`. Example:: diff --git a/flash/core/data/input_transform.py b/flash/core/data/input_transform.py index a9853ecfda..99604c355d 100644 --- a/flash/core/data/input_transform.py +++ b/flash/core/data/input_transform.py @@ -19,11 +19,12 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data._utils.collate import default_collate -from flash.core.data.data_pipeline import _Preprocessor, DataPipeline +from flash.core.data.data_pipeline import DataPipeline from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input_transform import _InputTransformProcessor from flash.core.data.properties import Properties from flash.core.data.states import CollateFn -from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX +from flash.core.data.utils import _INPUT_TRANSFORM_FUNCS, _STAGES_PREFIX from flash.core.registry import FlashRegistry from flash.core.utilities.stages import RunningStage @@ -119,7 +120,8 @@ def current_transform(self) -> Callable: @property def transforms(self) -> Dict[str, Optional[Dict[str, Callable]]]: - """The transforms currently being used by this :class:`~flash.core.data.process.Preprocess`.""" + """The transforms currently being used by this + :class:`~flash.core.data.io.input_transform.InputTransform`.""" return { "transform": self.transform, } @@ -288,7 +290,7 @@ def _check_transforms( if len(keys_diff) > 0: raise MisconfigurationException( - f"{stage}_transform contains {keys_diff}. Only {_PREPROCESS_FUNCS} keys are supported." + f"{stage}_transform contains {keys_diff}. Only {_INPUT_TRANSFORM_FUNCS} keys are supported." ) is_per_batch_transform_in = "per_batch_transform" in transform @@ -354,14 +356,14 @@ def _make_collates(self, on_device: bool, collate: Callable) -> Tuple[Callable, @property def dataloader_collate_fn(self): """Generate the function to be injected within the DataLoader as the collate_fn.""" - return self._create_collate_preprocessors()[0] + return self._create_collate_input_transform_processors()[0] @property def on_after_batch_transfer_fn(self): """Generate the function to be injected after the on_after_batch_transfer from the LightningModule.""" - return self._create_collate_preprocessors()[1] + return self._create_collate_input_transform_processors()[1] - def _create_collate_preprocessors(self) -> Tuple[Any]: + def _create_collate_input_transform_processors(self) -> Tuple[Any]: prefix: str = _STAGES_PREFIX[self.running_stage] func_names: Dict[str, str] = { @@ -396,17 +398,19 @@ def _create_collate_preprocessors(self) -> Tuple[Any]: ) worker_collate_fn = ( - worker_collate_fn.collate_fn if isinstance(worker_collate_fn, _Preprocessor) else worker_collate_fn + worker_collate_fn.collate_fn + if isinstance(worker_collate_fn, _InputTransformProcessor) + else worker_collate_fn ) - worker_preprocessor = _Preprocessor( + worker_input_transform_processor = _InputTransformProcessor( self, worker_collate_fn, getattr(self, func_names["per_sample_transform"]), getattr(self, func_names["per_batch_transform"]), self.running_stage, ) - device_preprocessor = _Preprocessor( + device_input_transform_processor = _InputTransformProcessor( self, device_collate_fn, getattr(self, func_names["per_sample_transform_on_device"]), @@ -415,4 +419,4 @@ def _create_collate_preprocessors(self) -> Tuple[Any]: apply_per_sample_transform=device_collate_fn != self._identity, on_device=True, ) - return worker_preprocessor, device_preprocessor + return worker_input_transform_processor, device_input_transform_processor diff --git a/flash/core/data/io/input_transform.py b/flash/core/data/io/input_transform.py new file mode 100644 index 0000000000..ae759ed929 --- /dev/null +++ b/flash/core/data/io/input_transform.py @@ -0,0 +1,713 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from abc import ABC, abstractclassmethod, abstractmethod +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union + +import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import Tensor +from torch.utils.data._utils.collate import default_collate + +from flash.core.data.callback import ControlFlow, FlashCallback +from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.process import Deserializer +from flash.core.data.properties import ProcessState, Properties +from flash.core.data.states import ( + CollateFn, + PerBatchTransform, + PerBatchTransformOnDevice, + PerSampleTransformOnDevice, + PostTensorTransform, + PreTensorTransform, + ToTensorTransform, +) +from flash.core.data.transforms import ApplyToKeys +from flash.core.data.utils import ( + _contains_any_tensor, + _INPUT_TRANSFORM_FUNCS, + _STAGES_PREFIX, + convert_to_modules, + CurrentFuncContext, + CurrentRunningStageContext, + CurrentRunningStageFuncContext, +) +from flash.core.utilities.stages import RunningStage + + +class BaseInputTransform(ABC): + @abstractmethod + def get_state_dict(self) -> Dict[str, Any]: + """Override this method to return state_dict.""" + + @abstractclassmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + """Override this method to load from state_dict.""" + + +class InputTransform(BaseInputTransform, Properties): + """The :class:`~flash.core.data.io.input_transform.InputTransform` encapsulates all the data processing logic + that should run before the data is passed to the model. It is particularly useful when you want to provide an + end to end implementation which works with 4 different stages: ``train``, ``validation``, ``test``, and + inference (``predict``). + + The :class:`~flash.core.data.io.input_transform.InputTransform` supports the following hooks: + + - ``pre_tensor_transform``: Performs transforms on a single data sample. + Example:: + + * Input: Receive a PIL Image and its label. + + * Action: Rotate the PIL Image. + + * Output: Return the rotated PIL image and its label. + + - ``to_tensor_transform``: Converts a single data sample to a tensor / data structure containing tensors. + Example:: + + * Input: Receive the rotated PIL Image and its label. + + * Action: Convert the rotated PIL Image to a tensor. + + * Output: Return the tensored image and its label. + + - ``post_tensor_transform``: Performs transform on a single tensor sample. + Example:: + + * Input: Receive the tensored image and its label. + + * Action: Flip the tensored image randomly. + + * Output: Return the tensored image and its label. + + - ``per_batch_transform``: Performs transforms on a batch. + In this example, we decided not to override the hook. + + - ``per_sample_transform_on_device``: Performs transform on a sample already on a ``GPU`` or ``TPU``. + Example:: + + * Input: Receive a tensored image on device and its label. + + * Action: Apply random transforms. + + * Output: Return an augmented tensored image on device and its label. + + - ``collate``: Converts a sequence of data samples into a batch. + Defaults to ``torch.utils.data._utils.collate.default_collate``. + Example:: + + * Input: Receive a list of augmented tensored images and their respective labels. + + * Action: Collate the list of images into batch. + + * Output: Return a batch of images and their labels. + + - ``per_batch_transform_on_device``: Performs transform on a batch already on ``GPU`` or ``TPU``. + Example:: + + * Input: Receive a batch of images and their labels. + + * Action: Apply normalization on the batch by subtracting the mean + and dividing by the standard deviation from ImageNet. + + * Output: Return a normalized augmented batch of images and their labels. + + .. note:: + + The ``per_sample_transform_on_device`` and ``per_batch_transform`` are mutually exclusive + as it will impact performances. + + Data processing can be configured by overriding hooks or through transforms. The input transforms are given as + a mapping from hook names to callables. Default transforms can be configured by overriding the + ``default_transforms`` or ``{train,val,test,predict}_default_transforms`` methods. These can then be overridden by + the user with the ``{train,val,test,predict}_transform`` arguments to the ``InputTransform``. + All of the hooks can be used in the transform mappings. + + Example:: + + class CustomInputTransform(InputTransform): + + def default_transforms() -> Mapping[str, Callable]: + return { + "to_tensor_transform": transforms.ToTensor(), + "collate": torch.utils.data._utils.collate.default_collate, + } + + def train_default_transforms() -> Mapping[str, Callable]: + return { + "pre_tensor_transform": transforms.RandomHorizontalFlip(), + "to_tensor_transform": transforms.ToTensor(), + "collate": torch.utils.data._utils.collate.default_collate, + } + + When overriding hooks for particular stages, you can prefix with ``train``, ``val``, ``test`` or ``predict``. For + example, you can achieve the same as the above example by implementing ``train_pre_tensor_transform`` and + ``train_to_tensor_transform``. + + Example:: + + class CustomInputTransform(InputTransform): + + def train_pre_tensor_transform(self, sample: PIL.Image) -> PIL.Image: + return transforms.RandomHorizontalFlip()(sample) + + def to_tensor_transform(self, sample: PIL.Image) -> torch.Tensor: + return transforms.ToTensor()(sample) + + def collate(self, samples: List[torch.Tensor]) -> torch.Tensor: + return torch.utils.data._utils.collate.default_collate(samples) + + Each hook is aware of the Trainer running stage through booleans. These are useful for adapting functionality for a + stage without duplicating code. + + Example:: + + class CustomInputTransform(InputTransform): + + def pre_tensor_transform(self, sample: PIL.Image) -> PIL.Image: + + if self.training: + # logic for training + + elif self.validating: + # logic for validation + + elif self.testing: + # logic for testing + + elif self.predicting: + # logic for predicting + """ + + def __init__( + self, + train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + predict_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + data_sources: Optional[Dict[str, "DataSource"]] = None, + deserializer: Optional["Deserializer"] = None, + default_data_source: Optional[str] = None, + ): + super().__init__() + + # resolve the default transforms + train_transform = train_transform or self._resolve_transforms(RunningStage.TRAINING) + val_transform = val_transform or self._resolve_transforms(RunningStage.VALIDATING) + test_transform = test_transform or self._resolve_transforms(RunningStage.TESTING) + predict_transform = predict_transform or self._resolve_transforms(RunningStage.PREDICTING) + + # used to keep track of provided transforms + self._train_collate_in_worker_from_transform: Optional[bool] = None + self._val_collate_in_worker_from_transform: Optional[bool] = None + self._predict_collate_in_worker_from_transform: Optional[bool] = None + self._test_collate_in_worker_from_transform: Optional[bool] = None + + # store the transform before conversion to modules. + self.train_transform = self._check_transforms(train_transform, RunningStage.TRAINING) + self.val_transform = self._check_transforms(val_transform, RunningStage.VALIDATING) + self.test_transform = self._check_transforms(test_transform, RunningStage.TESTING) + self.predict_transform = self._check_transforms(predict_transform, RunningStage.PREDICTING) + + self._train_transform = convert_to_modules(self.train_transform) + self._val_transform = convert_to_modules(self.val_transform) + self._test_transform = convert_to_modules(self.test_transform) + self._predict_transform = convert_to_modules(self.predict_transform) + + if DefaultDataSources.DATASETS not in data_sources: + data_sources[DefaultDataSources.DATASETS] = DatasetDataSource() + + self._data_sources = data_sources + self._deserializer = deserializer + self._default_data_source = default_data_source + self._callbacks: List[FlashCallback] = [] + self._default_collate: Callable = default_collate + + @property + def deserializer(self) -> Optional["Deserializer"]: + return self._deserializer + + def _resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]: + from flash.core.data.data_pipeline import DataPipeline + + resolved_function = getattr( + self, DataPipeline._resolve_function_hierarchy("default_transforms", self, running_stage, InputTransform) + ) + + with CurrentRunningStageFuncContext(running_stage, "default_transforms", self): + transforms: Optional[Dict[str, Callable]] = resolved_function() + return transforms + + def _save_to_state_dict(self, destination, prefix, keep_vars): + input_transform_state_dict = self.get_state_dict() + if not isinstance(input_transform_state_dict, Dict): + raise MisconfigurationException("get_state_dict should return a dictionary") + input_transform_state_dict["_meta"] = {} + input_transform_state_dict["_meta"]["module"] = self.__module__ + input_transform_state_dict["_meta"]["class_name"] = self.__class__.__name__ + input_transform_state_dict["_meta"]["_state"] = self._state + destination["input_transform.state_dict"] = input_transform_state_dict + self._ddp_params_and_buffers_to_ignore = ["input_transform.state_dict"] + return super()._save_to_state_dict(destination, prefix, keep_vars) + + def _check_transforms( + self, transform: Optional[Dict[str, Callable]], stage: RunningStage + ) -> Optional[Dict[str, Callable]]: + if transform is None: + return transform + + if isinstance(transform, list): + transform = {"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, torch.nn.Sequential(*transform))} + elif callable(transform): + transform = {"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, transform)} + + if not isinstance(transform, Dict): + raise MisconfigurationException( + "Transform should be a dict. " + f"Here are the available keys for your transforms: {_INPUT_TRANSFORM_FUNCS}." + ) + + keys_diff = set(transform.keys()).difference(_INPUT_TRANSFORM_FUNCS) + + if len(keys_diff) > 0: + raise MisconfigurationException( + f"{stage}_transform contains {keys_diff}. Only {_INPUT_TRANSFORM_FUNCS} keys are supported." + ) + + is_per_batch_transform_in = "per_batch_transform" in transform + is_per_sample_transform_on_device_in = "per_sample_transform_on_device" in transform + + if is_per_batch_transform_in and is_per_sample_transform_on_device_in: + raise MisconfigurationException( + f"{transform}: `per_batch_transform` and `per_sample_transform_on_device` are mutually exclusive." + ) + + collate_in_worker: Optional[bool] = None + + if is_per_batch_transform_in or (not is_per_batch_transform_in and not is_per_sample_transform_on_device_in): + collate_in_worker = True + + elif is_per_sample_transform_on_device_in: + collate_in_worker = False + + setattr(self, f"_{_STAGES_PREFIX[stage]}_collate_in_worker_from_transform", collate_in_worker) + return transform + + @staticmethod + def _identity(x: Any) -> Any: + return x + + def _get_transform(self, transform: Dict[str, Callable]) -> Callable: + if self.current_fn in transform: + return transform[self.current_fn] + return self._identity + + @property + def current_transform(self) -> Callable: + if self.training and self._train_transform: + return self._get_transform(self._train_transform) + if self.validating and self._val_transform: + return self._get_transform(self._val_transform) + if self.testing and self._test_transform: + return self._get_transform(self._test_transform) + if self.predicting and self._predict_transform: + return self._get_transform(self._predict_transform) + return self._identity + + @property + def transforms(self) -> Dict[str, Optional[Dict[str, Callable]]]: + """The transforms currently being used by this + :class:`~flash.core.data.io.input_transform.InputTransform`.""" + return { + "train_transform": self.train_transform, + "val_transform": self.val_transform, + "test_transform": self.test_transform, + "predict_transform": self.predict_transform, + } + + @property + def callbacks(self) -> List["FlashCallback"]: + if not hasattr(self, "_callbacks"): + self._callbacks: List[FlashCallback] = [] + return self._callbacks + + @callbacks.setter + def callbacks(self, callbacks: List["FlashCallback"]): + self._callbacks = callbacks + + def add_callbacks(self, callbacks: List["FlashCallback"]): + _callbacks = [c for c in callbacks if c not in self._callbacks] + self._callbacks.extend(_callbacks) + + @staticmethod + def default_transforms() -> Optional[Dict[str, Callable]]: + """The default transforms to use. + + Will be overridden by transforms passed to the ``__init__``. + """ + + def _apply_sample_transform(self, sample: Any) -> Any: + if isinstance(sample, list): + return [self.current_transform(s) for s in sample] + return self.current_transform(sample) + + def _apply_batch_transform(self, batch: Any): + return self.current_transform(batch) + + def _apply_transform_on_sample(self, sample: Any, transform: Callable): + if isinstance(sample, list): + return [transform(s) for s in sample] + + return transform(sample) + + def _apply_transform_on_batch(self, batch: Any, transform: Callable): + return transform(batch) + + def _apply_process_state_transform( + self, + process_state: ProcessState, + sample: Optional[Any] = None, + batch: Optional[Any] = None, + ): + # assert both sample and batch are not None + if sample is None: + assert batch is not None, "sample not provided, batch should not be None" + mode = "batch" + else: + assert batch is None, "sample provided, batch should be None" + mode = "sample" + + process_state_transform = self.get_state(process_state) + + if process_state_transform is not None: + if process_state_transform.transform is not None: + if mode == "sample": + return self._apply_transform_on_sample(sample, process_state_transform.transform) + else: + return self._apply_transform_on_batch(batch, process_state_transform.transform) + else: + if mode == "sample": + return sample + else: + return batch + else: + if mode == "sample": + return self._apply_sample_transform(sample) + else: + return self._apply_batch_transform(batch) + + def pre_tensor_transform(self, sample: Any) -> Any: + """Transforms to apply on a single object.""" + return self._apply_process_state_transform(PreTensorTransform, sample=sample) + + def to_tensor_transform(self, sample: Any) -> Tensor: + """Transforms to convert single object to a tensor.""" + return self._apply_process_state_transform(ToTensorTransform, sample=sample) + + def post_tensor_transform(self, sample: Tensor) -> Tensor: + """Transforms to apply on a tensor.""" + return self._apply_process_state_transform(PostTensorTransform, sample=sample) + + def per_batch_transform(self, batch: Any) -> Any: + """Transforms to apply to a whole batch (if possible use this for efficiency). + + .. note:: + + This option is mutually exclusive with :meth:`per_sample_transform_on_device`, + since if both are specified, uncollation has to be applied. + """ + return self._apply_process_state_transform(PerBatchTransform, batch=batch) + + def collate(self, samples: Sequence, metadata=None) -> Any: + """Transform to convert a sequence of samples to a collated batch.""" + current_transform = self.current_transform + if current_transform is self._identity: + current_transform = self._default_collate + + # the model can provide a custom ``collate_fn``. + collate_fn = self.get_state(CollateFn) + if collate_fn is not None: + collate_fn = collate_fn.collate_fn + else: + collate_fn = current_transform + # return collate_fn.collate_fn(samples) + + parameters = inspect.signature(collate_fn).parameters + if len(parameters) > 1 and DefaultDataKeys.METADATA in parameters: + return collate_fn(samples, metadata) + return collate_fn(samples) + + def per_sample_transform_on_device(self, sample: Any) -> Any: + """Transforms to apply to the data before the collation (per-sample basis). + + .. note:: + + This option is mutually exclusive with :meth:`per_batch_transform`, + since if both are specified, uncollation has to be applied. + + .. note:: + + This function won't be called within the dataloader workers, since to make that happen + each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). + """ + return self._apply_process_state_transform(PerSampleTransformOnDevice, sample=sample) + + def per_batch_transform_on_device(self, batch: Any) -> Any: + """Transforms to apply to a whole batch (if possible use this for efficiency). + + .. note:: + + This function won't be called within the dataloader workers, since to make that happen + each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). + """ + return self._apply_process_state_transform(PerBatchTransformOnDevice, batch=batch) + + def available_data_sources(self) -> Sequence[str]: + """Get the list of available data source names for use with this + :class:`~flash.core.data.io.input_transform.InputTransform`. + + Returns: + The list of data source names. + """ + return list(self._data_sources.keys()) + + def data_source_of_name(self, data_source_name: str) -> DataSource: + """Get the :class:`~flash.core.data.data_source.DataSource` of the given name from the + :class:`~flash.core.data.io.input_transform.InputTransform`. + + Args: + data_source_name: The name of the data source to look up. + + Returns: + The :class:`~flash.core.data.data_source.DataSource` of the given name. + + Raises: + MisconfigurationException: If the requested data source is not configured by this + :class:`~flash.core.data.io.input_transform.InputTransform`. + """ + if data_source_name == "default": + data_source_name = self._default_data_source + data_sources = self._data_sources + if data_source_name in data_sources: + return data_sources[data_source_name] + raise MisconfigurationException( + f"No '{data_source_name}' data source is available for use with the {type(self)}. The available data " + f"sources are: {', '.join(self.available_data_sources())}." + ) + + +class DefaultInputTransform(InputTransform): + def __init__( + self, + train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + predict_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + data_sources: Optional[Dict[str, "DataSource"]] = None, + default_data_source: Optional[str] = None, + ): + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources=data_sources or {"default": DataSource()}, + default_data_source=default_data_source or "default", + ) + + def get_state_dict(self) -> Dict[str, Any]: + return {**self.transforms} + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): + return cls(**state_dict) + + +class _InputTransformSequential(torch.nn.Module): + """This class is used to chain 3 functions together for the _InputTransformProcessor ``per_sample_transform`` + function. + + 1. ``pre_tensor_transform`` + 2. ``to_tensor_transform`` + 3. ``post_tensor_transform`` + """ + + def __init__( + self, + input_transform: InputTransform, + pre_tensor_transform: Optional[Callable], + to_tensor_transform: Optional[Callable], + post_tensor_transform: Callable, + stage: RunningStage, + assert_contains_tensor: bool = False, + ): + super().__init__() + self.input_transform = input_transform + self.callback = ControlFlow(self.input_transform.callbacks) + self.pre_tensor_transform = convert_to_modules(pre_tensor_transform) + self.to_tensor_transform = convert_to_modules(to_tensor_transform) + self.post_tensor_transform = convert_to_modules(post_tensor_transform) + self.stage = stage + self.assert_contains_tensor = assert_contains_tensor + + self._current_stage_context = CurrentRunningStageContext(stage, input_transform, reset=False) + self._pre_tensor_transform_context = CurrentFuncContext("pre_tensor_transform", input_transform) + self._to_tensor_transform_context = CurrentFuncContext("to_tensor_transform", input_transform) + self._post_tensor_transform_context = CurrentFuncContext("post_tensor_transform", input_transform) + + def forward(self, sample: Any) -> Any: + self.callback.on_load_sample(sample, self.stage) + + with self._current_stage_context: + if self.pre_tensor_transform is not None: + with self._pre_tensor_transform_context: + sample = self.pre_tensor_transform(sample) + self.callback.on_pre_tensor_transform(sample, self.stage) + + if self.to_tensor_transform is not None: + with self._to_tensor_transform_context: + sample = self.to_tensor_transform(sample) + self.callback.on_to_tensor_transform(sample, self.stage) + + if self.assert_contains_tensor: + if not _contains_any_tensor(sample): + raise MisconfigurationException( + "When ``to_tensor_transform`` is overriden, " + "``DataPipeline`` expects the outputs to be ``tensors``" + ) + + with self._post_tensor_transform_context: + sample = self.post_tensor_transform(sample) + self.callback.on_post_tensor_transform(sample, self.stage) + + return sample + + def __str__(self) -> str: + return ( + f"{self.__class__.__name__}:\n" + f"\t(pre_tensor_transform): {str(self.pre_tensor_transform)}\n" + f"\t(to_tensor_transform): {str(self.to_tensor_transform)}\n" + f"\t(post_tensor_transform): {str(self.post_tensor_transform)}\n" + f"\t(assert_contains_tensor): {str(self.assert_contains_tensor)}\n" + f"\t(stage): {str(self.stage)}" + ) + + +class _InputTransformProcessor(torch.nn.Module): + """ + This class is used to encapsultate the following functions of a InputTransformInputTransform Object: + Inside a worker: + per_sample_transform: Function to transform an individual sample + Inside a worker, it is actually make of 3 functions: + * pre_tensor_transform + * to_tensor_transform + * post_tensor_transform + collate: Function to merge sample into a batch + per_batch_transform: Function to transform an individual batch + * per_batch_transform + + Inside main process: + per_sample_transform: Function to transform an individual sample + * per_sample_transform_on_device + collate: Function to merge sample into a batch + per_batch_transform: Function to transform an individual batch + * per_batch_transform_on_device + """ + + def __init__( + self, + input_transform: InputTransform, + collate_fn: Callable, + per_sample_transform: Union[Callable, _InputTransformSequential], + per_batch_transform: Callable, + stage: RunningStage, + apply_per_sample_transform: bool = True, + on_device: bool = False, + ): + super().__init__() + self.input_transform = input_transform + self.callback = ControlFlow(self.input_transform.callbacks) + self.collate_fn = convert_to_modules(collate_fn) + self.per_sample_transform = convert_to_modules(per_sample_transform) + self.per_batch_transform = convert_to_modules(per_batch_transform) + self.apply_per_sample_transform = apply_per_sample_transform + self.stage = stage + self.on_device = on_device + + extension = f"{'_on_device' if self.on_device else ''}" + self._current_stage_context = CurrentRunningStageContext(stage, input_transform) + self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform{extension}", input_transform) + self._collate_context = CurrentFuncContext("collate", input_transform) + self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", input_transform) + + @staticmethod + def _extract_metadata( + samples: List[Dict[str, Any]], + ) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]: + metadata = [s.pop(DefaultDataKeys.METADATA, None) if isinstance(s, Mapping) else None for s in samples] + return samples, metadata if any(m is not None for m in metadata) else None + + def forward(self, samples: Sequence[Any]) -> Any: + # we create a new dict to prevent from potential memory leaks + # assuming that the dictionary samples are stored in between and + # potentially modified before the transforms are applied. + if isinstance(samples, dict): + samples = dict(samples.items()) + + with self._current_stage_context: + + if self.apply_per_sample_transform: + with self._per_sample_transform_context: + _samples = [] + + if isinstance(samples, Mapping): + samples = [samples] + + for sample in samples: + sample = self.per_sample_transform(sample) + if self.on_device: + self.callback.on_per_sample_transform_on_device(sample, self.stage) + _samples.append(sample) + + samples = type(_samples)(_samples) + + with self._collate_context: + samples, metadata = self._extract_metadata(samples) + try: + samples = self.collate_fn(samples, metadata) + except TypeError: + samples = self.collate_fn(samples) + if metadata and isinstance(samples, dict): + samples[DefaultDataKeys.METADATA] = metadata + self.callback.on_collate(samples, self.stage) + + with self._per_batch_transform_context: + samples = self.per_batch_transform(samples) + if self.on_device: + self.callback.on_per_batch_transform_on_device(samples, self.stage) + else: + self.callback.on_per_batch_transform(samples, self.stage) + return samples + + def __str__(self) -> str: + # todo: define repr function which would take object and string attributes to be shown + return ( + "_InputTransformProcessor:\n" + f"\t(per_sample_transform): {str(self.per_sample_transform)}\n" + f"\t(collate_fn): {str(self.collate_fn)}\n" + f"\t(per_batch_transform): {str(self.per_batch_transform)}\n" + f"\t(apply_per_sample_transform): {str(self.apply_per_sample_transform)}\n" + f"\t(on_device): {str(self.on_device)}\n" + f"\t(stage): {str(self.stage)}" + ) diff --git a/flash/core/data/new_data_module.py b/flash/core/data/new_data_module.py index ed309646f2..69b370c278 100644 --- a/flash/core/data/new_data_module.py +++ b/flash/core/data/new_data_module.py @@ -26,9 +26,9 @@ from flash.core.data.base_viz import BaseVisualization from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DefaultPreprocess from flash.core.data.datasets import BaseDataset from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform +from flash.core.data.io.input_transform import DefaultInputTransform from flash.core.data.io.output_transform import OutputTransform from flash.core.registry import FlashRegistry from flash.core.utilities.stages import RunningStage @@ -44,7 +44,7 @@ class DataModule(DataModule): test_dataset: Dataset to test model performance. Defaults to None. predict_dataset: Dataset for predicting. Defaults to None. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to attach to the - :class:`~flash.core.data.process.Preprocess`. If ``None``, the output from + :class:`~flash.core.data.io.input_transform.InputTransform`. If ``None``, the output from :meth:`~flash.core.data.data_module.DataModule.configure_data_fetcher` will be used. val_split: An optional float which gives the relative amount of the training dataset to use for the validation dataset. @@ -56,7 +56,7 @@ class DataModule(DataModule): Will be passed to the DataLoader for the training dataset. Defaults to None. """ - preprocess_cls = DefaultPreprocess + input_transform_cls = DefaultInputTransform output_transform_cls = OutputTransform flash_datasets_registry = FlashRegistry("datasets") diff --git a/flash/core/data/process.py b/flash/core/data/process.py index 4396171923..027f992017 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -12,520 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -import inspect -from abc import ABC, abstractclassmethod, abstractmethod -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union +from abc import abstractmethod +from typing import Any, Mapping +from warnings import warn -import torch -from _warnings import warn from deprecate import deprecated -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch import Tensor -from torch.utils.data._utils.collate import default_collate import flash -from flash.core.data.callback import FlashCallback -from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataKeys, DefaultDataSources from flash.core.data.io.output import Output -from flash.core.data.properties import ProcessState, Properties -from flash.core.data.states import ( - CollateFn, - PerBatchTransform, - PerBatchTransformOnDevice, - PerSampleTransformOnDevice, - PostTensorTransform, - PreTensorTransform, - ToTensorTransform, -) -from flash.core.data.transforms import ApplyToKeys -from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules, CurrentRunningStageFuncContext -from flash.core.utilities.stages import RunningStage - - -class BasePreprocess(ABC): - @abstractmethod - def get_state_dict(self) -> Dict[str, Any]: - """Override this method to return state_dict.""" - - @abstractclassmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): - """Override this method to load from state_dict.""" - - -class Preprocess(BasePreprocess, Properties): - """The :class:`~flash.core.data.process.Preprocess` encapsulates all the data processing logic that should run - before the data is passed to the model. It is particularly useful when you want to provide an end to end - implementation which works with 4 different stages: ``train``, ``validation``, ``test``, and inference - (``predict``). - - The :class:`~flash.core.data.process.Preprocess` supports the following hooks: - - - ``pre_tensor_transform``: Performs transforms on a single data sample. - Example:: - - * Input: Receive a PIL Image and its label. - - * Action: Rotate the PIL Image. - - * Output: Return the rotated PIL image and its label. - - - ``to_tensor_transform``: Converts a single data sample to a tensor / data structure containing tensors. - Example:: - - * Input: Receive the rotated PIL Image and its label. - - * Action: Convert the rotated PIL Image to a tensor. - - * Output: Return the tensored image and its label. - - - ``post_tensor_transform``: Performs transform on a single tensor sample. - Example:: - - * Input: Receive the tensored image and its label. - - * Action: Flip the tensored image randomly. - - * Output: Return the tensored image and its label. - - - ``per_batch_transform``: Performs transforms on a batch. - In this example, we decided not to override the hook. - - - ``per_sample_transform_on_device``: Performs transform on a sample already on a ``GPU`` or ``TPU``. - Example:: - - * Input: Receive a tensored image on device and its label. - - * Action: Apply random transforms. - - * Output: Return an augmented tensored image on device and its label. - - - ``collate``: Converts a sequence of data samples into a batch. - Defaults to ``torch.utils.data._utils.collate.default_collate``. - Example:: - - * Input: Receive a list of augmented tensored images and their respective labels. - - * Action: Collate the list of images into batch. - - * Output: Return a batch of images and their labels. - - - ``per_batch_transform_on_device``: Performs transform on a batch already on ``GPU`` or ``TPU``. - Example:: - - * Input: Receive a batch of images and their labels. - - * Action: Apply normalization on the batch by subtracting the mean - and dividing by the standard deviation from ImageNet. - - * Output: Return a normalized augmented batch of images and their labels. - - .. note:: - - The ``per_sample_transform_on_device`` and ``per_batch_transform`` are mutually exclusive - as it will impact performances. - - Data processing can be configured by overriding hooks or through transforms. The preprocess transforms are given as - a mapping from hook names to callables. Default transforms can be configured by overriding the - ``default_transforms`` or ``{train,val,test,predict}_default_transforms`` methods. These can then be overridden by - the user with the ``{train,val,test,predict}_transform`` arguments to the ``Preprocess``. All of the hooks can be - used in the transform mappings. - - Example:: - - class CustomPreprocess(Preprocess): - - def default_transforms() -> Mapping[str, Callable]: - return { - "to_tensor_transform": transforms.ToTensor(), - "collate": torch.utils.data._utils.collate.default_collate, - } - - def train_default_transforms() -> Mapping[str, Callable]: - return { - "pre_tensor_transform": transforms.RandomHorizontalFlip(), - "to_tensor_transform": transforms.ToTensor(), - "collate": torch.utils.data._utils.collate.default_collate, - } - - When overriding hooks for particular stages, you can prefix with ``train``, ``val``, ``test`` or ``predict``. For - example, you can achieve the same as the above example by implementing ``train_pre_tensor_transform`` and - ``train_to_tensor_transform``. - - Example:: - - class CustomPreprocess(Preprocess): - - def train_pre_tensor_transform(self, sample: PIL.Image) -> PIL.Image: - return transforms.RandomHorizontalFlip()(sample) - - def to_tensor_transform(self, sample: PIL.Image) -> torch.Tensor: - return transforms.ToTensor()(sample) - - def collate(self, samples: List[torch.Tensor]) -> torch.Tensor: - return torch.utils.data._utils.collate.default_collate(samples) - - Each hook is aware of the Trainer running stage through booleans. These are useful for adapting functionality for a - stage without duplicating code. - - Example:: - - class CustomPreprocess(Preprocess): - - def pre_tensor_transform(self, sample: PIL.Image) -> PIL.Image: - - if self.training: - # logic for training - - elif self.validating: - # logic for validation - - elif self.testing: - # logic for testing - - elif self.predicting: - # logic for predicting - """ - - def __init__( - self, - train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - predict_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - data_sources: Optional[Dict[str, "DataSource"]] = None, - deserializer: Optional["Deserializer"] = None, - default_data_source: Optional[str] = None, - ): - super().__init__() - - # resolve the default transforms - train_transform = train_transform or self._resolve_transforms(RunningStage.TRAINING) - val_transform = val_transform or self._resolve_transforms(RunningStage.VALIDATING) - test_transform = test_transform or self._resolve_transforms(RunningStage.TESTING) - predict_transform = predict_transform or self._resolve_transforms(RunningStage.PREDICTING) - - # used to keep track of provided transforms - self._train_collate_in_worker_from_transform: Optional[bool] = None - self._val_collate_in_worker_from_transform: Optional[bool] = None - self._predict_collate_in_worker_from_transform: Optional[bool] = None - self._test_collate_in_worker_from_transform: Optional[bool] = None - - # store the transform before conversion to modules. - self.train_transform = self._check_transforms(train_transform, RunningStage.TRAINING) - self.val_transform = self._check_transforms(val_transform, RunningStage.VALIDATING) - self.test_transform = self._check_transforms(test_transform, RunningStage.TESTING) - self.predict_transform = self._check_transforms(predict_transform, RunningStage.PREDICTING) - - self._train_transform = convert_to_modules(self.train_transform) - self._val_transform = convert_to_modules(self.val_transform) - self._test_transform = convert_to_modules(self.test_transform) - self._predict_transform = convert_to_modules(self.predict_transform) - - if DefaultDataSources.DATASETS not in data_sources: - data_sources[DefaultDataSources.DATASETS] = DatasetDataSource() - - self._data_sources = data_sources - self._deserializer = deserializer - self._default_data_source = default_data_source - self._callbacks: List[FlashCallback] = [] - self._default_collate: Callable = default_collate - - @property - def deserializer(self) -> Optional["Deserializer"]: - return self._deserializer - - def _resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]: - from flash.core.data.data_pipeline import DataPipeline - - resolved_function = getattr( - self, DataPipeline._resolve_function_hierarchy("default_transforms", self, running_stage, Preprocess) - ) - - with CurrentRunningStageFuncContext(running_stage, "default_transforms", self): - transforms: Optional[Dict[str, Callable]] = resolved_function() - return transforms - - def _save_to_state_dict(self, destination, prefix, keep_vars): - preprocess_state_dict = self.get_state_dict() - if not isinstance(preprocess_state_dict, Dict): - raise MisconfigurationException("get_state_dict should return a dictionary") - preprocess_state_dict["_meta"] = {} - preprocess_state_dict["_meta"]["module"] = self.__module__ - preprocess_state_dict["_meta"]["class_name"] = self.__class__.__name__ - preprocess_state_dict["_meta"]["_state"] = self._state - destination["preprocess.state_dict"] = preprocess_state_dict - self._ddp_params_and_buffers_to_ignore = ["preprocess.state_dict"] - return super()._save_to_state_dict(destination, prefix, keep_vars) - - def _check_transforms( - self, transform: Optional[Dict[str, Callable]], stage: RunningStage - ) -> Optional[Dict[str, Callable]]: - if transform is None: - return transform - - if isinstance(transform, list): - transform = {"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, torch.nn.Sequential(*transform))} - elif callable(transform): - transform = {"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, transform)} - - if not isinstance(transform, Dict): - raise MisconfigurationException( - "Transform should be a dict. " f"Here are the available keys for your transforms: {_PREPROCESS_FUNCS}." - ) - - keys_diff = set(transform.keys()).difference(_PREPROCESS_FUNCS) - - if len(keys_diff) > 0: - raise MisconfigurationException( - f"{stage}_transform contains {keys_diff}. Only {_PREPROCESS_FUNCS} keys are supported." - ) - - is_per_batch_transform_in = "per_batch_transform" in transform - is_per_sample_transform_on_device_in = "per_sample_transform_on_device" in transform - - if is_per_batch_transform_in and is_per_sample_transform_on_device_in: - raise MisconfigurationException( - f"{transform}: `per_batch_transform` and `per_sample_transform_on_device` are mutually exclusive." - ) - - collate_in_worker: Optional[bool] = None - - if is_per_batch_transform_in or (not is_per_batch_transform_in and not is_per_sample_transform_on_device_in): - collate_in_worker = True - - elif is_per_sample_transform_on_device_in: - collate_in_worker = False - - setattr(self, f"_{_STAGES_PREFIX[stage]}_collate_in_worker_from_transform", collate_in_worker) - return transform - - @staticmethod - def _identity(x: Any) -> Any: - return x - - def _get_transform(self, transform: Dict[str, Callable]) -> Callable: - if self.current_fn in transform: - return transform[self.current_fn] - return self._identity - - @property - def current_transform(self) -> Callable: - if self.training and self._train_transform: - return self._get_transform(self._train_transform) - if self.validating and self._val_transform: - return self._get_transform(self._val_transform) - if self.testing and self._test_transform: - return self._get_transform(self._test_transform) - if self.predicting and self._predict_transform: - return self._get_transform(self._predict_transform) - return self._identity - - @property - def transforms(self) -> Dict[str, Optional[Dict[str, Callable]]]: - """The transforms currently being used by this :class:`~flash.core.data.process.Preprocess`.""" - return { - "train_transform": self.train_transform, - "val_transform": self.val_transform, - "test_transform": self.test_transform, - "predict_transform": self.predict_transform, - } - - @property - def callbacks(self) -> List["FlashCallback"]: - if not hasattr(self, "_callbacks"): - self._callbacks: List[FlashCallback] = [] - return self._callbacks - - @callbacks.setter - def callbacks(self, callbacks: List["FlashCallback"]): - self._callbacks = callbacks - - def add_callbacks(self, callbacks: List["FlashCallback"]): - _callbacks = [c for c in callbacks if c not in self._callbacks] - self._callbacks.extend(_callbacks) - - @staticmethod - def default_transforms() -> Optional[Dict[str, Callable]]: - """The default transforms to use. - - Will be overridden by transforms passed to the ``__init__``. - """ - - def _apply_sample_transform(self, sample: Any) -> Any: - if isinstance(sample, list): - return [self.current_transform(s) for s in sample] - return self.current_transform(sample) - - def _apply_batch_transform(self, batch: Any): - return self.current_transform(batch) - - def _apply_transform_on_sample(self, sample: Any, transform: Callable): - if isinstance(sample, list): - return [transform(s) for s in sample] - - return transform(sample) - - def _apply_transform_on_batch(self, batch: Any, transform: Callable): - return transform(batch) - - def _apply_process_state_transform( - self, - process_state: ProcessState, - sample: Optional[Any] = None, - batch: Optional[Any] = None, - ): - # assert both sample and batch are not None - if sample is None: - assert batch is not None, "sample not provided, batch should not be None" - mode = "batch" - else: - assert batch is None, "sample provided, batch should be None" - mode = "sample" - - process_state_transform = self.get_state(process_state) - - if process_state_transform is not None: - if process_state_transform.transform is not None: - if mode == "sample": - return self._apply_transform_on_sample(sample, process_state_transform.transform) - else: - return self._apply_transform_on_batch(batch, process_state_transform.transform) - else: - if mode == "sample": - return sample - else: - return batch - else: - if mode == "sample": - return self._apply_sample_transform(sample) - else: - return self._apply_batch_transform(batch) - - def pre_tensor_transform(self, sample: Any) -> Any: - """Transforms to apply on a single object.""" - return self._apply_process_state_transform(PreTensorTransform, sample=sample) - - def to_tensor_transform(self, sample: Any) -> Tensor: - """Transforms to convert single object to a tensor.""" - return self._apply_process_state_transform(ToTensorTransform, sample=sample) - - def post_tensor_transform(self, sample: Tensor) -> Tensor: - """Transforms to apply on a tensor.""" - return self._apply_process_state_transform(PostTensorTransform, sample=sample) - - def per_batch_transform(self, batch: Any) -> Any: - """Transforms to apply to a whole batch (if possible use this for efficiency). - - .. note:: - - This option is mutually exclusive with :meth:`per_sample_transform_on_device`, - since if both are specified, uncollation has to be applied. - """ - return self._apply_process_state_transform(PerBatchTransform, batch=batch) - - def collate(self, samples: Sequence, metadata=None) -> Any: - """Transform to convert a sequence of samples to a collated batch.""" - current_transform = self.current_transform - if current_transform is self._identity: - current_transform = self._default_collate - - # the model can provide a custom ``collate_fn``. - collate_fn = self.get_state(CollateFn) - if collate_fn is not None: - collate_fn = collate_fn.collate_fn - else: - collate_fn = current_transform - # return collate_fn.collate_fn(samples) - - parameters = inspect.signature(collate_fn).parameters - if len(parameters) > 1 and DefaultDataKeys.METADATA in parameters: - return collate_fn(samples, metadata) - return collate_fn(samples) - - def per_sample_transform_on_device(self, sample: Any) -> Any: - """Transforms to apply to the data before the collation (per-sample basis). - - .. note:: - - This option is mutually exclusive with :meth:`per_batch_transform`, - since if both are specified, uncollation has to be applied. - - .. note:: - - This function won't be called within the dataloader workers, since to make that happen - each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). - """ - return self._apply_process_state_transform(PerSampleTransformOnDevice, sample=sample) - - def per_batch_transform_on_device(self, batch: Any) -> Any: - """Transforms to apply to a whole batch (if possible use this for efficiency). - - .. note:: - - This function won't be called within the dataloader workers, since to make that happen - each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). - """ - return self._apply_process_state_transform(PerBatchTransformOnDevice, batch=batch) - - def available_data_sources(self) -> Sequence[str]: - """Get the list of available data source names for use with this - :class:`~flash.core.data.process.Preprocess`. - - Returns: - The list of data source names. - """ - return list(self._data_sources.keys()) - - def data_source_of_name(self, data_source_name: str) -> DataSource: - """Get the :class:`~flash.core.data.data_source.DataSource` of the given name from the - :class:`~flash.core.data.process.Preprocess`. - - Args: - data_source_name: The name of the data source to look up. - - Returns: - The :class:`~flash.core.data.data_source.DataSource` of the given name. - - Raises: - MisconfigurationException: If the requested data source is not configured by this - :class:`~flash.core.data.process.Preprocess`. - """ - if data_source_name == "default": - data_source_name = self._default_data_source - data_sources = self._data_sources - if data_source_name in data_sources: - return data_sources[data_source_name] - raise MisconfigurationException( - f"No '{data_source_name}' data source is available for use with the {type(self)}. The available data " - f"sources are: {', '.join(self.available_data_sources())}." - ) - - -class DefaultPreprocess(Preprocess): - def __init__( - self, - train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - predict_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - data_sources: Optional[Dict[str, "DataSource"]] = None, - default_data_source: Optional[str] = None, - ): - super().__init__( - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, - data_sources=data_sources or {"default": DataSource()}, - default_data_source=default_data_source or "default", - ) - - def get_state_dict(self) -> Dict[str, Any]: - return {**self.transforms} - - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): - return cls(**state_dict) +from flash.core.data.properties import Properties class Deserializer(Properties): diff --git a/flash/core/data/transforms.py b/flash/core/data/transforms.py index 42a5d40fcb..6cb7adf4b4 100644 --- a/flash/core/data/transforms.py +++ b/flash/core/data/transforms.py @@ -17,7 +17,7 @@ from torch import nn from torch.utils.data._utils.collate import default_collate -from flash.core.data.utils import _PREPROCESS_FUNCS, convert_to_modules +from flash.core.data.utils import _INPUT_TRANSFORM_FUNCS, convert_to_modules class ApplyToKeys(nn.Sequential): @@ -135,7 +135,7 @@ def merge_transforms( The new dictionary of transforms. """ transforms = {} - for hook in _PREPROCESS_FUNCS: + for hook in _INPUT_TRANSFORM_FUNCS: if hook in base_transforms and hook in additional_transforms: transforms[hook] = nn.Sequential( convert_to_modules(base_transforms[hook]), diff --git a/flash/core/data/utils.py b/flash/core/data/utils.py index 3252952660..37244f5b1b 100644 --- a/flash/core/data/utils.py +++ b/flash/core/data/utils.py @@ -47,7 +47,7 @@ "load_sample", } -_PREPROCESS_FUNCS: Set[str] = { +_INPUT_TRANSFORM_FUNCS: Set[str] = { "pre_tensor_transform", "to_tensor_transform", "post_tensor_transform", @@ -59,7 +59,7 @@ _CALLBACK_FUNCS: Set[str] = { "load_sample", - *_PREPROCESS_FUNCS, + *_INPUT_TRANSFORM_FUNCS, } _OUTPUT_TRANSFORM_FUNCS: Set[str] = { diff --git a/flash/core/model.py b/flash/core/model.py index ffe7063ea8..6bf1ad900c 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -39,9 +39,10 @@ from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.data_pipeline import DataPipeline, DataPipelineState from flash.core.data.data_source import DataSource +from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output import Output from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.process import Deserializer, DeserializerMapping, Preprocess +from flash.core.data.process import Deserializer, DeserializerMapping from flash.core.data.properties import ProcessState from flash.core.optimizers.optimizers import _OPTIMIZERS_REGISTRY from flash.core.optimizers.schedulers import _SCHEDULERS_REGISTRY @@ -53,6 +54,7 @@ from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import ( DESERIALIZER_TYPE, + INPUT_TRANSFORM_TYPE, LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, @@ -60,7 +62,6 @@ OPTIMIZER_TYPE, OUTPUT_TRANSFORM_TYPE, OUTPUT_TYPE, - PREPROCESS_TYPE, ) @@ -318,7 +319,8 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=Check `metric(preds,target)` and return a single scalar tensor. deserializer: Either a single :class:`~flash.core.data.process.Deserializer` or a mapping of these to deserialize the input - preprocess: :class:`~flash.core.data.process.Preprocess` to use as the default for this task. + input_transform: :class:`~flash.core.data.io.input_transform.InputTransform` to use as the default + for this task. output_transform: :class:`~flash.core.data.io.output_transform.OutputTransform` to use as the default for this task. output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. @@ -338,7 +340,7 @@ def __init__( lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: METRICS_TYPE = None, deserializer: DESERIALIZER_TYPE = None, - preprocess: PREPROCESS_TYPE = None, + input_transform: INPUT_TRANSFORM_TYPE = None, output_transform: OUTPUT_TRANSFORM_TYPE = None, output: OUTPUT_TYPE = None, ): @@ -357,7 +359,7 @@ def __init__( self.save_hyperparameters("learning_rate", "optimizer") self._deserializer: Optional[Deserializer] = None - self._preprocess: Optional[Preprocess] = preprocess + self._input_transform: Optional[InputTransform] = input_transform self._output_transform: Optional[OutputTransform] = output_transform self._output: Optional[Output] = None @@ -490,13 +492,13 @@ def predict( dataset = data_pipeline.data_source.generate_dataset(x, running_stage) dataloader = self.process_predict_dataset(dataset) x = list(dataloader.dataset) - x = data_pipeline.worker_preprocessor(running_stage, collate_fn=dataloader.collate_fn)(x) + x = data_pipeline.worker_input_transform_processor(running_stage, collate_fn=dataloader.collate_fn)(x) # todo (tchaton): Remove this when sync with Lightning master. if len(inspect.signature(self.transfer_batch_to_device).parameters) == 3: x = self.transfer_batch_to_device(x, self.device, 0) else: x = self.transfer_batch_to_device(x, self.device) - x = data_pipeline.device_preprocessor(running_stage)(x) + x = data_pipeline.device_input_transform_processor(running_stage)(x) x = x[0] if isinstance(x, list) else x predictions = self.predict_step(x, 0) # batch_idx is always 0 when running with `model.predict` predictions = data_pipeline.output_transform_processor(running_stage)(predictions) @@ -569,30 +571,31 @@ def configure_finetune_callback() -> List[Callback]: @staticmethod def _resolve( old_deserializer: Optional[Deserializer], - old_preprocess: Optional[Preprocess], + old_input_transform: Optional[InputTransform], old_output_transform: Optional[OutputTransform], old_output: Optional[Output], new_deserializer: Optional[Deserializer], - new_preprocess: Optional[Preprocess], + new_input_transform: Optional[InputTransform], new_output_transform: Optional[OutputTransform], new_output: Optional[Output], - ) -> Tuple[Optional[Deserializer], Optional[Preprocess], Optional[OutputTransform], Optional[Output]]: - """Resolves the correct :class:`~flash.core.data.process.Preprocess`, - :class:`~flash.core.data.io.output_transform.OutputTransform`, and :class:`~flash.core.data.io.output.Output` to - use, choosing ``new_*`` if it is not None or a base class (:class:`~flash.core.data.process.Preprocess`, + ) -> Tuple[Optional[Deserializer], Optional[InputTransform], Optional[OutputTransform], Optional[Output]]: + """Resolves the correct :class:`~flash.core.data.io.input_transform.InputTransform`, + :class:`~flash.core.data.io.output_transform.OutputTransform`, and :class:`~flash.core.data.io.output.Output` + to use, choosing ``new_*`` if it is not None or a base class + (:class:`~flash.core.data.io.input_transform.InputTransform`, :class:`~flash.core.data.io.output_transform.OutputTransform`, or :class:`~flash.core.data.io.output.Output`) and ``old_*`` otherwise. Args: - old_preprocess: :class:`~flash.core.data.process.Preprocess` to be overridden. + old_input_transform: :class:`~flash.core.data.io.input_transform.InputTransform` to be overridden. old_output_transform: :class:`~flash.core.data.io.output_transform.OutputTransform` to be overridden. old_output: :class:`~flash.core.data.io.output.Output` to be overridden. - new_preprocess: :class:`~flash.core.data.process.Preprocess` to override with. + new_input_transform: :class:`~flash.core.data.io.input_transform.InputTransform` to override with. new_output_transform: :class:`~flash.core.data.io.output_transform.OutputTransform` to override with. new_output: :class:`~flash.core.data.io.output.Output` to override with. Returns: - The resolved :class:`~flash.core.data.process.Preprocess`, + The resolved :class:`~flash.core.data.io.input_transform.InputTransform`, :class:`~flash.core.data.io.output_transform.OutputTransform`, and :class:`~flash.core.data.io.output.Output`. """ @@ -600,9 +603,9 @@ def _resolve( if new_deserializer is not None and type(new_deserializer) != Deserializer: deserializer = new_deserializer - preprocess = old_preprocess - if new_preprocess is not None and type(new_preprocess) != Preprocess: - preprocess = new_preprocess + input_transform = old_input_transform + if new_input_transform is not None and type(new_input_transform) != InputTransform: + input_transform = new_input_transform output_transform = old_output_transform if new_output_transform is not None and type(new_output_transform) != OutputTransform: @@ -612,7 +615,7 @@ def _resolve( if new_output is not None and type(new_output) != Output: output = new_output - return deserializer, preprocess, output_transform, output + return deserializer, input_transform, output_transform, output @torch.jit.unused @property @@ -673,7 +676,8 @@ def build_data_pipeline( data_pipeline: Optional[DataPipeline] = None, ) -> Optional[DataPipeline]: """Build a :class:`.DataPipeline` incorporating available - :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.io.output_transform.OutputTransform` + :class:`~flash.core.data.io.input_transform.InputTransform` and + :class:`~flash.core.data.io.output_transform.OutputTransform` objects. These will be overridden in the following resolution order (lowest priority first): - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`. @@ -686,13 +690,13 @@ def build_data_pipeline( the current data source format used. deserializer: deserializer to use data_pipeline: Optional highest priority source of - :class:`~flash.core.data.process.Preprocess` and + :class:`~flash.core.data.io.input_transform.InputTransform` and :class:`~flash.core.data.io.output_transform.OutputTransform`. Returns: The fully resolved :class:`.DataPipeline`. """ - deserializer, old_data_source, preprocess, output_transform, output = None, None, None, None, None + deserializer, old_data_source, input_transform, output_transform, output = None, None, None, None, None # Datamodule datamodule = None @@ -703,32 +707,32 @@ def build_data_pipeline( if getattr(datamodule, "data_pipeline", None) is not None: old_data_source = getattr(datamodule.data_pipeline, "data_source", None) - preprocess = getattr(datamodule.data_pipeline, "_preprocess_pipeline", None) + input_transform = getattr(datamodule.data_pipeline, "_input_transform_pipeline", None) output_transform = getattr(datamodule.data_pipeline, "_output_transform", None) output = getattr(datamodule.data_pipeline, "_output", None) deserializer = getattr(datamodule.data_pipeline, "_deserializer", None) # Defaults / task attributes - deserializer, preprocess, output_transform, output = Task._resolve( + deserializer, input_transform, output_transform, output = Task._resolve( deserializer, - preprocess, + input_transform, output_transform, output, self._deserializer, - self._preprocess, + self._input_transform, self._output_transform, self._output, ) # Datapipeline if data_pipeline is not None: - deserializer, preprocess, output_transform, output = Task._resolve( + deserializer, input_transform, output_transform, output = Task._resolve( deserializer, - preprocess, + input_transform, output_transform, output, getattr(data_pipeline, "_deserializer", None), - getattr(data_pipeline, "_preprocess_pipeline", None), + getattr(data_pipeline, "_input_transform_pipeline", None), getattr(data_pipeline, "_output_transform", None), getattr(data_pipeline, "_output", None), ) @@ -736,15 +740,15 @@ def build_data_pipeline( data_source = data_source or old_data_source if isinstance(data_source, str): - if preprocess is None: + if input_transform is None: data_source = DataSource() # TODO: warn the user that we are not using the specified data source else: - data_source = preprocess.data_source_of_name(data_source) + data_source = input_transform.data_source_of_name(data_source) if deserializer is None or type(deserializer) is Deserializer: - deserializer = getattr(preprocess, "deserializer", deserializer) + deserializer = getattr(input_transform, "deserializer", deserializer) - data_pipeline = DataPipeline(data_source, preprocess, output_transform, deserializer, output) + data_pipeline = DataPipeline(data_source, input_transform, output_transform, deserializer, output) self._data_pipeline_state = self._data_pipeline_state or DataPipelineState() self.attach_data_pipeline_state(self._data_pipeline_state) self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) @@ -768,25 +772,25 @@ def data_pipeline(self) -> DataPipeline: @torch.jit.unused @data_pipeline.setter def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None: - self._deserializer, self._preprocess, self._output_transform, self.output = Task._resolve( + self._deserializer, self._input_transform, self._output_transform, self.output = Task._resolve( self._deserializer, - self._preprocess, + self._input_transform, self._output_transform, self._output, getattr(data_pipeline, "_deserializer", None), - getattr(data_pipeline, "_preprocess_pipeline", None), + getattr(data_pipeline, "_input_transform_pipeline", None), getattr(data_pipeline, "_output_transform", None), getattr(data_pipeline, "_output", None), ) - # self._preprocess.state_dict() - if getattr(self._preprocess, "_ddp_params_and_buffers_to_ignore", None): - self._ddp_params_and_buffers_to_ignore = self._preprocess._ddp_params_and_buffers_to_ignore + # self._input_transform.state_dict() + if getattr(self._input_transform, "_ddp_params_and_buffers_to_ignore", None): + self._ddp_params_and_buffers_to_ignore = self._input_transform._ddp_params_and_buffers_to_ignore @torch.jit.unused @property - def preprocess(self) -> Preprocess: - return getattr(self.data_pipeline, "_preprocess_pipeline", None) + def input_transform(self) -> InputTransform: + return getattr(self.data_pipeline, "_input_transform_pipeline", None) @torch.jit.unused @property @@ -1048,22 +1052,23 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): - if "preprocess.state_dict" in state_dict: + if "input_transform.state_dict" in state_dict: try: - preprocess_state_dict = state_dict["preprocess.state_dict"] - meta = preprocess_state_dict["_meta"] + input_transform_state_dict = state_dict["input_transform.state_dict"] + meta = input_transform_state_dict["_meta"] cls = getattr(import_module(meta["module"]), meta["class_name"]) - self._preprocess = cls.load_state_dict( - {k: v for k, v in preprocess_state_dict.items() if k != "_meta"}, + self._input_transform = cls.load_state_dict( + {k: v for k, v in input_transform_state_dict.items() if k != "_meta"}, strict=strict, ) - self._preprocess._state = meta["_state"] - del state_dict["preprocess.state_dict"] - del preprocess_state_dict["_meta"] + self._input_transform._state = meta["_state"] + del state_dict["input_transform.state_dict"] + del input_transform_state_dict["_meta"] except (ModuleNotFoundError, KeyError): - meta = state_dict["preprocess.state_dict"]["_meta"] + meta = state_dict["input_transform.state_dict"]["_meta"] raise MisconfigurationException( - f"The `Preprocess` {meta['module']}.{meta['class_name']} has been moved and couldn't be imported." + f"The `InputTransform` {meta['module']}.{meta['class_name']}" + "has been moved and couldn't be imported." ) super()._load_from_state_dict( diff --git a/flash/core/serve/flash_components.py b/flash/core/serve/flash_components.py index 70f36879a5..1ff8574c9b 100644 --- a/flash/core/serve/flash_components.py +++ b/flash/core/serve/flash_components.py @@ -55,8 +55,12 @@ def __init__(self, model): self.model = model self.model.eval() self.data_pipeline = model.build_data_pipeline() - self.worker_preprocessor = self.data_pipeline.worker_preprocessor(RunningStage.PREDICTING, is_serving=True) - self.device_preprocessor = self.data_pipeline.device_preprocessor(RunningStage.PREDICTING) + self.worker_input_transform_processor = self.data_pipeline.worker_input_transform_processor( + RunningStage.PREDICTING, is_serving=True + ) + self.device_input_transform_processor = self.data_pipeline.device_input_transform_processor( + RunningStage.PREDICTING + ) self.output_transform_processor = self.data_pipeline.output_transform_processor( RunningStage.PREDICTING, is_serving=True ) @@ -70,12 +74,12 @@ def __init__(self, model): ) def predict(self, inputs): with torch.no_grad(): - inputs = self.worker_preprocessor(inputs) + inputs = self.worker_input_transform_processor(inputs) if self.extra_arguments: inputs = self.model.transfer_batch_to_device(inputs, self.device, 0) else: inputs = self.model.transfer_batch_to_device(inputs, self.device) - inputs = self.device_preprocessor(inputs) + inputs = self.device_input_transform_processor(inputs) preds = self.model.predict_step(inputs, 0) preds = self.output_transform_processor(preds) return preds diff --git a/flash/core/utilities/flash_cli.py b/flash/core/utilities/flash_cli.py index dda14e0b4b..f015da385f 100644 --- a/flash/core/utilities/flash_cli.py +++ b/flash/core/utilities/flash_cli.py @@ -180,7 +180,7 @@ def parse_arguments(self) -> None: def add_arguments_to_parser(self, parser) -> None: subcommands = parser.add_subcommands() - data_sources = self.local_datamodule_class.preprocess_cls().available_data_sources() + data_sources = self.local_datamodule_class.input_transform_cls().available_data_sources() for data_source in data_sources: if isinstance(data_source, DefaultDataSources): @@ -201,10 +201,12 @@ def add_arguments_to_parser(self, parser) -> None: def add_subcommand_from_function(self, subcommands, function, function_name=None): subcommand = ArgumentParser() datamodule_function = class_from_function(drop_kwargs(function)) - preprocess_function = class_from_function(drop_kwargs(self.local_datamodule_class.preprocess_cls)) + input_transform_function = class_from_function(drop_kwargs(self.local_datamodule_class.input_transform_cls)) subcommand.add_class_arguments(datamodule_function, fail_untyped=False) subcommand.add_class_arguments( - preprocess_function, fail_untyped=False, skip=get_overlapping_args(datamodule_function, preprocess_function) + input_transform_function, + fail_untyped=False, + skip=get_overlapping_args(datamodule_function, input_transform_function), ) subcommand_name = function_name or function.__name__ subcommands.add_subcommand(subcommand_name, subcommand) diff --git a/flash/core/utilities/types.py b/flash/core/utilities/types.py index 21f45be693..2059bf9066 100644 --- a/flash/core/utilities/types.py +++ b/flash/core/utilities/types.py @@ -3,9 +3,10 @@ from torch import nn from torchmetrics import Metric +from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output import Output from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.process import Deserializer, Preprocess +from flash.core.data.process import Deserializer MODEL_TYPE = Optional[nn.Module] LOSS_FN_TYPE = Optional[Union[Callable, Mapping, Sequence]] @@ -15,6 +16,6 @@ ] METRICS_TYPE = Union[Metric, Mapping, Sequence, None] DESERIALIZER_TYPE = Optional[Union[Deserializer, Mapping[str, Deserializer]]] -PREPROCESS_TYPE = Optional[Preprocess] +INPUT_TRANSFORM_TYPE = Optional[InputTransform] OUTPUT_TRANSFORM_TYPE = Optional[OutputTransform] OUTPUT_TYPE = Optional[Output] diff --git a/flash/graph/classification/cli.py b/flash/graph/classification/cli.py index d8fd18702c..56bc053aca 100644 --- a/flash/graph/classification/cli.py +++ b/flash/graph/classification/cli.py @@ -23,7 +23,7 @@ def from_tu_dataset( val_split: float = 0.1, batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> GraphClassificationData: """Downloads and loads the TU Dataset.""" from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE @@ -40,7 +40,7 @@ def from_tu_dataset( val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/graph/classification/data.py b/flash/graph/classification/data.py index 7d8ff5d2a0..e997a3279e 100644 --- a/flash/graph/classification/data.py +++ b/flash/graph/classification/data.py @@ -15,7 +15,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataSources -from flash.core.data.process import Preprocess +from flash.core.data.io.input_transform import InputTransform from flash.core.utilities.imports import _GRAPH_AVAILABLE from flash.graph.data import GraphDatasetDataSource @@ -24,7 +24,7 @@ from torch_geometric.transforms import NormalizeFeatures -class GraphClassificationPreprocess(Preprocess): +class GraphClassificationInputTransform(InputTransform): def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -58,7 +58,7 @@ def default_transforms() -> Optional[Dict[str, Callable]]: class GraphClassificationData(DataModule): """Data module for graph classification tasks.""" - preprocess_cls = GraphClassificationPreprocess + input_transform_cls = GraphClassificationInputTransform @property def num_features(self): diff --git a/flash/image/__init__.py b/flash/image/__init__.py index 788a15ca40..c86531edf5 100644 --- a/flash/image/__init__.py +++ b/flash/image/__init__.py @@ -1,6 +1,6 @@ from flash.image.classification import ( # noqa: F401 ImageClassificationData, - ImageClassificationPreprocess, + ImageClassificationInputTransform, ImageClassifier, ) from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES # noqa: F401 @@ -12,6 +12,6 @@ from flash.image.segmentation import ( # noqa: F401 SemanticSegmentation, SemanticSegmentationData, - SemanticSegmentationPreprocess, + SemanticSegmentationInputTransform, ) -from flash.image.style_transfer import StyleTransfer, StyleTransferData, StyleTransferPreprocess # noqa: F401 +from flash.image.style_transfer import StyleTransfer, StyleTransferData, StyleTransferInputTransform # noqa: F401 diff --git a/flash/image/classification/__init__.py b/flash/image/classification/__init__.py index 9c2ee8298c..b8c8fdcffa 100644 --- a/flash/image/classification/__init__.py +++ b/flash/image/classification/__init__.py @@ -1,2 +1,2 @@ -from flash.image.classification.data import ImageClassificationData, ImageClassificationPreprocess # noqa: F401 +from flash.image.classification.data import ImageClassificationData, ImageClassificationInputTransform # noqa: F401 from flash.image.classification.model import ImageClassifier # noqa: F401 diff --git a/flash/image/classification/cli.py b/flash/image/classification/cli.py index 4056387b86..c74c89ded0 100644 --- a/flash/image/classification/cli.py +++ b/flash/image/classification/cli.py @@ -22,7 +22,7 @@ def from_hymenoptera( batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> ImageClassificationData: """Downloads and loads the Hymenoptera (Ants, Bees) data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data") @@ -31,14 +31,14 @@ def from_hymenoptera( val_folder="data/hymenoptera_data/val/", batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) def from_movie_posters( batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> ImageClassificationData: """Downloads and loads the movie posters genre classification data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "./data") @@ -49,7 +49,7 @@ def from_movie_posters( val_file="data/movie_posters/val/metadata.csv", batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index 641747c754..45fd5a1648 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -22,7 +22,8 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, LoaderDataFrameDataSource -from flash.core.data.process import Deserializer, Preprocess +from flash.core.data.io.input_transform import InputTransform +from flash.core.data.process import Deserializer from flash.core.integrations.labelstudio.data_source import LabelStudioImageClassificationDataSource from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, Image, requires from flash.core.utilities.stages import RunningStage @@ -54,7 +55,7 @@ def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> return sample -class ImageClassificationPreprocess(Preprocess): +class ImageClassificationInputTransform(InputTransform): """Preprocssing of data of image classification. Args:: @@ -115,7 +116,7 @@ def train_default_transforms(self) -> Optional[Dict[str, Callable]]: class ImageClassificationData(DataModule): """Data module for image classification tasks.""" - preprocess_cls = ImageClassificationPreprocess + input_transform_cls = ImageClassificationInputTransform @classmethod def from_data_frame( @@ -139,12 +140,12 @@ def from_data_frame( test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.image.classification.data.ImageClassificationData` object from the given pandas ``DataFrame`` objects. @@ -173,24 +174,24 @@ def from_data_frame( predict_resolver: The function to use to resolve filenames given the ``predict_images_root`` and IDs from the ``input_field`` column. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -206,12 +207,12 @@ def from_data_frame( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -236,17 +237,17 @@ def from_csv( test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.image.classification.data.ImageClassificationData` object from the given CSV files using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.CSV` from the passed or constructed - :class:`~flash.core.data.process.Preprocess`. + :class:`~flash.core.data.io.input_transform.InputTransform`. Args: input_field: The field (column) in the CSV file to use for the input. @@ -272,24 +273,24 @@ def from_csv( predict_resolver: The function to use to resolve filenames given the ``predict_images_root`` and IDs from the ``input_field`` column. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -305,12 +306,12 @@ def from_csv( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, - **preprocess_kwargs, + **input_transform_kwargs, ) def set_block_viz_window(self, value: bool) -> None: diff --git a/flash/image/detection/cli.py b/flash/image/detection/cli.py index f955e34bbe..a2ebaeacb0 100644 --- a/flash/image/detection/cli.py +++ b/flash/image/detection/cli.py @@ -23,7 +23,7 @@ def from_coco_128( val_split: float = 0.1, batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> ObjectDetectionData: """Downloads and loads the COCO 128 data set.""" download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") @@ -33,7 +33,7 @@ def from_coco_128( val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index 9a7e5c31fa..9d57c35233 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -16,7 +16,7 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, FiftyOneDataSource -from flash.core.data.process import Preprocess +from flash.core.data.io.input_transform import InputTransform from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource from flash.core.integrations.icevision.transforms import default_transforms from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, lazy_import, requires @@ -133,7 +133,7 @@ def predict_load_data(data: SampleCollection, dataset: Optional[Any] = None) -> return [{DefaultDataKeys.INPUT: f} for f in data.values("filepath")] -class ObjectDetectionPreprocess(Preprocess): +class ObjectDetectionInputTransform(InputTransform): def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -180,7 +180,7 @@ def train_default_transforms(self) -> Optional[Dict[str, Callable]]: class ObjectDetectionData(DataModule): - preprocess_cls = ObjectDetectionPreprocess + input_transform_cls = ObjectDetectionInputTransform @classmethod def from_coco( @@ -197,11 +197,11 @@ def from_coco( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ): """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders and annotation files in the COCO format. @@ -215,23 +215,23 @@ def from_coco( test_ann_file: The COCO format annotation file. predict_folder: The folder containing the predict data. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -254,11 +254,11 @@ def from_coco( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -276,11 +276,11 @@ def from_voc( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ): """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders and annotation files in the VOC format. @@ -294,23 +294,23 @@ def from_voc( test_ann_file: The COCO format annotation file. predict_folder: The folder containing the predict data. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -333,11 +333,11 @@ def from_voc( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -355,11 +355,11 @@ def from_via( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ): """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders and annotation files in the VIA format. @@ -373,23 +373,23 @@ def from_via( test_ann_file: The COCO format annotation file. predict_folder: The folder containing the predict data. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -412,9 +412,9 @@ def from_via( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/image/face_detection/data.py b/flash/image/face_detection/data.py index 663862e975..c78a03899d 100644 --- a/flash/image/face_detection/data.py +++ b/flash/image/face_detection/data.py @@ -18,8 +18,8 @@ from torch.utils.data import Dataset from flash.core.data.data_source import DatasetDataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.process import Preprocess from flash.core.data.transforms import ApplyToKeys from flash.core.data.utils import image_default_loader from flash.core.utilities.imports import _FASTFACE_AVAILABLE, _TORCHVISION_AVAILABLE @@ -98,7 +98,7 @@ def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str return sample -class FaceDetectionPreprocess(Preprocess): +class FaceDetectionInputTransform(InputTransform): """Applies default transform and collate_fn for fastface on FastFaceDataSource.""" def __init__( @@ -169,5 +169,5 @@ def per_batch_transform(batch: Any) -> Any: class FaceDetectionData(ObjectDetectionData): - preprocess_cls = FaceDetectionPreprocess + input_transform_cls = FaceDetectionInputTransform output_transform_cls = FaceDetectionOutputTransform diff --git a/flash/image/face_detection/model.py b/flash/image/face_detection/model.py index a19bab84a7..a4beb464d4 100644 --- a/flash/image/face_detection/model.py +++ b/flash/image/face_detection/model.py @@ -22,15 +22,15 @@ from flash.core.model import Task from flash.core.utilities.imports import _FASTFACE_AVAILABLE from flash.core.utilities.types import ( + INPUT_TRANSFORM_TYPE, LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE, - PREPROCESS_TYPE, ) from flash.image.face_detection.backbones import FACE_DETECTION_BACKBONES -from flash.image.face_detection.data import FaceDetectionPreprocess +from flash.image.face_detection.data import FaceDetectionInputTransform if _FASTFACE_AVAILABLE: import fastface as ff @@ -81,7 +81,7 @@ def __init__( lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 1e-4, output: OUTPUT_TYPE = None, - preprocess: PREPROCESS_TYPE = None, + input_transform: INPUT_TRANSFORM_TYPE = None, **kwargs: Any, ): self.save_hyperparameters() @@ -99,7 +99,7 @@ def __init__( optimizer=optimizer, lr_scheduler=lr_scheduler, output=output or DetectionLabels(), - preprocess=preprocess or FaceDetectionPreprocess(), + input_transform=input_transform or FaceDetectionInputTransform(), ) @staticmethod @@ -112,7 +112,7 @@ def get_model( # following steps are required since `get_model` needs to return `torch.nn.Module` # moving some required parameters from `fastface.FaceDetector` to `torch.nn.Module` - # set preprocess params + # set input_transform params model.register_buffer("normalizer", getattr(pl_model, "normalizer")) model.register_buffer("mean", getattr(pl_model, "mean")) model.register_buffer("std", getattr(pl_model, "std")) diff --git a/flash/image/instance_segmentation/cli.py b/flash/image/instance_segmentation/cli.py index 97960ae5c9..a8536b90af 100644 --- a/flash/image/instance_segmentation/cli.py +++ b/flash/image/instance_segmentation/cli.py @@ -30,7 +30,7 @@ def from_pets( batch_size: int = 4, num_workers: int = 0, parser: Optional[Callable] = None, - **preprocess_kwargs, + **input_transform_kwargs, ) -> InstanceSegmentationData: """Downloads and loads the pets data set from icedata.""" data_dir = icedata.pets.load_data() @@ -44,7 +44,7 @@ def from_pets( batch_size=batch_size, num_workers=num_workers, parser=parser, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index 15d70d7c95..52ff2c197e 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -16,8 +16,8 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources +from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.process import Preprocess from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource from flash.core.integrations.icevision.transforms import default_transforms from flash.core.utilities.imports import _ICEVISION_AVAILABLE @@ -29,7 +29,7 @@ VOCMaskParser = object -class InstanceSegmentationPreprocess(Preprocess): +class InstanceSegmentationInputTransform(InputTransform): def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -79,7 +79,7 @@ def uncollate(batch: Any) -> Any: class InstanceSegmentationData(DataModule): - preprocess_cls = InstanceSegmentationPreprocess + input_transform_cls = InstanceSegmentationInputTransform output_transform_cls = InstanceSegmentationOutputTransform @classmethod @@ -97,11 +97,11 @@ def from_coco( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ): """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the given data folders and annotation files in the COCO format. @@ -115,23 +115,23 @@ def from_coco( test_ann_file: The COCO format annotation file. predict_folder: The folder containing the predict data. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -154,11 +154,11 @@ def from_coco( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -176,11 +176,11 @@ def from_voc( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ): """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the given data folders and annotation files in the VOC format. @@ -194,23 +194,23 @@ def from_voc( test_ann_file: The COCO format annotation file. predict_folder: The folder containing the predict data. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -233,9 +233,9 @@ def from_voc( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index cea0d9b7b5..50c1936b9e 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -21,7 +21,10 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE from flash.image.instance_segmentation.backbones import INSTANCE_SEGMENTATION_HEADS -from flash.image.instance_segmentation.data import InstanceSegmentationOutputTransform, InstanceSegmentationPreprocess +from flash.image.instance_segmentation.data import ( + InstanceSegmentationInputTransform, + InstanceSegmentationOutputTransform, +) class InstanceSegmentation(AdapterTask): @@ -90,5 +93,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: "If you'd like to change this, extend the InstanceSegmentation Task and override `on_load_checkpoint`." ) self.data_pipeline = DataPipeline( - preprocess=InstanceSegmentationPreprocess(), output_transform=InstanceSegmentationOutputTransform() + input_transform=InstanceSegmentationInputTransform(), + output_transform=InstanceSegmentationOutputTransform(), ) diff --git a/flash/image/keypoint_detection/cli.py b/flash/image/keypoint_detection/cli.py index 959328a51c..7bc753a033 100644 --- a/flash/image/keypoint_detection/cli.py +++ b/flash/image/keypoint_detection/cli.py @@ -29,7 +29,7 @@ def from_biwi( batch_size: int = 4, num_workers: int = 0, parser: Optional[Callable] = None, - **preprocess_kwargs, + **input_transform_kwargs, ) -> KeypointDetectionData: """Downloads and loads the BIWI data set from icedata.""" data_dir = icedata.biwi.load_data() @@ -43,7 +43,7 @@ def from_biwi( batch_size=batch_size, num_workers=num_workers, parser=parser, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py index 97948a7d40..3885c6cc04 100644 --- a/flash/image/keypoint_detection/data.py +++ b/flash/image/keypoint_detection/data.py @@ -16,7 +16,7 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataSources -from flash.core.data.process import Preprocess +from flash.core.data.io.input_transform import InputTransform from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource from flash.core.integrations.icevision.transforms import default_transforms from flash.core.utilities.imports import _ICEVISION_AVAILABLE @@ -27,7 +27,7 @@ COCOKeyPointsParser = object -class KeypointDetectionPreprocess(Preprocess): +class KeypointDetectionInputTransform(InputTransform): def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -70,7 +70,7 @@ def train_default_transforms(self) -> Optional[Dict[str, Callable]]: class KeypointDetectionData(DataModule): - preprocess_cls = KeypointDetectionPreprocess + input_transform_cls = KeypointDetectionInputTransform @classmethod def from_coco( @@ -87,11 +87,11 @@ def from_coco( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ): """Creates a :class:`~flash.image.keypoint_detection.data.KeypointDetectionData` object from the given data folders and annotation files in the COCO format. @@ -105,23 +105,23 @@ def from_coco( test_ann_file: The COCO format annotation file. predict_folder: The folder containing the predict data. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -144,9 +144,9 @@ def from_coco( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/image/segmentation/__init__.py b/flash/image/segmentation/__init__.py index 8def354fd0..6929888f46 100644 --- a/flash/image/segmentation/__init__.py +++ b/flash/image/segmentation/__init__.py @@ -1,2 +1,2 @@ -from flash.image.segmentation.data import SemanticSegmentationData, SemanticSegmentationPreprocess # noqa: F401 +from flash.image.segmentation.data import SemanticSegmentationData, SemanticSegmentationInputTransform # noqa: F401 from flash.image.segmentation.model import SemanticSegmentation # noqa: F401 diff --git a/flash/image/segmentation/cli.py b/flash/image/segmentation/cli.py index 2e92877015..674bbb86e6 100644 --- a/flash/image/segmentation/cli.py +++ b/flash/image/segmentation/cli.py @@ -24,7 +24,7 @@ def from_carla( val_split: float = 0.1, batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> SemanticSegmentationData: """Downloads and loads the CARLA capture data set.""" download_data( @@ -38,7 +38,7 @@ def from_carla( batch_size=batch_size, num_workers=num_workers, num_classes=num_classes, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index 0e55db2ffc..920986ce80 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -33,7 +33,8 @@ PathsDataSource, TensorDataSource, ) -from flash.core.data.process import Deserializer, Preprocess +from flash.core.data.io.input_transform import InputTransform +from flash.core.data.process import Deserializer from flash.core.data.utils import image_default_loader from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, @@ -217,7 +218,7 @@ def deserialize(self, data: str) -> torch.Tensor: return result -class SemanticSegmentationPreprocess(Preprocess): +class SemanticSegmentationInputTransform(InputTransform): def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -230,7 +231,7 @@ def __init__( labels_map: Dict[int, Tuple[int, int, int]] = None, **data_source_kwargs: Any, ) -> None: - """Preprocess pipeline for semantic segmentation tasks. + """InputTransform pipeline for semantic segmentation tasks. Args: train_transform: Dictionary with the set of transforms to apply during training. @@ -291,7 +292,7 @@ def predict_default_transforms(self) -> Optional[Dict[str, Callable]]: class SemanticSegmentationData(DataModule): """Data module for semantic segmentation tasks.""" - preprocess_cls = SemanticSegmentationPreprocess + input_transform_cls = SemanticSegmentationInputTransform @staticmethod def configure_data_fetcher( @@ -316,19 +317,19 @@ def from_data_source( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": - if "num_classes" not in preprocess_kwargs: + if "num_classes" not in input_transform_kwargs: raise MisconfigurationException("`num_classes` should be provided during instantiation.") - num_classes = preprocess_kwargs["num_classes"] + num_classes = input_transform_kwargs["num_classes"] - labels_map = getattr(preprocess_kwargs, "labels_map", None) or SegmentationLabels.create_random_labels_map( + labels_map = getattr(input_transform_kwargs, "labels_map", None) or SegmentationLabels.create_random_labels_map( num_classes ) @@ -348,11 +349,11 @@ def from_data_source( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) if dm.train_dataset is not None: @@ -374,13 +375,13 @@ def from_folders( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, num_classes: Optional[int] = None, labels_map: Dict[int, Tuple[int, int, int]] = None, - **preprocess_kwargs, + **input_transform_kwargs, ) -> "DataModule": """Creates a :class:`~flash.image.segmentation.data.SemanticSegmentationData` object from the given data folders and corresponding target folders. @@ -397,25 +398,25 @@ def from_folders( corresponding inputs). predict_folder: The folder containing the predict data. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_classes: Number of classes within the segmentation mask. labels_map: Mapping between a class_id and its corresponding color. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -438,13 +439,13 @@ def from_folders( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, num_classes=num_classes, labels_map=labels_map, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/image/style_transfer/__init__.py b/flash/image/style_transfer/__init__.py index 401c253609..fc1f862ced 100644 --- a/flash/image/style_transfer/__init__.py +++ b/flash/image/style_transfer/__init__.py @@ -1,3 +1,3 @@ from flash.image.style_transfer.backbones import STYLE_TRANSFER_BACKBONES # noqa: F401 -from flash.image.style_transfer.data import StyleTransferData, StyleTransferPreprocess # noqa: F401 +from flash.image.style_transfer.data import StyleTransferData, StyleTransferInputTransform # noqa: F401 from flash.image.style_transfer.model import StyleTransfer # noqa: F401 diff --git a/flash/image/style_transfer/cli.py b/flash/image/style_transfer/cli.py index 0aab00a4e3..652e94c82b 100644 --- a/flash/image/style_transfer/cli.py +++ b/flash/image/style_transfer/cli.py @@ -24,7 +24,7 @@ def from_coco_128( batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> StyleTransferData: """Downloads and loads the COCO 128 data set.""" download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") @@ -32,7 +32,7 @@ def from_coco_128( train_folder="data/coco128/images/train2017/", batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index f9f63c5905..73a6621f31 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -19,7 +19,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Preprocess +from flash.core.data.io.input_transform import InputTransform from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _TORCHVISION_AVAILABLE from flash.image.classification import ImageClassificationData @@ -29,7 +29,7 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as T -__all__ = ["StyleTransferPreprocess", "StyleTransferData"] +__all__ = ["StyleTransferInputTransform", "StyleTransferData"] def _apply_to_input( @@ -46,7 +46,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]: return wrapper -class StyleTransferPreprocess(Preprocess): +class StyleTransferInputTransform(InputTransform): def __init__( self, train_transform: Optional[Union[Dict[str, Callable]]] = None, @@ -107,7 +107,7 @@ def default_transforms(self) -> Optional[Dict[str, Callable]]: class StyleTransferData(ImageClassificationData): - preprocess_cls = StyleTransferPreprocess + input_transform_cls = StyleTransferInputTransform @classmethod def from_folders( @@ -116,7 +116,7 @@ def from_folders( predict_folder: Optional[Union[str, pathlib.Path]] = None, train_transform: Optional[Union[str, Dict]] = None, predict_transform: Optional[Union[str, Dict]] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, **kwargs: Any, ) -> "DataModule": @@ -126,7 +126,7 @@ def from_folders( if any(param in kwargs and kwargs[param] is not None for param in ("test_folder", "test_transform")): raise_not_supported("test") - preprocess = preprocess or cls.preprocess_cls( + input_transform = input_transform or cls.input_transform_cls( train_transform=train_transform, predict_transform=predict_transform, ) @@ -135,6 +135,6 @@ def from_folders( DefaultDataSources.FOLDERS, train_data=train_folder, predict_data=predict_folder, - preprocess=preprocess, + input_transform=input_transform, **kwargs, ) diff --git a/flash/pointcloud/detection/cli.py b/flash/pointcloud/detection/cli.py index 1acbef5efa..bdb7663a25 100644 --- a/flash/pointcloud/detection/cli.py +++ b/flash/pointcloud/detection/cli.py @@ -22,7 +22,7 @@ def from_kitti( batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> PointCloudObjectDetectorData: """Downloads and loads the KITTI data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_tiny.zip", "data/") @@ -31,7 +31,7 @@ def from_kitti( val_folder="data/KITTI_Tiny/Kitti/val", batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py index e565c358b7..2e6b43f795 100644 --- a/flash/pointcloud/detection/data.py +++ b/flash/pointcloud/detection/data.py @@ -5,7 +5,8 @@ from flash.core.data.base_viz import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import BaseDataFormat, DataSource, DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Deserializer, Preprocess +from flash.core.data.io.input_transform import InputTransform +from flash.core.data.process import Deserializer from flash.pointcloud.detection.open3d_ml.data_sources import ( PointCloudObjectDetectionDataFormat, PointCloudObjectDetectorFoldersDataSource, @@ -35,7 +36,7 @@ def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any: } -class PointCloudObjectDetectorPreprocess(Preprocess): +class PointCloudObjectDetectorInputTransform(InputTransform): def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -72,7 +73,7 @@ def load_state_dict(cls, state_dict, strict: bool = False): class PointCloudObjectDetectorData(DataModule): - preprocess_cls = PointCloudObjectDetectorPreprocess + input_transform_cls = PointCloudObjectDetectorInputTransform @classmethod def from_folders( @@ -86,7 +87,7 @@ def from_folders( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, @@ -95,12 +96,12 @@ def from_folders( labels_folder_name: Optional[str] = "labels", calibrations_folder_name: Optional[str] = "calibs", data_format: Optional[BaseDataFormat] = PointCloudObjectDetectionDataFormat.KITTI, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given folders using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS` - from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: train_folder: The folder containing the train data. @@ -108,24 +109,24 @@ def from_folders( test_folder: The folder containing the test data. predict_folder: The folder containing the predict data. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. scans_folder_name: The name of the pointcloud scan folder labels_folder_name: The name of the pointcloud scan labels folder calibrations_folder_name: The name of the pointcloud scan calibration folder @@ -154,7 +155,7 @@ def from_folders( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, @@ -163,5 +164,5 @@ def from_folders( labels_folder_name=labels_folder_name, calibrations_folder_name=calibrations_folder_name, data_format=data_format, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index d0671b8648..7fb0500483 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -159,7 +159,7 @@ def _process_dataset( if not _POINTCLOUD_AVAILABLE: raise ModuleNotFoundError("Please, run `pip install flash[pointcloud]`.") - dataset.preprocess_fn = self.model.preprocess + dataset.input_transform_fn = self.model.preprocess dataset.transform_fn = self.model.transform return DataLoader( diff --git a/flash/pointcloud/detection/open3d_ml/data_sources.py b/flash/pointcloud/detection/open3d_ml/data_sources.py index ba7b84f670..6b945b0b2d 100644 --- a/flash/pointcloud/detection/open3d_ml/data_sources.py +++ b/flash/pointcloud/detection/open3d_ml/data_sources.py @@ -191,9 +191,9 @@ def load_sample(self, metadata: Dict[str, str], dataset: Optional[BaseAutoDatase data, metadata = self.loader.load_sample(metadata, dataset) - preprocess_fn = getattr(dataset, "preprocess_fn", None) - if preprocess_fn: - data = preprocess_fn(data, metadata) + input_transform_fn = getattr(dataset, "input_transform_fn", None) + if input_transform_fn: + data = input_transform_fn(data, metadata) transform_fn = getattr(dataset, "transform_fn", None) if transform_fn: @@ -230,9 +230,9 @@ def predict_load_sample( data, metadata = self.loader.predict_load_sample(metadata, dataset) - preprocess_fn = getattr(dataset, "preprocess_fn", None) - if preprocess_fn: - data = preprocess_fn(data, metadata) + input_transform_fn = getattr(dataset, "input_transform_fn", None) + if input_transform_fn: + data = input_transform_fn(data, metadata) transform_fn = getattr(dataset, "transform_fn", None) if transform_fn: diff --git a/flash/pointcloud/segmentation/cli.py b/flash/pointcloud/segmentation/cli.py index 26a147d68b..3e142bb535 100644 --- a/flash/pointcloud/segmentation/cli.py +++ b/flash/pointcloud/segmentation/cli.py @@ -22,7 +22,7 @@ def from_kitti( batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> PointCloudSegmentationData: """Downloads and loads the semantic KITTI data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiTiny.zip", "data/") @@ -31,7 +31,7 @@ def from_kitti( val_folder="data/SemanticKittiTiny/val", batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/pointcloud/segmentation/data.py b/flash/pointcloud/segmentation/data.py index 984296ae9c..05293dffe3 100644 --- a/flash/pointcloud/segmentation/data.py +++ b/flash/pointcloud/segmentation/data.py @@ -2,7 +2,8 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Deserializer, Preprocess +from flash.core.data.io.input_transform import InputTransform +from flash.core.data.process import Deserializer from flash.core.utilities.imports import requires from flash.pointcloud.segmentation.open3d_ml.sequences_dataset import SequencesDataset @@ -52,7 +53,7 @@ def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any: } -class PointCloudSegmentationPreprocess(Preprocess): +class PointCloudSegmentationInputTransform(InputTransform): def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -90,4 +91,4 @@ def load_state_dict(cls, state_dict, strict: bool = False): class PointCloudSegmentationData(DataModule): - preprocess_cls = PointCloudSegmentationPreprocess + input_transform_cls = PointCloudSegmentationInputTransform diff --git a/flash/tabular/__init__.py b/flash/tabular/__init__.py index b62af944c0..ae025f94b9 100644 --- a/flash/tabular/__init__.py +++ b/flash/tabular/__init__.py @@ -3,6 +3,6 @@ from flash.tabular.forecasting.data import ( # noqa: F401 TabularForecastingData, TabularForecastingDataFrameDataSource, - TabularForecastingPreprocess, + TabularForecastingInputTransform, ) from flash.tabular.regression import TabularRegressionData, TabularRegressor # noqa: F401 diff --git a/flash/tabular/classification/cli.py b/flash/tabular/classification/cli.py index fe3f8a9ae3..7be4c08818 100644 --- a/flash/tabular/classification/cli.py +++ b/flash/tabular/classification/cli.py @@ -23,7 +23,7 @@ def from_titanic( batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> TabularClassificationData: """Downloads and loads the Titanic data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data") @@ -35,7 +35,7 @@ def from_titanic( val_split=0.1, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/tabular/data.py b/flash/tabular/data.py index 6c981019bd..b21e5e485c 100644 --- a/flash/tabular/data.py +++ b/flash/tabular/data.py @@ -22,8 +22,9 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.process import Deserializer, Preprocess +from flash.core.data.process import Deserializer from flash.core.utilities.imports import _PANDAS_AVAILABLE from flash.tabular.classification.utils import ( _compute_normalization, @@ -158,7 +159,7 @@ def example_input(self) -> str: return str(DataFrame.from_dict(row).to_csv()) -class TabularPreprocess(Preprocess): +class TabularInputTransform(InputTransform): def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -231,7 +232,7 @@ def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: } @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> "Preprocess": + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> "InputTransform": return cls(**state_dict) @@ -243,7 +244,7 @@ def uncollate(self, batch: Any) -> Any: class TabularData(DataModule): """Data module for tabular tasks.""" - preprocess_cls = TabularPreprocess + input_transform_cls = TabularInputTransform output_transform_cls = TabularOutputTransform is_regression: bool = False @@ -342,12 +343,12 @@ def from_data_frame( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ): """Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames. @@ -360,24 +361,24 @@ def from_data_frame( test_data_frame: The pandas ``DataFrame`` containing the testing data. predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -420,7 +421,7 @@ def from_data_frame( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, @@ -434,7 +435,7 @@ def from_data_frame( target_codes=target_codes, classes=classes, is_regression=cls.is_regression, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -452,12 +453,12 @@ def from_csv( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.tabular.data.TabularData` object from the given CSV files. @@ -470,24 +471,24 @@ def from_csv( test_file: The CSV file containing the testing data. predict_file: The CSV file containing the data to use when predicting. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -509,10 +510,10 @@ def from_csv( val_data_frame=pd.read_csv(val_file) if val_file is not None else None, test_data_frame=pd.read_csv(test_file) if test_file is not None else None, predict_data_frame=pd.read_csv(predict_file) if predict_file is not None else None, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/tabular/forecasting/cli.py b/flash/tabular/forecasting/cli.py index 5757e01031..c1510fdc74 100644 --- a/flash/tabular/forecasting/cli.py +++ b/flash/tabular/forecasting/cli.py @@ -34,7 +34,7 @@ def from_synthetic_ar_data( max_prediction_length: int = 20, batch_size: int = 32, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> TabularForecastingData: """Creates and loads a synthetic auto-regressive (AR) data set.""" data = generate_ar_data(seasonality=seasonality, timesteps=timesteps, n_series=n_series, seed=42) @@ -55,7 +55,7 @@ def from_synthetic_ar_data( val_data_frame=data, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/tabular/forecasting/data.py b/flash/tabular/forecasting/data.py index 85528c4e19..4fe7591a32 100644 --- a/flash/tabular/forecasting/data.py +++ b/flash/tabular/forecasting/data.py @@ -20,7 +20,8 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Deserializer, Preprocess +from flash.core.data.io.input_transform import InputTransform +from flash.core.data.process import Deserializer from flash.core.data.properties import ProcessState from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _PANDAS_AVAILABLE, requires @@ -93,7 +94,7 @@ def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) return {DefaultDataKeys.INPUT: sample[0], DefaultDataKeys.TARGET: sample[1]} -class TabularForecastingPreprocess(Preprocess): +class TabularForecastingInputTransform(InputTransform): def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -120,14 +121,14 @@ def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: return {**self.transforms, **self.data_source_kwargs} @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> "Preprocess": + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> "InputTransform": return cls(**state_dict) class TabularForecastingData(DataModule): """Data module for the tabular forecasting task.""" - preprocess_cls = TabularForecastingPreprocess + input_transform_cls = TabularForecastingInputTransform @property def parameters(self) -> Optional[Dict[str, Any]]: @@ -151,11 +152,11 @@ def from_data_frame( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ): """Creates a :class:`~flash.tabular.forecasting.data.TabularForecastingData` object from the given data frames. @@ -180,23 +181,23 @@ def from_data_frame( test_data_frame: The pandas ``DataFrame`` containing the testing data. predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -226,9 +227,9 @@ def from_data_frame( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/template/classification/data.py b/flash/template/classification/data.py index 66cb30eb97..989112723f 100644 --- a/flash/template/classification/data.py +++ b/flash/template/classification/data.py @@ -21,7 +21,7 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, LabelsState, NumpyDataSource -from flash.core.data.process import Preprocess +from flash.core.data.io.input_transform import InputTransform from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _SKLEARN_AVAILABLE from flash.core.utilities.stages import RunningStage @@ -82,8 +82,8 @@ def predict_load_data(self, data: Bunch) -> Sequence[Mapping[str, Any]]: return super().predict_load_data(data.data) -class TemplatePreprocess(Preprocess): - """An example :class:`~flash.core.data.process.Preprocess`. +class TemplateInputTransform(InputTransform): + """An example :class:`~flash.core.data.io.input_transform.InputTransform`. Args: train_transform: The user-specified transforms to apply during training. @@ -151,15 +151,15 @@ def default_transforms(self) -> Optional[Dict[str, Callable]]: class TemplateData(DataModule): - """Creating our :class:`~flash.core.data.data_module.DataModule` is as easy as setting the ``preprocess_cls`` - attribute. + """Creating our :class:`~flash.core.data.data_module.DataModule` is as easy as setting the + ``input_transform_cls`` attribute. We get the ``from_numpy`` method for free as we've configured a ``DefaultDataSources.NUMPY`` data source. We'll also add a ``from_sklearn`` method so that we can use our ``TemplateSKLearnDataSource. Finally, we define the ``num_features`` property for convenience. """ - preprocess_cls = TemplatePreprocess + input_transform_cls = TemplateInputTransform @classmethod def from_sklearn( @@ -173,11 +173,11 @@ def from_sklearn( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ): """This is our custom ``from_*`` method. It expects scikit-learn ``Bunch`` objects as input and passes them through to the :meth:`~flash.core.data.data_module.DataModule.from_data_source` method underneath. @@ -188,23 +188,23 @@ def from_sklearn( test_bunch: The scikit-learn ``Bunch`` containing the test data. predict_bunch: The scikit-learn ``Bunch`` containing the predict data. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -220,11 +220,11 @@ def from_sklearn( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) @property diff --git a/flash/text/classification/cli.py b/flash/text/classification/cli.py index 0b7be2bd11..d12aaa52ec 100644 --- a/flash/text/classification/cli.py +++ b/flash/text/classification/cli.py @@ -23,7 +23,7 @@ def from_imdb( backbone: str = "prajjwal1/bert-medium", batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> TextClassificationData: """Downloads and loads the IMDB sentiment classification data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/") @@ -35,7 +35,7 @@ def from_imdb( backbone=backbone, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) @@ -44,7 +44,7 @@ def from_toxic( val_split: float = 0.1, batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> TextClassificationData: """Downloads and loads the Jigsaw toxic comments data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "./data") @@ -56,7 +56,7 @@ def from_toxic( val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index d64e9026c0..6290caa44f 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -24,8 +24,9 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources, LabelsState +from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.process import Deserializer, Preprocess +from flash.core.data.process import Deserializer from flash.core.integrations.labelstudio.data_source import LabelStudioTextClassificationDataSource from flash.core.utilities.imports import _TEXT_AVAILABLE, requires @@ -252,7 +253,7 @@ def load_data( return hf_dataset -class TextClassificationPreprocess(Preprocess): +class TextClassificationInputTransform(InputTransform): @requires("text") def __init__( self, @@ -323,12 +324,12 @@ def per_batch_transform(self, batch: Any) -> Any: class TextClassificationData(DataModule): """Data Module for text classification tasks.""" - preprocess_cls = TextClassificationPreprocess + input_transform_cls = TextClassificationInputTransform output_transform_cls = TextClassificationOutputTransform @property def backbone(self) -> Optional[str]: - return getattr(self.preprocess, "backbone", None) + return getattr(self.input_transform, "backbone", None) @classmethod def from_data_frame( @@ -344,12 +345,12 @@ def from_data_frame( test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.text.classification.data.TextClassificationData` object from the given pandas ``DataFrame`` objects. @@ -362,24 +363,24 @@ def from_data_frame( test_data_frame: The pandas ``DataFrame`` containing the testing data. predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -395,12 +396,12 @@ def from_data_frame( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -418,12 +419,12 @@ def from_lists( test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.text.classification.data.TextClassificationData` object from the given Python lists. @@ -440,24 +441,24 @@ def from_lists( should be provided as a list of lists, where each inner list contains the targets for a sample. predict_data: A list of sentences to use when predicting. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -473,12 +474,12 @@ def from_lists( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -495,17 +496,17 @@ def from_parquet( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given PARQUET files using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.PARQUET` - from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: input_fields: The field or fields (columns) in the PARQUET file to use for the input. @@ -515,24 +516,24 @@ def from_parquet( test_file: The PARQUET file containing the testing data. predict_file: The PARQUET file containing the data to use when predicting. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -559,12 +560,12 @@ def from_parquet( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -581,12 +582,12 @@ def from_hf_datasets( test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.text.classification.data.TextClassificationData` object from the given Hugging Face datasets ``Dataset`` objects. @@ -599,24 +600,24 @@ def from_hf_datasets( test_hf_dataset: The pandas ``Dataset`` containing the testing data. predict_hf_dataset: The pandas ``Dataset`` containing the data to use when predicting. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` to use for the ``train_dataloader``. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -632,10 +633,10 @@ def from_hf_datasets( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/text/question_answering/cli.py b/flash/text/question_answering/cli.py index 471cf13eca..5bc707b1b5 100644 --- a/flash/text/question_answering/cli.py +++ b/flash/text/question_answering/cli.py @@ -23,7 +23,7 @@ def from_squad( backbone: str = "distilbert-base-uncased", batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> QuestionAnsweringData: """Downloads and loads a tiny subset of the squad V2 data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/squad_tiny.zip", "./data/") @@ -34,7 +34,7 @@ def from_squad( backbone=backbone, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/text/question_answering/data.py b/flash/text/question_answering/data.py index c0caf41efc..00ea22a992 100644 --- a/flash/text/question_answering/data.py +++ b/flash/text/question_answering/data.py @@ -30,8 +30,8 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.process import Preprocess from flash.core.data.properties import ProcessState from flash.core.utilities.imports import _TEXT_AVAILABLE, requires from flash.core.utilities.stages import RunningStage @@ -87,7 +87,7 @@ def _tokenize_fn(self, samples: Any) -> Callable: ) if stage == RunningStage.TRAINING: - # Preprocess function for training + # InputTransform function for training tokenized_samples, _, _ = self._prepare_train_features(samples, tokenized_samples) elif self._running_stage.evaluating or stage == RunningStage.PREDICTING: if self._running_stage.evaluating: @@ -98,7 +98,7 @@ def _tokenize_fn(self, samples: Any) -> Callable: tokenized_samples["overflow_to_sample_mapping"] = _sample_mapping tokenized_samples["offset_mapping"] = _offset_mapping - # Preprocess function for eval or predict + # InputTransform function for eval or predict tokenized_samples = self._prepare_val_features(samples, tokenized_samples) offset_mappings = tokenized_samples.pop("offset_mapping") @@ -484,13 +484,13 @@ def load_data(self, data: str, dataset: Optional[Any] = None) -> "datasets.Datas @dataclass(unsafe_hash=True, frozen=True) class QuestionAnsweringBackboneState(ProcessState): """The ``QuestionAnsweringBackboneState`` stores the backbone in use by the - :class:`~flash.text.question_answering.data.QuestionAnsweringPreprocess` + :class:`~flash.text.question_answering.data.QuestionAnsweringInputTransform` """ backbone: str -class QuestionAnsweringPreprocess(Preprocess): +class QuestionAnsweringInputTransform(InputTransform): @requires("text") def __init__( self, @@ -638,7 +638,7 @@ def __setstate__(self, state): class QuestionAnsweringData(DataModule): """Data module for QuestionAnswering task.""" - preprocess_cls = QuestionAnsweringPreprocess + input_transform_cls = QuestionAnsweringInputTransform output_transform_cls = QuestionAnsweringOutputTransform @classmethod @@ -651,11 +651,11 @@ def from_squad_v2( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ): """Creates a :class:`~flash.text.question_answering.data.QuestionAnsweringData` object from the given data JSON files in the SQuAD2.0 format. @@ -665,21 +665,21 @@ def from_squad_v2( val_file: The JSON file containing the validation data. test_file: The JSON file containing the testing data. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. Returns: The constructed data module. @@ -700,11 +700,11 @@ def from_squad_v2( val_transform=val_transform, test_transform=test_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -719,18 +719,18 @@ def from_json( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Sampler] = None, field: Optional[str] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.text.question_answering.QuestionAnsweringData` object from the given JSON files using the :class:`~flash.text.question_answering.QuestionAnsweringDataSource`of name :attr:`~flash.core.data.data_source.DefaultDataSources.JSON` from the passed or constructed - :class:`~flash.text.question_answering.QuestionAnsweringPreprocess`. + :class:`~flash.text.question_answering.QuestionAnsweringInputTransform`. Args: train_file: The JSON file containing the training data. @@ -738,27 +738,27 @@ def from_json( test_file: The JSON file containing the testing data. predict_file: The JSON file containing the data to use when predicting. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. field: To specify the field that holds the data in the JSON file. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. - .. note:: The following keyword arguments can be passed through to the preprocess_kwargs + .. note:: The following keyword arguments can be passed through to the input_transform_kwargs - backbone: The HF model to be used for the task. - max_source_length: Max length of the sequence to be considered during tokenization. @@ -800,12 +800,12 @@ def from_json( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, - **preprocess_kwargs, + **input_transform_kwargs, ) @classmethod @@ -820,17 +820,17 @@ def from_csv( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, + input_transform: Optional[InputTransform] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, sampler: Optional[Sampler] = None, - **preprocess_kwargs: Any, + **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given CSV files using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.CSV` - from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: input_fields: The field or fields (columns) in the CSV file to use for the input. @@ -840,26 +840,26 @@ def from_csv( test_file: The CSV file containing the testing data. predict_file: The CSV file containing the data to use when predicting. train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + input_transform: The :class:`~flash.core.data.data.InputTransform` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. + input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform. + Will only be used if ``input_transform = None``. - .. note:: The following keyword arguments can be passed through to the preprocess_kwargs + .. note:: The following keyword arguments can be passed through to the input_transform_kwargs - backbone: The HF model to be used for the task. - max_source_length: Max length of the sequence to be considered during tokenization. @@ -903,10 +903,10 @@ def from_csv( test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, - preprocess=preprocess, + input_transform=input_transform, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index d523af6692..2afbea8790 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -21,8 +21,8 @@ import flash from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataSources +from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.process import Preprocess from flash.core.data.properties import ProcessState from flash.core.utilities.imports import _TEXT_AVAILABLE, requires from flash.text.classification.data import TextDeserializer @@ -242,14 +242,14 @@ def __setstate__(self, state): @dataclass(unsafe_hash=True, frozen=True) class Seq2SeqBackboneState(ProcessState): """The ``Seq2SeqBackboneState`` stores the backbone in use by the - :class:`~flash.text.seq2seq.core.data.Seq2SeqPreprocess` + :class:`~flash.text.seq2seq.core.data.Seq2SeqInputTransform` """ backbone: str backbone_kwargs: Dict[str, Any] = field(default_factory=dict) -class Seq2SeqPreprocess(Preprocess): +class Seq2SeqInputTransform(InputTransform): @requires("text") def __init__( self, @@ -360,5 +360,5 @@ def __setstate__(self, state): class Seq2SeqData(DataModule): """Data module for Seq2Seq tasks.""" - preprocess_cls = Seq2SeqPreprocess + input_transform_cls = Seq2SeqInputTransform output_transform_cls = Seq2SeqOutputTransform diff --git a/flash/text/seq2seq/summarization/cli.py b/flash/text/seq2seq/summarization/cli.py index 25003cb58b..12bbc8b967 100644 --- a/flash/text/seq2seq/summarization/cli.py +++ b/flash/text/seq2seq/summarization/cli.py @@ -23,7 +23,7 @@ def from_xsum( backbone: str = "sshleifer/distilbart-xsum-1-1", batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> SummarizationData: """Downloads and loads the XSum data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "./data/") @@ -35,7 +35,7 @@ def from_xsum( backbone=backbone, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py index cd99caa490..0e13cc9aca 100644 --- a/flash/text/seq2seq/summarization/data.py +++ b/flash/text/seq2seq/summarization/data.py @@ -13,10 +13,10 @@ # limitations under the License. from typing import Callable, Dict, Optional, Union -from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqOutputTransform, Seq2SeqPreprocess +from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqInputTransform, Seq2SeqOutputTransform -class SummarizationPreprocess(Seq2SeqPreprocess): +class SummarizationInputTransform(Seq2SeqInputTransform): def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -44,5 +44,5 @@ def __init__( class SummarizationData(Seq2SeqData): - preprocess_cls = SummarizationPreprocess + input_transform_cls = SummarizationInputTransform output_transform_cls = Seq2SeqOutputTransform diff --git a/flash/text/seq2seq/translation/cli.py b/flash/text/seq2seq/translation/cli.py index 66ec698791..2e310fc2aa 100644 --- a/flash/text/seq2seq/translation/cli.py +++ b/flash/text/seq2seq/translation/cli.py @@ -23,7 +23,7 @@ def from_wmt_en_ro( backbone: str = "Helsinki-NLP/opus-mt-en-ro", batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> TranslationData: """Downloads and loads the WMT EN RO data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "./data") @@ -35,7 +35,7 @@ def from_wmt_en_ro( backbone=backbone, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py index 89a712d492..e7403302b8 100644 --- a/flash/text/seq2seq/translation/data.py +++ b/flash/text/seq2seq/translation/data.py @@ -13,10 +13,10 @@ # limitations under the License. from typing import Callable, Dict, Optional, Union -from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqOutputTransform, Seq2SeqPreprocess +from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqInputTransform, Seq2SeqOutputTransform -class TranslationPreprocess(Seq2SeqPreprocess): +class TranslationInputTransform(Seq2SeqInputTransform): def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -45,5 +45,5 @@ def __init__( class TranslationData(Seq2SeqData): """Data module for Translation tasks.""" - preprocess_cls = TranslationPreprocess + input_transform_cls = TranslationInputTransform output_transform_cls = Seq2SeqOutputTransform diff --git a/flash/video/classification/cli.py b/flash/video/classification/cli.py index 3053d0c1ca..395de259c5 100644 --- a/flash/video/classification/cli.py +++ b/flash/video/classification/cli.py @@ -26,7 +26,7 @@ def from_kinetics( decode_audio: bool = False, batch_size: int = 4, num_workers: int = 0, - **preprocess_kwargs, + **input_transform_kwargs, ) -> VideoClassificationData: """Downloads and loads the Kinetics data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip", "./data") @@ -38,7 +38,7 @@ def from_kinetics( decode_audio=decode_audio, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs, + **input_transform_kwargs, ) diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index b80058549d..2d72b3ad45 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -27,7 +27,7 @@ LabelsState, PathsDataSource, ) -from flash.core.data.process import Preprocess +from flash.core.data.io.input_transform import InputTransform from flash.core.integrations.labelstudio.data_source import LabelStudioVideoClassificationDataSource from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _KORNIA_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, lazy_import @@ -269,7 +269,7 @@ def _make_encoded_video_dataset(self, data: SampleCollection) -> "LabeledVideoDa return ds -class VideoClassificationPreprocess(Preprocess): +class VideoClassificationInputTransform(InputTransform): def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -352,7 +352,7 @@ def get_state_dict(self) -> Dict[str, Any]: } @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool) -> "VideoClassificationPreprocess": + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool) -> "VideoClassificationInputTransform": return cls(**state_dict) def default_transforms(self) -> Dict[str, Callable]: @@ -393,4 +393,4 @@ def default_transforms(self) -> Dict[str, Callable]: class VideoClassificationData(DataModule): """Data module for Video classification tasks.""" - preprocess_cls = VideoClassificationPreprocess + input_transform_cls = VideoClassificationInputTransform diff --git a/flash_examples/flash_components/custom_data_loading.py b/flash_examples/flash_components/custom_data_loading.py index 0c3edf23d3..fed306ee46 100644 --- a/flash_examples/flash_components/custom_data_loading.py +++ b/flash_examples/flash_components/custom_data_loading.py @@ -21,9 +21,9 @@ from pytorch_lightning import seed_everything from torch.utils.data._utils.collate import default_collate -from flash import _PACKAGE_ROOT, FlashDataset, InputTransform +from flash import _PACKAGE_ROOT, FlashDataset from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE +from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform from flash.core.data.new_data_module import DataModule from flash.core.data.transforms import ApplyToKeys from flash.core.data.utils import download_data diff --git a/tests/audio/speech_recognition/test_model.py b/tests/audio/speech_recognition/test_model.py index e801c02224..6fe31b83c9 100644 --- a/tests/audio/speech_recognition/test_model.py +++ b/tests/audio/speech_recognition/test_model.py @@ -22,7 +22,7 @@ from flash import Trainer from flash.__main__ import main from flash.audio import SpeechRecognition -from flash.audio.speech_recognition.data import SpeechRecognitionOutputTransform, SpeechRecognitionPreprocess +from flash.audio.speech_recognition.data import SpeechRecognitionInputTransform, SpeechRecognitionOutputTransform from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _AUDIO_AVAILABLE from tests.helpers.utils import _AUDIO_TESTING, _SERVE_TESTING @@ -79,8 +79,9 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = SpeechRecognition(backbone=TEST_BACKBONE) - # TODO: Currently only servable once a preprocess and output_transform have been attached - model._preprocess = SpeechRecognitionPreprocess() + + # TODO: Currently only servable once a input_transform and postprocess have been attached + model._input_transform = SpeechRecognitionInputTransform() model._output_transform = SpeechRecognitionOutputTransform() model.eval() model.serve() diff --git a/tests/core/data/io/test_input_transform.py b/tests/core/data/io/test_input_transform.py new file mode 100644 index 0000000000..92513bfaac --- /dev/null +++ b/tests/core/data/io/test_input_transform.py @@ -0,0 +1,111 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock + +import pytest +import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch.utils.data._utils.collate import default_collate + +from flash import DataModule +from flash.core.data.data_source import DefaultDataSources +from flash.core.data.io.input_transform import ( + _InputTransformProcessor, + _InputTransformSequential, + DefaultInputTransform, +) +from flash.core.utilities.stages import RunningStage + + +class CustomInputTransform(DefaultInputTransform): + def __init__(self): + super().__init__( + data_sources={ + "test": Mock(return_value="test"), + DefaultDataSources.TENSORS: Mock(return_value="tensors"), + }, + default_data_source="test", + ) + + +def test_input_transform_processor_str(): + input_transform_processor = _InputTransformProcessor( + Mock(name="input_transform"), + default_collate, + torch.relu, + torch.softmax, + RunningStage.TRAINING, + False, + True, + ) + assert str(input_transform_processor) == ( + "_InputTransformProcessor:\n" + "\t(per_sample_transform): FuncModule(relu)\n" + "\t(collate_fn): FuncModule(default_collate)\n" + "\t(per_batch_transform): FuncModule(softmax)\n" + "\t(apply_per_sample_transform): False\n" + "\t(on_device): True\n" + "\t(stage): RunningStage.TRAINING" + ) + + +def test_sequential_str(): + sequential = _InputTransformSequential( + Mock(name="input_transform"), + torch.softmax, + torch.as_tensor, + torch.relu, + RunningStage.TRAINING, + True, + ) + assert str(sequential) == ( + "_InputTransformSequential:\n" + "\t(pre_tensor_transform): FuncModule(softmax)\n" + "\t(to_tensor_transform): FuncModule(as_tensor)\n" + "\t(post_tensor_transform): FuncModule(relu)\n" + "\t(assert_contains_tensor): True\n" + "\t(stage): RunningStage.TRAINING" + ) + + +def test_data_source_of_name(): + input_transform = CustomInputTransform() + + assert input_transform.data_source_of_name("test")() == "test" + assert input_transform.data_source_of_name(DefaultDataSources.TENSORS)() == "tensors" + assert input_transform.data_source_of_name("tensors")() == "tensors" + assert input_transform.data_source_of_name("default")() == "test" + + with pytest.raises(MisconfigurationException, match="available data sources are: test, tensor"): + input_transform.data_source_of_name("not available") + + +def test_available_data_sources(): + input_transform = CustomInputTransform() + + assert DefaultDataSources.TENSORS in input_transform.available_data_sources() + assert "test" in input_transform.available_data_sources() + assert len(input_transform.available_data_sources()) == 3 + + data_module = DataModule(input_transform=input_transform) + + assert DefaultDataSources.TENSORS in data_module.available_data_sources() + assert "test" in data_module.available_data_sources() + assert len(data_module.available_data_sources()) == 3 + + +def test_check_transforms(): + transform = torch.nn.Identity() + DefaultInputTransform(train_transform=transform) + DefaultInputTransform(train_transform=[transform]) diff --git a/tests/core/data/io/test_output.py b/tests/core/data/io/test_output.py index e875c2f86b..a0890bfb04 100644 --- a/tests/core/data/io/test_output.py +++ b/tests/core/data/io/test_output.py @@ -20,8 +20,8 @@ from flash.core.classification import Labels from flash.core.data.data_pipeline import DataPipeline, DataPipelineState from flash.core.data.data_source import LabelsState +from flash.core.data.io.input_transform import DefaultInputTransform from flash.core.data.io.output import Output -from flash.core.data.process import DefaultPreprocess from flash.core.model import Task from flash.core.trainer import Trainer @@ -47,10 +47,10 @@ def __init__(self): output = Labels(["a", "b"]) model = CustomModel() trainer = Trainer(fast_dev_run=True) - data_pipeline = DataPipeline(preprocess=DefaultPreprocess(), output=output) + data_pipeline = DataPipeline(input_transform=DefaultInputTransform(), output=output) data_pipeline.initialize() model.data_pipeline = data_pipeline - assert isinstance(model.preprocess, DefaultPreprocess) + assert isinstance(model.input_transform, DefaultInputTransform) dummy_data = DataLoader(list(zip(torch.arange(10, dtype=torch.float), torch.arange(10, dtype=torch.float)))) trainer.fit(model, train_dataloader=dummy_data) trainer.save_checkpoint(checkpoint_file) diff --git a/tests/core/data/test_auto_dataset.py b/tests/core/data/test_auto_dataset.py index 3543ec83bc..7c65a160b5 100644 --- a/tests/core/data/test_auto_dataset.py +++ b/tests/core/data/test_auto_dataset.py @@ -168,7 +168,7 @@ def test_iterable_autodataset_smoke(): False, ], ) -def test_preprocessing_data_source_with_running_stage(with_dataset): +def test_input_transforming_data_source_with_running_stage(with_dataset): data_source = _AutoDatasetTestDataSource(with_dataset) running_stage = RunningStage.TRAINING diff --git a/tests/core/data/test_batch.py b/tests/core/data/test_batch.py index 958f31bf85..102fb7196f 100644 --- a/tests/core/data/test_batch.py +++ b/tests/core/data/test_batch.py @@ -12,54 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple -from unittest.mock import Mock import torch from torch.testing import assert_allclose -from torch.utils.data._utils.collate import default_collate -from flash.core.data.batch import _Preprocessor, _Sequential, default_uncollate -from flash.core.utilities.stages import RunningStage - - -def test_sequential_str(): - sequential = _Sequential( - Mock(name="preprocess"), - torch.softmax, - torch.as_tensor, - torch.relu, - RunningStage.TRAINING, - True, - ) - assert str(sequential) == ( - "_Sequential:\n" - "\t(pre_tensor_transform): FuncModule(softmax)\n" - "\t(to_tensor_transform): FuncModule(as_tensor)\n" - "\t(post_tensor_transform): FuncModule(relu)\n" - "\t(assert_contains_tensor): True\n" - "\t(stage): RunningStage.TRAINING" - ) - - -def test_preprocessor_str(): - preprocessor = _Preprocessor( - Mock(name="preprocess"), - default_collate, - torch.relu, - torch.softmax, - RunningStage.TRAINING, - False, - True, - ) - assert str(preprocessor) == ( - "_Preprocessor:\n" - "\t(per_sample_transform): FuncModule(relu)\n" - "\t(collate_fn): FuncModule(default_collate)\n" - "\t(per_batch_transform): FuncModule(softmax)\n" - "\t(apply_per_sample_transform): False\n" - "\t(on_device): True\n" - "\t(stage): RunningStage.TRAINING" - ) +from flash.core.data.batch import default_uncollate class TestDefaultUncollate: diff --git a/tests/core/data/test_callback.py b/tests/core/data/test_callback.py index 7cf49d395b..577c84ecd5 100644 --- a/tests/core/data/test_callback.py +++ b/tests/core/data/test_callback.py @@ -17,7 +17,7 @@ import torch from flash.core.data.data_module import DataModule -from flash.core.data.process import DefaultPreprocess +from flash.core.data.io.input_transform import DefaultInputTransform from flash.core.model import Task from flash.core.trainer import Trainer from flash.core.utilities.stages import RunningStage @@ -32,9 +32,9 @@ def test_flash_callback(_, __, tmpdir): inputs = [[torch.rand(1), torch.rand(1)]] dm = DataModule.from_data_source( - "default", inputs, inputs, inputs, None, preprocess=DefaultPreprocess(), batch_size=1, num_workers=0 + "default", inputs, inputs, inputs, None, input_transform=DefaultInputTransform(), batch_size=1, num_workers=0 ) - dm.preprocess.callbacks += [callback_mock] + dm.input_transform.callbacks += [callback_mock] _ = next(iter(dm.train_dataloader())) @@ -59,9 +59,9 @@ def __init__(self): progress_bar_refresh_rate=0, ) dm = DataModule.from_data_source( - "default", inputs, inputs, inputs, None, preprocess=DefaultPreprocess(), batch_size=1, num_workers=0 + "default", inputs, inputs, inputs, None, input_transform=DefaultInputTransform(), batch_size=1, num_workers=0 ) - dm.preprocess.callbacks += [callback_mock] + dm.input_transform.callbacks += [callback_mock] trainer.fit(CustomModel(), datamodule=dm) assert callback_mock.method_calls == [ diff --git a/tests/core/data/test_callbacks.py b/tests/core/data/test_callbacks.py index 271eb595a9..d61a591c94 100644 --- a/tests/core/data/test_callbacks.py +++ b/tests/core/data/test_callbacks.py @@ -18,7 +18,7 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.process import DefaultPreprocess +from flash.core.data.io.input_transform import DefaultInputTransform from flash.core.utilities.stages import RunningStage @@ -43,7 +43,7 @@ def configure_data_fetcher(): @classmethod def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_data: Any) -> "CustomDataModule": - preprocess = DefaultPreprocess() + input_transform = DefaultInputTransform() return cls.from_data_source( "default", @@ -51,7 +51,7 @@ def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_dat val_data=val_data, test_data=test_data, predict_data=predict_data, - preprocess=preprocess, + input_transform=input_transform, batch_size=5, ) diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py index 51c0279661..268850729c 100644 --- a/tests/core/data/test_data_pipeline.py +++ b/tests/core/data/test_data_pipeline.py @@ -25,13 +25,13 @@ from flash import Trainer from flash.core.data.auto_dataset import IterableAutoDataset -from flash.core.data.batch import _Preprocessor from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import _StageOrchestrator, DataPipeline, DataPipelineState from flash.core.data.data_source import DataSource +from flash.core.data.io.input_transform import _InputTransformProcessor, DefaultInputTransform, InputTransform from flash.core.data.io.output import Output from flash.core.data.io.output_transform import _OutputTransformProcessor, OutputTransform -from flash.core.data.process import DefaultPreprocess, Deserializer, Preprocess +from flash.core.data.process import Deserializer from flash.core.data.properties import ProcessState from flash.core.data.states import PerBatchTransformOnDevice, ToTensorTransform from flash.core.model import Task @@ -73,20 +73,20 @@ def test_get_state(): def test_data_pipeline_str(): data_pipeline = DataPipeline( data_source=cast(DataSource, "data_source"), - preprocess=cast(Preprocess, "preprocess"), + input_transform=cast(InputTransform, "input_transform"), output_transform=cast(OutputTransform, "output_transform"), output=cast(Output, "output"), deserializer=cast(Deserializer, "deserializer"), ) expected = "data_source=data_source, deserializer=deserializer, " - expected += "preprocess=preprocess, output_transform=output_transform, output=output" + expected += "input_transform=input_transform, output_transform=output_transform, output=output" assert str(data_pipeline) == (f"DataPipeline({expected})") -@pytest.mark.parametrize("use_preprocess", [False, True]) +@pytest.mark.parametrize("use_input_transform", [False, True]) @pytest.mark.parametrize("use_output_transform", [False, True]) -def test_data_pipeline_init_and_assignement(use_preprocess, use_output_transform, tmpdir): +def test_data_pipeline_init_and_assignement(use_input_transform, use_output_transform, tmpdir): class CustomModel(Task): def __init__(self, output_transform: Optional[OutputTransform] = None): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) @@ -95,17 +95,19 @@ def __init__(self, output_transform: Optional[OutputTransform] = None): def train_dataloader(self) -> Any: return DataLoader(DummyDataset()) - class SubPreprocess(DefaultPreprocess): + class SubInputTransform(DefaultInputTransform): pass class SubOutputTransform(OutputTransform): pass data_pipeline = DataPipeline( - preprocess=SubPreprocess() if use_preprocess else None, + input_transform=SubInputTransform() if use_input_transform else None, output_transform=SubOutputTransform() if use_output_transform else None, ) - assert isinstance(data_pipeline._preprocess_pipeline, SubPreprocess if use_preprocess else DefaultPreprocess) + assert isinstance( + data_pipeline._input_transform_pipeline, SubInputTransform if use_input_transform else DefaultInputTransform + ) assert isinstance(data_pipeline._output_transform, SubOutputTransform if use_output_transform else OutputTransform) model = CustomModel(output_transform=OutputTransform()) @@ -113,10 +115,10 @@ class SubOutputTransform(OutputTransform): # TODO: the line below should make the same effect but it's not # data_pipeline._attach_to_model(model) - if use_preprocess: - assert isinstance(model._preprocess, SubPreprocess) + if use_input_transform: + assert isinstance(model._input_transform, SubInputTransform) else: - assert model._preprocess is None or isinstance(model._preprocess, Preprocess) + assert model._input_transform is None or isinstance(model._input_transform, InputTransform) if use_output_transform: assert isinstance(model._output_transform, SubOutputTransform) @@ -125,7 +127,7 @@ class SubOutputTransform(OutputTransform): def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir): - class CustomPreprocess(DefaultPreprocess): + class CustomInputTransform(DefaultInputTransform): def val_pre_tensor_transform(self, *_, **__): pass @@ -147,32 +149,32 @@ def train_per_batch_transform_on_device(self, *_, **__): def test_per_batch_transform_on_device(self, *_, **__): pass - preprocess = CustomPreprocess() - data_pipeline = DataPipeline(preprocess=preprocess) + input_transform = CustomInputTransform() + data_pipeline = DataPipeline(input_transform=input_transform) train_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( - k, data_pipeline._preprocess_pipeline, RunningStage.TRAINING, Preprocess + k, data_pipeline._input_transform_pipeline, RunningStage.TRAINING, InputTransform ) - for k in data_pipeline.PREPROCESS_FUNCS + for k in data_pipeline.INPUT_TRANSFORM_FUNCS } val_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( - k, data_pipeline._preprocess_pipeline, RunningStage.VALIDATING, Preprocess + k, data_pipeline._input_transform_pipeline, RunningStage.VALIDATING, InputTransform ) - for k in data_pipeline.PREPROCESS_FUNCS + for k in data_pipeline.INPUT_TRANSFORM_FUNCS } test_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( - k, data_pipeline._preprocess_pipeline, RunningStage.TESTING, Preprocess + k, data_pipeline._input_transform_pipeline, RunningStage.TESTING, InputTransform ) - for k in data_pipeline.PREPROCESS_FUNCS + for k in data_pipeline.INPUT_TRANSFORM_FUNCS } predict_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( - k, data_pipeline._preprocess_pipeline, RunningStage.PREDICTING, Preprocess + k, data_pipeline._input_transform_pipeline, RunningStage.PREDICTING, InputTransform ) - for k in data_pipeline.PREPROCESS_FUNCS + for k in data_pipeline.INPUT_TRANSFORM_FUNCS } # pre_tensor_transform @@ -211,41 +213,41 @@ def test_per_batch_transform_on_device(self, *_, **__): assert test_func_names["per_batch_transform_on_device"] == "test_per_batch_transform_on_device" assert predict_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device" - train_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TRAINING) - val_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.VALIDATING) - test_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TESTING) - predict_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.PREDICTING) - - _seq = train_worker_preprocessor.per_sample_transform - assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform - assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform - assert _seq.post_tensor_transform.func == preprocess.train_post_tensor_transform - assert train_worker_preprocessor.collate_fn.func == preprocess.collate - assert train_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform - - _seq = val_worker_preprocessor.per_sample_transform - assert _seq.pre_tensor_transform.func == preprocess.val_pre_tensor_transform - assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform - assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform - assert val_worker_preprocessor.collate_fn.func == DataPipeline._identity - assert val_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform - - _seq = test_worker_preprocessor.per_sample_transform - assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform - assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform - assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform - assert test_worker_preprocessor.collate_fn.func == preprocess.test_collate - assert test_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform - - _seq = predict_worker_preprocessor.per_sample_transform - assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform - assert _seq.to_tensor_transform.func == preprocess.predict_to_tensor_transform - assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform - assert predict_worker_preprocessor.collate_fn.func == preprocess.collate - assert predict_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform - - -class CustomPreprocess(DefaultPreprocess): + train_worker_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.TRAINING) + val_worker_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.VALIDATING) + test_worker_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.TESTING) + predict_worker_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.PREDICTING) + + _seq = train_worker_input_transform_processor.per_sample_transform + assert _seq.pre_tensor_transform.func == input_transform.pre_tensor_transform + assert _seq.to_tensor_transform.func == input_transform.to_tensor_transform + assert _seq.post_tensor_transform.func == input_transform.train_post_tensor_transform + assert train_worker_input_transform_processor.collate_fn.func == input_transform.collate + assert train_worker_input_transform_processor.per_batch_transform.func == input_transform.per_batch_transform + + _seq = val_worker_input_transform_processor.per_sample_transform + assert _seq.pre_tensor_transform.func == input_transform.val_pre_tensor_transform + assert _seq.to_tensor_transform.func == input_transform.to_tensor_transform + assert _seq.post_tensor_transform.func == input_transform.post_tensor_transform + assert val_worker_input_transform_processor.collate_fn.func == DataPipeline._identity + assert val_worker_input_transform_processor.per_batch_transform.func == input_transform.per_batch_transform + + _seq = test_worker_input_transform_processor.per_sample_transform + assert _seq.pre_tensor_transform.func == input_transform.pre_tensor_transform + assert _seq.to_tensor_transform.func == input_transform.to_tensor_transform + assert _seq.post_tensor_transform.func == input_transform.post_tensor_transform + assert test_worker_input_transform_processor.collate_fn.func == input_transform.test_collate + assert test_worker_input_transform_processor.per_batch_transform.func == input_transform.per_batch_transform + + _seq = predict_worker_input_transform_processor.per_sample_transform + assert _seq.pre_tensor_transform.func == input_transform.pre_tensor_transform + assert _seq.to_tensor_transform.func == input_transform.predict_to_tensor_transform + assert _seq.post_tensor_transform.func == input_transform.post_tensor_transform + assert predict_worker_input_transform_processor.collate_fn.func == input_transform.collate + assert predict_worker_input_transform_processor.per_batch_transform.func == input_transform.per_batch_transform + + +class CustomInputTransform(DefaultInputTransform): def train_per_sample_transform(self, *_, **__): pass @@ -280,20 +282,20 @@ def predict_per_batch_transform_on_device(self, *_, **__): pass -def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor(): +def test_data_pipeline_predict_worker_input_transform_processor_and_device_input_transform_processor(): - preprocess = CustomPreprocess() - data_pipeline = DataPipeline(preprocess=preprocess) + input_transform = CustomInputTransform() + data_pipeline = DataPipeline(input_transform=input_transform) - data_pipeline.worker_preprocessor(RunningStage.TRAINING) + data_pipeline.worker_input_transform_processor(RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="are mutually exclusive"): - data_pipeline.worker_preprocessor(RunningStage.VALIDATING) + data_pipeline.worker_input_transform_processor(RunningStage.VALIDATING) with pytest.raises(MisconfigurationException, match="are mutually exclusive"): - data_pipeline.worker_preprocessor(RunningStage.TESTING) - data_pipeline.worker_preprocessor(RunningStage.PREDICTING) + data_pipeline.worker_input_transform_processor(RunningStage.TESTING) + data_pipeline.worker_input_transform_processor(RunningStage.PREDICTING) -def test_detach_preprocessing_from_model(tmpdir): +def test_detach_input_transform_from_model(tmpdir): class CustomModel(Task): def __init__(self, output_transform: Optional[OutputTransform] = None): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) @@ -302,22 +304,22 @@ def __init__(self, output_transform: Optional[OutputTransform] = None): def train_dataloader(self) -> Any: return DataLoader(DummyDataset()) - preprocess = CustomPreprocess() - data_pipeline = DataPipeline(preprocess=preprocess) + input_transform = CustomInputTransform() + data_pipeline = DataPipeline(input_transform=input_transform) model = CustomModel() model.data_pipeline = data_pipeline assert model.train_dataloader().collate_fn == default_collate assert model.transfer_batch_to_device.__self__ == model model.on_train_dataloader() - assert isinstance(model.train_dataloader().collate_fn, _Preprocessor) + assert isinstance(model.train_dataloader().collate_fn, _InputTransformProcessor) assert isinstance(model.transfer_batch_to_device, _StageOrchestrator) model.on_fit_end() assert model.transfer_batch_to_device.__self__ == model assert model.train_dataloader().collate_fn == default_collate -class TestPreprocess(DefaultPreprocess): +class TestInputTransform(DefaultInputTransform): def train_per_sample_transform(self, *_, **__): pass @@ -347,11 +349,11 @@ def predict_per_batch_transform_on_device(self, *_, **__): def test_attaching_datapipeline_to_model(tmpdir): - class SubPreprocess(DefaultPreprocess): + class SubInputTransform(DefaultInputTransform): pass - preprocess = SubPreprocess() - data_pipeline = DataPipeline(preprocess=preprocess) + input_transform = SubInputTransform() + data_pipeline = DataPipeline(input_transform=input_transform) class CustomModel(Task): def __init__(self): @@ -403,7 +405,7 @@ def _compare_pre_processor(p1, p2): @staticmethod def _assert_stage_orchestrator_state( - stage_mapping: Dict, current_running_stage: RunningStage, cls=_Preprocessor + stage_mapping: Dict, current_running_stage: RunningStage, cls=_InputTransformProcessor ): assert isinstance(stage_mapping[current_running_stage], cls) assert stage_mapping[current_running_stage] @@ -417,7 +419,9 @@ def on_train_dataloader(self) -> None: super().on_train_dataloader() collate_fn = self.train_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage - self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) + self._compare_pre_processor( + collate_fn, self.data_pipeline.worker_input_transform_processor(current_running_stage) + ) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) @@ -430,7 +434,9 @@ def on_val_dataloader(self) -> None: super().on_val_dataloader() collate_fn = self.val_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage - self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) + self._compare_pre_processor( + collate_fn, self.data_pipeline.worker_input_transform_processor(current_running_stage) + ) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) @@ -443,7 +449,9 @@ def on_test_dataloader(self) -> None: super().on_test_dataloader() collate_fn = self.test_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage - self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) + self._compare_pre_processor( + collate_fn, self.data_pipeline.worker_input_transform_processor(current_running_stage) + ) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) @@ -457,7 +465,9 @@ def on_predict_dataloader(self) -> None: super().on_predict_dataloader() collate_fn = self.predict_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage - self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) + self._compare_pre_processor( + collate_fn, self.data_pipeline.worker_input_transform_processor(current_running_stage) + ) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert isinstance(self.predict_step, _StageOrchestrator) self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) @@ -490,7 +500,7 @@ def on_fit_end(self) -> None: def test_stage_orchestrator_state_attach_detach(tmpdir): model = CustomModel() - preprocess = TestPreprocess() + input_transform = TestInputTransform() _original_predict_step = model.predict_step @@ -503,7 +513,7 @@ def _attach_output_transform_to_model( ) return model - data_pipeline = CustomDataPipeline(preprocess=preprocess) + data_pipeline = CustomDataPipeline(input_transform=input_transform) _output_transform_processor = data_pipeline._create_output_transform_processor(RunningStage.PREDICTING) data_pipeline._attach_output_transform_to_model(model, _output_transform_processor) assert model.predict_step._original == _original_predict_step @@ -581,7 +591,7 @@ def predict_load_data(self, sample) -> LamdaDummyDataset: return LamdaDummyDataset(self.fn_predict_load_data) -class TestInputTransformations(DefaultPreprocess): +class TestInputTransformations(DefaultInputTransform): def __init__(self): super().__init__(data_sources={"default": TestInputTransformationsDataSource()}) @@ -682,7 +692,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): def test_datapipeline_transformations(tmpdir): datamodule = DataModule.from_data_source( - "default", 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestInputTransformations() + "default", 1, 1, 1, 1, batch_size=2, num_workers=0, input_transform=TestInputTransformations() ) assert datamodule.train_dataloader().dataset[0] == (0, 1, 2, 3) @@ -695,7 +705,7 @@ def test_datapipeline_transformations(tmpdir): batch = next(iter(datamodule.val_dataloader())) datamodule = DataModule.from_data_source( - "default", 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestInputTransformations2() + "default", 1, 1, 1, 1, batch_size=2, num_workers=0, input_transform=TestInputTransformations2() ) batch = next(iter(datamodule.val_dataloader())) assert torch.equal(batch["a"], tensor([0, 1])) @@ -714,26 +724,26 @@ def test_datapipeline_transformations(tmpdir): trainer.test(model) trainer.predict(model) - preprocess = model._preprocess - data_source = preprocess.data_source_of_name("default") + input_transform = model._input_transform + data_source = input_transform.data_source_of_name("default") assert data_source.train_load_data_called - assert preprocess.train_pre_tensor_transform_called - assert preprocess.train_collate_called - assert preprocess.train_per_batch_transform_on_device_called + assert input_transform.train_pre_tensor_transform_called + assert input_transform.train_collate_called + assert input_transform.train_per_batch_transform_on_device_called assert data_source.val_load_data_called assert data_source.val_load_sample_called - assert preprocess.val_to_tensor_transform_called - assert preprocess.val_collate_called - assert preprocess.val_per_batch_transform_on_device_called + assert input_transform.val_to_tensor_transform_called + assert input_transform.val_collate_called + assert input_transform.val_per_batch_transform_on_device_called assert data_source.test_load_data_called - assert preprocess.test_to_tensor_transform_called - assert preprocess.test_post_tensor_transform_called + assert input_transform.test_to_tensor_transform_called + assert input_transform.test_post_tensor_transform_called assert data_source.predict_load_data_called @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_datapipeline_transformations_overridden_by_task(): - # define preprocess transforms + # define input transforms class ImageDataSource(DataSource): def load_data(self, folder: str): # from folder -> return files paths @@ -743,7 +753,7 @@ def load_sample(self, path: str) -> Image.Image: # from a file path, load the associated image return np.random.uniform(0, 1, (64, 64, 3)) - class ImageClassificationPreprocess(DefaultPreprocess): + class ImageClassificationInputTransform(DefaultInputTransform): def __init__( self, train_transform=None, @@ -788,7 +798,7 @@ def validation_step(self, batch, batch_idx): class CustomDataModule(DataModule): - preprocess_cls = ImageClassificationPreprocess + input_transform_cls = ImageClassificationInputTransform datamodule = CustomDataModule.from_data_source( "default", @@ -810,22 +820,22 @@ class CustomDataModule(DataModule): def test_is_overriden_recursive(tmpdir): - class TestPreprocess(DefaultPreprocess): + class TestInputTransform(DefaultInputTransform): def collate(self, *_): pass def val_collate(self, *_): pass - preprocess = TestPreprocess() - assert DataPipeline._is_overriden_recursive("collate", preprocess, Preprocess, prefix="val") - assert DataPipeline._is_overriden_recursive("collate", preprocess, Preprocess, prefix="train") + input_transform = TestInputTransform() + assert DataPipeline._is_overriden_recursive("collate", input_transform, InputTransform, prefix="val") + assert DataPipeline._is_overriden_recursive("collate", input_transform, InputTransform, prefix="train") assert not DataPipeline._is_overriden_recursive( - "per_batch_transform_on_device", preprocess, Preprocess, prefix="train" + "per_batch_transform_on_device", input_transform, InputTransform, prefix="train" ) - assert not DataPipeline._is_overriden_recursive("per_batch_transform_on_device", preprocess, Preprocess) + assert not DataPipeline._is_overriden_recursive("per_batch_transform_on_device", input_transform, InputTransform) with pytest.raises(MisconfigurationException, match="This function doesn't belong to the parent class"): - assert not DataPipeline._is_overriden_recursive("chocolate", preprocess, Preprocess) + assert not DataPipeline._is_overriden_recursive("chocolate", input_transform, InputTransform) @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @@ -841,7 +851,7 @@ def load_sample(self, path: str) -> Image.Image: img8Bit = np.uint8(np.random.uniform(0, 1, (64, 64, 3)) * 255.0) return Image.fromarray(img8Bit) - class ImageClassificationPreprocess(DefaultPreprocess): + class ImageClassificationInputTransform(DefaultInputTransform): def __init__( self, train_transform=None, @@ -884,7 +894,7 @@ def test_step(self, batch, batch_idx): class CustomDataModule(DataModule): - preprocess_cls = ImageClassificationPreprocess + input_transform_cls = ImageClassificationInputTransform datamodule = CustomDataModule.from_data_source( "default", @@ -914,79 +924,87 @@ class CustomDataModule(DataModule): trainer.test(model) -def test_preprocess_transforms(tmpdir): - """This test makes sure that when a preprocess is being provided transforms as dictionaries, checking is done - properly, and collate_in_worker_from_transform is properly extracted.""" +def test_input_transform_transforms(tmpdir): + """This test makes sure that when a input_transform is being provided transforms as dictionaries, checking is + done properly, and collate_in_worker_from_transform is properly extracted.""" with pytest.raises(MisconfigurationException, match="Transform should be a dict."): - DefaultPreprocess(train_transform="choco") + DefaultInputTransform(train_transform="choco") with pytest.raises(MisconfigurationException, match="train_transform contains {'choco'}. Only"): - DefaultPreprocess(train_transform={"choco": None}) + DefaultInputTransform(train_transform={"choco": None}) - preprocess = DefaultPreprocess(train_transform={"to_tensor_transform": torch.nn.Linear(1, 1)}) + input_transform = DefaultInputTransform(train_transform={"to_tensor_transform": torch.nn.Linear(1, 1)}) # keep is None - assert preprocess._train_collate_in_worker_from_transform is True - assert preprocess._val_collate_in_worker_from_transform is None - assert preprocess._test_collate_in_worker_from_transform is None - assert preprocess._predict_collate_in_worker_from_transform is None + assert input_transform._train_collate_in_worker_from_transform is True + assert input_transform._val_collate_in_worker_from_transform is None + assert input_transform._test_collate_in_worker_from_transform is None + assert input_transform._predict_collate_in_worker_from_transform is None with pytest.raises(MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`"): - preprocess = DefaultPreprocess( + input_transform = DefaultInputTransform( train_transform={ "per_batch_transform": torch.nn.Linear(1, 1), "per_sample_transform_on_device": torch.nn.Linear(1, 1), } ) - preprocess = DefaultPreprocess( + input_transform = DefaultInputTransform( train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)}, ) # keep is None - assert preprocess._train_collate_in_worker_from_transform is True - assert preprocess._val_collate_in_worker_from_transform is None - assert preprocess._test_collate_in_worker_from_transform is None - assert preprocess._predict_collate_in_worker_from_transform is False - - train_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(RunningStage.TRAINING) - val_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(RunningStage.VALIDATING) - test_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(RunningStage.TESTING) - predict_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(RunningStage.PREDICTING) - - assert train_preprocessor.collate_fn.func == preprocess.collate - assert val_preprocessor.collate_fn.func == preprocess.collate - assert test_preprocessor.collate_fn.func == preprocess.collate - assert predict_preprocessor.collate_fn.func == DataPipeline._identity - - class CustomPreprocess(DefaultPreprocess): + assert input_transform._train_collate_in_worker_from_transform is True + assert input_transform._val_collate_in_worker_from_transform is None + assert input_transform._test_collate_in_worker_from_transform is None + assert input_transform._predict_collate_in_worker_from_transform is False + + train_input_transform_processor = DataPipeline(input_transform=input_transform).worker_input_transform_processor( + RunningStage.TRAINING + ) + val_input_transform_processor = DataPipeline(input_transform=input_transform).worker_input_transform_processor( + RunningStage.VALIDATING + ) + test_input_transform_processor = DataPipeline(input_transform=input_transform).worker_input_transform_processor( + RunningStage.TESTING + ) + predict_input_transform_processor = DataPipeline(input_transform=input_transform).worker_input_transform_processor( + RunningStage.PREDICTING + ) + + assert train_input_transform_processor.collate_fn.func == input_transform.collate + assert val_input_transform_processor.collate_fn.func == input_transform.collate + assert test_input_transform_processor.collate_fn.func == input_transform.collate + assert predict_input_transform_processor.collate_fn.func == DataPipeline._identity + + class CustomInputTransform(DefaultInputTransform): def per_sample_transform_on_device(self, sample: Any) -> Any: return super().per_sample_transform_on_device(sample) def per_batch_transform(self, batch: Any) -> Any: return super().per_batch_transform(batch) - preprocess = CustomPreprocess( + input_transform = CustomInputTransform( train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)}, ) # keep is None - assert preprocess._train_collate_in_worker_from_transform is True - assert preprocess._val_collate_in_worker_from_transform is None - assert preprocess._test_collate_in_worker_from_transform is None - assert preprocess._predict_collate_in_worker_from_transform is False + assert input_transform._train_collate_in_worker_from_transform is True + assert input_transform._val_collate_in_worker_from_transform is None + assert input_transform._test_collate_in_worker_from_transform is None + assert input_transform._predict_collate_in_worker_from_transform is False - data_pipeline = DataPipeline(preprocess=preprocess) + data_pipeline = DataPipeline(input_transform=input_transform) - train_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TRAINING) + train_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`"): - val_preprocessor = data_pipeline.worker_preprocessor(RunningStage.VALIDATING) + val_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.VALIDATING) with pytest.raises(MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`"): - test_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TESTING) - predict_preprocessor = data_pipeline.worker_preprocessor(RunningStage.PREDICTING) + test_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.TESTING) + predict_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.PREDICTING) - assert train_preprocessor.collate_fn.func == preprocess.collate - assert predict_preprocessor.collate_fn.func == DataPipeline._identity + assert train_input_transform_processor.collate_fn.func == input_transform.collate + assert predict_input_transform_processor.collate_fn.func == DataPipeline._identity def test_iterable_auto_dataset(tmpdir): @@ -1000,7 +1018,7 @@ def load_sample(self, index: int) -> Dict[str, int]: assert v == {"index": index} -class CustomPreprocessHyperparameters(DefaultPreprocess): +class CustomInputTransformHyperparameters(DefaultInputTransform): def __init__(self, token: str, *args, **kwargs): self.token = token super().__init__(*args, **kwargs) @@ -1020,9 +1038,9 @@ def local_fn(x): def test_save_hyperparemeters(tmpdir): kwargs = {"train_transform": {"pre_tensor_transform": local_fn}} - preprocess = CustomPreprocessHyperparameters("token", **kwargs) - state_dict = preprocess.state_dict() + input_transform = CustomInputTransformHyperparameters("token", **kwargs) + state_dict = input_transform.state_dict() torch.save(state_dict, os.path.join(tmpdir, "state_dict.pt")) state_dict = torch.load(os.path.join(tmpdir, "state_dict.pt")) - preprocess = CustomPreprocessHyperparameters.load_from_state_dict(state_dict) - assert isinstance(preprocess, CustomPreprocessHyperparameters) + input_transform = CustomInputTransformHyperparameters.load_from_state_dict(state_dict) + assert isinstance(input_transform, CustomInputTransformHyperparameters) diff --git a/tests/core/data/test_process.py b/tests/core/data/test_process.py deleted file mode 100644 index e792f83080..0000000000 --- a/tests/core/data/test_process.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from unittest.mock import Mock - -import pytest -import torch -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DefaultPreprocess -from flash.core.data.data_source import DefaultDataSources - - -class CustomPreprocess(DefaultPreprocess): - def __init__(self): - super().__init__( - data_sources={ - "test": Mock(return_value="test"), - DefaultDataSources.TENSORS: Mock(return_value="tensors"), - }, - default_data_source="test", - ) - - -def test_data_source_of_name(): - preprocess = CustomPreprocess() - - assert preprocess.data_source_of_name("test")() == "test" - assert preprocess.data_source_of_name(DefaultDataSources.TENSORS)() == "tensors" - assert preprocess.data_source_of_name("tensors")() == "tensors" - assert preprocess.data_source_of_name("default")() == "test" - - with pytest.raises(MisconfigurationException, match="available data sources are: test, tensor"): - preprocess.data_source_of_name("not available") - - -def test_available_data_sources(): - preprocess = CustomPreprocess() - - assert DefaultDataSources.TENSORS in preprocess.available_data_sources() - assert "test" in preprocess.available_data_sources() - assert len(preprocess.available_data_sources()) == 3 - - data_module = DataModule(preprocess=preprocess) - - assert DefaultDataSources.TENSORS in data_module.available_data_sources() - assert "test" in data_module.available_data_sources() - assert len(data_module.available_data_sources()) == 3 - - -def test_check_transforms(): - transform = torch.nn.Identity() - DefaultPreprocess(train_transform=transform) - DefaultPreprocess(train_transform=[transform]) diff --git a/tests/core/data/test_serialization.py b/tests/core/data/test_serialization.py index 948f6bee13..94999b37dd 100644 --- a/tests/core/data/test_serialization.py +++ b/tests/core/data/test_serialization.py @@ -20,7 +20,7 @@ from torch.utils.data.dataloader import DataLoader from flash.core.data.data_pipeline import DataPipeline -from flash.core.data.process import DefaultPreprocess +from flash.core.data.io.input_transform import DefaultInputTransform from flash.core.model import Task @@ -29,7 +29,7 @@ def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) -class CustomPreprocess(DefaultPreprocess): +class CustomInputTransform(DefaultInputTransform): @classmethod def load_data(cls, data): return data @@ -50,22 +50,22 @@ def test_serialization_data_pipeline(tmpdir): loaded_model = CustomModel.load_from_checkpoint(checkpoint_file) assert loaded_model.data_pipeline - model.data_pipeline = DataPipeline(preprocess=CustomPreprocess()) - assert isinstance(model.preprocess, CustomPreprocess) + model.data_pipeline = DataPipeline(input_transform=CustomInputTransform()) + assert isinstance(model.input_transform, CustomInputTransform) trainer.fit(model, dummy_data) assert model.data_pipeline - assert isinstance(model.preprocess, CustomPreprocess) + assert isinstance(model.input_transform, CustomInputTransform) trainer.save_checkpoint(checkpoint_file) def fn(*args, **kwargs): return "0.0.2" - CustomPreprocess.version = fn + CustomInputTransform.version = fn loaded_model = CustomModel.load_from_checkpoint(checkpoint_file) assert loaded_model.data_pipeline - assert isinstance(loaded_model.preprocess, CustomPreprocess) + assert isinstance(loaded_model.input_transform, CustomInputTransform) for file in os.listdir(tmpdir): if file.endswith(".ckpt"): os.remove(os.path.join(tmpdir, file)) diff --git a/tests/core/integrations/labelstudio/test_labelstudio.py b/tests/core/integrations/labelstudio/test_labelstudio.py index 4f586dc00b..9e04d839ec 100644 --- a/tests/core/integrations/labelstudio/test_labelstudio.py +++ b/tests/core/integrations/labelstudio/test_labelstudio.py @@ -10,7 +10,7 @@ from flash.core.integrations.labelstudio.visualizer import launch_app from flash.image.classification.data import ImageClassificationData from flash.text.classification.data import TextClassificationData -from flash.video.classification.data import VideoClassificationData, VideoClassificationPreprocess +from flash.video.classification.data import VideoClassificationData, VideoClassificationInputTransform from tests.helpers.utils import _IMAGE_TESTING, _TEXT_TESTING, _VIDEO_TESTING @@ -272,8 +272,8 @@ def test_datasource_labelstudio_video(): """Test creation of LabelStudioVideoClassificationDataSource from video.""" download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/video_data.zip") data = {"data_folder": "data/upload/", "export_json": "data/project.json", "multi_label": True} - preprocess = VideoClassificationPreprocess() - ds = preprocess.data_source_of_name(DefaultDataSources.LABELSTUDIO) + input_transform = VideoClassificationInputTransform() + ds = input_transform.data_source_of_name(DefaultDataSources.LABELSTUDIO) train, val, test, predict = ds.to_datasets(train_data=data, test_data=data) sample_iter = iter(train) sample = next(sample_iter) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index d6b9f96999..722f0f5d32 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -35,8 +35,8 @@ from flash.audio import SpeechRecognition from flash.core.adapter import Adapter from flash.core.classification import ClassificationTask +from flash.core.data.io.input_transform import DefaultInputTransform from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.process import DefaultPreprocess from flash.core.utilities.imports import _TORCH_OPTIMIZER_AVAILABLE, _TRANSFORMERS_AVAILABLE, Image from flash.image import ImageClassificationData, ImageClassifier, SemanticSegmentation from flash.tabular import TabularClassifier @@ -176,7 +176,7 @@ def test_nested_tasks(tmpdir, task): def test_classificationtask_task_predict(): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) - task = ClassificationTask(model, preprocess=DefaultPreprocess()) + task = ClassificationTask(model, input_transform=DefaultInputTransform()) ds = DummyDataset() expected = list(range(10)) # single item diff --git a/tests/graph/classification/test_data.py b/tests/graph/classification/test_data.py index de4d08ff72..8cce12dd04 100644 --- a/tests/graph/classification/test_data.py +++ b/tests/graph/classification/test_data.py @@ -15,7 +15,7 @@ from flash.core.data.transforms import merge_transforms from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE -from flash.graph.classification.data import GraphClassificationData, GraphClassificationPreprocess +from flash.graph.classification.data import GraphClassificationData, GraphClassificationInputTransform from tests.helpers.utils import _GRAPH_TESTING if _TORCH_GEOMETRIC_AVAILABLE: @@ -24,12 +24,12 @@ @pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed.") -class TestGraphClassificationPreprocess: - """Tests ``GraphClassificationPreprocess``.""" +class TestGraphClassificationInputTransform: + """Tests ``GraphClassificationInputTransform``.""" def test_smoke(self): """A simple test that the class can be instantiated.""" - prep = GraphClassificationPreprocess() + prep = GraphClassificationInputTransform() assert prep is not None @@ -94,19 +94,19 @@ def test_transforms(self, tmpdir): test_dataset=test_dataset, predict_dataset=predict_dataset, train_transform=merge_transforms( - GraphClassificationPreprocess.default_transforms(), + GraphClassificationInputTransform.default_transforms(), {"pre_tensor_transform": OneHotDegree(tudataset.num_features - 1)}, ), val_transform=merge_transforms( - GraphClassificationPreprocess.default_transforms(), + GraphClassificationInputTransform.default_transforms(), {"pre_tensor_transform": OneHotDegree(tudataset.num_features - 1)}, ), test_transform=merge_transforms( - GraphClassificationPreprocess.default_transforms(), + GraphClassificationInputTransform.default_transforms(), {"pre_tensor_transform": OneHotDegree(tudataset.num_features - 1)}, ), predict_transform=merge_transforms( - GraphClassificationPreprocess.default_transforms(), + GraphClassificationInputTransform.default_transforms(), {"pre_tensor_transform": OneHotDegree(tudataset.num_features - 1)}, ), batch_size=2, diff --git a/tests/graph/classification/test_model.py b/tests/graph/classification/test_model.py index 0813a6fb3a..8392e8010d 100644 --- a/tests/graph/classification/test_model.py +++ b/tests/graph/classification/test_model.py @@ -21,7 +21,7 @@ from flash.core.data.data_pipeline import DataPipeline from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE from flash.graph.classification import GraphClassifier -from flash.graph.classification.data import GraphClassificationPreprocess +from flash.graph.classification.data import GraphClassificationInputTransform from tests.helpers.utils import _GRAPH_TESTING if _TORCH_GEOMETRIC_AVAILABLE: @@ -40,7 +40,7 @@ def test_train(tmpdir): """Tests that the model can be trained on a pytorch geometric dataset.""" tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) - model.data_pipeline = DataPipeline(preprocess=GraphClassificationPreprocess()) + model.data_pipeline = DataPipeline(input_transform=GraphClassificationInputTransform()) train_dl = torch.utils.data.DataLoader(tudataset, batch_size=4) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, train_dl) @@ -51,7 +51,7 @@ def test_val(tmpdir): """Tests that the model can be validated on a pytorch geometric dataset.""" tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) - model.data_pipeline = DataPipeline(preprocess=GraphClassificationPreprocess()) + model.data_pipeline = DataPipeline(input_transform=GraphClassificationInputTransform()) val_dl = torch.utils.data.DataLoader(tudataset, batch_size=4) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.validate(model, val_dl) @@ -62,7 +62,7 @@ def test_test(tmpdir): """Tests that the model can be tested on a pytorch geometric dataset.""" tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) - model.data_pipeline = DataPipeline(preprocess=GraphClassificationPreprocess()) + model.data_pipeline = DataPipeline(input_transform=GraphClassificationInputTransform()) test_dl = torch.utils.data.DataLoader(tudataset, batch_size=4) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.test(model, test_dl) @@ -73,7 +73,7 @@ def test_predict_dataset(tmpdir): """Tests that we can generate predictions from a pytorch geometric dataset.""" tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) - data_pipe = DataPipeline(preprocess=GraphClassificationPreprocess()) + data_pipe = DataPipeline(input_transform=GraphClassificationInputTransform()) out = model.predict(tudataset, data_source="datasets", data_pipeline=data_pipe) assert isinstance(out[0], int) diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py index 96da2f4a11..779da0c646 100644 --- a/tests/image/classification/test_model.py +++ b/tests/image/classification/test_model.py @@ -24,7 +24,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _IMAGE_AVAILABLE from flash.image import ImageClassifier -from flash.image.classification.data import ImageClassificationPreprocess +from flash.image.classification.data import ImageClassificationInputTransform from tests.helpers.utils import _IMAGE_TESTING, _SERVE_TESTING # ======== Mock functions ======== @@ -138,8 +138,8 @@ def test_jit(tmpdir, jitter, args): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = ImageClassifier(2) - # TODO: Currently only servable once a preprocess has been attached - model._preprocess = ImageClassificationPreprocess() + # TODO: Currently only servable once a input_transform has been attached + model._input_transform = ImageClassificationInputTransform() model.eval() model.serve() diff --git a/tests/image/embedding/utils.py b/tests/image/embedding/utils.py index 3f66ee81db..ee7fe2bd13 100644 --- a/tests/image/embedding/utils.py +++ b/tests/image/embedding/utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.process import DefaultPreprocess +from flash.core.data.io.input_transform import DefaultInputTransform from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE from flash.image import ImageClassificationData @@ -41,7 +41,7 @@ def ssl_datamodule( DefaultDataKeys.INPUT, multi_crop_transform, ) - preprocess = DefaultPreprocess( + input_transform = DefaultInputTransform( train_transform={ "to_tensor_transform": to_tensor_transform, "collate": collate_fn, @@ -50,7 +50,7 @@ def ssl_datamodule( datamodule = ImageClassificationData.from_datasets( train_dataset=FakeData(), - preprocess=preprocess, + input_transform=input_transform, batch_size=batch_size, ) diff --git a/tests/image/segmentation/test_data.py b/tests/image/segmentation/test_data.py index b44a68da0d..68e5e3f758 100644 --- a/tests/image/segmentation/test_data.py +++ b/tests/image/segmentation/test_data.py @@ -10,7 +10,7 @@ from flash import Trainer from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE -from flash.image import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess +from flash.image import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationInputTransform from tests.helpers.utils import _IMAGE_TESTING if _PIL_AVAILABLE: @@ -47,11 +47,11 @@ def create_random_data(image_files: List[str], label_files: List[str], size: Tup _rand_labels(size, num_classes).save(label_file) -class TestSemanticSegmentationPreprocess: +class TestSemanticSegmentationInputTransform: @staticmethod @pytest.mark.xfail(reaspn="parameters are marked as optional but it returns Misconficg error.") def test_smoke(): - prep = SemanticSegmentationPreprocess(num_classes=1) + prep = SemanticSegmentationInputTransform(num_classes=1) assert prep is not None diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py index 6715ebfc50..41443e470e 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/segmentation/test_model.py @@ -26,7 +26,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _IMAGE_AVAILABLE from flash.image import SemanticSegmentation -from flash.image.segmentation.data import SemanticSegmentationPreprocess +from flash.image.segmentation.data import SemanticSegmentationInputTransform from tests.helpers.utils import _IMAGE_TESTING, _SERVE_TESTING # ======== Mock functions ======== @@ -106,7 +106,7 @@ def test_unfreeze(): def test_predict_tensor(): img = torch.rand(1, 3, 64, 64) model = SemanticSegmentation(2, backbone="mobilenetv3_large_100") - data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1)) + data_pipe = DataPipeline(input_transform=SemanticSegmentationInputTransform(num_classes=1)) out = model.predict(img, data_source="tensors", data_pipeline=data_pipe) assert isinstance(out[0], list) assert len(out[0]) == 64 @@ -117,7 +117,7 @@ def test_predict_tensor(): def test_predict_numpy(): img = np.ones((1, 3, 64, 64)) model = SemanticSegmentation(2, backbone="mobilenetv3_large_100") - data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1)) + data_pipe = DataPipeline(input_transform=SemanticSegmentationInputTransform(num_classes=1)) out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) assert isinstance(out[0], list) assert len(out[0]) == 64 @@ -146,8 +146,8 @@ def test_jit(tmpdir, jitter, args): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = SemanticSegmentation(2) - # TODO: Currently only servable once a preprocess has been attached - model._preprocess = SemanticSegmentationPreprocess() + # TODO: Currently only servable once a input_transform has been attached + model._input_transform = SemanticSegmentationInputTransform() model.eval() model.serve() diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index 2efe7c316e..9ebfe8eb97 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -106,8 +106,8 @@ def test_serve(): pd.DataFrame.from_dict(train_data), ) model = TabularClassifier.from_data(datamodule) - # TODO: Currently only servable once a preprocess has been attached - model._preprocess = datamodule.preprocess + # TODO: Currently only servable once a input_transform has been attached + model._input_transform = datamodule.input_transform model.eval() model.serve() diff --git a/tests/template/classification/test_data.py b/tests/template/classification/test_data.py index b793849e08..03f11b3c81 100644 --- a/tests/template/classification/test_data.py +++ b/tests/template/classification/test_data.py @@ -16,20 +16,20 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _SKLEARN_AVAILABLE -from flash.template.classification.data import TemplateData, TemplatePreprocess +from flash.template.classification.data import TemplateData, TemplateInputTransform if _SKLEARN_AVAILABLE: from sklearn import datasets @pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed") -class TestTemplatePreprocess: - """Tests ``TemplatePreprocess``.""" +class TestTemplateInputTransform: + """Tests ``TemplateInputTransform``.""" @staticmethod def test_smoke(): """A simple test that the class can be instantiated.""" - prep = TemplatePreprocess() + prep = TemplateInputTransform() assert prep is not None diff --git a/tests/template/classification/test_model.py b/tests/template/classification/test_model.py index cfd0f77f39..0c585b842d 100644 --- a/tests/template/classification/test_model.py +++ b/tests/template/classification/test_model.py @@ -22,7 +22,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _SKLEARN_AVAILABLE from flash.template import TemplateSKLearnClassifier -from flash.template.classification.data import TemplatePreprocess +from flash.template.classification.data import TemplateInputTransform if _SKLEARN_AVAILABLE: from sklearn import datasets @@ -105,7 +105,7 @@ def test_predict_numpy(): """Tests that we can generate predictions from a numpy array.""" row = np.random.rand(1, DummyDataset.num_features) model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) - data_pipe = DataPipeline(preprocess=TemplatePreprocess()) + data_pipe = DataPipeline(input_transform=TemplateInputTransform()) out = model.predict(row, data_pipeline=data_pipe) assert isinstance(out[0], int) @@ -115,7 +115,7 @@ def test_predict_sklearn(): """Tests that we can generate predictions from a scikit-learn ``Bunch``.""" bunch = datasets.load_iris() model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) - data_pipe = DataPipeline(preprocess=TemplatePreprocess()) + data_pipe = DataPipeline(input_transform=TemplateInputTransform()) out = model.predict(bunch, data_source="sklearn", data_pipeline=data_pipe) assert isinstance(out[0], int) diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index ac7c005105..e336ff81c5 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -23,7 +23,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import TextClassifier -from flash.text.classification.data import TextClassificationOutputTransform, TextClassificationPreprocess +from flash.text.classification.data import TextClassificationInputTransform, TextClassificationOutputTransform from tests.helpers.utils import _SERVE_TESTING, _TEXT_TESTING # ======== Mock functions ======== @@ -77,8 +77,9 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = TextClassifier(2, TEST_BACKBONE) - # TODO: Currently only servable once a preprocess and output_transform have been attached - model._preprocess = TextClassificationPreprocess(backbone=TEST_BACKBONE) + + # TODO: Currently only servable once a input_transform and postprocess have been attached + model._input_transform = TextClassificationInputTransform(backbone=TEST_BACKBONE) model._output_transform = TextClassificationOutputTransform() model.eval() model.serve() diff --git a/tests/text/seq2seq/summarization/test_model.py b/tests/text/seq2seq/summarization/test_model.py index b6e02e9895..1723f9e42e 100644 --- a/tests/text/seq2seq/summarization/test_model.py +++ b/tests/text/seq2seq/summarization/test_model.py @@ -22,7 +22,7 @@ from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import SummarizationTask from flash.text.seq2seq.core.data import Seq2SeqOutputTransform -from flash.text.seq2seq.summarization.data import SummarizationPreprocess +from flash.text.seq2seq.summarization.data import SummarizationInputTransform from tests.helpers.utils import _SERVE_TESTING, _TEXT_TESTING # ======== Mock functions ======== @@ -78,8 +78,9 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = SummarizationTask(TEST_BACKBONE) - # TODO: Currently only servable once a preprocess and output_transform have been attached - model._preprocess = SummarizationPreprocess(backbone=TEST_BACKBONE) + + # TODO: Currently only servable once a input_transform and postprocess have been attached + model._input_transform = SummarizationInputTransform(backbone=TEST_BACKBONE) model._output_transform = Seq2SeqOutputTransform() model.eval() model.serve() diff --git a/tests/text/seq2seq/translation/test_model.py b/tests/text/seq2seq/translation/test_model.py index e552b74385..eb4007c59b 100644 --- a/tests/text/seq2seq/translation/test_model.py +++ b/tests/text/seq2seq/translation/test_model.py @@ -22,7 +22,7 @@ from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import TranslationTask from flash.text.seq2seq.core.data import Seq2SeqOutputTransform -from flash.text.seq2seq.translation.data import TranslationPreprocess +from flash.text.seq2seq.translation.data import TranslationInputTransform from tests.helpers.utils import _SERVE_TESTING, _TEXT_TESTING # ======== Mock functions ======== @@ -78,9 +78,10 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): model = TranslationTask(TEST_BACKBONE) - # TODO: Currently only servable once a preprocess and output_transform have been attached - model._preprocess = TranslationPreprocess(backbone=TEST_BACKBONE) + # TODO: Currently only servable once a input_transform and output_transform have been attached + model._input_transform = TranslationInputTransform(backbone=TEST_BACKBONE) model._output_transform = Seq2SeqOutputTransform() + model.eval() model.serve()