Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'optional_input_108' of https://github.com/aniketmaurya/…
Browse files Browse the repository at this point in the history
…lightning-flash into optional_input_108
  • Loading branch information
Aniket Maurya committed Feb 12, 2021
2 parents 2f8e716 + b033748 commit f4e8750
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 29 deletions.
47 changes: 42 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,56 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [0.1.0] - 02/02/2021
## [Unreleased] - 2021-MM-DD

### Added

- Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/lightning-flash/pull/9))
- Added `strategy` to `trainer.finetune` with `NoFreeze`, `Freeze`, `FreezeUnfreeze`, `UnfreezeMilestones` Callbacks([#39](https://github.com/PyTorchLightning/lightning-flash/pull/39))
- Added `SummarizationData`, `SummarizationTask` and `TranslationData`, `TranslationTask` ([#37](https://github.com/PyTorchLightning/lightning-flash/pull/37))
- Added `ImageEmbedder` ([#36](https://github.com/PyTorchLightning/lightning-flash/pull/36))


### Changed



### Fixed



### Removed




## [0.2.0] - 2021-02-12

### Added

- Added `ObjectDetector` Task ([#56](https://github.com/PyTorchLightning/lightning-flash/pull/56))
- Added TabNet for tabular classification ([#101](https://github.com/PyTorchLightning/lightning-flash/pull/#101))
- Added support for more backbones(mobilnet, vgg, densenet, resnext) ([#45](https://github.com/PyTorchLightning/lightning-flash/pull/45))
- Added backbones for image embedding model ([#63](https://github.com/PyTorchLightning/lightning-flash/pull/63))
- Added SWAV and SimCLR models to `imageclassifier` + backbone reorg ([#68](https://github.com/PyTorchLightning/lightning-flash/pull/68))

### Changed

- Applied transform in `FilePathDataset` ([#97](https://github.com/PyTorchLightning/lightning-flash/pull/97))
- Moved classification integration from vision root to folder ([#86](https://github.com/PyTorchLightning/lightning-flash/pull/86))

### Fixed

- Unfreeze default number of workers in datamodule ([#57](https://github.com/PyTorchLightning/lightning-flash/pull/57))
- Fixed wrong label in `FilePathDataset` ([#94](https://github.com/PyTorchLightning/lightning-flash/pull/94))

### Removed

- Removed `densenet161` duplicate in `DENSENET_MODELS` ([#76](https://github.com/PyTorchLightning/lightning-flash/pull/76))
- Removed redundant `num_features` arg from Classification model ([#88](https://github.com/PyTorchLightning/lightning-flash/pull/88))


## [0.1.0] - 2021-02-02

### Added

- Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/lightning-flash/pull/9))
- Added `strategy` to `trainer.finetune` with `NoFreeze`, `Freeze`, `FreezeUnfreeze`, `UnfreezeMilestones` Callbacks([#39](https://github.com/PyTorchLightning/lightning-flash/pull/39))
- Added `SummarizationData`, `SummarizationTask` and `TranslationData`, `TranslationTask` ([#37](https://github.com/PyTorchLightning/lightning-flash/pull/37))
- Added `ImageEmbedder` ([#36](https://github.com/PyTorchLightning/lightning-flash/pull/36))
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
'sphinx.ext.intersphinx',
# 'sphinx.ext.todo',
# 'sphinx.ext.coverage',
'sphinx.ext.viewcode',
'sphinx.ext.autosummary',
'sphinx.ext.napoleon',
'sphinx.ext.imgmath',
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Lightning Flash
reference/text_classification
reference/tabular_classification
reference/translation
reference/object_detection

.. toctree::
:maxdepth: 1
Expand Down
5 changes: 2 additions & 3 deletions docs/source/reference/image_embedder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Use the :class:`~flash.vision.ImageEmbedder` pretrained model for inference on a
embedder = ImageEmbedder(backbone="resnet18")
# 2. Perform inference on an image file
embeddings = model.predict("path/to/image.png")
embeddings = embedder.predict("path/to/image.png")
print(embeddings)
Or on a random image tensor
Expand Down Expand Up @@ -91,13 +91,12 @@ By default, we use the encoder from `SwAV <https://arxiv.org/pdf/2006.09882.pdf>

.. note::

When changing the backbone, make sure you pass in the same backbone to the Task and the Data object!
When changing the backbone, make sure you pass in the same backbone to the Task!

.. code-block:: python
# 1. organize the data
data = ImageClassificationData.from_folders(
backbone="resnet34",
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/"
)
Expand Down
132 changes: 132 additions & 0 deletions docs/source/reference/object_detection.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@

.. _object_detection:

################
Object Detection
################

********
The task
********

The object detection task identifies instances of objects of a certain class within an image.

------

*********
Inference
*********

The :class:`~flash.vision.ObjectDetector` is already pre-trained on `COCO train2017 <https://cocodataset.org/>`_, a dataset with `91 classes <https://cocodataset.org/#explore>`_ (123,287 images, 886,284 instances).

.. code-block::
annotation{
"id": int,
"image_id": int,
"category_id": int,
"segmentation": RLE or [polygon],
"area": float,
"bbox": [x,y,width,height],
"iscrowd": 0 or 1,
}
categories[{
"id": int,
"name": str,
"supercategory": str,
}]
Use the :class:`~flash.vision.ObjectDetector` pretrained model for inference on any image tensor or image path using :func:`~flash.vision.ObjectDetector.predict`:

.. code-block:: python
from flash.vision import ObjectDetector
# 1. Load the model
detector = ObjectDetector()
# 2. Perform inference on an image file
predictions = detector.predict("path/to/image.png")
print(predictions)
Or on a random image tensor

.. code-block:: python
# Perform inference on a random image tensor
import torch
images = torch.rand(32, 3, 1080, 1920)
predictions = detector.predict(images)
print(predictions)
For more advanced inference options, see :ref:`predictions`.

------

**********
Finetuning
**********

To tailor the object detector to your dataset, you would need to have it in `COCO Format <https://cocodataset.org/#format-data>`_, and then finetune the model.

.. code-block:: python
import flash
from flash.core.data import download_data
from flash.vision import ObjectDetectionData, ObjectDetector
# 1. Download the data
# Dataset Credit: https://www.kaggle.com/ultralytics/coco128
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/")
# 2. Load the Data
datamodule = ObjectDetectionData.from_coco(
train_folder="data/coco128/images/train2017/",
train_ann_file="data/coco128/annotations/instances_train2017.json",
batch_size=2
)
# 3. Build the model
model = ObjectDetector(num_classes=datamodule.num_classes)
# 4. Create the trainer. Run thrice on data
trainer = flash.Trainer(max_epochs=3)
# 5. Finetune the model
trainer.finetune(model, datamodule)
# 6. Save it!
trainer.save_checkpoint("object_detection_model.pt")
------

*****
Model
*****

By default, we use the `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_ model with a ResNet-50 FPN backbone. The inputs could be images of different sizes. The model behaves differently for training and evaluation. For training, it expects both the input tensors as well as the targets. And during evaluation, it expects only the input tensors and returns predictions for each image. The predictions are a list of boxes, labels and scores.

------

*************
API reference
*************

.. _object_detector:

ObjectDetector
--------------

.. autoclass:: flash.vision.ObjectDetector
:members:
:exclude-members: forward

.. _object_detection_data:

ObjectDetectionData
-------------------

.. autoclass:: flash.vision.ObjectDetectionData

.. automethod:: flash.vision.ObjectDetectionData.from_coco
12 changes: 10 additions & 2 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Root package info."""
import os

__version__ = "0.2.0rc1"
__version__ = "0.2.1-dev"
__author__ = "PyTorchLightning et al."
__author_email__ = "name@pytorchlightning.ai"
__license__ = 'Apache-2.0'
Expand Down Expand Up @@ -56,5 +56,13 @@
from flash.core.trainer import Trainer

__all__ = [
"Task", "ClassificationTask", "DataModule", "vision", "text", "tabular", "data", "utils", "download_data"
"Task",
"ClassificationTask",
"DataModule",
"vision",
"text",
"tabular",
"data",
"utils",
"download_data",
]
3 changes: 1 addition & 2 deletions flash/tabular/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Optional, Tuple, Type, Union
from typing import Any, Callable, List, Optional, Tuple, Type

import torch
from pytorch_lightning.metrics import Metric
from pytorch_tabnet.tab_network import TabNet
from torch import nn
from torch.nn import functional as F

from flash.core.classification import ClassificationTask
Expand Down
14 changes: 5 additions & 9 deletions flash/vision/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union
from typing import Any, Callable, Mapping, Sequence, Type, Union

import torch
import torchvision
Expand All @@ -20,7 +20,6 @@
from torchvision.ops import box_iou

from flash.core import Task
from flash.core.data import DataPipeline
from flash.vision.detection.data import ObjectDetectionDataPipeline
from flash.vision.detection.finetuning import ObjectDetectionFineTuning

Expand All @@ -29,8 +28,7 @@

def _evaluate_iou(target, pred):
"""
Evaluate intersection over union (IOU) for target from dataset and output prediction
from model
Evaluate intersection over union (IOU) for target from dataset and output prediction from model
"""
if pred["boxes"].shape[0] == 0:
# no box detected, 0 IOU
Expand All @@ -42,17 +40,16 @@ class ObjectDetector(Task):
"""Image detection task
Ref: Lightning Bolts https://github.com/PyTorchLightning/pytorch-lightning-bolts
Args:
num_classes: the number of classes for detection, including background
model: either a string of :attr`_models` or a custom nn.Module.
Defaults to 'fasterrcnn_resnet50_fpn'.
loss: the function(s) to update the model with. Has no effect for torchvision detection models.
metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger.
Defaults to None.
optimizer: The optimizer to use for training. Can either be the actual class or the class name.
Defaults to Adam.
pretrained: Whether the model from torchvision should be loaded with it's pretrained weights.
Has no effect for custom models. Defaults to True.
Has no effect for custom models.
learning_rate: The learning rate to use for training
"""
Expand Down Expand Up @@ -89,8 +86,7 @@ def __init__(
)

def training_step(self, batch, batch_idx) -> Any:
"""The training step.
Overrides Task.training_step
"""The training step. Overrides ``Task.training_step``
"""
images, targets = batch
targets = [{k: v for k, v in t.items()} for t in targets]
Expand Down
1 change: 0 additions & 1 deletion flash/vision/embedding/image_embedder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union

import torch
import torchvision
from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down
1 change: 0 additions & 1 deletion tests/vision/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import pytest
import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash import Trainer
from flash.vision import ImageClassifier
Expand Down
4 changes: 0 additions & 4 deletions tests/vision/detection/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@
from pathlib import Path

import pytest
import torch
from PIL import Image
from pytorch_lightning.utilities import _module_available
from torchvision import transforms as T

from flash.vision.detection.data import ObjectDetectionData

_COCO_AVAILABLE = _module_available("pycocotools")
if _COCO_AVAILABLE:
from pycocotools.coco import COCO


def _create_dummy_coco_json(dummy_json_path):
Expand Down
2 changes: 0 additions & 2 deletions tests/vision/detection/test_data_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from tests.vision.detection.test_data import _create_synth_coco_dataset

_COCO_AVAILABLE = _module_available("pycocotools")
if _COCO_AVAILABLE:
from pycocotools.coco import COCO


@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing")
Expand Down

0 comments on commit f4e8750

Please sign in to comment.