From c9475ef428d0d13259147019b86581589c4ea5c7 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Sun, 4 Aug 2024 22:58:58 -0400 Subject: [PATCH 1/7] update Signed-off-by: Pengfei Guo --- .../networks/diffusion_model_unet_maisi.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index d5f5f6136b..e71c96daf7 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -40,15 +40,16 @@ from monai.utils import ensure_tuple_rep, optional_import from monai.utils.type_conversion import convert_to_tensor -get_down_block, has_get_down_block = optional_import( - "generative.networks.nets.diffusion_model_unet", name="get_down_block" -) -get_mid_block, has_get_mid_block = optional_import( - "generative.networks.nets.diffusion_model_unet", name="get_mid_block" -) -get_timestep_embedding, has_get_timestep_embedding = optional_import( - "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" -) +# get_down_block, has_get_down_block = optional_import( +# "generative.networks.nets.diffusion_model_unet", name="get_down_block" +# ) +# get_mid_block, has_get_mid_block = optional_import( +# "generative.networks.nets.diffusion_model_unet", name="get_mid_block" +# ) +# get_timestep_embedding, has_get_timestep_embedding = optional_import( +# "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" +# ) +from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding get_up_block, has_get_up_block = optional_import("generative.networks.nets.diffusion_model_unet", name="get_up_block") xformers, has_xformers = optional_import("xformers") zero_module, has_zero_module = optional_import("generative.networks.nets.diffusion_model_unet", name="zero_module") From 66b886ab137406cdf11578a883904dc5e419aa67 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Tue, 6 Aug 2024 02:19:28 +0000 Subject: [PATCH 2/7] update Signed-off-by: Pengfei Guo --- .../networks/diffusion_model_unet_maisi.py | 20 ++------ monai/networks/blocks/crossattention.py | 31 +++++++----- monai/networks/blocks/selfattention.py | 50 ++++++++++++------- monai/networks/blocks/spatialattention.py | 7 ++- monai/networks/nets/diffusion_model_unet.py | 26 ++++++++++ 5 files changed, 85 insertions(+), 49 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index e71c96daf7..489c3dcf53 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -40,19 +40,7 @@ from monai.utils import ensure_tuple_rep, optional_import from monai.utils.type_conversion import convert_to_tensor -# get_down_block, has_get_down_block = optional_import( -# "generative.networks.nets.diffusion_model_unet", name="get_down_block" -# ) -# get_mid_block, has_get_mid_block = optional_import( -# "generative.networks.nets.diffusion_model_unet", name="get_mid_block" -# ) -# get_timestep_embedding, has_get_timestep_embedding = optional_import( -# "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" -# ) -from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding -get_up_block, has_get_up_block = optional_import("generative.networks.nets.diffusion_model_unet", name="get_up_block") -xformers, has_xformers = optional_import("xformers") -zero_module, has_zero_module = optional_import("generative.networks.nets.diffusion_model_unet", name="zero_module") +from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_up_block, get_timestep_embedding, zero_module __all__ = ["DiffusionModelUNetMaisi"] @@ -153,9 +141,6 @@ def __init__( "`num_channels`." ) - if use_flash_attention and not has_xformers: - raise ValueError("use_flash_attention is True but xformers is not installed.") - if use_flash_attention is True and not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." @@ -211,7 +196,6 @@ def __init__( input_channel = output_channel output_channel = num_channels[i] is_final_block = i == len(num_channels) - 1 - down_block = get_down_block( spatial_dims=spatial_dims, in_channels=input_channel, @@ -409,3 +393,5 @@ def forward( h = self.out(h) h_tensor: torch.Tensor = convert_to_tensor(h) return h_tensor + + diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index b888ea3942..8c9f6f47f9 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import @@ -44,6 +45,7 @@ def __init__( rel_pos_embedding: Optional[str] = None, input_size: Optional[Tuple] = None, attention_dtype: Optional[torch.dtype] = None, + use_flash_attention: bool = False, ) -> None: """ Args: @@ -119,6 +121,7 @@ def __init__( else None ) self.input_size = input_size + self.use_flash_attention = use_flash_attention def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): """ @@ -147,22 +150,26 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale - # apply relative positional embedding if defined - att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + if self.use_flash_attention: + x = F.scaled_dot_product_attention(q, k, v) + else: + # apply relative positional embedding if defined + att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat - if self.causal: - att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) + if self.causal: + att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) - att_mat = att_mat.softmax(dim=-1) + att_mat = att_mat.softmax(dim=-1) - if self.save_attn: - # no gradients and new tensor; - # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html - self.att_mat = att_mat.detach() + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() - att_mat = self.drop_weights(att_mat) - x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) - x = self.out_proj(x) + # x = self.out_proj(x) x = self.drop_output(x) return x + \ No newline at end of file diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 3ab1e1fd10..c63efbe713 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import @@ -42,6 +43,7 @@ def __init__( rel_pos_embedding: Optional[str] = None, input_size: Optional[Tuple] = None, attention_dtype: Optional[torch.dtype] = None, + use_flash_attention: bool = False, ) -> None: """ Args: @@ -86,9 +88,13 @@ def __init__( self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) - self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) - self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) - self.out_rearrange = Rearrange("b h l d -> b l (h d)") + # self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) + self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + # self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) + self.input_rearrange = Rearrange("b h (c l d) -> (b l) c h d", l=num_heads, c=1) + self.out_rearrange = Rearrange("(b l) h c d -> b h (c l d)", l=num_heads, c=1) self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.scale = self.dim_head**-0.5 @@ -97,6 +103,7 @@ def __init__( self.attention_dtype = attention_dtype self.causal = causal self.sequence_length = sequence_length + self.use_flash_attention = use_flash_attention if causal and sequence_length is not None: # causal mask to ensure that attention is only applied to the left in the input sequence @@ -123,31 +130,38 @@ def forward(self, x): Return: torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C """ - output = self.input_rearrange(self.qkv(x)) - q, k, v = output[0], output[1], output[2] + # output = self.input_rearrange(self.qkv(x)) + q = self.input_rearrange(self.to_q(x)) + k = self.input_rearrange(self.to_k(x)) + v = self.input_rearrange(self.to_v(x)) + # q, k, v = output[0], output[1], output[2] if self.attention_dtype is not None: q = q.to(self.attention_dtype) k = k.to(self.attention_dtype) - att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + if self.use_flash_attention: + x = F.scaled_dot_product_attention(q, k, v).transpose(1,2) + else: + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale - # apply relative positional embedding if defined - att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + # apply relative positional embedding if defined + att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat - if self.causal: - att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf")) + if self.causal: + att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf")) - att_mat = att_mat.softmax(dim=-1) + att_mat = att_mat.softmax(dim=-1) - if self.save_attn: - # no gradients and new tensor; - # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html - self.att_mat = att_mat.detach() + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() - att_mat = self.drop_weights(att_mat) - x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) - x = self.out_proj(x) + # x = self.out_proj(x) x = self.drop_output(x) return x + \ No newline at end of file diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index 75319853d9..4a787e11df 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from monai.networks.blocks import SABlock from monai.utils import optional_import @@ -44,6 +45,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, attention_dtype: Optional[torch.dtype] = None, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -52,9 +54,9 @@ def __init__( # check num_head_channels is divisible by num_channels if num_head_channels is not None and num_channels % num_head_channels != 0: raise ValueError("num_channels must be divisible by num_head_channels") - num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 self.attn = SABlock( - hidden_size=num_channels, num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype + hidden_size=num_channels, num_heads=self.num_heads, qkv_bias=True, attention_dtype=attention_dtype, use_flash_attention=use_flash_attention ) def forward(self, x: torch.Tensor): @@ -80,3 +82,4 @@ def forward(self, x: torch.Tensor): x = rearrange_output(x) # B x x C x x_dim * y_dim * [z_dim] x = x + residual return x + \ No newline at end of file diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 8a9ac859a3..cda8c7fcd8 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -77,6 +77,7 @@ def __init__( dropout: float = 0.0, cross_attention_dim: int | None = None, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.attn1 = SABlock( @@ -86,6 +87,7 @@ def __init__( dim_head=num_head_channels, dropout_rate=dropout, attention_dtype=torch.float if upcast_attention else None, + use_flash_attention=use_flash_attention, ) self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) self.attn2 = CrossAttentionBlock( @@ -143,6 +145,7 @@ def __init__( norm_eps: float = 1e-6, cross_attention_dim: int | None = None, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -170,6 +173,7 @@ def __init__( dropout=dropout, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, ) for _ in range(num_layers) ] @@ -539,6 +543,7 @@ def __init__( resblock_updown: bool = False, downsample_padding: int = 1, num_head_channels: int = 1, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -565,6 +570,7 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + use_flash_attention=use_flash_attention, ) ) @@ -651,6 +657,7 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -683,6 +690,7 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + use_flash_attention=use_flash_attention, ) ) @@ -750,6 +758,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, num_head_channels: int = 1, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -767,6 +776,7 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + use_flash_attention=use_flash_attention, ) self.resnet_2 = DiffusionUNetResnetBlock( @@ -817,6 +827,7 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -839,6 +850,7 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + use_flash_attention=use_flash_attention, ) self.resnet_2 = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, @@ -999,6 +1011,7 @@ def __init__( add_upsample: bool = True, resblock_updown: bool = False, num_head_channels: int = 1, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1027,6 +1040,7 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + use_flash_attention=use_flash_attention, ) ) @@ -1131,6 +1145,7 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1164,6 +1179,7 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + use_flash_attention=use_flash_attention, ) ) @@ -1245,6 +1261,7 @@ def get_down_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return AttnDownBlock( @@ -1258,6 +1275,7 @@ def get_down_block( add_downsample=add_downsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnDownBlock( @@ -1275,6 +1293,7 @@ def get_down_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + use_flash_attention=use_flash_attention, ) else: return DownBlock( @@ -1302,6 +1321,7 @@ def get_mid_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + use_flash_attention: bool = False, ) -> nn.Module: if with_conditioning: return CrossAttnMidBlock( @@ -1315,6 +1335,7 @@ def get_mid_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + use_flash_attention=use_flash_attention, ) else: return AttnMidBlock( @@ -1324,6 +1345,7 @@ def get_mid_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, ) @@ -1345,6 +1367,7 @@ def get_up_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return AttnUpBlock( @@ -1359,6 +1382,7 @@ def get_up_block( add_upsample=add_upsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnUpBlock( @@ -1377,6 +1401,7 @@ def get_up_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + use_flash_attention=use_flash_attention, ) else: return UpBlock( @@ -1911,3 +1936,4 @@ def forward( output: torch.Tensor = self.out(h) return output + \ No newline at end of file From ac12c04a41b2b8a1e6d074feee17922ef18eb338 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 02:21:35 +0000 Subject: [PATCH 3/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../generation/maisi/networks/diffusion_model_unet_maisi.py | 4 +--- monai/networks/blocks/crossattention.py | 1 - monai/networks/blocks/selfattention.py | 1 - monai/networks/blocks/spatialattention.py | 2 -- monai/networks/nets/diffusion_model_unet.py | 1 - 5 files changed, 1 insertion(+), 8 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index 489c3dcf53..9730c7359b 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -37,7 +37,7 @@ from torch import nn from monai.networks.blocks import Convolution -from monai.utils import ensure_tuple_rep, optional_import +from monai.utils import ensure_tuple_rep from monai.utils.type_conversion import convert_to_tensor from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_up_block, get_timestep_embedding, zero_module @@ -393,5 +393,3 @@ def forward( h = self.out(h) h_tensor: torch.Tensor = convert_to_tensor(h) return h_tensor - - diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index 8c9f6f47f9..7c4621ca82 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -172,4 +172,3 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): # x = self.out_proj(x) x = self.drop_output(x) return x - \ No newline at end of file diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index c63efbe713..511542c12e 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -164,4 +164,3 @@ def forward(self, x): # x = self.out_proj(x) x = self.drop_output(x) return x - \ No newline at end of file diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index 4a787e11df..7ab1816d8b 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -15,7 +15,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from monai.networks.blocks import SABlock from monai.utils import optional_import @@ -82,4 +81,3 @@ def forward(self, x: torch.Tensor): x = rearrange_output(x) # B x x C x x_dim * y_dim * [z_dim] x = x + residual return x - \ No newline at end of file diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index cda8c7fcd8..8e0c8dfdc5 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -1936,4 +1936,3 @@ def forward( output: torch.Tensor = self.out(h) return output - \ No newline at end of file From 70690cdf46223a2a4f20db31e6b3c89287caf3d5 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Fri, 9 Aug 2024 15:28:04 +0000 Subject: [PATCH 4/7] update Signed-off-by: Pengfei Guo --- monai/networks/blocks/crossattention.py | 58 +++++--- monai/networks/blocks/selfattention.py | 81 +++++++++--- monai/networks/blocks/spatialattention.py | 18 ++- monai/networks/nets/diffusion_model_unet.py | 138 +++++++++++++++++--- 4 files changed, 231 insertions(+), 64 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index 7c4621ca82..bdecf63168 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -15,10 +15,9 @@ import torch import torch.nn as nn -import torch.nn.functional as F from monai.networks.layers.utils import get_rel_pos_embedding_layer -from monai.utils import optional_import +from monai.utils import optional_import, pytorch_after Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -57,13 +56,15 @@ def __init__( dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. - causal: whether to use causal attention. - sequence_length: if causal is True, it is necessary to specify the sequence length. - rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. - For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. - input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative - positional parameter size. + causal (bool, optional): whether to use causal attention. + sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length. + rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only + "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional + parameter size. attention_dtype: cast attention operations to this dtype. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ super().__init__() @@ -83,6 +84,20 @@ def __init__( if causal and sequence_length is None: raise ValueError("sequence_length is necessary for causal attention.") + if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0): + raise ValueError( + "use_flash_attention is only supported for PyTorch versions >= 2.0." + "Upgrade your PyTorch or set the flag to False." + ) + if use_flash_attention and save_attn: + raise ValueError( + "save_attn has been set to True, but use_flash_attention is also set" + "to True. save_attn can only be used if use_flash_attention is False" + ) + + if use_flash_attention and rel_pos_embedding is not None: + raise ValueError("rel_pos_embedding must be None if you are using flash_attention.") + self.num_heads = num_heads self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.context_input_size = context_input_size if context_input_size else hidden_size @@ -93,9 +108,10 @@ def __init__( self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias) self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads) - self.out_rearrange = Rearrange("b h l d -> b l (h d)") + self.out_rearrange = Rearrange("b l h d -> b h (l d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) + self.dropout_rate = dropout_rate self.scale = self.head_dim**-0.5 self.save_attn = save_attn @@ -103,6 +119,7 @@ def __init__( self.causal = causal self.sequence_length = sequence_length + self.use_flash_attention = use_flash_attention if causal and sequence_length is not None: # causal mask to ensure that attention is only applied to the left in the input sequence @@ -121,7 +138,6 @@ def __init__( else None ) self.input_size = input_size - self.use_flash_attention = use_flash_attention def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): """ @@ -135,26 +151,25 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): # calculate query, key, values for all heads in batch and move head forward to be the batch dim b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) - q = self.to_q(x) + q = self.input_rearrange(self.to_q(x)) kv = context if context is not None else x _, kv_t, _ = kv.size() - k = self.to_k(kv) - v = self.to_v(kv) + k = self.input_rearrange(self.to_k(kv)) + v = self.input_rearrange(self.to_v(kv)) if self.attention_dtype is not None: q = q.to(self.attention_dtype) k = k.to(self.attention_dtype) - q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) - k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) - v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) - att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale - if self.use_flash_attention: - x = F.scaled_dot_product_attention(q, k, v) + x = torch.nn.functional.scaled_dot_product_attention( + query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + ) else: + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined - att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + if self.rel_positional_embedding is not None: + att_mat = self.rel_positional_embedding(x, att_mat, q) if self.causal: att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) @@ -168,7 +183,8 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + x = self.out_rearrange(x) - # x = self.out_proj(x) + x = self.out_proj(x) x = self.drop_output(x) return x diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 511542c12e..ac96b077bd 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,14 +11,14 @@ from __future__ import annotations -from typing import Optional, Tuple +from typing import Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from monai.networks.layers.utils import get_rel_pos_embedding_layer -from monai.utils import optional_import +from monai.utils import optional_import, pytorch_after Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -40,9 +40,11 @@ def __init__( hidden_input_size: int | None = None, causal: bool = False, sequence_length: int | None = None, - rel_pos_embedding: Optional[str] = None, - input_size: Optional[Tuple] = None, - attention_dtype: Optional[torch.dtype] = None, + rel_pos_embedding: str | None = None, + input_size: Tuple | None = None, + attention_dtype: torch.dtype | None = None, + include_fc: bool = True, + use_combined_linear: bool = True, use_flash_attention: bool = False, ) -> None: """ @@ -61,6 +63,10 @@ def __init__( input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. attention_dtype: cast attention operations to this dtype. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ @@ -84,25 +90,51 @@ def __init__( if causal and sequence_length is None: raise ValueError("sequence_length is necessary for causal attention.") + if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0): + raise ValueError( + "use_flash_attention is only supported for PyTorch versions >= 2.0." + "Upgrade your PyTorch or set the flag to False." + ) + if use_flash_attention and save_attn: + raise ValueError( + "save_attn has been set to True, but use_flash_attention is also set" + "to True. save_attn can only be used if use_flash_attention is False." + ) + + if use_flash_attention and rel_pos_embedding is not None: + raise ValueError("rel_pos_embedding must be None if you are using flash_attention.") + self.num_heads = num_heads self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) - # self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) - self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) - self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) - self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) - # self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) - self.input_rearrange = Rearrange("b h (c l d) -> (b l) c h d", l=num_heads, c=1) - self.out_rearrange = Rearrange("(b l) h c d -> b h (c l d)", l=num_heads, c=1) + self.qkv: Union[nn.Linear, nn.Identity] + self.to_q: Union[nn.Linear, nn.Identity] + self.to_k: Union[nn.Linear, nn.Identity] + self.to_v: Union[nn.Linear, nn.Identity] + + if use_combined_linear: + self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) + self.to_q = self.to_k = self.to_v = nn.Identity() # add to enable torchscript + self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) + else: + self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.qkv = nn.Identity() # add to enable torchscript + self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads) + self.out_rearrange = Rearrange("b l h d -> b h (l d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) + self.dropout_rate = dropout_rate self.scale = self.dim_head**-0.5 self.save_attn = save_attn self.att_mat = torch.Tensor() self.attention_dtype = attention_dtype self.causal = causal self.sequence_length = sequence_length + self.include_fc = include_fc + self.use_combined_linear = use_combined_linear self.use_flash_attention = use_flash_attention if causal and sequence_length is not None: @@ -130,26 +162,31 @@ def forward(self, x): Return: torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C """ - # output = self.input_rearrange(self.qkv(x)) - q = self.input_rearrange(self.to_q(x)) - k = self.input_rearrange(self.to_k(x)) - v = self.input_rearrange(self.to_v(x)) - # q, k, v = output[0], output[1], output[2] + if self.use_combined_linear: + output = self.input_rearrange(self.qkv(x)) + q, k, v = output[0], output[1], output[2] + else: + q = self.input_rearrange(self.to_q(x)) + k = self.input_rearrange(self.to_k(x)) + v = self.input_rearrange(self.to_v(x)) if self.attention_dtype is not None: q = q.to(self.attention_dtype) k = k.to(self.attention_dtype) if self.use_flash_attention: - x = F.scaled_dot_product_attention(q, k, v).transpose(1,2) + x = F.scaled_dot_product_attention( + query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined - att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + if self.rel_positional_embedding is not None: + att_mat = self.rel_positional_embedding(x, att_mat, q) if self.causal: - att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf")) + att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf")) att_mat = att_mat.softmax(dim=-1) @@ -160,7 +197,9 @@ def forward(self, x): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + x = self.out_rearrange(x) - # x = self.out_proj(x) + if self.include_fc: + x = self.out_proj(x) x = self.drop_output(x) return x diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index 7ab1816d8b..665442b55e 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -32,7 +32,13 @@ class SpatialAttentionBlock(nn.Module): spatial_dims: number of spatial dimensions, could be 1, 2, or 3. num_channels: number of input channels. Must be divisible by num_head_channels. num_head_channels: number of channels per head. + norm_num_groups: Number of groups for the group norm layer. + norm_eps: Epsilon for the normalization. attention_dtype: cast attention operations to this dtype. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ @@ -44,6 +50,8 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, attention_dtype: Optional[torch.dtype] = None, + include_fc: bool = True, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() @@ -53,9 +61,15 @@ def __init__( # check num_head_channels is divisible by num_channels if num_head_channels is not None and num_channels % num_head_channels != 0: raise ValueError("num_channels must be divisible by num_head_channels") - self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 self.attn = SABlock( - hidden_size=num_channels, num_heads=self.num_heads, qkv_bias=True, attention_dtype=attention_dtype, use_flash_attention=use_flash_attention + hidden_size=num_channels, + num_heads=num_heads, + qkv_bias=True, + attention_dtype=attention_dtype, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) def forward(self, x: torch.Tensor): diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 8e0c8dfdc5..f57fe251d2 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -66,6 +66,10 @@ class DiffusionUNetTransformerBlock(nn.Module): dropout: dropout probability to use. cross_attention_dim: size of the context vector for cross attention. upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. """ @@ -78,6 +82,8 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, ) -> None: super().__init__() self.attn1 = SABlock( @@ -87,6 +93,8 @@ def __init__( dim_head=num_head_channels, dropout_rate=dropout, attention_dtype=torch.float if upcast_attention else None, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) @@ -98,6 +106,7 @@ def __init__( dim_head=num_head_channels, dropout_rate=dropout, attention_dtype=torch.float if upcast_attention else None, + use_flash_attention=use_flash_attention, ) self.norm1 = nn.LayerNorm(num_channels) self.norm2 = nn.LayerNorm(num_channels) @@ -131,6 +140,11 @@ class SpatialTransformer(nn.Module): norm_eps: epsilon for the normalization. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + """ def __init__( @@ -145,6 +159,8 @@ def __init__( norm_eps: float = 1e-6, cross_attention_dim: int | None = None, upcast_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() @@ -173,6 +189,8 @@ def __init__( dropout=dropout, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) for _ in range(num_layers) @@ -528,6 +546,10 @@ class AttnDownBlock(nn.Module): resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. num_head_channels: number of channels in each attention head. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -543,6 +565,8 @@ def __init__( resblock_updown: bool = False, downsample_padding: int = 1, num_head_channels: int = 1, + include_fc: bool = True, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() @@ -570,6 +594,8 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) @@ -637,7 +663,11 @@ class CrossAttnDownBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -657,6 +687,8 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() @@ -690,6 +722,8 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) @@ -748,6 +782,10 @@ class AttnMidBlock(nn.Module): norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. num_head_channels: number of channels in each attention head. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -758,6 +796,8 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, num_head_channels: int = 1, + include_fc: bool = True, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() @@ -776,6 +816,8 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) @@ -813,6 +855,10 @@ class CrossAttnMidBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -827,6 +873,8 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() @@ -850,6 +898,8 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) self.resnet_2 = DiffusionUNetResnetBlock( @@ -996,6 +1046,10 @@ class AttnUpBlock(nn.Module): add_upsample: if True add downsample block. resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -1011,6 +1065,8 @@ def __init__( add_upsample: bool = True, resblock_updown: bool = False, num_head_channels: int = 1, + include_fc: bool = True, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() @@ -1040,6 +1096,8 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) @@ -1125,7 +1183,11 @@ class CrossAttnUpBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -1145,6 +1207,8 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() @@ -1179,6 +1243,8 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) @@ -1261,6 +1327,8 @@ def get_down_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> nn.Module: if with_attn: @@ -1275,6 +1343,8 @@ def get_down_block( add_downsample=add_downsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) elif with_cross_attn: @@ -1293,6 +1363,8 @@ def get_down_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) else: @@ -1321,6 +1393,8 @@ def get_mid_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> nn.Module: if with_conditioning: @@ -1335,6 +1409,8 @@ def get_mid_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) else: @@ -1345,6 +1421,8 @@ def get_mid_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, num_head_channels=num_head_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) @@ -1367,6 +1445,8 @@ def get_up_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> nn.Module: if with_attn: @@ -1382,6 +1462,8 @@ def get_up_block( add_upsample=add_upsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) elif with_cross_attn: @@ -1401,6 +1483,8 @@ def get_up_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) else: @@ -1439,9 +1523,13 @@ class DiffusionModelUNet(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` - classes. + classes. upcast_attention: if True, upcast attention operations to full precision. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -1462,6 +1550,9 @@ def __init__( num_class_embeds: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1556,6 +1647,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -1573,6 +1667,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) # up @@ -1607,6 +1704,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.up_blocks.append(up_block) @@ -1734,31 +1834,23 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] + new_state_dict[k] = old_state_dict.pop(k) # fix the attention blocks - attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] + attention_blocks = [k.replace(".out_proj.weight", "") for k in new_state_dict if "out_proj.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.attn1.to_q.weight"], - old_state_dict[f"{block}.attn1.to_k.weight"], - old_state_dict[f"{block}.attn1.to_v.weight"], - ], - dim=0, - ) - # projection - new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"] - new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"] + new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight") + new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias") - new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"] - new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"] # fix the upsample conv blocks which were renamed postconv for k in new_state_dict: if "postconv" in k: old_name = k.replace("postconv", "conv") - new_state_dict[k] = old_state_dict[old_name] + new_state_dict[k] = old_state_dict.pop(old_name) + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) self.load_state_dict(new_state_dict) @@ -1802,6 +1894,9 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1886,6 +1981,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) From 5c726535edc26b99cbffa4390ba6f0c92b59a7e4 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Fri, 9 Aug 2024 22:11:41 +0000 Subject: [PATCH 5/7] update Signed-off-by: Pengfei Guo --- .../networks/diffusion_model_unet_maisi.py | 19 +++++++++++++++++-- tests/test_diffusion_model_unet_maisi.py | 7 +------ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index 9730c7359b..d8c75faae1 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -37,11 +37,16 @@ from torch import nn from monai.networks.blocks import Convolution +from monai.networks.nets.diffusion_model_unet import ( + get_down_block, + get_mid_block, + get_timestep_embedding, + get_up_block, + zero_module, +) from monai.utils import ensure_tuple_rep from monai.utils.type_conversion import convert_to_tensor -from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_up_block, get_timestep_embedding, zero_module - __all__ = ["DiffusionModelUNetMaisi"] @@ -67,6 +72,8 @@ class DiffusionModelUNetMaisi(nn.Module): cross_attention_dim: Number of context dimensions to use. num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: If True, upcast attention operations to full precision. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers. include_top_region_index_input: If True, use top region index input. @@ -91,6 +98,8 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, use_flash_attention: bool = False, dropout_cattn: float = 0.0, include_top_region_index_input: bool = False, @@ -212,6 +221,8 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) @@ -230,6 +241,8 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) @@ -265,6 +278,8 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) diff --git a/tests/test_diffusion_model_unet_maisi.py b/tests/test_diffusion_model_unet_maisi.py index 059a4a4ba8..f9384e6d82 100644 --- a/tests/test_diffusion_model_unet_maisi.py +++ b/tests/test_diffusion_model_unet_maisi.py @@ -17,14 +17,11 @@ import torch from parameterized import parameterized +from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi from monai.networks import eval_mode from monai.utils import optional_import _, has_einops = optional_import("einops") -_, has_generative = optional_import("generative") - -if has_generative: - from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi UNCOND_CASES_2D = [ [ @@ -291,7 +288,6 @@ ] -@skipUnless(has_generative, "monai-generative required") class TestDiffusionModelUNetMaisi2D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_2D) @@ -510,7 +506,6 @@ def test_shape_with_additional_inputs(self, input_param): self.assertEqual(result.shape, (1, 1, 16, 16)) -@skipUnless(has_generative, "monai-generative required") class TestDiffusionModelUNetMaisi3D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_3D) From 147ec8499bf83f0fa50691e642044811111a3767 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Sun, 11 Aug 2024 18:57:08 +0000 Subject: [PATCH 6/7] update Signed-off-by: Pengfei Guo --- .../generation/maisi/networks/diffusion_model_unet_maisi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index d8c75faae1..f906267172 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -98,7 +98,7 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, - include_fc: bool = True, + include_fc: bool = False, use_combined_linear: bool = False, use_flash_attention: bool = False, dropout_cattn: float = 0.0, From 0cc5a49ad70a1daac3f9a53a456478c48c630390 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Sun, 11 Aug 2024 19:58:43 +0000 Subject: [PATCH 7/7] update Signed-off-by: Pengfei Guo --- .../generation/maisi/networks/diffusion_model_unet_maisi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index f906267172..e990b5fc98 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -72,8 +72,8 @@ class DiffusionModelUNetMaisi(nn.Module): cross_attention_dim: Number of context dimensions to use. num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: If True, upcast attention operations to full precision. - include_fc: whether to include the final linear layer. Default to True. - use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + include_fc: whether to include the final linear layer. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers. include_top_region_index_input: If True, use top region index input.