This repository has been archived by the owner on Oct 19, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add models to do classification (#106)
* Added ABmil and transmil implementations * added utils necessary to use models * simplified abmil and layers * added tests --------- Co-authored-by: JorenB <jorenb@gmail.com>
- Loading branch information
1 parent
f600d04
commit d0ba0e9
Showing
6 changed files
with
555 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from typing import List, Optional | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from ahcore.models.layers.attention import GatedAttention | ||
from ahcore.models.layers.MLP import MLP | ||
|
||
|
||
class ABMIL(nn.Module): | ||
""" | ||
Attention-based MIL (Multiple Instance Learning) classification model (See [1]_). | ||
This model is adapted from | ||
https://github.com/owkin/HistoSSLscaling/blob/main/rl_benchmarks/models/slide_models/abmil.py. | ||
It uses an attention mechanism to aggregate features from multiple instances (tiles) into a single prediction. | ||
Methods | ||
------- | ||
forward(features: torch.Tensor, return_attention_weights: bool = False) | ||
-> torch.Tensor | tuple[torch.Tensor, torch.Tensor] | ||
Forward pass of the ABMIL model. | ||
References | ||
---------- | ||
.. [1] Maximilian Ilse, Jakub Tomczak, and Max Welling. Attention-based | ||
deep multiple instance learning. In Jennifer Dy and Andreas Krause, | ||
editors, Proceedings of the 35th International Conference on Machine | ||
Learning, volume 80 of Proceedings of Machine Learning Research, | ||
pages 2127–2136. PMLR, 10–15 Jul 2018. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_features: int, | ||
num_classes: int = 1, | ||
attention_dimension: int = 128, | ||
temperature: float = 1.0, | ||
embed_mlp_hidden: Optional[List[int]] = None, | ||
embed_mlp_dropout: Optional[List[float]] = None, | ||
embed_mlp_activation: Optional[torch.nn.Module] = nn.ReLU(), | ||
embed_mlp_bias: bool = True, | ||
classifier_hidden: Optional[List[int]] = [128, 64], | ||
classifier_dropout: Optional[List[float]] = None, | ||
classifier_activation: Optional[torch.nn.Module] = nn.ReLU(), | ||
classifier_bias: bool = False, | ||
) -> None: | ||
""" | ||
Initializes the ABMIL model with embedding and classification layers. | ||
Parameters | ||
---------- | ||
in_features : int | ||
Number of input features for each tile. | ||
out_features : int, optional | ||
Number of output features (typically 1 for binary classification), by default 1. | ||
attention_dimension : int, optional | ||
Dimensionality of the attention mechanism, by default 128. | ||
temperature : float, optional | ||
Temperature parameter for scaling the attention scores, by default 1.0. | ||
embed_mlp_hidden : Optional[List[int]], optional | ||
List of hidden layer sizes for the embedding MLP, by default None. | ||
embed_mlp_dropout : Optional[List[float]], optional | ||
List of dropout rates for the embedding MLP, by default None. | ||
embed_mlp_activation : Optional[torch.nn.Module], optional | ||
Activation function for the embedding MLP, by default nn.ReLU(). | ||
embed_mlp_bias : bool, optional | ||
Whether to include bias in the embedding MLP layers, by default True. | ||
classifier_hidden : Optional[List[int]], optional | ||
List of hidden layer sizes for the classifier MLP, by default [128, 64]. | ||
classifier_dropout : Optional[List[float]], optional | ||
List of dropout rates for the classifier MLP, by default None. | ||
classifier_activation : Optional[torch.nn.Module], optional | ||
Activation function for the classifier MLP, by default nn.ReLU(). | ||
classifier_bias : bool, optional | ||
Whether to include bias in the classifier MLP layers, by default False. | ||
""" | ||
super(ABMIL, self).__init__() | ||
|
||
self.embed_mlp = MLP( | ||
in_features=in_features, | ||
hidden=embed_mlp_hidden, | ||
bias=embed_mlp_bias, | ||
out_features=attention_dimension, | ||
dropout=embed_mlp_dropout, | ||
activation=embed_mlp_activation, | ||
) | ||
|
||
self.attention_layer = GatedAttention(dim=attention_dimension, temperature=temperature) | ||
|
||
self.classifier = MLP( | ||
in_features=attention_dimension, | ||
out_features=num_classes, | ||
bias=classifier_bias, | ||
hidden=classifier_hidden, | ||
dropout=classifier_dropout, | ||
activation=classifier_activation, | ||
) | ||
|
||
def forward( | ||
self, | ||
features: torch.Tensor, | ||
return_attention_weights: bool = False, | ||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Forward pass of the ABMIL model. | ||
Parameters | ||
---------- | ||
features : torch.Tensor | ||
Input tensor of shape (batch_size, n_tiles, in_features) representing the features of tiles. | ||
return_attention : bool, optional | ||
If True, also returns the attention weights, by default False. | ||
Returns | ||
------- | ||
torch.Tensor | ||
Logits representing the model's output. | ||
torch.Tensor, optional | ||
Attention weights, returned if return_attention is True. | ||
""" | ||
tiles_emb = self.embed_mlp(features) # BxN_tilesxN_features --> BxN_tilesx128 | ||
scaled_tiles_emb, attention_weights = self.attention_layer( | ||
tiles_emb, return_attention_weights=True | ||
) # BxN_tilesx128 --> Bx128 | ||
logits: torch.Tensor = self.classifier(scaled_tiles_emb) # Bx128 --> Bx1 | ||
|
||
if return_attention_weights: | ||
return logits, attention_weights | ||
|
||
return logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# this file includes the original nystrom attention and transmil model | ||
# from https://github.com/lucidrains/nystrom-attention/blob/main/nystrom_attention/nystrom_attention.py | ||
# and https://github.com/szc19990412/TransMIL/blob/main/models/TransMIL.py, respectively. | ||
|
||
from typing import Any | ||
|
||
import numpy as np | ||
import torch | ||
from torch import nn as nn | ||
|
||
from ahcore.models.layers.attention import NystromAttention | ||
|
||
|
||
class TransLayer(nn.Module): | ||
def __init__(self, norm_layer: type = nn.LayerNorm, dim: int = 512) -> None: | ||
super().__init__() | ||
self.norm = norm_layer(dim) | ||
self.attn = NystromAttention( | ||
dim=dim, | ||
dim_head=dim // 8, | ||
heads=8, | ||
num_landmarks=dim // 2, # number of landmarks | ||
pinv_iterations=6, | ||
# number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper | ||
residual=True, | ||
# whether to do an extra residual with the value or not. supposedly faster convergence if turned on | ||
dropout=0.1, | ||
) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = x + self.attn(self.norm(x)) | ||
|
||
return x | ||
|
||
|
||
class PPEG(nn.Module): | ||
def __init__(self, dim: int = 512) -> None: | ||
super(PPEG, self).__init__() | ||
self.proj = nn.Conv2d(dim, dim, 7, 1, 7 // 2, groups=dim) | ||
self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5 // 2, groups=dim) | ||
self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3 // 2, groups=dim) | ||
|
||
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: | ||
B, _, C = x.shape | ||
cls_token, feat_token = x[:, 0], x[:, 1:] | ||
cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) | ||
x = self.proj(cnn_feat) + cnn_feat + self.proj1(cnn_feat) + self.proj2(cnn_feat) | ||
x = x.flatten(2).transpose(1, 2) | ||
x = torch.cat((cls_token.unsqueeze(1), x), dim=1) | ||
return x | ||
|
||
|
||
class TransMIL(nn.Module): | ||
def __init__(self, in_features: int = 1024, num_classes: int = 1, hidden_dimension: int = 512) -> None: | ||
super(TransMIL, self).__init__() | ||
self.pos_layer = PPEG(dim=hidden_dimension) | ||
self._fc1 = nn.Sequential(nn.Linear(in_features, hidden_dimension), nn.ReLU()) | ||
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dimension)) | ||
self.n_classes = num_classes | ||
self.layer1 = TransLayer(dim=hidden_dimension) | ||
self.layer2 = TransLayer(dim=hidden_dimension) | ||
self.norm = nn.LayerNorm(hidden_dimension) | ||
self._fc2 = nn.Linear(hidden_dimension, self.n_classes) | ||
|
||
def forward(self, features: torch.Tensor, **kwargs: Any) -> torch.Tensor: | ||
h = features # [B, n, in_features] | ||
|
||
h = self._fc1(h) # [B, n, hidden_dimension] | ||
|
||
# ---->pad | ||
H = h.shape[1] | ||
_H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H))) | ||
add_length = _H * _W - H | ||
h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N, hidden_dimension] | ||
|
||
# ---->cls_token | ||
B = h.shape[0] | ||
cls_tokens = self.cls_token.expand(B, -1, -1).to(h.device) | ||
h = torch.cat((cls_tokens, h), dim=1) | ||
|
||
# ---->Translayer x1 | ||
h = self.layer1(h) # [B, N, hidden_dimension] | ||
|
||
# ---->PPEG | ||
h = self.pos_layer(h, _H, _W) # [B, N, hidden_dimension] | ||
|
||
# ---->Translayer x2 | ||
h = self.layer2(h) # [B, N, hidden_dimension] | ||
|
||
# ---->cls_token | ||
h = self.norm(h)[:, 0] | ||
|
||
# ---->predict | ||
logits: torch.Tensor = self._fc2(h) # [B, out_features] | ||
|
||
return logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from typing import List, Optional | ||
|
||
from torch import nn | ||
|
||
"""Most of this stuff is adapted from utils from https://github.com/owkin/HistoSSLscaling/tree/main""" | ||
|
||
|
||
class MLP(nn.Sequential): | ||
"""MLP Module. | ||
Parameters | ||
---------- | ||
in_features: int | ||
Features (model input) dimension. | ||
out_features: int = 1 | ||
Prediction (model output) dimension. | ||
hidden: Optional[List[int]] = None | ||
Dimension of hidden layer(s). | ||
dropout: Optional[List[float]] = None | ||
Dropout rate(s). | ||
activation: Optional[torch.nn.Module] = torch.nn.Sigmoid | ||
MLP activation. | ||
bias: bool = True | ||
Add bias to MLP hidden layers. | ||
Raises | ||
------ | ||
ValueError | ||
If ``hidden`` and ``dropout`` do not share the same length. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
hidden: Optional[List[int]] = None, | ||
dropout: Optional[List[float]] = None, | ||
activation: Optional[nn.Module] = nn.ReLU(), | ||
bias: bool = True, | ||
): | ||
if dropout is not None: | ||
if hidden is not None: | ||
assert len(hidden) == len(dropout), "hidden and dropout must have the same length" | ||
else: | ||
raise ValueError("hidden must have a value and have the same length as dropout if dropout is given.") | ||
|
||
d_model = in_features | ||
layers: list[nn.Module] = [] | ||
|
||
if hidden is not None: | ||
for i, h in enumerate(hidden): | ||
seq: list[nn.Module] = [nn.Linear(d_model, h, bias=bias)] | ||
d_model = h | ||
|
||
if activation is not None: | ||
seq.append(activation) | ||
|
||
if dropout is not None: | ||
seq.append(nn.Dropout(dropout[i])) | ||
|
||
layers.append(nn.Sequential(*seq)) | ||
|
||
layers.append(nn.Linear(d_model, out_features)) | ||
|
||
super(MLP, self).__init__(*layers) |
Oops, something went wrong.