diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 519c8c7728..71fb549db8 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -25,13 +25,21 @@ class SABlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ - def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + save_attn: bool = False, + ) -> None: """ Args: - hidden_size: dimension of hidden layer. - num_heads: number of attention heads. - dropout_rate: faction of the input units to drop. - qkv_bias: bias term for the qkv linear layer. + hidden_size (int): dimension of hidden layer. + num_heads (int): number of attention heads. + dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. + qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. + save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. """ @@ -52,11 +60,18 @@ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, self.drop_weights = nn.Dropout(dropout_rate) self.head_dim = hidden_size // num_heads self.scale = self.head_dim**-0.5 + self.save_attn = save_attn + self.att_mat = torch.Tensor() def forward(self, x): output = self.input_rearrange(self.qkv(x)) q, k, v = output[0], output[1], output[2] att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() + att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 3a4b507d69..f7d4e0e130 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -24,15 +24,22 @@ class TransformerBlock(nn.Module): """ def __init__( - self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False + self, + hidden_size: int, + mlp_dim: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + save_attn: bool = False, ) -> None: """ Args: - hidden_size: dimension of hidden layer. - mlp_dim: dimension of feedforward layer. - num_heads: number of attention heads. - dropout_rate: faction of the input units to drop. - qkv_bias: apply bias term for the qkv linear layer + hidden_size (int): dimension of hidden layer. + mlp_dim (int): dimension of feedforward layer. + num_heads (int): number of attention heads. + dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. + qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False. + save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. """ @@ -46,7 +53,7 @@ def __init__( self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) self.norm1 = nn.LayerNorm(hidden_size) - self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias) + self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn) self.norm2 = nn.LayerNorm(hidden_size) def forward(self, x): diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index f3896d76c4..8cd42b54b1 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -46,24 +46,27 @@ def __init__( spatial_dims: int = 3, post_activation="Tanh", qkv_bias: bool = False, + save_attn: bool = False, ) -> None: """ Args: - in_channels: dimension of input channels. - img_size: dimension of input image. - patch_size: dimension of patch size. - hidden_size: dimension of hidden layer. - mlp_dim: dimension of feedforward layer. - num_layers: number of transformer blocks. - num_heads: number of attention heads. - pos_embed: position embedding layer type. - classification: bool argument to determine if classification is used. - num_classes: number of classes if classification is used. - dropout_rate: faction of the input units to drop. - spatial_dims: number of spatial dimensions. - post_activation: add a final acivation function to the classification head when `classification` is True. - Default to "Tanh" for `nn.Tanh()`. Set to other values to remove this function. - qkv_bias: apply bias to the qkv linear layer in self attention block + in_channels (int): dimension of input channels. + img_size (Union[Sequence[int], int]): dimension of input image. + patch_size (Union[Sequence[int], int]): dimension of patch size. + hidden_size (int, optional): dimension of hidden layer. Defaults to 768. + mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072. + num_layers (int, optional): number of transformer blocks. Defaults to 12. + num_heads (int, optional): number of attention heads. Defaults to 12. + pos_embed (str, optional): position embedding layer type. Defaults to "conv". + classification (bool, optional): bool argument to determine if classification is used. Defaults to False. + num_classes (int, optional): number of classes if classification is used. Defaults to 2. + dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. + spatial_dims (int, optional): number of spatial dimensions. Defaults to 3. + post_activation (str, optional): add a final acivation function to the classification head + when `classification` is True. Default to "Tanh" for `nn.Tanh()`. + Set to other values to remove this function. + qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False. + save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False. Examples:: @@ -98,7 +101,10 @@ def __init__( spatial_dims=spatial_dims, ) self.blocks = nn.ModuleList( - [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias) for i in range(num_layers)] + [ + TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn) + for i in range(num_layers) + ] ) self.norm = nn.LayerNorm(hidden_size) if self.classification: diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 926ef7da55..13054bd561 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -52,6 +52,27 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + def test_access_attn_matrix(self): + # input format + hidden_size = 128 + num_heads = 2 + dropout_rate = 0 + input_shape = (2, 256, hidden_size) + + # be not able to access the matrix + no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate) + no_matrix_acess_blk(torch.randn(input_shape)) + assert type(no_matrix_acess_blk.att_mat) == torch.Tensor + # no of elements is zero + assert no_matrix_acess_blk.att_mat.nelement() == 0 + + # be able to acess the attention matrix + matrix_acess_blk = SABlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True + ) + matrix_acess_blk(torch.randn(input_shape)) + assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_transformerblock.py b/tests/test_transformerblock.py index f1a20b842c..a966dcbfdc 100644 --- a/tests/test_transformerblock.py +++ b/tests/test_transformerblock.py @@ -53,6 +53,30 @@ def test_ill_arg(self): with self.assertRaises(ValueError): TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4) + def test_access_attn_matrix(self): + # input format + hidden_size = 128 + mlp_dim = 12 + num_heads = 2 + dropout_rate = 0 + input_shape = (2, 256, hidden_size) + + # returns an empty attention matrix + no_matrix_acess_blk = TransformerBlock( + hidden_size=hidden_size, mlp_dim=mlp_dim, num_heads=num_heads, dropout_rate=dropout_rate + ) + no_matrix_acess_blk(torch.randn(input_shape)) + assert type(no_matrix_acess_blk.attn.att_mat) == torch.Tensor + # no of elements is zero + assert no_matrix_acess_blk.attn.att_mat.nelement() == 0 + + # be able to acess the attention matrix + matrix_acess_blk = TransformerBlock( + hidden_size=hidden_size, mlp_dim=mlp_dim, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True + ) + matrix_acess_blk(torch.randn(input_shape)) + assert matrix_acess_blk.attn.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_vit.py b/tests/test_vit.py index 504c1ccebd..d193e6d222 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -150,6 +150,25 @@ def test_script(self, input_param, input_shape, _): test_data = torch.randn(input_shape) test_script_save(net, test_data) + def test_access_attn_matrix(self): + # input format + in_channels = 1 + img_size = (96, 96, 96) + patch_size = (16, 16, 16) + in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2]) + + # no data in the matrix + no_matrix_acess_blk = ViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size) + no_matrix_acess_blk(torch.randn(in_shape)) + assert type(no_matrix_acess_blk.blocks[0].attn.att_mat) == torch.Tensor + # no of elements is zero + assert no_matrix_acess_blk.blocks[0].attn.att_mat.nelement() == 0 + + # be able to acess the attention matrix + matrix_acess_blk = ViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size, save_attn=True) + matrix_acess_blk(torch.randn(in_shape)) + assert matrix_acess_blk.blocks[0].attn.att_mat.shape == (in_shape[0], 12, 216, 216) + if __name__ == "__main__": unittest.main()