From 3f91409efeda5c6d1713953040f219bfb9efc638 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 6 Aug 2024 14:48:44 +0800 Subject: [PATCH 1/6] update mlp block Signed-off-by: Yiheng Wang --- monai/networks/blocks/mlp.py | 14 ++++++++++---- tests/test_mlp.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index d3510b64d3..5d207920b1 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -14,9 +14,10 @@ import torch.nn as nn from monai.networks.layers import get_act_layer +from monai.networks.layers.factories import split_args from monai.utils import look_up_option -SUPPORTED_DROPOUT_MODE = {"vit", "swin"} +SUPPORTED_DROPOUT_MODE = {"vit", "swin", "vista3d"} class MLPBlock(nn.Module): @@ -39,7 +40,7 @@ def __init__( https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87 "swin" corresponds to one instance as implemented in https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23 - + "vista3d" mode does not use dropout. """ @@ -48,15 +49,20 @@ def __init__( if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") mlp_dim = mlp_dim or hidden_size - self.linear1 = nn.Linear(hidden_size, mlp_dim) if act != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2) + act_name, _ = split_args(act) + self.linear1 = nn.Linear(hidden_size, mlp_dim) if act_name != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2) self.linear2 = nn.Linear(mlp_dim, hidden_size) self.fn = get_act_layer(act) - self.drop1 = nn.Dropout(dropout_rate) dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE) if dropout_opt == "vit": + self.drop1 = nn.Dropout(dropout_rate) self.drop2 = nn.Dropout(dropout_rate) elif dropout_opt == "swin": + self.drop1 = nn.Dropout(dropout_rate) self.drop2 = self.drop1 + elif dropout_opt == "vista3d": + self.drop1 = nn.Identity() + self.drop2 = nn.Identity() else: raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}") diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 54f70d3318..af6eb5b6b8 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -15,10 +15,12 @@ import numpy as np import torch +import torch.nn as nn from parameterized import parameterized from monai.networks import eval_mode from monai.networks.blocks.mlp import MLPBlock +from monai.networks.layers.factories import split_args TEST_CASE_MLP = [] for dropout_rate in np.linspace(0, 1, 4): @@ -31,6 +33,14 @@ ] TEST_CASE_MLP.append(test_case) +# test different activation layers +TEST_CASE_ACT = [] +for act in ["GELU", "GEGLU", ("GELU", {"approximate": "tanh"}), ("GEGLU", {})]: + TEST_CASE_ACT.append([{"hidden_size": 128, "mlp_dim": 0, "act": act}, (2, 512, 128), (2, 512, 128)]) + +# test different dropout modes +TEST_CASE_DROP = [["vit", nn.Dropout], ["swin", nn.Dropout], ["vista3d", nn.Identity]] + class TestMLPBlock(unittest.TestCase): @@ -45,6 +55,24 @@ def test_ill_arg(self): with self.assertRaises(ValueError): MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=5.0) + @parameterized.expand(TEST_CASE_ACT) + def test_act(self, input_param, input_shape, expected_shape): + net = MLPBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + act_name, _ = split_args(input_param["act"]) + if act_name == "GEGLU": + self.assertEqual(net.linear1.in_features, net.linear1.out_features // 2) + else: + self.assertEqual(net.linear1.in_features, net.linear1.out_features) + + @parameterized.expand(TEST_CASE_DROP) + def test_dropout_mode(self, dropout_mode, dropout_layer): + net = MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=0.1, dropout_mode=dropout_mode) + self.assertTrue(isinstance(net.drop1, dropout_layer)) + self.assertTrue(isinstance(net.drop2, dropout_layer)) + if __name__ == "__main__": unittest.main() From d241db5af6604a4b56029f43831cc10cc67ad2aa Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 6 Aug 2024 15:38:46 +0800 Subject: [PATCH 2/6] add mypy fix Signed-off-by: Yiheng Wang --- monai/networks/blocks/mlp.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index 5d207920b1..8771711d25 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -11,6 +11,8 @@ from __future__ import annotations +from typing import Union + import torch.nn as nn from monai.networks.layers import get_act_layer @@ -53,6 +55,10 @@ def __init__( self.linear1 = nn.Linear(hidden_size, mlp_dim) if act_name != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2) self.linear2 = nn.Linear(mlp_dim, hidden_size) self.fn = get_act_layer(act) + # Use Union[nn.Dropout, nn.Identity] for type annotations + self.drop1: Union[nn.Dropout, nn.Identity] + self.drop2: Union[nn.Dropout, nn.Identity] + dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE) if dropout_opt == "vit": self.drop1 = nn.Dropout(dropout_rate) From 1af03f6468de122bc32b69794d4b8f0c60b64460 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 7 Aug 2024 11:52:14 +0800 Subject: [PATCH 3/6] remove gelu approximate Signed-off-by: Yiheng Wang --- tests/test_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mlp.py b/tests/test_mlp.py index af6eb5b6b8..a34a3b89ac 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -35,7 +35,7 @@ # test different activation layers TEST_CASE_ACT = [] -for act in ["GELU", "GEGLU", ("GELU", {"approximate": "tanh"}), ("GEGLU", {})]: +for act in ["GELU", "GEGLU", ("GEGLU", {})]: TEST_CASE_ACT.append([{"hidden_size": 128, "mlp_dim": 0, "act": act}, (2, 512, 128), (2, 512, 128)]) # test different dropout modes From 4dc89e51d3ab700897e8c39fa409e7c260be8d29 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 7 Aug 2024 13:15:24 +0800 Subject: [PATCH 4/6] free space Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- .github/workflows/pythonapp.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index fe04f96a80..7040e32f14 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -101,6 +101,7 @@ jobs: python -m pip install --pre -U itk - name: Install the dependencies run: | + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; python -m pip install --user --upgrade pip wheel python -m pip install torch==1.13.1 torchvision==0.14.1 cat "requirements-dev.txt" From b5400b9ca4796bed2b84ca1fb07dfd7c69c02393 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 7 Aug 2024 13:28:15 +0800 Subject: [PATCH 5/6] ignore test case type annotation Signed-off-by: Yiheng Wang --- tests/test_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mlp.py b/tests/test_mlp.py index a34a3b89ac..2598d8877d 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -35,7 +35,7 @@ # test different activation layers TEST_CASE_ACT = [] -for act in ["GELU", "GEGLU", ("GEGLU", {})]: +for act in ["GELU", "GEGLU", ("GEGLU", {})]: # type: ignore TEST_CASE_ACT.append([{"hidden_size": 128, "mlp_dim": 0, "act": act}, (2, 512, 128), (2, 512, 128)]) # test different dropout modes From c53f038f66111d7763245d3c76bdcab6a5ed747b Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 7 Aug 2024 13:41:19 +0800 Subject: [PATCH 6/6] try to fix Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- .github/workflows/pythonapp.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 7040e32f14..65f9a4dcf2 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -99,9 +99,9 @@ jobs: name: Install itk pre-release (Linux only) run: | python -m pip install --pre -U itk + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; - name: Install the dependencies run: | - find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; python -m pip install --user --upgrade pip wheel python -m pip install torch==1.13.1 torchvision==0.14.1 cat "requirements-dev.txt"