From 35e8c0f41487b27ddbbe53b31613434c5993523e Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 3 Feb 2023 13:28:40 +0000 Subject: [PATCH 1/5] Add original SABlock Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/blocks/selfattention.py | 65 +++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 generative/networks/blocks/selfattention.py diff --git a/generative/networks/blocks/selfattention.py b/generative/networks/blocks/selfattention.py new file mode 100644 index 00000000..519c8c77 --- /dev/null +++ b/generative/networks/blocks/selfattention.py @@ -0,0 +1,65 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn as nn + +from monai.utils import optional_import + +Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") + + +class SABlock(nn.Module): + """ + A self-attention block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + """ + + def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None: + """ + Args: + hidden_size: dimension of hidden layer. + num_heads: number of attention heads. + dropout_rate: faction of the input units to drop. + qkv_bias: bias term for the qkv linear layer. + + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise ValueError("hidden size should be divisible by num_heads.") + + self.num_heads = num_heads + self.out_proj = nn.Linear(hidden_size, hidden_size) + self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) + self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) + self.out_rearrange = Rearrange("b h l d -> b l (h d)") + self.drop_output = nn.Dropout(dropout_rate) + self.drop_weights = nn.Dropout(dropout_rate) + self.head_dim = hidden_size // num_heads + self.scale = self.head_dim**-0.5 + + def forward(self, x): + output = self.input_rearrange(self.qkv(x)) + q, k, v = output[0], output[1], output[2] + att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + x = self.out_rearrange(x) + x = self.out_proj(x) + x = self.drop_output(x) + return x From 06314aa393f3031f521b970625653925dd7e484f Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 3 Feb 2023 16:17:09 +0000 Subject: [PATCH 2/5] Move towards implementation without eionps Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/blocks/selfattention.py | 49 +++++++++++++-------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/generative/networks/blocks/selfattention.py b/generative/networks/blocks/selfattention.py index 519c8c77..83c9f85f 100644 --- a/generative/networks/blocks/selfattention.py +++ b/generative/networks/blocks/selfattention.py @@ -11,12 +11,9 @@ from __future__ import annotations -import torch import torch.nn as nn - -from monai.utils import optional_import - -Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") +import math +from torch.nn import functional as F class SABlock(nn.Module): @@ -25,11 +22,12 @@ class SABlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ - def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None: + def __init__(self, hidden_size: int, num_heads: int, causal:bool=False, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None: """ Args: hidden_size: dimension of hidden layer. num_heads: number of attention heads. + causal: whether to use causal attention. dropout_rate: faction of the input units to drop. qkv_bias: bias term for the qkv linear layer. @@ -43,23 +41,36 @@ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, if hidden_size % num_heads != 0: raise ValueError("hidden size should be divisible by num_heads.") - self.num_heads = num_heads + # output projection self.out_proj = nn.Linear(hidden_size, hidden_size) + # key, query, value projections for all heads, but in a batch self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) - self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) - self.out_rearrange = Rearrange("b h l d -> b l (h d)") + # regularization self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) + + self.num_heads = num_heads self.head_dim = hidden_size // num_heads - self.scale = self.head_dim**-0.5 + self.scale = 1.0 / math.sqrt(self.head_dim) + self.causal = causal def forward(self, x): - output = self.input_rearrange(self.qkv(x)) - q, k, v = output[0], output[1], output[2] - att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) - att_mat = self.drop_weights(att_mat) - x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) - x = self.out_rearrange(x) - x = self.out_proj(x) - x = self.drop_output(x) - return x + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k ,v = self.qkv(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.drop_weights(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + y = self.out_proj(y) + y = self.drop_output(y) + return y From 02cd1204092df7ef2a1b9df66dd26b071ed2e4eb Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 3 Feb 2023 16:19:13 +0000 Subject: [PATCH 3/5] Move towards implementation without eionps Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/blocks/selfattention.py | 23 ++++++++++++--------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/generative/networks/blocks/selfattention.py b/generative/networks/blocks/selfattention.py index 83c9f85f..67a4530c 100644 --- a/generative/networks/blocks/selfattention.py +++ b/generative/networks/blocks/selfattention.py @@ -11,8 +11,9 @@ from __future__ import annotations -import torch.nn as nn import math + +import torch.nn as nn from torch.nn import functional as F @@ -22,7 +23,9 @@ class SABlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ - def __init__(self, hidden_size: int, num_heads: int, causal:bool=False, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None: + def __init__( + self, hidden_size: int, num_heads: int, causal: bool = False, dropout_rate: float = 0.0, qkv_bias: bool = False + ) -> None: """ Args: hidden_size: dimension of hidden layer. @@ -55,21 +58,21 @@ def __init__(self, hidden_size: int, num_heads: int, causal:bool=False, dropout_ self.causal = causal def forward(self, x): - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim - q, k ,v = self.qkv(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q, k, v = self.qkv(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) # manual implementation of attention att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.drop_weights(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side y = self.out_proj(y) y = self.drop_output(y) From f50b9ee03ae9aec6d60aa5041c7aec16942b9a53 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sun, 12 Mar 2023 11:58:34 +0000 Subject: [PATCH 4/5] Add causal operation to self attention block Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/blocks/selfattention.py | 70 ++++++++++++++------- tests/test_selfattention.py | 66 +++++++++++++++++++ 2 files changed, 113 insertions(+), 23 deletions(-) create mode 100644 tests/test_selfattention.py diff --git a/generative/networks/blocks/selfattention.py b/generative/networks/blocks/selfattention.py index 67a4530c..4102a310 100644 --- a/generative/networks/blocks/selfattention.py +++ b/generative/networks/blocks/selfattention.py @@ -13,6 +13,7 @@ import math +import torch import torch.nn as nn from torch.nn import functional as F @@ -21,21 +22,25 @@ class SABlock(nn.Module): """ A self-attention block, based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Args: + hidden_size: dimension of hidden layer. + num_heads: number of attention heads. + dropout_rate: dropout ratio. Defaults to no dropout. + qkv_bias: bias term for the qkv linear layer. + causal: whether to use causal attention. + sequence_length: if causal is True, it is necessary to specify the sequence length. """ def __init__( - self, hidden_size: int, num_heads: int, causal: bool = False, dropout_rate: float = 0.0, qkv_bias: bool = False + self, + hidden_size: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + causal: bool = False, + sequence_length: int | None = None, ) -> None: - """ - Args: - hidden_size: dimension of hidden layer. - num_heads: number of attention heads. - causal: whether to use causal attention. - dropout_rate: faction of the input units to drop. - qkv_bias: bias term for the qkv linear layer. - - """ - super().__init__() if not (0 <= dropout_rate <= 1): @@ -44,35 +49,54 @@ def __init__( if hidden_size % num_heads != 0: raise ValueError("hidden size should be divisible by num_heads.") + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") + # output projection self.out_proj = nn.Linear(hidden_size, hidden_size) # key, query, value projections for all heads, but in a batch self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) # regularization - self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) + self.drop_output = nn.Dropout(dropout_rate) + self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.scale = 1.0 / math.sqrt(self.head_dim) self.causal = causal + self.sequence_length = sequence_length + + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.mask = torch.tril(torch.ones(sequence_length, sequence_length)).view( + 1, 1, sequence_length, sequence_length + ) + else: + self.mask = None def forward(self, x): - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) + + if self.sequence_length is not None and t != self.sequence_length: + raise ValueError("sequence length should be equal to the one specified in the SABlock constructor.") # calculate query, key, values for all heads in batch and move head forward to be the batch dim - q, k, v = self.qkv(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + query, key, value = self.qkv(x).split(self.hidden_size, dim=2) + key = key.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) + query = query.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) + value = value.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) # manual implementation of attention - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) - att = F.softmax(att, dim=-1) - att = self.drop_weights(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + attention_scores = (query @ key.transpose(-2, -1)) * self.scale + + if self.causal: + attention_scores = attention_scores.masked_fill(self.mask[:, :, :t, :t] == 0, float("-inf")) + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.drop_weights(attention_probs) + y = attention_probs @ value # (b, nh, t, t) x (b, nh, t, hs) -> (b, nh, t, hs) + y = y.transpose(1, 2).contiguous().view(b, t, c) # re-assemble all head outputs side by side y = self.out_proj(y) y = self.drop_output(y) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py new file mode 100644 index 00000000..09f2feb6 --- /dev/null +++ b/tests/test_selfattention.py @@ -0,0 +1,66 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from monai.networks import eval_mode +from parameterized import parameterized + +from generative.networks.blocks.selfattention import SABlock + +TEST_CASE_SABLOCK = [ + [ + {"hidden_size": 16, "num_heads": 8, "dropout_rate": 0.2, "causal": False, "sequence_length": None}, + (2, 4, 16), + (2, 4, 16), + ], + [ + {"hidden_size": 16, "num_heads": 8, "dropout_rate": 0.2, "causal": True, "sequence_length": 4}, + (2, 4, 16), + (2, 4, 16), + ], +] + + +class TestResBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_SABLOCK) + def test_shape(self, input_param, input_shape, expected_shape): + net = SABlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + SABlock(hidden_size=12, num_heads=4, dropout_rate=6.0) + + with self.assertRaises(ValueError): + SABlock(hidden_size=12, num_heads=4, dropout_rate=-6.0) + + with self.assertRaises(ValueError): + SABlock(hidden_size=20, num_heads=8, dropout_rate=0.4) + + with self.assertRaises(ValueError): + SABlock(hidden_size=12, num_heads=4, dropout_rate=0.4, causal=True, sequence_length=None) + + def test_wrong_sequence_length(self): + net = SABlock(hidden_size=16, num_heads=4, dropout_rate=0.0, causal=True, sequence_length=6) + with self.assertRaises(ValueError): + with eval_mode(net): + result = net(torch.randn((2, 4, 16))) + self.assertEqual(result.shape, (2, 4, 16)) + + +if __name__ == "__main__": + unittest.main() From 96e9c9d8fa161fb423889f5eded5902ec426b119 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 13 Mar 2023 17:57:03 +0000 Subject: [PATCH 5/5] Address comments Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/blocks/__init__.py | 14 ++++++++++++++ generative/networks/blocks/selfattention.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) create mode 100644 generative/networks/blocks/__init__.py diff --git a/generative/networks/blocks/__init__.py b/generative/networks/blocks/__init__.py new file mode 100644 index 00000000..9036f087 --- /dev/null +++ b/generative/networks/blocks/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from .selfattention import SABlock diff --git a/generative/networks/blocks/selfattention.py b/generative/networks/blocks/selfattention.py index 4102a310..396bffcc 100644 --- a/generative/networks/blocks/selfattention.py +++ b/generative/networks/blocks/selfattention.py @@ -75,7 +75,7 @@ def __init__( else: self.mask = None - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) if self.sequence_length is not None and t != self.sequence_length: