Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ec217de
Feature/vit support (#4182)
klhhhhh Jul 7, 2023
ddecf73
[shardformer] support SAM (#4231)
FoolPlayer Jul 14, 2023
afcf4a0
[shardformer] support whisper (#4212)
FoolPlayer Jul 17, 2023
77cc087
Feature/chatglm (#4240)
klhhhhh Jul 20, 2023
6c2acf0
[shardformer] added tests
klhhhhh Jul 4, 2023
7668b24
[shardformer] vit test finish and support
klhhhhh Jul 6, 2023
b135b75
import chatglm
klhhhhh Jul 7, 2023
e3cd5cb
[shardformer] add test kit in model zoo for chatglm
klhhhhh Jul 7, 2023
30574a7
[sharformer] add first version of policy of chatglm
klhhhhh Jul 10, 2023
28677d4
[shardformer] polish chatglm code
klhhhhh Jul 12, 2023
28319c2
[shardformer] polish code
klhhhhh Jul 13, 2023
3f19de9
[shardformer] support chatglm without layernorm
klhhhhh Jul 14, 2023
2a4bbcf
[shardformer] delete some file
klhhhhh Jul 17, 2023
32448e3
[shardformer] ChatGLM support layernorm sharding
klhhhhh Jul 17, 2023
eb1c71a
[shardformer] register without auto policy
klhhhhh Jul 18, 2023
127e385
[shardformer] pre-commit check files
klhhhhh Jul 19, 2023
9d5b141
[shardformer] support ChatGLMForConditionalGeneration & add fusedlaye…
klhhhhh Jul 20, 2023
805f342
Merge pull request #4297 from klhhhhh/feature/support_ChatGLMForCondi…
klhhhhh Jul 21, 2023
f48a8bb
[shardformer] support Blip2 (#4243)
FoolPlayer Jul 25, 2023
bf2beb0
[shardformer] merge blip2 from shard-models branch
flybird11111 Jul 25, 2023
bb3a589
[shardformer] blip2 support flash attention and jit operator
flybird11111 Jul 25, 2023
5c49c0b
Merge branch 'feature/flash-attention-shardformer' into update-blip2
flybird11111 Jul 25, 2023
ce1eccf
[shardformer] blip2 support flash attention and jit operator
flybird11111 Jul 25, 2023
07a1b14
Merge branch 'update-blip2' of https://github.com/flybird1111/Colossa…
flybird11111 Jul 25, 2023
537e005
[shardformer] blip2 support flash attention and jit operator
flybird11111 Jul 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions colossalai/shardformer/modeling/blip2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -58,3 +59,62 @@ def forward(
return outputs

return forward


def get_blip2_flash_attention_forward():

from transformers.models.blip_2.modeling_blip_2 import Blip2Attention

from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention

def forward(
self: Blip2Attention,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

bsz, tgt_len, embed_dim = hidden_states.size()
mixed_qkv = self.qkv(hidden_states)
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4)
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]

attention = ColoAttention(embed_dim=self.embed_dim,
num_heads=self.num_heads,
dropout=self.dropout.p,
scale=self.scale)
context_layer = attention(query_states, key_states, value_states)

output = self.projection(context_layer)
outputs = (output, None)

return outputs

return forward


def get_jit_fused_blip2_QFormer_self_output_forward():

from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput

def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states

return forward


def get_jit_fused_blip2_QFormer_output_forward():

from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput

def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states

return forward
28 changes: 27 additions & 1 deletion colossalai/shardformer/policies/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import colossalai.shardformer.layer as col_nn

from .._utils import getattr_, setattr_
from ..modeling.blip2 import forward_fn
from ..modeling.blip2 import (
forward_fn,
get_blip2_flash_attention_forward,
get_jit_fused_blip2_QFormer_output_forward,
get_jit_fused_blip2_QFormer_self_output_forward,
)
from ..modeling.jit import get_jit_fused_dropout_add_func
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ['BlipPolicy', 'BlipModelPolicy']
Expand Down Expand Up @@ -33,6 +39,8 @@ def module_policy(self):
Blip2EncoderLayer,
Blip2QFormerLayer,
Blip2QFormerModel,
Blip2QFormerOutput,
Blip2QFormerSelfOutput,
Blip2VisionModel,
)
from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTForCausalLM
Expand Down Expand Up @@ -275,6 +283,24 @@ def module_policy(self):
policy=policy,
target_key=OPTDecoderLayer)

# use flash attention
if self.shard_config.enable_flash_attention:
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={
'forward': get_blip2_flash_attention_forward(),
})

# use jit operator
if self.shard_config.enable_jit_fused:
policy[Blip2QFormerSelfOutput] = ModulePolicyDescription(
method_replacement={
'forward': get_jit_fused_blip2_QFormer_self_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_blip2_QFormer_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})

return policy

def postprocess(self):
Expand Down
1 change: 1 addition & 0 deletions tests/kit/model_zoo/transformers/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def data_gen():
loss_fn_blip2_model = lambda x: x.loss

config = transformers.Blip2Config()
config.vision_config.patch_size = 14
config.text_config.num_hidden_layers = 1
config.qformer_config.num_hidden_layers = 1
config.vision_config.num_hidden_layers = 1
Expand Down
7 changes: 5 additions & 2 deletions tests/test_shardformer/test_model/test_shard_blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo

@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_blip2')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, enable_jit_fused)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)

torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/test_shard_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"


@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
Expand Down