diff --git a/fiftyone/utils/torch.py b/fiftyone/utils/torch.py index 4a9486ec89..a75d2498ef 100644 --- a/fiftyone/utils/torch.py +++ b/fiftyone/utils/torch.py @@ -18,6 +18,7 @@ import eta.core.geometry as etag import eta.core.learning as etal import eta.core.utils as etau +from torchvision.models.feature_extraction import create_feature_extractor import fiftyone.core.config as foc import fiftyone.core.labels as fol @@ -251,19 +252,31 @@ class TorchEmbeddingsMixin(fom.EmbeddingsMixin): layer_name (None): the name of the embeddings layer whose output to save, or ``None`` if this model instance should not expose embeddings. Prepend ``"<"`` to save the input tensor instead + as_feature_extractor (False): whether this model instance should + operate as a feature extractor, in which case ``layer_name`` if provided + is used to create a feature extractor. If ``layer_name`` is not provided, + the model is used as-is for feature extraction. """ - def __init__(self, model, layer_name=None): - if layer_name is not None: + def __init__(self, model, layer_name=None, as_feature_extractor=False): + if as_feature_extractor: + if layer_name: + # create a torchvision feature extractor + self._model = create_feature_extractor( + model, return_nodes=[layer_name] + ) + embeddings_layer = None + elif layer_name is not None: embeddings_layer = SaveLayerTensor(model, layer_name) else: embeddings_layer = None self._embeddings_layer = embeddings_layer + self._as_feature_extractor = as_feature_extractor @property def has_embeddings(self): - return self._embeddings_layer is not None + return self._embeddings_layer is not None or self._as_feature_extractor def embed(self, arg): if isinstance(arg, torch.Tensor): @@ -271,12 +284,18 @@ def embed(self, arg): else: args = [arg] - self._predict_all(args) - return self.get_embeddings()[0] + features = self._predict_all(args) + if self._as_feature_extractor: + return features + else: + return self.get_embeddings()[0] def embed_all(self, args): - self._predict_all(args) - return self.get_embeddings() + features = self._predict_all(args) + if self._as_feature_extractor: + return features + else: + return self.get_embeddings() def get_embeddings(self): if not self.has_embeddings: @@ -399,6 +418,10 @@ def predict_all(imgs): inputs that are lists of Tensors embeddings_layer (None): the name of a layer whose output to expose as embeddings. Prepend ``"<"`` to save the input tensor instead + as_feature_extractor (False): whether this model instance should be + treated as a feature extractor. If embedding_layer is provided, + then a feature extractor is created using torchvision, otherwise + the model itself is treated as a feature extractor. use_half_precision (None): whether to use half precision (only supported when using GPU) cudnn_benchmark (None): a value to use for @@ -465,6 +488,9 @@ def __init__(self, d): self.embeddings_layer = self.parse_string( d, "embeddings_layer", default=None ) + self.as_feature_extractor = self.parse_bool( + d, "as_feature_extractor", default=False + ) self.use_half_precision = self.parse_bool( d, "use_half_precision", default=None ) @@ -519,7 +545,10 @@ def __init__(self, config): fom.LogitsMixin.__init__(self) TorchEmbeddingsMixin.__init__( - self, self._model, layer_name=self.config.embeddings_layer + self, + self._model, + layer_name=self.config.embeddings_layer, + as_feature_extractor=self.config.as_feature_extractor, ) def __enter__(self): @@ -1000,6 +1029,11 @@ def _setup(self, model, layer_name): if _layer is None: raise ValueError("No layer found with name %s" % layer_name) + elif isinstance(_layer, torch.nn.Identity): + raise ValueError( + "Layer '%s' is an Identity layer. Use previous layer." + % layer_name + ) _layer.register_forward_hook(self) diff --git a/tests/intensive/model_tests.py b/tests/intensive/model_tests.py index 25b71a549a..1e070c70ee 100644 --- a/tests/intensive/model_tests.py +++ b/tests/intensive/model_tests.py @@ -12,9 +12,11 @@ import unittest import numpy as np +import pytest import fiftyone as fo import fiftyone.zoo as foz +import fiftyone.utils.torch as fout def test_apply_model(): @@ -49,6 +51,31 @@ def test_compute_embeddings(): _assert_embeddings_equal(embeddings2a, embeddings2b) +def test_torch_hub_feature_extractor(): + dataset = foz.load_zoo_dataset("quickstart") + view = dataset.take(5) + + model = fout.load_torch_hub_image_model( + "facebookresearch/dino:main", + "dino_vits16", + as_feature_extractor=True, + image_size=[224, 224], + ) + + embeddings1a = view.compute_embeddings(model, skip_failures=False) + view.compute_embeddings(model, embeddings_field="embeddings1") + embeddings1b = _load_embeddings(view, "embeddings1") + _assert_embeddings_equal(embeddings1a, embeddings1b) + + with pytest.raises(ValueError): + fout.load_torch_hub_image_model( + "facebookresearch/dino:main", + "dino_vits16", + embeddings_layer="head", + image_size=[224, 224], + ) + + def _load_embeddings(samples, path): return np.stack(samples.values(path))