From 14d91f21ff326b55a83b3293794ab7b7469742a1 Mon Sep 17 00:00:00 2001 From: Felix Schnabel Date: Mon, 16 Jan 2023 18:00:06 +0100 Subject: [PATCH 1/7] First draft of the GEGLU activation function for the MLP Block. Signed-off-by: Felix Schnabel --- monai/networks/blocks/activation.py | 6 ++++++ monai/networks/blocks/mlp.py | 2 +- monai/networks/layers/factories.py | 5 +++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index 22d3c74897..e98bc777b5 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -162,3 +162,9 @@ def __init__(self, inplace: bool = False): def forward(self, input: torch.Tensor): return monai_mish(input, self.inplace) + +class GEGLU(nn.Module): + + def forward(self, input: torch.Tensor): + x, gate = input.chunk(2, dim=-1) + return x * nn.functional.gelu(gate) diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index e691112a1e..54a3b0d4ca 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -48,7 +48,7 @@ def __init__( if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") mlp_dim = mlp_dim or hidden_size - self.linear1 = nn.Linear(hidden_size, mlp_dim) + self.linear1 = nn.Linear(hidden_size, mlp_dim) if act != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2) self.linear2 = nn.Linear(mlp_dim, hidden_size) self.fn = get_act_layer(act) self.drop1 = nn.Dropout(dropout_rate) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 038b85dbca..972e188e82 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -319,6 +319,11 @@ def mish_factory(): return Mish +@Act.factory_function("geglu") +def geglu_factory(): + from monai.networks.blocks.activation import GEGLU + + return GEGLU @Conv.factory_function("conv") def conv_factory(dim: int) -> type[nn.Conv1d | nn.Conv2d | nn.Conv3d]: From f686f39bb924ca5306eb17fd3678fd87d8522a5d Mon Sep 17 00:00:00 2001 From: Felix Schnabel Date: Mon, 16 Jan 2023 18:31:19 +0100 Subject: [PATCH 2/7] Formatting Signed-off-by: Felix Schnabel --- monai/networks/blocks/activation.py | 2 +- monai/networks/layers/factories.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index e98bc777b5..d9683435c7 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -163,8 +163,8 @@ def __init__(self, inplace: bool = False): def forward(self, input: torch.Tensor): return monai_mish(input, self.inplace) -class GEGLU(nn.Module): +class GEGLU(nn.Module): def forward(self, input: torch.Tensor): x, gate = input.chunk(2, dim=-1) return x * nn.functional.gelu(gate) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 972e188e82..37536a7cf6 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -319,12 +319,14 @@ def mish_factory(): return Mish + @Act.factory_function("geglu") def geglu_factory(): from monai.networks.blocks.activation import GEGLU return GEGLU + @Conv.factory_function("conv") def conv_factory(dim: int) -> type[nn.Conv1d | nn.Conv2d | nn.Conv3d]: types = (nn.Conv1d, nn.Conv2d, nn.Conv3d) From 8d5e721119ae70f206c113880807b22b20032159 Mon Sep 17 00:00:00 2001 From: Felix Schnabel Date: Thu, 19 Jan 2023 19:20:51 +0100 Subject: [PATCH 3/7] Add documentation Signed-off-by: Felix Schnabel --- monai/networks/blocks/activation.py | 13 +++++++++++++ monai/networks/blocks/mlp.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index d9683435c7..35f72f9fc8 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -165,6 +165,19 @@ def forward(self, input: torch.Tensor): class GEGLU(nn.Module): + r"""Applies the element-wise function: + + .. math:: + \text{GEGLU}(x) = x_1 * \text{Sigmoid}(x_2) + + where :math:`x_1` and :math:`x_2` are split from the input tensor along the last dimension. + + Citation: GLU Variants Improve Transformer, Noam Shazeer, 2020, https://arxiv.org/abs/2002.05202. + + Shape: + - Input: :math:`(N, *, 2 * D)` + - Output: :math:`(N, *, D)`, where `*` means, any number of additional dimensions + """ def forward(self, input: torch.Tensor): x, gate = input.chunk(2, dim=-1) return x * nn.functional.gelu(gate) diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index 54a3b0d4ca..e3ab94b32a 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -33,7 +33,7 @@ def __init__( hidden_size: dimension of hidden layer. mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used. dropout_rate: faction of the input units to drop. - act: activation type and arguments. Defaults to GELU. + act: activation type and arguments. Defaults to GELU. Also supports "GEGLU" and others. dropout_mode: dropout mode, can be "vit" or "swin". "vit" mode uses two dropout instances as implemented in https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87 From 35cba63bc579041527d48ba4a40ced80aa559535 Mon Sep 17 00:00:00 2001 From: Felix Schnabel Date: Thu, 19 Jan 2023 19:23:36 +0100 Subject: [PATCH 4/7] Add GEGLU to documentation Signed-off-by: Felix Schnabel --- docs/source/networks.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 64ba0313c4..3a3c47afd5 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -58,6 +58,11 @@ Blocks .. autoclass:: Mish :members: +`GEGLU` +~~~~~~~ +.. autoclass:: GEGLU + :members: + `GCN Module` ~~~~~~~~~~~~ .. autoclass:: GCN From 11059b02e47f27c09e8ef873e7a6b788f3c8d2dd Mon Sep 17 00:00:00 2001 From: Felix Schnabel Date: Thu, 19 Jan 2023 20:10:58 +0100 Subject: [PATCH 5/7] Fix formatting and export GEGLU in __init__ Signed-off-by: Felix Schnabel --- monai/networks/blocks/__init__.py | 2 +- monai/networks/blocks/activation.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 5db8a2f8df..e67cb3376f 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -12,7 +12,7 @@ from __future__ import annotations from .acti_norm import ADN -from .activation import MemoryEfficientSwish, Mish, Swish +from .activation import GEGLU, MemoryEfficientSwish, Mish, Swish from .aspp import SimpleASPP from .backbone_fpn_utils import BackboneWithFPN from .convolutions import Convolution, ResidualUnit diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index 35f72f9fc8..1e5e979dff 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -178,6 +178,7 @@ class GEGLU(nn.Module): - Input: :math:`(N, *, 2 * D)` - Output: :math:`(N, *, D)`, where `*` means, any number of additional dimensions """ + def forward(self, input: torch.Tensor): x, gate = input.chunk(2, dim=-1) return x * nn.functional.gelu(gate) From 2da7ad3ec54bb39959f81c03fe908b62a87f1e1d Mon Sep 17 00:00:00 2001 From: Felix Schnabel Date: Thu, 19 Jan 2023 20:11:39 +0100 Subject: [PATCH 6/7] Add testcase Signed-off-by: Felix Schnabel --- tests/test_activations.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/test_activations.py b/tests/test_activations.py index 476b44ab6f..494a64df0e 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -85,6 +85,13 @@ (1, 2, 5), ] +TEST_CASE_7 = [ + "geglu", + torch.tensor([[[-10, -8, -6, -4, -2, 0], [0, 2, 4, 6, 8, 10]]], dtype=torch.float32), + torch.tensor([[[1.27e-03, 3.64e-01, 0.00e+00], [0.00e00, 1.60e+01, 4.00e01]]]), + (1, 2, 3), +] + class TestActivations(unittest.TestCase): @parameterized.expand(TEST_CASES) @@ -101,7 +108,7 @@ def _compare(ret, out, shape): else: _compare(result, out, expected_shape) - @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_monai_activations_value_shape(self, input_param, img, out, expected_shape): act = Act[input_param]() result = act(img) From 9ca029d6f225802d2b1296a3bb8f42a4683939bf Mon Sep 17 00:00:00 2001 From: Felix Schnabel Date: Thu, 19 Jan 2023 20:30:13 +0100 Subject: [PATCH 7/7] Fix formatting Signed-off-by: Felix Schnabel --- tests/test_activations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_activations.py b/tests/test_activations.py index 494a64df0e..0e83c73304 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -88,7 +88,7 @@ TEST_CASE_7 = [ "geglu", torch.tensor([[[-10, -8, -6, -4, -2, 0], [0, 2, 4, 6, 8, 10]]], dtype=torch.float32), - torch.tensor([[[1.27e-03, 3.64e-01, 0.00e+00], [0.00e00, 1.60e+01, 4.00e01]]]), + torch.tensor([[[1.27e-03, 3.64e-01, 0.00e00], [0.00e00, 1.60e01, 4.00e01]]]), (1, 2, 3), ]