diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index b651f206..fba39153 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -90,8 +90,10 @@ def __init__( num_head_channels: int = 64, dropout: float = 0.0, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() + self.use_flash_attention = use_flash_attention inner_dim = num_head_channels * num_attention_heads cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim @@ -161,7 +163,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) - if has_xformers: + if self.use_flash_attention: x = self._memory_efficient_attention_xformers(query, key, value) else: x = self._attention(query, key, value) @@ -353,8 +355,10 @@ def __init__( num_head_channels: int | None = None, norm_num_groups: int = 32, norm_eps: float = 1e-6, + use_flash_attention: bool = False, ) -> None: super().__init__() + self.use_flash_attention = use_flash_attention self.spatial_dims = spatial_dims self.num_channels = num_channels @@ -429,7 +433,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) - if has_xformers: + if self.use_flash_attention: x = self._memory_efficient_attention_xformers(query, key, value) else: x = self._attention(query, key, value) @@ -1604,6 +1608,7 @@ class DiffusionModelUNet(nn.Module): num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -1623,6 +1628,7 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1660,6 +1666,11 @@ def __init__( "`num_channels`." ) + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + self.in_channels = in_channels self.block_out_channels = num_channels self.out_channels = out_channels diff --git a/requirements-dev.txt b/requirements-dev.txt index 2c8e5a6d..71ccc476 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -54,4 +54,5 @@ nni optuna git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded lpips==0.1.4 +xformers==0.0.16 x-transformers==1.8.1