Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion generative/networks/nets/autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class AttentionBlock(nn.Module):
norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of
channels is divisible by this number.
norm_eps: epsilon value to use for the normalisation.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
"""

def __init__(
Expand All @@ -196,8 +197,10 @@ def __init__(
num_head_channels: int | None = None,
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
use_flash_attention: bool = False,
) -> None:
super().__init__()
self.use_flash_attention = use_flash_attention
self.spatial_dims = spatial_dims
self.num_channels = num_channels

Expand Down Expand Up @@ -276,7 +279,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)

if has_xformers:
if self.use_flash_attention:
x = self._memory_efficient_attention_xformers(query, key, value)
else:
x = self._attention(query, key, value)
Expand Down Expand Up @@ -306,6 +309,7 @@ class Encoder(nn.Module):
norm_eps: epsilon for the normalization.
attention_levels: indicate which level from num_channels contain an attention block.
with_nonlocal_attn: if True use non-local attention block.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
"""

def __init__(
Expand All @@ -319,6 +323,7 @@ def __init__(
norm_eps: float,
attention_levels: Sequence[bool],
with_nonlocal_attn: bool = True,
use_flash_attention: bool = False,
) -> None:
super().__init__()
self.spatial_dims = spatial_dims
Expand Down Expand Up @@ -369,6 +374,7 @@ def __init__(
num_channels=input_channel,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
use_flash_attention=use_flash_attention,
)
)

Expand All @@ -393,6 +399,7 @@ def __init__(
num_channels=num_channels[-1],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
use_flash_attention=use_flash_attention,
)
)
blocks.append(
Expand Down Expand Up @@ -442,6 +449,7 @@ class Decoder(nn.Module):
norm_eps: epsilon for the normalization.
attention_levels: indicate which level from num_channels contain an attention block.
with_nonlocal_attn: if True use non-local attention block.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
"""

def __init__(
Expand All @@ -455,6 +463,7 @@ def __init__(
norm_eps: float,
attention_levels: Sequence[bool],
with_nonlocal_attn: bool = True,
use_flash_attention: bool = False,
) -> None:
super().__init__()
self.spatial_dims = spatial_dims
Expand Down Expand Up @@ -499,6 +508,7 @@ def __init__(
num_channels=reversed_block_out_channels[0],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
use_flash_attention=use_flash_attention,
)
)
blocks.append(
Expand Down Expand Up @@ -538,6 +548,7 @@ def __init__(
num_channels=block_in_ch,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
use_flash_attention=use_flash_attention,
)
)

Expand Down Expand Up @@ -583,6 +594,7 @@ class AutoencoderKL(nn.Module):
norm_eps: epsilon for the normalization.
with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.
with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
"""

def __init__(
Expand All @@ -598,6 +610,7 @@ def __init__(
norm_eps: float = 1e-6,
with_encoder_nonlocal_attn: bool = True,
with_decoder_nonlocal_attn: bool = True,
use_flash_attention: bool = False,
) -> None:
super().__init__()

Expand All @@ -617,6 +630,11 @@ def __init__(
"`num_channels`."
)

if use_flash_attention is True and not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU."
)

self.encoder = Encoder(
spatial_dims=spatial_dims,
in_channels=in_channels,
Expand All @@ -627,6 +645,7 @@ def __init__(
norm_eps=norm_eps,
attention_levels=attention_levels,
with_nonlocal_attn=with_encoder_nonlocal_attn,
use_flash_attention=use_flash_attention,
)
self.decoder = Decoder(
spatial_dims=spatial_dims,
Expand All @@ -638,6 +657,7 @@ def __init__(
norm_eps=norm_eps,
attention_levels=attention_levels,
with_nonlocal_attn=with_decoder_nonlocal_attn,
use_flash_attention=use_flash_attention,
)
self.quant_conv_mu = Convolution(
spatial_dims=spatial_dims,
Expand Down
Loading