This repository was archived by the owner on Feb 7, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 105
Add causal self-attention block #218
Merged
Warvito
merged 6 commits into
main
from
196-add-self-attention-block-for-autoregressive-transformer
Mar 13, 2023
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
35e8c0f
Add original SABlock
Warvito 06314aa
Move towards implementation without eionps
Warvito 02cd120
Move towards implementation without eionps
Warvito b3797d6
Merge branch 'main' into 196-add-self-attention-block-for-autoregress…
Warvito f50b9ee
Add causal operation to self attention block
Warvito 96e9c9d
Address comments
Warvito File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from .selfattention import SABlock |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import math | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from torch.nn import functional as F | ||
|
|
||
|
|
||
| class SABlock(nn.Module): | ||
| """ | ||
| A self-attention block, based on: "Dosovitskiy et al., | ||
| An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>" | ||
|
|
||
| Args: | ||
| hidden_size: dimension of hidden layer. | ||
| num_heads: number of attention heads. | ||
| dropout_rate: dropout ratio. Defaults to no dropout. | ||
| qkv_bias: bias term for the qkv linear layer. | ||
| causal: whether to use causal attention. | ||
| sequence_length: if causal is True, it is necessary to specify the sequence length. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| hidden_size: int, | ||
| num_heads: int, | ||
| dropout_rate: float = 0.0, | ||
| qkv_bias: bool = False, | ||
| causal: bool = False, | ||
| sequence_length: int | None = None, | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| if not (0 <= dropout_rate <= 1): | ||
| raise ValueError("dropout_rate should be between 0 and 1.") | ||
|
|
||
| if hidden_size % num_heads != 0: | ||
| raise ValueError("hidden size should be divisible by num_heads.") | ||
|
|
||
| if causal and sequence_length is None: | ||
| raise ValueError("sequence_length is necessary for causal attention.") | ||
|
|
||
| # output projection | ||
| self.out_proj = nn.Linear(hidden_size, hidden_size) | ||
| # key, query, value projections for all heads, but in a batch | ||
| self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) | ||
| # regularization | ||
| self.drop_weights = nn.Dropout(dropout_rate) | ||
| self.drop_output = nn.Dropout(dropout_rate) | ||
|
|
||
| self.hidden_size = hidden_size | ||
| self.num_heads = num_heads | ||
| self.head_dim = hidden_size // num_heads | ||
| self.scale = 1.0 / math.sqrt(self.head_dim) | ||
| self.causal = causal | ||
| self.sequence_length = sequence_length | ||
|
|
||
| if causal and sequence_length is not None: | ||
| # causal mask to ensure that attention is only applied to the left in the input sequence | ||
| self.mask = torch.tril(torch.ones(sequence_length, sequence_length)).view( | ||
| 1, 1, sequence_length, sequence_length | ||
| ) | ||
| else: | ||
| self.mask = None | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) | ||
|
|
||
| if self.sequence_length is not None and t != self.sequence_length: | ||
| raise ValueError("sequence length should be equal to the one specified in the SABlock constructor.") | ||
|
|
||
| # calculate query, key, values for all heads in batch and move head forward to be the batch dim | ||
| query, key, value = self.qkv(x).split(self.hidden_size, dim=2) | ||
| key = key.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) | ||
| query = query.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) | ||
| value = value.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) | ||
|
|
||
| # manual implementation of attention | ||
| attention_scores = (query @ key.transpose(-2, -1)) * self.scale | ||
|
|
||
| if self.causal: | ||
| attention_scores = attention_scores.masked_fill(self.mask[:, :, :t, :t] == 0, float("-inf")) | ||
|
|
||
| attention_probs = F.softmax(attention_scores, dim=-1) | ||
| attention_probs = self.drop_weights(attention_probs) | ||
| y = attention_probs @ value # (b, nh, t, t) x (b, nh, t, hs) -> (b, nh, t, hs) | ||
| y = y.transpose(1, 2).contiguous().view(b, t, c) # re-assemble all head outputs side by side | ||
|
|
||
| y = self.out_proj(y) | ||
| y = self.drop_output(y) | ||
| return y |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| from monai.networks import eval_mode | ||
| from parameterized import parameterized | ||
|
|
||
| from generative.networks.blocks.selfattention import SABlock | ||
|
|
||
| TEST_CASE_SABLOCK = [ | ||
| [ | ||
| {"hidden_size": 16, "num_heads": 8, "dropout_rate": 0.2, "causal": False, "sequence_length": None}, | ||
| (2, 4, 16), | ||
| (2, 4, 16), | ||
| ], | ||
| [ | ||
| {"hidden_size": 16, "num_heads": 8, "dropout_rate": 0.2, "causal": True, "sequence_length": 4}, | ||
| (2, 4, 16), | ||
| (2, 4, 16), | ||
| ], | ||
| ] | ||
|
|
||
|
|
||
| class TestResBlock(unittest.TestCase): | ||
| @parameterized.expand(TEST_CASE_SABLOCK) | ||
| def test_shape(self, input_param, input_shape, expected_shape): | ||
| net = SABlock(**input_param) | ||
| with eval_mode(net): | ||
| result = net(torch.randn(input_shape)) | ||
| self.assertEqual(result.shape, expected_shape) | ||
|
|
||
| def test_ill_arg(self): | ||
| with self.assertRaises(ValueError): | ||
| SABlock(hidden_size=12, num_heads=4, dropout_rate=6.0) | ||
|
|
||
| with self.assertRaises(ValueError): | ||
| SABlock(hidden_size=12, num_heads=4, dropout_rate=-6.0) | ||
|
|
||
| with self.assertRaises(ValueError): | ||
| SABlock(hidden_size=20, num_heads=8, dropout_rate=0.4) | ||
|
|
||
| with self.assertRaises(ValueError): | ||
| SABlock(hidden_size=12, num_heads=4, dropout_rate=0.4, causal=True, sequence_length=None) | ||
|
|
||
| def test_wrong_sequence_length(self): | ||
| net = SABlock(hidden_size=16, num_heads=4, dropout_rate=0.0, causal=True, sequence_length=6) | ||
| with self.assertRaises(ValueError): | ||
| with eval_mode(net): | ||
| result = net(torch.randn((2, 4, 16))) | ||
| self.assertEqual(result.shape, (2, 4, 16)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we should use an
__init__.pyfile so we can have the nicerfrom generative.networks.blocks import SABlock?