Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions generative/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from .selfattention import SABlock
103 changes: 103 additions & 0 deletions generative/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import math

import torch
import torch.nn as nn
from torch.nn import functional as F


class SABlock(nn.Module):
"""
A self-attention block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"

Args:
hidden_size: dimension of hidden layer.
num_heads: number of attention heads.
dropout_rate: dropout ratio. Defaults to no dropout.
qkv_bias: bias term for the qkv linear layer.
causal: whether to use causal attention.
sequence_length: if causal is True, it is necessary to specify the sequence length.
"""

def __init__(
self,
hidden_size: int,
num_heads: int,
dropout_rate: float = 0.0,
qkv_bias: bool = False,
causal: bool = False,
sequence_length: int | None = None,
) -> None:
super().__init__()

if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")

if hidden_size % num_heads != 0:
raise ValueError("hidden size should be divisible by num_heads.")

if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

# output projection
self.out_proj = nn.Linear(hidden_size, hidden_size)
# key, query, value projections for all heads, but in a batch
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
# regularization
self.drop_weights = nn.Dropout(dropout_rate)
self.drop_output = nn.Dropout(dropout_rate)

self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
self.causal = causal
self.sequence_length = sequence_length

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
self.mask = torch.tril(torch.ones(sequence_length, sequence_length)).view(
1, 1, sequence_length, sequence_length
)
else:
self.mask = None

def forward(self, x: torch.Tensor) -> torch.Tensor:
b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size)

if self.sequence_length is not None and t != self.sequence_length:
raise ValueError("sequence length should be equal to the one specified in the SABlock constructor.")

# calculate query, key, values for all heads in batch and move head forward to be the batch dim
query, key, value = self.qkv(x).split(self.hidden_size, dim=2)
key = key.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
query = query.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
value = value.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)

# manual implementation of attention
attention_scores = (query @ key.transpose(-2, -1)) * self.scale

if self.causal:
attention_scores = attention_scores.masked_fill(self.mask[:, :, :t, :t] == 0, float("-inf"))

attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = self.drop_weights(attention_probs)
y = attention_probs @ value # (b, nh, t, t) x (b, nh, t, hs) -> (b, nh, t, hs)
y = y.transpose(1, 2).contiguous().view(b, t, c) # re-assemble all head outputs side by side

y = self.out_proj(y)
y = self.drop_output(y)
return y
66 changes: 66 additions & 0 deletions tests/test_selfattention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch
from monai.networks import eval_mode
from parameterized import parameterized

from generative.networks.blocks.selfattention import SABlock
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should use an __init__.py file so we can have the nicer
from generative.networks.blocks import SABlock?


TEST_CASE_SABLOCK = [
[
{"hidden_size": 16, "num_heads": 8, "dropout_rate": 0.2, "causal": False, "sequence_length": None},
(2, 4, 16),
(2, 4, 16),
],
[
{"hidden_size": 16, "num_heads": 8, "dropout_rate": 0.2, "causal": True, "sequence_length": 4},
(2, 4, 16),
(2, 4, 16),
],
]


class TestResBlock(unittest.TestCase):
@parameterized.expand(TEST_CASE_SABLOCK)
def test_shape(self, input_param, input_shape, expected_shape):
net = SABlock(**input_param)
with eval_mode(net):
result = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)

def test_ill_arg(self):
with self.assertRaises(ValueError):
SABlock(hidden_size=12, num_heads=4, dropout_rate=6.0)

with self.assertRaises(ValueError):
SABlock(hidden_size=12, num_heads=4, dropout_rate=-6.0)

with self.assertRaises(ValueError):
SABlock(hidden_size=20, num_heads=8, dropout_rate=0.4)

with self.assertRaises(ValueError):
SABlock(hidden_size=12, num_heads=4, dropout_rate=0.4, causal=True, sequence_length=None)

def test_wrong_sequence_length(self):
net = SABlock(hidden_size=16, num_heads=4, dropout_rate=0.0, causal=True, sequence_length=6)
with self.assertRaises(ValueError):
with eval_mode(net):
result = net(torch.randn((2, 4, 16)))
self.assertEqual(result.shape, (2, 4, 16))


if __name__ == "__main__":
unittest.main()