diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index fe04f96a80..65f9a4dcf2 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -99,6 +99,7 @@ jobs: name: Install itk pre-release (Linux only) run: | python -m pip install --pre -U itk + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; - name: Install the dependencies run: | python -m pip install --user --upgrade pip wheel diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index d3510b64d3..8771711d25 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -11,12 +11,15 @@ from __future__ import annotations +from typing import Union + import torch.nn as nn from monai.networks.layers import get_act_layer +from monai.networks.layers.factories import split_args from monai.utils import look_up_option -SUPPORTED_DROPOUT_MODE = {"vit", "swin"} +SUPPORTED_DROPOUT_MODE = {"vit", "swin", "vista3d"} class MLPBlock(nn.Module): @@ -39,7 +42,7 @@ def __init__( https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87 "swin" corresponds to one instance as implemented in https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23 - + "vista3d" mode does not use dropout. """ @@ -48,15 +51,24 @@ def __init__( if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") mlp_dim = mlp_dim or hidden_size - self.linear1 = nn.Linear(hidden_size, mlp_dim) if act != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2) + act_name, _ = split_args(act) + self.linear1 = nn.Linear(hidden_size, mlp_dim) if act_name != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2) self.linear2 = nn.Linear(mlp_dim, hidden_size) self.fn = get_act_layer(act) - self.drop1 = nn.Dropout(dropout_rate) + # Use Union[nn.Dropout, nn.Identity] for type annotations + self.drop1: Union[nn.Dropout, nn.Identity] + self.drop2: Union[nn.Dropout, nn.Identity] + dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE) if dropout_opt == "vit": + self.drop1 = nn.Dropout(dropout_rate) self.drop2 = nn.Dropout(dropout_rate) elif dropout_opt == "swin": + self.drop1 = nn.Dropout(dropout_rate) self.drop2 = self.drop1 + elif dropout_opt == "vista3d": + self.drop1 = nn.Identity() + self.drop2 = nn.Identity() else: raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}") diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 54f70d3318..2598d8877d 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -15,10 +15,12 @@ import numpy as np import torch +import torch.nn as nn from parameterized import parameterized from monai.networks import eval_mode from monai.networks.blocks.mlp import MLPBlock +from monai.networks.layers.factories import split_args TEST_CASE_MLP = [] for dropout_rate in np.linspace(0, 1, 4): @@ -31,6 +33,14 @@ ] TEST_CASE_MLP.append(test_case) +# test different activation layers +TEST_CASE_ACT = [] +for act in ["GELU", "GEGLU", ("GEGLU", {})]: # type: ignore + TEST_CASE_ACT.append([{"hidden_size": 128, "mlp_dim": 0, "act": act}, (2, 512, 128), (2, 512, 128)]) + +# test different dropout modes +TEST_CASE_DROP = [["vit", nn.Dropout], ["swin", nn.Dropout], ["vista3d", nn.Identity]] + class TestMLPBlock(unittest.TestCase): @@ -45,6 +55,24 @@ def test_ill_arg(self): with self.assertRaises(ValueError): MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=5.0) + @parameterized.expand(TEST_CASE_ACT) + def test_act(self, input_param, input_shape, expected_shape): + net = MLPBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + act_name, _ = split_args(input_param["act"]) + if act_name == "GEGLU": + self.assertEqual(net.linear1.in_features, net.linear1.out_features // 2) + else: + self.assertEqual(net.linear1.in_features, net.linear1.out_features) + + @parameterized.expand(TEST_CASE_DROP) + def test_dropout_mode(self, dropout_mode, dropout_layer): + net = MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=0.1, dropout_mode=dropout_mode) + self.assertTrue(isinstance(net.drop1, dropout_layer)) + self.assertTrue(isinstance(net.drop2, dropout_layer)) + if __name__ == "__main__": unittest.main()