diff --git a/generative/networks/blocks/__init__.py b/generative/networks/blocks/__init__.py new file mode 100644 index 00000000..9036f087 --- /dev/null +++ b/generative/networks/blocks/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from .selfattention import SABlock diff --git a/generative/networks/blocks/selfattention.py b/generative/networks/blocks/selfattention.py new file mode 100644 index 00000000..396bffcc --- /dev/null +++ b/generative/networks/blocks/selfattention.py @@ -0,0 +1,103 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +class SABlock(nn.Module): + """ + A self-attention block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Args: + hidden_size: dimension of hidden layer. + num_heads: number of attention heads. + dropout_rate: dropout ratio. Defaults to no dropout. + qkv_bias: bias term for the qkv linear layer. + causal: whether to use causal attention. + sequence_length: if causal is True, it is necessary to specify the sequence length. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + causal: bool = False, + sequence_length: int | None = None, + ) -> None: + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise ValueError("hidden size should be divisible by num_heads.") + + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") + + # output projection + self.out_proj = nn.Linear(hidden_size, hidden_size) + # key, query, value projections for all heads, but in a batch + self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) + # regularization + self.drop_weights = nn.Dropout(dropout_rate) + self.drop_output = nn.Dropout(dropout_rate) + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + self.causal = causal + self.sequence_length = sequence_length + + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.mask = torch.tril(torch.ones(sequence_length, sequence_length)).view( + 1, 1, sequence_length, sequence_length + ) + else: + self.mask = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) + + if self.sequence_length is not None and t != self.sequence_length: + raise ValueError("sequence length should be equal to the one specified in the SABlock constructor.") + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + query, key, value = self.qkv(x).split(self.hidden_size, dim=2) + key = key.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) + query = query.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) + value = value.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) + + # manual implementation of attention + attention_scores = (query @ key.transpose(-2, -1)) * self.scale + + if self.causal: + attention_scores = attention_scores.masked_fill(self.mask[:, :, :t, :t] == 0, float("-inf")) + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.drop_weights(attention_probs) + y = attention_probs @ value # (b, nh, t, t) x (b, nh, t, hs) -> (b, nh, t, hs) + y = y.transpose(1, 2).contiguous().view(b, t, c) # re-assemble all head outputs side by side + + y = self.out_proj(y) + y = self.drop_output(y) + return y diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py new file mode 100644 index 00000000..09f2feb6 --- /dev/null +++ b/tests/test_selfattention.py @@ -0,0 +1,66 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from monai.networks import eval_mode +from parameterized import parameterized + +from generative.networks.blocks.selfattention import SABlock + +TEST_CASE_SABLOCK = [ + [ + {"hidden_size": 16, "num_heads": 8, "dropout_rate": 0.2, "causal": False, "sequence_length": None}, + (2, 4, 16), + (2, 4, 16), + ], + [ + {"hidden_size": 16, "num_heads": 8, "dropout_rate": 0.2, "causal": True, "sequence_length": 4}, + (2, 4, 16), + (2, 4, 16), + ], +] + + +class TestResBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_SABLOCK) + def test_shape(self, input_param, input_shape, expected_shape): + net = SABlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + SABlock(hidden_size=12, num_heads=4, dropout_rate=6.0) + + with self.assertRaises(ValueError): + SABlock(hidden_size=12, num_heads=4, dropout_rate=-6.0) + + with self.assertRaises(ValueError): + SABlock(hidden_size=20, num_heads=8, dropout_rate=0.4) + + with self.assertRaises(ValueError): + SABlock(hidden_size=12, num_heads=4, dropout_rate=0.4, causal=True, sequence_length=None) + + def test_wrong_sequence_length(self): + net = SABlock(hidden_size=16, num_heads=4, dropout_rate=0.0, causal=True, sequence_length=6) + with self.assertRaises(ValueError): + with eval_mode(net): + result = net(torch.randn((2, 4, 16))) + self.assertEqual(result.shape, (2, 4, 16)) + + +if __name__ == "__main__": + unittest.main()