From f3be14dfbf206fb417e307f28c7f299ea1c23870 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 24 Apr 2024 14:53:21 +0800 Subject: [PATCH 01/11] [shardformer] fix attn replacement (#5636) --- colossalai/shardformer/policies/falcon.py | 20 +++++------- colossalai/shardformer/policies/sam.py | 34 +++++++++++---------- colossalai/shardformer/policies/whisper.py | 16 ++++++++++ tests/kit/model_zoo/transformers/whisper.py | 1 - 4 files changed, 42 insertions(+), 29 deletions(-) diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 628e9fdc0d96..09d895843b61 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -7,12 +7,7 @@ import colossalai.shardformer.layer as col_nn -from ..modeling.falcon import ( - FalconPipelineForwards, - build_falcon_alibi_tensor_fn, - get_falcon_flash_attention_forward, - get_tp_falcon_decoder_layer_forward, -) +from ..modeling.falcon import FalconPipelineForwards, build_falcon_alibi_tensor_fn, get_tp_falcon_decoder_layer_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["FalconPolicy"] @@ -30,7 +25,7 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.falcon.modeling_falcon import FalconAttention, FalconDecoderLayer, FalconModel + from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel if not self.model.config.new_decoder_architecture and self.model.config.multi_query: warnings.warn( @@ -141,11 +136,12 @@ def module_policy(self): ) if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement( - description={"forward": get_falcon_flash_attention_forward()}, - policy=policy, - target_key=FalconAttention, - ) + warnings.warn("Falcon doesn't support flash attention now, fallback to transformers attention.") + # self.append_or_create_method_replacement( + # description={"forward": get_falcon_flash_attention_forward()}, + # policy=policy, + # target_key=FalconAttention, + # ) return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index 498e62164b09..ce33925ff82e 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -1,6 +1,8 @@ +import warnings + import colossalai.shardformer.layer as col_nn -from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward +from ..modeling.sam import forward_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["SamPolicy", "SamModelPolicy"] @@ -15,7 +17,6 @@ def preprocess(self): def module_policy(self): from transformers.models.sam.modeling_sam import ( - SamAttention, SamTwoWayAttentionBlock, SamTwoWayTransformer, SamVisionAttention, @@ -210,20 +211,21 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement( - description={ - "forward": get_sam_flash_attention_forward(), - }, - policy=policy, - target_key=SamAttention, - ) - self.append_or_create_method_replacement( - description={ - "forward": get_sam_vision_flash_attention_forward(), - }, - policy=policy, - target_key=SamVisionAttention, - ) + warnings.warn("Flash attention is not supported in SAM model. Fallback to normal attention.") + # self.append_or_create_method_replacement( + # description={ + # "forward": get_sam_flash_attention_forward(), + # }, + # policy=policy, + # target_key=SamAttention, + # ) + # self.append_or_create_method_replacement( + # description={ + # "forward": get_sam_vision_flash_attention_forward(), + # }, + # policy=policy, + # target_key=SamVisionAttention, + # ) return policy diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 16ed2607c6f7..aeb6687971e5 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -48,6 +48,8 @@ def module_policy(self): WhisperDecoderLayer, WhisperEncoder, WhisperEncoderLayer, + WhisperFlashAttention2, + WhisperSdpaAttention, ) policy = {} @@ -242,6 +244,20 @@ def module_policy(self): policy=policy, target_key=WhisperAttention, ) + self.append_or_create_method_replacement( + description={ + "forward": get_whisper_flash_attention_forward(), + }, + policy=policy, + target_key=WhisperFlashAttention2, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_whisper_flash_attention_forward(), + }, + policy=policy, + target_key=WhisperSdpaAttention, + ) if not self.shard_config.pipeline_stage_manager: self.append_or_create_method_replacement( description={ diff --git a/tests/kit/model_zoo/transformers/whisper.py b/tests/kit/model_zoo/transformers/whisper.py index 0d9a581dfbe9..d69bebe6cc04 100644 --- a/tests/kit/model_zoo/transformers/whisper.py +++ b/tests/kit/model_zoo/transformers/whisper.py @@ -66,7 +66,6 @@ def data_gen_for_audio_classification(): encoder_ffn_dim=1536, encoder_layers=2, vocab_size=51866, - _attn_implementation="eager", ) # register the Whisper variants From 10a815ace329f2d04edda81fa50d0b90ab691f91 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 24 Apr 2024 14:53:51 +0800 Subject: [PATCH 02/11] [shardformer] update flashattention replacement (#5637) * update transformers update transformers fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/shardformer/policies/gpt2.py | 9 ++++++- colossalai/shardformer/policies/gptj.py | 9 ++++++- colossalai/shardformer/policies/llama.py | 24 +++++++++++++++---- colossalai/shardformer/policies/mistral.py | 23 ++++++++++++------ colossalai/shardformer/policies/opt.py | 14 ++++++++--- tests/kit/model_zoo/transformers/llama.py | 1 - .../test_model/test_shard_mistral.py | 2 +- 7 files changed, 63 insertions(+), 19 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 98db7b948954..6f4f835a8dbe 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -35,13 +35,20 @@ def preprocess(self): Reshape the Embedding layer to make the embedding dimension divisible by world_size """ self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model + ATTN_IMPLEMENTATION = { + "eager": GPT2Attention, + } + policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D @@ -186,7 +193,7 @@ def module_policy(self): "forward": get_gpt2_flash_attention_forward(), }, policy=policy, - target_key=GPT2Attention, + target_key=attn_cls, ) if not self.shard_config.pipeline_stage_manager: policy[GPT2Model].method_replacement = { diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 4b69137a6892..1280efaec921 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -30,13 +30,20 @@ def config_sanity_check(self): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self): from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel + ATTN_IMPLEMENTATION = { + "eager": GPTJAttention, + } + policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D @@ -160,7 +167,7 @@ def module_policy(self): "forward": get_gptj_flash_attention_forward(), }, policy=policy, - target_key=GPTJAttention, + target_key=attn_cls, ) if not self.shard_config.pipeline_stage_manager: self.append_or_create_method_replacement( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index ff686a179553..1b30ae9c9f40 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -36,13 +36,27 @@ def config_sanity_check(self): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel + from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaFlashAttention2, + LlamaModel, + LlamaSdpaAttention, + ) + ATTN_IMPLEMENTATION = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, + } policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -93,7 +107,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), }, policy=policy, - target_key=LlamaAttention, + target_key=attn_cls, ) elif sp_mode == "all_to_all": decoder_attribute_replacement = { @@ -102,7 +116,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if getattr(self.model.config, "num_key_value_heads", False): decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size - policy[LlamaAttention] = ModulePolicyDescription( + policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) self.append_or_create_method_replacement( @@ -110,7 +124,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), }, policy=policy, - target_key=LlamaAttention, + target_key=attn_cls, ) self.append_or_create_method_replacement( description={ @@ -221,7 +235,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size), }, policy=policy, - target_key=LlamaAttention, + target_key=attn_cls, ) if self.pipeline_stage_manager is None: # replace llama model forward method diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 61e1b5f9c7b4..b3f89b4042c1 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -26,13 +26,26 @@ def config_sanity_check(self): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel + from transformers.models.mistral.modeling_mistral import ( + MistralAttention, + MistralDecoderLayer, + MistralFlashAttention2, + MistralModel, + ) + + ATTN_IMPLEMENTATION = { + "eager": MistralAttention, + "flash_attention_2": MistralFlashAttention2, + } policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -128,10 +141,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_mistral_flash_attention_forward(), + "forward": get_mistral_flash_attention_forward(self.shard_config), }, policy=policy, - target_key=MistralAttention, + target_key=attn_cls, ) return policy @@ -143,10 +156,6 @@ def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) method_replacement = {"forward": partial(new_forward)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) - def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - method_replacement = {"forward": partial(new_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) - class MistralModelPolicy(MistralPolicy): def __init__(self) -> None: diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 2bb28b095114..2f6eabd5fef9 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -44,13 +44,21 @@ def config_sanity_check(self): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self): - from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer + from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer, OptFlashAttention2 + + ATTN_IMPLEMENTATION = { + "eager": OPTAttention, + "flash_attention_2": OptFlashAttention2, + } policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -81,7 +89,7 @@ def module_policy(self): ] ) - policy[OPTAttention] = ModulePolicyDescription( + policy[attn_cls] = ModulePolicyDescription( attribute_replacement={ "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, @@ -151,7 +159,7 @@ def module_policy(self): "forward": get_opt_flash_attention_forward(self.shard_config), }, policy=policy, - target_key=OPTAttention, + target_key=attn_cls, ) if not self.shard_config.pipeline_stage_manager: self.append_or_create_method_replacement( diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 08c05e9063bf..58b5b0487a82 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -65,7 +65,6 @@ def data_gen_for_casual_lm(): num_attention_heads=4, max_position_embeddings=128, num_labels=16, - attn_implementation="eager", ) if hasattr(config, "pad_token_id"): diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py index 07bc91b33b72..f127472aee0b 100644 --- a/tests/test_shardformer/test_model/test_shard_mistral.py +++ b/tests/test_shardformer/test_model/test_shard_mistral.py @@ -156,7 +156,7 @@ def check_mistral(rank, world_size, port): run_mistral_test() -@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.") +@pytest.mark.skip("something wrong with pipeline parallelism") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() From 889940b7056a818a207056ab67b63030f0de3a26 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 23 Apr 2024 13:54:05 +0800 Subject: [PATCH 03/11] [Feature] Support LLaMA-3 CPT and ST (#5619) * support LLaMA-3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run pre-commit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- applications/Colossal-LLaMA-2/version.txt | 1 - .../README.md | 30 +++++----- .../colossal_llama}/__init__.py | 0 .../colossal_llama}/dataset/__init__.py | 0 .../colossal_llama}/dataset/conversation.py | 14 ++++- .../colossal_llama}/dataset/loader.py | 0 .../dataset/spliced_and_tokenized_dataset.py | 3 +- .../colossal_llama}/model/init_model.py | 0 .../tokenizer/init_tokenizer.py | 0 .../colossal_llama}/utils/__init__.py | 0 .../colossal_llama}/utils/ckpt_io.py | 0 .../utils/flash_attention_patch.py | 0 .../colossal_llama}/utils/froze.py | 0 .../colossal_llama}/utils/neftune_patch.py | 0 .../utils/stream_chat_patch.py | 0 .../docs/example_13b.md | 0 .../docs/example_7b.md | 0 .../hostfile.example | 0 .../inference_example.py | 2 +- .../prepare_pretrain_dataset.py | 41 +++++--------- .../prepare_sft_dataset.py | 55 +++++++++---------- .../requirements.txt | 9 +-- .../stream_chat_example.py | 2 +- .../train.example.sh | 0 .../train.py | 16 +++--- .../train_sft.example.sh | 0 applications/Colossal-LLaMA/version.txt | 1 + applications/README.md | 2 +- 28 files changed, 89 insertions(+), 87 deletions(-) delete mode 100644 applications/Colossal-LLaMA-2/version.txt rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/README.md (97%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/__init__.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/dataset/__init__.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/dataset/conversation.py (86%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/dataset/loader.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/dataset/spliced_and_tokenized_dataset.py (99%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/model/init_model.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/tokenizer/init_tokenizer.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/utils/__init__.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/utils/ckpt_io.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/utils/flash_attention_patch.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/utils/froze.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/utils/neftune_patch.py (100%) rename applications/{Colossal-LLaMA-2/colossal_llama2 => Colossal-LLaMA/colossal_llama}/utils/stream_chat_patch.py (100%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/docs/example_13b.md (100%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/docs/example_7b.md (100%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/hostfile.example (100%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/inference_example.py (97%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/prepare_pretrain_dataset.py (80%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/prepare_sft_dataset.py (74%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/requirements.txt (65%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/stream_chat_example.py (97%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/train.example.sh (100%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/train.py (96%) rename applications/{Colossal-LLaMA-2 => Colossal-LLaMA}/train_sft.example.sh (100%) create mode 100644 applications/Colossal-LLaMA/version.txt diff --git a/applications/Colossal-LLaMA-2/version.txt b/applications/Colossal-LLaMA-2/version.txt deleted file mode 100644 index 8acdd82b765e..000000000000 --- a/applications/Colossal-LLaMA-2/version.txt +++ /dev/null @@ -1 +0,0 @@ -0.0.1 diff --git a/applications/Colossal-LLaMA-2/README.md b/applications/Colossal-LLaMA/README.md similarity index 97% rename from applications/Colossal-LLaMA-2/README.md rename to applications/Colossal-LLaMA/README.md index 1377e1facec0..93ba58ac5894 100644 --- a/applications/Colossal-LLaMA-2/README.md +++ b/applications/Colossal-LLaMA/README.md @@ -1,6 +1,6 @@

- +Colossal-LLaMA

@@ -47,6 +47,7 @@ - [Citations](#citations) ## News +* [2024/4] Support continual pre-training and supervised fine-tuning of LLaMA-3. * [2024/01] [Construct Refined 13B Private Model With Just $5000 USD, Upgraded Colossal-AI Llama-2 Open Source](https://hpc-ai.com/blog/colossal-llama-2-13b). [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2) [[blog]](https://hpc-ai.com/blog/colossal-llama-2-13b) @@ -289,7 +290,7 @@ Here is details about CLI arguments: #### 1. Install required packages ``` -cd Colossal-LLaMA-2 +cd Colossal-LLaMA pip install -r requirements.txt ``` #### 2. Install `xentropy`, `layer_norm` and `rotary` @@ -314,7 +315,7 @@ Initialize new tokenizer with additional Chinese tokens. Additional Chinese toke Command to initialize new tokenizer: ```bash export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION='python' -python colossal_llama2/tokenizer/init_tokenizer.py \ +python colossal_llama/tokenizer/init_tokenizer.py \ --source_tokenizer_dir "" \ --target_tokenizer_dir "" \ --expand_tokens_file ".jsonl" @@ -328,7 +329,7 @@ Here is details about CLI arguments: Initialize the new model checkpoint by calculating the mean values from the original model checkpoint. Command to initialize new model checkpoint: ```bash -python colossal_llama2/model/init_model.py \ +python colossal_llama/model/init_model.py \ --source_model_and_tokenizer_path "" \ --target_tokenizer_path "" \ --target_model_path "" @@ -362,18 +363,17 @@ Command to convert jsonl dataset to arrow format: python prepare_pretrain_dataset.py \ --data_input_dirs ",," \ --tokenizer_dir "" \ - --data_cache_dir "jsonl_to_arrow_cache" \ - --data_jsonl_output_dir "spliced_tokenized_output_jsonl" \ - --data_arrow_output_dir "spliced_tokenized_output_arrow" \ + --data_output_dirs "spliced tokenized output" \ --max_length 4096 \ --num_spliced_dataset_bins 10 ``` Here is details about CLI arguments: * Source data directory: `data_input_dirs`. Each `` can have multiple file in `jsonl` format. * Tokenizer directory: `tokenizer_dir`. Path to the tokenizer in Hugging Face format. -* Data cache directory: `data_cache_dir`. Directory to store Hugging Face data cache. Default case will create `cache` folder locally. -* Output directory for jsonl format: `data_jsonl_output_dir`. Output directory to store converted dataset in jsonl format. -* Output directory for arrow format: `data_arrow_output_dir`. Output directory to store converted dataset in arrow format, which can be used for training directly. +* Data output directory: `data_output_dirs`. Directory to store preprocessed output, including three sub-directories: + * `cache`: Directory to store Hugging Face data cache. + * `jsonl`: Output directory to store converted dataset in jsonl format. + * `arrow`: Output directory to store converted dataset in arrow format, which can be used for training directly. * Max length: `max_length`. Max length of spliced samples. Default value is 4096. * Number of bins for each category: `num_spliced_dataset_bins`. Number of bins for each category, used for bucket-based training. @@ -392,13 +392,15 @@ Command to convert jsonl dataset to arrow format is similar to the command in [3 python prepare_sft_dataset.py.py \ --data_input_dirs ",," \ --tokenizer_dir "" \ - --data_cache_dir "jsonl_to_arrow_cache" \ - --data_jsonl_output_dir "spliced_tokenized_output_jsonl" \ - --data_arrow_output_dir "spliced_tokenized_output_arrow" \ + --data_output_dirs "spliced tokenized output" \ --max_length 4096 \ - --num_spliced_dataset_bins 10 + --num_spliced_dataset_bins 10 \ + --llama_version 3 ``` +Additional CLI arguments: +* LLaMA verison: `llama_version`. Specify the LLaMA version. + #### 4. Command Line Arguments for Training ##### 4.1 Arguments for Pretraining diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/__init__.py b/applications/Colossal-LLaMA/colossal_llama/__init__.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/__init__.py rename to applications/Colossal-LLaMA/colossal_llama/__init__.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py b/applications/Colossal-LLaMA/colossal_llama/dataset/__init__.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py rename to applications/Colossal-LLaMA/colossal_llama/dataset/__init__.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/conversation.py b/applications/Colossal-LLaMA/colossal_llama/dataset/conversation.py similarity index 86% rename from applications/Colossal-LLaMA-2/colossal_llama2/dataset/conversation.py rename to applications/Colossal-LLaMA/colossal_llama/dataset/conversation.py index be27ff7bc817..8ec9c848b2c8 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/conversation.py +++ b/applications/Colossal-LLaMA/colossal_llama/dataset/conversation.py @@ -83,7 +83,7 @@ def dict(self): } -conv = Conversation( +LLaMA2_Conv = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", roles=("Human", "Assistant"), @@ -93,4 +93,14 @@ def dict(self): seps=["", ""], ) -default_conversation = conv +LLaMA3_Conv = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("Human", "Assistant"), + messages=[], + offset=0, + sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN, + seps=["<|begin_of_text|>", "<|end_of_text|>"], +) + +default_conversation = LLaMA3_Conv diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py b/applications/Colossal-LLaMA/colossal_llama/dataset/loader.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py rename to applications/Colossal-LLaMA/colossal_llama/dataset/loader.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py b/applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py similarity index 99% rename from applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py rename to applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py index 8314941babb4..30122d2838f9 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py +++ b/applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py @@ -12,6 +12,7 @@ from datasets import dataset_dict from torch.utils.data import ConcatDataset, Dataset, IterableDataset +from transformers import AutoTokenizer from transformers.models.llama.tokenization_llama import LlamaTokenizer from transformers.tokenization_utils import PreTrainedTokenizer @@ -71,7 +72,7 @@ def supervised_tokenize_pretrain( def supervised_tokenize_sft( data_point: Dict[str, str], - tokenizer: LlamaTokenizer, + tokenizer: AutoTokenizer, conversation_template: Conversation = default_conversation, ignore_index: int = None, max_length: int = 4096, diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py b/applications/Colossal-LLaMA/colossal_llama/model/init_model.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py rename to applications/Colossal-LLaMA/colossal_llama/model/init_model.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py b/applications/Colossal-LLaMA/colossal_llama/tokenizer/init_tokenizer.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py rename to applications/Colossal-LLaMA/colossal_llama/tokenizer/init_tokenizer.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py b/applications/Colossal-LLaMA/colossal_llama/utils/__init__.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py rename to applications/Colossal-LLaMA/colossal_llama/utils/__init__.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py b/applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py rename to applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA/colossal_llama/utils/flash_attention_patch.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py rename to applications/Colossal-LLaMA/colossal_llama/utils/flash_attention_patch.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py b/applications/Colossal-LLaMA/colossal_llama/utils/froze.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py rename to applications/Colossal-LLaMA/colossal_llama/utils/froze.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py b/applications/Colossal-LLaMA/colossal_llama/utils/neftune_patch.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py rename to applications/Colossal-LLaMA/colossal_llama/utils/neftune_patch.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/stream_chat_patch.py b/applications/Colossal-LLaMA/colossal_llama/utils/stream_chat_patch.py similarity index 100% rename from applications/Colossal-LLaMA-2/colossal_llama2/utils/stream_chat_patch.py rename to applications/Colossal-LLaMA/colossal_llama/utils/stream_chat_patch.py diff --git a/applications/Colossal-LLaMA-2/docs/example_13b.md b/applications/Colossal-LLaMA/docs/example_13b.md similarity index 100% rename from applications/Colossal-LLaMA-2/docs/example_13b.md rename to applications/Colossal-LLaMA/docs/example_13b.md diff --git a/applications/Colossal-LLaMA-2/docs/example_7b.md b/applications/Colossal-LLaMA/docs/example_7b.md similarity index 100% rename from applications/Colossal-LLaMA-2/docs/example_7b.md rename to applications/Colossal-LLaMA/docs/example_7b.md diff --git a/applications/Colossal-LLaMA-2/hostfile.example b/applications/Colossal-LLaMA/hostfile.example similarity index 100% rename from applications/Colossal-LLaMA-2/hostfile.example rename to applications/Colossal-LLaMA/hostfile.example diff --git a/applications/Colossal-LLaMA-2/inference_example.py b/applications/Colossal-LLaMA/inference_example.py similarity index 97% rename from applications/Colossal-LLaMA-2/inference_example.py rename to applications/Colossal-LLaMA/inference_example.py index 8d301616d678..0369d9c0ab88 100644 --- a/applications/Colossal-LLaMA-2/inference_example.py +++ b/applications/Colossal-LLaMA/inference_example.py @@ -1,7 +1,7 @@ import argparse import torch -from colossal_llama2.dataset.conversation import default_conversation +from colossal_llama.dataset.conversation import default_conversation from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.logging import get_dist_logger diff --git a/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py b/applications/Colossal-LLaMA/prepare_pretrain_dataset.py similarity index 80% rename from applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py rename to applications/Colossal-LLaMA/prepare_pretrain_dataset.py index cb578b5f6585..9642159aa0f6 100644 --- a/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py +++ b/applications/Colossal-LLaMA/prepare_pretrain_dataset.py @@ -11,12 +11,12 @@ import time from multiprocessing import cpu_count -from colossal_llama2.dataset.spliced_and_tokenized_dataset import ( +from colossal_llama.dataset.spliced_and_tokenized_dataset import ( ClosedToConstantLengthSplicedDataset, supervised_tokenize_pretrain, ) from datasets import dataset_dict, load_dataset -from transformers.models.llama.tokenization_llama import LlamaTokenizer +from transformers import AutoTokenizer from colossalai.logging import get_dist_logger @@ -35,35 +35,24 @@ def main(): parser.add_argument( "--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer" ) - parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory") - parser.add_argument( - "--data_jsonl_output_dir", - type=str, - default="jsonl_output", - help="Output directory of spliced dataset with jsonl format", - ) - parser.add_argument( - "--data_arrow_output_dir", - type=str, - default="arrow_output", - help="Output directory of spliced dataset with arrow format", - ) - parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence") + parser.add_argument("--data_output_dirs", type=str, default="data_output_dirs", help="Data output directory") + parser.add_argument("--max_length", type=int, default=8192, help="Max length of each spliced tokenized sequence") parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins") args = parser.parse_args() if args.num_spliced_dataset_bins >= 100000: raise ValueError("Too many spliced divisions, must be smaller than 100000") - assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}" - assert not os.path.exists( - args.data_jsonl_output_dir - ), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}" - assert not os.path.exists( - args.data_arrow_output_dir - ), f"Find existed arrow data output dir {args.data_arrow_output_dir}" - os.makedirs(args.data_jsonl_output_dir) - os.makedirs(args.data_arrow_output_dir) + args.data_cache_dir = os.path.join(args.data_output_dirs, "cache") + args.data_jsonl_output_dir = os.path.join(args.data_output_dirs, "jsonl") + args.data_arrow_output_dir = os.path.join(args.data_output_dirs, "arrow") + + if not os.path.exists(args.data_cache_dir): + os.makedirs(args.data_cache_dir) + if not os.path.exists(args.data_jsonl_output_dir): + os.makedirs(args.data_jsonl_output_dir) + if not os.path.exists(args.data_arrow_output_dir): + os.makedirs(args.data_arrow_output_dir) # Prepare to all input datasets input_data_paths = [] @@ -86,7 +75,7 @@ def main(): train_splits.append(f"train[{start}%:{end}%]") # Prepare to the tokenizer. - tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir) tokenizer.add_bos_token = False tokenizer.add_eos_token = False if tokenizer.pad_token is None: diff --git a/applications/Colossal-LLaMA-2/prepare_sft_dataset.py b/applications/Colossal-LLaMA/prepare_sft_dataset.py similarity index 74% rename from applications/Colossal-LLaMA-2/prepare_sft_dataset.py rename to applications/Colossal-LLaMA/prepare_sft_dataset.py index 6d19cbd72372..be5f9bcca3df 100644 --- a/applications/Colossal-LLaMA-2/prepare_sft_dataset.py +++ b/applications/Colossal-LLaMA/prepare_sft_dataset.py @@ -10,10 +10,10 @@ import os from multiprocessing import cpu_count -from colossal_llama2.dataset.conversation import default_conversation -from colossal_llama2.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft +from colossal_llama.dataset.conversation import default_conversation +from colossal_llama.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft from datasets import dataset_dict, load_dataset -from transformers.models.llama.tokenization_llama import LlamaTokenizer +from transformers import AddedToken, AutoTokenizer from colossalai.logging import get_dist_logger @@ -32,35 +32,25 @@ def main(): parser.add_argument( "--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer" ) - parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory") - parser.add_argument( - "--data_jsonl_output_dir", - type=str, - default="jsonl_output", - help="Output directory of spliced dataset with jsonl format", - ) - parser.add_argument( - "--data_arrow_output_dir", - type=str, - default="arrow_output", - help="Output directory of spliced dataset with arrow format", - ) - parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence") + parser.add_argument("--data_output_dirs", type=str, default="data_output_dirs", help="Data output directory") + parser.add_argument("--max_length", type=int, default=8192, help="Max length of each spliced tokenized sequence") parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins") + parser.add_argument("--llama_version", type=int, default=3, help="LLaMA version") args = parser.parse_args() if args.num_spliced_dataset_bins >= 100000: raise ValueError("Too many spliced divisions, must be smaller than 100000") - assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}" - assert not os.path.exists( - args.data_jsonl_output_dir - ), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}" - assert not os.path.exists( - args.data_arrow_output_dir - ), f"Find existed arrow data output dir {args.data_arrow_output_dir}" - os.makedirs(args.data_jsonl_output_dir) - os.makedirs(args.data_arrow_output_dir) + args.data_cache_dir = os.path.join(args.data_output_dirs, "cache") + args.data_jsonl_output_dir = os.path.join(args.data_output_dirs, "jsonl") + args.data_arrow_output_dir = os.path.join(args.data_output_dirs, "arrow") + + if not os.path.exists(args.data_cache_dir): + os.makedirs(args.data_cache_dir) + if not os.path.exists(args.data_jsonl_output_dir): + os.makedirs(args.data_jsonl_output_dir) + if not os.path.exists(args.data_arrow_output_dir): + os.makedirs(args.data_arrow_output_dir) # Prepare to all input datasets input_data_paths = [] @@ -83,11 +73,20 @@ def main(): train_splits.append(f"train[{start}%:{end}%]") # Prepare to the tokenizer. - tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir) + + # Fix split issue: https://github.com/huggingface/transformers/issues/23833 + if args.llama_version == 2: + tokenizer.add_tokens(AddedToken("", normalized=False, special=True), special_tokens=True) + tokenizer.add_bos_token = False tokenizer.add_eos_token = False if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.unk_token + if tokenizer.unk_token is not None: + tokenizer.pad_token = tokenizer.unk_token + else: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.unk_token = tokenizer.eos_token list_dataset = load_dataset( path="json", diff --git a/applications/Colossal-LLaMA-2/requirements.txt b/applications/Colossal-LLaMA/requirements.txt similarity index 65% rename from applications/Colossal-LLaMA-2/requirements.txt rename to applications/Colossal-LLaMA/requirements.txt index 5cdb8e7f3348..809a942ac398 100644 --- a/applications/Colossal-LLaMA-2/requirements.txt +++ b/applications/Colossal-LLaMA/requirements.txt @@ -1,9 +1,10 @@ -torch<2.0.0, >=1.12.1 -packaging==23.1 -colossalai==0.3.5 +torch==2.1.2 +huggingface-hub +packaging==24.0 +colossalai==0.3.6 autoflake==2.2.1 black==23.9.1 -transformers==4.33.3 +transformers==4.34.1 tensorboard==2.14.0 six==1.16.0 datasets diff --git a/applications/Colossal-LLaMA-2/stream_chat_example.py b/applications/Colossal-LLaMA/stream_chat_example.py similarity index 97% rename from applications/Colossal-LLaMA-2/stream_chat_example.py rename to applications/Colossal-LLaMA/stream_chat_example.py index 4c0d1fe2a35f..9a353b473701 100644 --- a/applications/Colossal-LLaMA-2/stream_chat_example.py +++ b/applications/Colossal-LLaMA/stream_chat_example.py @@ -1,6 +1,6 @@ import argparse -from colossal_llama2.utils.stream_chat_patch import streaming_chat +from colossal_llama.utils.stream_chat_patch import streaming_chat from transformers import AutoModelForCausalLM, AutoTokenizer SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." diff --git a/applications/Colossal-LLaMA-2/train.example.sh b/applications/Colossal-LLaMA/train.example.sh similarity index 100% rename from applications/Colossal-LLaMA-2/train.example.sh rename to applications/Colossal-LLaMA/train.example.sh diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA/train.py similarity index 96% rename from applications/Colossal-LLaMA-2/train.py rename to applications/Colossal-LLaMA/train.py index d97da61e4dc8..dcd7be9f4e4c 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -12,18 +12,18 @@ import torch import torch.distributed as dist -from colossal_llama2.dataset.loader import ( +from colossal_llama.dataset.loader import ( DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset, ) -from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint -from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention -from colossal_llama2.utils.froze import freeze_non_embeds_parameters -from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune +from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint +from colossal_llama.utils.flash_attention_patch import replace_with_flash_attention +from colossal_llama.utils.froze import freeze_non_embeds_parameters +from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from transformers import LlamaForCausalLM, LlamaTokenizer +from transformers import AutoTokenizer, LlamaForCausalLM import colossalai from colossalai.accelerator import get_accelerator @@ -89,7 +89,7 @@ def main() -> None: parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps") parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process") parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") - parser.add_argument("--max_length", type=int, default=4096, help="Model max length") + parser.add_argument("--max_length", type=int, default=8192, help="Model max length") parser.add_argument( "--mixed_precision", type=str, @@ -196,7 +196,7 @@ def main() -> None: # ====================================================== # Initialize Tokenizer, Dataset, Collator and Dataloader # ====================================================== - tokenizer = LlamaTokenizer.from_pretrained(args.pretrained) + tokenizer = AutoTokenizer.from_pretrained(args.pretrained) if args.pad_token == "eos": tokenizer.pad_token = tokenizer.eos_token elif args.pad_token == "unk": diff --git a/applications/Colossal-LLaMA-2/train_sft.example.sh b/applications/Colossal-LLaMA/train_sft.example.sh similarity index 100% rename from applications/Colossal-LLaMA-2/train_sft.example.sh rename to applications/Colossal-LLaMA/train_sft.example.sh diff --git a/applications/Colossal-LLaMA/version.txt b/applications/Colossal-LLaMA/version.txt new file mode 100644 index 000000000000..3eefcb9dd5b3 --- /dev/null +++ b/applications/Colossal-LLaMA/version.txt @@ -0,0 +1 @@ +1.0.0 diff --git a/applications/README.md b/applications/README.md index 120767d5c9ea..e7c23c7e9b7b 100644 --- a/applications/README.md +++ b/applications/README.md @@ -5,7 +5,7 @@ This directory contains the applications that are powered by Colossal-AI. The list of applications include: - [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models -- [X] [Colossal-LLaMA-2](./Colossal-LLaMA-2/): Continual Pre-training of LLaMA-2. +- [X] [Colossal-LLaMA](./Colossal-LLaMA/): Continual Pre-training and Supervisied Fine-tuning of LLaMA2 / LLaMA3. - [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs. - [X] [ColossalChat](./Chat/README.md): Replication of ChatGPT with RLHF. - [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters. From 7152411f1a95b15e5008cd48634439cfacf18955 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 23 Apr 2024 14:12:20 +0800 Subject: [PATCH 04/11] [exampe] update llama example (#5626) * [plugin] support dp inside for hybriad parallel * [example] update llama benchmark * [example] update llama benchmark * [example] update llama readme * [example] update llama readme --- colossalai/booster/plugin/gemini_plugin.py | 1 + .../booster/plugin/hybrid_parallel_plugin.py | 28 +- examples/language/llama2/README.md | 117 +------ examples/language/llama2/attn.py | 1 - examples/language/llama2/benchmark.py | 62 +++- examples/language/llama2/finetune.py | 313 ----------------- examples/language/llama2/pretrain.py | 328 ------------------ examples/language/llama2/requirements.txt | 5 +- 8 files changed, 72 insertions(+), 783 deletions(-) delete mode 120000 examples/language/llama2/attn.py delete mode 100644 examples/language/llama2/finetune.py delete mode 100644 examples/language/llama2/pretrain.py diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 442ac4a8da06..a67ca18a3456 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -424,6 +424,7 @@ def __init__( ) self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None + self.dp_size = self.zero_size * self.extra_dp_size self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 8d12eb80621d..95fb2def10a4 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -34,7 +34,6 @@ from .pp_plugin_base import PipelinePluginBase -DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3 SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} @@ -987,6 +986,7 @@ def __init__( gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, make_vocab_size_divisible_by: int = 64, + dp_outside: bool = True, ) -> None: super().__init__() assert ( @@ -1034,7 +1034,12 @@ def __init__( self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) + if dp_outside: + self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) + else: + self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.stage_manager = None self.schedule = None self.custom_policy = custom_policy @@ -1048,7 +1053,7 @@ def __init__( assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" self.stage_manager = PipelineStageManager( self.pg_mesh, - pipeline_axis=PP_AXIS, + pipeline_axis=self.pp_axis, enable_interleave=pp_style == "interleaved", num_model_chunks=num_model_chunks, ) @@ -1072,13 +1077,13 @@ def __init__( else: raise NotImplementedError() - self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) - self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) - self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) + self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) + self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: - self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) else: - self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS) + self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, @@ -1169,7 +1174,7 @@ def configure( and self.sequence_parallelism_mode == "all_to_all" ) if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": - dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS]) + dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) else: dp_group = self.dp_group model = HybridParallelModule( @@ -1317,7 +1322,10 @@ def prepare_dataloader( _kwargs = kwargs.copy() distributed_sampler_cls = distributed_sampler_cls or DistributedSampler sampler = distributed_sampler_cls( - dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle + dataset, + num_replicas=self.pg_mesh.size(self.dp_axis), + rank=self.pg_mesh.coordinate(self.dp_axis), + shuffle=shuffle, ) # Deterministic dataloader diff --git a/examples/language/llama2/README.md b/examples/language/llama2/README.md index 068f15cbb041..11b2ee511a6e 100644 --- a/examples/language/llama2/README.md +++ b/examples/language/llama2/README.md @@ -1,4 +1,4 @@ -# Pretraining LLaMA-1/2: best practices for building LLaMA-1/2-like base models +# Pretraining LLaMA-1/2/3: best practices for building LLaMA-1/2/3-like base models ### LLaMA2

@@ -16,38 +16,10 @@ - 65-billion-parameter large model pretraining accelerated by 38% [[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining) -## Dataset - -Different from the original LLaMA, we use [RedPajama](https://www.together.xyz/blog/redpajama) dataset, which is a reproduction of the LLaMA training dataset containing over 1.2 trillion tokens. The full dataset is ~5TB unzipped on disk and ~3TB to download compressed. - -A smaller, more consumable random sample can be downloaded through [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T). If you just want to try out the pretraining script, you can use a 1B-token sample subset of RedPajama, which is available at [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample). - -RedPajama-Data-1T consists of seven data slices: - -| | RedPajama | LLaMA | -|---------------|--------------|---------------| -| CommonCrawl | 878 billion | 852 billion | -| C4 | 175 billion | 190 billion | -| Github | 59 billion | 100 billion | -| Books | 26 billion | 25 billion | -| ArXiv | 28 billion | 33 billion | -| Wikipedia | 24 billion | 25 billion | -| StackExchange | 20 billion | 27 billion | -| Total | 1.2 trillion | 1.25 trillion | - -## Training - -We follow the hyperparameter settings from the original LLaMA paper. We use AdamW with $beta1=0.9$ and $beta2=0.95$. We use a cosine learning rate schedule, such that the final learning rate is equal to 10% of the maximal learning rate. We use a weight decay of 0.1 and gradient clipping of 1.0. We use 2,000 warmup steps. - -| params | learning rate | batch size | -|--------|---------------|------------| -| 6.7B | 3.0e-4 | 4M | -| 13.0B | 3.0e-4 | 4M | -| 32.5B | 1.5e-4 | 4M | -| 65.2B | 1.5e-4 | 4M | - ## Usage +> ⚠ This example only has benchmarking script. For training/finetuning, please refer to the [applications/Colossal-LLaMA](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA). + ### 1. Installation Please install the latest ColossalAI from source. @@ -62,52 +34,6 @@ Then install other dependencies. pip install -r requirements.txt ``` -Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention. - -### 2. Download the dataset - -The dataset can be automatically downloaded by using `huggingface/datasets`. You can specify the dataset path by `-d` or `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. - -### 3. Command line arguments - -Yon can use colossalai run to launch multi-nodes training: -```bash -colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ -pretrain.py --OTHER_CONFIGURATIONS -``` - -Here is a sample hostfile: - -```text -hostname1 -hostname2 -hostname3 -hostname4 -``` - -Make sure master node can access all nodes (including itself) by ssh without password. - -Here is details about CLI arguments: - -- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported for LLaMA-1, `7b`, `13b`, and `70b` are supported for LLaMA-2. -- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins). -- Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama. -- Number of epochs: `-e`, `--num_epochs`. The default value is 1. -- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2. -- Learning rate: `--lr`. The default value is 3e-4. -- Weight decay: `-w`, `--weight_decay`. The default value is 0.1. -- Warmup steps: `-s`, `--warmup_steps`. The default value is 2000. -- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size. -- Max length: `-l`, `--max_length`. The default value is 4096. -- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported. -- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. -- Checkpoint directory: `-o`, `--save_dir`. The directory path to save checkpoints. The default value is `checkpoint`. -- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`. -- Gradient clipping: `--gradient_clipping`. The default value is 1.0. -- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`. -- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention. - - ### 4. Shell Script Examples For your convenience, we provide some shell scripts to run benchmark with various configurations. @@ -193,40 +119,3 @@ If you run the above command successfully, you will get the following results: year={2023} } ``` - - -# Fine-tune Llama2 - -We also provide a example to fine-tune llama2 in `finetune.py`, - -Make sure master node can access all nodes (including itself) by ssh without password. - -Here is details about CLI arguments: - -- Pretrained checkpoint path: `--model_path`, the path of your model checkpoint, it can be your local directory or a Hugging Face tag. -- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins). -- Dataset path: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as `yizhongw/self_instruct`. -- task name: `--task_name`, the task to fine-tune, it's also related to the target of loading dataset, The default value is `super_natural_instructions`. -- Number of epochs: `-e`, `--num_epochs`. The default value is 1. -- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2. -- Learning rate: `--lr`. The default value is 3e-4. -- Weight decay: `-w`, `--weight_decay`. The default value is 0.1. -- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size. -- Max length: `-l`, `--max_length`. The default value is 4096. -- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported. -- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. -- Checkpoint directory: `-o`, `--save_dir`. The directory path to save checkpoints. The default value is `checkpoint`. -- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`. -- Gradient clipping: `--gradient_clipping`. The default value is 1.0. -- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`. -- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention. - - -```shell -torchrun --standalone --nproc_per_node 8 finetune.py \ - --plugin "hybrid_parallel" \ - --dataset "yizhongw/self_instruct" \ - --model_path "/path/llama" \ - --task_name "super_natural_instructions" \ - --save_dir "/path/output" -``` diff --git a/examples/language/llama2/attn.py b/examples/language/llama2/attn.py deleted file mode 120000 index 4e95c7bfa519..000000000000 --- a/examples/language/llama2/attn.py +++ /dev/null @@ -1 +0,0 @@ -../../../applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py \ No newline at end of file diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index 832465490907..ff94891f50ec 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -3,14 +3,13 @@ from contextlib import nullcontext import torch -from attn import replace_with_flash_attention from data_utils import RandomDataset from model_utils import format_numel_str, get_model_numel from performance_evaluator import PerformanceEvaluator from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaForCausalLM import colossalai from colossalai.accelerator import get_accelerator @@ -19,6 +18,7 @@ from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam +from colossalai.shardformer import PipelineGradientCheckpointConfig from examples.language.data_utils import RandomDataset from examples.language.model_utils import format_numel_str, get_model_numel from examples.language.performance_evaluator import PerformanceEvaluator @@ -78,6 +78,7 @@ def main(): parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") + parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) args = parser.parse_args() colossalai.launch_from_torch({}) @@ -86,6 +87,19 @@ def main(): def empty_init(): pass + # ckpt config for LLaMA3-70B on 64 H100 GPUs + ckpt_config = ( + PipelineGradientCheckpointConfig( + num_stages=args.pp, + num_model_chunks=1, + num_model_layers=80, + num_layers_per_stage=[19, 20, 20, 21], + num_ckpt_layers_per_stage=[19, 19, 19, 13], + ) + if args.custom_ckpt + else None + ) + # ============================== # Initialize Booster # ============================== @@ -98,6 +112,8 @@ def empty_init(): offload_param_frac=args.offload_param_frac, tp_size=args.tp, extra_dp_size=args.extra_dp, + enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.xformers, ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( @@ -106,26 +122,34 @@ def empty_init(): warmup_non_model_data_ratio=args.warmup_ratio, tp_size=args.tp, extra_dp_size=args.extra_dp, + enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.xformers, ) elif args.plugin == "fsdp": if use_empty_init: plugin = TorchFSDPPlugin( mixed_precision=MixedPrecision( - param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, ), param_init_fn=empty_init(), ) else: plugin = TorchFSDPPlugin( mixed_precision=MixedPrecision( - param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, ) ) elif args.plugin == "fsdp_cpu": if use_empty_init: plugin = TorchFSDPPlugin( mixed_precision=MixedPrecision( - param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, ), cpu_offload=CPUOffload(offload_params=True), param_init_fn=empty_init(), @@ -133,7 +157,9 @@ def empty_init(): else: plugin = TorchFSDPPlugin( mixed_precision=MixedPrecision( - param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, ), cpu_offload=CPUOffload(offload_params=True), ) @@ -141,12 +167,13 @@ def empty_init(): plugin = HybridParallelPlugin( tp_size=args.tp, pp_size=args.pp, - pp_style="interleaved", zero_stage=args.zero, - num_model_chunks=2, enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.xformers, microbatch_size=args.mbs, precision="bf16", + dp_outside=False, + gradient_checkpoint_config=ckpt_config, ) elif args.plugin == "3d_cpu": plugin = HybridParallelPlugin( @@ -155,6 +182,7 @@ def empty_init(): zero_stage=args.zero, cpu_offload=True, enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.xformers, microbatch_size=args.mbs, initial_scale=2**8, precision="bf16", @@ -167,9 +195,12 @@ def empty_init(): # ============================== # Initialize Dataset and Dataloader # ============================== - dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size + dp_size = getattr(plugin, "dp_size", coordinator.world_size) - config = MODEL_CONFIGS[args.config] + if args.config in MODEL_CONFIGS: + config = MODEL_CONFIGS[args.config] + else: + config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) dataset = RandomDataset( num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size ) @@ -184,14 +215,17 @@ def empty_init(): else nullcontext() ) + init_kwargs = {} + if config.model_type == "chatglm": + init_kwargs["empty_init"] = False + with init_ctx: - model = LlamaForCausalLM(config) + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs) if args.grad_checkpoint: model.gradient_checkpointing_enable() - - if args.xformers: - replace_with_flash_attention(model) + if config.model_type == "chatglm": + model.transformer.encoder.gradient_checkpointing = True model_numel = get_model_numel(model) coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py deleted file mode 100644 index 69b4ebe42bf7..000000000000 --- a/examples/language/llama2/finetune.py +++ /dev/null @@ -1,313 +0,0 @@ -import argparse -import math -import os -import resource -from contextlib import nullcontext -from functools import partial -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn as nn -from attn import replace_with_flash_attention -from data_utils import load_json, prepare_dataloader, save_json -from datasets import load_dataset -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaForCausalLM -from transformers.models.llama.tokenization_llama import LlamaTokenizer - -import colossalai -from colossalai.accelerator import get_accelerator -from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin -from colossalai.cluster import DistCoordinator -from colossalai.lazy import LazyInitContext -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.nn.optimizer import HybridAdam - - -def get_model_numel(model: nn.Module) -> int: - return sum(p.numel() for p in model.parameters()) - - -def format_numel_str(numel: int) -> str: - B = 1024**3 - M = 1024**2 - K = 1024 - if numel >= B: - return f"{numel / B:.2f} B" - elif numel >= M: - return f"{numel / M:.2f} M" - elif numel >= K: - return f"{numel / K:.2f} K" - else: - return f"{numel}" - - -def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): - texts = [sample["prompt"] + sample["completion"] for sample in batch] - data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) - data = {k: v.cuda() for k, v in data.items()} - data["labels"] = data["input_ids"].clone() - return data - - -def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - tensor = tensor.data - tensor.div_(dist.get_world_size()) - return tensor - - -def save( - booster: Booster, - model: nn.Module, - optimizer: Optimizer, - lr_scheduler: _LRScheduler, - epoch: int, - step: int, - batch_size: int, - coordinator: DistCoordinator, - save_dir: str, -): - save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}") - os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) - - booster.save_model(model, os.path.join(save_dir, "model"), shard=True) - booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) - booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) - running_states = { - "epoch": epoch, - "step": step, - "sample_start_index": step * batch_size, - } - if coordinator.is_master(): - save_json(running_states, os.path.join(save_dir, "running_states.json")) - - -def load( - booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str -) -> Tuple[int, int, int]: - booster.load_model(model, os.path.join(load_dir, "model")) - booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) - booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) - running_states = load_json(os.path.join(load_dir, "running_states.json")) - return running_states["epoch"], running_states["step"], running_states["sample_start_index"] - - -def _criterion(outputs, inputs): - return outputs.loss - - -def main(): - # ============================== - # Parse Arguments - # ============================== - parser = argparse.ArgumentParser() - parser.add_argument("--model_path", type=str, help="pretrained checkpoint path, used with mode==finetune") - parser.add_argument( - "-p", - "--plugin", - choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"], - default="gemini", - help="Choose which plugin to use", - ) - parser.add_argument("-d", "--dataset", type=str, default="yizhongw/self_instruct", help="Data set path") - parser.add_argument("--task_name", type=str, default="super_natural_instructions", help="task to run") - parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs") - parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size") - parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") - parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay") - parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") - parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") - parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision") - parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval") - parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory") - parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint") - parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping") - parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory") - parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention") - args = parser.parse_args() - - # ============================== - # Initialize Distributed Training - # ============================== - colossalai.launch_from_torch({}) - coordinator = DistCoordinator() - - # ============================== - # Initialize Booster - # ============================== - if args.plugin == "gemini": - plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip) - elif args.plugin == "gemini_auto": - plugin = GeminiPlugin( - precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip - ) - elif args.plugin == "zero2": - plugin = LowLevelZeroPlugin( - stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip - ) - elif args.plugin == "zero2_cpu": - plugin = LowLevelZeroPlugin( - stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip - ) - elif args.plugin == "hybrid_parallel": - # modify the param accordingly, default configuration is for llama2-7b - plugin = HybridParallelPlugin( - tp_size=4, - pp_size=2, - num_microbatches=None, - microbatch_size=1, - enable_jit_fused=False, - zero_stage=0, - precision="fp32", - initial_scale=1, - ) - else: - raise ValueError(f"Unknown plugin {args.plugin}") - - booster = Booster(plugin=plugin) - - use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) - - # ============================== - # Initialize Tensorboard - # ============================== - if print_flag: - os.makedirs(args.tensorboard_dir, exist_ok=True) - writer = SummaryWriter(args.tensorboard_dir) - - # ============================== - # Initialize Model, Optimizer and LR Scheduler - # ============================== - - config = LlamaConfig.from_pretrained(args.model_path) - # use lazy init when using GeminiPlugin - init_ctx = ( - LazyInitContext(default_device=get_accelerator().get_current_device()) - if isinstance(plugin, GeminiPlugin) - else nullcontext() - ) - - with init_ctx: - model = LlamaForCausalLM(config) - - # ============================== - # Initialize Tokenizer, Dataset and Dataloader - # ============================== - tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257 - tokenizer.pad_token = tokenizer.unk_token - - dataset = load_dataset(args.dataset, args.task_name) - train_ds = dataset["train"] - dataloader = prepare_dataloader( - train_ds, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=partial(tokenize_batch_for_finetune, tokenizer=tokenizer, max_length=args.max_length), - ) - - if args.grad_checkpoint: - model.gradient_checkpointing_enable() - if args.flash_attention: - replace_with_flash_attention(model) - - model_numel = get_model_numel(model) - coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") - - optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) - total_step = args.num_epochs * len(dataloader) - lr_scheduler = CosineAnnealingWarmupLR( - optimizer, total_steps=total_step, warmup_steps=math.ceil(total_step * 0.03), eta_min=0.1 * args.lr - ) - default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 - torch.set_default_dtype(default_dtype) - model, optimizer, _, dataloader, lr_scheduler = booster.boost( - model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler - ) - torch.set_default_dtype(torch.float) - - booster.load_model(model, args.model_path) - - coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") - coordinator.print_on_master( - f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" - ) - - # load checkpoint if specified - start_epoch = 0 - start_step = 0 - sampler_start_idx = 0 - if args.load is not None: - coordinator.print_on_master("Loading checkpoint") - start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) - coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}") - - num_steps_per_epoch = len(dataloader) - - # if resume training, set the sampler start index to the correct value - dataloader.sampler.set_start_index(sampler_start_idx) - for epoch in range(start_epoch, args.num_epochs): - dataloader.sampler.set_epoch(epoch) - step_nums = num_steps_per_epoch - start_step - dataloader_iter = iter(dataloader) - - with tqdm( - range(step_nums), - desc=f"Epoch {epoch}", - disable=not print_flag, - total=num_steps_per_epoch, - initial=start_step, - ) as pbar: - for step in pbar: - if use_pipeline: - outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True) - loss = outputs["loss"] - else: - batch = next(dataloader_iter) - outputs = model(**batch) - loss = outputs[0] - booster.backward(loss, optimizer) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - if not use_pipeline: - all_reduce_mean(loss) - if print_flag: - pbar.set_postfix({"loss": loss.item()}) - writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step) - - if args.save_interval > 0 and (step + 1) % args.save_interval == 0: - coordinator.print_on_master(f"Saving checkpoint") - save( - booster, - model, - optimizer, - lr_scheduler, - epoch, - step + 1, - args.batch_size, - coordinator, - args.save_dir, - ) - coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}") - # the continue epochs are not resumed, so we need to reset the sampler start index and start step - dataloader.sampler.set_start_index(0) - start_step = 0 - - coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") - - -if __name__ == "__main__": - main() diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py deleted file mode 100644 index 970cd5290f9f..000000000000 --- a/examples/language/llama2/pretrain.py +++ /dev/null @@ -1,328 +0,0 @@ -import argparse -import os -import resource -from contextlib import nullcontext -from functools import partial -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn as nn -from attn import replace_with_flash_attention -from data_utils import load_json, prepare_dataloader, save_json -from datasets import load_dataset -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaForCausalLM -from transformers.models.llama.tokenization_llama import LlamaTokenizer - -import colossalai -from colossalai.accelerator import get_accelerator -from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin -from colossalai.cluster import DistCoordinator -from colossalai.lazy import LazyInitContext -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.nn.optimizer import HybridAdam - -MODEL_CONFIGS = { - "7b": LlamaConfig(max_position_embeddings=4096), - "13b": LlamaConfig( - hidden_size=5120, - intermediate_size=13824, - num_hidden_layers=40, - num_attention_heads=40, - max_position_embeddings=4096, - ), - "70b": LlamaConfig( - hidden_size=8192, - intermediate_size=28672, - num_hidden_layers=80, - num_attention_heads=64, - max_position_embeddings=4096, - num_key_value_heads=8, - ), -} - - -def get_model_numel(model: nn.Module) -> int: - return sum(p.numel() for p in model.parameters()) - - -def format_numel_str(numel: int) -> str: - B = 1024**3 - M = 1024**2 - K = 1024 - if numel >= B: - return f"{numel / B:.2f} B" - elif numel >= M: - return f"{numel / M:.2f} M" - elif numel >= K: - return f"{numel / K:.2f} K" - else: - return f"{numel}" - - -def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): - texts = [sample["text"] for sample in batch] - data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) - data = {k: v.cuda() for k, v in data.items()} - data["labels"] = data["input_ids"].clone() - return data - - -def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - tensor = tensor.data - tensor.div_(dist.get_world_size()) - return tensor - - -def save( - booster: Booster, - model: nn.Module, - optimizer: Optimizer, - lr_scheduler: _LRScheduler, - epoch: int, - step: int, - batch_size: int, - coordinator: DistCoordinator, - save_dir: str, -): - save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}") - os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) - - booster.save_model(model, os.path.join(save_dir, "model"), shard=True) - booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) - booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) - running_states = { - "epoch": epoch, - "step": step, - "sample_start_index": step * batch_size, - } - if coordinator.is_master(): - save_json(running_states, os.path.join(save_dir, "running_states.json")) - - -def load( - booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str -) -> Tuple[int, int, int]: - booster.load_model(model, os.path.join(load_dir, "model")) - booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) - booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) - running_states = load_json(os.path.join(load_dir, "running_states.json")) - return running_states["epoch"], running_states["step"], running_states["sample_start_index"] - - -def _criterion(outputs, inputs): - return outputs.loss - - -def main(): - # ============================== - # Parse Arguments - # ============================== - parser = argparse.ArgumentParser() - parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration") - parser.add_argument( - "-p", - "--plugin", - choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"], - default="gemini", - help="Choose which plugin to use", - ) - parser.add_argument( - "-d", "--dataset", type=str, default="togethercomputer/RedPajama-Data-1T-Sample", help="Data set path" - ) - parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs") - parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size") - parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") - parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay") - parser.add_argument("-s", "--warmup_steps", type=int, default=2000, help="Warmup steps") - parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") - parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") - parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision") - parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval") - parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory") - parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint") - parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping") - parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory") - parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention") - args = parser.parse_args() - - # ============================== - # Initialize Distributed Training - # ============================== - colossalai.launch_from_torch({}) - coordinator = DistCoordinator() - - # ============================== - # Initialize Booster - # ============================== - if args.plugin == "gemini": - plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip) - elif args.plugin == "gemini_auto": - plugin = GeminiPlugin( - precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip - ) - elif args.plugin == "zero2": - plugin = LowLevelZeroPlugin( - stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip - ) - elif args.plugin == "zero2_cpu": - plugin = LowLevelZeroPlugin( - stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip - ) - elif args.plugin == "hybrid_parallel": - # modify the param accordingly, default configuration is for llama2-7b - plugin = HybridParallelPlugin( - tp_size=4, - pp_size=2, - num_microbatches=None, - microbatch_size=1, - enable_jit_fused=False, - zero_stage=0, - precision=args.mixed_precision, - initial_scale=1, - ) - else: - raise ValueError(f"Unknown plugin {args.plugin}") - - booster = Booster(plugin=plugin) - - use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) - - # ============================== - # Initialize Tensorboard - # ============================== - if print_flag: - os.makedirs(args.tensorboard_dir, exist_ok=True) - writer = SummaryWriter(args.tensorboard_dir) - - # ============================== - # Initialize Tokenizer, Dataset and Dataloader - # ============================== - tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257 - tokenizer.pad_token = tokenizer.unk_token - - dataset = load_dataset(args.dataset) - train_ds = dataset["train"] - dataloader = prepare_dataloader( - train_ds, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=partial(tokenize_batch_for_pretrain, tokenizer=tokenizer, max_length=args.max_length), - ) - - # ============================== - # Initialize Model, Optimizer and LR Scheduler - # ============================== - config = MODEL_CONFIGS[args.config] - # use lazy init when using GeminiPlugin - init_ctx = ( - LazyInitContext(default_device=get_accelerator().get_current_device()) - if isinstance(plugin, GeminiPlugin) - else nullcontext() - ) - - with init_ctx: - model = LlamaForCausalLM(config) - - if args.grad_checkpoint: - model.gradient_checkpointing_enable() - if args.flash_attention: - replace_with_flash_attention(model) - - model_numel = get_model_numel(model) - coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") - - optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) - lr_scheduler = CosineAnnealingWarmupLR( - optimizer, total_steps=args.num_epochs * len(dataloader), warmup_steps=args.warmup_steps, eta_min=0.1 * args.lr - ) - default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 - torch.set_default_dtype(default_dtype) - model, optimizer, _, dataloader, lr_scheduler = booster.boost( - model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler - ) - torch.set_default_dtype(torch.float) - - coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") - coordinator.print_on_master( - f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" - ) - - # load checkpoint if specified - start_epoch = 0 - start_step = 0 - sampler_start_idx = 0 - if args.load is not None: - coordinator.print_on_master("Loading checkpoint") - start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) - coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}") - - num_steps_per_epoch = len(dataloader) - - # if resume training, set the sampler start index to the correct value - dataloader.sampler.set_start_index(sampler_start_idx) - for epoch in range(start_epoch, args.num_epochs): - dataloader.sampler.set_epoch(epoch) - dataloader_iter = iter(dataloader) - - with tqdm( - range(start_step, num_steps_per_epoch), - desc=f"Epoch {epoch}", - disable=not print_flag, - total=num_steps_per_epoch, - initial=start_step, - ) as pbar: - for step in pbar: - if use_pipeline: - outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True) - loss = outputs["loss"] - else: - batch = next(dataloader_iter) - outputs = model(**batch) - loss = outputs[0] - booster.backward(loss, optimizer) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - if not use_pipeline: - all_reduce_mean(loss) - if print_flag: - pbar.set_postfix({"loss": loss.item()}) - writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step) - - if args.save_interval > 0 and (step + 1) % args.save_interval == 0: - coordinator.print_on_master(f"Saving checkpoint") - save( - booster, - model, - optimizer, - lr_scheduler, - epoch, - step + 1, - args.batch_size, - coordinator, - args.save_dir, - ) - coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}") - # the continue epochs are not resumed, so we need to reset the sampler start index and start step - dataloader.sampler.set_start_index(0) - start_step = 0 - - coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") - - -if __name__ == "__main__": - main() diff --git a/examples/language/llama2/requirements.txt b/examples/language/llama2/requirements.txt index 6b475682dad0..438a4999a3fe 100644 --- a/examples/language/llama2/requirements.txt +++ b/examples/language/llama2/requirements.txt @@ -1,9 +1,8 @@ -colossalai>=0.3.2 +colossalai>=0.3.6 datasets numpy -torch>=1.12.0,<=2.0.0 tqdm transformers -flash-attn>=2.0.0,<=2.0.5 +flash-attn>=2.0.0 SentencePiece==0.1.99 tensorboard==2.14.0 From 4dccdc6955ad1a3d0e18b188ee7f8fd6dd80817b Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Tue, 23 Apr 2024 18:48:07 +0800 Subject: [PATCH 05/11] [example] llama3 (#5631) * release llama3 * [release] llama3 * [release] llama3 * [release] llama3 * [release] llama3 --- README.md | 14 +++++++++++--- docs/README-zh-Hans.md | 10 +++++++++- examples/language/{llama2 => llama}/README.md | 6 ++++++ examples/language/{llama2 => llama}/benchmark.py | 0 .../language/{llama2 => llama}/requirements.txt | 0 .../{llama2 => llama}/scripts/benchmark_70B/3d.sh | 0 .../scripts/benchmark_70B/gemini.sh | 0 .../scripts/benchmark_70B/gemini_auto.sh | 0 .../scripts/benchmark_7B/gemini.sh | 0 .../scripts/benchmark_7B/gemini_auto.sh | 0 examples/language/{llama2 => llama}/test_ci.sh | 0 11 files changed, 26 insertions(+), 4 deletions(-) rename examples/language/{llama2 => llama}/README.md (94%) rename examples/language/{llama2 => llama}/benchmark.py (100%) rename examples/language/{llama2 => llama}/requirements.txt (100%) rename examples/language/{llama2 => llama}/scripts/benchmark_70B/3d.sh (100%) rename examples/language/{llama2 => llama}/scripts/benchmark_70B/gemini.sh (100%) rename examples/language/{llama2 => llama}/scripts/benchmark_70B/gemini_auto.sh (100%) rename examples/language/{llama2 => llama}/scripts/benchmark_7B/gemini.sh (100%) rename examples/language/{llama2 => llama}/scripts/benchmark_7B/gemini_auto.sh (100%) rename examples/language/{llama2 => llama}/test_ci.sh (100%) diff --git a/README.md b/README.md index 26776bdf6d9f..c1e2da0d406f 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@

  • Parallel Training Demo
      -
    • LLaMA 1/2
    • +
    • LLaMA 1/2/3
    • MoE
    • GPT-3
    • GPT-2
    • @@ -270,13 +270,21 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)

      (back to top)

      ## Parallel Training Demo +### LLaMA3 +

      + +

      + +- 70 billion parameter LLaMA3 model training accelerated by 18% +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama) + ### LLaMA2

      - 70 billion parameter LLaMA2 model training accelerated by 195% -[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2) +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama) [[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training) ### LLaMA1 @@ -285,7 +293,7 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)

      - 65-billion-parameter large model pretraining accelerated by 38% -[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama) +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama) [[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining) ### MoE diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index 6d243a80852d..7e0ed07fec16 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -51,7 +51,7 @@
    • 并行训练样例展示
        -
      • LLaMA 1/2
      • +
      • LLaMA 1/2/3
      • MoE
      • GPT-3
      • GPT-2
      • @@ -261,6 +261,14 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的

        (返回顶端)

        ## 并行训练样例展示 +### LLaMA3 +

        + +

        + +- 700亿参数LLaMA3训练加速18% +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama) + ### LLaMA2

        diff --git a/examples/language/llama2/README.md b/examples/language/llama/README.md similarity index 94% rename from examples/language/llama2/README.md rename to examples/language/llama/README.md index 11b2ee511a6e..fa0c6dc07156 100644 --- a/examples/language/llama2/README.md +++ b/examples/language/llama/README.md @@ -1,4 +1,10 @@ # Pretraining LLaMA-1/2/3: best practices for building LLaMA-1/2/3-like base models +### LLaMA3 +

        + +

        + +- 70 billion parameter LLaMA3 model training accelerated by 18% ### LLaMA2

        diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama/benchmark.py similarity index 100% rename from examples/language/llama2/benchmark.py rename to examples/language/llama/benchmark.py diff --git a/examples/language/llama2/requirements.txt b/examples/language/llama/requirements.txt similarity index 100% rename from examples/language/llama2/requirements.txt rename to examples/language/llama/requirements.txt diff --git a/examples/language/llama2/scripts/benchmark_70B/3d.sh b/examples/language/llama/scripts/benchmark_70B/3d.sh similarity index 100% rename from examples/language/llama2/scripts/benchmark_70B/3d.sh rename to examples/language/llama/scripts/benchmark_70B/3d.sh diff --git a/examples/language/llama2/scripts/benchmark_70B/gemini.sh b/examples/language/llama/scripts/benchmark_70B/gemini.sh similarity index 100% rename from examples/language/llama2/scripts/benchmark_70B/gemini.sh rename to examples/language/llama/scripts/benchmark_70B/gemini.sh diff --git a/examples/language/llama2/scripts/benchmark_70B/gemini_auto.sh b/examples/language/llama/scripts/benchmark_70B/gemini_auto.sh similarity index 100% rename from examples/language/llama2/scripts/benchmark_70B/gemini_auto.sh rename to examples/language/llama/scripts/benchmark_70B/gemini_auto.sh diff --git a/examples/language/llama2/scripts/benchmark_7B/gemini.sh b/examples/language/llama/scripts/benchmark_7B/gemini.sh similarity index 100% rename from examples/language/llama2/scripts/benchmark_7B/gemini.sh rename to examples/language/llama/scripts/benchmark_7B/gemini.sh diff --git a/examples/language/llama2/scripts/benchmark_7B/gemini_auto.sh b/examples/language/llama/scripts/benchmark_7B/gemini_auto.sh similarity index 100% rename from examples/language/llama2/scripts/benchmark_7B/gemini_auto.sh rename to examples/language/llama/scripts/benchmark_7B/gemini_auto.sh diff --git a/examples/language/llama2/test_ci.sh b/examples/language/llama/test_ci.sh similarity index 100% rename from examples/language/llama2/test_ci.sh rename to examples/language/llama/test_ci.sh From 70c6f84faccdf90e7d01233cbc0a4b2e8ca8b3cd Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 24 Apr 2024 15:02:43 +0800 Subject: [PATCH 06/11] [test] fix llama test (#5638) --- tests/kit/model_zoo/transformers/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 58b5b0487a82..61fa560506c2 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -64,7 +64,6 @@ def data_gen_for_casual_lm(): intermediate_size=64, num_attention_heads=4, max_position_embeddings=128, - num_labels=16, ) if hasattr(config, "pad_token_id"): From 555c8bfc10c1e35cff7392b08c04053f73b52892 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 24 Apr 2024 16:06:27 +0800 Subject: [PATCH 07/11] [gemini] fix buffer cast (#5639) --- colossalai/zero/gemini/gemini_ddp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index c79422171f1b..b25de1d68613 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -840,6 +840,7 @@ def _cast_buffers(self): for buffer in self.module.buffers(): if isinstance(buffer, LazyTensor): buffer.materialize() + for buffer in self.module.buffers(): buffer.data = buffer.to(get_accelerator().get_current_device()) if torch.is_floating_point(buffer): buffer.data = buffer.to(self.mixed_precision) From 255f3e4029e54e7ed4509873ce02a283948a5785 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 24 Apr 2024 12:09:14 +0000 Subject: [PATCH 08/11] support pp for mistral --- colossalai/shardformer/modeling/falcon.py | 141 ----- colossalai/shardformer/modeling/llama.py | 7 +- colossalai/shardformer/modeling/mistral.py | 485 ++++++++++++++++-- colossalai/shardformer/modeling/opt.py | 14 - colossalai/shardformer/policies/falcon.py | 6 +- colossalai/shardformer/policies/mistral.py | 158 +++++- tests/kit/model_zoo/transformers/llama.py | 32 +- tests/kit/model_zoo/transformers/mistral.py | 35 +- .../test_model/test_shard_llama.py | 244 ++++----- .../test_model/test_shard_mistral.py | 22 +- 10 files changed, 765 insertions(+), 379 deletions(-) diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index 34754ecdbac9..df3b09c71cbc 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -6,7 +6,6 @@ import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from torch.nn import functional as F from transformers.modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_causal_attention_mask, @@ -25,7 +24,6 @@ FalconForSequenceClassification, FalconForTokenClassification, FalconModel, - apply_rotary_pos_emb, build_alibi_tensor, ) from transformers.utils import logging @@ -171,145 +169,6 @@ def forward( return forward -def get_falcon_flash_attention_forward(): - try: - pass - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - from transformers.models.falcon.modeling_falcon import FalconAttention - - def forward( - self: FalconAttention, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - **kwargs, - ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - - batch_size, query_length, _, _ = query_layer.shape - - query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim) - key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - - kv_seq_len = key_layer.shape[-2] - if layer_past is not None: - kv_seq_len += layer_past[0].shape[-2] - if alibi is None: - cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) - query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) - - if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension: - # - key: [batch_size, self.num_heads, kv_length, head_dim] - # - value: [batch_size, self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=-2) - value_layer = torch.cat((past_value, value_layer), dim=-2) - - kv_length = key_layer.shape[-2] - if use_cache: - present = (key_layer, value_layer) - else: - present = None - - if alibi is None: - if self._use_sdpa and not output_attentions: - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attention_mask, - 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) - attention_scores = None - else: - attention_scores = query_layer @ key_layer.transpose(-1, -2) - attention_scores /= math.sqrt(self.head_dim) - - attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) - # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi). - attn_output = attention_scores @ value_layer - - attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) - attn_output = attn_output.permute(0, 2, 1, 3) - attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - - attn_output = self.dense(attn_output) - - if output_attentions: - return attn_output, present, attention_scores - else: - return attn_output, present - else: - if self._use_sdpa and not output_attentions and head_mask is None: - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attn_mask=attention_mask, - dropout_p=self.attention_dropout.p if self.training else 0.0, - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - - attn_output = self.dense(attn_output) - else: - matmul_result = query_layer @ key_layer.transpose(-1, -2) - - # change view to [batch_size, num_heads, q_length, kv_length] - attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) - - # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] - input_dtype = attention_scores.dtype - # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` - if input_dtype == torch.float16 or input_dtype == torch.bfloat16: - attention_scores = attention_scores.to(torch.float32) - - attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) - attention_logits *= self.inv_norm_factor - attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) - # [batch_size, num_heads, q_length, kv_length] - attention_probs = self.attention_dropout(attention_probs) - - if head_mask is not None: - attention_probs = attention_probs * head_mask - - # change view [batch_size, num_heads, q_length, kv_length] - attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) - - # matmul: [batch_size * num_heads, q_length, head_dim] - attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1) - - # change view [batch_size, q_length, num_heads * head_dim] - attn_output = self._merge_heads(attn_output) - - attn_output = self.dense(attn_output) - - if output_attentions: - return attn_output, present, attention_probs - else: - return attn_output, present - - return forward - - class FalconPipelineForwards: """ This class serves as a micro library for falcon pipeline forwards. diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index ac9baad5fdb9..4de57afb9444 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -115,10 +115,6 @@ def llama_model_forward( ) position_ids = position_ids.unsqueeze(0) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: @@ -414,6 +410,9 @@ def llama_for_sequence_classification_forward( else: batch_size = hidden_states.shape[0] + print("batch_sizellama", batch_size) + print("self.config.pad_token_id", self.config.pad_token_id) + if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] logits = self.score(hidden_states) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 3b876bcab96a..db9d3afbc82b 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -2,12 +2,22 @@ from typing import List, Optional, Tuple, Union import torch -from transformers.cache_utils import Cache +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.mistral.modeling_mistral import MistralModel +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.models.mistral.modeling_mistral import MistralForCausalLM, MistralModel from transformers.utils import logging +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.shard import ShardConfig + +from ..layer import ColoAttention + logger = logging.get_logger(__name__) @@ -24,6 +34,10 @@ def mistral_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ) -> Union[Tuple, BaseModelOutputWithPast]: if use_cache: logger.warning_once("use_cache=True is not supported for Mistral models at the moment.") @@ -35,6 +49,377 @@ def mistral_model_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + past_key_values_length = 0 + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if attention_mask is not None and self._use_flash_attention_2 and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if shard_config.enable_flash_attention: + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length, seq_length) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + start_idx, end_idx = stage_index[0], stage_index[1] + num_ckpt_layers = 0 + if self.gradient_checkpointing and self.training: + num_ckpt_layers = end_idx - start_idx + # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer + if shard_config.gradient_checkpoint_config is not None: + num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( + stage=stage_manager.stage, + num_layers=end_idx - start_idx, + model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), + ) + assert num_ckpt_layers <= end_idx - start_idx + + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if idx - start_idx < num_ckpt_layers: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + else: + return {"hidden_states": hidden_states} + + @staticmethod + def mistral_for_causal_lm_forward( + self: MistralForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = MistralForwards.mistral_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def mistral_for_sequence_classification_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = MistralForwards.mistral_model_forward( + self.model, + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + if input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + batch_size = hidden_states.shape[0] + print("batch_size", batch_size) + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( + logits.device + ) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig): + logger = logging.get_logger(__name__) + assert shard_config.enable_flash_attention, "Flash Attention is not enabled." + + def forward( + self: MistralModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") @@ -47,6 +432,12 @@ def mistral_model_forward( past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( @@ -67,19 +458,29 @@ def mistral_model_forward( " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + if shard_config.enable_flash_attention: + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length, seq_length) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + inputs_embeds.dtype, + inputs_embeds.device, + q_padding_mask=attention_mask, + is_causal=True, ) + else: + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) hidden_states = inputs_embeds @@ -93,6 +494,7 @@ def mistral_model_forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -121,7 +523,7 @@ def mistral_model_forward( hidden_states = layer_outputs[0] if use_cache: - layer_outputs[2 if output_attentions else 1] + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -133,6 +535,8 @@ def mistral_model_forward( all_hidden_states += (hidden_states,) next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -143,11 +547,11 @@ def mistral_model_forward( attentions=all_self_attns, ) + return forward -def get_mistral_flash_attention_forward(): - from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention +def get_mistral_flash_attention_forward(shard_config: ShardConfig): + from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv def forward( self: MistralAttention, @@ -164,15 +568,14 @@ def forward( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() - assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = ( - self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - ) - value_states = ( - self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - ) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -190,34 +593,18 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) - query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) - key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) - value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) - flash_attention_mask = None - attn_mask_type = AttnMaskType.causal - if attention_mask != None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - attn_mask_type = AttnMaskType.paddedcausal - - attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) - attn_output = attention( - query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type - ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, None, past_key_value return forward diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 76534b5d5d2e..8f841c8a6615 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -58,20 +58,6 @@ class OPTPipelineForwards: under pipeline setting. """ - @staticmethod - def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - @staticmethod def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 09d895843b61..e72a97e4bfc0 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -137,11 +137,7 @@ def module_policy(self): if self.shard_config.enable_flash_attention: warnings.warn("Falcon doesn't support flash attention now, fallback to transformers attention.") - # self.append_or_create_method_replacement( - # description={"forward": get_falcon_flash_attention_forward()}, - # policy=policy, - # target_key=FalconAttention, - # ) + return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index b3f89b4042c1..5345bbc90a83 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -1,8 +1,10 @@ import warnings from functools import partial -from typing import Callable, Dict, Union +from typing import Callable, Dict, List, Union import torch.nn as nn +from torch import Tensor +from torch.nn import Module from colossalai.shardformer.layer import ( FusedRMSNorm, @@ -14,7 +16,11 @@ VocabParallelLMHead1D, ) -from ..modeling.mistral import MistralForwards, get_mistral_flash_attention_forward +from ..modeling.mistral import ( + MistralForwards, + get_mistral_flash_attention_forward, + get_mistral_model_forward_for_flash_attn, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["MistralPolicy", "MistralModelPolicy", "MistralForCausalLMPolicy", "MistralForSequenceClassificationPolicy"] @@ -146,16 +152,87 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=attn_cls, ) + if self.pipeline_stage_manager is None: + # replace llama model forward method + self.append_or_create_method_replacement( + description={ + "forward": get_mistral_model_forward_for_flash_attn(self.shard_config), + }, + policy=policy, + target_key=MistralModel, + ) return policy def postprocess(self): return self.model - def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - method_replacement = {"forward": partial(new_forward)} + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager is None: + return + + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "MistralModel": + module = self.model + else: + module = self.model.model + + if stage_manager.is_interleave: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) + method_replacement = { + "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + } + + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "MistralModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(module.norm) + + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + print("layers_per_stage", layers_per_stage) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + print("start_idx, end_idx", start_idx, end_idx) + print("input_layernorm", module.layers[start_idx].input_layernorm.weight) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + print(held_layers) + return held_layers + class MistralModelPolicy(MistralPolicy): def __init__(self) -> None: @@ -165,17 +242,29 @@ def module_policy(self): policy = super().module_policy() from transformers.models.mistral.modeling_mistral import MistralModel - self.set_forward(model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy + ) + print("policy", policy) + return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in mistral model""" + return [] + class MistralForCausalLMPolicy(MistralPolicy): def module_policy(self): from transformers import MistralForCausalLM policy = super().module_policy() - if self.pipeline_stage_manager: - warnings.warn("Mistral doesn't support pipeline parallelism now.") if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm @@ -208,8 +297,38 @@ def module_policy(self): policy.update(new_item) + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MistralForCausalLM, new_forward=MistralForwards.mistral_for_causal_lm_forward, policy=policy + ) + return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + mistral_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(mistral_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: mistral_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] + class MistralForSequenceClassificationPolicy(MistralPolicy): def module_policy(self): @@ -228,9 +347,28 @@ def module_policy(self): ] ) } + policy.update(new_item) - if self.pipeline_stage_manager: - warnings.warn("Mistral doesn't support pipeline parallelism now.") + # if self.pipeline_stage_manager: + # warnings.warn("Mistral doesn't support pipeline parallelism now.") + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MistralForSequenceClassification, + new_forward=MistralForwards.mistral_for_sequence_classification_forward, + policy=policy, + ) - policy.update(new_item) return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama for sequence classification model""" + return [] diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 61fa560506c2..5cbfe23938a1 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -73,22 +73,22 @@ def data_gen_for_casual_lm(): # transformers.LlamaModel, # transformers.LlamaForCausalLM, # transformers.LlamaForSequenceClassification, - model_zoo.register( - name="transformers_llama", - model_fn=lambda: transformers.LlamaModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True), - ) - model_zoo.register( - name="transformers_llama_for_casual_lm", - model_fn=lambda: transformers.LlamaForCausalLM(config), - data_gen_fn=data_gen_for_casual_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_casual_lm, - model_attribute=ModelAttribute(has_control_flow=True), - ) + # model_zoo.register( + # name="transformers_llama", + # model_fn=lambda: transformers.LlamaModel(config), + # data_gen_fn=data_gen, + # output_transform_fn=output_transform_fn, + # loss_fn=loss_fn, + # model_attribute=ModelAttribute(has_control_flow=True), + # ) + # model_zoo.register( + # name="transformers_llama_for_casual_lm", + # model_fn=lambda: transformers.LlamaForCausalLM(config), + # data_gen_fn=data_gen_for_casual_lm, + # output_transform_fn=output_transform_fn, + # loss_fn=loss_fn_for_casual_lm, + # model_attribute=ModelAttribute(has_control_flow=True), + # ) model_zoo.register( name="transformers_llama_for_sequence_classification", model_fn=lambda: transformers.LlamaForSequenceClassification(config), diff --git a/tests/kit/model_zoo/transformers/mistral.py b/tests/kit/model_zoo/transformers/mistral.py index 37f87585759e..1d1ba157b3c9 100644 --- a/tests/kit/model_zoo/transformers/mistral.py +++ b/tests/kit/model_zoo/transformers/mistral.py @@ -52,22 +52,25 @@ def data_gen_for_sequence_classification(): hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258 ) -model_zoo.register( - name="transformers_mistral", - model_fn=lambda: transformers.MistralModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_mistral_model, - model_attribute=ModelAttribute(has_control_flow=True), -) -model_zoo.register( - name="transformers_mistral_for_casual_lm", - model_fn=lambda: transformers.MistralForCausalLM(config), - data_gen_fn=data_gen_for_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True), -) +if hasattr(config, "pad_token_id"): + config.pad_token_id = config.eos_token_id + +# model_zoo.register( +# name="transformers_mistral", +# model_fn=lambda: transformers.MistralModel(config), +# data_gen_fn=data_gen, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn_for_mistral_model, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) +# model_zoo.register( +# name="transformers_mistral_for_casual_lm", +# model_fn=lambda: transformers.MistralForCausalLM(config), +# data_gen_fn=data_gen_for_lm, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) model_zoo.register( name="transformers_mistral_for_sequence_classification", model_fn=lambda: transformers.MistralForSequenceClassification(config), diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 2a10d86c79bb..06aa261ed51b 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -131,85 +131,85 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - { - "tp_size": 2, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": True, - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": False, - "use_lazy_init": True, - "precision": "fp32", - "initial_scale": 1, - }, - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": False, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 1, - "sp_size": 2, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 1, - "sp_size": 2, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": False, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - "enable_gradient_checkpointing": True, - "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), - }, + # { + # "tp_size": 2, + # "pp_size": 1, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring", + # "enable_flash_attention": True, + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 4, + # "pp_size": 1, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring", + # "enable_flash_attention": False, + # "use_lazy_init": True, + # "precision": "fp32", + # "initial_scale": 1, + # }, + # { + # "tp_size": 4, + # "pp_size": 1, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "split_gather", + # "enable_flash_attention": False, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 1, + # "pp_size": 1, + # "sp_size": 2, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "all_to_all", + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 1, + # "pp_size": 1, + # "sp_size": 2, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "all_to_all", + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 1, + # "pp_size": 1, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "all_to_all", + # "enable_flash_attention": False, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 2, + # "pp_size": 2, + # "num_microbatches": 2, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # "enable_gradient_checkpointing": True, + # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), + # }, { "tp_size": 1, "pp_size": 2, @@ -221,41 +221,41 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0] ), }, - { - "tp_size": 4, - "pp_size": 1, - "enable_all_optimization": False, - "use_lazy_init": False, - "precision": "fp32", - }, - { - "tp_size": 1, - "pp_size": 4, - "num_microbatches": 4, - "enable_all_optimization": False, - "use_lazy_init": False, - "precision": "fp32", - }, - {"tp_size": 2, "pp_size": 1, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32"}, - { - "tp_size": 2, - "pp_size": 1, - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, + # { + # "tp_size": 4, + # "pp_size": 1, + # "enable_all_optimization": False, + # "use_lazy_init": False, + # "precision": "fp32", + # }, + # { + # "tp_size": 1, + # "pp_size": 4, + # "num_microbatches": 4, + # "enable_all_optimization": False, + # "use_lazy_init": False, + # "precision": "fp32", + # }, + # {"tp_size": 2, "pp_size": 1, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32"}, + # { + # "tp_size": 2, + # "pp_size": 1, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 2, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 1, + # "precision": "fp16", + # "initial_scale": 1, + # }, ], ) def run_llama_test(test_config): @@ -341,13 +341,13 @@ def test_llama(): spawn(check_llama, 4) -@pytest.mark.largedist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_llama_3d(): - spawn(check_llama_3d, 8) +# @pytest.mark.largedist +# @rerun_if_address_is_in_use() +# @clear_cache_before_run() +# def test_llama_3d(): +# spawn(check_llama_3d, 8) -if __name__ == "__main__": - test_llama() - test_llama_3d() +# if __name__ == "__main__": +# test_llama() +# test_llama_3d() diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py index f127472aee0b..2e7a4886494d 100644 --- a/tests/test_shardformer/test_model/test_shard_mistral.py +++ b/tests/test_shardformer/test_model/test_shard_mistral.py @@ -91,7 +91,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config["precision"] == "fp32": - atol, rtol = 1e-4, 1e-3 + atol, rtol = 2e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 check_weight( @@ -114,6 +114,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 4, "pp_size": 1, @@ -156,7 +174,7 @@ def check_mistral(rank, world_size, port): run_mistral_test() -@pytest.mark.skip("something wrong with pipeline parallelism") +# @pytest.mark.skip("something wrong with pipeline parallelism") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() From ae4a4f0c30624c48389dea5d8d3d780c6a8b860a Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 24 Apr 2024 12:22:01 +0000 Subject: [PATCH 09/11] fix --- colossalai/shardformer/modeling/llama.py | 3 - colossalai/shardformer/modeling/mistral.py | 1 - colossalai/shardformer/policies/mistral.py | 7 - tests/kit/model_zoo/transformers/llama.py | 32 +-- tests/kit/model_zoo/transformers/mistral.py | 32 +-- .../test_model/test_shard_llama.py | 244 +++++++++--------- 6 files changed, 154 insertions(+), 165 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 4de57afb9444..2f6cb73c01a8 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -410,9 +410,6 @@ def llama_for_sequence_classification_forward( else: batch_size = hidden_states.shape[0] - print("batch_sizellama", batch_size) - print("self.config.pad_token_id", self.config.pad_token_id) - if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] logits = self.score(hidden_states) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index db9d3afbc82b..ac7845400d8d 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -338,7 +338,6 @@ def mistral_for_sequence_classification_forward( batch_size = inputs_embeds.shape[0] else: batch_size = hidden_states.shape[0] - print("batch_size", batch_size) if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 5345bbc90a83..b5018e47d65d 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -221,16 +221,12 @@ def get_held_layers(self) -> List[Module]: else: layers_per_stage = stage_manager.distribute_layers(len(module.layers)) - print("layers_per_stage", layers_per_stage) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - print("start_idx, end_idx", start_idx, end_idx) - print("input_layernorm", module.layers[start_idx].input_layernorm.weight) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) - print(held_layers) return held_layers @@ -246,7 +242,6 @@ def module_policy(self): self.set_pipeline_forward( model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy ) - print("policy", policy) return policy @@ -349,8 +344,6 @@ def module_policy(self): } policy.update(new_item) - # if self.pipeline_stage_manager: - # warnings.warn("Mistral doesn't support pipeline parallelism now.") if self.pipeline_stage_manager: # set None as default self.set_pipeline_forward( diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 5cbfe23938a1..61fa560506c2 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -73,22 +73,22 @@ def data_gen_for_casual_lm(): # transformers.LlamaModel, # transformers.LlamaForCausalLM, # transformers.LlamaForSequenceClassification, - # model_zoo.register( - # name="transformers_llama", - # model_fn=lambda: transformers.LlamaModel(config), - # data_gen_fn=data_gen, - # output_transform_fn=output_transform_fn, - # loss_fn=loss_fn, - # model_attribute=ModelAttribute(has_control_flow=True), - # ) - # model_zoo.register( - # name="transformers_llama_for_casual_lm", - # model_fn=lambda: transformers.LlamaForCausalLM(config), - # data_gen_fn=data_gen_for_casual_lm, - # output_transform_fn=output_transform_fn, - # loss_fn=loss_fn_for_casual_lm, - # model_attribute=ModelAttribute(has_control_flow=True), - # ) + model_zoo.register( + name="transformers_llama", + model_fn=lambda: transformers.LlamaModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), + ) + model_zoo.register( + name="transformers_llama_for_casual_lm", + model_fn=lambda: transformers.LlamaForCausalLM(config), + data_gen_fn=data_gen_for_casual_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_casual_lm, + model_attribute=ModelAttribute(has_control_flow=True), + ) model_zoo.register( name="transformers_llama_for_sequence_classification", model_fn=lambda: transformers.LlamaForSequenceClassification(config), diff --git a/tests/kit/model_zoo/transformers/mistral.py b/tests/kit/model_zoo/transformers/mistral.py index 1d1ba157b3c9..ae5a9700240a 100644 --- a/tests/kit/model_zoo/transformers/mistral.py +++ b/tests/kit/model_zoo/transformers/mistral.py @@ -55,22 +55,22 @@ def data_gen_for_sequence_classification(): if hasattr(config, "pad_token_id"): config.pad_token_id = config.eos_token_id -# model_zoo.register( -# name="transformers_mistral", -# model_fn=lambda: transformers.MistralModel(config), -# data_gen_fn=data_gen, -# output_transform_fn=output_transform_fn, -# loss_fn=loss_fn_for_mistral_model, -# model_attribute=ModelAttribute(has_control_flow=True), -# ) -# model_zoo.register( -# name="transformers_mistral_for_casual_lm", -# model_fn=lambda: transformers.MistralForCausalLM(config), -# data_gen_fn=data_gen_for_lm, -# output_transform_fn=output_transform_fn, -# loss_fn=loss_fn, -# model_attribute=ModelAttribute(has_control_flow=True), -# ) +model_zoo.register( + name="transformers_mistral", + model_fn=lambda: transformers.MistralModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_mistral_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_mistral_for_casual_lm", + model_fn=lambda: transformers.MistralForCausalLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) model_zoo.register( name="transformers_mistral_for_sequence_classification", model_fn=lambda: transformers.MistralForSequenceClassification(config), diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 06aa261ed51b..2a10d86c79bb 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -131,85 +131,85 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - # { - # "tp_size": 2, - # "pp_size": 1, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring", - # "enable_flash_attention": True, - # "use_lazy_init": True, - # "zero_stage": 2, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 4, - # "pp_size": 1, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring", - # "enable_flash_attention": False, - # "use_lazy_init": True, - # "precision": "fp32", - # "initial_scale": 1, - # }, - # { - # "tp_size": 4, - # "pp_size": 1, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "split_gather", - # "enable_flash_attention": False, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 1, - # "pp_size": 1, - # "sp_size": 2, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "all_to_all", - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 1, - # "pp_size": 1, - # "sp_size": 2, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "all_to_all", - # "use_lazy_init": True, - # "zero_stage": 2, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 1, - # "pp_size": 1, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "all_to_all", - # "enable_flash_attention": False, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 2, - # "pp_size": 2, - # "num_microbatches": 2, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # "enable_gradient_checkpointing": True, - # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), - # }, + { + "tp_size": 2, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), + }, { "tp_size": 1, "pp_size": 2, @@ -221,41 +221,41 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0] ), }, - # { - # "tp_size": 4, - # "pp_size": 1, - # "enable_all_optimization": False, - # "use_lazy_init": False, - # "precision": "fp32", - # }, - # { - # "tp_size": 1, - # "pp_size": 4, - # "num_microbatches": 4, - # "enable_all_optimization": False, - # "use_lazy_init": False, - # "precision": "fp32", - # }, - # {"tp_size": 2, "pp_size": 1, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32"}, - # { - # "tp_size": 2, - # "pp_size": 1, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 2, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 2, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 1, - # "precision": "fp16", - # "initial_scale": 1, - # }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, ], ) def run_llama_test(test_config): @@ -341,13 +341,13 @@ def test_llama(): spawn(check_llama, 4) -# @pytest.mark.largedist -# @rerun_if_address_is_in_use() -# @clear_cache_before_run() -# def test_llama_3d(): -# spawn(check_llama_3d, 8) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama_3d(): + spawn(check_llama_3d, 8) -# if __name__ == "__main__": -# test_llama() -# test_llama_3d() +if __name__ == "__main__": + test_llama() + test_llama_3d() From 29589ffba4a238b15d478d03fd22871ea738b489 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 24 Apr 2024 12:25:06 +0000 Subject: [PATCH 10/11] fix fix fix --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index b0352230788a..d307312ded8e 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -16,4 +16,4 @@ ray sentencepiece google protobuf -transformers==4.36.0 +transformers==4.36.2 From d102d1d197dd583c30f40cb46b8d863852b221a4 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 24 Apr 2024 13:14:31 +0000 Subject: [PATCH 11/11] fix --- colossalai/shardformer/policies/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 1b30ae9c9f40..0a95284bcfdf 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -56,7 +56,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] - embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D