From 1cf0eaf165af51879df378a4ecc65f1ed40e202b Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 23 Jan 2023 22:00:30 +0000 Subject: [PATCH 1/2] Update monai-weekly prerelease and replace FeedForward with MLPBlock Signed-off-by: Walter Hugo Lopez Pinaya --- .../networks/nets/diffusion_model_unet.py | 50 ++----------------- requirements.txt | 2 +- setup.py | 2 +- 3 files changed, 6 insertions(+), 48 deletions(-) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 7eb80b8c..de653c16 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -35,7 +35,7 @@ import torch import torch.nn.functional as F -from monai.networks.blocks import Convolution +from monai.networks.blocks import Convolution, MLPBlock from monai.networks.layers.factories import Pool from torch import nn @@ -66,46 +66,6 @@ def zero_module(module: nn.Module) -> nn.Module: return module -class GEGLU(nn.Module): - """ - A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. - - Args: - dim_in: number of channels in the input. - dim_out: number of channels in the output. - """ - - def __init__(self, dim_in: int, dim_out: int) -> None: - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) - - -class FeedForward(nn.Module): - """ - A feed-forward layer. - - Args: - num_channels: number of channels in the input. - dim_out: number of channels in the output. If not given, defaults to `dim`. - mult: multiplier to use for the hidden dimension. - dropout: dropout probability to use. - """ - - def __init__(self, num_channels: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0) -> None: - super().__init__() - inner_dim = int(num_channels * mult) - dim_out = dim_out if dim_out is not None else num_channels - - self.net = nn.Sequential(GEGLU(num_channels, inner_dim), nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.net(x) - - class CrossAttention(nn.Module): """ A cross attention layer. @@ -239,7 +199,7 @@ def __init__( dropout=dropout, upcast_attention=upcast_attention, ) # is a self-attention - self.ff = FeedForward(num_channels, dropout=dropout) + self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU") self.attn2 = CrossAttention( query_dim=num_channels, cross_attention_dim=cross_attention_dim, @@ -1677,10 +1637,8 @@ def __init__( super().__init__() if with_conditioning is True and cross_attention_dim is None: raise ValueError( - ( - "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " - "when using with_conditioning." - ) + "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." ) if cross_attention_dim is not None and with_conditioning is False: raise ValueError( diff --git a/requirements.txt b/requirements.txt index 4254669e..ec12d276 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ numpy>=1.17 torch>=1.8 -monai-weekly==1.1.dev2248 +monai-weekly==1.2.dev2304 diff --git a/setup.py b/setup.py index 7a74fb5c..03fd1b59 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,6 @@ version="0.1.0", description="Installer to help to use the prototypes from MONAI generative models in other projects.", install_requires=[ - "monai-weekly==1.1.dev2248", + "monai-weekly==1.2.dev2304", ], ) From a8d57068808e7b22b7bc85a23c88f67568ef0ccd Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 23 Jan 2023 22:33:47 +0000 Subject: [PATCH 2/2] Add dropout Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/diffusion_model_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index de653c16..57ea7b3f 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -199,7 +199,7 @@ def __init__( dropout=dropout, upcast_attention=upcast_attention, ) # is a self-attention - self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU") + self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) self.attn2 = CrossAttention( query_dim=num_channels, cross_attention_dim=cross_attention_dim,