This repository has been archived by the owner on Oct 19, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'feature/extend-jit-models' into feature/feature_embeddi…
…ng_dataset
- Loading branch information
Showing
8 changed files
with
197 additions
and
220 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from abc import ABC, abstractmethod | ||
from enum import Enum | ||
from typing import Any, Callable | ||
|
||
import torch | ||
|
||
from ahcore.utils.types import OutputModeBase, ViTEmbedMode | ||
|
||
|
||
class _TokenNames(str, Enum): | ||
def __str__(self) -> Any: | ||
return self.value | ||
|
||
|
||
class ViTTokenNames(_TokenNames): | ||
CLS_TOKEN_NAME = "x_norm_clstoken" | ||
PATCH_TOKEN_NAME = "x_norm_patchtokens" | ||
|
||
|
||
class AhCoreFeatureEmbedding(ABC): | ||
def __init__(self, embedding_mode: OutputModeBase): | ||
self._embedding_mode = embedding_mode | ||
self.embed_fn: Callable[[dict[str, Any]], torch.Tensor] | ||
|
||
@property | ||
def embedding_mode(self) -> Any: | ||
return self._embedding_mode.value | ||
|
||
@property | ||
def token_names(self) -> Any: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def _set_embed_function(self) -> None: | ||
raise NotImplementedError | ||
|
||
|
||
class ViTEmbed(AhCoreFeatureEmbedding): | ||
def __init__(self, embedding_mode: OutputModeBase): | ||
super().__init__(embedding_mode=embedding_mode) | ||
self._set_embed_function() | ||
|
||
@property | ||
def dim_factor(self) -> int: | ||
""" | ||
Returns the scaling factor by which the output feature dimensionality will increase | ||
when using a certain embedding method. | ||
E.g. the concat method will make the output dimensionality twice as big. | ||
""" | ||
if self._embedding_mode == ViTEmbedMode.CONCAT: | ||
return 2 | ||
else: | ||
return 1 | ||
|
||
@property | ||
def token_names(self) -> tuple[ViTTokenNames, ViTTokenNames]: | ||
return ViTTokenNames.CLS_TOKEN_NAME, ViTTokenNames.PATCH_TOKEN_NAME | ||
|
||
def _set_embed_function(self) -> None: | ||
if self._embedding_mode == ViTEmbedMode.CLS_ONLY: | ||
self.embed_fn = self.embed_cls_only | ||
elif self._embedding_mode == ViTEmbedMode.PATCH_ONLY: | ||
self.embed_fn = self.embed_patch_only | ||
elif self._embedding_mode == ViTEmbedMode.MEAN: | ||
self.embed_fn = self.embed_mean | ||
elif self._embedding_mode == ViTEmbedMode.CONCAT_MEAN: | ||
self.embed_fn = self.embed_concat_mean | ||
elif self._embedding_mode == ViTEmbedMode.CONCAT: | ||
self.embed_fn = self.embed_concat | ||
else: | ||
raise NotImplementedError(f"Embedding mode {self._embedding_mode} is not supported.") | ||
|
||
@staticmethod | ||
def get_output_tokens(jit_model_prediction_sample: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]: | ||
cls_token = jit_model_prediction_sample[ViTTokenNames.CLS_TOKEN_NAME] | ||
patch_tokens = jit_model_prediction_sample[ViTTokenNames.PATCH_TOKEN_NAME] | ||
return cls_token, patch_tokens | ||
|
||
def embed_cls_only(self, jit_model_prediction_sample: dict[str, Any]) -> torch.Tensor: | ||
cls_token, _ = self.get_output_tokens(jit_model_prediction_sample) | ||
return cls_token | ||
|
||
def embed_patch_only(self, jit_model_prediction_sample: dict[str, Any]) -> torch.Tensor: | ||
_, patch_tokens = self.get_output_tokens(jit_model_prediction_sample) | ||
return patch_tokens | ||
|
||
def embed_mean(self, jit_model_prediction_sample: dict[str, Any]) -> torch.Tensor: | ||
cls_token, patch_tokens = self.get_output_tokens(jit_model_prediction_sample) | ||
cls_token = cls_token.unsqueeze(1) | ||
tokens = torch.cat([cls_token, patch_tokens], dim=1) | ||
output = tokens.mean(dim=1) | ||
return output | ||
|
||
def embed_concat_mean(self, jit_model_prediction_sample: dict[str, Any]) -> torch.Tensor: | ||
cls_token, patch_tokens = self.get_output_tokens(jit_model_prediction_sample) | ||
cls_token = cls_token.unsqueeze(1) | ||
mean_patch_token = patch_tokens.mean(dim=1).unsqueeze(1) | ||
tokens = torch.cat([cls_token, mean_patch_token], dim=1) | ||
output = tokens.mean(dim=1) | ||
return output | ||
|
||
def embed_concat(self, jit_model_prediction_sample: dict[str, Any]) -> torch.Tensor: | ||
cls_token, patch_tokens = self.get_output_tokens(jit_model_prediction_sample) | ||
mean_patch_token = patch_tokens.mean(dim=1) | ||
output = torch.cat([cls_token, mean_patch_token], dim=1) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from typing import Any | ||
|
||
import torch | ||
from torch.jit import ScriptModule | ||
|
||
from ahcore.models.embedding_functions.vit_embed import ViTEmbed | ||
from ahcore.models.jit_model import BaseAhcoreJitModel | ||
from ahcore.utils.types import OutputModeBase, ViTEmbedMode | ||
|
||
|
||
class DinoV2JitModel(BaseAhcoreJitModel): | ||
""" | ||
This class is a wrapper for the DinoV2 foundation model in Ahcore. | ||
""" | ||
|
||
def __init__(self, model: ScriptModule, output_mode: OutputModeBase) -> None: | ||
""" | ||
Constructor for the DinoV2 class. | ||
This can be a feature extractor based on very large pre-trained model. | ||
Parameters | ||
---------- | ||
model : ScriptModule | ||
The jit compiled model. | ||
output_mode : OutputModeBase | ||
The output mode of the model. This is used to determine the forward function of the model. | ||
Returns | ||
------- | ||
None | ||
""" | ||
super().__init__(model=model, output_mode=output_mode) | ||
self._set_forward_function() | ||
|
||
def _set_forward_function(self) -> None: | ||
if self._output_mode == ViTEmbedMode.DEFAULT: # Assume that the JIT model returns a tensor | ||
self._forward_function = lambda x: x | ||
else: | ||
self._forward_function = ViTEmbed(self._output_mode).embed_fn # More complex forward function | ||
|
||
def forward(self, x: torch.Tensor) -> Any: | ||
output = self._model(x) | ||
return self._forward_function(output) |
Oops, something went wrong.