From 1a11528b20b7a13c98176817fa145e31896d756b Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Fri, 29 Dec 2023 18:53:09 +0100 Subject: [PATCH 01/12] attention-rel-pos-embedd Signed-off-by: vgrau98 --- monai/networks/blocks/selfattention.py | 116 ++++++++++++++++++++++++- 1 file changed, 114 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 7c81c1704f..9c62e78506 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,8 +11,11 @@ from __future__ import annotations +from typing import Optional, Tuple + import torch import torch.nn as nn +import torch.nn.functional as F from monai.utils import optional_import @@ -23,6 +26,7 @@ class SABlock(nn.Module): """ A self-attention block, based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + One can setup relative positional embedding as described in """ def __init__( @@ -32,6 +36,8 @@ def __init__( dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, + use_rel_pos: bool = False, + input_size: Optional[Tuple[int, int]] = None, ) -> None: """ Args: @@ -39,6 +45,9 @@ def __init__( num_heads (int): number of attention heads. dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. + rel_pos (bool): If True, add relative positional embeddings to the attention map. Only support 2D inputs. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. """ @@ -62,11 +71,43 @@ def __init__( self.scale = self.head_dim**-0.5 self.save_attn = save_attn self.att_mat = torch.Tensor() + self.use_rel_pos = use_rel_pos + self.input_size = input_size + + if self.use_rel_pos: + assert input_size is not None, "Input size must be provided if using relative positional encoding." + assert len(input_size) == 2, "Relative positional embedding is only supported for 2D" + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim)) + + def forward(self, x: torch.Tensor): + """ + Args: + x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C - def forward(self, x): + Return: + torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C + """ output = self.input_rearrange(self.qkv(x)) q, k, v = output[0], output[1], output[2] - att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + + if self.use_rel_pos: + batch = x.shape[0] + h, w = self.input_size if self.input_size is not None else (0, 0) + att_mat = add_decomposed_rel_pos( + att_mat.view(batch * self.num_heads, h * w, h * w), + q.view(batch * self.num_heads, h * w, -1), + self.rel_pos_h, + self.rel_pos_w, + (h, w), + (h, w), + ) + att_mat = att_mat.reshape(batch, self.num_heads, h * w, h * w) + + att_mat = att_mat.softmax(dim=-1) + if self.save_attn: # no gradients and new tensor; # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html @@ -78,3 +119,74 @@ def forward(self, x): x = self.out_proj(x) x = self.drop_output(x) return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + rel_pos_resized: torch.Tensor = torch.Tensor() + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear" + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + rh = get_rel_pos(q_h, k_h, rel_pos_h) + rw = get_rel_pos(q_w, k_w, rel_pos_w) + + batch, _, dim = q.shape + r_q = q.reshape(batch, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw) + + attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( + batch, q_h * q_w, k_h * k_w + ) + + return attn From cdaaee9014c12641153f4c91e8d3dc325c3cf768 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Mon, 1 Jan 2024 19:52:27 +0100 Subject: [PATCH 02/12] feat: 3D decomposed relative positional embeddings Signed-off-by: vgrau98 --- monai/networks/blocks/selfattention.py | 92 +++++++++++++++----------- 1 file changed, 53 insertions(+), 39 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 9c62e78506..cc96926aa8 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -36,8 +36,8 @@ def __init__( dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, - use_rel_pos: bool = False, - input_size: Optional[Tuple[int, int]] = None, + use_rel_pos: Optional[str] = None, + input_size: Optional[Tuple] = None, ) -> None: """ Args: @@ -45,8 +45,9 @@ def __init__( num_heads (int): number of attention heads. dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. - rel_pos (bool): If True, add relative positional embeddings to the attention map. Only support 2D inputs. - input_size (tuple(int, int) or None): Input resolution for calculating the relative + rel_pos (str, optional): Add relative positional embeddings to the attention map. + For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. @@ -74,12 +75,11 @@ def __init__( self.use_rel_pos = use_rel_pos self.input_size = input_size - if self.use_rel_pos: + if self.use_rel_pos == "decomposed": assert input_size is not None, "Input size must be provided if using relative positional encoding." - assert len(input_size) == 2, "Relative positional embedding is only supported for 2D" - # initialize relative positional embeddings - self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim)) - self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim)) + self.rel_pos_arr = nn.ParameterList( + [nn.Parameter(torch.zeros(2 * dim_input_size - 1, self.head_dim)) for dim_input_size in input_size] + ) def forward(self, x: torch.Tensor): """ @@ -93,18 +93,18 @@ def forward(self, x: torch.Tensor): q, k, v = output[0], output[1], output[2] att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale - if self.use_rel_pos: + if self.use_rel_pos == "decomposed": batch = x.shape[0] - h, w = self.input_size if self.input_size is not None else (0, 0) + h, w = self.input_size[:2] if self.input_size is not None else (0, 0) + d = self.input_size[2] if self.input_size is not None and len(self.input_size) > 2 else 1 att_mat = add_decomposed_rel_pos( - att_mat.view(batch * self.num_heads, h * w, h * w), - q.view(batch * self.num_heads, h * w, -1), - self.rel_pos_h, - self.rel_pos_w, - (h, w), - (h, w), + att_mat.view(batch * self.num_heads, h * w * d, h * w * d), + q.view(batch * self.num_heads, h * w * d, -1), + self.rel_pos_arr, + (h, w) if d == 1 else (h, w, d), + (h, w) if d == 1 else (h, w, d), ) - att_mat = att_mat.reshape(batch, self.num_heads, h * w, h * w) + att_mat = att_mat.reshape(batch, self.num_heads, h * w * d, h * w * d) att_mat = att_mat.softmax(dim=-1) @@ -154,39 +154,53 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor def add_decomposed_rel_pos( - attn: torch.Tensor, - q: torch.Tensor, - rel_pos_h: torch.Tensor, - rel_pos_w: torch.Tensor, - q_size: Tuple[int, int], - k_size: Tuple[int, int], + attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple ) -> torch.Tensor: """ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Only 2D and 3D are supported. Args: attn (Tensor): attention map. - q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). - rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. - rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. - q_size (Tuple): spatial sequence size of query q with (q_h, q_w). - k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C). + rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis. + q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n). + k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n). Returns: attn (Tensor): attention map with added relative positional embeddings. """ - q_h, q_w = q_size - k_h, k_w = k_size - rh = get_rel_pos(q_h, k_h, rel_pos_h) - rw = get_rel_pos(q_w, k_w, rel_pos_w) + rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0]) + rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1]) batch, _, dim = q.shape - r_q = q.reshape(batch, q_h, q_w, dim) - rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh) - rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw) - attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( - batch, q_h * q_w, k_h * k_w - ) + if len(rel_pos_lst) == 2: + q_h, q_w = q_size[:2] + k_h, k_w = k_size[:2] + r_q = q.reshape(batch, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw) + + attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( + batch, q_h * q_w, k_h * k_w + ) + elif len(rel_pos_lst) == 3: + q_h, q_w, q_d = q_size[:3] + k_h, k_w, k_d = k_size[:3] + + rd = get_rel_pos(q_d, k_d, rel_pos_lst[2]) + + r_q = q.reshape(batch, q_h, q_w, q_d, dim) + rel_h = torch.einsum("bhwdc,hkc->bhwdk", r_q, rh) + rel_w = torch.einsum("bhwdc,wkc->bhwdk", r_q, rw) + rel_d = torch.einsum("bhwdc,wkc->bhwdk", r_q, rd) + + attn = ( + attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d) + + rel_h[:, :, :, :, None, None] + + rel_w[:, :, :, None, :, None] + + rel_d[:, :, :, None, None, :] + ).view(batch, q_h * q_w * q_d, k_h * k_w * k_d) return attn From bc66d68465cbeca4dfac316d196b0853636ec5f0 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sun, 7 Jan 2024 14:36:30 +0100 Subject: [PATCH 03/12] refacto Signed-off-by: vgrau98 --- monai/networks/blocks/attention_utils.py | 101 ++++++++++++++++++ monai/networks/blocks/rel_pos_embedding.py | 61 +++++++++++ monai/networks/blocks/selfattention.py | 117 ++------------------- monai/networks/layers/factories.py | 13 ++- monai/networks/layers/utils.py | 17 ++- 5 files changed, 200 insertions(+), 109 deletions(-) create mode 100644 monai/networks/blocks/attention_utils.py create mode 100644 monai/networks/blocks/rel_pos_embedding.py diff --git a/monai/networks/blocks/attention_utils.py b/monai/networks/blocks/attention_utils.py new file mode 100644 index 0000000000..a9fd6c89ae --- /dev/null +++ b/monai/networks/blocks/attention_utils.py @@ -0,0 +1,101 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + rel_pos_resized: torch.Tensor = torch.Tensor() + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear" + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Only 2D and 3D are supported. + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C). + rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis. + q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n). + k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0]) + rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1]) + + batch, _, dim = q.shape + + if len(rel_pos_lst) == 2: + q_h, q_w = q_size[:2] + k_h, k_w = k_size[:2] + r_q = q.reshape(batch, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw) + + attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( + batch, q_h * q_w, k_h * k_w + ) + elif len(rel_pos_lst) == 3: + q_h, q_w, q_d = q_size[:3] + k_h, k_w, k_d = k_size[:3] + + rd = get_rel_pos(q_d, k_d, rel_pos_lst[2]) + + r_q = q.reshape(batch, q_h, q_w, q_d, dim) + rel_h = torch.einsum("bhwdc,hkc->bhwdk", r_q, rh) + rel_w = torch.einsum("bhwdc,wkc->bhwdk", r_q, rw) + rel_d = torch.einsum("bhwdc,wkc->bhwdk", r_q, rd) + + attn = ( + attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d) + + rel_h[:, :, :, :, None, None] + + rel_w[:, :, :, None, :, None] + + rel_d[:, :, :, None, None, :] + ).view(batch, q_h * q_w * q_d, k_h * k_w * k_d) + + return attn diff --git a/monai/networks/blocks/rel_pos_embedding.py b/monai/networks/blocks/rel_pos_embedding.py new file mode 100644 index 0000000000..a96fcc9e21 --- /dev/null +++ b/monai/networks/blocks/rel_pos_embedding.py @@ -0,0 +1,61 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Tuple + +import torch +from torch import nn + +from monai.networks.blocks.attention_utils import add_decomposed_rel_pos + + +class RelativePosEmbedding(nn.Module): + def __init__( + self, + ) -> None: + super().__init__() + + def forward(self, x: torch.Tensor, att_mat: torch.Tensor) -> torch.Tensor: + ... + + +class DecomposedRelativePosEmbedding(RelativePosEmbedding): + def __init__(self, s_input_dims: Tuple, c_dim: int, num_heads: int) -> None: + """ + Args: + s_input_dims (Tuple): input spatial dimension. (H, W) or (H, W, D) + c_dim (int): channel dimension + num_heads(int): number of attentio heads + """ + super().__init__() + self.s_input_dims = s_input_dims + self.c_dim = c_dim + self.num_heads = num_heads + self.rel_pos_arr = nn.ParameterList( + [nn.Parameter(torch.zeros(2 * dim_input_size - 1, c_dim)) for dim_input_size in s_input_dims] + ) + + def forward(self, x: torch.Tensor, att_mat: torch.Tensor, q: torch.Tensor) -> torch.Tensor: + """""" + batch = x.shape[0] + h, w = self.s_input_dims[:2] if self.s_input_dims is not None else (0, 0) + d = self.s_input_dims[2] if self.s_input_dims is not None and len(self.s_input_dims) > 2 else 1 + + att_mat = add_decomposed_rel_pos( + att_mat.contiguous().view(batch * self.num_heads, h * w * d, h * w * d), + q.contiguous().view(batch * self.num_heads, h * w * d, -1), + self.rel_pos_arr, + (h, w) if d == 1 else (h, w, d), + (h, w) if d == 1 else (h, w, d), + ) + + att_mat = att_mat.reshape(batch, self.num_heads, h * w * d, h * w * d) + return att_mat diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index cc96926aa8..3bef24b4e8 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -15,8 +15,8 @@ import torch import torch.nn as nn -import torch.nn.functional as F +from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -36,7 +36,7 @@ def __init__( dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, - use_rel_pos: Optional[str] = None, + rel_pos_embedding: Optional[str] = None, input_size: Optional[Tuple] = None, ) -> None: """ @@ -45,7 +45,7 @@ def __init__( num_heads (int): number of attention heads. dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. - rel_pos (str, optional): Add relative positional embeddings to the attention map. + rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. @@ -72,15 +72,13 @@ def __init__( self.scale = self.head_dim**-0.5 self.save_attn = save_attn self.att_mat = torch.Tensor() - self.use_rel_pos = use_rel_pos + self.rel_positional_embedding = ( + get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads) + if rel_pos_embedding is not None + else None + ) self.input_size = input_size - if self.use_rel_pos == "decomposed": - assert input_size is not None, "Input size must be provided if using relative positional encoding." - self.rel_pos_arr = nn.ParameterList( - [nn.Parameter(torch.zeros(2 * dim_input_size - 1, self.head_dim)) for dim_input_size in input_size] - ) - def forward(self, x: torch.Tensor): """ Args: @@ -93,18 +91,8 @@ def forward(self, x: torch.Tensor): q, k, v = output[0], output[1], output[2] att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale - if self.use_rel_pos == "decomposed": - batch = x.shape[0] - h, w = self.input_size[:2] if self.input_size is not None else (0, 0) - d = self.input_size[2] if self.input_size is not None and len(self.input_size) > 2 else 1 - att_mat = add_decomposed_rel_pos( - att_mat.view(batch * self.num_heads, h * w * d, h * w * d), - q.view(batch * self.num_heads, h * w * d, -1), - self.rel_pos_arr, - (h, w) if d == 1 else (h, w, d), - (h, w) if d == 1 else (h, w, d), - ) - att_mat = att_mat.reshape(batch, self.num_heads, h * w * d, h * w * d) + # apply relative positional embedding if defined + att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat att_mat = att_mat.softmax(dim=-1) @@ -119,88 +107,3 @@ def forward(self, x: torch.Tensor): x = self.out_proj(x) x = self.drop_output(x) return x - - -def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: - """ - Get relative positional embeddings according to the relative positions of - query and key sizes. - Args: - q_size (int): size of query q. - k_size (int): size of key k. - rel_pos (Tensor): relative position embeddings (L, C). - - Returns: - Extracted positional embeddings according to relative positions. - """ - rel_pos_resized: torch.Tensor = torch.Tensor() - max_rel_dist = int(2 * max(q_size, k_size) - 1) - # Interpolate rel pos if needed. - if rel_pos.shape[0] != max_rel_dist: - # Interpolate rel pos. - rel_pos_resized = F.interpolate( - rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear" - ) - rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) - else: - rel_pos_resized = rel_pos - - # Scale the coords with short length if shapes for q and k are different. - q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) - k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) - relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) - - return rel_pos_resized[relative_coords.long()] - - -def add_decomposed_rel_pos( - attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple -) -> torch.Tensor: - """ - Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. - https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 - Only 2D and 3D are supported. - Args: - attn (Tensor): attention map. - q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C). - rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis. - q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n). - k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n). - - Returns: - attn (Tensor): attention map with added relative positional embeddings. - """ - rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0]) - rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1]) - - batch, _, dim = q.shape - - if len(rel_pos_lst) == 2: - q_h, q_w = q_size[:2] - k_h, k_w = k_size[:2] - r_q = q.reshape(batch, q_h, q_w, dim) - rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh) - rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw) - - attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( - batch, q_h * q_w, k_h * k_w - ) - elif len(rel_pos_lst) == 3: - q_h, q_w, q_d = q_size[:3] - k_h, k_w, k_d = k_size[:3] - - rd = get_rel_pos(q_d, k_d, rel_pos_lst[2]) - - r_q = q.reshape(batch, q_h, q_w, q_d, dim) - rel_h = torch.einsum("bhwdc,hkc->bhwdk", r_q, rh) - rel_w = torch.einsum("bhwdc,wkc->bhwdk", r_q, rw) - rel_d = torch.einsum("bhwdc,wkc->bhwdk", r_q, rd) - - attn = ( - attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d) - + rel_h[:, :, :, :, None, None] - + rel_w[:, :, :, None, :, None] - + rel_d[:, :, :, None, None, :] - ).view(batch, q_h * q_w * q_d, k_h * k_w * k_d) - - return attn diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 4fc2c16f73..d57421132b 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -70,7 +70,7 @@ def use_factory(fact_args): from monai.networks.utils import has_nvfuser_instance_norm from monai.utils import ComponentStore, look_up_option, optional_import -__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] +__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "RelPosEmbedding", "split_args"] class LayerFactory(ComponentStore): @@ -201,6 +201,10 @@ def split_args(args): Conv = LayerFactory(name="Convolution layers", description="Factory for creating convolution layers.") Pool = LayerFactory(name="Pooling layers", description="Factory for creating pooling layers.") Pad = LayerFactory(name="Padding layers", description="Factory for creating padding layers.") +RelPosEmbedding = LayerFactory( + name="Relative positional embedding layers", + description="Factory for creating relative positional embedding factory", +) @Dropout.factory_function("dropout") @@ -468,3 +472,10 @@ def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | """ types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d) return types[dim - 1] + + +@RelPosEmbedding.factory_function("decomposed") +def decomposed_rel_pos_embedding() -> type[Any]: + from monai.networks.blocks.rel_pos_embedding import DecomposedRelativePosEmbedding + + return DecomposedRelativePosEmbedding diff --git a/monai/networks/layers/utils.py b/monai/networks/layers/utils.py index ace1af27b6..2e010f77c4 100644 --- a/monai/networks/layers/utils.py +++ b/monai/networks/layers/utils.py @@ -13,7 +13,7 @@ import torch.nn -from monai.networks.layers.factories import Act, Dropout, Norm, Pool, split_args +from monai.networks.layers.factories import Act, Dropout, Norm, Pool, RelPosEmbedding, split_args from monai.utils import has_option __all__ = ["get_norm_layer", "get_act_layer", "get_dropout_layer", "get_pool_layer"] @@ -124,3 +124,18 @@ def get_pool_layer(name: tuple | str, spatial_dims: int | None = 1): pool_name, pool_args = split_args(name) pool_type = Pool[pool_name, spatial_dims] return pool_type(**pool_args) + + +def get_rel_pos_embedding_layer(name: tuple | str, s_input_dims: tuple, c_dim: int, num_heads: int): + embedding_name, embedding_args = split_args(name) + kw_args = dict(embedding_args) + embedding_type = RelPosEmbedding[embedding_name] + + if has_option(embedding_type, "s_input_dims") and "s_input_dims" not in kw_args: + embedding_args["s_input_dims"] = s_input_dims + if has_option(embedding_type, "c_dim") and "c_dim" not in kw_args: + embedding_args["c_dim"] = c_dim + if has_option(embedding_type, "num_heads") and "num_heads" not in kw_args: + embedding_args["num_heads"] = num_heads + + return embedding_type(**embedding_args) From 188fdd7db017cf590758817df5ca43a09be4ac8a Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sun, 7 Jan 2024 15:27:33 +0100 Subject: [PATCH 04/12] add tests Signed-off-by: vgrau98 --- monai/networks/blocks/rel_pos_embedding.py | 12 +----------- tests/test_selfattention.py | 21 +++++++++++++++------ 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/monai/networks/blocks/rel_pos_embedding.py b/monai/networks/blocks/rel_pos_embedding.py index a96fcc9e21..4aee63d5f1 100644 --- a/monai/networks/blocks/rel_pos_embedding.py +++ b/monai/networks/blocks/rel_pos_embedding.py @@ -17,17 +17,7 @@ from monai.networks.blocks.attention_utils import add_decomposed_rel_pos -class RelativePosEmbedding(nn.Module): - def __init__( - self, - ) -> None: - super().__init__() - - def forward(self, x: torch.Tensor, att_mat: torch.Tensor) -> torch.Tensor: - ... - - -class DecomposedRelativePosEmbedding(RelativePosEmbedding): +class DecomposedRelativePosEmbedding(nn.Module): def __init__(self, s_input_dims: Tuple, c_dim: int, num_heads: int) -> None: """ Args: diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 6062b5352f..0d0553ed2c 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -20,6 +20,7 @@ from monai.networks import eval_mode from monai.networks.blocks.selfattention import SABlock +from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import einops, has_einops = optional_import("einops") @@ -28,12 +29,20 @@ for dropout_rate in np.linspace(0, 1, 4): for hidden_size in [360, 480, 600, 768]: for num_heads in [4, 6, 8, 12]: - test_case = [ - {"hidden_size": hidden_size, "num_heads": num_heads, "dropout_rate": dropout_rate}, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: + for input_size in [(16, 32), (8, 8, 8)]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase): From 361e008f5e298160f1ff5b78914ef7d673a9c039 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sun, 7 Jan 2024 23:53:58 +0100 Subject: [PATCH 05/12] mypy Signed-off-by: vgrau98 --- monai/networks/layers/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/networks/layers/utils.py b/monai/networks/layers/utils.py index 2e010f77c4..5ad8e244c3 100644 --- a/monai/networks/layers/utils.py +++ b/monai/networks/layers/utils.py @@ -11,6 +11,8 @@ from __future__ import annotations +from typing import Optional + import torch.nn from monai.networks.layers.factories import Act, Dropout, Norm, Pool, RelPosEmbedding, split_args @@ -126,7 +128,7 @@ def get_pool_layer(name: tuple | str, spatial_dims: int | None = 1): return pool_type(**pool_args) -def get_rel_pos_embedding_layer(name: tuple | str, s_input_dims: tuple, c_dim: int, num_heads: int): +def get_rel_pos_embedding_layer(name: tuple | str, s_input_dims: Optional[tuple], c_dim: int, num_heads: int): embedding_name, embedding_args = split_args(name) kw_args = dict(embedding_args) embedding_type = RelPosEmbedding[embedding_name] From b26206cd9d95589c9eecff9bf2391786473eaccd Mon Sep 17 00:00:00 2001 From: vgrau98 <35843843+vgrau98@users.noreply.github.com> Date: Sun, 14 Jan 2024 00:26:02 +0100 Subject: [PATCH 06/12] Fix rpe layer generation Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: vgrau98 <35843843+vgrau98@users.noreply.github.com> --- monai/networks/layers/utils.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/monai/networks/layers/utils.py b/monai/networks/layers/utils.py index 5ad8e244c3..10b619cbe0 100644 --- a/monai/networks/layers/utils.py +++ b/monai/networks/layers/utils.py @@ -130,14 +130,10 @@ def get_pool_layer(name: tuple | str, spatial_dims: int | None = 1): def get_rel_pos_embedding_layer(name: tuple | str, s_input_dims: Optional[tuple], c_dim: int, num_heads: int): embedding_name, embedding_args = split_args(name) - kw_args = dict(embedding_args) embedding_type = RelPosEmbedding[embedding_name] - - if has_option(embedding_type, "s_input_dims") and "s_input_dims" not in kw_args: - embedding_args["s_input_dims"] = s_input_dims - if has_option(embedding_type, "c_dim") and "c_dim" not in kw_args: - embedding_args["c_dim"] = c_dim - if has_option(embedding_type, "num_heads") and "num_heads" not in kw_args: - embedding_args["num_heads"] = num_heads - - return embedding_type(**embedding_args) + # create a dictionary with the default values which can be overridden by embedding_args + kw_args = {"s_input_dims": s_input_dims, "c_dim": c_dim, "num_heads": num_heads, **embedding_args} + # filter out unused argument names + kw_args = {k: v for k, v in kw_args.items() if has_option(embedding_type, k)} + + return embedding_type(**kw_args) From 2bdd32203e709f9e833d4e4a4b56e3cda1210a9d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:26:32 +0000 Subject: [PATCH 07/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/layers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/layers/utils.py b/monai/networks/layers/utils.py index 10b619cbe0..8676f74638 100644 --- a/monai/networks/layers/utils.py +++ b/monai/networks/layers/utils.py @@ -135,5 +135,5 @@ def get_rel_pos_embedding_layer(name: tuple | str, s_input_dims: Optional[tuple] kw_args = {"s_input_dims": s_input_dims, "c_dim": c_dim, "num_heads": num_heads, **embedding_args} # filter out unused argument names kw_args = {k: v for k, v in kw_args.items() if has_option(embedding_type, k)} - + return embedding_type(**kw_args) From 80f39b892f25f54e6caaeb6e1db3f316339208c7 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sun, 14 Jan 2024 00:52:51 +0100 Subject: [PATCH 08/12] check s_input_dims in DecomposedRelativePosEmbedding constructor Signed-off-by: vgrau98 --- monai/networks/blocks/rel_pos_embedding.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/monai/networks/blocks/rel_pos_embedding.py b/monai/networks/blocks/rel_pos_embedding.py index 4aee63d5f1..a2c6564ac7 100644 --- a/monai/networks/blocks/rel_pos_embedding.py +++ b/monai/networks/blocks/rel_pos_embedding.py @@ -9,7 +9,7 @@ from __future__ import annotations -from typing import Tuple +from typing import Iterable, Tuple import torch from torch import nn @@ -18,14 +18,19 @@ class DecomposedRelativePosEmbedding(nn.Module): - def __init__(self, s_input_dims: Tuple, c_dim: int, num_heads: int) -> None: + def __init__(self, s_input_dims: Tuple[int, int] | Tuple[int, int, int], c_dim: int, num_heads: int) -> None: """ Args: s_input_dims (Tuple): input spatial dimension. (H, W) or (H, W, D) c_dim (int): channel dimension - num_heads(int): number of attentio heads + num_heads(int): number of attention heads """ super().__init__() + + # validate inputs + if not isinstance(s_input_dims, Iterable) or len(s_input_dims) not in [2, 3]: + raise ValueError("s_input_dims must be set as follows: (H, W) or (H, W, D)") + self.s_input_dims = s_input_dims self.c_dim = c_dim self.num_heads = num_heads @@ -36,8 +41,8 @@ def __init__(self, s_input_dims: Tuple, c_dim: int, num_heads: int) -> None: def forward(self, x: torch.Tensor, att_mat: torch.Tensor, q: torch.Tensor) -> torch.Tensor: """""" batch = x.shape[0] - h, w = self.s_input_dims[:2] if self.s_input_dims is not None else (0, 0) - d = self.s_input_dims[2] if self.s_input_dims is not None and len(self.s_input_dims) > 2 else 1 + h, w = self.s_input_dims[:2] + d = self.s_input_dims[2] if len(self.s_input_dims) == 3 else 1 att_mat = add_decomposed_rel_pos( att_mat.contiguous().view(batch * self.num_heads, h * w * d, h * w * d), From 78883072b3c973395fa76c8c80ae6785227b2a66 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sun, 14 Jan 2024 00:54:32 +0100 Subject: [PATCH 09/12] refacto DecomposedRelativePosEmbedding Signed-off-by: vgrau98 --- monai/networks/blocks/rel_pos_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/rel_pos_embedding.py b/monai/networks/blocks/rel_pos_embedding.py index a2c6564ac7..e53e5841b0 100644 --- a/monai/networks/blocks/rel_pos_embedding.py +++ b/monai/networks/blocks/rel_pos_embedding.py @@ -15,6 +15,7 @@ from torch import nn from monai.networks.blocks.attention_utils import add_decomposed_rel_pos +from monai.utils.misc import ensure_tuple_size class DecomposedRelativePosEmbedding(nn.Module): @@ -41,8 +42,7 @@ def __init__(self, s_input_dims: Tuple[int, int] | Tuple[int, int, int], c_dim: def forward(self, x: torch.Tensor, att_mat: torch.Tensor, q: torch.Tensor) -> torch.Tensor: """""" batch = x.shape[0] - h, w = self.s_input_dims[:2] - d = self.s_input_dims[2] if len(self.s_input_dims) == 3 else 1 + h, w, d = ensure_tuple_size(self.s_input_dims, 3, 1) att_mat = add_decomposed_rel_pos( att_mat.contiguous().view(batch * self.num_heads, h * w * d, h * w * d), From 3cab918b8d5e8b9dfec70561adfb610d35faaa0b Mon Sep 17 00:00:00 2001 From: vgrau98 <35843843+vgrau98@users.noreply.github.com> Date: Sun, 14 Jan 2024 01:00:27 +0100 Subject: [PATCH 10/12] typing Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: vgrau98 <35843843+vgrau98@users.noreply.github.com> --- monai/networks/layers/factories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index d57421132b..f583ff61b6 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -475,7 +475,7 @@ def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | @RelPosEmbedding.factory_function("decomposed") -def decomposed_rel_pos_embedding() -> type[Any]: +def decomposed_rel_pos_embedding() -> type[nn.Module}: from monai.networks.blocks.rel_pos_embedding import DecomposedRelativePosEmbedding return DecomposedRelativePosEmbedding From 6d61b222a9d6e369f1f6139c8dbeee5003d63ae1 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sun, 14 Jan 2024 01:02:44 +0100 Subject: [PATCH 11/12] typo Signed-off-by: vgrau98 --- monai/networks/layers/factories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index f583ff61b6..29b72a4f37 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -475,7 +475,7 @@ def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | @RelPosEmbedding.factory_function("decomposed") -def decomposed_rel_pos_embedding() -> type[nn.Module}: +def decomposed_rel_pos_embedding() -> type[nn.Module]: from monai.networks.blocks.rel_pos_embedding import DecomposedRelativePosEmbedding return DecomposedRelativePosEmbedding From 6fb51f5d74697bd71710040ac8fb135ad1ff6c08 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sun, 14 Jan 2024 18:33:39 +0100 Subject: [PATCH 12/12] doc Signed-off-by: vgrau98 --- docs/source/networks.rst | 6 ++++ monai/networks/blocks/attention_utils.py | 37 ++++++++++++++++++++---- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index f9375f1e97..556bf12d50 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -248,6 +248,12 @@ Blocks .. autoclass:: monai.apps.reconstruction.networks.blocks.varnetblock.VarNetBlock :members: +`Attention utilities` +~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: monai.networks.blocks.attention_utils +.. autofunction:: monai.networks.blocks.attention_utils.get_rel_pos +.. autofunction:: monai.networks.blocks.attention_utils.add_decomposed_rel_pos + N-Dim Fourier Transform ~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: monai.networks.blocks.fft_utils_t diff --git a/monai/networks/blocks/attention_utils.py b/monai/networks/blocks/attention_utils.py index a9fd6c89ae..8c9002a16e 100644 --- a/monai/networks/blocks/attention_utils.py +++ b/monai/networks/blocks/attention_utils.py @@ -19,7 +19,8 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: """ Get relative positional embeddings according to the relative positions of - query and key sizes. + query and key sizes. + Args: q_size (int): size of query q. k_size (int): size of key k. @@ -51,10 +52,36 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor def add_decomposed_rel_pos( attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple ) -> torch.Tensor: - """ - Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. - https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + r""" + Calculate decomposed Relative Positional Embeddings from mvitv2 implementation: + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + Only 2D and 3D are supported. + + Encoding the relative position of tokens in the attention matrix: tokens spaced a distance + `d` apart will have the same embedding value (unlike absolute positional embedding). + + .. math:: + Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale + + where + + .. math:: + E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)} + + with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`, + respectively spatial positions of element :math:`i` and :math:`j` + + When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow: + + .. math:: + R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)} + + with :math:`n = 1...dim` + + Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to + :math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding. + Args: attn (Tensor): attention map. q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C). @@ -63,7 +90,7 @@ def add_decomposed_rel_pos( k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n). Returns: - attn (Tensor): attention map with added relative positional embeddings. + attn (Tensor): attention logits with added relative positional embeddings. """ rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0]) rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1])