Skip to content
Merged
23 changes: 18 additions & 5 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,21 @@ class SABlock(nn.Module):
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
"""

def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None:
def __init__(
self,
hidden_size: int,
num_heads: int,
dropout_rate: float = 0.0,
qkv_bias: bool = False,
save_attn: bool = False,
) -> None:
"""
Args:
hidden_size: dimension of hidden layer.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.
qkv_bias: bias term for the qkv linear layer.
hidden_size (int): dimension of hidden layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.

"""

Expand All @@ -52,11 +60,16 @@ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0,
self.drop_weights = nn.Dropout(dropout_rate)
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim**-0.5
self.save_attn = save_attn

def forward(self, x):
output = self.input_rearrange(self.qkv(x))
q, k, v = output[0], output[1], output[2]
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
Expand Down
13 changes: 10 additions & 3 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,22 @@ class TransformerBlock(nn.Module):
"""

def __init__(
self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False
self,
hidden_size: int,
mlp_dim: int,
num_heads: int,
dropout_rate: float = 0.0,
qkv_bias: bool = False,
save_attn: bool = False,
) -> None:
"""
Args:
hidden_size: dimension of hidden layer.
mlp_dim: dimension of feedforward layer.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.
qkv_bias: apply bias term for the qkv linear layer
qkv_bias: apply bias term for the qkv linear layer.
save_attn: to make accessible the attention matrix post training.

"""

Expand All @@ -46,7 +53,7 @@ def __init__(

self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)
self.norm1 = nn.LayerNorm(hidden_size)
self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias)
self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn)
self.norm2 = nn.LayerNorm(hidden_size)

def forward(self, x):
Expand Down
20 changes: 20 additions & 0 deletions tests/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,26 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4)

def test_access_attn_matrix(self):
# input format
hidden_size = 128
num_heads = 2
dropout_rate = 0
input_shape = (2, 256, hidden_size)

# be able to access the matrix
no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate)
with self.assertRaises(AttributeError):
no_matrix_acess_blk(torch.randn(input_shape))
no_matrix_acess_blk.att_mat

# be not able to acess the attention matrix
matrix_acess_blk = SABlock(
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True
)
matrix_acess_blk(torch.randn(input_shape))
assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1])


if __name__ == "__main__":
unittest.main()