diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index 23750839..276bf86f 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -187,6 +187,7 @@ class AttentionBlock(nn.Module): norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of channels is divisible by this number. norm_eps: epsilon value to use for the normalisation. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -196,8 +197,10 @@ def __init__( num_head_channels: int | None = None, norm_num_groups: int = 32, norm_eps: float = 1e-6, + use_flash_attention: bool = False, ) -> None: super().__init__() + self.use_flash_attention = use_flash_attention self.spatial_dims = spatial_dims self.num_channels = num_channels @@ -276,7 +279,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) - if has_xformers: + if self.use_flash_attention: x = self._memory_efficient_attention_xformers(query, key, value) else: x = self._attention(query, key, value) @@ -306,6 +309,7 @@ class Encoder(nn.Module): norm_eps: epsilon for the normalization. attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -319,6 +323,7 @@ def __init__( norm_eps: float, attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -369,6 +374,7 @@ def __init__( num_channels=input_channel, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + use_flash_attention=use_flash_attention, ) ) @@ -393,6 +399,7 @@ def __init__( num_channels=num_channels[-1], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + use_flash_attention=use_flash_attention, ) ) blocks.append( @@ -442,6 +449,7 @@ class Decoder(nn.Module): norm_eps: epsilon for the normalization. attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -455,6 +463,7 @@ def __init__( norm_eps: float, attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -499,6 +508,7 @@ def __init__( num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + use_flash_attention=use_flash_attention, ) ) blocks.append( @@ -538,6 +548,7 @@ def __init__( num_channels=block_in_ch, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + use_flash_attention=use_flash_attention, ) ) @@ -583,6 +594,7 @@ class AutoencoderKL(nn.Module): norm_eps: epsilon for the normalization. with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -598,6 +610,7 @@ def __init__( norm_eps: float = 1e-6, with_encoder_nonlocal_attn: bool = True, with_decoder_nonlocal_attn: bool = True, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -617,6 +630,11 @@ def __init__( "`num_channels`." ) + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + self.encoder = Encoder( spatial_dims=spatial_dims, in_channels=in_channels, @@ -627,6 +645,7 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_encoder_nonlocal_attn, + use_flash_attention=use_flash_attention, ) self.decoder = Decoder( spatial_dims=spatial_dims, @@ -638,6 +657,7 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_decoder_nonlocal_attn, + use_flash_attention=use_flash_attention, ) self.quant_conv_mu = Convolution( spatial_dims=spatial_dims, diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index fba39153..fe8bf95c 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -80,6 +80,7 @@ class CrossAttention(nn.Module): num_head_channels: number of channels in each head. dropout: dropout probability to use. upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -185,6 +186,7 @@ class BasicTransformerBlock(nn.Module): dropout: dropout probability to use. cross_attention_dim: size of the context vector for cross attention. upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -195,6 +197,7 @@ def __init__( dropout: float = 0.0, cross_attention_dim: int | None = None, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.attn1 = CrossAttention( @@ -203,6 +206,7 @@ def __init__( num_head_channels=num_head_channels, dropout=dropout, upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, ) # is a self-attention self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) self.attn2 = CrossAttention( @@ -212,6 +216,7 @@ def __init__( num_head_channels=num_head_channels, dropout=dropout, upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, ) # is a self-attention if context is None self.norm1 = nn.LayerNorm(num_channels) self.norm2 = nn.LayerNorm(num_channels) @@ -245,6 +250,7 @@ class SpatialTransformer(nn.Module): norm_eps: epsilon for the normalization. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -259,6 +265,7 @@ def __init__( norm_eps: float = 1e-6, cross_attention_dim: int | None = None, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -286,6 +293,7 @@ def __init__( dropout=dropout, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, ) for _ in range(num_layers) ] @@ -346,6 +354,7 @@ class AttentionBlock(nn.Module): norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of channels is divisible by this number. norm_eps: epsilon value to use for the normalisation. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -451,7 +460,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor: """ - Create sinusoidal timestep embeddingsfollowing the implementation in Ho et al. "Denoising Diffusion Probabilistic + Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic Models" https://arxiv.org/abs/2006.11239. Args: @@ -784,6 +793,7 @@ class AttnDownBlock(nn.Module): resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -799,6 +809,7 @@ def __init__( resblock_updown: bool = False, downsample_padding: int = 1, num_head_channels: int = 1, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -825,6 +836,7 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + use_flash_attention=use_flash_attention, ) ) @@ -890,6 +902,7 @@ class CrossAttnDownBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -908,6 +921,7 @@ def __init__( transformer_num_layers: int = 1, cross_attention_dim: int | None = None, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -939,6 +953,7 @@ def __init__( norm_eps=norm_eps, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, ) ) @@ -995,6 +1010,7 @@ class AttnMidBlock(nn.Module): norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -1005,6 +1021,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, num_head_channels: int = 1, + use_flash_attention: bool = False, ) -> None: super().__init__() self.attention = None @@ -1023,6 +1040,7 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + use_flash_attention=use_flash_attention, ) self.resnet_2 = ResnetBlock( @@ -1059,6 +1077,7 @@ class CrossAttnMidBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -1072,6 +1091,7 @@ def __init__( transformer_num_layers: int = 1, cross_attention_dim: int | None = None, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.attention = None @@ -1094,6 +1114,7 @@ def __init__( norm_eps=norm_eps, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, ) self.resnet_2 = ResnetBlock( spatial_dims=spatial_dims, @@ -1221,6 +1242,7 @@ class AttnUpBlock(nn.Module): add_upsample: if True add downsample block. resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -1236,6 +1258,7 @@ def __init__( add_upsample: bool = True, resblock_updown: bool = False, num_head_channels: int = 1, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1264,6 +1287,7 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + use_flash_attention=use_flash_attention, ) ) @@ -1330,6 +1354,7 @@ class CrossAttnUpBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -1348,6 +1373,7 @@ def __init__( transformer_num_layers: int = 1, cross_attention_dim: int | None = None, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1380,6 +1406,7 @@ def __init__( num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, ) ) @@ -1442,6 +1469,7 @@ def get_down_block( transformer_num_layers: int, cross_attention_dim: int | None, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return AttnDownBlock( @@ -1455,6 +1483,7 @@ def get_down_block( add_downsample=add_downsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnDownBlock( @@ -1471,6 +1500,7 @@ def get_down_block( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, ) else: return DownBlock( @@ -1497,6 +1527,7 @@ def get_mid_block( transformer_num_layers: int, cross_attention_dim: int | None, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_conditioning: return CrossAttnMidBlock( @@ -1509,6 +1540,7 @@ def get_mid_block( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, ) else: return AttnMidBlock( @@ -1518,6 +1550,7 @@ def get_mid_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, ) @@ -1538,6 +1571,7 @@ def get_up_block( transformer_num_layers: int, cross_attention_dim: int | None, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return AttnUpBlock( @@ -1552,6 +1586,7 @@ def get_up_block( add_upsample=add_upsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnUpBlock( @@ -1569,6 +1604,7 @@ def get_up_block( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, ) else: return UpBlock( @@ -1725,6 +1761,7 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -1741,6 +1778,7 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, ) # up @@ -1774,6 +1812,7 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, ) self.up_blocks.append(up_block) diff --git a/tests/min_tests.py b/tests/min_tests.py index b4373dd8..1bc9eed7 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -29,10 +29,11 @@ def run_testsuit(): :return: a test suite """ exclude_cases = [ # these cases use external dependencies - "test_perceptual_loss", + "test_autoencoderkl", + "test_diffusion_inferer", "test_integration_workflows_adversarial", "test_latent_diffusion_inferer", - "test_diffusion_inferer", + "test_perceptual_loss", "test_transformer", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py index e6280169..3c6ca000 100644 --- a/tests/test_autoencoderkl.py +++ b/tests/test_autoencoderkl.py @@ -18,7 +18,6 @@ from parameterized import parameterized from generative.networks.nets import AutoencoderKL -from tests.utils import test_script_save device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -144,11 +143,11 @@ def test_shape(self, input_param, input_shape, expected_shape, expected_latent_s self.assertEqual(result[1].shape, expected_latent_shape) self.assertEqual(result[2].shape, expected_latent_shape) - def test_script(self): - input_param, input_shape, _, _ = CASES[0] - net = AutoencoderKL(**input_param) - test_data = torch.randn(input_shape) - test_script_save(net, test_data) + # def test_script(self): + # input_param, input_shape, _, _ = CASES[0] + # net = AutoencoderKL(**input_param) + # test_data = torch.randn(input_shape) + # test_script_save(net, test_data) def test_model_channels_not_multiple_of_norm_num_group(self): with self.assertRaises(ValueError):