Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 5 additions & 18 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,13 @@ class SABlock(nn.Module):
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
"""

def __init__(
self,
hidden_size: int,
num_heads: int,
dropout_rate: float = 0.0,
qkv_bias: bool = False,
save_attn: bool = False,
) -> None:
def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None:
"""
Args:
hidden_size (int): dimension of hidden layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
hidden_size: dimension of hidden layer.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.
qkv_bias: bias term for the qkv linear layer.

"""

Expand All @@ -60,16 +52,11 @@ def __init__(
self.drop_weights = nn.Dropout(dropout_rate)
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim**-0.5
self.save_attn = save_attn

def forward(self, x):
output = self.input_rearrange(self.qkv(x))
q, k, v = output[0], output[1], output[2]
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
Expand Down
13 changes: 3 additions & 10 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,15 @@ class TransformerBlock(nn.Module):
"""

def __init__(
self,
hidden_size: int,
mlp_dim: int,
num_heads: int,
dropout_rate: float = 0.0,
qkv_bias: bool = False,
save_attn: bool = False,
self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False
) -> None:
"""
Args:
hidden_size: dimension of hidden layer.
mlp_dim: dimension of feedforward layer.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.
qkv_bias: apply bias term for the qkv linear layer.
save_attn: to make accessible the attention matrix post training.
qkv_bias: apply bias term for the qkv linear layer

"""

Expand All @@ -53,7 +46,7 @@ def __init__(

self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)
self.norm1 = nn.LayerNorm(hidden_size)
self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn)
self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias)
self.norm2 = nn.LayerNorm(hidden_size)

def forward(self, x):
Expand Down
20 changes: 0 additions & 20 deletions tests/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,6 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4)

def test_access_attn_matrix(self):
# input format
hidden_size = 128
num_heads = 2
dropout_rate = 0
input_shape = (2, 256, hidden_size)

# be able to access the matrix
no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate)
with self.assertRaises(AttributeError):
no_matrix_acess_blk(torch.randn(input_shape))
no_matrix_acess_blk.att_mat

# be not able to acess the attention matrix
matrix_acess_blk = SABlock(
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True
)
matrix_acess_blk(torch.randn(input_shape))
assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1])


if __name__ == "__main__":
unittest.main()