-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Description
Is your feature request related to a problem? Please describe.
I am training a vision transformer using monai and would like to carry out some interpretability analysis. However, the current model code does not save the self-attention matrix during training, and it is not straightforward to pass it from the self-attention block to the model output.
MONAI/monai/networks/blocks/selfattention.py
Lines 56 to 65 in 88fb0b1
| def forward(self, x): | |
| output = self.input_rearrange(self.qkv(x)) | |
| q, k, v = output[0], output[1], output[2] | |
| att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) | |
| att_mat = self.drop_weights(att_mat) | |
| x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) | |
| x = self.out_rearrange(x) | |
| x = self.out_proj(x) | |
| x = self.drop_output(x) | |
| return x |
Describe the solution you'd like
An option to output the attn_mat from the self-attention block in the model forward pass (before matrix multiplication with the input) or access it after training as a class attribute.