Skip to content

Commit

Permalink
add a feature to treat torch models as feature extractors without for…
Browse files Browse the repository at this point in the history
…ward hooks
  • Loading branch information
Rustem Galiullin committed Sep 13, 2023
1 parent 26cfe85 commit e5008bb
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 8 deletions.
50 changes: 42 additions & 8 deletions fiftyone/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import eta.core.geometry as etag
import eta.core.learning as etal
import eta.core.utils as etau
from torchvision.models.feature_extraction import create_feature_extractor

import fiftyone.core.config as foc
import fiftyone.core.labels as fol
Expand Down Expand Up @@ -251,32 +252,50 @@ class TorchEmbeddingsMixin(fom.EmbeddingsMixin):
layer_name (None): the name of the embeddings layer whose output to
save, or ``None`` if this model instance should not expose
embeddings. Prepend ``"<"`` to save the input tensor instead
feature_extractor (False): whether this model instance should
operate as a feature extractor, in which case ``layer_name`` if provided
is used to create a feature extractor. If ``layer_name`` is not provided,
the model is used as-is for feature extraction.
"""

def __init__(self, model, layer_name=None):
if layer_name is not None:
def __init__(self, model, layer_name=None, feature_extractor=False):
if feature_extractor:
if layer_name:
# create a torchvision feature extractor
self._model = create_feature_extractor(
model, return_nodes=[layer_name]
)
embeddings_layer = None
elif layer_name is not None:
embeddings_layer = SaveLayerTensor(model, layer_name)
else:
embeddings_layer = None

self._embeddings_layer = embeddings_layer
self._as_feature_extractor = feature_extractor

@property
def has_embeddings(self):
return self._embeddings_layer is not None
return self._embeddings_layer is not None or self._as_feature_extractor

def embed(self, arg):
if isinstance(arg, torch.Tensor):
args = arg.unsqueeze(0)
else:
args = [arg]

self._predict_all(args)
return self.get_embeddings()[0]
features = self._predict_all(args)
if self._as_feature_extractor:
return features
else:
return self.get_embeddings()[0]

def embed_all(self, args):
self._predict_all(args)
return self.get_embeddings()
features = self._predict_all(args)
if self._as_feature_extractor:
return features
else:
return self.get_embeddings()

def get_embeddings(self):
if not self.has_embeddings:
Expand Down Expand Up @@ -399,6 +418,10 @@ def predict_all(imgs):
inputs that are lists of Tensors
embeddings_layer (None): the name of a layer whose output to expose as
embeddings. Prepend ``"<"`` to save the input tensor instead
feature_extractor (False): whether this model instance should be
treated as a feature extractor. If embedding_layer is provided,
then a feature extractor is created using torchvision, otherwise
the model itself is treated as a feature extractor.
use_half_precision (None): whether to use half precision (only
supported when using GPU)
cudnn_benchmark (None): a value to use for
Expand Down Expand Up @@ -465,6 +488,9 @@ def __init__(self, d):
self.embeddings_layer = self.parse_string(
d, "embeddings_layer", default=None
)
self.feature_extractor = self.parse_bool(
d, "feature_extractor", default=False
)
self.use_half_precision = self.parse_bool(
d, "use_half_precision", default=None
)
Expand Down Expand Up @@ -519,7 +545,10 @@ def __init__(self, config):

fom.LogitsMixin.__init__(self)
TorchEmbeddingsMixin.__init__(
self, self._model, layer_name=self.config.embeddings_layer
self,
self._model,
layer_name=self.config.embeddings_layer,
feature_extractor=self.config.feature_extractor,
)

def __enter__(self):
Expand Down Expand Up @@ -1000,6 +1029,11 @@ def _setup(self, model, layer_name):

if _layer is None:
raise ValueError("No layer found with name %s" % layer_name)
elif isinstance(_layer, torch.nn.Identity):
raise ValueError(
"Layer '%s' is an Identity layer. Use previous layer."
% layer_name
)

_layer.register_forward_hook(self)

Expand Down
27 changes: 27 additions & 0 deletions tests/intensive/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
import unittest

import numpy as np
import pytest

import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.utils.torch as fout


def test_apply_model():
Expand Down Expand Up @@ -49,6 +51,31 @@ def test_compute_embeddings():
_assert_embeddings_equal(embeddings2a, embeddings2b)


def test_torch_hub_feature_extractor():
dataset = foz.load_zoo_dataset("quickstart")
view = dataset.take(5)

model = fout.load_torch_hub_image_model(
"facebookresearch/dino:main",
"dino_vits16",
feature_extractor=True,
image_size=[224, 224],
)

embeddings1a = view.compute_embeddings(model, skip_failures=False)
view.compute_embeddings(model, embeddings_field="embeddings1")
embeddings1b = _load_embeddings(view, "embeddings1")
_assert_embeddings_equal(embeddings1a, embeddings1b)

with pytest.raises(ValueError):
fout.load_torch_hub_image_model(
"facebookresearch/dino:main",
"dino_vits16",
embeddings_layer="head",
image_size=[224, 224],
)


def _load_embeddings(samples, path):
return np.stack(samples.values(path))

Expand Down

0 comments on commit e5008bb

Please sign in to comment.