From 37cd5cd37ac7d89e9aedc31370b0e913bcacc8ab Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 6 Aug 2024 16:23:41 +0800 Subject: [PATCH 01/28] fix #7991 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/selfattention.py | 11 +++++++---- tests/test_selfattention.py | 11 +++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 3ab1e1fd10..03f24bf766 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -39,9 +39,10 @@ def __init__( hidden_input_size: int | None = None, causal: bool = False, sequence_length: int | None = None, - rel_pos_embedding: Optional[str] = None, - input_size: Optional[Tuple] = None, - attention_dtype: Optional[torch.dtype] = None, + rel_pos_embedding: str | None = None, + input_size: Tuple | None = None, + attention_dtype: torch.dtype | None = None, + include_fc: bool = True, ) -> None: """ Args: @@ -97,6 +98,7 @@ def __init__( self.attention_dtype = attention_dtype self.causal = causal self.sequence_length = sequence_length + self.include_fc = include_fc if causal and sequence_length is not None: # causal mask to ensure that attention is only applied to the left in the input sequence @@ -148,6 +150,7 @@ def forward(self, x): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) - x = self.out_proj(x) + if self.include_fc: + x = self.out_proj(x) x = self.drop_output(x) return x diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index d069d6aa30..91d2a553d4 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -22,6 +22,7 @@ from monai.networks.blocks.selfattention import SABlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import +from tests.utils import test_script_save einops, has_einops = optional_import("einops") @@ -138,6 +139,16 @@ def count_sablock_params(*args, **kwargs): nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2) self.assertEqual(nparams_default, nparams_default_more_heads) + @skipUnless(has_einops, "Requires einops") + def test_script(self): + for include_fc in [True, False]: + input_param = {'hidden_size': 360, 'num_heads': 4, 'dropout_rate': 0.0, 'rel_pos_embedding': None, 'input_size': (16, 32), "include_fc": include_fc} + net = SABlock(**input_param) + input_shape = (2, 512, 360) + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + + if __name__ == "__main__": unittest.main() From 63ba16d7a0b5cd0426fe1333e8398d279c841e84 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 08:25:34 +0000 Subject: [PATCH 02/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/blocks/selfattention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 03f24bf766..13e59236cd 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Optional, Tuple +from typing import Tuple import torch import torch.nn as nn From 7255a903180017d2cd80ae4a0ff4e39d1912fcf1 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 6 Aug 2024 16:25:45 +0800 Subject: [PATCH 03/28] add docstring Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/selfattention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 03f24bf766..40321b3e46 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -60,6 +60,7 @@ def __init__( input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. attention_dtype: cast attention operations to this dtype. + include_fc: whether to include the final linear layer. Default to True. """ From 0337d458b6986c7140f4d7d3539f4424592f4cf1 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 6 Aug 2024 17:17:01 +0800 Subject: [PATCH 04/28] fix #7992 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/selfattention.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 3ab1e1fd10..d9d6124797 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -42,6 +42,7 @@ def __init__( rel_pos_embedding: Optional[str] = None, input_size: Optional[Tuple] = None, attention_dtype: Optional[torch.dtype] = None, + use_combined_linear: bool = True, ) -> None: """ Args: @@ -59,6 +60,7 @@ def __init__( input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. attention_dtype: cast attention operations to this dtype. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. """ @@ -86,9 +88,17 @@ def __init__( self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) - self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) - self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) - self.out_rearrange = Rearrange("b h l d -> b l (h d)") + if use_combined_linear: + self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) + self.to_q = self.to_k = self.to_v = nn.Identity() # add to enable torchscript + self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) + else: + self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.qkv = nn.Identity() # add to enable torchscript + self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads) + self.out_rearrange = Rearrange("b l h d -> b h (l d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.scale = self.dim_head**-0.5 @@ -97,6 +107,7 @@ def __init__( self.attention_dtype = attention_dtype self.causal = causal self.sequence_length = sequence_length + self.use_combined_linear = use_combined_linear if causal and sequence_length is not None: # causal mask to ensure that attention is only applied to the left in the input sequence @@ -123,8 +134,13 @@ def forward(self, x): Return: torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C """ - output = self.input_rearrange(self.qkv(x)) - q, k, v = output[0], output[1], output[2] + if self.use_combined_linear: + output = self.input_rearrange(self.qkv(x)) + q, k, v = output[0], output[1], output[2] + else: + q = self.input_rearrange(self.to_q(x)) + k = self.input_rearrange(self.to_k(x)) + v = self.input_rearrange(self.to_v(x)) if self.attention_dtype is not None: q = q.to(self.attention_dtype) From 814e61a640d0862e4097c8b50a0b565b08322388 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 6 Aug 2024 17:22:59 +0800 Subject: [PATCH 05/28] add tests Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_selfattention.py | 40 ++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 91d2a553d4..432dfad2f0 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -32,18 +32,22 @@ for num_heads in [4, 6, 8, 12]: for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: for input_size in [(16, 32), (8, 8, 8)]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding, - "input_size": input_size, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for include_fc in [True, False]: + for use_combined_linear in [True, False]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + "include_fc": include_fc, + "use_combined_linear": use_combined_linear, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase): @@ -142,13 +146,21 @@ def count_sablock_params(*args, **kwargs): @skipUnless(has_einops, "Requires einops") def test_script(self): for include_fc in [True, False]: - input_param = {'hidden_size': 360, 'num_heads': 4, 'dropout_rate': 0.0, 'rel_pos_embedding': None, 'input_size': (16, 32), "include_fc": include_fc} + for use_combined_linear in [True, False]: + input_param = { + "hidden_size": 360, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": None, + "input_size": (16, 32), + "include_fc": include_fc, + "use_combined_linear": use_combined_linear, + } net = SABlock(**input_param) input_shape = (2, 512, 360) test_data = torch.randn(input_shape) test_script_save(net, test_data) - if __name__ == "__main__": unittest.main() From de9eef0b580e9a7aad7059224dba5792af31f9b9 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 7 Aug 2024 11:45:56 +0800 Subject: [PATCH 06/28] remove transpose in sablock Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/selfattention.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 3a8f1aba86..f2165a4cf7 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -172,13 +172,13 @@ def forward(self, x): if self.use_flash_attention: x = F.scaled_dot_product_attention( - query=q.transpose(1, 2), - key=k.transpose(1, 2), - value=v.transpose(1, 2), + query=q, + key=k, + value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal, - ).transpose(1, 2) + ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale From 2333351945fdec7f3e6fb8e795affda8710f4410 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 7 Aug 2024 13:10:15 +0800 Subject: [PATCH 07/28] fix docstring Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/crossattention.py | 27 +++++++++++-------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index daa5abdd56..8495406ff9 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -59,13 +59,13 @@ def __init__( causal (bool, optional): whether to use causal attention. sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length. rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only - "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional - parameter size. + parameter size. attention_dtype: cast attention operations to this dtype. use_flash_attention: if True, use Pytorch's inbuilt - flash attention for a memory efficient attention mechanism (see - https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + flash attention for a memory efficient attention mechanism (see + https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ super().__init__() @@ -107,9 +107,8 @@ def __init__( self.to_q = nn.Linear(self.hidden_input_size, inner_size, bias=qkv_bias) self.to_k = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias) self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias) - self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads) - self.out_rearrange = Rearrange("b h l d -> b l (h d)") + self.out_rearrange = Rearrange("b l h d -> b h (l d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.dropout_rate = dropout_rate @@ -162,21 +161,19 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): q = q.to(self.attention_dtype) k = k.to(self.attention_dtype) - q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) # - k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) - v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) + q = q.view(b, self.num_heads, t, c // self.num_heads) # (b, nh, t, hs) + k = k.view(b, self.num_heads, kv_t, c // self.num_heads) # (b, nh, kv_t, hs) + v = v.view(b, self.num_heads, kv_t, c // self.num_heads) # (b, nh, kv_t, hs) if self.use_flash_attention: x = torch.nn.functional.scaled_dot_product_attention( - query=q.transpose(1, 2), - key=k.transpose(1, 2), - value=v.transpose(1, 2), + query=q, + key=k, + value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal, - ).transpose( - 1, 2 - ) # Back to (b, nh, t, hs) + ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined From f9eb6d8cb9b1100dd5df3c67ae6042dbf0c95257 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 7 Aug 2024 22:16:20 +0800 Subject: [PATCH 08/28] use rearange Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/crossattention.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index 8495406ff9..674b4f0b86 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -107,6 +107,7 @@ def __init__( self.to_q = nn.Linear(self.hidden_input_size, inner_size, bias=qkv_bias) self.to_k = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias) self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias) + self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads) self.out_rearrange = Rearrange("b l h d -> b h (l d)") self.drop_output = nn.Dropout(dropout_rate) @@ -151,20 +152,16 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): # calculate query, key, values for all heads in batch and move head forward to be the batch dim b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) - q = self.to_q(x) + q = self.input_rearrange(self.to_q(x)) kv = context if context is not None else x _, kv_t, _ = kv.size() - k = self.to_k(kv) - v = self.to_v(kv) + k = self.input_rearrange(self.to_k(kv)) + v = self.input_rearrange(self.to_v(kv)) if self.attention_dtype is not None: q = q.to(self.attention_dtype) k = k.to(self.attention_dtype) - q = q.view(b, self.num_heads, t, c // self.num_heads) # (b, nh, t, hs) - k = k.view(b, self.num_heads, kv_t, c // self.num_heads) # (b, nh, kv_t, hs) - v = v.view(b, self.num_heads, kv_t, c // self.num_heads) # (b, nh, kv_t, hs) - if self.use_flash_attention: x = torch.nn.functional.scaled_dot_product_attention( query=q, @@ -174,6 +171,8 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): dropout_p=self.dropout_rate, is_causal=self.causal, ) + + else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined @@ -192,7 +191,9 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + x = self.out_rearrange(x) + print("reshape of monai 'b h l d -> b l (h d)':", x.shape) x = self.out_proj(x) x = self.drop_output(x) return x From 5aeccbeeac17333cdd1c4e2892ff5998fad36385 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Aug 2024 14:17:03 +0000 Subject: [PATCH 09/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/blocks/crossattention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index 674b4f0b86..df2e960edc 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -191,7 +191,7 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) - + x = self.out_rearrange(x) print("reshape of monai 'b h l d -> b l (h d)':", x.shape) x = self.out_proj(x) From 3154c7cc25b824d9b5d5b1fff9c98badef477a42 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 7 Aug 2024 22:19:13 +0800 Subject: [PATCH 10/28] minor fix Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/crossattention.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index df2e960edc..b330965d3f 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -171,8 +171,6 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): dropout_p=self.dropout_rate, is_causal=self.causal, ) - - else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined @@ -193,7 +191,6 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) - print("reshape of monai 'b h l d -> b l (h d)':", x.shape) x = self.out_proj(x) x = self.drop_output(x) return x From 81d3605816b903f9cbd0a55329c4371f032130b7 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 7 Aug 2024 22:54:43 +0800 Subject: [PATCH 11/28] add in SpatialAttentionBlock Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/spatialattention.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index 1cfafb1585..b1b6fc2961 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -46,6 +46,8 @@ def __init__( norm_eps: float = 1e-6, attention_dtype: Optional[torch.dtype] = None, use_flash_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = True, ) -> None: super().__init__() @@ -61,6 +63,8 @@ def __init__( qkv_bias=True, attention_dtype=attention_dtype, use_flash_attention=use_flash_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, ) def forward(self, x: torch.Tensor): From 754e7f2122624397c762cc80e0fd712454fc6cb6 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 7 Aug 2024 23:25:47 +0800 Subject: [PATCH 12/28] fix format Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/crossattention.py | 7 +------ monai/networks/blocks/selfattention.py | 7 +------ tests/test_selfattention.py | 2 +- 3 files changed, 3 insertions(+), 13 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index b330965d3f..42787cc770 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -164,12 +164,7 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): if self.use_flash_attention: x = torch.nn.functional.scaled_dot_product_attention( - query=q, - key=k, - value=v, - scale=self.scale, - dropout_p=self.dropout_rate, - is_causal=self.causal, + query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index f2165a4cf7..2adf8243f0 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -172,12 +172,7 @@ def forward(self, x): if self.use_flash_attention: x = F.scaled_dot_product_attention( - query=q, - key=k, - value=v, - scale=self.scale, - dropout_p=self.dropout_rate, - is_causal=self.causal, + query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index dde7d72a76..cd0c23c0b3 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -22,7 +22,7 @@ from monai.networks.blocks.selfattention import SABlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import -from tests.utils import test_script_save, SkipIfBeforePyTorchVersion +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save einops, has_einops = optional_import("einops") From 3cf212496dfb5144f05c9f994d809316cca89b0b Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 8 Aug 2024 00:04:36 +0800 Subject: [PATCH 13/28] add tests Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/crossattention.py | 4 + monai/networks/blocks/selfattention.py | 7 +- tests/test_crossattention.py | 241 ++++++++++--------- tests/test_selfattention.py | 303 +++++++++++++----------- 4 files changed, 299 insertions(+), 256 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index 42787cc770..a16893faca 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -185,6 +185,10 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + y = torch.nn.functional.scaled_dot_product_attention( + query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + ) + x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 2adf8243f0..0742e78e47 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -182,7 +182,7 @@ def forward(self, x): att_mat = self.rel_positional_embedding(x, att_mat, q) if self.causal: - att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf")) + att_mat = att_mat.masked_fill(self.causal_mask[:, :, :x.shape[-2], :x.shape[-2]] == 0, float("-inf")) att_mat = att_mat.softmax(dim=-1) @@ -193,6 +193,11 @@ def forward(self, x): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + + y = torch.nn.functional.scaled_dot_product_attention( + query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + ) + x = self.out_rearrange(x) if self.include_fc: x = self.out_proj(x) diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 44458147d6..2af7b21d13 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -22,7 +22,7 @@ from monai.networks.blocks.crossattention import CrossAttentionBlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion +from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose einops, has_einops = optional_import("einops") @@ -50,121 +50,136 @@ class TestResBlock(unittest.TestCase): - @parameterized.expand(TEST_CASE_CABLOCK) - @skipUnless(has_einops, "Requires einops") - @SkipIfBeforePyTorchVersion((2, 0)) - def test_shape(self, input_param, input_shape, expected_shape): - # Without flash attention - net = CrossAttentionBlock(**input_param) - with eval_mode(net): - result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"])) - self.assertEqual(result.shape, expected_shape) - - def test_ill_arg(self): - with self.assertRaises(ValueError): - CrossAttentionBlock(hidden_size=128, num_heads=12, dropout_rate=6.0) - - with self.assertRaises(ValueError): - CrossAttentionBlock(hidden_size=620, num_heads=8, dropout_rate=0.4) - - @SkipIfBeforePyTorchVersion((2, 0)) - def test_save_attn_with_flash_attention(self): - with self.assertRaises(ValueError): - CrossAttentionBlock( - hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True - ) - - @SkipIfBeforePyTorchVersion((2, 0)) - def test_rel_pos_embedding_with_flash_attention(self): - with self.assertRaises(ValueError): - CrossAttentionBlock( - hidden_size=128, - num_heads=3, - dropout_rate=0.1, - use_flash_attention=True, - save_attn=False, - rel_pos_embedding=RelPosEmbedding.DECOMPOSED, - ) - - @skipUnless(has_einops, "Requires einops") - def test_attention_dim_not_multiple_of_heads(self): - with self.assertRaises(ValueError): - CrossAttentionBlock(hidden_size=128, num_heads=3, dropout_rate=0.1) - - @skipUnless(has_einops, "Requires einops") - def test_inner_dim_different(self): - CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30) - - def test_causal_no_sequence_length(self): - with self.assertRaises(ValueError): - CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + # @parameterized.expand(TEST_CASE_CABLOCK) + # @skipUnless(has_einops, "Requires einops") + # @SkipIfBeforePyTorchVersion((2, 0)) + # def test_shape(self, input_param, input_shape, expected_shape): + # # Without flash attention + # net = CrossAttentionBlock(**input_param) + # with eval_mode(net): + # result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"])) + # self.assertEqual(result.shape, expected_shape) + + # def test_ill_arg(self): + # with self.assertRaises(ValueError): + # CrossAttentionBlock(hidden_size=128, num_heads=12, dropout_rate=6.0) + + # with self.assertRaises(ValueError): + # CrossAttentionBlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + + # @SkipIfBeforePyTorchVersion((2, 0)) + # def test_save_attn_with_flash_attention(self): + # with self.assertRaises(ValueError): + # CrossAttentionBlock( + # hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True + # ) + + # @SkipIfBeforePyTorchVersion((2, 0)) + # def test_rel_pos_embedding_with_flash_attention(self): + # with self.assertRaises(ValueError): + # CrossAttentionBlock( + # hidden_size=128, + # num_heads=3, + # dropout_rate=0.1, + # use_flash_attention=True, + # save_attn=False, + # rel_pos_embedding=RelPosEmbedding.DECOMPOSED, + # ) + + # @skipUnless(has_einops, "Requires einops") + # def test_attention_dim_not_multiple_of_heads(self): + # with self.assertRaises(ValueError): + # CrossAttentionBlock(hidden_size=128, num_heads=3, dropout_rate=0.1) + + # @skipUnless(has_einops, "Requires einops") + # def test_inner_dim_different(self): + # CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30) + + # def test_causal_no_sequence_length(self): + # with self.assertRaises(ValueError): + # CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + + # @skipUnless(has_einops, "Requires einops") + # @SkipIfBeforePyTorchVersion((2, 0)) + # def test_causal_flash_attention(self): + # block = CrossAttentionBlock( + # hidden_size=128, + # num_heads=1, + # dropout_rate=0.1, + # causal=True, + # sequence_length=16, + # save_attn=False, + # use_flash_attention=True, + # ) + # input_shape = (1, 16, 128) + # # Check it runs correctly + # block(torch.randn(input_shape)) + + # @skipUnless(has_einops, "Requires einops") + # def test_causal(self): + # block = CrossAttentionBlock( + # hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True + # ) + # input_shape = (1, 16, 128) + # block(torch.randn(input_shape)) + # # check upper triangular part of the attention matrix is zero + # assert torch.triu(block.att_mat, diagonal=1).sum() == 0 + + # @skipUnless(has_einops, "Requires einops") + # def test_context_input(self): + # block = CrossAttentionBlock( + # hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12 + # ) + # input_shape = (1, 16, 128) + # block(torch.randn(input_shape), context=torch.randn(1, 3, 12)) + + # @skipUnless(has_einops, "Requires einops") + # def test_context_wrong_input_size(self): + # block = CrossAttentionBlock( + # hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12 + # ) + # input_shape = (1, 16, 128) + # with self.assertRaises(RuntimeError): + # block(torch.randn(input_shape), context=torch.randn(1, 3, 24)) + + # @skipUnless(has_einops, "Requires einops") + # def test_access_attn_matrix(self): + # # input format + # hidden_size = 128 + # num_heads = 2 + # dropout_rate = 0 + # input_shape = (2, 256, hidden_size) + + # # be not able to access the matrix + # no_matrix_acess_blk = CrossAttentionBlock( + # hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate + # ) + # no_matrix_acess_blk(torch.randn(input_shape)) + # assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor) + # # no of elements is zero + # assert no_matrix_acess_blk.att_mat.nelement() == 0 + + # # be able to acess the attention matrix. + # matrix_acess_blk = CrossAttentionBlock( + # hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True + # ) + # matrix_acess_blk(torch.randn(input_shape)) + # assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) @skipUnless(has_einops, "Requires einops") @SkipIfBeforePyTorchVersion((2, 0)) - def test_causal_flash_attention(self): - block = CrossAttentionBlock( - hidden_size=128, - num_heads=1, - dropout_rate=0.1, - causal=True, - sequence_length=16, - save_attn=False, - use_flash_attention=True, - ) - input_shape = (1, 16, 128) - # Check it runs correctly - block(torch.randn(input_shape)) - - @skipUnless(has_einops, "Requires einops") - def test_causal(self): - block = CrossAttentionBlock( - hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True - ) - input_shape = (1, 16, 128) - block(torch.randn(input_shape)) - # check upper triangular part of the attention matrix is zero - assert torch.triu(block.att_mat, diagonal=1).sum() == 0 - - @skipUnless(has_einops, "Requires einops") - def test_context_input(self): - block = CrossAttentionBlock( - hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12 - ) - input_shape = (1, 16, 128) - block(torch.randn(input_shape), context=torch.randn(1, 3, 12)) - - @skipUnless(has_einops, "Requires einops") - def test_context_wrong_input_size(self): - block = CrossAttentionBlock( - hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12 - ) - input_shape = (1, 16, 128) - with self.assertRaises(RuntimeError): - block(torch.randn(input_shape), context=torch.randn(1, 3, 24)) - - @skipUnless(has_einops, "Requires einops") - def test_access_attn_matrix(self): - # input format - hidden_size = 128 - num_heads = 2 - dropout_rate = 0 - input_shape = (2, 256, hidden_size) - - # be not able to access the matrix - no_matrix_acess_blk = CrossAttentionBlock( - hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate - ) - no_matrix_acess_blk(torch.randn(input_shape)) - assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor) - # no of elements is zero - assert no_matrix_acess_blk.att_mat.nelement() == 0 - - # be able to acess the attention matrix. - matrix_acess_blk = CrossAttentionBlock( - hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True - ) - matrix_acess_blk(torch.randn(input_shape)) - assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + def test_flash_attention(self): + for causal in [True, False]: + input_param = {"hidden_size": 128, "num_heads": 1, 'causal': causal, 'sequence_length': 16 if causal else None} + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_w_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device) + block_wo_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device) + block_wo_flash_attention.load_state_dict(block_w_flash_attention.state_dict()) + test_data = torch.randn(1, 16, 128).to(device) + + out_1 = block_w_flash_attention(test_data) + out_2 = block_wo_flash_attention(test_data) + assert_allclose(out_1, out_2, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index cd0c23c0b3..db3e14efb5 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -22,7 +22,7 @@ from monai.networks.blocks.selfattention import SABlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion, test_script_save +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save, assert_allclose einops, has_einops = optional_import("einops") @@ -54,149 +54,168 @@ class TestResBlock(unittest.TestCase): - @parameterized.expand(TEST_CASE_SABLOCK) + # @parameterized.expand(TEST_CASE_SABLOCK) + # @skipUnless(has_einops, "Requires einops") + # @SkipIfBeforePyTorchVersion((2, 0)) + # def test_shape(self, input_param, input_shape, expected_shape): + # net = SABlock(**input_param) + # with eval_mode(net): + # result = net(torch.randn(input_shape)) + # self.assertEqual(result.shape, expected_shape) + + # def test_ill_arg(self): + # with self.assertRaises(ValueError): + # SABlock(hidden_size=128, num_heads=12, dropout_rate=6.0) + + # with self.assertRaises(ValueError): + # SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + + # @SkipIfBeforePyTorchVersion((2, 0)) + # def test_rel_pos_embedding_with_flash_attention(self): + # with self.assertRaises(ValueError): + # SABlock( + # hidden_size=128, + # num_heads=3, + # dropout_rate=0.1, + # use_flash_attention=True, + # save_attn=False, + # rel_pos_embedding=RelPosEmbedding.DECOMPOSED, + # ) + + # @SkipIfBeforePyTorchVersion((1, 13)) + # def test_save_attn_with_flash_attention(self): + # with self.assertRaises(ValueError): + # SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True) + + # def test_attention_dim_not_multiple_of_heads(self): + # with self.assertRaises(ValueError): + # SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1) + + # @skipUnless(has_einops, "Requires einops") + # def test_inner_dim_different(self): + # SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30) + + # def test_causal_no_sequence_length(self): + # with self.assertRaises(ValueError): + # SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + + # @skipUnless(has_einops, "Requires einops") + # @SkipIfBeforePyTorchVersion((2, 0)) + # def test_causal_flash_attention(self): + # block = SABlock( + # hidden_size=128, + # num_heads=1, + # dropout_rate=0.1, + # causal=True, + # sequence_length=16, + # save_attn=False, + # use_flash_attention=True, + # ) + # input_shape = (1, 16, 128) + # # Check it runs correctly + # block(torch.randn(input_shape)) + + # @skipUnless(has_einops, "Requires einops") + # def test_causal(self): + # block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True) + # input_shape = (1, 16, 128) + # block(torch.randn(input_shape)) + # # check upper triangular part of the attention matrix is zero + # assert torch.triu(block.att_mat, diagonal=1).sum() == 0 + + # @skipUnless(has_einops, "Requires einops") + # def test_access_attn_matrix(self): + # # input format + # hidden_size = 128 + # num_heads = 2 + # dropout_rate = 0 + # input_shape = (2, 256, hidden_size) + + # # be not able to access the matrix + # no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate) + # no_matrix_acess_blk(torch.randn(input_shape)) + # assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor) + # # no of elements is zero + # assert no_matrix_acess_blk.att_mat.nelement() == 0 + + # # be able to acess the attention matrix + # matrix_acess_blk = SABlock( + # hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True + # ) + # matrix_acess_blk(torch.randn(input_shape)) + # assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + + # def test_number_of_parameters(self): + + # def count_sablock_params(*args, **kwargs): + # """Count the number of parameters in a SABlock.""" + # sablock = SABlock(*args, **kwargs) + # return sum([x.numel() for x in sablock.parameters() if x.requires_grad]) + + # hidden_size = 128 + # num_heads = 8 + # default_dim_head = hidden_size // num_heads + + # # Default dim_head is hidden_size // num_heads + # nparams_default = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads) + # nparams_like_default = count_sablock_params( + # hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head + # ) + # self.assertEqual(nparams_default, nparams_like_default) + + # # Increasing dim_head should increase the number of parameters + # nparams_custom_large = count_sablock_params( + # hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head * 2 + # ) + # self.assertGreater(nparams_custom_large, nparams_default) + + # # Decreasing dim_head should decrease the number of parameters + # nparams_custom_small = count_sablock_params( + # hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head // 2 + # ) + # self.assertGreater(nparams_default, nparams_custom_small) + + # # Increasing the number of heads with the default behaviour should not change the number of params. + # nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2) + # self.assertEqual(nparams_default, nparams_default_more_heads) + + # @skipUnless(has_einops, "Requires einops") + # def test_script(self): + # for include_fc in [True, False]: + # for use_combined_linear in [True, False]: + # input_param = { + # "hidden_size": 360, + # "num_heads": 4, + # "dropout_rate": 0.0, + # "rel_pos_embedding": None, + # "input_size": (16, 32), + # "include_fc": include_fc, + # "use_combined_linear": use_combined_linear, + # } + # net = SABlock(**input_param) + # input_shape = (2, 512, 360) + # test_data = torch.randn(input_shape) + # test_script_save(net, test_data) + @skipUnless(has_einops, "Requires einops") @SkipIfBeforePyTorchVersion((2, 0)) - def test_shape(self, input_param, input_shape, expected_shape): - net = SABlock(**input_param) - with eval_mode(net): - result = net(torch.randn(input_shape)) - self.assertEqual(result.shape, expected_shape) - - def test_ill_arg(self): - with self.assertRaises(ValueError): - SABlock(hidden_size=128, num_heads=12, dropout_rate=6.0) - - with self.assertRaises(ValueError): - SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) - - @SkipIfBeforePyTorchVersion((2, 0)) - def test_rel_pos_embedding_with_flash_attention(self): - with self.assertRaises(ValueError): - SABlock( - hidden_size=128, - num_heads=3, - dropout_rate=0.1, - use_flash_attention=True, - save_attn=False, - rel_pos_embedding=RelPosEmbedding.DECOMPOSED, - ) - - @SkipIfBeforePyTorchVersion((1, 13)) - def test_save_attn_with_flash_attention(self): - with self.assertRaises(ValueError): - SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True) - - def test_attention_dim_not_multiple_of_heads(self): - with self.assertRaises(ValueError): - SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1) - - @skipUnless(has_einops, "Requires einops") - def test_inner_dim_different(self): - SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30) - - def test_causal_no_sequence_length(self): - with self.assertRaises(ValueError): - SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) - - @skipUnless(has_einops, "Requires einops") - @SkipIfBeforePyTorchVersion((2, 0)) - def test_causal_flash_attention(self): - block = SABlock( - hidden_size=128, - num_heads=1, - dropout_rate=0.1, - causal=True, - sequence_length=16, - save_attn=False, - use_flash_attention=True, - ) - input_shape = (1, 16, 128) - # Check it runs correctly - block(torch.randn(input_shape)) - - @skipUnless(has_einops, "Requires einops") - def test_causal(self): - block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True) - input_shape = (1, 16, 128) - block(torch.randn(input_shape)) - # check upper triangular part of the attention matrix is zero - assert torch.triu(block.att_mat, diagonal=1).sum() == 0 - - @skipUnless(has_einops, "Requires einops") - def test_access_attn_matrix(self): - # input format - hidden_size = 128 - num_heads = 2 - dropout_rate = 0 - input_shape = (2, 256, hidden_size) - - # be not able to access the matrix - no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate) - no_matrix_acess_blk(torch.randn(input_shape)) - assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor) - # no of elements is zero - assert no_matrix_acess_blk.att_mat.nelement() == 0 - - # be able to acess the attention matrix - matrix_acess_blk = SABlock( - hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True - ) - matrix_acess_blk(torch.randn(input_shape)) - assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) - - def test_number_of_parameters(self): - - def count_sablock_params(*args, **kwargs): - """Count the number of parameters in a SABlock.""" - sablock = SABlock(*args, **kwargs) - return sum([x.numel() for x in sablock.parameters() if x.requires_grad]) - - hidden_size = 128 - num_heads = 8 - default_dim_head = hidden_size // num_heads - - # Default dim_head is hidden_size // num_heads - nparams_default = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads) - nparams_like_default = count_sablock_params( - hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head - ) - self.assertEqual(nparams_default, nparams_like_default) - - # Increasing dim_head should increase the number of parameters - nparams_custom_large = count_sablock_params( - hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head * 2 - ) - self.assertGreater(nparams_custom_large, nparams_default) - - # Decreasing dim_head should decrease the number of parameters - nparams_custom_small = count_sablock_params( - hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head // 2 - ) - self.assertGreater(nparams_default, nparams_custom_small) - - # Increasing the number of heads with the default behaviour should not change the number of params. - nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2) - self.assertEqual(nparams_default, nparams_default_more_heads) - - @skipUnless(has_einops, "Requires einops") - def test_script(self): - for include_fc in [True, False]: - for use_combined_linear in [True, False]: - input_param = { - "hidden_size": 360, - "num_heads": 4, - "dropout_rate": 0.0, - "rel_pos_embedding": None, - "input_size": (16, 32), - "include_fc": include_fc, - "use_combined_linear": use_combined_linear, - } - net = SABlock(**input_param) - input_shape = (2, 512, 360) - test_data = torch.randn(input_shape) - test_script_save(net, test_data) - + def test_flash_attention(self): + for causal in [True, False]: + input_param = { + "hidden_size": 360, + "num_heads": 4, + "input_size": (16, 32), + "causal": causal, + } + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_w_flash_attention = SABlock(**input_param, use_flash_attention=True).to(device) + block_wo_flash_attention = SABlock(**input_param, use_flash_attention=False).to(device) + block_wo_flash_attention.load_state_dict(block_w_flash_attention.state_dict()) + test_data = torch.randn(2, 512, 360).to(device) + + out_1 = block_w_flash_attention(test_data) + out_2 = block_wo_flash_attention(test_data) + assert_allclose(out_1, out_2, atol=1e-4) if __name__ == "__main__": unittest.main() From 8de91eb30428f71b354598440ca955e432a4b4bd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Aug 2024 16:05:15 +0000 Subject: [PATCH 14/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/blocks/selfattention.py | 2 +- tests/test_crossattention.py | 2 -- tests/test_selfattention.py | 6 ++---- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 0742e78e47..6548805747 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -193,7 +193,7 @@ def forward(self, x): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) - + y = torch.nn.functional.scaled_dot_product_attention( query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal ) diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 2af7b21d13..8a22d3bfaf 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -16,9 +16,7 @@ import numpy as np import torch -from parameterized import parameterized -from monai.networks import eval_mode from monai.networks.blocks.crossattention import CrossAttentionBlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index db3e14efb5..c7851da055 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -16,13 +16,11 @@ import numpy as np import torch -from parameterized import parameterized -from monai.networks import eval_mode from monai.networks.blocks.selfattention import SABlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion, test_script_save, assert_allclose +from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose einops, has_einops = optional_import("einops") @@ -196,7 +194,7 @@ class TestResBlock(unittest.TestCase): # input_shape = (2, 512, 360) # test_data = torch.randn(input_shape) # test_script_save(net, test_data) - + @skipUnless(has_einops, "Requires einops") @SkipIfBeforePyTorchVersion((2, 0)) def test_flash_attention(self): From 0b556a5546afa40ac1c68743a34dcd1715394661 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 8 Aug 2024 00:09:26 +0800 Subject: [PATCH 15/28] minor fix Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_crossattention.py | 230 ++++++++++++++-------------- tests/test_selfattention.py | 284 +++++++++++++++++------------------ 2 files changed, 257 insertions(+), 257 deletions(-) diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 2af7b21d13..33d62f7284 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -50,121 +50,121 @@ class TestResBlock(unittest.TestCase): - # @parameterized.expand(TEST_CASE_CABLOCK) - # @skipUnless(has_einops, "Requires einops") - # @SkipIfBeforePyTorchVersion((2, 0)) - # def test_shape(self, input_param, input_shape, expected_shape): - # # Without flash attention - # net = CrossAttentionBlock(**input_param) - # with eval_mode(net): - # result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"])) - # self.assertEqual(result.shape, expected_shape) - - # def test_ill_arg(self): - # with self.assertRaises(ValueError): - # CrossAttentionBlock(hidden_size=128, num_heads=12, dropout_rate=6.0) - - # with self.assertRaises(ValueError): - # CrossAttentionBlock(hidden_size=620, num_heads=8, dropout_rate=0.4) - - # @SkipIfBeforePyTorchVersion((2, 0)) - # def test_save_attn_with_flash_attention(self): - # with self.assertRaises(ValueError): - # CrossAttentionBlock( - # hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True - # ) - - # @SkipIfBeforePyTorchVersion((2, 0)) - # def test_rel_pos_embedding_with_flash_attention(self): - # with self.assertRaises(ValueError): - # CrossAttentionBlock( - # hidden_size=128, - # num_heads=3, - # dropout_rate=0.1, - # use_flash_attention=True, - # save_attn=False, - # rel_pos_embedding=RelPosEmbedding.DECOMPOSED, - # ) - - # @skipUnless(has_einops, "Requires einops") - # def test_attention_dim_not_multiple_of_heads(self): - # with self.assertRaises(ValueError): - # CrossAttentionBlock(hidden_size=128, num_heads=3, dropout_rate=0.1) - - # @skipUnless(has_einops, "Requires einops") - # def test_inner_dim_different(self): - # CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30) - - # def test_causal_no_sequence_length(self): - # with self.assertRaises(ValueError): - # CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) - - # @skipUnless(has_einops, "Requires einops") - # @SkipIfBeforePyTorchVersion((2, 0)) - # def test_causal_flash_attention(self): - # block = CrossAttentionBlock( - # hidden_size=128, - # num_heads=1, - # dropout_rate=0.1, - # causal=True, - # sequence_length=16, - # save_attn=False, - # use_flash_attention=True, - # ) - # input_shape = (1, 16, 128) - # # Check it runs correctly - # block(torch.randn(input_shape)) - - # @skipUnless(has_einops, "Requires einops") - # def test_causal(self): - # block = CrossAttentionBlock( - # hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True - # ) - # input_shape = (1, 16, 128) - # block(torch.randn(input_shape)) - # # check upper triangular part of the attention matrix is zero - # assert torch.triu(block.att_mat, diagonal=1).sum() == 0 - - # @skipUnless(has_einops, "Requires einops") - # def test_context_input(self): - # block = CrossAttentionBlock( - # hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12 - # ) - # input_shape = (1, 16, 128) - # block(torch.randn(input_shape), context=torch.randn(1, 3, 12)) - - # @skipUnless(has_einops, "Requires einops") - # def test_context_wrong_input_size(self): - # block = CrossAttentionBlock( - # hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12 - # ) - # input_shape = (1, 16, 128) - # with self.assertRaises(RuntimeError): - # block(torch.randn(input_shape), context=torch.randn(1, 3, 24)) - - # @skipUnless(has_einops, "Requires einops") - # def test_access_attn_matrix(self): - # # input format - # hidden_size = 128 - # num_heads = 2 - # dropout_rate = 0 - # input_shape = (2, 256, hidden_size) - - # # be not able to access the matrix - # no_matrix_acess_blk = CrossAttentionBlock( - # hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate - # ) - # no_matrix_acess_blk(torch.randn(input_shape)) - # assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor) - # # no of elements is zero - # assert no_matrix_acess_blk.att_mat.nelement() == 0 - - # # be able to acess the attention matrix. - # matrix_acess_blk = CrossAttentionBlock( - # hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True - # ) - # matrix_acess_blk(torch.randn(input_shape)) - # assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + @parameterized.expand(TEST_CASE_CABLOCK) + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_shape(self, input_param, input_shape, expected_shape): + # Without flash attention + net = CrossAttentionBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"])) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + CrossAttentionBlock(hidden_size=128, num_heads=12, dropout_rate=6.0) + + with self.assertRaises(ValueError): + CrossAttentionBlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + + @SkipIfBeforePyTorchVersion((2, 0)) + def test_save_attn_with_flash_attention(self): + with self.assertRaises(ValueError): + CrossAttentionBlock( + hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True + ) + + @SkipIfBeforePyTorchVersion((2, 0)) + def test_rel_pos_embedding_with_flash_attention(self): + with self.assertRaises(ValueError): + CrossAttentionBlock( + hidden_size=128, + num_heads=3, + dropout_rate=0.1, + use_flash_attention=True, + save_attn=False, + rel_pos_embedding=RelPosEmbedding.DECOMPOSED, + ) + + @skipUnless(has_einops, "Requires einops") + def test_attention_dim_not_multiple_of_heads(self): + with self.assertRaises(ValueError): + CrossAttentionBlock(hidden_size=128, num_heads=3, dropout_rate=0.1) + + @skipUnless(has_einops, "Requires einops") + def test_inner_dim_different(self): + CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30) + + def test_causal_no_sequence_length(self): + with self.assertRaises(ValueError): + CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_causal_flash_attention(self): + block = CrossAttentionBlock( + hidden_size=128, + num_heads=1, + dropout_rate=0.1, + causal=True, + sequence_length=16, + save_attn=False, + use_flash_attention=True, + ) + input_shape = (1, 16, 128) + # Check it runs correctly + block(torch.randn(input_shape)) + + @skipUnless(has_einops, "Requires einops") + def test_causal(self): + block = CrossAttentionBlock( + hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True + ) + input_shape = (1, 16, 128) + block(torch.randn(input_shape)) + # check upper triangular part of the attention matrix is zero + assert torch.triu(block.att_mat, diagonal=1).sum() == 0 + + @skipUnless(has_einops, "Requires einops") + def test_context_input(self): + block = CrossAttentionBlock( + hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12 + ) + input_shape = (1, 16, 128) + block(torch.randn(input_shape), context=torch.randn(1, 3, 12)) + + @skipUnless(has_einops, "Requires einops") + def test_context_wrong_input_size(self): + block = CrossAttentionBlock( + hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12 + ) + input_shape = (1, 16, 128) + with self.assertRaises(RuntimeError): + block(torch.randn(input_shape), context=torch.randn(1, 3, 24)) + + @skipUnless(has_einops, "Requires einops") + def test_access_attn_matrix(self): + # input format + hidden_size = 128 + num_heads = 2 + dropout_rate = 0 + input_shape = (2, 256, hidden_size) + + # be not able to access the matrix + no_matrix_acess_blk = CrossAttentionBlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate + ) + no_matrix_acess_blk(torch.randn(input_shape)) + assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor) + # no of elements is zero + assert no_matrix_acess_blk.att_mat.nelement() == 0 + + # be able to acess the attention matrix. + matrix_acess_blk = CrossAttentionBlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True + ) + matrix_acess_blk(torch.randn(input_shape)) + assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) @skipUnless(has_einops, "Requires einops") @SkipIfBeforePyTorchVersion((2, 0)) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index db3e14efb5..eaf67bf8d8 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -54,148 +54,148 @@ class TestResBlock(unittest.TestCase): - # @parameterized.expand(TEST_CASE_SABLOCK) - # @skipUnless(has_einops, "Requires einops") - # @SkipIfBeforePyTorchVersion((2, 0)) - # def test_shape(self, input_param, input_shape, expected_shape): - # net = SABlock(**input_param) - # with eval_mode(net): - # result = net(torch.randn(input_shape)) - # self.assertEqual(result.shape, expected_shape) - - # def test_ill_arg(self): - # with self.assertRaises(ValueError): - # SABlock(hidden_size=128, num_heads=12, dropout_rate=6.0) - - # with self.assertRaises(ValueError): - # SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) - - # @SkipIfBeforePyTorchVersion((2, 0)) - # def test_rel_pos_embedding_with_flash_attention(self): - # with self.assertRaises(ValueError): - # SABlock( - # hidden_size=128, - # num_heads=3, - # dropout_rate=0.1, - # use_flash_attention=True, - # save_attn=False, - # rel_pos_embedding=RelPosEmbedding.DECOMPOSED, - # ) - - # @SkipIfBeforePyTorchVersion((1, 13)) - # def test_save_attn_with_flash_attention(self): - # with self.assertRaises(ValueError): - # SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True) - - # def test_attention_dim_not_multiple_of_heads(self): - # with self.assertRaises(ValueError): - # SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1) - - # @skipUnless(has_einops, "Requires einops") - # def test_inner_dim_different(self): - # SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30) - - # def test_causal_no_sequence_length(self): - # with self.assertRaises(ValueError): - # SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) - - # @skipUnless(has_einops, "Requires einops") - # @SkipIfBeforePyTorchVersion((2, 0)) - # def test_causal_flash_attention(self): - # block = SABlock( - # hidden_size=128, - # num_heads=1, - # dropout_rate=0.1, - # causal=True, - # sequence_length=16, - # save_attn=False, - # use_flash_attention=True, - # ) - # input_shape = (1, 16, 128) - # # Check it runs correctly - # block(torch.randn(input_shape)) - - # @skipUnless(has_einops, "Requires einops") - # def test_causal(self): - # block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True) - # input_shape = (1, 16, 128) - # block(torch.randn(input_shape)) - # # check upper triangular part of the attention matrix is zero - # assert torch.triu(block.att_mat, diagonal=1).sum() == 0 - - # @skipUnless(has_einops, "Requires einops") - # def test_access_attn_matrix(self): - # # input format - # hidden_size = 128 - # num_heads = 2 - # dropout_rate = 0 - # input_shape = (2, 256, hidden_size) - - # # be not able to access the matrix - # no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate) - # no_matrix_acess_blk(torch.randn(input_shape)) - # assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor) - # # no of elements is zero - # assert no_matrix_acess_blk.att_mat.nelement() == 0 - - # # be able to acess the attention matrix - # matrix_acess_blk = SABlock( - # hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True - # ) - # matrix_acess_blk(torch.randn(input_shape)) - # assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) - - # def test_number_of_parameters(self): - - # def count_sablock_params(*args, **kwargs): - # """Count the number of parameters in a SABlock.""" - # sablock = SABlock(*args, **kwargs) - # return sum([x.numel() for x in sablock.parameters() if x.requires_grad]) - - # hidden_size = 128 - # num_heads = 8 - # default_dim_head = hidden_size // num_heads - - # # Default dim_head is hidden_size // num_heads - # nparams_default = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads) - # nparams_like_default = count_sablock_params( - # hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head - # ) - # self.assertEqual(nparams_default, nparams_like_default) - - # # Increasing dim_head should increase the number of parameters - # nparams_custom_large = count_sablock_params( - # hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head * 2 - # ) - # self.assertGreater(nparams_custom_large, nparams_default) - - # # Decreasing dim_head should decrease the number of parameters - # nparams_custom_small = count_sablock_params( - # hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head // 2 - # ) - # self.assertGreater(nparams_default, nparams_custom_small) - - # # Increasing the number of heads with the default behaviour should not change the number of params. - # nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2) - # self.assertEqual(nparams_default, nparams_default_more_heads) - - # @skipUnless(has_einops, "Requires einops") - # def test_script(self): - # for include_fc in [True, False]: - # for use_combined_linear in [True, False]: - # input_param = { - # "hidden_size": 360, - # "num_heads": 4, - # "dropout_rate": 0.0, - # "rel_pos_embedding": None, - # "input_size": (16, 32), - # "include_fc": include_fc, - # "use_combined_linear": use_combined_linear, - # } - # net = SABlock(**input_param) - # input_shape = (2, 512, 360) - # test_data = torch.randn(input_shape) - # test_script_save(net, test_data) + @parameterized.expand(TEST_CASE_SABLOCK) + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_shape(self, input_param, input_shape, expected_shape): + net = SABlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + SABlock(hidden_size=128, num_heads=12, dropout_rate=6.0) + + with self.assertRaises(ValueError): + SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + + @SkipIfBeforePyTorchVersion((2, 0)) + def test_rel_pos_embedding_with_flash_attention(self): + with self.assertRaises(ValueError): + SABlock( + hidden_size=128, + num_heads=3, + dropout_rate=0.1, + use_flash_attention=True, + save_attn=False, + rel_pos_embedding=RelPosEmbedding.DECOMPOSED, + ) + + @SkipIfBeforePyTorchVersion((1, 13)) + def test_save_attn_with_flash_attention(self): + with self.assertRaises(ValueError): + SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True) + + def test_attention_dim_not_multiple_of_heads(self): + with self.assertRaises(ValueError): + SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1) + + @skipUnless(has_einops, "Requires einops") + def test_inner_dim_different(self): + SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30) + + def test_causal_no_sequence_length(self): + with self.assertRaises(ValueError): + SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_causal_flash_attention(self): + block = SABlock( + hidden_size=128, + num_heads=1, + dropout_rate=0.1, + causal=True, + sequence_length=16, + save_attn=False, + use_flash_attention=True, + ) + input_shape = (1, 16, 128) + # Check it runs correctly + block(torch.randn(input_shape)) + + @skipUnless(has_einops, "Requires einops") + def test_causal(self): + block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True) + input_shape = (1, 16, 128) + block(torch.randn(input_shape)) + # check upper triangular part of the attention matrix is zero + assert torch.triu(block.att_mat, diagonal=1).sum() == 0 + + @skipUnless(has_einops, "Requires einops") + def test_access_attn_matrix(self): + # input format + hidden_size = 128 + num_heads = 2 + dropout_rate = 0 + input_shape = (2, 256, hidden_size) + + # be not able to access the matrix + no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate) + no_matrix_acess_blk(torch.randn(input_shape)) + assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor) + # no of elements is zero + assert no_matrix_acess_blk.att_mat.nelement() == 0 + + # be able to acess the attention matrix + matrix_acess_blk = SABlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True + ) + matrix_acess_blk(torch.randn(input_shape)) + assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + + def test_number_of_parameters(self): + + def count_sablock_params(*args, **kwargs): + """Count the number of parameters in a SABlock.""" + sablock = SABlock(*args, **kwargs) + return sum([x.numel() for x in sablock.parameters() if x.requires_grad]) + + hidden_size = 128 + num_heads = 8 + default_dim_head = hidden_size // num_heads + + # Default dim_head is hidden_size // num_heads + nparams_default = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads) + nparams_like_default = count_sablock_params( + hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head + ) + self.assertEqual(nparams_default, nparams_like_default) + + # Increasing dim_head should increase the number of parameters + nparams_custom_large = count_sablock_params( + hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head * 2 + ) + self.assertGreater(nparams_custom_large, nparams_default) + + # Decreasing dim_head should decrease the number of parameters + nparams_custom_small = count_sablock_params( + hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head // 2 + ) + self.assertGreater(nparams_default, nparams_custom_small) + + # Increasing the number of heads with the default behaviour should not change the number of params. + nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2) + self.assertEqual(nparams_default, nparams_default_more_heads) + + @skipUnless(has_einops, "Requires einops") + def test_script(self): + for include_fc in [True, False]: + for use_combined_linear in [True, False]: + input_param = { + "hidden_size": 360, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": None, + "input_size": (16, 32), + "include_fc": include_fc, + "use_combined_linear": use_combined_linear, + } + net = SABlock(**input_param) + input_shape = (2, 512, 360) + test_data = torch.randn(input_shape) + test_script_save(net, test_data) @skipUnless(has_einops, "Requires einops") @SkipIfBeforePyTorchVersion((2, 0)) From 9a59a15c3176ccea3042ea194ac09150465f85e4 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 8 Aug 2024 00:10:34 +0800 Subject: [PATCH 16/28] minor fix Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/crossattention.py | 3 --- monai/networks/blocks/selfattention.py | 7 ++----- tests/test_crossattention.py | 7 ++++++- tests/test_selfattention.py | 12 ++++-------- 4 files changed, 12 insertions(+), 17 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index a16893faca..06e75c49e4 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -185,9 +185,6 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) - y = torch.nn.functional.scaled_dot_product_attention( - query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal - ) x = self.out_rearrange(x) x = self.out_proj(x) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 0742e78e47..abb31abe55 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -182,7 +182,7 @@ def forward(self, x): att_mat = self.rel_positional_embedding(x, att_mat, q) if self.causal: - att_mat = att_mat.masked_fill(self.causal_mask[:, :, :x.shape[-2], :x.shape[-2]] == 0, float("-inf")) + att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf")) att_mat = att_mat.softmax(dim=-1) @@ -193,10 +193,7 @@ def forward(self, x): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) - - y = torch.nn.functional.scaled_dot_product_attention( - query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal - ) + x = self.out_rearrange(x) if self.include_fc: diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 33d62f7284..5af09ed9b4 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -170,7 +170,12 @@ def test_access_attn_matrix(self): @SkipIfBeforePyTorchVersion((2, 0)) def test_flash_attention(self): for causal in [True, False]: - input_param = {"hidden_size": 128, "num_heads": 1, 'causal': causal, 'sequence_length': 16 if causal else None} + input_param = { + "hidden_size": 128, + "num_heads": 1, + "causal": causal, + "sequence_length": 16 if causal else None, + } device = "cuda:0" if torch.cuda.is_available() else "cpu" block_w_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device) block_wo_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index eaf67bf8d8..5fb9669098 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -22,7 +22,7 @@ from monai.networks.blocks.selfattention import SABlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion, test_script_save, assert_allclose +from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose, test_script_save einops, has_einops = optional_import("einops") @@ -196,17 +196,12 @@ def test_script(self): input_shape = (2, 512, 360) test_data = torch.randn(input_shape) test_script_save(net, test_data) - + @skipUnless(has_einops, "Requires einops") @SkipIfBeforePyTorchVersion((2, 0)) def test_flash_attention(self): for causal in [True, False]: - input_param = { - "hidden_size": 360, - "num_heads": 4, - "input_size": (16, 32), - "causal": causal, - } + input_param = {"hidden_size": 360, "num_heads": 4, "input_size": (16, 32), "causal": causal} device = "cuda:0" if torch.cuda.is_available() else "cpu" block_w_flash_attention = SABlock(**input_param, use_flash_attention=True).to(device) block_wo_flash_attention = SABlock(**input_param, use_flash_attention=False).to(device) @@ -217,5 +212,6 @@ def test_flash_attention(self): out_2 = block_wo_flash_attention(test_data) assert_allclose(out_1, out_2, atol=1e-4) + if __name__ == "__main__": unittest.main() From 05e42ce93c98644cd6641c2c15fcf247067f9f3a Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 8 Aug 2024 00:10:52 +0800 Subject: [PATCH 17/28] format fix Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/crossattention.py | 1 - monai/networks/blocks/selfattention.py | 1 - 2 files changed, 2 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index 06e75c49e4..42787cc770 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -185,7 +185,6 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) - x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index abb31abe55..f07710941f 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -194,7 +194,6 @@ def forward(self, x): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) - x = self.out_rearrange(x) if self.include_fc: x = self.out_proj(x) From aae275de3a1e4baa4749be7980b1a649fd5ddc02 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 8 Aug 2024 00:14:42 +0800 Subject: [PATCH 18/28] minor fix Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_crossattention.py | 2 ++ tests/test_selfattention.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index ac0d79f6d2..5af09ed9b4 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -16,7 +16,9 @@ import numpy as np import torch +from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks.crossattention import CrossAttentionBlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 0c4ec1e616..5fb9669098 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -16,7 +16,9 @@ import numpy as np import torch +from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.blocks.selfattention import SABlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import From 531a83166f10dd80f6eb9fbc970aaf8b4639a300 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 8 Aug 2024 00:28:14 +0800 Subject: [PATCH 19/28] fix mypy Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/selfattention.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index f07710941f..275add37f9 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Tuple +from typing import Tuple, Union import torch import torch.nn as nn @@ -109,6 +109,11 @@ def __init__( self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) + self.qkv: Union[nn.Linear, nn.Identity] + self.to_q: Union[nn.Linear, nn.Identity] + self.to_k: Union[nn.Linear, nn.Identity] + self.to_v: Union[nn.Linear, nn.Identity] + if use_combined_linear: self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) self.to_q = self.to_k = self.to_v = nn.Identity() # add to enable torchscript From 48319c093ec709f2e2c5370d5709f763cdaa2bee Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 8 Aug 2024 01:05:35 +0800 Subject: [PATCH 20/28] fix ci Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_selfattention.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 5fb9669098..75ed99d72e 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -32,24 +32,23 @@ for num_heads in [4, 6, 8, 12]: for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: for input_size in [(16, 32), (8, 8, 8)]: - for flash_attn in [True, False]: - for include_fc in [True, False]: - for use_combined_linear in [True, False]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding, - "input_size": input_size, - "include_fc": include_fc, - "use_combined_linear": use_combined_linear, - "use_flash_attention": flash_attn, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for include_fc in [True, False]: + for use_combined_linear in [True, False]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + "include_fc": include_fc, + "use_combined_linear": use_combined_linear, + "use_flash_attention": True if rel_pos_embedding is None else False, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase): @@ -180,6 +179,7 @@ def count_sablock_params(*args, **kwargs): self.assertEqual(nparams_default, nparams_default_more_heads) @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) def test_script(self): for include_fc in [True, False]: for use_combined_linear in [True, False]: From b854d7acc0b39e448b855679586e129d13fb7b09 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 8 Aug 2024 01:13:07 +0800 Subject: [PATCH 21/28] minor fix Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_unetr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_unetr.py b/tests/test_unetr.py index 46018d2bc0..1217c9d85f 100644 --- a/tests/test_unetr.py +++ b/tests/test_unetr.py @@ -123,7 +123,7 @@ def test_ill_arg(self): ) @parameterized.expand(TEST_CASE_UNETR) - @SkipIfBeforePyTorchVersion((1, 9)) + @SkipIfBeforePyTorchVersion((2, 0)) def test_script(self, input_param, input_shape, _): net = UNETR(**(input_param)) net.eval() From 32d0a5d2916ca027359208baafcca841225a1a0b Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 8 Aug 2024 11:59:38 +0800 Subject: [PATCH 22/28] address comments Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/selfattention.py | 5 ++--- monai/networks/blocks/spatialattention.py | 13 +++++++++---- monai/networks/blocks/transformerblock.py | 8 +++++++- monai/networks/nets/diffusion_model_unet.py | 8 +++++++- 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 275add37f9..ac96b077bd 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -65,9 +65,8 @@ def __init__( attention_dtype: cast attention operations to this dtype. include_fc: whether to include the final linear layer. Default to True. use_combined_linear: whether to use a single linear layer for qkv projection, default to True. - use_flash_attention: if True, use Pytorch's inbuilt - flash attention for a memory efficient attention mechanism (see - https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index b1b6fc2961..665442b55e 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -32,8 +32,13 @@ class SpatialAttentionBlock(nn.Module): spatial_dims: number of spatial dimensions, could be 1, 2, or 3. num_channels: number of input channels. Must be divisible by num_head_channels. num_head_channels: number of channels per head. + norm_num_groups: Number of groups for the group norm layer. + norm_eps: Epsilon for the normalization. attention_dtype: cast attention operations to this dtype. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ @@ -45,9 +50,9 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, attention_dtype: Optional[torch.dtype] = None, - use_flash_attention: bool = False, include_fc: bool = True, - use_combined_linear: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -62,9 +67,9 @@ def __init__( num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype, - use_flash_attention=use_flash_attention, include_fc=include_fc, use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) def forward(self, x: torch.Tensor): diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 28d9c563ac..05eb3b07ab 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -37,6 +37,8 @@ def __init__( sequence_length: int | None = None, with_cross_attention: bool = False, use_flash_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = True, ) -> None: """ Args: @@ -47,7 +49,9 @@ def __init__( qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism - (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. """ @@ -69,6 +73,8 @@ def __init__( save_attn=save_attn, causal=causal, sequence_length=sequence_length, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) self.norm2 = nn.LayerNorm(hidden_size) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index a885339d0d..d69464682c 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -67,7 +67,9 @@ class DiffusionUNetTransformerBlock(nn.Module): cross_attention_dim: size of the context vector for cross attention. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism - (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. """ @@ -80,6 +82,8 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, ) -> None: super().__init__() self.attn1 = SABlock( @@ -89,6 +93,8 @@ def __init__( dim_head=num_head_channels, dropout_rate=dropout, attention_dtype=torch.float if upcast_attention else None, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) From e5f2cb17fc2ca3094e1082daa3203aad06f01c3b Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 8 Aug 2024 12:05:25 +0800 Subject: [PATCH 23/28] minor fix Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/crossattention.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index 42787cc770..bdecf63168 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -63,9 +63,8 @@ def __init__( input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. attention_dtype: cast attention operations to this dtype. - use_flash_attention: if True, use Pytorch's inbuilt - flash attention for a memory efficient attention mechanism (see - https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ super().__init__() From 818ba7e332a715fecca9bc874c8fed41cdd562f7 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 9 Aug 2024 11:26:13 +0800 Subject: [PATCH 24/28] Update tests/test_crossattention.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_crossattention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 5af09ed9b4..4888397308 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -166,16 +166,16 @@ def test_access_attn_matrix(self): matrix_acess_blk(torch.randn(input_shape)) assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + @parameterized.expand([[True], [False]]) @skipUnless(has_einops, "Requires einops") @SkipIfBeforePyTorchVersion((2, 0)) - def test_flash_attention(self): - for causal in [True, False]: - input_param = { - "hidden_size": 128, - "num_heads": 1, - "causal": causal, - "sequence_length": 16 if causal else None, - } + def test_flash_attention(self, causal): + input_param = { + "hidden_size": 128, + "num_heads": 1, + "causal": causal, + "sequence_length": 16 if causal else None, + } device = "cuda:0" if torch.cuda.is_available() else "cpu" block_w_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device) block_wo_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device) From 4bef7f09f299005d3f6d18599a0951ce314ef629 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 9 Aug 2024 11:27:36 +0800 Subject: [PATCH 25/28] Update tests/test_selfattention.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_selfattention.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 75ed99d72e..c35406d5fa 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -178,20 +178,19 @@ def count_sablock_params(*args, **kwargs): nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2) self.assertEqual(nparams_default, nparams_default_more_heads) + @parameterized.expand([[True,False], [True,True], [False, True],[False, False]]) @skipUnless(has_einops, "Requires einops") @SkipIfBeforePyTorchVersion((2, 0)) - def test_script(self): - for include_fc in [True, False]: - for use_combined_linear in [True, False]: - input_param = { - "hidden_size": 360, - "num_heads": 4, - "dropout_rate": 0.0, - "rel_pos_embedding": None, - "input_size": (16, 32), - "include_fc": include_fc, - "use_combined_linear": use_combined_linear, - } + def test_script(self, include_fc, use_combined_linear): + input_param = { + "hidden_size": 360, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": None, + "input_size": (16, 32), + "include_fc": include_fc, + "use_combined_linear": use_combined_linear, + } net = SABlock(**input_param) input_shape = (2, 512, 360) test_data = torch.randn(input_shape) From bfc8f29a40cc666224c05aa0ca92301dffedbb93 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 9 Aug 2024 11:30:06 +0800 Subject: [PATCH 26/28] minor fix Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_selfattention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index c35406d5fa..eda049ea98 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -178,7 +178,7 @@ def count_sablock_params(*args, **kwargs): nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2) self.assertEqual(nparams_default, nparams_default_more_heads) - @parameterized.expand([[True,False], [True,True], [False, True],[False, False]]) + @parameterized.expand([[True, False], [True, True], [False, True],[False, False]]) @skipUnless(has_einops, "Requires einops") @SkipIfBeforePyTorchVersion((2, 0)) def test_script(self, include_fc, use_combined_linear): From 0da115a4d59eef1c0fb9e84fb3f103032845e3ea Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 9 Aug 2024 13:45:54 +0800 Subject: [PATCH 27/28] address comments Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/nets/autoencoderkl.py | 39 ++++++ monai/networks/nets/controlnet.py | 13 ++ monai/networks/nets/diffusion_model_unet.py | 128 +++++++++++++++++- monai/networks/nets/spade_autoencoderkl.py | 22 +++ .../nets/spade_diffusion_model_unet.py | 41 +++++- monai/networks/nets/transformer.py | 10 ++ tests/test_crossattention.py | 7 +- tests/test_selfattention.py | 2 +- 8 files changed, 249 insertions(+), 13 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 35d80e0565..e6a9da9b9e 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -157,6 +157,10 @@ class Encoder(nn.Module): norm_eps: epsilon for the normalization. attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -170,6 +174,9 @@ def __init__( norm_eps: float, attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -220,6 +227,9 @@ def __init__( num_channels=input_channel, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -243,6 +253,9 @@ def __init__( num_channels=channels[-1], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) blocks.append( @@ -291,6 +304,10 @@ class Decoder(nn.Module): attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -305,6 +322,9 @@ def __init__( attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, use_convtranspose: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -350,6 +370,9 @@ def __init__( num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) blocks.append( @@ -389,6 +412,9 @@ def __init__( num_channels=block_in_ch, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -463,6 +489,10 @@ class AutoencoderKL(nn.Module): with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. use_checkpoint: if True, use activation checkpoint to save memory. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + include_fc: whether to include the final linear layer in the attention block. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -480,6 +510,9 @@ def __init__( with_decoder_nonlocal_attn: bool = True, use_checkpoint: bool = False, use_convtranspose: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -509,6 +542,9 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_encoder_nonlocal_attn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.decoder = Decoder( spatial_dims=spatial_dims, @@ -521,6 +557,9 @@ def __init__( attention_levels=attention_levels, with_nonlocal_attn=with_decoder_nonlocal_attn, use_convtranspose=use_convtranspose, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.quant_conv_mu = Convolution( spatial_dims=spatial_dims, diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index ed3654733d..ca0977717c 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -143,6 +143,10 @@ class ControlNet(nn.Module): upcast_attention: if True, upcast attention operations to full precision. conditioning_embedding_in_channels: number of input channels for the conditioning embedding. conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -163,6 +167,9 @@ def __init__( upcast_attention: bool = False, conditioning_embedding_in_channels: int = 1, conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256), + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -282,6 +289,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -326,6 +336,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) controlnet_block = Convolution( diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index d69464682c..d924c9c800 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -140,6 +140,11 @@ class SpatialTransformer(nn.Module): norm_eps: epsilon for the normalization. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + """ def __init__( @@ -154,6 +159,9 @@ def __init__( norm_eps: float = 1e-6, cross_attention_dim: int | None = None, upcast_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -181,6 +189,9 @@ def __init__( dropout=dropout, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) for _ in range(num_layers) ] @@ -535,6 +546,10 @@ class AttnDownBlock(nn.Module): resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. num_head_channels: number of channels in each attention head. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -550,6 +565,9 @@ def __init__( resblock_updown: bool = False, downsample_padding: int = 1, num_head_channels: int = 1, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -576,6 +594,9 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -642,7 +663,11 @@ class CrossAttnDownBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -662,6 +687,9 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -694,6 +722,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -751,6 +782,10 @@ class AttnMidBlock(nn.Module): norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. num_head_channels: number of channels in each attention head. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -761,6 +796,9 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, num_head_channels: int = 1, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -778,6 +816,9 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.resnet_2 = DiffusionUNetResnetBlock( @@ -814,6 +855,10 @@ class CrossAttnMidBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -828,6 +873,9 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -850,6 +898,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.resnet_2 = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, @@ -995,6 +1046,10 @@ class AttnUpBlock(nn.Module): add_upsample: if True add downsample block. resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -1010,6 +1065,9 @@ def __init__( add_upsample: bool = True, resblock_updown: bool = False, num_head_channels: int = 1, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1038,6 +1096,9 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -1122,7 +1183,11 @@ class CrossAttnUpBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -1142,6 +1207,9 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1175,6 +1243,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -1256,6 +1327,9 @@ def get_down_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return AttnDownBlock( @@ -1269,6 +1343,9 @@ def get_down_block( add_downsample=add_downsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnDownBlock( @@ -1286,6 +1363,9 @@ def get_down_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) else: return DownBlock( @@ -1313,6 +1393,9 @@ def get_mid_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_conditioning: return CrossAttnMidBlock( @@ -1326,6 +1409,9 @@ def get_mid_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) else: return AttnMidBlock( @@ -1335,6 +1421,9 @@ def get_mid_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, num_head_channels=num_head_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) @@ -1356,6 +1445,9 @@ def get_up_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return AttnUpBlock( @@ -1370,6 +1462,9 @@ def get_up_block( add_upsample=add_upsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnUpBlock( @@ -1388,6 +1483,9 @@ def get_up_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) else: return UpBlock( @@ -1425,9 +1523,13 @@ class DiffusionModelUNet(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` - classes. + classes. upcast_attention: if True, upcast attention operations to full precision. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -1448,6 +1550,9 @@ def __init__( num_class_embeds: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1542,6 +1647,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -1559,6 +1667,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) # up @@ -1593,6 +1704,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.up_blocks.append(up_block) @@ -1788,6 +1902,9 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1872,6 +1989,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) diff --git a/monai/networks/nets/spade_autoencoderkl.py b/monai/networks/nets/spade_autoencoderkl.py index d5794a9227..cc8909194a 100644 --- a/monai/networks/nets/spade_autoencoderkl.py +++ b/monai/networks/nets/spade_autoencoderkl.py @@ -137,6 +137,10 @@ class SPADEDecoder(nn.Module): label_nc: number of semantic channels for SPADE normalisation. with_nonlocal_attn: if True use non-local attention block. spade_intermediate_channels: number of intermediate channels for SPADE block layer. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -152,6 +156,9 @@ def __init__( label_nc: int, with_nonlocal_attn: bool = True, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -200,6 +207,9 @@ def __init__( num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) blocks.append( @@ -243,6 +253,9 @@ def __init__( num_channels=block_in_ch, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -331,6 +344,9 @@ def __init__( with_encoder_nonlocal_attn: bool = True, with_decoder_nonlocal_attn: bool = True, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -360,6 +376,9 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_encoder_nonlocal_attn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.decoder = SPADEDecoder( spatial_dims=spatial_dims, @@ -373,6 +392,9 @@ def __init__( label_nc=label_nc, with_nonlocal_attn=with_decoder_nonlocal_attn, spade_intermediate_channels=spade_intermediate_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.quant_conv_mu = Convolution( spatial_dims=spatial_dims, diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py index 75d1687df3..a9609b1d39 100644 --- a/monai/networks/nets/spade_diffusion_model_unet.py +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -325,6 +325,10 @@ class SPADEAttnUpBlock(nn.Module): resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. spade_intermediate_channels: number of intermediate channels for SPADE block layer + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -342,6 +346,9 @@ def __init__( resblock_updown: bool = False, num_head_channels: int = 1, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -371,6 +378,9 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -457,6 +467,8 @@ class SPADECrossAttnUpBlock(nn.Module): cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. spade_intermediate_channels: number of intermediate channels for SPADE block layer. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism. + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -477,6 +489,9 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -510,6 +525,9 @@ def __init__( num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -592,6 +610,9 @@ def get_spade_up_block( cross_attention_dim: int | None, upcast_attention: bool = False, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return SPADEAttnUpBlock( @@ -608,6 +629,9 @@ def get_spade_up_block( resblock_updown=resblock_updown, num_head_channels=num_head_channels, spade_intermediate_channels=spade_intermediate_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) elif with_cross_attn: return SPADECrossAttnUpBlock( @@ -627,6 +651,7 @@ def get_spade_up_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, spade_intermediate_channels=spade_intermediate_channels, + use_flash_attention=use_flash_attention, ) else: return SPADEUpBlock( @@ -667,9 +692,11 @@ class SPADEDiffusionModelUNet(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` - classes. + classes. upcast_attention: if True, upcast attention operations to full precision. - spade_intermediate_channels: number of intermediate channels for SPADE block layer + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -691,6 +718,9 @@ def __init__( num_class_embeds: int | None = None, upcast_attention: bool = False, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -783,6 +813,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -799,6 +832,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) # up @@ -834,6 +870,7 @@ def __init__( upcast_attention=upcast_attention, label_nc=label_nc, spade_intermediate_channels=spade_intermediate_channels, + use_flash_attention=use_flash_attention, ) self.up_blocks.append(up_block) diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py index 1af725abda..cc51436f10 100644 --- a/monai/networks/nets/transformer.py +++ b/monai/networks/nets/transformer.py @@ -51,6 +51,10 @@ class DecoderOnlyTransformer(nn.Module): attn_layers_heads: Number of attention heads. with_cross_attention: Whether to use cross attention for conditioning. embedding_dropout_rate: Dropout rate for the embedding. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -62,6 +66,9 @@ def __init__( attn_layers_heads: int, with_cross_attention: bool = False, embedding_dropout_rate: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.num_tokens = num_tokens @@ -86,6 +93,9 @@ def __init__( causal=True, sequence_length=max_seq_len, with_cross_attention=with_cross_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) for _ in range(attn_layers_depth) ] diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 4888397308..e034e42290 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -170,12 +170,7 @@ def test_access_attn_matrix(self): @skipUnless(has_einops, "Requires einops") @SkipIfBeforePyTorchVersion((2, 0)) def test_flash_attention(self, causal): - input_param = { - "hidden_size": 128, - "num_heads": 1, - "causal": causal, - "sequence_length": 16 if causal else None, - } + input_param = {"hidden_size": 128, "num_heads": 1, "causal": causal, "sequence_length": 16 if causal else None} device = "cuda:0" if torch.cuda.is_available() else "cpu" block_w_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device) block_wo_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index eda049ea98..88919fd8b1 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -178,7 +178,7 @@ def count_sablock_params(*args, **kwargs): nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2) self.assertEqual(nparams_default, nparams_default_more_heads) - @parameterized.expand([[True, False], [True, True], [False, True],[False, False]]) + @parameterized.expand([[True, False], [True, True], [False, True], [False, False]]) @skipUnless(has_einops, "Requires einops") @SkipIfBeforePyTorchVersion((2, 0)) def test_script(self, include_fc, use_combined_linear): From 1c5599ddebb949dee6e165bc0f6259d9fde0d5dd Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 9 Aug 2024 14:52:15 +0800 Subject: [PATCH 28/28] fix state dict Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/nets/autoencoderkl.py | 34 +++++++++------------ monai/networks/nets/controlnet.py | 23 +++++--------- monai/networks/nets/diffusion_model_unet.py | 24 +++++---------- monai/networks/nets/transformer.py | 24 +++++---------- 4 files changed, 36 insertions(+), 69 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index e6a9da9b9e..836027796f 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -704,27 +704,18 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] + new_state_dict[k] = old_state_dict.pop(k) # fix the attention blocks - attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k] + attention_blocks = [k.replace(".attn.to_q.weight", "") for k in new_state_dict if "attn.to_q.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.to_q.weight"], - old_state_dict[f"{block}.to_k.weight"], - old_state_dict[f"{block}.to_v.weight"], - ], - dim=0, - ) - new_state_dict[f"{block}.attn.qkv.bias"] = torch.cat( - [ - old_state_dict[f"{block}.to_q.bias"], - old_state_dict[f"{block}.to_k.bias"], - old_state_dict[f"{block}.to_v.bias"], - ], - dim=0, - ) + new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight") + new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight") + new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight") + new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias") + new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias") + new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias") + # old version did not have a projection so set these to the identity new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye( new_state_dict[f"{block}.attn.out_proj.weight"].shape[0] @@ -737,5 +728,8 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: for k in new_state_dict: if "postconv" in k: old_name = k.replace("postconv", "conv") - new_state_dict[k] = old_state_dict[old_name] - self.load_state_dict(new_state_dict) + new_state_dict[k] = old_state_dict.pop(old_name) + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) + self.load_state_dict(new_state_dict, strict=True) diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index ca0977717c..8b08eaae10 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -454,25 +454,16 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] + new_state_dict[k] = old_state_dict.pop(k) # fix the attention blocks - attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] + attention_blocks = [k.replace(".out_proj.weight", "") for k in new_state_dict if "out_proj.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.attn1.to_q.weight"], - old_state_dict[f"{block}.attn1.to_k.weight"], - old_state_dict[f"{block}.attn1.to_v.weight"], - ], - dim=0, - ) - # projection - new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"] - new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"] - - new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"] - new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"] + new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight") + new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias") + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) self.load_state_dict(new_state_dict) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index d924c9c800..f57fe251d2 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -1834,31 +1834,23 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] + new_state_dict[k] = old_state_dict.pop(k) # fix the attention blocks - attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] + attention_blocks = [k.replace(".out_proj.weight", "") for k in new_state_dict if "out_proj.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.attn1.to_q.weight"], - old_state_dict[f"{block}.attn1.to_k.weight"], - old_state_dict[f"{block}.attn1.to_v.weight"], - ], - dim=0, - ) - # projection - new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"] - new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"] + new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight") + new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias") - new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"] - new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"] # fix the upsample conv blocks which were renamed postconv for k in new_state_dict: if "postconv" in k: old_name = k.replace("postconv", "conv") - new_state_dict[k] = old_state_dict[old_name] + new_state_dict[k] = old_state_dict.pop(old_name) + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) self.load_state_dict(new_state_dict) diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py index cc51436f10..3a278c112a 100644 --- a/monai/networks/nets/transformer.py +++ b/monai/networks/nets/transformer.py @@ -143,25 +143,15 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] - - # fix the attention blocks - attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k] - for block in attention_blocks: - new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.attn.to_q.weight"], - old_state_dict[f"{block}.attn.to_k.weight"], - old_state_dict[f"{block}.attn.to_v.weight"], - ], - dim=0, - ) + new_state_dict[k] = old_state_dict.pop(k) # fix the renamed norm blocks first norm2 -> norm_cross_attention , norm3 -> norm2 - for k in old_state_dict: + for k in list(old_state_dict.keys()): if "norm2" in k: - new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict[k] + new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict.pop(k) if "norm3" in k: - new_state_dict[k.replace("norm3", "norm2")] = old_state_dict[k] - + new_state_dict[k.replace("norm3", "norm2")] = old_state_dict.pop(k) + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) self.load_state_dict(new_state_dict)