Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add more ways to load image data for classification and detection (#1372
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ethanwharris authored Jul 1, 2022
1 parent d751b83 commit ff92d44
Show file tree
Hide file tree
Showing 5 changed files with 590 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for Flash serve to the `ObjectDetector` ([#1370](https://github.com/PyTorchLightning/lightning-flash/pull/1370))

- Added support for loading `ImageClassificationData` from PIL images with `from_images` ([#1372](https://github.com/PyTorchLightning/lightning-flash/pull/1372))

- Added support for loading `ObjectDetectionData` with `from_numpy`, `from_images`, and `from_tensors` ([#1372](https://github.com/PyTorchLightning/lightning-flash/pull/1372))

### Changed

- Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276))
Expand Down
93 changes: 93 additions & 0 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
ImageClassificationFiftyOneInput,
ImageClassificationFilesInput,
ImageClassificationFolderInput,
ImageClassificationImageInput,
ImageClassificationNumpyInput,
ImageClassificationTensorInput,
)
Expand All @@ -64,6 +65,7 @@
"ImageClassificationData.from_files",
"ImageClassificationData.from_folders",
"ImageClassificationData.from_numpy",
"ImageClassificationData.from_images",
"ImageClassificationData.from_tensors",
"ImageClassificationData.from_data_frame",
"ImageClassificationData.from_csv",
Expand Down Expand Up @@ -385,6 +387,97 @@ def from_numpy(
**data_module_kwargs,
)

@classmethod
def from_images(
cls,
train_images: Optional[List[Image.Image]] = None,
train_targets: Optional[Sequence[Any]] = None,
val_images: Optional[List[Image.Image]] = None,
val_targets: Optional[Sequence[Any]] = None,
test_images: Optional[List[Image.Image]] = None,
test_targets: Optional[Sequence[Any]] = None,
predict_images: Optional[List[Image.Image]] = None,
target_formatter: Optional[TargetFormatter] = None,
input_cls: Type[Input] = ImageClassificationImageInput,
transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "ImageClassificationData":
"""Load the :class:`~flash.image.classification.data.ImageClassificationData` from lists of PIL images and
corresponding lists of targets.
The targets can be in any of our
:ref:`supported classification target formats <formatting_classification_targets>`.
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
train_images: The list of PIL images to use when training.
train_targets: The list of targets to use when training.
val_images: The list of PIL images to use when validating.
val_targets: The list of targets to use when validating.
test_images: The list of PIL images to use when testing.
test_targets: The list of targets to use when testing.
predict_images: The list of PIL images to use when predicting.
target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to
control how targets are handled. See :ref:`formatting_classification_targets` for more details.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.image.classification.data.ImageClassificationData`.
Examples
________
.. doctest::
>>> from PIL import Image
>>> import numpy as np
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_images(
... train_images=[
... Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")),
... Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")),
... Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")),
... ],
... train_targets=["cat", "dog", "cat"],
... predict_images=[Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
"""
ds_kw = dict(
target_formatter=target_formatter,
)

train_input = input_cls(RunningStage.TRAINING, train_images, train_targets, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
train_input,
input_cls(RunningStage.VALIDATING, val_images, val_targets, **ds_kw),
input_cls(RunningStage.TESTING, test_images, test_targets, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_images, **ds_kw),
transform=transform,
transform_kwargs=transform_kwargs,
**data_module_kwargs,
)

@classmethod
def from_tensors(
cls,
Expand Down
24 changes: 23 additions & 1 deletion flash/image/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
from flash.core.data.utilities.samples import to_samples
from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires
from flash.image.data import ImageFilesInput, ImageNumpyInput, ImageTensorInput, IMG_EXTENSIONS, NP_EXTENSIONS
from flash.image.data import (
ImageFilesInput,
ImageInput,
ImageNumpyInput,
ImageTensorInput,
IMG_EXTENSIONS,
NP_EXTENSIONS,
)

if _FIFTYONE_AVAILABLE:
fol = lazy_import("fiftyone.core.labels")
Expand Down Expand Up @@ -115,6 +122,21 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
return sample


class ImageClassificationImageInput(ClassificationInputMixin, ImageInput):
def load_data(
self, images: Any, targets: Optional[List[Any]] = None, target_formatter: Optional[TargetFormatter] = None
) -> List[Dict[str, Any]]:
if targets is not None:
self.load_target_metadata(targets, target_formatter=target_formatter)
return to_samples(images, targets)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
sample = super().load_sample(sample)
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET])
return sample


class ImageClassificationDataFrameInput(ImageClassificationFilesInput):
labels: list

Expand Down
Loading

0 comments on commit ff92d44

Please sign in to comment.