Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
Merge branch 'feature/extend-jit-models' into feature/feature_embeddi…
Browse files Browse the repository at this point in the history
…ng_dataset
  • Loading branch information
AjeyPaiK committed May 31, 2024
2 parents 82a1a73 + 87bff9c commit ca9cbc2
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 220 deletions.
6 changes: 3 additions & 3 deletions ahcore/lit_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from ahcore.exceptions import ConfigurationError
from ahcore.metrics import MetricFactory, WSIMetricFactory
from ahcore.models.jit_model import AhcoreJitModel
from ahcore.models.jit_model import BaseAhcoreJitModel
from ahcore.utils.data import DataDescription
from ahcore.utils.io import get_logger
from ahcore.utils.types import DlupDatasetSample
Expand All @@ -38,7 +38,7 @@ class AhCoreLightningModule(pl.LightningModule):

def __init__(
self,
model: nn.Module | AhcoreJitModel,
model: nn.Module | BaseAhcoreJitModel,
optimizer: torch.optim.Optimizer, # noqa
data_description: DataDescription,
loss: nn.Module | None = None,
Expand All @@ -58,7 +58,7 @@ def __init__(
"loss",
],
) # TODO: we should send the hyperparams to the logger elsewhere
if isinstance(model, AhcoreJitModel):
if isinstance(model, BaseAhcoreJitModel):
self._model = model
elif isinstance(model, functools.partial):
try:
Expand Down
Empty file.
106 changes: 106 additions & 0 deletions ahcore/models/embedding_functions/vit_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable

import torch

from ahcore.utils.types import OutputModeBase, ViTEmbedMode


class _TokenNames(str, Enum):
def __str__(self) -> Any:
return self.value


class ViTTokenNames(_TokenNames):
CLS_TOKEN_NAME = "x_norm_clstoken"
PATCH_TOKEN_NAME = "x_norm_patchtokens"


class AhCoreFeatureEmbedding(ABC):
def __init__(self, embedding_mode: OutputModeBase):
self._embedding_mode = embedding_mode
self.embed_fn: Callable[[dict[str, Any]], torch.Tensor]

@property
def embedding_mode(self) -> Any:
return self._embedding_mode.value

@property
def token_names(self) -> Any:
raise NotImplementedError

@abstractmethod
def _set_embed_function(self) -> None:
raise NotImplementedError


class ViTEmbed(AhCoreFeatureEmbedding):
def __init__(self, embedding_mode: OutputModeBase):
super().__init__(embedding_mode=embedding_mode)
self._set_embed_function()

@property
def dim_factor(self) -> int:
"""
Returns the scaling factor by which the output feature dimensionality will increase
when using a certain embedding method.
E.g. the concat method will make the output dimensionality twice as big.
"""
if self._embedding_mode == ViTEmbedMode.CONCAT:
return 2
else:
return 1

@property
def token_names(self) -> tuple[ViTTokenNames, ViTTokenNames]:
return ViTTokenNames.CLS_TOKEN_NAME, ViTTokenNames.PATCH_TOKEN_NAME

def _set_embed_function(self) -> None:
if self._embedding_mode == ViTEmbedMode.CLS_ONLY:
self.embed_fn = self.embed_cls_only
elif self._embedding_mode == ViTEmbedMode.PATCH_ONLY:
self.embed_fn = self.embed_patch_only
elif self._embedding_mode == ViTEmbedMode.MEAN:
self.embed_fn = self.embed_mean
elif self._embedding_mode == ViTEmbedMode.CONCAT_MEAN:
self.embed_fn = self.embed_concat_mean
elif self._embedding_mode == ViTEmbedMode.CONCAT:
self.embed_fn = self.embed_concat
else:
raise NotImplementedError(f"Embedding mode {self._embedding_mode} is not supported.")

@staticmethod
def get_output_tokens(jit_model_prediction_sample: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]:
cls_token = jit_model_prediction_sample[ViTTokenNames.CLS_TOKEN_NAME]
patch_tokens = jit_model_prediction_sample[ViTTokenNames.PATCH_TOKEN_NAME]
return cls_token, patch_tokens

def embed_cls_only(self, jit_model_prediction_sample: dict[str, Any]) -> torch.Tensor:
cls_token, _ = self.get_output_tokens(jit_model_prediction_sample)
return cls_token

def embed_patch_only(self, jit_model_prediction_sample: dict[str, Any]) -> torch.Tensor:
_, patch_tokens = self.get_output_tokens(jit_model_prediction_sample)
return patch_tokens

def embed_mean(self, jit_model_prediction_sample: dict[str, Any]) -> torch.Tensor:
cls_token, patch_tokens = self.get_output_tokens(jit_model_prediction_sample)
cls_token = cls_token.unsqueeze(1)
tokens = torch.cat([cls_token, patch_tokens], dim=1)
output = tokens.mean(dim=1)
return output

def embed_concat_mean(self, jit_model_prediction_sample: dict[str, Any]) -> torch.Tensor:
cls_token, patch_tokens = self.get_output_tokens(jit_model_prediction_sample)
cls_token = cls_token.unsqueeze(1)
mean_patch_token = patch_tokens.mean(dim=1).unsqueeze(1)
tokens = torch.cat([cls_token, mean_patch_token], dim=1)
output = tokens.mean(dim=1)
return output

def embed_concat(self, jit_model_prediction_sample: dict[str, Any]) -> torch.Tensor:
cls_token, patch_tokens = self.get_output_tokens(jit_model_prediction_sample)
mean_patch_token = patch_tokens.mean(dim=1)
output = torch.cat([cls_token, mean_patch_token], dim=1)
return output
44 changes: 44 additions & 0 deletions ahcore/models/feature_extractors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Any

import torch
from torch.jit import ScriptModule

from ahcore.models.embedding_functions.vit_embed import ViTEmbed
from ahcore.models.jit_model import BaseAhcoreJitModel
from ahcore.utils.types import OutputModeBase, ViTEmbedMode


class DinoV2JitModel(BaseAhcoreJitModel):
"""
This class is a wrapper for the DinoV2 foundation model in Ahcore.
"""

def __init__(self, model: ScriptModule, output_mode: OutputModeBase) -> None:
"""
Constructor for the DinoV2 class.
This can be a feature extractor based on very large pre-trained model.
Parameters
----------
model : ScriptModule
The jit compiled model.
output_mode : OutputModeBase
The output mode of the model. This is used to determine the forward function of the model.
Returns
-------
None
"""
super().__init__(model=model, output_mode=output_mode)
self._set_forward_function()

def _set_forward_function(self) -> None:
if self._output_mode == ViTEmbedMode.DEFAULT: # Assume that the JIT model returns a tensor
self._forward_function = lambda x: x
else:
self._forward_function = ViTEmbed(self._output_mode).embed_fn # More complex forward function

def forward(self, x: torch.Tensor) -> Any:
output = self._model(x)
return self._forward_function(output)
Loading

0 comments on commit ca9cbc2

Please sign in to comment.