diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 64ba0313c4..3a3c47afd5 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -58,6 +58,11 @@ Blocks .. autoclass:: Mish :members: +`GEGLU` +~~~~~~~ +.. autoclass:: GEGLU + :members: + `GCN Module` ~~~~~~~~~~~~ .. autoclass:: GCN diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 5db8a2f8df..e67cb3376f 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -12,7 +12,7 @@ from __future__ import annotations from .acti_norm import ADN -from .activation import MemoryEfficientSwish, Mish, Swish +from .activation import GEGLU, MemoryEfficientSwish, Mish, Swish from .aspp import SimpleASPP from .backbone_fpn_utils import BackboneWithFPN from .convolutions import Convolution, ResidualUnit diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index 22d3c74897..1e5e979dff 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -162,3 +162,23 @@ def __init__(self, inplace: bool = False): def forward(self, input: torch.Tensor): return monai_mish(input, self.inplace) + + +class GEGLU(nn.Module): + r"""Applies the element-wise function: + + .. math:: + \text{GEGLU}(x) = x_1 * \text{Sigmoid}(x_2) + + where :math:`x_1` and :math:`x_2` are split from the input tensor along the last dimension. + + Citation: GLU Variants Improve Transformer, Noam Shazeer, 2020, https://arxiv.org/abs/2002.05202. + + Shape: + - Input: :math:`(N, *, 2 * D)` + - Output: :math:`(N, *, D)`, where `*` means, any number of additional dimensions + """ + + def forward(self, input: torch.Tensor): + x, gate = input.chunk(2, dim=-1) + return x * nn.functional.gelu(gate) diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index e691112a1e..e3ab94b32a 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -33,7 +33,7 @@ def __init__( hidden_size: dimension of hidden layer. mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used. dropout_rate: faction of the input units to drop. - act: activation type and arguments. Defaults to GELU. + act: activation type and arguments. Defaults to GELU. Also supports "GEGLU" and others. dropout_mode: dropout mode, can be "vit" or "swin". "vit" mode uses two dropout instances as implemented in https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87 @@ -48,7 +48,7 @@ def __init__( if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") mlp_dim = mlp_dim or hidden_size - self.linear1 = nn.Linear(hidden_size, mlp_dim) + self.linear1 = nn.Linear(hidden_size, mlp_dim) if act != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2) self.linear2 = nn.Linear(mlp_dim, hidden_size) self.fn = get_act_layer(act) self.drop1 = nn.Dropout(dropout_rate) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 038b85dbca..37536a7cf6 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -320,6 +320,13 @@ def mish_factory(): return Mish +@Act.factory_function("geglu") +def geglu_factory(): + from monai.networks.blocks.activation import GEGLU + + return GEGLU + + @Conv.factory_function("conv") def conv_factory(dim: int) -> type[nn.Conv1d | nn.Conv2d | nn.Conv3d]: types = (nn.Conv1d, nn.Conv2d, nn.Conv3d) diff --git a/tests/test_activations.py b/tests/test_activations.py index 476b44ab6f..0e83c73304 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -85,6 +85,13 @@ (1, 2, 5), ] +TEST_CASE_7 = [ + "geglu", + torch.tensor([[[-10, -8, -6, -4, -2, 0], [0, 2, 4, 6, 8, 10]]], dtype=torch.float32), + torch.tensor([[[1.27e-03, 3.64e-01, 0.00e00], [0.00e00, 1.60e01, 4.00e01]]]), + (1, 2, 3), +] + class TestActivations(unittest.TestCase): @parameterized.expand(TEST_CASES) @@ -101,7 +108,7 @@ def _compare(ret, out, shape): else: _compare(result, out, expected_shape) - @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_monai_activations_value_shape(self, input_param, img, out, expected_shape): act = Act[input_param]() result = act(img)