Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions colossalai/shardformer/modeling/sam.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import math
from typing import Tuple

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor, nn
from torch.distributed import ProcessGroup


Expand Down Expand Up @@ -39,3 +44,162 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch
return outputs

return forward


def get_sam_flash_attention_forward():

from transformers.models.sam.modeling_sam import SamAttention
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")

def _separate_heads(hidden_states: Tensor, num_attention_heads: int) -> Tensor:
batch, point_batch_size, n_tokens, channel = hidden_states.shape
c_per_head = channel // num_attention_heads
hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
return hidden_states

def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor:
batch, n_tokens, n_heads, c_per_head = hidden_states.shape
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)

def forward(self: SamAttention,
query: Tensor,
key: Tensor,
value: Tensor,
attention_similarity: Tensor = None) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)

point_batch_size = query.shape[1]
# Separate into heads
query = _separate_heads(query, self.num_attention_heads)
key = _separate_heads(key, self.num_attention_heads)
value = _separate_heads(value, self.num_attention_heads)

# SamAttention
_, _, _, c_per_head = query.shape
bias = None
if attention_similarity is not None:
bias = attention_similarity

scale = 1.0 / math.sqrt(c_per_head)
out = me_attention(query, key, value, attn_bias=bias, scale=scale)

out = _recombine_heads(out, point_batch_size)
out = self.out_proj(out)

return out

return forward


def get_sam_vision_flash_attention_forward():

from transformers.models.sam.modeling_sam import SamVisionAttention
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")

def add_decomposed_rel_pos(
query: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py

Args:
attn (`torch.Tensor`):
attention map.
query (`torch.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`torch.Tensor`):
relative position embeddings (Lh, channel) for height axis.
rel_pos_w (`torch.Tensor`):
relative position embeddings (Lw, channel) for width axis.
q_size (tuple):
spatial sequence size of query q with (query_height, query_width).
k_size (tuple):
spatial sequence size of key k with (key_height, key_width).

Returns:
attn (`torch.Tensor`):
attention map with added relative positional embeddings.
"""

query_height, query_width = q_size
key_height, key_width = k_size
relative_position_height = get_rel_pos(query_height, key_height, rel_pos_h)
relative_position_width = get_rel_pos(query_width, key_width, rel_pos_w)

batch_size, _, nHead, dim = query.shape
reshaped_query = query.transpose(1, 2).reshape(batch_size * nHead, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
rel_pos = rel_pos.reshape(batch_size, nHead, query_height * query_width, key_height * key_width)
return rel_pos

def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.

Args:
q_size (int):
size of the query.
k_size (int):
size of key k.
rel_pos (`torch.Tensor`):
relative position embeddings (L, channel).

Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)

# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)

return rel_pos_resized[relative_coords.long()]

def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
# qkv with shape (3, batch_size, nHead, height * width, channel)
qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads,
-1).permute(2, 0, 1, 3, 4))

query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0)

rel_pos = None
if self.use_rel_pos:
rel_pos = add_decomposed_rel_pos(query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width))

attn_output = me_attention(query, key, value, attn_bias=rel_pos, p=self.dropout, scale=self.scale)

attn_output = attn_output.reshape(batch_size, height, width, -1)

attn_output = self.proj(attn_output)

outputs = (attn_output, None)

return outputs

return forward
12 changes: 11 additions & 1 deletion colossalai/shardformer/policies/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import colossalai.shardformer.layer as col_nn

from .._utils import getattr_, setattr_
from ..modeling.sam import forward_fn
from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ['SamPolicy', 'SamModelPolicy']
Expand All @@ -19,6 +19,7 @@ def preprocess(self):

def module_policy(self):
from transformers.models.sam.modeling_sam import (
SamAttention,
SamFeedForward,
SamTwoWayAttentionBlock,
SamTwoWayTransformer,
Expand Down Expand Up @@ -196,6 +197,15 @@ def module_policy(self):
policy=policy,
target_key=SamTwoWayTransformer)

# use flash attention
if self.shard_config.enable_flash_attention:
policy[SamAttention] = ModulePolicyDescription(method_replacement={
'forward': get_sam_flash_attention_forward(),
})
policy[SamVisionAttention] = ModulePolicyDescription(method_replacement={
'forward': get_sam_vision_flash_attention_forward(),
})

return policy

def postprocess(self):
Expand Down
6 changes: 4 additions & 2 deletions tests/test_shardformer/test_model/test_shard_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo

@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_sam_test(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('enable_flash_attention', [True, False])
def run_sam_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention):
sub_model_zoo = model_zoo.get_sub_registry('transformers_sam')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)

torch.cuda.empty_cache()
Expand Down