From 1756a2ca2d05e9b861d174c83a663398df9b59e2 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 4 Feb 2023 14:33:58 +0000 Subject: [PATCH 1/4] Add option to use flash attention Signed-off-by: Walter Hugo Lopez Pinaya --- .../networks/nets/diffusion_model_unet.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 76caa08b..19f70f52 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -87,8 +87,10 @@ def __init__( num_head_channels: int = 64, dropout: float = 0.0, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() + self.use_flash_attention = use_flash_attention inner_dim = num_head_channels * num_attention_heads cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim @@ -158,7 +160,7 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> to key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) - if has_xformers: + if self.use_flash_attention: x = self._memory_efficient_attention_xformers(query, key, value) else: x = self._attention(query, key, value) @@ -350,8 +352,10 @@ def __init__( num_head_channels: Optional[int] = None, norm_num_groups: int = 32, norm_eps: float = 1e-6, + use_flash_attention: bool = False, ) -> None: super().__init__() + self.use_flash_attention = use_flash_attention self.spatial_dims = spatial_dims self.num_channels = num_channels @@ -426,7 +430,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) - if has_xformers: + if self.use_flash_attention: x = self._memory_efficient_attention_xformers(query, key, value) else: x = self._attention(query, key, value) @@ -1620,6 +1624,7 @@ def __init__( cross_attention_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1648,6 +1653,14 @@ def __init__( " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." ) + if has_xformers is False and use_flash_attention is True: + raise ValueError("DiffusionModelUNet expects xformers to be installed when using flash attention.") + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + self.in_channels = in_channels self.block_out_channels = num_channels self.out_channels = out_channels From 258aa6668ab1432c561bef66f3eb5a9f8dc788d6 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 4 Feb 2023 14:36:51 +0000 Subject: [PATCH 2/4] Fix docstring Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/diffusion_model_unet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 19f70f52..9c1a70a6 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -1605,6 +1605,7 @@ class DiffusionModelUNet(nn.Module): num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( From 9b31ca82ad742b42d08e2c92dde3492fe87891ca Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 4 Feb 2023 14:48:37 +0000 Subject: [PATCH 3/4] Add xformers==0.0.16 to requirements-dev.txt Signed-off-by: Walter Hugo Lopez Pinaya --- requirements-dev.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index 1a9ddd18..5b19e5d4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -54,3 +54,4 @@ nni optuna git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded lpips==0.1.4 +xformers==0.0.16 From 617fffae0c57230c1c0559851f9a73c7393d1948 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sun, 26 Feb 2023 19:18:25 +0000 Subject: [PATCH 4/4] Fixed torchscript error Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/diffusion_model_unet.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index e7d1aead..fba39153 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -1666,9 +1666,6 @@ def __init__( "`num_channels`." ) - if has_xformers is False and use_flash_attention is True: - raise ValueError("DiffusionModelUNet expects xformers to be installed when using flash attention.") - if use_flash_attention is True and not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU."