Skip to content
25 changes: 20 additions & 5 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,21 @@ class SABlock(nn.Module):
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
"""

def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None:
def __init__(
self,
hidden_size: int,
num_heads: int,
dropout_rate: float = 0.0,
qkv_bias: bool = False,
save_attn: bool = False,
) -> None:
"""
Args:
hidden_size: dimension of hidden layer.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.
qkv_bias: bias term for the qkv linear layer.
hidden_size (int): dimension of hidden layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.

"""

Expand All @@ -52,11 +60,18 @@ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0,
self.drop_weights = nn.Dropout(dropout_rate)
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim**-0.5
self.save_attn = save_attn
self.att_mat = torch.Tensor()

def forward(self, x):
output = self.input_rearrange(self.qkv(x))
q, k, v = output[0], output[1], output[2]
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
Expand Down
21 changes: 14 additions & 7 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,22 @@ class TransformerBlock(nn.Module):
"""

def __init__(
self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False
self,
hidden_size: int,
mlp_dim: int,
num_heads: int,
dropout_rate: float = 0.0,
qkv_bias: bool = False,
save_attn: bool = False,
) -> None:
"""
Args:
hidden_size: dimension of hidden layer.
mlp_dim: dimension of feedforward layer.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.
qkv_bias: apply bias term for the qkv linear layer
hidden_size (int): dimension of hidden layer.
mlp_dim (int): dimension of feedforward layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.

"""

Expand All @@ -46,7 +53,7 @@ def __init__(

self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)
self.norm1 = nn.LayerNorm(hidden_size)
self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias)
self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn)
self.norm2 = nn.LayerNorm(hidden_size)

def forward(self, x):
Expand Down
38 changes: 22 additions & 16 deletions monai/networks/nets/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,27 @@ def __init__(
spatial_dims: int = 3,
post_activation="Tanh",
qkv_bias: bool = False,
save_attn: bool = False,
) -> None:
"""
Args:
in_channels: dimension of input channels.
img_size: dimension of input image.
patch_size: dimension of patch size.
hidden_size: dimension of hidden layer.
mlp_dim: dimension of feedforward layer.
num_layers: number of transformer blocks.
num_heads: number of attention heads.
pos_embed: position embedding layer type.
classification: bool argument to determine if classification is used.
num_classes: number of classes if classification is used.
dropout_rate: faction of the input units to drop.
spatial_dims: number of spatial dimensions.
post_activation: add a final acivation function to the classification head when `classification` is True.
Default to "Tanh" for `nn.Tanh()`. Set to other values to remove this function.
qkv_bias: apply bias to the qkv linear layer in self attention block
in_channels (int): dimension of input channels.
img_size (Union[Sequence[int], int]): dimension of input image.
patch_size (Union[Sequence[int], int]): dimension of patch size.
hidden_size (int, optional): dimension of hidden layer. Defaults to 768.
mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
num_layers (int, optional): number of transformer blocks. Defaults to 12.
num_heads (int, optional): number of attention heads. Defaults to 12.
pos_embed (str, optional): position embedding layer type. Defaults to "conv".
classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
num_classes (int, optional): number of classes if classification is used. Defaults to 2.
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
post_activation (str, optional): add a final acivation function to the classification head
when `classification` is True. Default to "Tanh" for `nn.Tanh()`.
Set to other values to remove this function.
qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.

Examples::

Expand Down Expand Up @@ -98,7 +101,10 @@ def __init__(
spatial_dims=spatial_dims,
)
self.blocks = nn.ModuleList(
[TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias) for i in range(num_layers)]
[
TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
for i in range(num_layers)
]
)
self.norm = nn.LayerNorm(hidden_size)
if self.classification:
Expand Down
21 changes: 21 additions & 0 deletions tests/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,27 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4)

def test_access_attn_matrix(self):
# input format
hidden_size = 128
num_heads = 2
dropout_rate = 0
input_shape = (2, 256, hidden_size)

# be not able to access the matrix
no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate)
no_matrix_acess_blk(torch.randn(input_shape))
assert type(no_matrix_acess_blk.att_mat) == torch.Tensor
# no of elements is zero
assert no_matrix_acess_blk.att_mat.nelement() == 0

# be able to acess the attention matrix
matrix_acess_blk = SABlock(
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True
)
matrix_acess_blk(torch.randn(input_shape))
assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1])


if __name__ == "__main__":
unittest.main()
24 changes: 24 additions & 0 deletions tests/test_transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,30 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4)

def test_access_attn_matrix(self):
# input format
hidden_size = 128
mlp_dim = 12
num_heads = 2
dropout_rate = 0
input_shape = (2, 256, hidden_size)

# returns an empty attention matrix
no_matrix_acess_blk = TransformerBlock(
hidden_size=hidden_size, mlp_dim=mlp_dim, num_heads=num_heads, dropout_rate=dropout_rate
)
no_matrix_acess_blk(torch.randn(input_shape))
assert type(no_matrix_acess_blk.attn.att_mat) == torch.Tensor
# no of elements is zero
assert no_matrix_acess_blk.attn.att_mat.nelement() == 0

# be able to acess the attention matrix
matrix_acess_blk = TransformerBlock(
hidden_size=hidden_size, mlp_dim=mlp_dim, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True
)
matrix_acess_blk(torch.randn(input_shape))
assert matrix_acess_blk.attn.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1])


if __name__ == "__main__":
unittest.main()
19 changes: 19 additions & 0 deletions tests/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,25 @@ def test_script(self, input_param, input_shape, _):
test_data = torch.randn(input_shape)
test_script_save(net, test_data)

def test_access_attn_matrix(self):
# input format
in_channels = 1
img_size = (96, 96, 96)
patch_size = (16, 16, 16)
in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2])

# no data in the matrix
no_matrix_acess_blk = ViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size)
no_matrix_acess_blk(torch.randn(in_shape))
assert type(no_matrix_acess_blk.blocks[0].attn.att_mat) == torch.Tensor
# no of elements is zero
assert no_matrix_acess_blk.blocks[0].attn.att_mat.nelement() == 0

# be able to acess the attention matrix
matrix_acess_blk = ViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size, save_attn=True)
matrix_acess_blk(torch.randn(in_shape))
assert matrix_acess_blk.blocks[0].attn.att_mat.shape == (in_shape[0], 12, 216, 216)


if __name__ == "__main__":
unittest.main()