diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 7c81c1704f..7b410b1a7c 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -32,6 +32,7 @@ def __init__( dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, + dim_head: int | None = None, ) -> None: """ Args: @@ -40,6 +41,7 @@ def __init__( dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. """ @@ -52,14 +54,16 @@ def __init__( raise ValueError("hidden size should be divisible by num_heads.") self.num_heads = num_heads - self.out_proj = nn.Linear(hidden_size, hidden_size) - self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) + self.dim_head = hidden_size // num_heads if dim_head is None else dim_head + self.inner_dim = self.dim_head * num_heads + + self.out_proj = nn.Linear(self.inner_dim, hidden_size) + self.qkv = nn.Linear(hidden_size, self.inner_dim * 3, bias=qkv_bias) self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) - self.head_dim = hidden_size // num_heads - self.scale = self.head_dim**-0.5 + self.scale = self.dim_head**-0.5 self.save_attn = save_attn self.att_mat = torch.Tensor() diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index b8be4fd1b6..0ebed84159 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -74,6 +74,40 @@ def test_access_attn_matrix(self): matrix_acess_blk(torch.randn(input_shape)) assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + def test_number_of_parameters(self): + + def count_sablock_params(*args, **kwargs): + """Count the number of parameters in a SABlock.""" + sablock = SABlock(*args, **kwargs) + return sum([x.numel() for x in sablock.parameters() if x.requires_grad]) + + hidden_size = 128 + num_heads = 8 + default_dim_head = hidden_size // num_heads + + # Default dim_head is hidden_size // num_heads + nparams_default = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads) + nparams_like_default = count_sablock_params( + hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head + ) + self.assertEqual(nparams_default, nparams_like_default) + + # Increasing dim_head should increase the number of parameters + nparams_custom_large = count_sablock_params( + hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head * 2 + ) + self.assertGreater(nparams_custom_large, nparams_default) + + # Decreasing dim_head should decrease the number of parameters + nparams_custom_small = count_sablock_params( + hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head // 2 + ) + self.assertGreater(nparams_default, nparams_custom_small) + + # Increasing the number of heads with the default behaviour should not change the number of params. + nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2) + self.assertEqual(nparams_default, nparams_default_more_heads) + if __name__ == "__main__": unittest.main()