From 39a6f579538532c84ec704bb749a8191f1f17166 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Wed, 24 Jan 2024 23:58:56 +0100 Subject: [PATCH 1/3] causal self attention Signed-off-by: vgrau98 --- monai/networks/blocks/selfattention.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 3bef24b4e8..306ac534db 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -38,6 +38,8 @@ def __init__( save_attn: bool = False, rel_pos_embedding: Optional[str] = None, input_size: Optional[Tuple] = None, + causal: bool = False, + sequence_length: int | None = None, ) -> None: """ Args: @@ -49,6 +51,8 @@ def __init__( For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. + causal (bool): wether to use causal attention. If true `sequence_length` has to be set + sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. """ @@ -61,6 +65,9 @@ def __init__( if hidden_size % num_heads != 0: raise ValueError("hidden size should be divisible by num_heads.") + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") + self.num_heads = num_heads self.out_proj = nn.Linear(hidden_size, hidden_size) self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) @@ -70,6 +77,8 @@ def __init__( self.drop_weights = nn.Dropout(dropout_rate) self.head_dim = hidden_size // num_heads self.scale = self.head_dim**-0.5 + self.causal = causal + self.sequence_length = sequence_length self.save_attn = save_attn self.att_mat = torch.Tensor() self.rel_positional_embedding = ( @@ -79,6 +88,14 @@ def __init__( ) self.input_size = input_size + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), + ) + self.causal_mask: torch.Tensor + def forward(self, x: torch.Tensor): """ Args: @@ -87,12 +104,15 @@ def forward(self, x: torch.Tensor): Return: torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C """ + _, t, _ = x.size() output = self.input_rearrange(self.qkv(x)) q, k, v = output[0], output[1], output[2] att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + # apply causal mask if set + att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :t] == 0, float("-inf")) if self.causal else att_mat att_mat = att_mat.softmax(dim=-1) From 5b3c4f3c1bd1747b7b6a3551b0ac20affcd1b6fe Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Thu, 25 Jan 2024 00:06:06 +0100 Subject: [PATCH 2/3] causal selfattention tests Signed-off-by: vgrau98 --- tests/test_selfattention.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 0d0553ed2c..277ba7faf9 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -31,18 +31,21 @@ for num_heads in [4, 6, 8, 12]: for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: for input_size in [(16, 32), (8, 8, 8)]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding, - "input_size": input_size, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for causal in [False, True]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + "causal": causal, + "sequence_length": 512, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase): From 440a89353e7ebffa01463902be0e9b8c4d690dcd Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sat, 27 Apr 2024 13:10:23 +0200 Subject: [PATCH 3/3] integrate flash attention usage Signed-off-by: vgrau98 --- monai/networks/blocks/selfattention.py | 56 ++++++++++++++++++++------ tests/test_selfattention.py | 29 +++++++++++++ 2 files changed, 72 insertions(+), 13 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 306ac534db..0a848e9ec5 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -12,6 +12,7 @@ from __future__ import annotations from typing import Optional, Tuple +import warnings import torch import torch.nn as nn @@ -19,6 +20,7 @@ from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import +xops, has_xformers = optional_import("xformers.ops") Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -40,6 +42,7 @@ def __init__( input_size: Optional[Tuple] = None, causal: bool = False, sequence_length: int | None = None, + use_flash_attention: bool = False, ) -> None: """ Args: @@ -68,11 +71,23 @@ def __init__( if causal and sequence_length is None: raise ValueError("sequence_length is necessary for causal attention.") + if use_flash_attention and rel_pos_embedding is not None: + self.use_flash_attention = False + warnings.warn( + "flash attention set to `False`: flash attention can't be used with relative position embedding. Set `rel_pos_embedding` to `None` to use flash attention" + ) + else: + self.use_flash_attention = use_flash_attention + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + self.num_heads = num_heads self.out_proj = nn.Linear(hidden_size, hidden_size) self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) self.out_rearrange = Rearrange("b h l d -> b l (h d)") + self.dropout_rate = dropout_rate self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.head_dim = hidden_size // num_heads @@ -105,24 +120,39 @@ def forward(self, x: torch.Tensor): torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C """ _, t, _ = x.size() - output = self.input_rearrange(self.qkv(x)) + output = self.input_rearrange(self.qkv(x)) # 3 x B x (s_dim_1 * ... * s_dim_n) x h x C/h q, k, v = output[0], output[1], output[2] - att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale - # apply relative positional embedding if defined - att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat - # apply causal mask if set - att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :t] == 0, float("-inf")) if self.causal else att_mat + if self.use_flash_attention: + x = xops.memory_efficient_attention( + query=q.contiguous(), + key=k.contiguous(), + value=v.contiguous(), + scale=self.scale, + p=self.dropout_rate, + attn_bias=xops.LowerTriangularMask() if self.causal else None, + ) + else: + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + + # apply relative positional embedding if defined + att_mat = ( + self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + ) + # apply causal mask if set + att_mat = ( + att_mat.masked_fill(self.causal_mask[:, :, :t, :t] == 0, float("-inf")) if self.causal else att_mat + ) - att_mat = att_mat.softmax(dim=-1) + att_mat = att_mat.softmax(dim=-1) - if self.save_attn: - # no gradients and new tensor; - # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html - self.att_mat = att_mat.detach() + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() - att_mat = self.drop_weights(att_mat) - x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 277ba7faf9..33622d63c8 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -24,6 +24,7 @@ from monai.utils import optional_import einops, has_einops = optional_import("einops") +xops, has_xformers = optional_import("xformers.ops") TEST_CASE_SABLOCK = [] for dropout_rate in np.linspace(0, 1, 4): @@ -57,6 +58,34 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + @skipUnless(has_xformers, "Requires xformers") + def test_flash_attention(self): + hidden_size = 360 + num_heads = 4 + dropout_rate = 0 + input_shape = (2, 512, hidden_size) + expected_shape = (2, 512, hidden_size) + flash_attention_block = SABlock(hidden_size, num_heads, dropout_rate, use_flash_attention=True) + # flash attention set to false because of conflict using relative position embedding at the same time + no_flash_attention_block = SABlock( + hidden_size, + num_heads, + dropout_rate, + use_flash_attention=True, + rel_pos_embedding=RelPosEmbedding.DECOMPOSED, + sequence_length=512, + input_size=([16, 32]), + ) + + self.assertFalse(no_flash_attention_block.use_flash_attention) + + with eval_mode(flash_attention_block): + result = flash_attention_block(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + with eval_mode(no_flash_attention_block): + result = no_flash_attention_block(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=128, num_heads=12, dropout_rate=6.0)