Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add YOLO object detection model #552

Merged
merged 86 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from 82 commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
461de96
Add YOLO object detection model
senarvi Feb 2, 2021
2b9b073
Readability improvements
senarvi Feb 3, 2021
cc42540
Documentation improvements
senarvi Feb 3, 2021
876da0d
Fixed style issues.
senarvi Feb 3, 2021
f99930e
Refactoring
senarvi Feb 5, 2021
4415d41
Refactoring
senarvi Feb 8, 2021
7356fe0
Refactoring
senarvi Feb 8, 2021
39eb80d
Fixed YOLO test.
senarvi Feb 8, 2021
291f4be
Fixedd style issues
senarvi Feb 9, 2021
8db7947
Comply to isort rules.
senarvi Feb 9, 2021
2831755
Reading Darknet weights works also with truncated files.
senarvi Feb 9, 2021
eb26eba
Fixed code formatting.
senarvi Feb 9, 2021
9c155a9
Trying to fix Python 3.6 import problem.
senarvi Feb 9, 2021
efeb1c8
Fixed Python 3.6 import error.
senarvi Feb 9, 2021
1a1ecd3
Added YOLO to CHANGELOG.
senarvi Feb 9, 2021
26ff979
Use torch.min() instead of torch.minimum() to avoid error with older …
senarvi Feb 9, 2021
3e9bdde
Generalized interface for custom losses
senarvi Feb 12, 2021
c348619
box_area() implementation copied from torchvision
senarvi Feb 12, 2021
c2d7907
Confirm to yapf formatter rules.
senarvi Feb 12, 2021
3d7f440
Removed the unnecessary linter instructions.
senarvi Feb 15, 2021
eb6be46
IoU losses use torchvision
senarvi Feb 15, 2021
3d940f2
Merge branch 'yolo' of github.com:groke-technologies/pytorch-lightnin…
senarvi Feb 15, 2021
cf7420c
Improved strange yapf formatting
senarvi Feb 15, 2021
6c90cd4
Refactoring
senarvi Feb 15, 2021
7aeea63
Merge branch 'yolo' of github.com:groke-technologies/pytorch-lightnin…
senarvi Feb 15, 2021
60fda75
get_deprecated_arg_names() is not needed anymore.
senarvi Feb 15, 2021
b2e3e84
Fixed yapf formatting.
senarvi Feb 15, 2021
940947f
Fixed formatting.
senarvi Feb 15, 2021
6e3d5bf
Removed unused imports.
senarvi Feb 15, 2021
a5bed26
Merge branch 'yolo' of github.com:groke-technologies/pytorch-lightnin…
senarvi Feb 15, 2021
b2ea497
Fixed some type hints.
senarvi Feb 16, 2021
6d8fa7d
Sorted imports.
senarvi Feb 16, 2021
e68df7a
Possible to limit the number of predictions per image
senarvi Feb 24, 2021
f895530
None instead of an empty list as default argument
senarvi Feb 24, 2021
58f1456
Fixed capitalization of YOLO class.
senarvi Feb 24, 2021
19b9df7
Merge branch 'origin/master' into yolo
senarvi Mar 4, 2021
4e6d4cf
No need to check for NaN values as Trainer has terminate_on_nan argum…
senarvi Mar 8, 2021
af3e0e6
YOLO test configuration moved to tests/data/yolo.cfg
senarvi Mar 8, 2021
c4ae5ec
Merge branch 'origin/master' into yolo
senarvi Mar 8, 2021
4012247
Use Optional[] as the default value for transforms is now None
senarvi Mar 8, 2021
c8b76a5
Refactoring and documentation improvements
senarvi Mar 23, 2021
71a4c3c
Fixed documentation formatting
senarvi Mar 24, 2021
da4eace
Merge branch 'origin/master' into yolo
senarvi Mar 24, 2021
70f14b0
Coordinate predictions are in image scale
senarvi Mar 31, 2021
3c1a0fb
Merge branch 'origin/master' into yolo
senarvi Mar 31, 2021
9587b46
Use default dtype for torch.arange() to fix export to TensorRT
senarvi Apr 1, 2021
2b6c552
Network input size can differ from the image size specified in the co…
senarvi Apr 10, 2021
e97b198
Merge branch 'yolo' of github.com:groke-technologies/pytorch-lightnin…
senarvi Apr 10, 2021
dbe1d59
Merge branch 'origin/master' into yolo
senarvi Apr 10, 2021
b66dbd3
Merge branch 'master' into yolo
senarvi May 3, 2021
004d1ce
Use torch.true_divide() instead of /
senarvi May 3, 2021
ad1e48e
Use torch.true_divide() instead of /
senarvi May 5, 2021
a254517
Merge branch 'master' into yolo
senarvi May 13, 2021
1cde4f8
Merge branch 'origin/master' into yolo
senarvi Jun 1, 2021
e92c405
Merge branch 'master' into yolo
senarvi Jun 17, 2021
c237b37
Loss is normalized by batch size only once
senarvi Jun 23, 2021
9b010de
Fixed division by zero when there are no targets in a batch
senarvi Jun 23, 2021
dc7ae4c
Merge branch 'master' into yolo
senarvi Jun 23, 2021
f15282d
Always return all losses to avoid deadlock with DDP when there are no…
senarvi Jun 24, 2021
de52b75
Merge branch 'master' into yolo
senarvi Jun 24, 2021
f6d3476
Hit rates are always logged so don't prefix the names
senarvi Jul 1, 2021
8e12359
Merge branch 'master' into yolo
senarvi Jul 1, 2021
86a6b66
Fixed training loss
senarvi Jul 31, 2021
7fd38ca
Merge branch 'origin/master' into yolo
senarvi Jul 31, 2021
3286533
Truncate nms() inputs to avoid it crashing when too many boxes are de…
senarvi Aug 4, 2021
bb92076
Use sum() instead of count_nonzero() as it's available already before…
senarvi Aug 11, 2021
7d08350
Merge branch 'master' into yolo
senarvi Aug 11, 2021
b896112
Squared error loss takes the sum over the predicted attributes
senarvi Aug 17, 2021
55a1180
Swish and logistic activation functions
senarvi Aug 17, 2021
6fb82c1
Merge branch 'master' into yolo
senarvi Aug 17, 2021
7857bea
Added a comment
senarvi Aug 18, 2021
0804699
Fixed code formatting
senarvi Aug 18, 2021
7b32d64
Ran docformatter with correct config
senarvi Aug 19, 2021
7b316b8
Merge branch 'master' into yolo
senarvi Aug 19, 2021
191e865
Ran pyupgrade
senarvi Aug 19, 2021
66c8111
Reformatted
senarvi Aug 19, 2021
572ca4f
VOCDetectionDataModule constructor takes batch size and the transform…
senarvi Aug 30, 2021
198dd32
Merge branch 'master' into yolo
senarvi Aug 30, 2021
db278be
Use true_divide() for integer division
senarvi Aug 30, 2021
4731b24
Fixed doc and package build without Torchvision
senarvi Aug 30, 2021
f63df0b
Merge branch 'master' into yolo
senarvi Aug 31, 2021
3b70b77
Merge branch 'master' into yolo
senarvi Sep 8, 2021
022192a
YOLO moved to unreleased
senarvi Sep 10, 2021
a8e1ce3
Merge branch 'master' into yolo
senarvi Sep 10, 2021
bfc774c
Code formatting
senarvi Sep 10, 2021
8d4e43e
Merge branch 'yolo' of github.com:groke-technologies/pytorch-lightnin…
senarvi Sep 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added Pix2Pix model ([#533](https://github.com/PyTorchLightning/lightning-bolts/pull/533))
- Added YOLO model ([#552](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/552))
senarvi marked this conversation as resolved.
Show resolved Hide resolved

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ Lightning-Bolts documentation

autoencoders
convolutional
object_detection
gans
reinforce_learn
self_supervised_models
Expand Down
20 changes: 20 additions & 0 deletions docs/source/object_detection.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
Object Detection
================
This package lists contributed object detection models.

--------------


Faster R-CNN
------------

.. autoclass:: pl_bolts.models.detection.faster_rcnn.faster_rcnn_module.FasterRCNN
:noindex:

-------------

YOLO
----

.. autoclass:: pl_bolts.models.detection.yolo.yolo_module.YOLO
:noindex:
77 changes: 35 additions & 42 deletions pl_bolts/datamodules/vocdetection_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import os
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from pytorch_lightning import LightningDataModule
from torch import Tensor
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down Expand Up @@ -107,10 +108,11 @@ class VOCDetectionDataModule(LightningDataModule):

def __init__(
self,
data_dir: str,
data_dir: Optional[str] = None,
year: str = "2012",
num_workers: int = 0,
normalize: bool = False,
batch_size: int = 16,
shuffle: bool = True,
pin_memory: bool = True,
drop_last: bool = False,
Expand All @@ -125,9 +127,10 @@ def __init__(
super().__init__(*args, **kwargs)

self.year = year
self.data_dir = data_dir
self.data_dir = data_dir if data_dir is not None else os.getcwd()
self.num_workers = num_workers
self.normalize = normalize
self.batch_size = batch_size
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
Expand All @@ -145,60 +148,50 @@ def prepare_data(self) -> None:
VOCDetection(self.data_dir, year=self.year, image_set="train", download=True)
VOCDetection(self.data_dir, year=self.year, image_set="val", download=True)

def train_dataloader(
self, batch_size: int = 1, image_transforms: Union[List[Callable], Callable] = None
) -> DataLoader:
def train_dataloader(self, image_transforms: Optional[Callable] = None) -> DataLoader:
"""VOCDetection train set uses the `train` subset.

Args:
batch_size: size of batch
transforms: custom transforms
image_transforms: custom image-only transforms
"""
transforms = [_prepare_voc_instance]
image_transforms = image_transforms or self.train_transforms or self._default_transforms()
transforms = [
_prepare_voc_instance,
self.default_transforms() if self.train_transforms is None else self.train_transforms,
]
transforms = Compose(transforms, image_transforms)

dataset = VOCDetection(self.data_dir, year=self.year, image_set="train", transforms=transforms)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
collate_fn=_collate_fn,
)
return loader
return self._data_loader(dataset, shuffle=self.shuffle)

def val_dataloader(self, batch_size: int = 1, image_transforms: Optional[List[Callable]] = None) -> DataLoader:
def val_dataloader(self, image_transforms: Optional[Callable] = None) -> DataLoader:
"""VOCDetection val set uses the `val` subset.

Args:
batch_size: size of batch
transforms: custom transforms
image_transforms: custom image-only transforms
"""
transforms = [_prepare_voc_instance]
image_transforms = image_transforms or self.train_transforms or self._default_transforms()
transforms = [
_prepare_voc_instance,
self.default_transforms() if self.val_transforms is None else self.val_transforms,
]
transforms = Compose(transforms, image_transforms)

dataset = VOCDetection(self.data_dir, year=self.year, image_set="val", transforms=transforms)
loader = DataLoader(
return self._data_loader(dataset, shuffle=False)

def default_transforms(self) -> Callable:
voc_transforms = [transform_lib.ToTensor()]
if self.normalize:
voc_transforms += [transform_lib.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
voc_transforms = transform_lib.Compose(voc_transforms)
return lambda image, target: (voc_transforms(image), target)

def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader:
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
batch_size=self.batch_size,
shuffle=shuffle,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
collate_fn=_collate_fn,
)
return loader

def _default_transforms(self) -> Callable:
if self.normalize:
voc_transforms = transform_lib.Compose(
[
transform_lib.ToTensor(),
transform_lib.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
else:
voc_transforms = transform_lib.Compose([transform_lib.ToTensor()])
return voc_transforms
7 changes: 3 additions & 4 deletions pl_bolts/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pl_bolts.models.detection import components
from pl_bolts.models.detection.faster_rcnn import FasterRCNN
from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration
from pl_bolts.models.detection.yolo.yolo_module import YOLO

__all__ = [
"components",
"FasterRCNN",
]
__all__ = ["components", "FasterRCNN", "YOLOConfiguration", "YOLO"]
senarvi marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,8 @@ def run_cli():

seed_everything(42)
parser = ArgumentParser()
parser = VOCDetectionDataModule.add_argparse_args(parser)
parser = Trainer.add_argparse_args(parser)
parser.add_argument("--data_dir", type=str, default=".")
parser.add_argument("--batch_size", type=int, default=1)
parser = FasterRCNN.add_model_specific_args(parser)

args = parser.parse_args()
Expand Down
Empty file.
Loading