diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 3bef24b4e8..0a848e9ec5 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -12,6 +12,7 @@ from __future__ import annotations from typing import Optional, Tuple +import warnings import torch import torch.nn as nn @@ -19,6 +20,7 @@ from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import +xops, has_xformers = optional_import("xformers.ops") Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -38,6 +40,9 @@ def __init__( save_attn: bool = False, rel_pos_embedding: Optional[str] = None, input_size: Optional[Tuple] = None, + causal: bool = False, + sequence_length: int | None = None, + use_flash_attention: bool = False, ) -> None: """ Args: @@ -49,6 +54,8 @@ def __init__( For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. + causal (bool): wether to use causal attention. If true `sequence_length` has to be set + sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. """ @@ -61,15 +68,32 @@ def __init__( if hidden_size % num_heads != 0: raise ValueError("hidden size should be divisible by num_heads.") + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") + + if use_flash_attention and rel_pos_embedding is not None: + self.use_flash_attention = False + warnings.warn( + "flash attention set to `False`: flash attention can't be used with relative position embedding. Set `rel_pos_embedding` to `None` to use flash attention" + ) + else: + self.use_flash_attention = use_flash_attention + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + self.num_heads = num_heads self.out_proj = nn.Linear(hidden_size, hidden_size) self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) self.out_rearrange = Rearrange("b h l d -> b l (h d)") + self.dropout_rate = dropout_rate self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.head_dim = hidden_size // num_heads self.scale = self.head_dim**-0.5 + self.causal = causal + self.sequence_length = sequence_length self.save_attn = save_attn self.att_mat = torch.Tensor() self.rel_positional_embedding = ( @@ -79,6 +103,14 @@ def __init__( ) self.input_size = input_size + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), + ) + self.causal_mask: torch.Tensor + def forward(self, x: torch.Tensor): """ Args: @@ -87,22 +119,40 @@ def forward(self, x: torch.Tensor): Return: torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C """ - output = self.input_rearrange(self.qkv(x)) + _, t, _ = x.size() + output = self.input_rearrange(self.qkv(x)) # 3 x B x (s_dim_1 * ... * s_dim_n) x h x C/h q, k, v = output[0], output[1], output[2] - att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale - - # apply relative positional embedding if defined - att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat - - att_mat = att_mat.softmax(dim=-1) - - if self.save_attn: - # no gradients and new tensor; - # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html - self.att_mat = att_mat.detach() - att_mat = self.drop_weights(att_mat) - x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + if self.use_flash_attention: + x = xops.memory_efficient_attention( + query=q.contiguous(), + key=k.contiguous(), + value=v.contiguous(), + scale=self.scale, + p=self.dropout_rate, + attn_bias=xops.LowerTriangularMask() if self.causal else None, + ) + else: + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + + # apply relative positional embedding if defined + att_mat = ( + self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + ) + # apply causal mask if set + att_mat = ( + att_mat.masked_fill(self.causal_mask[:, :, :t, :t] == 0, float("-inf")) if self.causal else att_mat + ) + + att_mat = att_mat.softmax(dim=-1) + + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() + + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 0d0553ed2c..33622d63c8 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -24,6 +24,7 @@ from monai.utils import optional_import einops, has_einops = optional_import("einops") +xops, has_xformers = optional_import("xformers.ops") TEST_CASE_SABLOCK = [] for dropout_rate in np.linspace(0, 1, 4): @@ -31,18 +32,21 @@ for num_heads in [4, 6, 8, 12]: for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: for input_size in [(16, 32), (8, 8, 8)]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding, - "input_size": input_size, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for causal in [False, True]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + "causal": causal, + "sequence_length": 512, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase): @@ -54,6 +58,34 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + @skipUnless(has_xformers, "Requires xformers") + def test_flash_attention(self): + hidden_size = 360 + num_heads = 4 + dropout_rate = 0 + input_shape = (2, 512, hidden_size) + expected_shape = (2, 512, hidden_size) + flash_attention_block = SABlock(hidden_size, num_heads, dropout_rate, use_flash_attention=True) + # flash attention set to false because of conflict using relative position embedding at the same time + no_flash_attention_block = SABlock( + hidden_size, + num_heads, + dropout_rate, + use_flash_attention=True, + rel_pos_embedding=RelPosEmbedding.DECOMPOSED, + sequence_length=512, + input_size=([16, 32]), + ) + + self.assertFalse(no_flash_attention_block.use_flash_attention) + + with eval_mode(flash_attention_block): + result = flash_attention_block(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + with eval_mode(no_flash_attention_block): + result = no_flash_attention_block(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=128, num_heads=12, dropout_rate=6.0)