Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions monai/networks/blocks/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@
import torch
from torch import nn

from monai.utils import optional_import

if optional_import("torch.nn.functional", name="mish")[1]:

def monai_mish(x):
return torch.nn.functional.mish(x, inplace=True)


else:

def monai_mish(x):
return x * torch.tanh(torch.nn.functional.softplus(x))


class Swish(nn.Module):
r"""Applies the element-wise function:
Expand All @@ -30,6 +43,8 @@ class Swish(nn.Module):

Examples::

>>> import torch
>>> from monai.networks.layers.factories import Act
>>> m = Act['swish']()
>>> input = torch.randn(2)
>>> output = m(input)
Expand Down Expand Up @@ -85,6 +100,8 @@ class MemoryEfficientSwish(nn.Module):

Examples::

>>> import torch
>>> from monai.networks.layers.factories import Act
>>> m = Act['memswish']()
>>> input = torch.randn(2)
>>> output = m(input)
Expand All @@ -111,10 +128,12 @@ class Mish(nn.Module):

Examples::

>>> import torch
>>> from monai.networks.layers.factories import Act
>>> m = Act['mish']()
>>> input = torch.randn(2)
>>> output = m(input)
"""

def forward(self, input: torch.Tensor) -> torch.Tensor:
return input * torch.tanh(torch.nn.functional.softplus(input))
def forward(self, input: torch.Tensor):
return monai_mish(input)