Skip to content

Commit

Permalink
Add mean pooling layer (#187)
Browse files Browse the repository at this point in the history
  • Loading branch information
dccastro authored Feb 18, 2022
1 parent b8b5298 commit 2bc397b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ the section headers (Added/Changed/...) and incrementing the package version.
the environment file since it is necessary for the augmentations.
- ([#178](https://github.com/microsoft/hi-ml/pull/178)) Add runner script for running ML experiments.
- ([#181](https://github.com/microsoft/hi-ml/pull/181)) Add computational pathology tools in hi-ml-histopathology folder.
- ([#187](https://github.com/microsoft/hi-ml/pull/187)) Add mean pooling layer for MIL.

### Changed

Expand Down
18 changes: 17 additions & 1 deletion hi-ml/src/health_ml/networks/layers/attention_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,27 @@
Created using the original DeepMIL paper and code from Ilse et al., 2018
https://github.com/AMLab-Amsterdam/AttentionDeepMIL (MIT License)
"""
from typing import Tuple
from typing import Any, Tuple
from torch import nn, Tensor, transpose, mm
import torch
import torch.nn.functional as F


class MeanPoolingLayer(nn.Module):
"""Mean pooling returns uniform weights and the average feature vector over the first axis"""

# args/kwargs added here for compatibility with parametrised pooling modules
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()

def forward(self, features: Tensor) -> Tuple[Tensor, Tensor]:
num_instances = features.shape[0]
A = torch.full((1, num_instances), 1. / num_instances)
M = features.mean(dim=0)
M = M.view(1, -1)
return (A, M)


class AttentionLayer(nn.Module):
""" AttentionLayer: Simple attention layer
Requires size of input L, hidden D, and attention layers K (default K=1)
Expand Down
38 changes: 25 additions & 13 deletions hi-ml/testhiml/testhiml/test_attentionlayers.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,43 @@
import pytest
from typing import Type, Union

from torch import rand, sum, allclose, ones_like
from torch import nn, rand, sum, allclose, ones_like

from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer
from health_ml.networks.layers.attention_layers import (AttentionLayer, GatedAttentionLayer,
MeanPoolingLayer)


def _test_attention_layer(attentionlayer: nn.Module, dim_in: int, dim_att: int,
batch_size: int,) -> None:
features = rand(batch_size, dim_in) # N x L x 1 x 1
attn_weights, output_features = attentionlayer(features)
assert attn_weights.shape == (dim_att, batch_size) # K x N
assert output_features.shape == (dim_att, dim_in) # K x L
assert ((attn_weights >= 0) & (attn_weights <= 1)).all()

row_sums = sum(attn_weights, dim=1, keepdim=True)
assert allclose(row_sums, ones_like(row_sums))

pooled_features = attn_weights @ features.flatten(start_dim=1)
assert allclose(pooled_features, output_features)


@pytest.mark.parametrize("dim_in", [1, 3])
@pytest.mark.parametrize("dim_hid", [1, 4])
@pytest.mark.parametrize("dim_att", [1, 5])
@pytest.mark.parametrize("batch_size", [1, 7])
@pytest.mark.parametrize('attention_layer_cls', [AttentionLayer, GatedAttentionLayer])
def test_attentionlayer(dim_in: int,
dim_hid: int,
dim_att: int,
batch_size: int,
def test_attentionlayer(dim_in: int, dim_hid: int, dim_att: int, batch_size: int,
attention_layer_cls: Type[Union[AttentionLayer, GatedAttentionLayer]]) -> None:

attentionlayer = attention_layer_cls(
input_dims=dim_in,
hidden_dims=dim_hid,
attention_dims=dim_att
)
_test_attention_layer(attentionlayer, dim_in, dim_att, batch_size)

features = rand(batch_size, dim_in, 1, 1) # N x L x 1 x 1
attn_weights, output_features = attentionlayer(features)
assert attn_weights.shape == (dim_att, batch_size) # K x N
assert output_features.shape == (dim_att, dim_in) # K x L
row_sums = sum(attn_weights, dim=1, keepdim=True)
assert allclose(row_sums, ones_like(row_sums))

@pytest.mark.parametrize("dim_in", [1, 3])
@pytest.mark.parametrize("batch_size", [1, 7])
def test_mean_pooling(dim_in: int, batch_size: int,) -> None:
_test_attention_layer(MeanPoolingLayer(), dim_in=dim_in, dim_att=1, batch_size=batch_size)

1 comment on commit 2bc397b

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filename Stmts Miss Cover Missing
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/encoders.py 37 0 100.00%
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/lightning_metrics.py 101 27 73.27% 44-45,53,58-59,67,72-73,81,95,120,134-135,143,153-165
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/ssl_augmentation_config.py 43 0 100.00%
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/utils.py 66 10 84.85% 55,59,90,111-121
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/configs/CIFAR_SSL_configs.py 11 1 90.91% 33
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/configs/CIFAR_classifier_configs.py 5 0 100.00%
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/configs/CXR_SSL_configs.py 19 0 100.00%
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/configs/CovidContainers.py 8 1 87.50% 20
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/data/cifar_datasets.py 10 1 90.00% 30
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/data/cxr_datasets.py 97 32 67.01% 42,49-51,70,141-143,163-164,167-179,190-197,201,210-218
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/data/datamodules.py 75 3 96.00% 83-85,133
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/data/dataset_cls_utils.py 15 1 93.33% 32
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/data/image_transforms.py 54 2 96.30% 33,53
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/data/io_util.py 29 6 79.31% 62-68
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/data/transform_pipeline.py 60 8 86.67% 71-74,79-82,123
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/data/transforms_utils.py 29 0 100.00%
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/lightning_containers/ssl_container.py 120 6 95.00% 127,190,275-286
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/lightning_containers/ssl_image_classifier.py 29 1 96.55% 39
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/lightning_modules/simclr_module.py 47 0 100.00%
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/lightning_modules/ssl_classifier_module.py 55 3 94.55% 50-51,58
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/lightning_modules/ssl_online_evaluator.py 97 4 95.88% 62,100,105,111
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/lightning_modules/byol/byol_models.py 26 0 100.00%
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/lightning_modules/byol/byol_module.py 71 1 98.59% 143
/home/runner/work/hi-ml/hi-ml/hi-ml-histopathology/src/SSL/lightning_modules/byol/byol_moving_average.py 24 0 100.00%
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_azure/datasets.py 149 1 99.33% 282
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_azure/himl.py 180 2 98.89% 143,535
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_azure/himl_download.py 27 10 62.96% 42-56
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_azure/himl_tensorboard.py 56 4 92.86% 87-89,101
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_azure/utils.py 658 74 88.75% 110-124,130,199,285,305,313-314,377,409,425,428,435,479,494-501,520,526,542,546,579-584,589-591,595,599,606-616,639,889,967,1238,1388,1435-1450,1495,1506-1508,1608,1617,1627,1679
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/deep_learning_config.py 170 20 88.24% 98,105-106,112-116,170,311,316-321,329-331,394
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/experiment_config.py 7 0 100.00%
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/lightning_container.py 76 13 82.89% 66,84,99,139,148-149,158,165-167,208-211
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/model_trainer.py 103 15 85.44% 65-74,86-87,112,115-116,125,134,225-226
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/run_ml.py 81 17 79.01% 37-41,74,79-85,132-133,140,152,163,167
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/runner.py 126 29 76.98% 26,56,73-90,110,199,209-210,228-231,256,271-282,294-295,299,303
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/configs/hello_container.py 88 43 51.14% 39-40,43,46,61-62,72-75,78,81,84,87,90,113,125-127,140-142,152-155,163-165,174-175,188-197,206-208,233-234
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/networks/layers/attention_layers.py 46 0 100.00%
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/utils/bag_utils.py 87 15 82.76% 117,143-144,147,150-155,166,169-175
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/utils/box_utils.py 50 20 60.00% 24,26,35-37,48,59,67,75,88-95,104,114-115
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/utils/checkpoint_utils.py 134 65 51.49% 54,72,86,94-113,120,127-137,141-142,152,172-190,203-211,225-226,234-242,254-265
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/utils/common_utils.py 135 25 81.48% 56-58,83,99-101,136,138,165-167,213,215-217,235-239,280-283
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/utils/config_loader.py 86 6 93.02% 109,135-139,146-147
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/utils/data_augmentations.py 72 0 100.00%
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/utils/diagnostics.py 123 0 100.00%
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/utils/fixed_paths.py 35 16 54.29% 22-27,39-41,58-61,77,84-86
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/utils/lightning_loggers.py 61 37 39.34% 30-54,58,61,64,67,74,87-98,110
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/utils/logging.py 181 6 96.69% 77-79,111,149,274
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/utils/lr_scheduler.py 80 60 25.00% 19,28-33,36-40,43,48-51,54-55,65-78,85-116,125-130,138-141,149-160,163
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/utils/model_util.py 13 7 46.15% 16-36
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/utils/reports.py 219 11 94.98% 79,81,254,288,343,345,420,454,466-467,475
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/utils/split_dataset.py 94 19 79.79% 48,56,64-65,86-87,97-104,124,128,132,140-141,273
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/health_ml/utils/type_annotations.py 8 0 100.00%
TOTAL 4273 622 85.44%

Please sign in to comment.