Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 4 additions & 46 deletions generative/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

import torch
import torch.nn.functional as F
from monai.networks.blocks import Convolution
from monai.networks.blocks import Convolution, MLPBlock
from monai.networks.layers.factories import Pool
from torch import nn

Expand Down Expand Up @@ -66,46 +66,6 @@ def zero_module(module: nn.Module) -> nn.Module:
return module


class GEGLU(nn.Module):
"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.

Args:
dim_in: number of channels in the input.
dim_out: number of channels in the output.
"""

def __init__(self, dim_in: int, dim_out: int) -> None:
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)


class FeedForward(nn.Module):
"""
A feed-forward layer.

Args:
num_channels: number of channels in the input.
dim_out: number of channels in the output. If not given, defaults to `dim`.
mult: multiplier to use for the hidden dimension.
dropout: dropout probability to use.
"""

def __init__(self, num_channels: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0) -> None:
super().__init__()
inner_dim = int(num_channels * mult)
dim_out = dim_out if dim_out is not None else num_channels

self.net = nn.Sequential(GEGLU(num_channels, inner_dim), nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)


class CrossAttention(nn.Module):
"""
A cross attention layer.
Expand Down Expand Up @@ -239,7 +199,7 @@ def __init__(
dropout=dropout,
upcast_attention=upcast_attention,
) # is a self-attention
self.ff = FeedForward(num_channels, dropout=dropout)
self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout)
self.attn2 = CrossAttention(
query_dim=num_channels,
cross_attention_dim=cross_attention_dim,
Expand Down Expand Up @@ -1677,10 +1637,8 @@ def __init__(
super().__init__()
if with_conditioning is True and cross_attention_dim is None:
raise ValueError(
(
"DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
"when using with_conditioning."
)
"DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
"when using with_conditioning."
)
if cross_attention_dim is not None and with_conditioning is False:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
numpy>=1.17
torch>=1.8
monai-weekly==1.1.dev2248
monai-weekly==1.2.dev2304
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@
version="0.1.0",
description="Installer to help to use the prototypes from MONAI generative models in other projects.",
install_requires=[
"monai-weekly==1.1.dev2248",
"monai-weekly==1.2.dev2304",
],
)