From c379dc92283b7040d65f4985093358dc4269384c Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Tue, 4 Apr 2023 11:46:49 +0100 Subject: [PATCH] Revert "feat(SABlock): access to the attn matrix (#6271)" This reverts commit 9e33ff25f34ebd11e026c49634743d9389564765. --- monai/networks/blocks/selfattention.py | 23 +++++------------------ monai/networks/blocks/transformerblock.py | 13 +++---------- tests/test_selfattention.py | 20 -------------------- 3 files changed, 8 insertions(+), 48 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 03a7bc7e08..519c8c7728 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -25,21 +25,13 @@ class SABlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ - def __init__( - self, - hidden_size: int, - num_heads: int, - dropout_rate: float = 0.0, - qkv_bias: bool = False, - save_attn: bool = False, - ) -> None: + def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None: """ Args: - hidden_size (int): dimension of hidden layer. - num_heads (int): number of attention heads. - dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. - qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. - save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + hidden_size: dimension of hidden layer. + num_heads: number of attention heads. + dropout_rate: faction of the input units to drop. + qkv_bias: bias term for the qkv linear layer. """ @@ -60,16 +52,11 @@ def __init__( self.drop_weights = nn.Dropout(dropout_rate) self.head_dim = hidden_size // num_heads self.scale = self.head_dim**-0.5 - self.save_attn = save_attn def forward(self, x): output = self.input_rearrange(self.qkv(x)) q, k, v = output[0], output[1], output[2] att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) - if self.save_attn: - # no gradients and new tensor; - # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html - self.att_mat = att_mat.detach() att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 30f2c2756a..3a4b507d69 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -24,13 +24,7 @@ class TransformerBlock(nn.Module): """ def __init__( - self, - hidden_size: int, - mlp_dim: int, - num_heads: int, - dropout_rate: float = 0.0, - qkv_bias: bool = False, - save_attn: bool = False, + self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False ) -> None: """ Args: @@ -38,8 +32,7 @@ def __init__( mlp_dim: dimension of feedforward layer. num_heads: number of attention heads. dropout_rate: faction of the input units to drop. - qkv_bias: apply bias term for the qkv linear layer. - save_attn: to make accessible the attention matrix post training. + qkv_bias: apply bias term for the qkv linear layer """ @@ -53,7 +46,7 @@ def __init__( self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) self.norm1 = nn.LayerNorm(hidden_size) - self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn) + self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias) self.norm2 = nn.LayerNorm(hidden_size) def forward(self, x): diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index a67f54f704..926ef7da55 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -52,26 +52,6 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) - def test_access_attn_matrix(self): - # input format - hidden_size = 128 - num_heads = 2 - dropout_rate = 0 - input_shape = (2, 256, hidden_size) - - # be able to access the matrix - no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate) - with self.assertRaises(AttributeError): - no_matrix_acess_blk(torch.randn(input_shape)) - no_matrix_acess_blk.att_mat - - # be not able to acess the attention matrix - matrix_acess_blk = SABlock( - hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True - ) - matrix_acess_blk(torch.randn(input_shape)) - assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) - if __name__ == "__main__": unittest.main()