Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add support for more backbones(mobilnet, vgg, densenet, resnext) & refactor #45

Merged
merged 12 commits into from
Feb 2, 2021
2 changes: 2 additions & 0 deletions flash/vision/classification/components/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from flash.vision.classification.components.backbones import torchvision_backbone_and_num_features
from flash.vision.classification.components.model_zoo import TORCHVISION_MODEL_ZOO
32 changes: 32 additions & 0 deletions flash/vision/classification/components/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Tuple

import torch.nn as nn

from flash.vision.classification.components.torchvision_model_zoo import TORCHVISION_MODEL_ZOO


def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
"""
Returns CNN backbone & it's final num of features from Torchvision supported models.
Args:
model_name: Name of the model. E.g. resnet18
pretrained: Pretrained weights on the ImageNet dataset
"""
model = TORCHVISION_MODEL_ZOO[model_name]
model = model(pretrained=pretrained)
if model_name in ["mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19"]:
backbone = model.features
num_features = model.classifier[-1].in_features
return backbone, num_features

elif model_name in [
"resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d"
]:
backbone = nn.Sequential(*list(model.children())[:-2])
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
num_features = model.fc.in_features
return backbone, num_features

elif model_name in ["dense121", "densenet169", "densenet161", "densenet161"]:
backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True))
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
num_features = model.classifier.in_features
return backbone, num_features
20 changes: 20 additions & 0 deletions flash/vision/classification/components/model_zoo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torchvision

TORCHVISION_MODEL_ZOO = {
carmocca marked this conversation as resolved.
Show resolved Hide resolved
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
"vgg11": torchvision.models.vgg11,
"vgg13": torchvision.models.vgg13,
"vgg16": torchvision.models.vgg16,
"vgg19": torchvision.models.vgg19,
"resnet18": torchvision.models.resnet18,
"resnet34": torchvision.models.resnet34,
"resnet50": torchvision.models.resnet50,
"resnet101": torchvision.models.resnet101,
"resnet152": torchvision.models.resnet152,
"resnext50_32x4d": torchvision.models.resnext50_32x4d,
"resnext50_32x8d": torchvision.models.resnext101_32x8d,
"mobilenet_v2": torchvision.models.mobilenet_v2,
"densenet121": torchvision.models.densenet121,
"densenet169": torchvision.models.densenet169,
"densenet161": torchvision.models.densenet161,
"densenet201": torchvision.models.densenet201,
}
19 changes: 3 additions & 16 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,9 @@
from torch.nn import functional as F

from flash.core.classification import ClassificationTask
from flash.vision.classification.components import torchvision_backbone_and_num_features, TORCHVISION_MODEL_ZOO
from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataPipeline

_resnet_backbone = lambda model: nn.Sequential(*list(model.children())[:-2]) # noqa: E731
_resnet_feats = lambda model: model.fc.in_features # noqa: E731

_backbones = {
"resnet18": (torchvision.models.resnet18, _resnet_backbone, _resnet_feats),
"resnet34": (torchvision.models.resnet34, _resnet_backbone, _resnet_feats),
"resnet50": (torchvision.models.resnet50, _resnet_backbone, _resnet_feats),
"resnet101": (torchvision.models.resnet101, _resnet_backbone, _resnet_feats),
"resnet152": (torchvision.models.resnet152, _resnet_backbone, _resnet_feats),
}


class ImageClassifier(ClassificationTask):
"""Task that classifies images.
Expand Down Expand Up @@ -67,13 +57,10 @@ def __init__(

self.save_hyperparameters()

if backbone not in _backbones:
if backbone not in TORCHVISION_MODEL_ZOO:
raise NotImplementedError(f"Backbone {backbone} is not yet supported")

backbone_fn, split, num_feats = _backbones[backbone]
backbone = backbone_fn(pretrained=pretrained)
self.backbone = split(backbone)
num_features = num_feats(backbone)
self.backbone, num_features = torchvision_backbone_and_num_features(backbone, pretrained)

self.head = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
Expand Down