From ff92d4475cec09fc5f1dfcbcba753ce4bb353f09 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 1 Jul 2022 10:17:59 +0100 Subject: [PATCH] Add more ways to load image data for classification and detection (#1372) --- CHANGELOG.md | 4 + flash/image/classification/data.py | 93 +++++++ flash/image/classification/input.py | 24 +- flash/image/detection/data.py | 380 +++++++++++++++++++++++++++- flash/image/detection/input.py | 94 ++++++- 5 files changed, 590 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 88d1850fa1..edf685f3c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for Flash serve to the `ObjectDetector` ([#1370](https://github.com/PyTorchLightning/lightning-flash/pull/1370)) +- Added support for loading `ImageClassificationData` from PIL images with `from_images` ([#1372](https://github.com/PyTorchLightning/lightning-flash/pull/1372)) + +- Added support for loading `ObjectDetectionData` with `from_numpy`, `from_images`, and `from_tensors` ([#1372](https://github.com/PyTorchLightning/lightning-flash/pull/1372)) + ### Changed - Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276)) diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index 6a860eccdb..37743ceb02 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -41,6 +41,7 @@ ImageClassificationFiftyOneInput, ImageClassificationFilesInput, ImageClassificationFolderInput, + ImageClassificationImageInput, ImageClassificationNumpyInput, ImageClassificationTensorInput, ) @@ -64,6 +65,7 @@ "ImageClassificationData.from_files", "ImageClassificationData.from_folders", "ImageClassificationData.from_numpy", + "ImageClassificationData.from_images", "ImageClassificationData.from_tensors", "ImageClassificationData.from_data_frame", "ImageClassificationData.from_csv", @@ -385,6 +387,97 @@ def from_numpy( **data_module_kwargs, ) + @classmethod + def from_images( + cls, + train_images: Optional[List[Image.Image]] = None, + train_targets: Optional[Sequence[Any]] = None, + val_images: Optional[List[Image.Image]] = None, + val_targets: Optional[Sequence[Any]] = None, + test_images: Optional[List[Image.Image]] = None, + test_targets: Optional[Sequence[Any]] = None, + predict_images: Optional[List[Image.Image]] = None, + target_formatter: Optional[TargetFormatter] = None, + input_cls: Type[Input] = ImageClassificationImageInput, + transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, + transform_kwargs: Optional[Dict] = None, + **data_module_kwargs: Any, + ) -> "ImageClassificationData": + """Load the :class:`~flash.image.classification.data.ImageClassificationData` from lists of PIL images and + corresponding lists of targets. + + The targets can be in any of our + :ref:`supported classification target formats `. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_images: The list of PIL images to use when training. + train_targets: The list of targets to use when training. + val_images: The list of PIL images to use when validating. + val_targets: The list of targets to use when validating. + test_images: The list of PIL images to use when testing. + test_targets: The list of targets to use when testing. + predict_images: The list of PIL images to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.image.classification.data.ImageClassificationData`. + + Examples + ________ + + .. doctest:: + + >>> from PIL import Image + >>> import numpy as np + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> datamodule = ImageClassificationData.from_images( + ... train_images=[ + ... Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")), + ... Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")), + ... Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")), + ... ], + ... train_targets=["cat", "dog", "cat"], + ... predict_images=[Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 2 + >>> datamodule.labels + ['cat', 'dog'] + >>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + """ + ds_kw = dict( + target_formatter=target_formatter, + ) + + train_input = input_cls(RunningStage.TRAINING, train_images, train_targets, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + + return cls( + train_input, + input_cls(RunningStage.VALIDATING, val_images, val_targets, **ds_kw), + input_cls(RunningStage.TESTING, test_images, test_targets, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_images, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, + **data_module_kwargs, + ) + @classmethod def from_tensors( cls, diff --git a/flash/image/classification/input.py b/flash/image/classification/input.py index cea217b0a0..8898a7a6c5 100644 --- a/flash/image/classification/input.py +++ b/flash/image/classification/input.py @@ -24,7 +24,14 @@ from flash.core.data.utilities.samples import to_samples from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires -from flash.image.data import ImageFilesInput, ImageNumpyInput, ImageTensorInput, IMG_EXTENSIONS, NP_EXTENSIONS +from flash.image.data import ( + ImageFilesInput, + ImageInput, + ImageNumpyInput, + ImageTensorInput, + IMG_EXTENSIONS, + NP_EXTENSIONS, +) if _FIFTYONE_AVAILABLE: fol = lazy_import("fiftyone.core.labels") @@ -115,6 +122,21 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: return sample +class ImageClassificationImageInput(ClassificationInputMixin, ImageInput): + def load_data( + self, images: Any, targets: Optional[List[Any]] = None, target_formatter: Optional[TargetFormatter] = None + ) -> List[Dict[str, Any]]: + if targets is not None: + self.load_target_metadata(targets, target_formatter=target_formatter) + return to_samples(images, targets) + + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + sample = super().load_sample(sample) + if DataKeys.TARGET in sample: + sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET]) + return sample + + class ImageClassificationDataFrameInput(ImageClassificationFilesInput): labels: list diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index 32c089f1a9..c4ba2fb772 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union +from typing import Any, Callable, Collection, Dict, List, Optional, Sequence, Type, Union + +import numpy as np +import torch from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input @@ -20,10 +23,22 @@ from flash.core.data.utilities.sort import sorted_alphanumeric from flash.core.integrations.icevision.data import IceVisionInput from flash.core.integrations.icevision.transforms import IceVisionInputTransform -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, _IMAGE_EXTRAS_TESTING, requires +from flash.core.utilities.imports import ( + _FIFTYONE_AVAILABLE, + _ICEVISION_AVAILABLE, + _IMAGE_EXTRAS_TESTING, + Image, + requires, +) from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE -from flash.image.detection.input import ObjectDetectionFiftyOneInput, ObjectDetectionFilesInput +from flash.image.detection.input import ( + ObjectDetectionFiftyOneInput, + ObjectDetectionFilesInput, + ObjectDetectionImageInput, + ObjectDetectionNumpyInput, + ObjectDetectionTensorInput, +) if _FIFTYONE_AVAILABLE: SampleCollection = "fiftyone.core.collections.SampleCollection" @@ -183,6 +198,365 @@ def from_files( **data_module_kwargs, ) + @classmethod + def from_numpy( + cls, + train_data: Optional[Collection[np.ndarray]] = None, + train_targets: Optional[Sequence[Sequence[Any]]] = None, + train_bboxes: Optional[Sequence[Sequence[Dict[str, int]]]] = None, + val_data: Optional[Collection[np.ndarray]] = None, + val_targets: Optional[Sequence[Sequence[Any]]] = None, + val_bboxes: Optional[Sequence[Sequence[Dict[str, int]]]] = None, + test_data: Optional[Collection[np.ndarray]] = None, + test_targets: Optional[Sequence[Sequence[Any]]] = None, + test_bboxes: Optional[Sequence[Sequence[Dict[str, int]]]] = None, + predict_data: Optional[Collection[np.ndarray]] = None, + target_formatter: Optional[TargetFormatter] = None, + input_cls: Type[Input] = ObjectDetectionNumpyInput, + transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + transform_kwargs: Optional[Dict] = None, + **data_module_kwargs: Any, + ) -> "ObjectDetectionData": + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given from numpy + arrays (or lists of arrays) and corresponding lists of bounding boxes and targets. + + The targets can be in any of our + :ref:`supported classification target formats `. + The bounding boxes are expected to be dictionaries with integer values (representing pixels) and the following + keys: ``xmin``, ``ymin``, ``width``, ``height``. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_data: The numpy array or list of arrays to use when training. + train_targets: The list of lists of targets to use when training. + train_bboxes: The list of lists of bounding boxes to use when training. + val_data: The numpy array or list of arrays to use when validating. + val_targets: The list of lists of targets to use when validating. + val_bboxes: The list of lists of bounding boxes to use when validating. + test_data: The numpy array or list of arrays to use when testing. + test_targets: The list of lists of targets to use when testing. + test_bboxes: The list of lists of bounding boxes to use when testing. + predict_data: The numpy array or list of arrays to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.image.detection.data.ObjectDetectionData`. + + Examples + ________ + + .. doctest:: + + >>> import numpy as np + >>> from flash import Trainer + >>> from flash.image import ObjectDetector, ObjectDetectionData + >>> datamodule = ObjectDetectionData.from_numpy( + ... train_data=[np.random.rand(3, 64, 64), np.random.rand(3, 64, 64), np.random.rand(3, 64, 64)], + ... train_targets=[["cat"], ["dog"], ["cat"]], + ... train_bboxes=[ + ... [{"xmin": 10, "ymin": 20, "width": 5, "height": 10}], + ... [{"xmin": 20, "ymin": 30, "width": 10, "height": 10}], + ... [{"xmin": 10, "ymin": 20, "width": 5, "height": 25}], + ... ], + ... predict_data=[np.random.rand(3, 64, 64)], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['background', 'cat', 'dog'] + >>> model = ObjectDetector(labels=datamodule.labels) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + """ + + ds_kw = dict( + target_formatter=target_formatter, + ) + + train_input = input_cls( + RunningStage.TRAINING, + train_data, + train_targets, + train_bboxes, + **ds_kw, + ) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + + return cls( + train_input, + input_cls( + RunningStage.VALIDATING, + val_data, + val_targets, + val_bboxes, + **ds_kw, + ), + input_cls( + RunningStage.TESTING, + test_data, + test_targets, + test_bboxes, + **ds_kw, + ), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, + **data_module_kwargs, + ) + + @classmethod + def from_images( + cls, + train_images: Optional[List[Image.Image]] = None, + train_targets: Optional[Sequence[Sequence[Any]]] = None, + train_bboxes: Optional[Sequence[Sequence[Dict[str, int]]]] = None, + val_images: Optional[List[Image.Image]] = None, + val_targets: Optional[Sequence[Sequence[Any]]] = None, + val_bboxes: Optional[Sequence[Sequence[Dict[str, int]]]] = None, + test_images: Optional[List[Image.Image]] = None, + test_targets: Optional[Sequence[Sequence[Any]]] = None, + test_bboxes: Optional[Sequence[Sequence[Dict[str, int]]]] = None, + predict_images: Optional[List[Image.Image]] = None, + target_formatter: Optional[TargetFormatter] = None, + input_cls: Type[Input] = ObjectDetectionImageInput, + transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + transform_kwargs: Optional[Dict] = None, + **data_module_kwargs: Any, + ) -> "ObjectDetectionData": + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given lists of PIL + images and corresponding lists of bounding boxes and targets. + + The targets can be in any of our + :ref:`supported classification target formats `. + The bounding boxes are expected to be dictionaries with integer values (representing pixels) and the following + keys: ``xmin``, ``ymin``, ``width``, ``height``. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_images: The list of PIL images to use when training. + train_targets: The list of lists of targets to use when training. + train_bboxes: The list of lists of bounding boxes to use when training. + val_images: The list of PIL images to use when validating. + val_targets: The list of lists of targets to use when validating. + val_bboxes: The list of lists of bounding boxes to use when validating. + test_images: The list of PIL images to use when testing. + test_targets: The list of lists of targets to use when testing. + test_bboxes: The list of lists of bounding boxes to use when testing. + predict_images: The list of PIL images to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.image.detection.data.ObjectDetectionData`. + + Examples + ________ + + .. doctest:: + + >>> from PIL import Image + >>> import numpy as np + >>> from flash import Trainer + >>> from flash.image import ObjectDetector, ObjectDetectionData + >>> datamodule = ObjectDetectionData.from_images( + ... train_images=[ + ... Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")), + ... Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")), + ... Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")), + ... ], + ... train_targets=[["cat"], ["dog"], ["cat"]], + ... train_bboxes=[ + ... [{"xmin": 10, "ymin": 20, "width": 5, "height": 10}], + ... [{"xmin": 20, "ymin": 30, "width": 10, "height": 10}], + ... [{"xmin": 10, "ymin": 20, "width": 5, "height": 25}], + ... ], + ... predict_images=[Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['background', 'cat', 'dog'] + >>> model = ObjectDetector(labels=datamodule.labels) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + """ + + ds_kw = dict( + target_formatter=target_formatter, + ) + + train_input = input_cls( + RunningStage.TRAINING, + train_images, + train_targets, + train_bboxes, + **ds_kw, + ) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + + return cls( + train_input, + input_cls( + RunningStage.VALIDATING, + val_images, + val_targets, + val_bboxes, + **ds_kw, + ), + input_cls( + RunningStage.TESTING, + test_images, + test_targets, + test_bboxes, + **ds_kw, + ), + input_cls(RunningStage.PREDICTING, predict_images, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, + **data_module_kwargs, + ) + + @classmethod + def from_tensors( + cls, + train_data: Optional[Collection[torch.Tensor]] = None, + train_targets: Optional[Sequence[Sequence[Any]]] = None, + train_bboxes: Optional[Sequence[Sequence[Dict[str, int]]]] = None, + val_data: Optional[Collection[torch.Tensor]] = None, + val_targets: Optional[Sequence[Sequence[Any]]] = None, + val_bboxes: Optional[Sequence[Sequence[Dict[str, int]]]] = None, + test_data: Optional[Collection[torch.Tensor]] = None, + test_targets: Optional[Sequence[Sequence[Any]]] = None, + test_bboxes: Optional[Sequence[Sequence[Dict[str, int]]]] = None, + predict_data: Optional[Collection[torch.Tensor]] = None, + target_formatter: Optional[TargetFormatter] = None, + input_cls: Type[Input] = ObjectDetectionTensorInput, + transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + transform_kwargs: Optional[Dict] = None, + **data_module_kwargs: Any, + ) -> "ObjectDetectionData": + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given from torch + tensors (or lists of tensors) and corresponding lists of bounding boxes and targets. + + The targets can be in any of our + :ref:`supported classification target formats `. + The bounding boxes are expected to be dictionaries with integer values (representing pixels) and the following + keys: ``xmin``, ``ymin``, ``width``, ``height``. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_data: The torch tensor or list of tensors to use when training. + train_targets: The list of lists of targets to use when training. + train_bboxes: The list of lists of bounding boxes to use when training. + val_data: The torch tensor or list of tensors to use when validating. + val_targets: The list of lists of targets to use when validating. + val_bboxes: The list of lists of bounding boxes to use when validating. + test_data: The torch tensor or list of tensors to use when testing. + test_targets: The list of lists of targets to use when testing. + test_bboxes: The list of lists of bounding boxes to use when testing. + predict_data: The torch tensor or list of tensors to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.image.detection.data.ObjectDetectionData`. + + Examples + ________ + + .. doctest:: + + >>> import torch + >>> from flash import Trainer + >>> from flash.image import ObjectDetector, ObjectDetectionData + >>> datamodule = ObjectDetectionData.from_tensors( + ... train_data=[torch.rand(3, 64, 64), torch.rand(3, 64, 64), torch.rand(3, 64, 64)], + ... train_targets=[["cat"], ["dog"], ["cat"]], + ... train_bboxes=[ + ... [{"xmin": 10, "ymin": 20, "width": 5, "height": 10}], + ... [{"xmin": 20, "ymin": 30, "width": 10, "height": 10}], + ... [{"xmin": 10, "ymin": 20, "width": 5, "height": 25}], + ... ], + ... predict_data=[torch.rand(3, 64, 64)], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['background', 'cat', 'dog'] + >>> model = ObjectDetector(labels=datamodule.labels) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + """ + + ds_kw = dict( + target_formatter=target_formatter, + ) + + train_input = input_cls( + RunningStage.TRAINING, + train_data, + train_targets, + train_bboxes, + **ds_kw, + ) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + + return cls( + train_input, + input_cls( + RunningStage.VALIDATING, + val_data, + val_targets, + val_bboxes, + **ds_kw, + ), + input_cls( + RunningStage.TESTING, + test_data, + test_targets, + test_bboxes, + **ds_kw, + ), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, + **data_module_kwargs, + ) + @classmethod def from_icedata( cls, diff --git a/flash/image/detection/input.py b/flash/image/detection/input.py index c75d137943..d8a0d60d5b 100644 --- a/flash/image/detection/input.py +++ b/flash/image/detection/input.py @@ -17,10 +17,18 @@ from flash.core.data.io.input import DataKeys from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE +from flash.core.data.utilities.samples import to_samples from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities from flash.core.integrations.icevision.data import IceVisionInput from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, lazy_import, requires -from flash.image.data import ImageFilesInput, IMG_EXTENSIONS, NP_EXTENSIONS +from flash.image.data import ( + ImageFilesInput, + ImageInput, + ImageNumpyInput, + ImageTensorInput, + IMG_EXTENSIONS, + NP_EXTENSIONS, +) if _FIFTYONE_AVAILABLE: fol = lazy_import("fiftyone.core.labels") @@ -69,6 +77,90 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: return sample +class ObjectDetectionNumpyInput(ClassificationInputMixin, ImageNumpyInput): + def load_data( + self, + array: Any, + targets: Optional[List[List[Any]]] = None, + bboxes: Optional[List[List[Dict[str, int]]]] = None, + target_formatter: Optional[TargetFormatter] = None, + ) -> List[Dict[str, Any]]: + if targets is None: + return to_samples(array) + self.load_target_metadata( + [t for target in targets for t in target], add_background=True, target_formatter=target_formatter + ) + + return [ + {DataKeys.INPUT: image, DataKeys.TARGET: {"bboxes": bbox, "labels": label}} + for image, label, bbox in zip(array, targets, bboxes) + ] + + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + sample = super().load_sample(sample) + if DataKeys.TARGET in sample: + sample[DataKeys.TARGET]["labels"] = [ + self.format_target(label) for label in sample[DataKeys.TARGET]["labels"] + ] + return sample + + +class ObjectDetectionImageInput(ClassificationInputMixin, ImageInput): + def load_data( + self, + images: Any, + targets: Optional[List[List[Any]]] = None, + bboxes: Optional[List[List[Dict[str, int]]]] = None, + target_formatter: Optional[TargetFormatter] = None, + ) -> List[Dict[str, Any]]: + if targets is None: + return to_samples(images) + self.load_target_metadata( + [t for target in targets for t in target], add_background=True, target_formatter=target_formatter + ) + + return [ + {DataKeys.INPUT: image, DataKeys.TARGET: {"bboxes": bbox, "labels": label}} + for image, label, bbox in zip(images, targets, bboxes) + ] + + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + sample = super().load_sample(sample) + if DataKeys.TARGET in sample: + sample[DataKeys.TARGET]["labels"] = [ + self.format_target(label) for label in sample[DataKeys.TARGET]["labels"] + ] + return sample + + +class ObjectDetectionTensorInput(ClassificationInputMixin, ImageTensorInput): + def load_data( + self, + tensor: Any, + targets: Optional[List[List[Any]]] = None, + bboxes: Optional[List[List[Dict[str, int]]]] = None, + target_formatter: Optional[TargetFormatter] = None, + ) -> List[Dict[str, Any]]: + if targets is None: + return to_samples(tensor) + self.load_target_metadata( + [t for target in targets for t in target], add_background=True, target_formatter=target_formatter + ) + + return [ + {DataKeys.INPUT: image, DataKeys.TARGET: {"bboxes": bbox, "labels": label}} + for image, label, bbox in zip(tensor, targets, bboxes) + ] + + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + sample = super().load_sample(sample) + if DataKeys.TARGET in sample: + sample[DataKeys.TARGET]["labels"] = [ + self.format_target(label) for label in sample[DataKeys.TARGET]["labels"] + ] + return sample + + class FiftyOneParser(Parser): def __init__(self, data, class_map, label_field, iscrowd): template_record = ObjectDetectionRecord()