From 4d29fa6fdfda50f685283762c878f9a06a1a2b5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xa9aX=20=E3=83=84?= Date: Tue, 10 Aug 2021 00:52:49 +0530 Subject: [PATCH 1/5] Update Mish to default PyTorch 1.9 version Signed-off-by: Wenqi Li --- monai/networks/blocks/activation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index f6a04e830e..387705821a 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -117,4 +117,4 @@ class Mish(nn.Module): """ def forward(self, input: torch.Tensor) -> torch.Tensor: - return input * torch.tanh(torch.nn.functional.softplus(input)) + return F.mish(input) if if torch.__version__ >= '1.9' else input * torch.tanh(torch.nn.functional.softplus(input)) From 6b07bf4aa43489004e647e63a857ea9152fa0105 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xa9aX=20=E3=83=84?= Date: Tue, 10 Aug 2021 01:22:19 +0530 Subject: [PATCH 2/5] Update activation.py Signed-off-by: Wenqi Li --- monai/networks/blocks/activation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index 387705821a..da946d65b7 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -117,4 +117,4 @@ class Mish(nn.Module): """ def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.mish(input) if if torch.__version__ >= '1.9' else input * torch.tanh(torch.nn.functional.softplus(input)) + return F.mish(input) if torch.__version__ >= '1.9' else input * torch.tanh(torch.nn.functional.softplus(input)) From f275df7f3411a1a628da33520fb51bea3ec406cd Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 10 Aug 2021 19:44:31 +0100 Subject: [PATCH 3/5] update based on comments Signed-off-by: Wenqi Li --- monai/networks/blocks/activation.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index da946d65b7..7372b12bf2 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -12,6 +12,8 @@ import torch from torch import nn +from monai.utils import get_torch_version_tuple + class Swish(nn.Module): r"""Applies the element-wise function: @@ -30,6 +32,8 @@ class Swish(nn.Module): Examples:: + >>> import torch + >>> from monai.networks.layers.factories import Act >>> m = Act['swish']() >>> input = torch.randn(2) >>> output = m(input) @@ -85,6 +89,8 @@ class MemoryEfficientSwish(nn.Module): Examples:: + >>> import torch + >>> from monai.networks.layers.factories import Act >>> m = Act['memswish']() >>> input = torch.randn(2) >>> output = m(input) @@ -111,10 +117,16 @@ class Mish(nn.Module): Examples:: + >>> import torch + >>> from monai.networks.layers.factories import Act >>> m = Act['mish']() >>> input = torch.randn(2) >>> output = m(input) """ def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.mish(input) if torch.__version__ >= '1.9' else input * torch.tanh(torch.nn.functional.softplus(input)) + return ( + nn.functional.mish(input) + if get_torch_version_tuple() >= (1, 9) + else input * torch.tanh(torch.nn.functional.softplus(input)) + ) From 53843194a02ff1f8ed923ba6ca16ce2c45a1ab02 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 10 Aug 2021 21:18:23 +0100 Subject: [PATCH 4/5] impl. with optional import Signed-off-by: Wenqi Li --- monai/networks/blocks/activation.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index 7372b12bf2..1f3d95f5ad 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -12,7 +12,18 @@ import torch from torch import nn -from monai.utils import get_torch_version_tuple +from monai.utils import optional_import + +if optional_import("torch.nn.functional", name="mish")[1]: + + def monai_mish(x): + return torch.nn.functional.mish(x, inplace=True) + + +else: + + def monai_mish(x): + return x * torch.tanh(torch.nn.functional.softplus(x)) class Swish(nn.Module): @@ -125,8 +136,4 @@ class Mish(nn.Module): """ def forward(self, input: torch.Tensor) -> torch.Tensor: - return ( - nn.functional.mish(input) - if get_torch_version_tuple() >= (1, 9) - else input * torch.tanh(torch.nn.functional.softplus(input)) - ) + return monai_mish(input) From 6569e73923824132e8772e348e4aedbd6c190c1f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 11 Aug 2021 09:23:56 +0100 Subject: [PATCH 5/5] fixes mypy error Signed-off-by: Wenqi Li --- monai/networks/blocks/activation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index 1f3d95f5ad..ef2d19b550 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -135,5 +135,5 @@ class Mish(nn.Module): >>> output = m(input) """ - def forward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, input: torch.Tensor): return monai_mish(input)