From 5acc96d6293d5d654b4c6a8b63d65d4189e5d94c Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 8 Feb 2021 19:58:26 +0530 Subject: [PATCH 1/2] remove num_features arg from Classification model --- flash/vision/classification/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 69a3fd8c85..5c8a7c990b 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -40,7 +40,6 @@ def __init__( self, num_classes, backbone="resnet18", - num_features: int = None, pretrained=True, loss_fn: Callable = F.cross_entropy, optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, From 5f2936be841d9e419b3555a0d8067980d91b14ab Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 8 Feb 2021 20:06:45 +0530 Subject: [PATCH 2/2] add annotations to args --- flash/vision/classification/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 5c8a7c990b..5528cfc5d6 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -38,9 +38,9 @@ class ImageClassifier(ClassificationTask): def __init__( self, - num_classes, - backbone="resnet18", - pretrained=True, + num_classes: int, + backbone: str = "resnet18", + pretrained: bool = True, loss_fn: Callable = F.cross_entropy, optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, metrics: Union[Callable, Mapping, Sequence, None] = (Accuracy()),