From 59e5b9141dde9079cbca704e577ea65aaa056fdc Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 2 Feb 2021 01:33:34 +0530 Subject: [PATCH 1/8] add support for more backbones & refactor --- .../classification/components/__init__.py | 2 ++ .../classification/components/backbones.py | 32 +++++++++++++++++++ .../classification/components/model_zoo.py | 20 ++++++++++++ flash/vision/classification/model.py | 19 ++--------- 4 files changed, 57 insertions(+), 16 deletions(-) create mode 100644 flash/vision/classification/components/__init__.py create mode 100644 flash/vision/classification/components/backbones.py create mode 100644 flash/vision/classification/components/model_zoo.py diff --git a/flash/vision/classification/components/__init__.py b/flash/vision/classification/components/__init__.py new file mode 100644 index 0000000000..3190cd5a27 --- /dev/null +++ b/flash/vision/classification/components/__init__.py @@ -0,0 +1,2 @@ +from flash.vision.classification.components.backbones import torchvision_backbone_and_num_features +from flash.vision.classification.components.model_zoo import TORCHVISION_MODEL_ZOO diff --git a/flash/vision/classification/components/backbones.py b/flash/vision/classification/components/backbones.py new file mode 100644 index 0000000000..7ee3c076a1 --- /dev/null +++ b/flash/vision/classification/components/backbones.py @@ -0,0 +1,32 @@ +from typing import Tuple + +import torch.nn as nn + +from flash.vision.classification.components.torchvision_model_zoo import TORCHVISION_MODEL_ZOO + + +def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: + """ + Returns CNN backbone & it's final num of features from Torchvision supported models. + Args: + model_name: Name of the model. E.g. resnet18 + pretrained: Pretrained weights on the ImageNet dataset + """ + model = TORCHVISION_MODEL_ZOO[model_name] + model = model(pretrained=pretrained) + if model_name in ["mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19"]: + backbone = model.features + num_features = model.classifier[-1].in_features + return backbone, num_features + + elif model_name in [ + "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d" + ]: + backbone = nn.Sequential(*list(model.children())[:-2]) + num_features = model.fc.in_features + return backbone, num_features + + elif model_name in ["dense121", "densenet169", "densenet161", "densenet161"]: + backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True)) + num_features = model.classifier.in_features + return backbone, num_features diff --git a/flash/vision/classification/components/model_zoo.py b/flash/vision/classification/components/model_zoo.py new file mode 100644 index 0000000000..46cfda516e --- /dev/null +++ b/flash/vision/classification/components/model_zoo.py @@ -0,0 +1,20 @@ +import torchvision + +TORCHVISION_MODEL_ZOO = { + "vgg11": torchvision.models.vgg11, + "vgg13": torchvision.models.vgg13, + "vgg16": torchvision.models.vgg16, + "vgg19": torchvision.models.vgg19, + "resnet18": torchvision.models.resnet18, + "resnet34": torchvision.models.resnet34, + "resnet50": torchvision.models.resnet50, + "resnet101": torchvision.models.resnet101, + "resnet152": torchvision.models.resnet152, + "resnext50_32x4d": torchvision.models.resnext50_32x4d, + "resnext50_32x8d": torchvision.models.resnext101_32x8d, + "mobilenet_v2": torchvision.models.mobilenet_v2, + "densenet121": torchvision.models.densenet121, + "densenet169": torchvision.models.densenet169, + "densenet161": torchvision.models.densenet161, + "densenet201": torchvision.models.densenet201, +} diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 66c38fa08c..545cf93a9d 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -20,19 +20,9 @@ from torch.nn import functional as F from flash.core.classification import ClassificationTask +from flash.vision.classification.components import torchvision_backbone_and_num_features, TORCHVISION_MODEL_ZOO from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataPipeline -_resnet_backbone = lambda model: nn.Sequential(*list(model.children())[:-2]) # noqa: E731 -_resnet_feats = lambda model: model.fc.in_features # noqa: E731 - -_backbones = { - "resnet18": (torchvision.models.resnet18, _resnet_backbone, _resnet_feats), - "resnet34": (torchvision.models.resnet34, _resnet_backbone, _resnet_feats), - "resnet50": (torchvision.models.resnet50, _resnet_backbone, _resnet_feats), - "resnet101": (torchvision.models.resnet101, _resnet_backbone, _resnet_feats), - "resnet152": (torchvision.models.resnet152, _resnet_backbone, _resnet_feats), -} - class ImageClassifier(ClassificationTask): """Task that classifies images. @@ -67,13 +57,10 @@ def __init__( self.save_hyperparameters() - if backbone not in _backbones: + if backbone not in TORCHVISION_MODEL_ZOO: raise NotImplementedError(f"Backbone {backbone} is not yet supported") - backbone_fn, split, num_feats = _backbones[backbone] - backbone = backbone_fn(pretrained=pretrained) - self.backbone = split(backbone) - num_features = num_feats(backbone) + self.backbone, num_features = torchvision_backbone_and_num_features(backbone, pretrained) self.head = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), From 55142ca44f8c9777bead102adaea2a471b3e93dc Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 2 Feb 2021 01:41:12 +0530 Subject: [PATCH 2/8] fix imports --- flash/vision/classification/components/backbones.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/flash/vision/classification/components/backbones.py b/flash/vision/classification/components/backbones.py index 7ee3c076a1..949879f66d 100644 --- a/flash/vision/classification/components/backbones.py +++ b/flash/vision/classification/components/backbones.py @@ -2,16 +2,11 @@ import torch.nn as nn -from flash.vision.classification.components.torchvision_model_zoo import TORCHVISION_MODEL_ZOO +from flash.vision.classification.components import TORCHVISION_MODEL_ZOO def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: - """ - Returns CNN backbone & it's final num of features from Torchvision supported models. - Args: - model_name: Name of the model. E.g. resnet18 - pretrained: Pretrained weights on the ImageNet dataset - """ + model = TORCHVISION_MODEL_ZOO[model_name] model = model(pretrained=pretrained) if model_name in ["mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19"]: From 6b552441a8f84c95fbab72003a49f974557c149a Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 2 Feb 2021 01:45:02 +0530 Subject: [PATCH 3/8] fix import paths --- flash/vision/classification/components/backbones.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/vision/classification/components/backbones.py b/flash/vision/classification/components/backbones.py index 949879f66d..ea142bc437 100644 --- a/flash/vision/classification/components/backbones.py +++ b/flash/vision/classification/components/backbones.py @@ -2,7 +2,7 @@ import torch.nn as nn -from flash.vision.classification.components import TORCHVISION_MODEL_ZOO +from flash.vision.classification.components.model_zoo import TORCHVISION_MODEL_ZOO def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: From 4e479c4eeb7b8c4cbf5958ca366f87e58273770c Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 2 Feb 2021 01:49:49 +0530 Subject: [PATCH 4/8] fix densenet model name --- flash/vision/classification/components/backbones.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/vision/classification/components/backbones.py b/flash/vision/classification/components/backbones.py index ea142bc437..19fd501489 100644 --- a/flash/vision/classification/components/backbones.py +++ b/flash/vision/classification/components/backbones.py @@ -21,7 +21,7 @@ def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = Tr num_features = model.fc.in_features return backbone, num_features - elif model_name in ["dense121", "densenet169", "densenet161", "densenet161"]: + elif model_name in ["densenet121", "densenet169", "densenet161", "densenet161"]: backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True)) num_features = model.classifier.in_features return backbone, num_features From 4b0e410b5f166c19222ea6be1c882a35047c21a8 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 2 Feb 2021 02:07:40 +0530 Subject: [PATCH 5/8] add comment for creating resnet backbone --- flash/vision/classification/components/backbones.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flash/vision/classification/components/backbones.py b/flash/vision/classification/components/backbones.py index 19fd501489..bb8b62cac6 100644 --- a/flash/vision/classification/components/backbones.py +++ b/flash/vision/classification/components/backbones.py @@ -17,6 +17,7 @@ def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = Tr elif model_name in [ "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d" ]: + # remove the last two layers & turn it into a Sequential model backbone = nn.Sequential(*list(model.children())[:-2]) num_features = model.fc.in_features return backbone, num_features From 8385631bdc14275f4c1e7a36e5ab19c6fdaaacf9 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 2 Feb 2021 03:11:24 +0530 Subject: [PATCH 6/8] remove model zoo --- .../{components => }/backbones.py | 16 +++++++++++---- .../classification/components/__init__.py | 2 -- .../classification/components/model_zoo.py | 20 ------------------- flash/vision/classification/model.py | 5 +---- 4 files changed, 13 insertions(+), 30 deletions(-) rename flash/vision/classification/{components => }/backbones.py (67%) delete mode 100644 flash/vision/classification/components/__init__.py delete mode 100644 flash/vision/classification/components/model_zoo.py diff --git a/flash/vision/classification/components/backbones.py b/flash/vision/classification/backbones.py similarity index 67% rename from flash/vision/classification/components/backbones.py rename to flash/vision/classification/backbones.py index bb8b62cac6..071e49d3d5 100644 --- a/flash/vision/classification/components/backbones.py +++ b/flash/vision/classification/backbones.py @@ -1,15 +1,18 @@ from typing import Tuple import torch.nn as nn - -from flash.vision.classification.components.model_zoo import TORCHVISION_MODEL_ZOO +import torchvision +from pytorch_lightning.utilities.exceptions import MisconfigurationException def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: - model = TORCHVISION_MODEL_ZOO[model_name] - model = model(pretrained=pretrained) + model = getattr(torchvision.models, model_name, None) + if model is None: + raise MisconfigurationException(f"{model_name} is not supported by torchvision") + if model_name in ["mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19"]: + model = model(pretrained=pretrained) backbone = model.features num_features = model.classifier[-1].in_features return backbone, num_features @@ -17,12 +20,17 @@ def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = Tr elif model_name in [ "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d" ]: + model = model(pretrained=pretrained) # remove the last two layers & turn it into a Sequential model backbone = nn.Sequential(*list(model.children())[:-2]) num_features = model.fc.in_features return backbone, num_features elif model_name in ["densenet121", "densenet169", "densenet161", "densenet161"]: + model = model(pretrained=pretrained) backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True)) num_features = model.classifier.in_features return backbone, num_features + + else: + raise ValueError(f"{model_name} is not supported yet.") diff --git a/flash/vision/classification/components/__init__.py b/flash/vision/classification/components/__init__.py deleted file mode 100644 index 3190cd5a27..0000000000 --- a/flash/vision/classification/components/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from flash.vision.classification.components.backbones import torchvision_backbone_and_num_features -from flash.vision.classification.components.model_zoo import TORCHVISION_MODEL_ZOO diff --git a/flash/vision/classification/components/model_zoo.py b/flash/vision/classification/components/model_zoo.py deleted file mode 100644 index 46cfda516e..0000000000 --- a/flash/vision/classification/components/model_zoo.py +++ /dev/null @@ -1,20 +0,0 @@ -import torchvision - -TORCHVISION_MODEL_ZOO = { - "vgg11": torchvision.models.vgg11, - "vgg13": torchvision.models.vgg13, - "vgg16": torchvision.models.vgg16, - "vgg19": torchvision.models.vgg19, - "resnet18": torchvision.models.resnet18, - "resnet34": torchvision.models.resnet34, - "resnet50": torchvision.models.resnet50, - "resnet101": torchvision.models.resnet101, - "resnet152": torchvision.models.resnet152, - "resnext50_32x4d": torchvision.models.resnext50_32x4d, - "resnext50_32x8d": torchvision.models.resnext101_32x8d, - "mobilenet_v2": torchvision.models.mobilenet_v2, - "densenet121": torchvision.models.densenet121, - "densenet169": torchvision.models.densenet169, - "densenet161": torchvision.models.densenet161, - "densenet201": torchvision.models.densenet201, -} diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 545cf93a9d..a899b9eb6f 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -20,7 +20,7 @@ from torch.nn import functional as F from flash.core.classification import ClassificationTask -from flash.vision.classification.components import torchvision_backbone_and_num_features, TORCHVISION_MODEL_ZOO +from flash.vision.classification.backbones import torchvision_backbone_and_num_features from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataPipeline @@ -57,9 +57,6 @@ def __init__( self.save_hyperparameters() - if backbone not in TORCHVISION_MODEL_ZOO: - raise NotImplementedError(f"Backbone {backbone} is not yet supported") - self.backbone, num_features = torchvision_backbone_and_num_features(backbone, pretrained) self.head = nn.Sequential( From 1ff125a96723b77e87d9fd13c0f635375d30214a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 1 Feb 2021 22:48:25 +0100 Subject: [PATCH 7/8] Update flash/vision/classification/backbones.py --- flash/vision/classification/backbones.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flash/vision/classification/backbones.py b/flash/vision/classification/backbones.py index 071e49d3d5..a6759ae84e 100644 --- a/flash/vision/classification/backbones.py +++ b/flash/vision/classification/backbones.py @@ -32,5 +32,4 @@ def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = Tr num_features = model.classifier.in_features return backbone, num_features - else: - raise ValueError(f"{model_name} is not supported yet.") + raise ValueError(f"{model_name} is not supported yet.") From 18c89ad86e6dd4705b3ca03a086582238f1087df Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 2 Feb 2021 19:22:43 +0530 Subject: [PATCH 8/8] fix tests to raise the right exception --- tests/vision/classification/test_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/vision/classification/test_model.py b/tests/vision/classification/test_model.py index 84004118d3..d2e06623ea 100644 --- a/tests/vision/classification/test_model.py +++ b/tests/vision/classification/test_model.py @@ -1,5 +1,6 @@ import pytest import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException from flash import Trainer from flash.vision import ImageClassifier @@ -37,7 +38,7 @@ def test_init_train(tmpdir, backbone): def test_non_existent_backbone(): - with pytest.raises(NotImplementedError): + with pytest.raises(MisconfigurationException): ImageClassifier(2, "i am never going to implement this lol")