Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions generative/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ def __init__(
num_head_channels: int = 64,
dropout: float = 0.0,
upcast_attention: bool = False,
use_flash_attention: bool = False,
) -> None:
super().__init__()
self.use_flash_attention = use_flash_attention
inner_dim = num_head_channels * num_attention_heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim

Expand Down Expand Up @@ -161,7 +163,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)

if has_xformers:
if self.use_flash_attention:
x = self._memory_efficient_attention_xformers(query, key, value)
else:
x = self._attention(query, key, value)
Expand Down Expand Up @@ -353,8 +355,10 @@ def __init__(
num_head_channels: int | None = None,
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
use_flash_attention: bool = False,
) -> None:
super().__init__()
self.use_flash_attention = use_flash_attention
self.spatial_dims = spatial_dims
self.num_channels = num_channels

Expand Down Expand Up @@ -429,7 +433,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)

if has_xformers:
if self.use_flash_attention:
x = self._memory_efficient_attention_xformers(query, key, value)
else:
x = self._attention(query, key, value)
Expand Down Expand Up @@ -1604,6 +1608,7 @@ class DiffusionModelUNet(nn.Module):
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
classes.
upcast_attention: if True, upcast attention operations to full precision.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
"""

def __init__(
Expand All @@ -1623,6 +1628,7 @@ def __init__(
cross_attention_dim: int | None = None,
num_class_embeds: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
) -> None:
super().__init__()
if with_conditioning is True and cross_attention_dim is None:
Expand Down Expand Up @@ -1660,6 +1666,11 @@ def __init__(
"`num_channels`."
)

if use_flash_attention is True and not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU."
)

self.in_channels = in_channels
self.block_out_channels = num_channels
self.out_channels = out_channels
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ nni
optuna
git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
lpips==0.1.4
xformers==0.0.16
x-transformers==1.8.1