diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 98db7b948954..6f4f835a8dbe 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -35,13 +35,20 @@ def preprocess(self): Reshape the Embedding layer to make the embedding dimension divisible by world_size """ self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model + ATTN_IMPLEMENTATION = { + "eager": GPT2Attention, + } + policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D @@ -186,7 +193,7 @@ def module_policy(self): "forward": get_gpt2_flash_attention_forward(), }, policy=policy, - target_key=GPT2Attention, + target_key=attn_cls, ) if not self.shard_config.pipeline_stage_manager: policy[GPT2Model].method_replacement = { diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 4b69137a6892..1280efaec921 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -30,13 +30,20 @@ def config_sanity_check(self): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self): from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel + ATTN_IMPLEMENTATION = { + "eager": GPTJAttention, + } + policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D @@ -160,7 +167,7 @@ def module_policy(self): "forward": get_gptj_flash_attention_forward(), }, policy=policy, - target_key=GPTJAttention, + target_key=attn_cls, ) if not self.shard_config.pipeline_stage_manager: self.append_or_create_method_replacement( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index ff686a179553..1b30ae9c9f40 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -36,13 +36,27 @@ def config_sanity_check(self): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel + from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaFlashAttention2, + LlamaModel, + LlamaSdpaAttention, + ) + ATTN_IMPLEMENTATION = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, + } policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -93,7 +107,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), }, policy=policy, - target_key=LlamaAttention, + target_key=attn_cls, ) elif sp_mode == "all_to_all": decoder_attribute_replacement = { @@ -102,7 +116,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if getattr(self.model.config, "num_key_value_heads", False): decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size - policy[LlamaAttention] = ModulePolicyDescription( + policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) self.append_or_create_method_replacement( @@ -110,7 +124,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), }, policy=policy, - target_key=LlamaAttention, + target_key=attn_cls, ) self.append_or_create_method_replacement( description={ @@ -221,7 +235,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size), }, policy=policy, - target_key=LlamaAttention, + target_key=attn_cls, ) if self.pipeline_stage_manager is None: # replace llama model forward method diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 61e1b5f9c7b4..b3f89b4042c1 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -26,13 +26,26 @@ def config_sanity_check(self): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel + from transformers.models.mistral.modeling_mistral import ( + MistralAttention, + MistralDecoderLayer, + MistralFlashAttention2, + MistralModel, + ) + + ATTN_IMPLEMENTATION = { + "eager": MistralAttention, + "flash_attention_2": MistralFlashAttention2, + } policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -128,10 +141,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_mistral_flash_attention_forward(), + "forward": get_mistral_flash_attention_forward(self.shard_config), }, policy=policy, - target_key=MistralAttention, + target_key=attn_cls, ) return policy @@ -143,10 +156,6 @@ def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) method_replacement = {"forward": partial(new_forward)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) - def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - method_replacement = {"forward": partial(new_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) - class MistralModelPolicy(MistralPolicy): def __init__(self) -> None: diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 2bb28b095114..2f6eabd5fef9 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -44,13 +44,21 @@ def config_sanity_check(self): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self): - from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer + from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer, OptFlashAttention2 + + ATTN_IMPLEMENTATION = { + "eager": OPTAttention, + "flash_attention_2": OptFlashAttention2, + } policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -81,7 +89,7 @@ def module_policy(self): ] ) - policy[OPTAttention] = ModulePolicyDescription( + policy[attn_cls] = ModulePolicyDescription( attribute_replacement={ "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, @@ -151,7 +159,7 @@ def module_policy(self): "forward": get_opt_flash_attention_forward(self.shard_config), }, policy=policy, - target_key=OPTAttention, + target_key=attn_cls, ) if not self.shard_config.pipeline_stage_manager: self.append_or_create_method_replacement( diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 08c05e9063bf..58b5b0487a82 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -65,7 +65,6 @@ def data_gen_for_casual_lm(): num_attention_heads=4, max_position_embeddings=128, num_labels=16, - attn_implementation="eager", ) if hasattr(config, "pad_token_id"): diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py index 07bc91b33b72..f127472aee0b 100644 --- a/tests/test_shardformer/test_model/test_shard_mistral.py +++ b/tests/test_shardformer/test_model/test_shard_mistral.py @@ -156,7 +156,7 @@ def check_mistral(rank, world_size, port): run_mistral_test() -@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.") +@pytest.mark.skip("something wrong with pipeline parallelism") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run()