diff --git a/fiftyone/utils/torch.py b/fiftyone/utils/torch.py index f49335282d..a75d2498ef 100644 --- a/fiftyone/utils/torch.py +++ b/fiftyone/utils/torch.py @@ -252,14 +252,14 @@ class TorchEmbeddingsMixin(fom.EmbeddingsMixin): layer_name (None): the name of the embeddings layer whose output to save, or ``None`` if this model instance should not expose embeddings. Prepend ``"<"`` to save the input tensor instead - feature_extractor (False): whether this model instance should + as_feature_extractor (False): whether this model instance should operate as a feature extractor, in which case ``layer_name`` if provided is used to create a feature extractor. If ``layer_name`` is not provided, the model is used as-is for feature extraction. """ - def __init__(self, model, layer_name=None, feature_extractor=False): - if feature_extractor: + def __init__(self, model, layer_name=None, as_feature_extractor=False): + if as_feature_extractor: if layer_name: # create a torchvision feature extractor self._model = create_feature_extractor( @@ -272,7 +272,7 @@ def __init__(self, model, layer_name=None, feature_extractor=False): embeddings_layer = None self._embeddings_layer = embeddings_layer - self._as_feature_extractor = feature_extractor + self._as_feature_extractor = as_feature_extractor @property def has_embeddings(self): @@ -418,7 +418,7 @@ def predict_all(imgs): inputs that are lists of Tensors embeddings_layer (None): the name of a layer whose output to expose as embeddings. Prepend ``"<"`` to save the input tensor instead - feature_extractor (False): whether this model instance should be + as_feature_extractor (False): whether this model instance should be treated as a feature extractor. If embedding_layer is provided, then a feature extractor is created using torchvision, otherwise the model itself is treated as a feature extractor. @@ -488,8 +488,8 @@ def __init__(self, d): self.embeddings_layer = self.parse_string( d, "embeddings_layer", default=None ) - self.feature_extractor = self.parse_bool( - d, "feature_extractor", default=False + self.as_feature_extractor = self.parse_bool( + d, "as_feature_extractor", default=False ) self.use_half_precision = self.parse_bool( d, "use_half_precision", default=None @@ -548,7 +548,7 @@ def __init__(self, config): self, self._model, layer_name=self.config.embeddings_layer, - feature_extractor=self.config.feature_extractor, + as_feature_extractor=self.config.as_feature_extractor, ) def __enter__(self): diff --git a/tests/intensive/model_tests.py b/tests/intensive/model_tests.py index 2336fb8385..1e070c70ee 100644 --- a/tests/intensive/model_tests.py +++ b/tests/intensive/model_tests.py @@ -58,7 +58,7 @@ def test_torch_hub_feature_extractor(): model = fout.load_torch_hub_image_model( "facebookresearch/dino:main", "dino_vits16", - feature_extractor=True, + as_feature_extractor=True, image_size=[224, 224], )