Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
Add models to do classification (#106)
Browse files Browse the repository at this point in the history
* Added ABmil and transmil implementations

* added utils necessary to use models

* simplified abmil and layers

* added tests

---------

Co-authored-by: JorenB <jorenb@gmail.com>
  • Loading branch information
moerlemans and JorenB authored Sep 6, 2024
1 parent f600d04 commit d0ba0e9
Show file tree
Hide file tree
Showing 6 changed files with 555 additions and 2 deletions.
133 changes: 133 additions & 0 deletions ahcore/models/MIL/ABmil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import List, Optional

import torch
from torch import nn

from ahcore.models.layers.attention import GatedAttention
from ahcore.models.layers.MLP import MLP


class ABMIL(nn.Module):
"""
Attention-based MIL (Multiple Instance Learning) classification model (See [1]_).
This model is adapted from
https://github.com/owkin/HistoSSLscaling/blob/main/rl_benchmarks/models/slide_models/abmil.py.
It uses an attention mechanism to aggregate features from multiple instances (tiles) into a single prediction.
Methods
-------
forward(features: torch.Tensor, return_attention_weights: bool = False)
-> torch.Tensor | tuple[torch.Tensor, torch.Tensor]
Forward pass of the ABMIL model.
References
----------
.. [1] Maximilian Ilse, Jakub Tomczak, and Max Welling. Attention-based
deep multiple instance learning. In Jennifer Dy and Andreas Krause,
editors, Proceedings of the 35th International Conference on Machine
Learning, volume 80 of Proceedings of Machine Learning Research,
pages 2127–2136. PMLR, 10–15 Jul 2018.
"""

def __init__(
self,
in_features: int,
num_classes: int = 1,
attention_dimension: int = 128,
temperature: float = 1.0,
embed_mlp_hidden: Optional[List[int]] = None,
embed_mlp_dropout: Optional[List[float]] = None,
embed_mlp_activation: Optional[torch.nn.Module] = nn.ReLU(),
embed_mlp_bias: bool = True,
classifier_hidden: Optional[List[int]] = [128, 64],
classifier_dropout: Optional[List[float]] = None,
classifier_activation: Optional[torch.nn.Module] = nn.ReLU(),
classifier_bias: bool = False,
) -> None:
"""
Initializes the ABMIL model with embedding and classification layers.
Parameters
----------
in_features : int
Number of input features for each tile.
out_features : int, optional
Number of output features (typically 1 for binary classification), by default 1.
attention_dimension : int, optional
Dimensionality of the attention mechanism, by default 128.
temperature : float, optional
Temperature parameter for scaling the attention scores, by default 1.0.
embed_mlp_hidden : Optional[List[int]], optional
List of hidden layer sizes for the embedding MLP, by default None.
embed_mlp_dropout : Optional[List[float]], optional
List of dropout rates for the embedding MLP, by default None.
embed_mlp_activation : Optional[torch.nn.Module], optional
Activation function for the embedding MLP, by default nn.ReLU().
embed_mlp_bias : bool, optional
Whether to include bias in the embedding MLP layers, by default True.
classifier_hidden : Optional[List[int]], optional
List of hidden layer sizes for the classifier MLP, by default [128, 64].
classifier_dropout : Optional[List[float]], optional
List of dropout rates for the classifier MLP, by default None.
classifier_activation : Optional[torch.nn.Module], optional
Activation function for the classifier MLP, by default nn.ReLU().
classifier_bias : bool, optional
Whether to include bias in the classifier MLP layers, by default False.
"""
super(ABMIL, self).__init__()

self.embed_mlp = MLP(
in_features=in_features,
hidden=embed_mlp_hidden,
bias=embed_mlp_bias,
out_features=attention_dimension,
dropout=embed_mlp_dropout,
activation=embed_mlp_activation,
)

self.attention_layer = GatedAttention(dim=attention_dimension, temperature=temperature)

self.classifier = MLP(
in_features=attention_dimension,
out_features=num_classes,
bias=classifier_bias,
hidden=classifier_hidden,
dropout=classifier_dropout,
activation=classifier_activation,
)

def forward(
self,
features: torch.Tensor,
return_attention_weights: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of the ABMIL model.
Parameters
----------
features : torch.Tensor
Input tensor of shape (batch_size, n_tiles, in_features) representing the features of tiles.
return_attention : bool, optional
If True, also returns the attention weights, by default False.
Returns
-------
torch.Tensor
Logits representing the model's output.
torch.Tensor, optional
Attention weights, returned if return_attention is True.
"""
tiles_emb = self.embed_mlp(features) # BxN_tilesxN_features --> BxN_tilesx128
scaled_tiles_emb, attention_weights = self.attention_layer(
tiles_emb, return_attention_weights=True
) # BxN_tilesx128 --> Bx128
logits: torch.Tensor = self.classifier(scaled_tiles_emb) # Bx128 --> Bx1

if return_attention_weights:
return logits, attention_weights

return logits
96 changes: 96 additions & 0 deletions ahcore/models/MIL/transmil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# this file includes the original nystrom attention and transmil model
# from https://github.com/lucidrains/nystrom-attention/blob/main/nystrom_attention/nystrom_attention.py
# and https://github.com/szc19990412/TransMIL/blob/main/models/TransMIL.py, respectively.

from typing import Any

import numpy as np
import torch
from torch import nn as nn

from ahcore.models.layers.attention import NystromAttention


class TransLayer(nn.Module):
def __init__(self, norm_layer: type = nn.LayerNorm, dim: int = 512) -> None:
super().__init__()
self.norm = norm_layer(dim)
self.attn = NystromAttention(
dim=dim,
dim_head=dim // 8,
heads=8,
num_landmarks=dim // 2, # number of landmarks
pinv_iterations=6,
# number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper
residual=True,
# whether to do an extra residual with the value or not. supposedly faster convergence if turned on
dropout=0.1,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm(x))

return x


class PPEG(nn.Module):
def __init__(self, dim: int = 512) -> None:
super(PPEG, self).__init__()
self.proj = nn.Conv2d(dim, dim, 7, 1, 7 // 2, groups=dim)
self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5 // 2, groups=dim)
self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3 // 2, groups=dim)

def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
B, _, C = x.shape
cls_token, feat_token = x[:, 0], x[:, 1:]
cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
x = self.proj(cnn_feat) + cnn_feat + self.proj1(cnn_feat) + self.proj2(cnn_feat)
x = x.flatten(2).transpose(1, 2)
x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
return x


class TransMIL(nn.Module):
def __init__(self, in_features: int = 1024, num_classes: int = 1, hidden_dimension: int = 512) -> None:
super(TransMIL, self).__init__()
self.pos_layer = PPEG(dim=hidden_dimension)
self._fc1 = nn.Sequential(nn.Linear(in_features, hidden_dimension), nn.ReLU())
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dimension))
self.n_classes = num_classes
self.layer1 = TransLayer(dim=hidden_dimension)
self.layer2 = TransLayer(dim=hidden_dimension)
self.norm = nn.LayerNorm(hidden_dimension)
self._fc2 = nn.Linear(hidden_dimension, self.n_classes)

def forward(self, features: torch.Tensor, **kwargs: Any) -> torch.Tensor:
h = features # [B, n, in_features]

h = self._fc1(h) # [B, n, hidden_dimension]

# ---->pad
H = h.shape[1]
_H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
add_length = _H * _W - H
h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N, hidden_dimension]

# ---->cls_token
B = h.shape[0]
cls_tokens = self.cls_token.expand(B, -1, -1).to(h.device)
h = torch.cat((cls_tokens, h), dim=1)

# ---->Translayer x1
h = self.layer1(h) # [B, N, hidden_dimension]

# ---->PPEG
h = self.pos_layer(h, _H, _W) # [B, N, hidden_dimension]

# ---->Translayer x2
h = self.layer2(h) # [B, N, hidden_dimension]

# ---->cls_token
h = self.norm(h)[:, 0]

# ---->predict
logits: torch.Tensor = self._fc2(h) # [B, out_features]

return logits
4 changes: 2 additions & 2 deletions ahcore/models/base_jit_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pathlib import Path
from typing import Any

from torch import nn
from torch.jit import ScriptModule, load
from torch.nn import Module


class BaseAhcoreJitModel(ScriptModule):
Expand Down Expand Up @@ -46,7 +46,7 @@ def from_jit_path(cls, jit_path: Path, output_mode: str) -> Any:
model = load(jit_path) # type: ignore
return cls(model)

def extend_model(self, modules: dict[str, Module]) -> None:
def extend_model(self, modules: dict[str, nn.Module]) -> None:
"""
Add modules to a jit compiled model.
Expand Down
65 changes: 65 additions & 0 deletions ahcore/models/layers/MLP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import List, Optional

from torch import nn

"""Most of this stuff is adapted from utils from https://github.com/owkin/HistoSSLscaling/tree/main"""


class MLP(nn.Sequential):
"""MLP Module.
Parameters
----------
in_features: int
Features (model input) dimension.
out_features: int = 1
Prediction (model output) dimension.
hidden: Optional[List[int]] = None
Dimension of hidden layer(s).
dropout: Optional[List[float]] = None
Dropout rate(s).
activation: Optional[torch.nn.Module] = torch.nn.Sigmoid
MLP activation.
bias: bool = True
Add bias to MLP hidden layers.
Raises
------
ValueError
If ``hidden`` and ``dropout`` do not share the same length.
"""

def __init__(
self,
in_features: int,
out_features: int,
hidden: Optional[List[int]] = None,
dropout: Optional[List[float]] = None,
activation: Optional[nn.Module] = nn.ReLU(),
bias: bool = True,
):
if dropout is not None:
if hidden is not None:
assert len(hidden) == len(dropout), "hidden and dropout must have the same length"
else:
raise ValueError("hidden must have a value and have the same length as dropout if dropout is given.")

d_model = in_features
layers: list[nn.Module] = []

if hidden is not None:
for i, h in enumerate(hidden):
seq: list[nn.Module] = [nn.Linear(d_model, h, bias=bias)]
d_model = h

if activation is not None:
seq.append(activation)

if dropout is not None:
seq.append(nn.Dropout(dropout[i]))

layers.append(nn.Sequential(*seq))

layers.append(nn.Linear(d_model, out_features))

super(MLP, self).__init__(*layers)
Loading

0 comments on commit d0ba0e9

Please sign in to comment.