Skip to content

Commit

Permalink
rename feature_extractor param to as_feature_extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
Rustem Galiullin committed Sep 26, 2023
1 parent e5008bb commit ec66f7e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
16 changes: 8 additions & 8 deletions fiftyone/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,14 +252,14 @@ class TorchEmbeddingsMixin(fom.EmbeddingsMixin):
layer_name (None): the name of the embeddings layer whose output to
save, or ``None`` if this model instance should not expose
embeddings. Prepend ``"<"`` to save the input tensor instead
feature_extractor (False): whether this model instance should
as_feature_extractor (False): whether this model instance should
operate as a feature extractor, in which case ``layer_name`` if provided
is used to create a feature extractor. If ``layer_name`` is not provided,
the model is used as-is for feature extraction.
"""

def __init__(self, model, layer_name=None, feature_extractor=False):
if feature_extractor:
def __init__(self, model, layer_name=None, as_feature_extractor=False):
if as_feature_extractor:
if layer_name:
# create a torchvision feature extractor
self._model = create_feature_extractor(
Expand All @@ -272,7 +272,7 @@ def __init__(self, model, layer_name=None, feature_extractor=False):
embeddings_layer = None

self._embeddings_layer = embeddings_layer
self._as_feature_extractor = feature_extractor
self._as_feature_extractor = as_feature_extractor

@property
def has_embeddings(self):
Expand Down Expand Up @@ -418,7 +418,7 @@ def predict_all(imgs):
inputs that are lists of Tensors
embeddings_layer (None): the name of a layer whose output to expose as
embeddings. Prepend ``"<"`` to save the input tensor instead
feature_extractor (False): whether this model instance should be
as_feature_extractor (False): whether this model instance should be
treated as a feature extractor. If embedding_layer is provided,
then a feature extractor is created using torchvision, otherwise
the model itself is treated as a feature extractor.
Expand Down Expand Up @@ -488,8 +488,8 @@ def __init__(self, d):
self.embeddings_layer = self.parse_string(
d, "embeddings_layer", default=None
)
self.feature_extractor = self.parse_bool(
d, "feature_extractor", default=False
self.as_feature_extractor = self.parse_bool(
d, "as_feature_extractor", default=False
)
self.use_half_precision = self.parse_bool(
d, "use_half_precision", default=None
Expand Down Expand Up @@ -548,7 +548,7 @@ def __init__(self, config):
self,
self._model,
layer_name=self.config.embeddings_layer,
feature_extractor=self.config.feature_extractor,
as_feature_extractor=self.config.as_feature_extractor,
)

def __enter__(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/intensive/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_torch_hub_feature_extractor():
model = fout.load_torch_hub_image_model(
"facebookresearch/dino:main",
"dino_vits16",
feature_extractor=True,
as_feature_extractor=True,
image_size=[224, 224],
)

Expand Down

0 comments on commit ec66f7e

Please sign in to comment.