From 71c3f0cb1bc52a08afff84ba5c5a1998923ed79a Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 24 Apr 2024 12:37:57 +0800 Subject: [PATCH] [shardformer] fix attn replacement --- colossalai/shardformer/policies/falcon.py | 20 +++++------- colossalai/shardformer/policies/sam.py | 34 +++++++++++---------- colossalai/shardformer/policies/whisper.py | 16 ++++++++++ tests/kit/model_zoo/transformers/whisper.py | 1 - 4 files changed, 42 insertions(+), 29 deletions(-) diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 628e9fdc0d96..09d895843b61 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -7,12 +7,7 @@ import colossalai.shardformer.layer as col_nn -from ..modeling.falcon import ( - FalconPipelineForwards, - build_falcon_alibi_tensor_fn, - get_falcon_flash_attention_forward, - get_tp_falcon_decoder_layer_forward, -) +from ..modeling.falcon import FalconPipelineForwards, build_falcon_alibi_tensor_fn, get_tp_falcon_decoder_layer_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["FalconPolicy"] @@ -30,7 +25,7 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.falcon.modeling_falcon import FalconAttention, FalconDecoderLayer, FalconModel + from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel if not self.model.config.new_decoder_architecture and self.model.config.multi_query: warnings.warn( @@ -141,11 +136,12 @@ def module_policy(self): ) if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement( - description={"forward": get_falcon_flash_attention_forward()}, - policy=policy, - target_key=FalconAttention, - ) + warnings.warn("Falcon doesn't support flash attention now, fallback to transformers attention.") + # self.append_or_create_method_replacement( + # description={"forward": get_falcon_flash_attention_forward()}, + # policy=policy, + # target_key=FalconAttention, + # ) return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index 498e62164b09..ce33925ff82e 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -1,6 +1,8 @@ +import warnings + import colossalai.shardformer.layer as col_nn -from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward +from ..modeling.sam import forward_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["SamPolicy", "SamModelPolicy"] @@ -15,7 +17,6 @@ def preprocess(self): def module_policy(self): from transformers.models.sam.modeling_sam import ( - SamAttention, SamTwoWayAttentionBlock, SamTwoWayTransformer, SamVisionAttention, @@ -210,20 +211,21 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement( - description={ - "forward": get_sam_flash_attention_forward(), - }, - policy=policy, - target_key=SamAttention, - ) - self.append_or_create_method_replacement( - description={ - "forward": get_sam_vision_flash_attention_forward(), - }, - policy=policy, - target_key=SamVisionAttention, - ) + warnings.warn("Flash attention is not supported in SAM model. Fallback to normal attention.") + # self.append_or_create_method_replacement( + # description={ + # "forward": get_sam_flash_attention_forward(), + # }, + # policy=policy, + # target_key=SamAttention, + # ) + # self.append_or_create_method_replacement( + # description={ + # "forward": get_sam_vision_flash_attention_forward(), + # }, + # policy=policy, + # target_key=SamVisionAttention, + # ) return policy diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 16ed2607c6f7..aeb6687971e5 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -48,6 +48,8 @@ def module_policy(self): WhisperDecoderLayer, WhisperEncoder, WhisperEncoderLayer, + WhisperFlashAttention2, + WhisperSdpaAttention, ) policy = {} @@ -242,6 +244,20 @@ def module_policy(self): policy=policy, target_key=WhisperAttention, ) + self.append_or_create_method_replacement( + description={ + "forward": get_whisper_flash_attention_forward(), + }, + policy=policy, + target_key=WhisperFlashAttention2, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_whisper_flash_attention_forward(), + }, + policy=policy, + target_key=WhisperSdpaAttention, + ) if not self.shard_config.pipeline_stage_manager: self.append_or_create_method_replacement( description={ diff --git a/tests/kit/model_zoo/transformers/whisper.py b/tests/kit/model_zoo/transformers/whisper.py index 0d9a581dfbe9..d69bebe6cc04 100644 --- a/tests/kit/model_zoo/transformers/whisper.py +++ b/tests/kit/model_zoo/transformers/whisper.py @@ -66,7 +66,6 @@ def data_gen_for_audio_classification(): encoder_ffn_dim=1536, encoder_layers=2, vocab_size=51866, - _attn_implementation="eager", ) # register the Whisper variants