From 20fa9a65cb682d32f294004f7ea6852e93f8f4dd Mon Sep 17 00:00:00 2001 From: Insu Jang Date: Fri, 15 Dec 2023 04:40:52 +0000 Subject: [PATCH 1/9] Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution --- colossalai/shardformer/policies/bloom.py | 4 ++-- colossalai/shardformer/policies/chatglm2.py | 4 ++-- colossalai/shardformer/policies/falcon.py | 4 ++-- colossalai/shardformer/policies/gpt2.py | 2 +- colossalai/shardformer/policies/gptj.py | 4 ++-- colossalai/shardformer/policies/llama.py | 6 +++--- colossalai/shardformer/policies/opt.py | 4 ++-- colossalai/shardformer/policies/t5.py | 10 +++++----- colossalai/shardformer/policies/vit.py | 4 ++-- colossalai/shardformer/policies/whisper.py | 13 +++++-------- 10 files changed, 26 insertions(+), 29 deletions(-) diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index eddfafdcbcdc..e4714c8c1b15 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -203,8 +203,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.transformer - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index d1ad9f91478b..cbe6254d1561 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -204,8 +204,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.transformer - layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 5c148880f980..16bbc3f23f81 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -161,8 +161,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.transformer - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 5b43ecaed0c7..13de6c1e4997 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -229,7 +229,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = self.distribute_layers( len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_manager.stage_indices = Policy.get_stage_index( + stage_manager.stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index b001a2009381..b24443298e07 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -200,8 +200,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.transformer - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index db8468713f66..86b170009f7e 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -167,7 +167,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = self.distribute_layers( len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_manager.stage_indices = Policy.get_stage_index( + stage_manager.stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -178,8 +178,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli } else: - layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 9a74da0b8d46..683f3a9d5a2d 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -208,8 +208,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.model.decoder - layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index e183b0632f88..2100d955ed3f 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -261,7 +261,7 @@ def distribute_t5_layers( # in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist if num_decoder_layers == 0: - return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages + return self.distribute_layers(num_encoder_layers, num_stages), num_stages # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) @@ -272,8 +272,8 @@ def objective(num_encoder_stages): num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_decoder_stages = num_stages - num_encoder_stages - encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) - decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages) + encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages) return encoder_distribution + decoder_distribution, num_encoder_stages @staticmethod @@ -285,9 +285,9 @@ def get_t5_stage_index( Return the starting/ending idx of layers in encoder/decoder """ if stage < decoder_starting_stage: - return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) else: - return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) + return self.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 584d4e2652c0..b0f224e22dc9 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -149,8 +149,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, else: module = self.model.vit - layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 14e1e3e0fda8..c9ab53312d21 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -312,7 +312,7 @@ def distribute_whisper_layers( # in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist if num_decoder_layers == 0: - return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages + return self.distribute_layers(num_encoder_layers, num_stages), num_stages # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) @@ -323,8 +323,8 @@ def objective(num_encoder_stages): num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_decoder_stages = num_stages - num_encoder_stages - encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) - decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages) + encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages) return encoder_distribution + decoder_distribution, num_encoder_stages @staticmethod @@ -336,12 +336,9 @@ def get_whisper_stage_index( Return the starting/ending idx of layers in encoder/decoder """ if stage < decoder_starting_stage: - return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) else: - return Policy.get_stage_index( - layers_per_stage[decoder_starting_stage:], - stage - decoder_starting_stage, - ) + return self.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) def get_held_layers(self) -> List[nn.Module]: assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" From 0d9a70e190f483cb8a1aefef7909205546a61c50 Mon Sep 17 00:00:00 2001 From: Insu Jang Date: Fri, 15 Dec 2023 07:45:35 +0000 Subject: [PATCH 2/9] Change static methods for t5 layer distribution to member functions --- colossalai/shardformer/policies/t5.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 2100d955ed3f..f5f701dc0972 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings from functools import partial from typing import Callable, Dict, List, Tuple @@ -241,9 +243,8 @@ def module_policy(self): def postprocess(self): return self.model - @staticmethod def distribute_t5_layers( - num_encoder_layers: int, num_decoder_layers: int, num_stages: int + self, num_encoder_layers: int, num_decoder_layers: int, num_stages: int ) -> Tuple[List[int], int]: """ Distribute t5 layers into stages when pipeline parallel is used. @@ -276,9 +277,8 @@ def objective(num_encoder_stages): decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages) return encoder_distribution + decoder_distribution, num_encoder_stages - @staticmethod def get_t5_stage_index( - layers_per_stage: List[int], stage: int, decoder_starting_stage: int + self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int ) -> Tuple[bool, int, int]: """ Input the distribution of layers among stages, the current stage and the first stage of decoder. @@ -302,12 +302,10 @@ def get_held_layers(self) -> List[nn.Module]: num_decoder_layers = len(decoder.block) if decoder else 0 held_layers = [] - layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + layers_per_stage, decoder_starting_stage = self.distribute_t5_layers( num_encoder_layers, num_decoder_layers, stage_manager.num_stages ) - start_idx, end_idx = T5BasePolicy.get_t5_stage_index( - layers_per_stage, stage_manager.stage, decoder_starting_stage - ) + start_idx, end_idx = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) if stage_manager.stage < decoder_starting_stage: # current stage is in t5's encoder @@ -343,10 +341,10 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli num_encoder_layers = len(encoder.block) num_decoder_layers = len(decoder.block) if decoder else 0 - layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + layers_per_stage, decoder_starting_stage = self.distribute_t5_layers( num_encoder_layers, num_decoder_layers, stage_manager.num_stages ) - stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) + stage_index = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) method_replacement = { "forward": partial( @@ -386,7 +384,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None and stage_manager.num_stages > 1: - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + _, decoder_starting_stage = self.distribute_t5_layers( len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages ) @@ -434,7 +432,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None and stage_manager.num_stages > 1: - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + _, decoder_starting_stage = self.distribute_t5_layers( len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages ) From 31f8365b9409435f9267c2cecd81f84c18fc234d Mon Sep 17 00:00:00 2001 From: Insu Jang Date: Fri, 15 Dec 2023 07:47:26 +0000 Subject: [PATCH 3/9] Change static methods for whisper layer distribution to member functions --- colossalai/shardformer/policies/whisper.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index c9ab53312d21..ecf27a88f474 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -292,9 +292,8 @@ def add_lm_head_policy(self, base_policy): def postprocess(self): return self.model - @staticmethod def distribute_whisper_layers( - num_encoder_layers: int, num_decoder_layers: int, num_stages: int + self, num_encoder_layers: int, num_decoder_layers: int, num_stages: int ) -> Tuple[List[int], int]: """ Distribute whisper layers into stages when pipeline parallel is used. @@ -327,9 +326,8 @@ def objective(num_encoder_stages): decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages) return encoder_distribution + decoder_distribution, num_encoder_stages - @staticmethod def get_whisper_stage_index( - layers_per_stage: List[int], stage: int, decoder_starting_stage: int + self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int ) -> Tuple[bool, int, int]: """ Input the distribution of layers among stages, the current stage and the first stage of decoder. From dd9d2c64b5e6d137b00e2e2df501e041d5e658f1 Mon Sep 17 00:00:00 2001 From: Insu Jang Date: Fri, 15 Dec 2023 16:07:18 +0000 Subject: [PATCH 4/9] Replace whisper policy usage with self one --- colossalai/shardformer/policies/whisper.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index ecf27a88f474..480a4beea581 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -336,7 +336,10 @@ def get_whisper_stage_index( if stage < decoder_starting_stage: return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) else: - return self.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) + return self.get_stage_index( + layers_per_stage[decoder_starting_stage:], + stage - decoder_starting_stage, + ) def get_held_layers(self) -> List[nn.Module]: assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" @@ -364,12 +367,10 @@ def get_held_layers(self) -> List[nn.Module]: num_decoder_layers = 0 held_layers = [] - layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers( num_encoder_layers, num_decoder_layers, stage_manager.num_stages ) - start_idx, end_idx = WhisperPolicy.get_whisper_stage_index( - layers_per_stage, stage_manager.stage, decoder_starting_stage - ) + start_idx, end_idx = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) if stage_manager.stage < decoder_starting_stage: # current stage is in whisper's encoder @@ -419,12 +420,10 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: num_decoder_layers = 0 - layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers( num_encoder_layers, num_decoder_layers, stage_manager.num_stages ) - stage_index = WhisperPolicy.get_whisper_stage_index( - layers_per_stage, stage_manager.stage, decoder_starting_stage - ) + stage_index = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) method_replacement = { "forward": partial( @@ -506,7 +505,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: stage_manager = self.pipeline_stage_manager if stage_manager is not None and stage_manager.num_stages > 1: - _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + _, decoder_starting_stage = self.distribute_whisper_layers( num_encoder_layers, num_decoder_layers, stage_manager.num_stages ) shared_params = [] From 485f874cb1de258b969f1d7b03a46c3e7d0c3d18 Mon Sep 17 00:00:00 2001 From: Insu Jang Date: Sat, 16 Dec 2023 14:25:46 +0000 Subject: [PATCH 5/9] Fix test case to use non-static layer distribution methods --- .../test_t5_pipeline_utils.py | 18 +++++++++++------- .../test_whisper_pipeline_utils.py | 16 +++++++++++----- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py index 3723c9c1014a..4ba67225f271 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -10,9 +10,12 @@ def test_t5_pipeline_distribution(): "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], } + policy = T5BasePolicy() for i in range(num_test_cases): - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( - test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + _, decoder_starting_stage = policy.distribute_t5_layers( + test_dict["num_encoder_layers"][i], + test_dict["num_decoder_layers"][i], + test_dict["num_stages"][i], ) assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage @@ -32,14 +35,15 @@ def test_t5_pipeline_layers(): } for i in range(num_test_cases): - layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( - test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + policy = T5BasePolicy() + layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers( + test_dict["num_encoder_layers"][i], + test_dict["num_decoder_layers"][i], + test_dict["num_stages"][i], ) for stage in range(test_dict["num_stages"][i]): start_idx, end_idx = test_dict["layers_per_stage"][i][stage] - predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index( - layers_per_stage, stage, decoder_starting_stage - ) + predicted_start, predicted_end = policy.get_t5_stage_index(layers_per_stage, stage, decoder_starting_stage) assert start_idx == predicted_start assert end_idx == predicted_end diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py index f6be8f6feac2..0500e46e890a 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -10,9 +10,12 @@ def test_whisper_pipeline_distribution(): "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], } + policy = WhisperPolicy() for i in range(num_test_cases): - _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( - test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + _, decoder_starting_stage = policy.distribute_whisper_layers( + test_dict["num_encoder_layers"][i], + test_dict["num_decoder_layers"][i], + test_dict["num_stages"][i], ) assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage @@ -31,14 +34,17 @@ def test_whisper_pipeline_layers(): ], } + policy = WhisperPolicy() for i in range(num_test_cases): - layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( - test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers( + test_dict["num_encoder_layers"][i], + test_dict["num_decoder_layers"][i], + test_dict["num_stages"][i], ) for stage in range(test_dict["num_stages"][i]): start_idx, end_idx = test_dict["layers_per_stage"][i][stage] - predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index( + predicted_start, predicted_end = policy.get_whisper_stage_index( layers_per_stage, stage, decoder_starting_stage ) assert start_idx == predicted_start From e8d4b7f489b9f262f1c57a7912dcf495919c4cc7 Mon Sep 17 00:00:00 2001 From: Insu Jang Date: Sat, 6 Jan 2024 15:34:06 +0000 Subject: [PATCH 6/9] Fix missing changes --- colossalai/shardformer/policies/bert.py | 67 +++++++++++++++++------- colossalai/shardformer/policies/llama.py | 2 +- 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 0ab63b7650c1..4d50a3c9920c 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -84,17 +84,26 @@ def module_policy(self): SubModuleReplacementDescription( suffix="attention.self.query", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attention.self.key", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attention.self.value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attention.self.dropout", @@ -112,7 +121,10 @@ def module_policy(self): SubModuleReplacementDescription( suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="output.dense", @@ -214,7 +226,9 @@ def add_lm_head_policy(self, base_policy): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + suffix="decoder", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}, ), policy=base_policy, target_key=BertLMPredictionHead, @@ -241,7 +255,9 @@ def add_lm_prediction_policy(self, base_policy): "_load_from_state_dict": col_nn.ParallelModule._load_from_state_dict, } self.append_or_create_method_replacement( - description=method_replacement, policy=base_policy, target_key=BertLMPredictionHead + description=method_replacement, + policy=base_policy, + target_key=BertLMPredictionHead, ) return base_policy @@ -264,24 +280,32 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli if stage_manager.is_interleave: layers_per_stage = self.distribute_layers( - len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks + len(module.encoder.layer), + stage_manager.num_stages * stage_manager.num_model_chunks, ) - stage_manager.stage_indices = Policy.get_stage_index( + stage_manager.stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, num_stages=stage_manager.num_stages, ) method_replacement = { - "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + "forward": partial( + new_forward, + stage_manager=stage_manager, + shard_config=self.shard_config, + ) } else: - layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( - new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config, ) } @@ -301,9 +325,10 @@ def get_held_layers(self) -> List[Module]: if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = self.distribute_layers( - len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks + len(module.encoder.layer), + stage_manager.num_stages * stage_manager.num_model_chunks, ) - stage_indices = Policy.get_stage_index( + stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -320,7 +345,7 @@ def get_held_layers(self) -> List[Module]: layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) if stage_manager.is_first_stage(): held_layers.append(module.embeddings) - start_idx, end_idx = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) held_layers.extend(module.encoder.layer[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.pooler) @@ -336,7 +361,9 @@ def module_policy(self): if self.pipeline_stage_manager: self.set_pipeline_forward( - model_cls=BertModel, new_forward=BertPipelineForwards.bert_model_forward, policy=policy + model_cls=BertModel, + new_forward=BertPipelineForwards.bert_model_forward, + policy=policy, ) return policy @@ -399,7 +426,9 @@ def module_policy(self): if self.pipeline_stage_manager: self.set_pipeline_forward( - model_cls=BertLMHeadModel, new_forward=BertPipelineForwards.bert_lm_head_model_forward, policy=policy + model_cls=BertLMHeadModel, + new_forward=BertPipelineForwards.bert_lm_head_model_forward, + policy=policy, ) return policy @@ -437,7 +466,9 @@ def module_policy(self): if self.pipeline_stage_manager: self.set_pipeline_forward( - model_cls=BertForMaskedLM, new_forward=BertPipelineForwards.bert_for_masked_lm_forward, policy=policy + model_cls=BertForMaskedLM, + new_forward=BertPipelineForwards.bert_for_masked_lm_forward, + policy=policy, ) return policy diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 86b170009f7e..daa7708c8fdf 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -207,7 +207,7 @@ def get_held_layers(self) -> List[Module]: layers_per_stage = self.distribute_layers( len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_indices = Policy.get_stage_index( + stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, From 86c069e3d870df5bbe7d98e733cb732365793262 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 25 Mar 2024 12:29:54 +0800 Subject: [PATCH 7/9] style: polish code --- colossalai/shardformer/policies/base_policy.py | 2 +- tests/test_shardformer/test_layer/test_dist_crossentropy.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 9a49b1ba6a14..1d2b7a570681 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -242,4 +242,4 @@ def get_stage_index( end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] stage_indices.append([start_idx, end_idx]) - return stage_indices[0] if num_model_chunks == 1 else stage_indices \ No newline at end of file + return stage_indices[0] if num_model_chunks == 1 else stage_indices diff --git a/tests/test_shardformer/test_layer/test_dist_crossentropy.py b/tests/test_shardformer/test_layer/test_dist_crossentropy.py index f594a80a43e0..414157c2233d 100644 --- a/tests/test_shardformer/test_layer/test_dist_crossentropy.py +++ b/tests/test_shardformer/test_layer/test_dist_crossentropy.py @@ -38,9 +38,10 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index): org_loss, dist_loss, atol=1e-5 ), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" - target_grad = torch.chunk(pred.grad, world_size, dim=-1)[rank] - assert torch.allclose(target_grad, dist_pred.grad), f"dist grad is not equal to orgin grad\n{target_grad}\n{dist_pred.grad}" + assert torch.allclose( + target_grad, dist_pred.grad + ), f"dist grad is not equal to orgin grad\n{target_grad}\n{dist_pred.grad}" @pytest.mark.dist From 00f94fa1e7cc3669de8d8222d65d169c075370ff Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 25 Mar 2024 14:51:53 +0800 Subject: [PATCH 8/9] feat: Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution --- .../colossal_moe/models/mixtral_policy.py | 2 +- colossalai/shardformer/policies/base_policy.py | 5 ++--- colossalai/shardformer/policies/gpt2.py | 6 +++--- .../language/openmoe/model/openmoe_policy.py | 17 +++++++---------- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py index 218b05b27fad..23ffbf5d317c 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py @@ -110,7 +110,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli module = self.model.model layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 1d2b7a570681..762e754816bf 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -197,8 +197,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: """ return [] - @staticmethod - def distribute_layers(num_layers: int, num_stages: int) -> List[int]: + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: """Divide layers into stages""" quotient = num_layers // num_stages remainder = num_layers % num_stages @@ -213,8 +212,8 @@ def distribute_layers(num_layers: int, num_stages: int) -> List[int]: layers_per_stage[i] += 1 return layers_per_stage - @staticmethod def get_stage_index( + self, layers_per_stage: List[int], stage: int, num_model_chunks: int = 1, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 13de6c1e4997..d1a8c9dce2c7 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -188,7 +188,7 @@ def get_held_layers(self) -> List[nn.Module]: layers_per_stage = self.distribute_layers( len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_indices = Policy.get_stage_index( + stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -243,8 +243,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli ) } else: - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py index 17e7aa46ce85..66a42e0176e9 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -98,11 +98,11 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli module = self.model.model layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=model_cls) + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) return @@ -126,12 +126,9 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.norm) return held_layers - - @staticmethod - def distribute_layers(num_layers: int, num_stages: int) -> List[int]: - """Divide layers into stages - """ + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: + """Divide layers into stages""" if num_layers == 24 and num_stages == 4: return [7, 7, 7, 3] elif num_layers == 24 and num_stages == 2: @@ -142,7 +139,7 @@ def distribute_layers(num_layers: int, num_stages: int) -> List[int]: return [8, 4] else: print(f"num_layers: {num_layers}, num_stages: {num_stages} not optimized, use origin pp policy") - return Policy.distribute_layers(num_layers, num_stages) + return super().distribute_layers(num_layers, num_stages) class OpenMoeModelPolicy(OpenMoePolicy): From a74c702af059a1425cb676edcfa1ef4935d9c460 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Wed, 27 Mar 2024 12:18:15 +0800 Subject: [PATCH 9/9] fix: fix typo --- tests/test_booster/test_plugin/test_3d_plugin.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index 61558c003c91..52cb8c46ed41 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -83,7 +83,7 @@ def _criterion(outputs, inputs): @parameterize("init_method", ["none", "lazy"]) def check_3d_plugin(init_method: str = "none", early_stop: bool = True): - """check gemini plugin over model zoo + """check hybrid plugin over model zoo Args: early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. @@ -260,7 +260,7 @@ def run_grad_acc_test(test_args): origin_model, origin_optimizer, dataloader=dataloader ) for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()): - assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) + assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) def run_dist(rank, world_size, port, early_stop: bool = True): @@ -271,9 +271,9 @@ def run_dist(rank, world_size, port, early_stop: bool = True): @rerun_if_address_is_in_use() -def test_gemini_plugin(early_stop: bool = True): +def test_3d_plugin(early_stop: bool = True): spawn(run_dist, 4, early_stop=early_stop) if __name__ == "__main__": - test_gemini_plugin(early_stop=False) \ No newline at end of file + test_3d_plugin(early_stop=False)