diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 519c8c7728..03a7bc7e08 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -25,13 +25,21 @@ class SABlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ - def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + save_attn: bool = False, + ) -> None: """ Args: - hidden_size: dimension of hidden layer. - num_heads: number of attention heads. - dropout_rate: faction of the input units to drop. - qkv_bias: bias term for the qkv linear layer. + hidden_size (int): dimension of hidden layer. + num_heads (int): number of attention heads. + dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. + qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. + save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. """ @@ -52,11 +60,16 @@ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, self.drop_weights = nn.Dropout(dropout_rate) self.head_dim = hidden_size // num_heads self.scale = self.head_dim**-0.5 + self.save_attn = save_attn def forward(self, x): output = self.input_rearrange(self.qkv(x)) q, k, v = output[0], output[1], output[2] att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 3a4b507d69..30f2c2756a 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -24,7 +24,13 @@ class TransformerBlock(nn.Module): """ def __init__( - self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False + self, + hidden_size: int, + mlp_dim: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + save_attn: bool = False, ) -> None: """ Args: @@ -32,7 +38,8 @@ def __init__( mlp_dim: dimension of feedforward layer. num_heads: number of attention heads. dropout_rate: faction of the input units to drop. - qkv_bias: apply bias term for the qkv linear layer + qkv_bias: apply bias term for the qkv linear layer. + save_attn: to make accessible the attention matrix post training. """ @@ -46,7 +53,7 @@ def __init__( self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) self.norm1 = nn.LayerNorm(hidden_size) - self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias) + self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn) self.norm2 = nn.LayerNorm(hidden_size) def forward(self, x): diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 926ef7da55..a67f54f704 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -52,6 +52,26 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + def test_access_attn_matrix(self): + # input format + hidden_size = 128 + num_heads = 2 + dropout_rate = 0 + input_shape = (2, 256, hidden_size) + + # be able to access the matrix + no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate) + with self.assertRaises(AttributeError): + no_matrix_acess_blk(torch.randn(input_shape)) + no_matrix_acess_blk.att_mat + + # be not able to acess the attention matrix + matrix_acess_blk = SABlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True + ) + matrix_acess_blk(torch.randn(input_shape)) + assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + if __name__ == "__main__": unittest.main()