Skip to content

Commit

Permalink
Construct YOLOv5 models with TorchVision MobileNetV3 backbone (#342)
Browse files Browse the repository at this point in the history
* Add yolov5n and yolov5lite for training with ultralytics

* Apply pre-commit

* Minor fixes for docstrings

* Use frcnn layout

* Fix configurations for mobilenetv3 with yolov5

* Apply pre-commit

* Move yolov5lite into yolort.models

* Minor fixes

* Apply pre-commit

* Change to mobilenet_v3_small
  • Loading branch information
zhiqwang authored Mar 5, 2022
1 parent de11966 commit b96c225
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 0 deletions.
2 changes: 2 additions & 0 deletions yolort/models/box_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def __init__(
num_classes: int,
):
super().__init__()
if not isinstance(in_channels, list):
in_channels = [in_channels] * num_anchors
self.num_anchors = num_anchors # anchors
self.num_classes = num_classes
self.num_outputs = num_classes + 5 # number of outputs per anchor
Expand Down
160 changes: 160 additions & 0 deletions yolort/models/yolo_lite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from torch import nn
from torchvision.models import mobilenet
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.detection.backbone_utils import _validate_trainable_layers
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool

from .yolo import YOLO

__all__ = ["yolov5_mobilenet_v3_small_fpn"]


class BackboneWithFPN(nn.Module):
"""
Adds a FPN on top of a model.
Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
extract a submodel that returns the feature maps specified in return_layers.
The same limitations of IntermediateLayerGetter apply here.
Args:
backbone (nn.Module)
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
in_channels_list (List[int]): number of channels for each feature map
that is returned, in the order they are present in the OrderedDict
out_channels (int): number of channels in the FPN.
Attributes:
out_channels (int): the number of channels in the FPN
"""

def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
super().__init__()

if extra_blocks is None:
extra_blocks = LastLevelMaxPool()

self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.fpn = FeaturePyramidNetwork(
in_channels_list=in_channels_list,
out_channels=out_channels,
extra_blocks=extra_blocks,
)
self.out_channels = out_channels

def forward(self, x):
x = self.body(x)
x = self.fpn(x)

return list(x.values()) # unpack OrderedDict into two lists for easier handling


def mobilenet_backbone(
backbone_name,
pretrained,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=2,
returned_layers=None,
):
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features

# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = (
[0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
)
num_stages = len(stage_indices)

# find the index of the layer from which we wont freeze
assert 0 <= trainable_layers <= num_stages
freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]

for b in backbone[:freeze_before]:
for parameter in b.parameters():
parameter.requires_grad_(False)

out_channels = 256

if returned_layers is None:
returned_layers = [num_stages - 2, num_stages - 1]
assert min(returned_layers) >= 0 and max(returned_layers) < num_stages
return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}

in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
return BackboneWithFPN(
backbone,
return_layers,
in_channels_list,
out_channels,
extra_blocks=LastLevelMaxPool(),
)


def _yolov5_mobilenet_v3_small_fpn(
weights_name,
pretrained=False,
progress=True,
num_classes=80,
pretrained_backbone=True,
trainable_backbone_layers=None,
**kwargs,
):
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3
)

if pretrained:
pretrained_backbone = False
backbone = mobilenet_backbone(
"mobilenet_v3_small",
pretrained_backbone,
trainable_layers=trainable_backbone_layers,
)

model = YOLO(backbone, num_classes, **kwargs)

return model


def yolov5_mobilenet_v3_small_fpn(
pretrained=False,
progress=True,
num_classes=80,
pretrained_backbone=True,
trainable_backbone_layers=None,
**kwargs,
):
"""
Constructs a high resolution YOLOv5 model with a MobileNetV3-Large FPN backbone.
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
:func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
details.
Note:
We do not provide a pre-trained model with mobilenet as the backbone now, this function
is just used as an example of how to construct a YOLOv5 model with TorchVision's pre-trained
MobileNetV3-Small FPN backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting
from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers
are trainable.
"""
weights_name = "yolov5_mobilenet_v3_small_fpn_coco"

return _yolov5_mobilenet_v3_small_fpn(
weights_name,
pretrained=pretrained,
progress=progress,
num_classes=num_classes,
pretrained_backbone=pretrained_backbone,
trainable_backbone_layers=trainable_backbone_layers,
**kwargs,
)

0 comments on commit b96c225

Please sign in to comment.