diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py index b7945423ae83..c5c6b14ba993 100644 --- a/colossalai/shardformer/modeling/blip2.py +++ b/colossalai/shardformer/modeling/blip2.py @@ -1,3 +1,4 @@ +import math from typing import Optional, Tuple, Union import torch @@ -58,3 +59,62 @@ def forward( return outputs return forward + + +def get_blip2_flash_attention_forward(): + + from transformers.models.blip_2.modeling_blip_2 import Blip2Attention + + from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + + def forward( + self: Blip2Attention, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + mixed_qkv = self.qkv(hidden_states) + mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) + query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] + + attention = ColoAttention(embed_dim=self.embed_dim, + num_heads=self.num_heads, + dropout=self.dropout.p, + scale=self.scale) + context_layer = attention(query_states, key_states, value_states) + + output = self.projection(context_layer) + outputs = (output, None) + + return outputs + + return forward + + +def get_jit_fused_blip2_QFormer_self_output_forward(): + + from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput + + def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + return forward + + +def get_jit_fused_blip2_QFormer_output_forward(): + + from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput + + def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + return forward diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 43aa1adc1c5b..888e209f9e7c 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -3,7 +3,13 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ -from ..modeling.blip2 import forward_fn +from ..modeling.blip2 import ( + forward_fn, + get_blip2_flash_attention_forward, + get_jit_fused_blip2_QFormer_output_forward, + get_jit_fused_blip2_QFormer_self_output_forward, +) +from ..modeling.jit import get_jit_fused_dropout_add_func from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['BlipPolicy', 'BlipModelPolicy'] @@ -33,6 +39,8 @@ def module_policy(self): Blip2EncoderLayer, Blip2QFormerLayer, Blip2QFormerModel, + Blip2QFormerOutput, + Blip2QFormerSelfOutput, Blip2VisionModel, ) from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTForCausalLM @@ -275,6 +283,24 @@ def module_policy(self): policy=policy, target_key=OPTDecoderLayer) + # use flash attention + if self.shard_config.enable_flash_attention: + policy[Blip2Attention] = ModulePolicyDescription(method_replacement={ + 'forward': get_blip2_flash_attention_forward(), + }) + + # use jit operator + if self.shard_config.enable_jit_fused: + policy[Blip2QFormerSelfOutput] = ModulePolicyDescription( + method_replacement={ + 'forward': get_jit_fused_blip2_QFormer_self_output_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_blip2_QFormer_output_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + return policy def postprocess(self): diff --git a/tests/kit/model_zoo/transformers/blip2.py b/tests/kit/model_zoo/transformers/blip2.py index 7338f740be7f..984a6ffa920d 100644 --- a/tests/kit/model_zoo/transformers/blip2.py +++ b/tests/kit/model_zoo/transformers/blip2.py @@ -38,6 +38,7 @@ def data_gen(): loss_fn_blip2_model = lambda x: x.loss config = transformers.Blip2Config() +config.vision_config.patch_size = 14 config.text_config.num_hidden_layers = 1 config.qformer_config.num_hidden_layers = 1 config.vision_config.num_hidden_layers = 1 diff --git a/tests/test_shardformer/test_model/test_shard_blip2.py b/tests/test_shardformer/test_model/test_shard_blip2.py index f96299e55a49..0564af329d35 100644 --- a/tests/test_shardformer/test_model/test_shard_blip2.py +++ b/tests/test_shardformer/test_model/test_shard_blip2.py @@ -81,10 +81,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +@parameterize('enable_jit_fused', [True, False]) +def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_blip2') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + enable_flash_attention, enable_jit_fused) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index a8048a9bdd12..13cde0e611ff 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -50,7 +50,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" - + @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False])