diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index f6a04e830e..ef2d19b550 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -12,6 +12,19 @@ import torch from torch import nn +from monai.utils import optional_import + +if optional_import("torch.nn.functional", name="mish")[1]: + + def monai_mish(x): + return torch.nn.functional.mish(x, inplace=True) + + +else: + + def monai_mish(x): + return x * torch.tanh(torch.nn.functional.softplus(x)) + class Swish(nn.Module): r"""Applies the element-wise function: @@ -30,6 +43,8 @@ class Swish(nn.Module): Examples:: + >>> import torch + >>> from monai.networks.layers.factories import Act >>> m = Act['swish']() >>> input = torch.randn(2) >>> output = m(input) @@ -85,6 +100,8 @@ class MemoryEfficientSwish(nn.Module): Examples:: + >>> import torch + >>> from monai.networks.layers.factories import Act >>> m = Act['memswish']() >>> input = torch.randn(2) >>> output = m(input) @@ -111,10 +128,12 @@ class Mish(nn.Module): Examples:: + >>> import torch + >>> from monai.networks.layers.factories import Act >>> m = Act['mish']() >>> input = torch.randn(2) >>> output = m(input) """ - def forward(self, input: torch.Tensor) -> torch.Tensor: - return input * torch.tanh(torch.nn.functional.softplus(input)) + def forward(self, input: torch.Tensor): + return monai_mish(input)