Skip to content

Access self-attention matrix of vision transformer #6032

@sophmrtn

Description

@sophmrtn

Is your feature request related to a problem? Please describe.
I am training a vision transformer using monai and would like to carry out some interpretability analysis. However, the current model code does not save the self-attention matrix during training, and it is not straightforward to pass it from the self-attention block to the model output.

def forward(self, x):
output = self.input_rearrange(self.qkv(x))
q, k, v = output[0], output[1], output[2]
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
return x

Describe the solution you'd like
An option to output the attn_mat from the self-attention block in the model forward pass (before matrix multiplication with the input) or access it after training as a class attribute.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions