From 0804ae9a46202c3ff52c1a4ab17453e2150751fa Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 24 Apr 2024 05:27:56 +0000 Subject: [PATCH 1/2] update transformers update transformers fix fix --- colossalai/shardformer/modeling/mistral.py | 2 +- colossalai/shardformer/policies/gpt2.py | 9 ++++++++- colossalai/shardformer/policies/gptj.py | 9 ++++++++- colossalai/shardformer/policies/llama.py | 19 +++++++++++++------ colossalai/shardformer/policies/mistral.py | 18 +++++++++++------- colossalai/shardformer/policies/opt.py | 14 +++++++++++--- tests/kit/model_zoo/transformers/llama.py | 1 - .../test_model/test_shard_mistral.py | 2 +- 8 files changed, 53 insertions(+), 21 deletions(-) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 3b876bcab96a..068623047af1 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -220,4 +220,4 @@ def forward( return attn_output, attn_weights, past_key_value - return forward + return forward \ No newline at end of file diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 98db7b948954..6f4f835a8dbe 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -35,13 +35,20 @@ def preprocess(self): Reshape the Embedding layer to make the embedding dimension divisible by world_size """ self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model + ATTN_IMPLEMENTATION = { + "eager": GPT2Attention, + } + policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D @@ -186,7 +193,7 @@ def module_policy(self): "forward": get_gpt2_flash_attention_forward(), }, policy=policy, - target_key=GPT2Attention, + target_key=attn_cls, ) if not self.shard_config.pipeline_stage_manager: policy[GPT2Model].method_replacement = { diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 4b69137a6892..1280efaec921 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -30,13 +30,20 @@ def config_sanity_check(self): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self): from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel + ATTN_IMPLEMENTATION = { + "eager": GPTJAttention, + } + policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D @@ -160,7 +167,7 @@ def module_policy(self): "forward": get_gptj_flash_attention_forward(), }, policy=policy, - target_key=GPTJAttention, + target_key=attn_cls, ) if not self.shard_config.pipeline_stage_manager: self.append_or_create_method_replacement( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index ff686a179553..bc54dcfc8a77 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -36,13 +36,20 @@ def config_sanity_check(self): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel - + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaFlashAttention2, LlamaSdpaAttention, LlamaDecoderLayer, LlamaModel + ATTN_IMPLEMENTATION = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, + } policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -93,7 +100,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), }, policy=policy, - target_key=LlamaAttention, + target_key=attn_cls, ) elif sp_mode == "all_to_all": decoder_attribute_replacement = { @@ -102,7 +109,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if getattr(self.model.config, "num_key_value_heads", False): decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size - policy[LlamaAttention] = ModulePolicyDescription( + policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) self.append_or_create_method_replacement( @@ -110,7 +117,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), }, policy=policy, - target_key=LlamaAttention, + target_key=attn_cls, ) self.append_or_create_method_replacement( description={ @@ -221,7 +228,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size), }, policy=policy, - target_key=LlamaAttention, + target_key=attn_cls, ) if self.pipeline_stage_manager is None: # replace llama model forward method diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 61e1b5f9c7b4..9e599cf19750 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -26,13 +26,21 @@ def config_sanity_check(self): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel + from transformers.models.mistral.modeling_mistral import MistralAttention, MistralFlashAttention2, MistralDecoderLayer, MistralModel + + ATTN_IMPLEMENTATION = { + "eager": MistralAttention, + "flash_attention_2": MistralFlashAttention2, + } policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -128,10 +136,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_mistral_flash_attention_forward(), + "forward": get_mistral_flash_attention_forward(self.shard_config), }, policy=policy, - target_key=MistralAttention, + target_key=attn_cls, ) return policy @@ -143,9 +151,6 @@ def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) method_replacement = {"forward": partial(new_forward)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) - def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - method_replacement = {"forward": partial(new_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) class MistralModelPolicy(MistralPolicy): @@ -155,7 +160,6 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() from transformers.models.mistral.modeling_mistral import MistralModel - self.set_forward(model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 2bb28b095114..6e248d09f396 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -44,13 +44,21 @@ def config_sanity_check(self): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self): - from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer + from transformers.models.opt.modeling_opt import OPTAttention, OptFlashAttention2, OPTDecoder, OPTDecoderLayer + + ATTN_IMPLEMENTATION = { + "eager": OPTAttention, + "flash_attention_2": OptFlashAttention2, + } policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -81,7 +89,7 @@ def module_policy(self): ] ) - policy[OPTAttention] = ModulePolicyDescription( + policy[attn_cls] = ModulePolicyDescription( attribute_replacement={ "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, @@ -151,7 +159,7 @@ def module_policy(self): "forward": get_opt_flash_attention_forward(self.shard_config), }, policy=policy, - target_key=OPTAttention, + target_key=attn_cls, ) if not self.shard_config.pipeline_stage_manager: self.append_or_create_method_replacement( diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 08c05e9063bf..58b5b0487a82 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -65,7 +65,6 @@ def data_gen_for_casual_lm(): num_attention_heads=4, max_position_embeddings=128, num_labels=16, - attn_implementation="eager", ) if hasattr(config, "pad_token_id"): diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py index 07bc91b33b72..f127472aee0b 100644 --- a/tests/test_shardformer/test_model/test_shard_mistral.py +++ b/tests/test_shardformer/test_model/test_shard_mistral.py @@ -156,7 +156,7 @@ def check_mistral(rank, world_size, port): run_mistral_test() -@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.") +@pytest.mark.skip("something wrong with pipeline parallelism") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() From a55324831859692994bbb79e7f7edcbcb6708b0d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Apr 2024 06:17:01 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/mistral.py | 2 +- colossalai/shardformer/policies/llama.py | 9 ++++++++- colossalai/shardformer/policies/mistral.py | 9 +++++++-- colossalai/shardformer/policies/opt.py | 2 +- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 068623047af1..3b876bcab96a 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -220,4 +220,4 @@ def forward( return attn_output, attn_weights, past_key_value - return forward \ No newline at end of file + return forward diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index bc54dcfc8a77..1b30ae9c9f40 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -40,7 +40,14 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.llama.modeling_llama import LlamaAttention, LlamaFlashAttention2, LlamaSdpaAttention, LlamaDecoderLayer, LlamaModel + from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaFlashAttention2, + LlamaModel, + LlamaSdpaAttention, + ) + ATTN_IMPLEMENTATION = { "eager": LlamaAttention, "flash_attention_2": LlamaFlashAttention2, diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 9e599cf19750..b3f89b4042c1 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -30,7 +30,12 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.mistral.modeling_mistral import MistralAttention, MistralFlashAttention2, MistralDecoderLayer, MistralModel + from transformers.models.mistral.modeling_mistral import ( + MistralAttention, + MistralDecoderLayer, + MistralFlashAttention2, + MistralModel, + ) ATTN_IMPLEMENTATION = { "eager": MistralAttention, @@ -152,7 +157,6 @@ def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) - class MistralModelPolicy(MistralPolicy): def __init__(self) -> None: super().__init__() @@ -160,6 +164,7 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() from transformers.models.mistral.modeling_mistral import MistralModel + self.set_forward(model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 6e248d09f396..2f6eabd5fef9 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -48,7 +48,7 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.opt.modeling_opt import OPTAttention, OptFlashAttention2, OPTDecoder, OPTDecoderLayer + from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer, OptFlashAttention2 ATTN_IMPLEMENTATION = { "eager": OPTAttention,