From b31a052037104aad22928dd6f7f83a0d83a79409 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 21 Nov 2024 09:38:18 +0000 Subject: [PATCH 01/54] [feat] Sharderformer support zbv --- colossalai/shardformer/policies/blip2.py | 241 +++++++++++++++++++++++ colossalai/shardformer/policies/bloom.py | 156 +++++++++++++-- 2 files changed, 380 insertions(+), 17 deletions(-) diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 2e73d5c2a637..4ca1cefc2815 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -51,6 +51,8 @@ def module_policy(self): else: norm_cls = col_nn.LayerNorm + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -80,6 +82,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -88,6 +91,7 @@ def module_policy(self): kwargs={ "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -95,6 +99,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], @@ -126,6 +131,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -133,6 +139,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -140,6 +147,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -151,6 +159,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -162,6 +171,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -169,6 +179,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -176,6 +187,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -187,6 +199,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -198,6 +211,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -205,6 +219,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -227,6 +242,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -234,6 +250,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -241,6 +258,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -248,6 +266,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -255,6 +274,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -262,6 +282,227 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + + policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()}) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_blip2_mlp_forward(), + }, + policy=policy, + target_key=Blip2MLP, + ) + + elif use_zbv: + policy[Blip2EncoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="self_attn.qkv", + target_module=col_nn.FusedLinear1D_Col, + kwargs={ + "split_sizes": [self.model.config.vision_config.hidden_size] * 3, + "fp8_communication": self.shard_config.fp8_communication, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.projection", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.fc1", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.fc2", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + + policy[Blip2QFormerModel] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ] + ) + + policy[Blip2QFormerLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.query", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.key", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.value", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate_query.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="output_query.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="output_query.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + + policy[OPTDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 7c6259e850c2..c7691698bed2 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -59,6 +59,8 @@ def module_policy(self): sp_partial_derived = sp_mode == "split_gather" + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.n_head % self.shard_config.tensor_parallel_size == 0 @@ -78,6 +80,7 @@ def module_policy(self): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -86,6 +89,7 @@ def module_policy(self): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -98,6 +102,7 @@ def module_policy(self): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -106,6 +111,7 @@ def module_policy(self): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], @@ -120,6 +126,52 @@ def module_policy(self): }, ) + if use_zbv: + policy[BloomBlock] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=[ @@ -247,14 +299,27 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = stage_manager.distribute_layers(len(module.h)) - if stage_manager.is_first_stage(): - held_layers.append(module.word_embeddings) - held_layers.append(module.word_embeddings_layernorm) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.h[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.ln_f) + if stage_manager.is_interleave: + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.word_embeddings) + held_layers.append(module.word_embeddings_layernorm) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.h[start_idx:end_idx]) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(module.ln_f) + else: + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + if stage_manager.is_first_stage(): + held_layers.append(module.word_embeddings) + held_layers.append(module.word_embeddings_layernorm) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) return held_layers @@ -328,8 +393,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): - held_layers.append(self.model.lm_head) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -351,6 +422,7 @@ def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv # handle tensor parallelism if self.shard_config.enable_tensor_parallelism: @@ -363,6 +435,18 @@ def module_policy(self): policy=policy, target_key=BloomForSequenceClassification, ) + elif use_zbv: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="score", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict( + gather_output=True, fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv + ), + ), + policy=policy, + target_key=BloomForSequenceClassification, + ) if self.pipeline_stage_manager: self.set_pipeline_forward( model_cls=BloomForSequenceClassification, @@ -375,8 +459,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): - held_layers.append(self.model.score) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.score) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.score) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -389,6 +479,7 @@ def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForTokenClassification policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv # handle tensor parallelism if self.shard_config.enable_tensor_parallelism: @@ -407,6 +498,24 @@ def module_policy(self): policy=policy, target_key=BloomForTokenClassification, ) + elif use_zbv: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="classifier", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict( + gather_output=True, fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv + ), + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ], + policy=policy, + target_key=BloomForTokenClassification, + ) if self.pipeline_stage_manager: self.set_pipeline_forward( model_cls=BloomForTokenClassification, @@ -420,9 +529,16 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): - held_layers.append(self.model.dropout) - held_layers.append(self.model.classifier) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -448,8 +564,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): - held_layers.append(self.model.qa_outputs) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.qa_outputs) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: From 5f89e7f70eb7b9f4469c8f086772c312fbc02430 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 21 Nov 2024 10:15:38 +0000 Subject: [PATCH 02/54] [feat] support chatglm2, command, deepseek for zbv --- colossalai/shardformer/policies/chatglm2.py | 77 ++++++++++-- colossalai/shardformer/policies/command.py | 129 ++++++++++++++++++-- colossalai/shardformer/policies/deepseek.py | 111 ++++++++++++++--- 3 files changed, 277 insertions(+), 40 deletions(-) diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index c003570a0582..7603932c3d0c 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -82,6 +82,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: attribute_replacement=decoder_attribute_replacement, ) + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -144,6 +146,35 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ), ], ) + elif use_zbv: + policy["GLMBlock"] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + kwargs={ + "seq_parallel_mode": sp_mode, + "seq_parallel_dim": 0, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + kwargs={ + "seq_parallel_mode": sp_mode, + "seq_parallel_dim": 0, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attention.core_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) if embedding_cls is not None: self.append_or_create_submodule_replacement( @@ -253,17 +284,30 @@ def get_held_layers(self) -> List[nn.Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = stage_manager.distribute_layers(module.num_layers) - if stage_manager.is_first_stage(): - held_layers.append(module.embedding) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.encoder.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(): - if module.encoder.post_layer_norm: - held_layers.append(module.encoder.final_layernorm) - - # rotary_pos_emb is needed for all stages - held_layers.append(module.rotary_pos_emb) + if stage_manager.is_interleave: + layers_per_stage = stage_manager.distribute_layers(module.num_layers) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + if module.encoder.post_layer_norm: + held_layers.append(module.encoder.final_layernorm) + else: + layers_per_stage = stage_manager.distribute_layers(module.num_layers) + if stage_manager.is_first_stage(): + held_layers.append(module.embedding) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.encoder.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + if module.encoder.post_layer_norm: + held_layers.append(module.encoder.final_layernorm) + + # rotary_pos_emb is needed for all stages + held_layers.append(module.rotary_pos_emb) return held_layers @@ -327,8 +371,15 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.transformer.output_layer) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.transformer.output_layer) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.transformer.output_layer) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 323480d6d084..e6e741d34a3a 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -10,6 +10,7 @@ LayerNorm, Linear1D_Col, Linear1D_Row, + LinearWithGradAccum, PaddingEmbedding, PaddingLMHead, VocabParallelEmbedding1D, @@ -107,6 +108,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_key=CohereModel, ) + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( num_q_heads % tp_size == 0 @@ -128,41 +131,137 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + ], + ) + elif use_zbv: + policy[CohereDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), ], ) - if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -258,7 +357,9 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): held_layers.append(module.norm) else: @@ -351,8 +452,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.lm_head) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index bd54e6f2db9e..9baf068aec9f 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -6,7 +6,7 @@ from torch.nn import Module from transformers.utils import is_flash_attn_greater_or_equal_2_10 -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, LinearWithGradAccum from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D from colossalai.shardformer.layer.linear import Linear1D_Row from colossalai.shardformer.modeling.deepseek import ( @@ -107,6 +107,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.tie_weight: embedding_cls = PaddingEmbedding + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: # tensor parallelism for non-moe params assert ( @@ -133,22 +135,58 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication": self.shard_config.fp8_communication}, + kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv}, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication": self.shard_config.fp8_communication}, + kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv}, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication": self.shard_config.fp8_communication}, + kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv}, ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs={"fp8_communication": self.shard_config.fp8_communication}, + kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv}, + ), + SubModuleReplacementDescription( + suffix="mlp.gate", + target_module=DeepseekMoEGate_Col, + kwargs={ + "gather_output": True, + "fp8_communication": self.shard_config.fp8_communication, + "config": self.model.config, + }, + ignore_if_not_exist=True, + ), + ], + ) + elif use_zbv: + policy["DeepseekDecoderLayer"] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv}, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv}, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv}, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv}, ), SubModuleReplacementDescription( suffix="mlp.gate", @@ -162,7 +200,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ), ], ) - if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -291,13 +328,26 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = stage_manager.distribute_layers(len(module.layers)) - if stage_manager.is_first_stage(): - held_layers.append(module.embed_tokens) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.norm) + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(module.norm) + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) return held_layers @@ -330,6 +380,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class DeepseekForCausalLMPolicy(DeepseekPolicy): def module_policy(self): policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv # TODO: assign pg mesh from plugin to all modules if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm @@ -339,7 +390,29 @@ def module_policy(self): SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, - kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ) + ] + ) + } + policy.update(new_item) + elif use_zbv: + # add a new item for casual lm + new_item = { + "DeepseekForCausalLM": ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=LinearWithGradAccum, + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ) ] ) @@ -360,8 +433,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): - held_layers.append(self.model.lm_head) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: From 41e1972496c43515c9bfcdbca50ecaaaaf760056 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 22 Nov 2024 05:38:14 +0000 Subject: [PATCH 03/54] [feat] support zbv in shardformer policy: falcon,gptj,mistral,opt,qwen2,t5, vit, whisper --- colossalai/shardformer/policies/falcon.py | 179 ++++++++++++-- colossalai/shardformer/policies/gptj.py | 133 +++++++++-- colossalai/shardformer/policies/mistral.py | 25 +- colossalai/shardformer/policies/opt.py | 123 +++++++++- colossalai/shardformer/policies/qwen2.py | 180 ++++++++++++-- colossalai/shardformer/policies/t5.py | 249 ++++++++++++++++--- colossalai/shardformer/policies/vit.py | 178 ++++++++++++-- colossalai/shardformer/policies/whisper.py | 263 ++++++++++++++++++--- 8 files changed, 1190 insertions(+), 140 deletions(-) diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index e20fb1568505..68a548aee869 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -51,6 +51,8 @@ def module_policy(self): if self.tie_weight: embedding_cls = col_nn.PaddingEmbedding + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -73,10 +75,16 @@ def module_policy(self): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, + kwargs=dict( + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, + kwargs=dict( + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attention.attention_dropout", @@ -85,8 +93,17 @@ def module_policy(self): SubModuleReplacementDescription( suffix="mlp.dense_h_to_4h", target_module=col_nn.Linear1D_Col, + kwargs=dict( + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.Linear1D_Row, + kwargs=dict( + use_zbv=use_zbv, + ), ), - SubModuleReplacementDescription(suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row), ], ) @@ -98,6 +115,44 @@ def module_policy(self): "build_alibi_tensor": build_falcon_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) }, ) + elif use_zbv: + policy[FalconDecoderLayer] = ModulePolicyDescription( + method_replacement={"forward": get_tp_falcon_decoder_layer_forward()}, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict( + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict( + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict( + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict( + use_zbv=use_zbv, + ), + ), + ], + ) if embedding_cls is not None: self.append_or_create_submodule_replacement( @@ -191,13 +246,26 @@ def get_held_layers(self) -> List[Module]: module = self.model.transformer stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = stage_manager.distribute_layers(len(module.h)) - if stage_manager.is_first_stage(): - held_layers.append(module.word_embeddings) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.h[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.ln_f) + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.word_embeddings) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.h[start_idx:end_idx]) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(module.ln_f) + else: + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + if stage_manager.is_first_stage(): + held_layers.append(module.word_embeddings) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) return held_layers @@ -281,8 +349,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): - held_layers.append(self.model.lm_head) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -308,11 +382,23 @@ def module_policy(self): policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + # handle tensor parallelism if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True, use_zbv=use_zbv) + ), + policy=policy, + target_key=FalconForSequenceClassification, + ) + elif use_zbv: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="score", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict(gather_output=True, use_zbv=use_zbv), ), policy=policy, target_key=FalconForSequenceClassification, @@ -330,8 +416,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): - held_layers.append(self.model.score) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.score) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.score) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -348,12 +440,32 @@ def module_policy(self): policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + # handle tensor parallelism if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=[ SubModuleReplacementDescription( - suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + suffix="classifier", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True, use_zbv=use_zbv), + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ], + policy=policy, + target_key=FalconForTokenClassification, + ) + elif use_zbv: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="classifier", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict(gather_output=True, use_zbv=use_zbv), ), SubModuleReplacementDescription( suffix="dropout", @@ -375,9 +487,16 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): - held_layers.append(self.model.dropout) - held_layers.append(self.model.classifier) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -394,11 +513,25 @@ def module_policy(self): policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + # handle tensor parallelism if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="qa_outputs", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + suffix="qa_outputs", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True, use_zbv=use_zbv), + ), + policy=policy, + target_key=FalconForQuestionAnswering, + ) + elif use_zbv: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="qa_outputs", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True, use_zbv=use_zbv), ), policy=policy, target_key=FalconForQuestionAnswering, @@ -415,8 +548,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): - held_layers.append(self.model.qa_outputs) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.qa_outputs) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 9fcca1385f79..891ebbdcc693 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -51,6 +51,8 @@ def module_policy(self): self.shard_config.enable_sequence_parallelism = False warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -76,6 +78,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -83,6 +86,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -90,6 +94,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -97,6 +102,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -104,6 +110,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -111,6 +118,72 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + elif use_zbv: + policy[GPTJBlock] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attn.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attn.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attn.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.fc_in", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.fc_out", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -127,7 +200,6 @@ def module_policy(self): ), ], ) - if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -200,13 +272,25 @@ def get_held_layers(self) -> List[nn.Module]: held_layers = [] layers_per_stage = stage_manager.distribute_layers(len(module.h)) - if stage_manager.is_first_stage(): - held_layers.append(module.wte) - held_layers.append(module.drop) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.h[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.ln_f) + if stage_manager.is_interleave: + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.wte) + held_layers.append(module.drop) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.h[start_idx:end_idx]) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(module.ln_f) + else: + if stage_manager.is_first_stage(): + held_layers.append(module.wte) + held_layers.append(module.drop) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) return held_layers def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: @@ -309,8 +393,15 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.lm_head) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -349,8 +440,15 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.score) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.score) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.score) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -378,8 +476,15 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.qa_outputs) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.qa_outputs) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index b4b87df923a3..f9c9a9404e72 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -324,9 +324,10 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): held_layers.append(module.norm) - else: layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): @@ -419,8 +420,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.lm_head) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -475,8 +482,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.score) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.score) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.score) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index dd64ce652f86..50742b850b24 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -10,6 +10,7 @@ LayerNorm, Linear1D_Col, Linear1D_Row, + LinearWithGradAccum, PaddingEmbedding, PaddingLMHead, VocabParallelEmbedding1D, @@ -76,6 +77,8 @@ def module_policy(self): self.shard_config.enable_sequence_parallelism = False warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -85,10 +88,16 @@ def module_policy(self): SubModuleReplacementDescription( suffix="fc1", target_module=Linear1D_Col, + kwargs=dict( + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="fc2", target_module=Linear1D_Row, + kwargs=dict( + use_zbv=use_zbv, + ), ), ] ) @@ -104,6 +113,7 @@ def module_policy(self): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -111,6 +121,7 @@ def module_policy(self): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -118,6 +129,7 @@ def module_policy(self): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -125,11 +137,67 @@ def module_policy(self): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], ) + elif use_zbv: + policy[OPTDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=LinearWithGradAccum, + kwargs=dict( + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=LinearWithGradAccum, + kwargs=dict( + use_zbv=use_zbv, + ), + ), + ] + ) + policy[attn_cls] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -221,15 +289,30 @@ def get_held_layers(self) -> List[nn.Module]: held_layers = [] layers_per_stage = stage_manager.distribute_layers(len(module.layers)) - if stage_manager.is_first_stage(): - held_layers.append(module.embed_tokens) - held_layers.append(module.embed_positions) - held_layers.append(module.project_in) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.final_layer_norm) - held_layers.append(module.project_out) + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + held_layers.append(module.embed_positions) + held_layers.append(module.project_in) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(module.final_layer_norm) + held_layers.append(module.project_out) + else: + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + held_layers.append(module.embed_positions) + held_layers.append(module.project_in) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.final_layer_norm) + held_layers.append(module.project_out) return held_layers def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: @@ -323,8 +406,15 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.lm_head) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -395,8 +485,15 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.qa_outputs) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.qa_outputs) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 1b066200de64..84d2b2fdbd99 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -9,6 +9,7 @@ FusedRMSNorm, Linear1D_Col, Linear1D_Row, + LinearWithGradAccum, PaddingEmbedding, RMSNorm, VocabParallelEmbedding1D, @@ -96,6 +97,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: attribute_replacement=decoder_attribute_replacement, ) + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -119,37 +122,134 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + ], + ) + elif use_zbv: + policy[Qwen2DecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), ], ) @@ -278,7 +378,9 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): held_layers.append(module.norm) else: @@ -318,6 +420,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy): def module_policy(self): policy = super().module_policy() setattr(self.shard_config, "causal_lm", True) + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm @@ -327,7 +430,22 @@ def module_policy(self): SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, - kwargs=dict(fp8_communication=self.shard_config.fp8_communication), + kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), + ) + ], + method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, + ) + } + policy.update(new_item) + elif use_zbv: + # add a new item for casual lm + new_item = { + Qwen2ForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=LinearWithGradAccum, + kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), ) ], method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, @@ -347,8 +465,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.lm_head) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -371,6 +495,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class Qwen2ForSequenceClassificationPolicy(Qwen2Policy): def module_policy(self): policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv if self.shard_config.enable_tensor_parallelism: # add a new item for sequence classification new_item = { @@ -379,7 +504,28 @@ def module_policy(self): SubModuleReplacementDescription( suffix="score", target_module=Linear1D_Col, - kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ) + ] + ) + } + policy.update(new_item) + elif use_zbv: + new_item = { + Qwen2ForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", + target_module=LinearWithGradAccum, + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ) ] ) @@ -399,8 +545,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.score) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.score) + else: + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.score) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 84b5d95947f0..6320a1668b09 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -13,6 +13,7 @@ FusedRMSNorm, Linear1D_Col, Linear1D_Row, + LinearWithGradAccum, PaddingEmbedding, PaddingLMHead, RMSNorm, @@ -77,6 +78,8 @@ def module_policy(self): self.shard_config.enable_sequence_parallelism = False warnings.warn("T5 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_heads % self.shard_config.tensor_parallel_size == 0 @@ -119,6 +122,7 @@ def module_policy(self): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -126,6 +130,7 @@ def module_policy(self): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -133,6 +138,7 @@ def module_policy(self): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -140,6 +146,7 @@ def module_policy(self): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -168,6 +175,7 @@ def module_policy(self): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -175,6 +183,7 @@ def module_policy(self): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -183,6 +192,7 @@ def module_policy(self): kwargs=dict( gather_output=True, fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, ), ), SubModuleReplacementDescription( @@ -198,6 +208,7 @@ def module_policy(self): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -205,6 +216,142 @@ def module_policy(self): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + elif use_zbv: + policy[T5Stack] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5LayerSelfAttention] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5LayerCrossAttention] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ] + ) + policy[T5Attention] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="k", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="v", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="o", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="relative_attention_bias", + target_module=Embedding1D, + kwargs=dict( + gather_output=False, + fp8_communication=self.shard_config.fp8_communication, + ), + ignore_if_not_exist=True, + ), + ], + ) + policy[T5LayerFF] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5DenseGatedActDense] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi_0 ", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="wi_1", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="wo", + target_module=LinearWithGradAccum, + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5DenseActDense] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="wo", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -213,7 +360,6 @@ def module_policy(self): ), ] ) - if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -369,30 +515,61 @@ def get_held_layers(self) -> List[nn.Module]: num_decoder_layers = len(decoder.block) if decoder else 0 held_layers = [] - layers_per_stage, decoder_starting_stage = self.distribute_t5_layers( - num_encoder_layers, num_decoder_layers, stage_manager.num_stages - ) - start_idx, end_idx = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) - - if stage_manager.stage < decoder_starting_stage: - # current stage is in t5's encoder - if stage_manager.is_first_stage(): - held_layers.append(model.shared) - held_layers.append(encoder.embed_tokens) - held_layers.append(encoder.dropout) - if stage_manager.stage == decoder_starting_stage - 1: - held_layers.append(encoder.final_layer_norm) - held_layers.append(encoder.dropout) - held_layers.extend(encoder.block[start_idx:end_idx]) + if stage_manager.is_interleave: + layers_per_stage, decoder_starting_stage = self.distribute_t5_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + stage_indices = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) + if stage_manager.stage < decoder_starting_stage: + # current stage is in t5's encoder + if stage_manager.is_first_stage(): + held_layers.append(model.shared) + held_layers.append(encoder.embed_tokens) + held_layers.append(encoder.dropout) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(encoder.final_layer_norm) + held_layers.append(encoder.dropout) + for start_idx, end_idx in stage_indices: + held_layers.extend(encoder.block[start_idx:end_idx]) + else: + # current stage is in t5's decoder + if stage_manager.stage == decoder_starting_stage: + held_layers.append(decoder.embed_tokens) + held_layers.append(decoder.dropout) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(decoder.final_layer_norm) + held_layers.append(decoder.dropout) + for start_idx, end_idx in stage_indices: + held_layers.extend(decoder.block[start_idx:end_idx]) else: - # current stage is in t5's decoder - if stage_manager.stage == decoder_starting_stage: - held_layers.append(decoder.embed_tokens) - held_layers.append(decoder.dropout) - if stage_manager.is_last_stage(): - held_layers.append(decoder.final_layer_norm) - held_layers.append(decoder.dropout) - held_layers.extend(decoder.block[start_idx:end_idx]) + layers_per_stage, decoder_starting_stage = self.distribute_t5_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + start_idx, end_idx = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) + + if stage_manager.stage < decoder_starting_stage: + # current stage is in t5's encoder + if stage_manager.is_first_stage(): + held_layers.append(model.shared) + held_layers.append(encoder.embed_tokens) + held_layers.append(encoder.dropout) + if stage_manager.stage == decoder_starting_stage - 1: + held_layers.append(encoder.final_layer_norm) + held_layers.append(encoder.dropout) + held_layers.extend(encoder.block[start_idx:end_idx]) + else: + # current stage is in t5's decoder + if stage_manager.stage == decoder_starting_stage: + held_layers.append(decoder.embed_tokens) + held_layers.append(decoder.dropout) + if stage_manager.is_last_stage(): + held_layers.append(decoder.final_layer_norm) + held_layers.append(decoder.dropout) + held_layers.extend(decoder.block[start_idx:end_idx]) return held_layers def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: @@ -545,8 +722,15 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.lm_head) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -652,9 +836,16 @@ def get_held_layers(self) -> List[nn.Module]: """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.dropout) - held_layers.append(self.model.classifier) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + else: + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 07202094f1f3..7b7dbf5557aa 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -43,6 +43,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: self.shard_config.enable_sequence_parallelism = False warnings.warn("Vit doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -72,6 +74,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -79,6 +82,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -86,6 +90,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -97,6 +102,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -109,6 +115,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: kwargs={ "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -116,6 +123,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -132,7 +140,92 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=ViTIntermediate, ) + elif use_zbv: + policy[ViTEmbeddings] = ModulePolicyDescription( + attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForReplicatedInput, + ) + ], + ) + policy[ViTLayer] = ModulePolicyDescription( + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ], + ) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_vit_intermediate_forward(), + }, + policy=policy, + target_key=ViTIntermediate, + ) # use flash attention if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( @@ -173,11 +266,20 @@ def get_held_layers(self) -> List[nn.Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) - if stage_manager.is_first_stage(): - held_layers.append(module.embeddings) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.encoder.layer[start_idx:end_idx]) + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embeddings) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.encoder.layer[start_idx:end_idx]) + else: + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) + if stage_manager.is_first_stage(): + held_layers.append(module.embeddings) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.encoder.layer[start_idx:end_idx]) return held_layers def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict): @@ -213,9 +315,16 @@ def get_held_layers(self) -> List[nn.Module]: module = self.model stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): - held_layers.append(module.layernorm) - held_layers.append(module.pooler) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(module.layernorm) + held_layers.append(module.pooler) + else: + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(module.pooler) return held_layers @@ -226,6 +335,9 @@ def module_policy(self): from transformers.models.vit.modeling_vit import ViTForImageClassification, ViTModel policy = super().module_policy() + + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: new_item = { ViTForImageClassification: ModulePolicyDescription( @@ -233,13 +345,33 @@ def module_policy(self): SubModuleReplacementDescription( suffix="classifier", target_module=Linear1D_Col, - kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ) + ] + ) + } + policy.update(new_item) + elif use_zbv: + new_item = { + ViTForImageClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="classifier", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ) ] ) } policy.update(new_item) - if self.shard_config.pipeline_stage_manager is not None: self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) self.set_pipeline_forward( @@ -256,9 +388,16 @@ def get_held_layers(self) -> List[nn.Module]: module = self.model.vit stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): - held_layers.append(module.layernorm) - held_layers.append(self.model.classifier) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(module.layernorm) + held_layers.append(self.model.classifier) + else: + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(self.model.classifier) return held_layers @@ -285,8 +424,15 @@ def get_held_layers(self) -> List[nn.Module]: module = self.model.vit stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): - held_layers.append(module.layernorm) - held_layers.append(self.model.decoder) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(module.layernorm) + held_layers.append(self.model.decoder) + else: + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(self.model.decoder) return held_layers diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 7a1f146d5bb8..5d9b38e6fe98 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -72,6 +72,8 @@ def module_policy(self): "Whisper doesn't support sequence parallelism now, will ignore the sequence parallelism flag." ) + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + # TODO using the jit fused add_and_dropout affect the accuracy if self.shard_config.enable_jit_fused: self.shard_config.enable_jit_fused = False @@ -93,6 +95,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -100,6 +103,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -107,6 +111,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -114,6 +119,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -121,6 +127,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -128,6 +135,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], @@ -148,6 +156,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -155,6 +164,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -162,6 +172,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -169,6 +180,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -176,6 +188,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -183,6 +196,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -190,6 +204,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -197,6 +212,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -204,6 +220,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -211,6 +228,145 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + elif use_zbv: + policy[WhisperEncoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + + policy[WhisperDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], @@ -460,30 +616,66 @@ def get_held_layers(self) -> List[nn.Module]: num_decoder_layers = 0 held_layers = [] - layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers( - num_encoder_layers, num_decoder_layers, stage_manager.num_stages - ) - start_idx, end_idx = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) - - if stage_manager.stage < decoder_starting_stage: - # current stage is in whisper's encoder - if stage_manager.is_first_stage(): - held_layers.append(encoder.embed_positions) - held_layers.append(encoder.conv1) - held_layers.append(encoder.conv2) - if stage_manager.stage == decoder_starting_stage - 1: - held_layers.append(encoder.layer_norm) - held_layers.extend(encoder.layers[start_idx:end_idx]) + if stage_manager.is_interleave: + layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + stage_indices = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) + + if stage_manager.stage < decoder_starting_stage: + # current stage is in whisper's encoder + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(encoder.embed_positions) + held_layers.append(encoder.conv1) + held_layers.append(encoder.conv2) + # interleaved: not use_zbv & stage_manager.stage == decoder_starting_stage - 1 + # zbv: use_zbv & stage_manager.stage == first stage + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and decoder_starting_stage - 1 + ): + held_layers.append(encoder.layer_norm) + for start_idx, end_idx in stage_indices: + held_layers.extend(encoder.layers[start_idx:end_idx]) + else: + # current stage is in whisper's decoder + # TODO:(Jianghai) We divide encoder and decoder layers into different parts here, + # the case encoder and decoder put in same stage should be add in the future. + if stage_manager.stage == decoder_starting_stage: + held_layers.append(decoder.embed_tokens) + held_layers.append(decoder.embed_positions) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(decoder.layer_norm) + for start_idx, end_idx in stage_indices: + held_layers.extend(encoder.layers[start_idx:end_idx]) else: - # current stage is in whisper's decoder - # TODO:(Jianghai) We divide encoder and decoder layers into different parts here, - # the case encoder and decoder put in same stage should be add in the future. - if stage_manager.stage == decoder_starting_stage: - held_layers.append(decoder.embed_tokens) - held_layers.append(decoder.embed_positions) - if stage_manager.is_last_stage(): - held_layers.append(decoder.layer_norm) - held_layers.extend(decoder.layers[start_idx:end_idx]) + layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + start_idx, end_idx = self.get_whisper_stage_index( + layers_per_stage, stage_manager.stage, decoder_starting_stage + ) + + if stage_manager.stage < decoder_starting_stage: + # current stage is in whisper's encoder + if stage_manager.is_first_stage(): + held_layers.append(encoder.embed_positions) + held_layers.append(encoder.conv1) + held_layers.append(encoder.conv2) + if stage_manager.stage == decoder_starting_stage - 1: + held_layers.append(encoder.layer_norm) + held_layers.extend(encoder.layers[start_idx:end_idx]) + else: + # current stage is in whisper's decoder + # TODO:(Jianghai) We divide encoder and decoder layers into different parts here, + # the case encoder and decoder put in same stage should be add in the future. + if stage_manager.stage == decoder_starting_stage: + held_layers.append(decoder.embed_tokens) + held_layers.append(decoder.embed_positions) + if stage_manager.is_last_stage(): + held_layers.append(decoder.layer_norm) + held_layers.extend(decoder.layers[start_idx:end_idx]) return held_layers def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: @@ -575,8 +767,15 @@ def postprocess(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.proj_out) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.proj_out) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.proj_out) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -629,9 +828,17 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.projector) - held_layers.append(self.model.classifier) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.projector) + held_layers.append(self.model.classifier) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.projector) + held_layers.append(self.model.classifier) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: From efffe6bb61e177c3aff3c7962812d3c5ba058964 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 9 Dec 2024 14:07:41 +0800 Subject: [PATCH 04/54] [feat] support GPT2FusedLinearConv1D --- colossalai/shardformer/layer/_operation.py | 105 ++++++++++++ .../shardformer/layer/qkv_fused_linear.py | 157 ++++++++++++++++++ 2 files changed, 262 insertions(+) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 8c2e6e7c5d92..fbdd5018e17e 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -123,6 +123,107 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias, None, None, None, None +class MatmulWithGradAccum(torch.autograd.Function): + """ + Linear layer execution with grad accum in backprop. (no tp version) + """ + + @staticmethod + def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False): + ctx.save_for_backward(input_, weight, bias) + ctx.use_bias = bias is not None + ctx.async_grad_allreduce = async_grad_allreduce + ctx.use_zbv = use_zbv + + output = torch.matmul(input_, weight) + + if bias is not None: + output = output + bias + + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + use_bias = ctx.use_bias + use_zbv = ctx.use_zbv + + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): + wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) + + def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): + return wgrad_gemm_func(_grad_output_.t(), _input_) + + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. + weight = weight.view(weight.shape) + if bias is not None: + bias = bias.view(bias.shape) + + total_input = input + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + # split dx & dw + if weight.grad is not None: + grad = weight.grad + if use_zbv: + if grad.dtype == torch.float32: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, + ), + ) + grad_weight = None + elif grad.dtype in (torch.float16, torch.bfloat16): + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, + ), + ) + grad_weight = None + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + else: + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + else: + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass, + wgrad_gemm_func=torch.matmul, + ), + ) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + + grad_bias = grad_output.sum(dim=0) if use_bias else None + + return grad_input, grad_weight, grad_bias, None, None, None, None + + class LinearWithAsyncCommunication(torch.autograd.Function): """ Linear layer execution with asynchronous communication in backprop. @@ -1113,6 +1214,10 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre input_, weight, bias, process_group, async_grad_allreduce, fp8_communication ) +def matmul_with_grad_comm(input_, weight, bias, async_grad_allreduce, use_zbv=False): + return MatmulWithGradAccum.apply( + input_, weight, bias, async_grad_allreduce, use_zbv + ) def linear_with_async_comm( input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 6e469686b403..57c5ce19f439 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -30,6 +30,7 @@ linear_with_async_comm, matmul_gather_forward_reducescatter_backward, matmul_with_async_comm, + matmul_with_grad_comm, reduce_forward, reducescatter_forward_gather_backward, split_forward_gather_backward, @@ -620,6 +621,162 @@ def forward(self, input_: Tensor) -> Tensor: return output, self.bias +class GPT2FusedLinearConv1D(ParallelModule): + r"""Linear layer without parallelism. + This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None. + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + def __init__( + self, + in_features: int, + out_features: int, + split_sizes: List[int], + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + seq_parallel_mode: str = None, + seq_parallel_dim: int = 1, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + fp8_communication: bool = False, + use_zbv: bool = False, + ): + super().__init__() + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.seq_parallel_mode = seq_parallel_mode + self.seq_parallel_dim = seq_parallel_dim + self.skip_bias_add = skip_bias_add + self.device = device + self.split_sizes = split_sizes + self.fp8_communication = fp8_communication + self.use_zbv = use_zbv + + assert ( + sum(split_sizes) == out_features + ), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})." + + if skip_bias_add and not bias: + raise ValueError("cannot skip bias addition if bias is None") + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, None) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" + else: + assert bias_ is None, "bias_ must be None if weight is None" + + # Parameters. + if weight is None: + # Initialize weight. + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + else: + self.bias = None + + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module( + module: nn.Module, + split_sizes: List[int], + *args, + **kwargs, + ) -> ParallelModule: + r""" + Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. + + Args: + module (`nn.Linear`): The module to be converted. + split_sizes (List[int]): The sizes of the split tensor. In GPT2, Q,K,V are fused in one weight. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.weight.shape[0] + out_features = module.weight.shape[1] + bias = module.bias is not None + device = module.weight.device + + linear_1d = GPT2FusedLinearConv1D( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + weight=module.weight, + bias_=module.bias, + split_sizes=split_sizes, + *args, + **kwargs, + ) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn": + # Set up backprop all-reduce. + input_parallel = input_ + output_parallel = matmul_with_grad_comm( + input_parallel, + self.weight, + bias, + False, + self.use_zbv, + ) + else: + raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!") + + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + # ==================================== # For Fused torch.nn.Linear # ==================================== From a84fc41de0389576510a5d46a170f4cd27c785ea Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 9 Dec 2024 18:37:39 +0800 Subject: [PATCH 05/54] [feat] support GPT2FusedLinear (without tp) --- colossalai/pipeline/weight_grad_store.py | 1 - colossalai/shardformer/layer/__init__.py | 3 +- colossalai/shardformer/layer/_operation.py | 11 ++- .../shardformer/layer/qkv_fused_linear.py | 11 +-- .../test_gpt2_qkv_fused_linear_1d.py | 77 ++++++++++++++++++- 5 files changed, 84 insertions(+), 19 deletions(-) diff --git a/colossalai/pipeline/weight_grad_store.py b/colossalai/pipeline/weight_grad_store.py index c51c45085ea2..66909317eba2 100644 --- a/colossalai/pipeline/weight_grad_store.py +++ b/colossalai/pipeline/weight_grad_store.py @@ -18,7 +18,6 @@ def flush(cls, chunk=0): @classmethod def pop(cls, chunk=0): - # print(f"chunk id {chunk} queue size {cls.weight_grad_queue[chunk].qsize()}") if cls.weight_grad_queue[chunk].qsize() > 0: stored_grads = cls.weight_grad_queue[chunk].get() for total_input, grad_output, weight, func in stored_grads: diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index da5363840848..800364003cef 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -6,7 +6,7 @@ from .loss import cross_entropy_1d, dist_cross_entropy from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule -from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, GPT2FusedLinearConv1D __all__ = [ "Embedding1D", @@ -16,6 +16,7 @@ "Linear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row", + "GPT2FusedLinearConv1D_Col", "DropoutForParallelInput", "DropoutForReplicatedInput", "cross_entropy_1d", diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index fbdd5018e17e..8d89b8f33869 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -136,7 +136,6 @@ def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False): ctx.use_zbv = use_zbv output = torch.matmul(input_, weight) - if bias is not None: output = output + bias @@ -149,10 +148,10 @@ def backward(ctx, grad_output): use_zbv = ctx.use_zbv def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): - wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) + wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_) def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): - return wgrad_gemm_func(_grad_output_.t(), _input_) + return wgrad_gemm_func(_input_.t(), _grad_output_) # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. weight = weight.view(weight.shape) @@ -203,7 +202,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) grad_weight = None else: - grad_weight = grad_output.t().matmul(total_input) + grad_weight = total_input.t().matmul(grad_output) else: if use_zbv: WeightGradStore.put( @@ -217,8 +216,8 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f ) grad_weight = None else: - grad_weight = grad_output.t().matmul(total_input) - + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None return grad_input, grad_weight, grad_bias, None, None, None, None diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 57c5ce19f439..052c0e4fd8d9 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -38,7 +38,7 @@ from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset, is_share_sp_tp -__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"] +__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row", "GPT2FusedLinearConv1D"] # ==================================== # For GPT Only @@ -630,7 +630,6 @@ class GPT2FusedLinearConv1D(ParallelModule): out_features (int): size of each output sample. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (`torch.dtype`): The dtype of parameters, defaults to None. - parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None. which is preserved for kernel fusion, defaults to False @@ -646,7 +645,6 @@ def __init__( self, in_features: int, out_features: int, - split_sizes: List[int], bias: bool = True, dtype: torch.dtype = None, device: torch.device = None, @@ -668,14 +666,9 @@ def __init__( self.seq_parallel_dim = seq_parallel_dim self.skip_bias_add = skip_bias_add self.device = device - self.split_sizes = split_sizes self.fp8_communication = fp8_communication self.use_zbv = use_zbv - assert ( - sum(split_sizes) == out_features - ), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})." - if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -714,7 +707,6 @@ def __init__( @staticmethod def from_native_module( module: nn.Module, - split_sizes: List[int], *args, **kwargs, ) -> ParallelModule: @@ -739,7 +731,6 @@ def from_native_module( device=device, weight=module.weight, bias_=module.bias, - split_sizes=split_sizes, *args, **kwargs, ) diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index a45beb77108f..0f68f6c639f4 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -8,7 +8,8 @@ import colossalai from colossalai.lazy import LazyInitContext -from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from colossalai.pipeline.weight_grad_store import WeightGradStore +from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, GPT2FusedLinearConv1D from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -118,11 +119,85 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool): assert_close(target_grad, linear_row.weight.grad) +def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: str): + ctx = LazyInitContext() if lazy_init else nullcontext() + + linear = Conv1D(192, 48).cuda() + with ctx: + linear_copy = Conv1D(192, 48).cuda() + linear_base = GPT2FusedLinearConv1D.from_native_module( + linear_copy, seq_parallel_mode=seq_parallel_mode + ) + + assert linear.weight.shape == torch.Size([48, 192]) + assert linear_base.weight.shape == torch.Size([48, 192]) + assert linear_base.bias.shape == torch.Size([192]) + assert linear_copy.weight is linear_base.weight + assert linear_copy.bias is linear_base.bias + + # ensure weights are reversibly loadable + linear_base.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_base.state_dict()) + + # check computation correctness + x = torch.rand(1, 4, 48).cuda() + out = linear(x) + gather_out = linear_base(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + # check the input gradients & weight gradients + assert_close(out.grad, gather_out.grad) + assert_close(linear.weight.grad, linear_base.weight.grad) + +def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: str): + ctx = LazyInitContext() if lazy_init else nullcontext() + + linear = Conv1D(192, 48).cuda() + with ctx: + linear_copy = Conv1D(192, 48).cuda() + linear_base = GPT2FusedLinearConv1D.from_native_module( + linear_copy, seq_parallel_mode=seq_parallel_mode, use_zbv=True + ) + + assert linear.weight.shape == torch.Size([48, 192]) + assert linear_base.weight.shape == torch.Size([48, 192]) + assert linear_base.bias.shape == torch.Size([192]) + assert linear_copy.weight is linear_base.weight + assert linear_copy.bias is linear_base.bias + + # ensure weights are reversibly loadable + linear_base.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_base.state_dict()) + + # check computation correctness + x = torch.rand(1, 4, 48).cuda() + out = linear(x) + gather_out = linear_base(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue + WeightGradStore.pop(chunk=0) + + # check the input gradients & weight gradients + assert_close(out.grad, gather_out.grad) + # TODO:linear_base.weight.grad is None; But not none in WeightGradStore + # assert_close(linear.weight.grad, linear_base.weight.grad) + @parameterize("lazy_init", [False, True]) @parameterize("seq_parallel_mode", ["split_gather", None]) def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool): check_linear_conv_1d_col(lazy_init, seq_parallel_mode) check_linear_conv_1d_row(lazy_init, seq_parallel_mode) + check_linear_conv_1d_without_weight_grad_store(lazy_init, None) + check_linear_conv_1d_with_weight_grad_store(lazy_init, None) def run_dist(rank, world_size, port): From 014cc27846d0fd3922e70b1f7188e6684e84f02f Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 10 Dec 2024 11:28:23 +0800 Subject: [PATCH 06/54] [fix] debug FusedConvLinear --- colossalai/pipeline/weight_grad_store.py | 1 + .../test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py | 1 + 2 files changed, 2 insertions(+) diff --git a/colossalai/pipeline/weight_grad_store.py b/colossalai/pipeline/weight_grad_store.py index 66909317eba2..846cb242fd53 100644 --- a/colossalai/pipeline/weight_grad_store.py +++ b/colossalai/pipeline/weight_grad_store.py @@ -27,5 +27,6 @@ def pop(cls, chunk=0): else: grad_weight = func(total_input, grad_output) weight.grad = grad_weight + print(f"WeightGradStore {weight.grad}") else: raise Exception("Pop empty queue.") diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index 0f68f6c639f4..eb10462526af 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -189,6 +189,7 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo # check the input gradients & weight gradients assert_close(out.grad, gather_out.grad) # TODO:linear_base.weight.grad is None; But not none in WeightGradStore + print(f"ZBV weight.grad {linear_base.weight.grad}") # assert_close(linear.weight.grad, linear_base.weight.grad) @parameterize("lazy_init", [False, True]) From 778d4dffeefa6ef8ffcc4b7d123c3d05b0bda32e Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 10 Dec 2024 15:04:54 +0800 Subject: [PATCH 07/54] [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv Col and Row. --- colossalai/pipeline/weight_grad_store.py | 2 - colossalai/shardformer/layer/_operation.py | 139 +++++++++++++-- .../shardformer/layer/qkv_fused_linear.py | 6 + colossalai/shardformer/policies/gpt2.py | 161 ++++++++++++++++-- .../test_gpt2_qkv_fused_linear_1d.py | 1 - 5 files changed, 278 insertions(+), 31 deletions(-) diff --git a/colossalai/pipeline/weight_grad_store.py b/colossalai/pipeline/weight_grad_store.py index 846cb242fd53..8c1b64e0ee57 100644 --- a/colossalai/pipeline/weight_grad_store.py +++ b/colossalai/pipeline/weight_grad_store.py @@ -8,7 +8,6 @@ class WeightGradStore: @classmethod def put(cls, total_input, grad_output, weight, func): - # func(total_input, grad_output, weight.main_grad) cls.cache.append((total_input, grad_output, weight, func)) @classmethod @@ -27,6 +26,5 @@ def pop(cls, chunk=0): else: grad_weight = func(total_input, grad_output) weight.grad = grad_weight - print(f"WeightGradStore {weight.grad}") else: raise Exception("Pop empty queue.") diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 8d89b8f33869..e29961c85997 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -73,12 +73,13 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce ctx.fp8_communication = fp8_communication + ctx.use_zbv=use_zbv output = torch.matmul(input_, weight) @@ -92,7 +93,8 @@ def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias fp8_communication = ctx.fp8_communication - + use_zbv=ctx.use_zbv + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. weight = weight.view(weight.shape) if bias is not None: @@ -114,7 +116,64 @@ def backward(ctx, grad_output): # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py - grad_weight = total_input.t().matmul(grad_output) + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): + wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_) + + def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): + return wgrad_gemm_func(_input_.t(), _grad_output_) + + # split dx & dw + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + if use_zbv: + if grad.dtype == torch.float32: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, + ), + ) + grad_weight = None + elif grad.dtype in (torch.float16, torch.bfloat16): + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, + ), + ) + grad_weight = None + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + else: + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = total_input.t().matmul(grad_output) + else: + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass, + wgrad_gemm_func=torch.matmul, + ), + ) + grad_weight = None + else: + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_allreduce and not fp8_communication: @@ -167,7 +226,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f total_input = total_input.view(-1, total_input.shape[-1]) # split dx & dw - if weight.grad is not None: + if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad if use_zbv: if grad.dtype == torch.float32: @@ -823,13 +882,14 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.dim = dim ctx.fp8_communication = fp8_communication + ctx.use_zbv = use_zbv if ring is True: input_to_gather = {"input": input_} @@ -859,6 +919,7 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group + use_zbv = ctx.use_zbv # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm weight = weight.view(weight.shape) @@ -885,7 +946,65 @@ def backward(ctx, grad_output): # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated - grad_weight = total_input.t().matmul(grad_output) + + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): + wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_) + + def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): + return wgrad_gemm_func(_input_.t(), _grad_output_) + + # split dx & dw + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + if use_zbv: + if grad.dtype == torch.float32: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, + ), + ) + grad_weight = None + elif grad.dtype in (torch.float16, torch.bfloat16): + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, + ), + ) + grad_weight = None + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + else: + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = total_input.t().matmul(grad_output) + else: + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass, + wgrad_gemm_func=torch.matmul, + ), + ) + grad_weight = None + else: + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_reduce_scatter: @@ -1208,9 +1327,9 @@ def _all_to_all_single( ).contiguous() -def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): +def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False): return MatmulWithAsyncCommunication.apply( - input_, weight, bias, process_group, async_grad_allreduce, fp8_communication + input_, weight, bias, process_group, async_grad_allreduce, fp8_communication, use_zbv ) def matmul_with_grad_comm(input_, weight, bias, async_grad_allreduce, use_zbv=False): @@ -1251,10 +1370,10 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc def matmul_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False, use_zbv=False ): return _MatmulWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv ) diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 052c0e4fd8d9..aa83fb993ab4 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -229,6 +229,7 @@ def __init__( weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), fp8_communication: bool = False, + use_zbv: bool = False, ): super().__init__() @@ -242,6 +243,7 @@ def __init__( self.split_sizes = split_sizes self.process_group = process_group self.fp8_communication = fp8_communication + self.use_zbv = use_zbv assert ( sum(split_sizes) == out_features @@ -376,6 +378,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: 1, ring=self.seq_parallel_mode == "ring", fp8_communication=self.fp8_communication, + use_zbv=self.use_zbv, ) elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn": # Set up backprop all-reduce. @@ -387,6 +390,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: self.process_group, True, fp8_communication=self.fp8_communication, + use_zbv=self.use_zbv, ) else: raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!") @@ -442,6 +446,7 @@ def __init__( bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1, fp8_communication: bool = False, + use_zbv: bool = False, ): super().__init__() @@ -456,6 +461,7 @@ def __init__( self.seq_parallel_mode = seq_parallel_mode self.num_partitions = dist.get_world_size(self.process_group) self.fp8_communication = fp8_communication + self.use_zbv = use_zbv if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 08accaaea279..1ee0bb4cef4f 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -67,6 +67,8 @@ def module_policy(self): self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather" sp_partial_derived = sp_mode in ["split_gather", "ring"] use_flash_attention = self.shard_config.enable_flash_attention + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -94,12 +96,13 @@ def module_policy(self): "split_sizes": [self.model.config.hidden_size] * 3, "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv":use_zbv, }, ), SubModuleReplacementDescription( suffix="attn.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication}, + kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv}, ), SubModuleReplacementDescription( suffix="mlp.c_fc", @@ -109,12 +112,13 @@ def module_policy(self): "seq_parallel_mode": sp_mode, "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv":use_zbv, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication}, + kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv,}, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", @@ -138,6 +142,70 @@ def module_policy(self): policy=policy, target_key=GPT2MLP, ) + elif use_zbv: + policy[GPT2Model] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="drop", + target_module=col_nn.DropoutForParallelInput, + ), + ] + ) + + policy[GPT2Block] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.GPT2FusedLinearConv1D, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv":use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D, + kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv}, + ), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.GPT2FusedLinearConv1D, + kwargs={ + "seq_parallel_mode": sp_mode, + "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv":use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D, + kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv,}, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_gpt2_mlp_forward(), + }, + policy=policy, + target_key=GPT2MLP, + ) + if embedding_cls is not None: # padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by self.append_or_create_submodule_replacement( @@ -352,8 +420,17 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.lm_head) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + # if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True): + # held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -420,13 +497,31 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - multiple_choice_head = self.model.multiple_choice_head - held_layers.append(self.model.lm_head) - held_layers.append(multiple_choice_head.summary) - held_layers.append(multiple_choice_head.activation) - held_layers.append(multiple_choice_head.first_dropout) - held_layers.append(multiple_choice_head.last_dropout) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + held_layers.append(multiple_choice_head.summary) + held_layers.append(multiple_choice_head.activation) + held_layers.append(multiple_choice_head.first_dropout) + held_layers.append(multiple_choice_head.last_dropout) + else: + if self.pipeline_stage_manager.is_last_stage(): + multiple_choice_head = self.model.multiple_choice_head + held_layers.append(self.model.lm_head) + held_layers.append(multiple_choice_head.summary) + held_layers.append(multiple_choice_head.activation) + held_layers.append(multiple_choice_head.first_dropout) + held_layers.append(multiple_choice_head.last_dropout) + # if self.pipeline_stage_manager.is_last_stage(): + # multiple_choice_head = self.model.multiple_choice_head + # held_layers.append(self.model.lm_head) + # held_layers.append(multiple_choice_head.summary) + # held_layers.append(multiple_choice_head.activation) + # held_layers.append(multiple_choice_head.first_dropout) + # held_layers.append(multiple_choice_head.last_dropout) return held_layers @@ -464,8 +559,17 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.qa_outputs) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.qa_outputs) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) + # if self.pipeline_stage_manager.is_last_stage(): + # held_layers.append(self.model.qa_outputs) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -503,9 +607,20 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.dropout) - held_layers.append(self.model.classifier) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + # if self.pipeline_stage_manager.is_last_stage(): + # held_layers.append(self.model.dropout) + # held_layers.append(self.model.classifier) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -530,8 +645,18 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.score) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.score) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.score) + + # if self.pipeline_stage_manager.is_last_stage(): + # held_layers.append(self.model.score) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index eb10462526af..0f68f6c639f4 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -189,7 +189,6 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo # check the input gradients & weight gradients assert_close(out.grad, gather_out.grad) # TODO:linear_base.weight.grad is None; But not none in WeightGradStore - print(f"ZBV weight.grad {linear_base.weight.grad}") # assert_close(linear.weight.grad, linear_base.weight.grad) @parameterize("lazy_init", [False, True]) From d168b733b41d928b7e1556e53b3b1f873c9e850b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 10 Dec 2024 15:30:55 +0800 Subject: [PATCH 08/54] [Shardformer] support FusedLinear1D base for zbv --- .../shardformer/layer/qkv_fused_linear.py | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index aa83fb993ab4..f164e2030d60 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -38,7 +38,13 @@ from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset, is_share_sp_tp -__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row", "GPT2FusedLinearConv1D"] +__all__ = [ + "FusedLinear1D_Col", + "FusedLinear1D_Row", + "GPT2FusedLinearConv1D_Col", + "GPT2FusedLinearConv1D_Row", + "GPT2FusedLinearConv1D", +] # ==================================== # For GPT Only @@ -647,6 +653,7 @@ class GPT2FusedLinearConv1D(ParallelModule): More details about ``initializer`` please refer to `init `_. """ + def __init__( self, in_features: int, @@ -1174,3 +1181,32 @@ def forward(self, input_: Tensor) -> Tensor: return output else: return output, self.bias + + +class FusedLinear1D(ParallelModule): + r"""Fused Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `torch.nn.Linear` layer (Fused QKV) in normal torch layer of huggingface, like SAM. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + split_sizes (List[int]): The sizes of the split tensor. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ From 01a9cb3e2d82275e3e3e495948da017e929b12b9 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 10 Dec 2024 16:10:22 +0800 Subject: [PATCH 09/54] [shardformer] support zbv in FusedLinear1D base, Col, Row --- colossalai/shardformer/layer/_operation.py | 180 ++++++++++++++---- .../shardformer/layer/qkv_fused_linear.py | 157 ++++++++++++++- 2 files changed, 302 insertions(+), 35 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index e29961c85997..921f92b025df 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -79,7 +79,7 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_ ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce ctx.fp8_communication = fp8_communication - ctx.use_zbv=use_zbv + ctx.use_zbv = use_zbv output = torch.matmul(input_, weight) @@ -93,8 +93,8 @@ def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias fp8_communication = ctx.fp8_communication - use_zbv=ctx.use_zbv - + use_zbv = ctx.use_zbv + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. weight = weight.view(weight.shape) if bias is not None: @@ -173,7 +173,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f grad_weight = None else: grad_weight = total_input.t().matmul(grad_output) - + grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_allreduce and not fp8_communication: @@ -205,7 +205,7 @@ def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias use_zbv = ctx.use_zbv - + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_) @@ -224,7 +224,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f if len(grad_output.shape) > 2: grad_output = grad_output.view(-1, grad_output.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1]) - + # split dx & dw if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad @@ -276,7 +276,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f grad_weight = None else: grad_weight = total_input.t().matmul(grad_output) - + grad_bias = grad_output.sum(dim=0) if use_bias else None return grad_input, grad_weight, grad_bias, None, None, None, None @@ -338,7 +338,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad if use_zbv: - # TODO: append input, grad_output_, weight, grad func to WeightGradStore if grad.dtype == torch.float32: WeightGradStore.put( total_input, @@ -439,7 +438,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad if use_zbv: - # TODO: append input, grad_output_, weight, grad func to WeightGradStore if grad.dtype == torch.float32: WeightGradStore.put( total_input, @@ -613,12 +611,13 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, use_zbv=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.dim = dim + ctx.use_zbv = use_zbv if ring is True: input_to_gather = {"input": input_} @@ -650,6 +649,7 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group + use_zbv = ctx.use_zbv # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm if use_bias: @@ -675,18 +675,62 @@ def backward(ctx, grad_output): # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): + wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) + + def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): + return wgrad_gemm_func(_grad_output_.t(), _input_) + if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad - if grad.dtype == torch.float32: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) - grad_weight = None - elif grad.dtype == torch.float16: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + if use_zbv: + if grad.dtype == torch.float32: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, + ), + ) + grad_weight = None + elif grad.dtype in (torch.float16, torch.bfloat16): + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, + ), + ) + grad_weight = None + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + else: + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + else: + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass, + wgrad_gemm_func=torch.matmul, + ), + ) grad_weight = None else: grad_weight = grad_output.t().matmul(total_input) - else: - grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None @@ -765,11 +809,12 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, dim, ring): + def forward(ctx, input_, weight, bias, process_group, dim, ring, use_zbv=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.dim = dim + ctx.use_zbv = use_zbv if ring is True: input_to_reducescatter = {"input": input_} @@ -810,7 +855,7 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group - + use_zbv = ctx.use_zbv # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm if use_bias: bias = bias.view(bias.shape) @@ -825,7 +870,65 @@ def backward(ctx, grad_output): if len(grad_output.shape) > 2: grad_output = grad_output.view(-1, grad_output.shape[-1]) total_input = total_input.reshape(-1, total_input.shape[-1]) - grad_weight = grad_output.t().matmul(total_input) + + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): + wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) + + def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): + return wgrad_gemm_func(_grad_output_.t(), _input_) + + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + if use_zbv: + if grad.dtype == torch.float32: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, + ), + ) + grad_weight = None + elif grad.dtype in (torch.float16, torch.bfloat16): + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, + ), + ) + grad_weight = None + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + else: + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + else: + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass, + wgrad_gemm_func=torch.matmul, + ), + ) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + + # grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None return grad_input, grad_weight, grad_bias, None, None, None @@ -882,7 +985,9 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv=False): + def forward( + ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv=False + ): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -946,7 +1051,6 @@ def backward(ctx, grad_output): # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_) @@ -1004,7 +1108,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f grad_weight = None else: grad_weight = total_input.t().matmul(grad_output) - + grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_reduce_scatter: @@ -1327,15 +1431,17 @@ def _all_to_all_single( ).contiguous() -def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False): +def matmul_with_async_comm( + input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False +): return MatmulWithAsyncCommunication.apply( input_, weight, bias, process_group, async_grad_allreduce, fp8_communication, use_zbv ) + def matmul_with_grad_comm(input_, weight, bias, async_grad_allreduce, use_zbv=False): - return MatmulWithGradAccum.apply( - input_, weight, bias, async_grad_allreduce, use_zbv - ) + return MatmulWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv) + def linear_with_async_comm( input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False @@ -1350,10 +1456,10 @@ def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=F def linear_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, use_zbv=False ): return _LinearWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, use_zbv ) @@ -1365,12 +1471,22 @@ def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_commun return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim, fp8_communication) -def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False): - return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring) +def linear_reducescatter_forward_gather_backward( + input_, weight, bias=None, process_group=None, dim=1, ring=False, use_zbv=False +): + return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring, use_zbv) def matmul_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False, use_zbv=False + input_, + weight, + bias, + process_group, + async_grad_reduce_scatter, + dim, + ring=False, + fp8_communication=False, + use_zbv=False, ): return _MatmulWithGatherForwardReduceScatterBackward.apply( input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index f164e2030d60..c8a8a4f7f340 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -7,7 +7,6 @@ import torch import torch.distributed as dist import torch.nn as nn -import torch.nn.functional as F from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter @@ -28,6 +27,7 @@ linear_gather_forward_reducescatter_backward, linear_reducescatter_forward_gather_backward, linear_with_async_comm, + linear_with_grad_accum, matmul_gather_forward_reducescatter_backward, matmul_with_async_comm, matmul_with_grad_comm, @@ -832,6 +832,7 @@ def __init__( weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), fp8_communication: bool = False, + use_zbv: bool = False, ): super().__init__() # Keep input parameters @@ -845,6 +846,7 @@ def __init__( self.split_sizes = split_sizes self.process_group = process_group self.fp8_communication = fp8_communication + self.use_zbv = use_zbv assert ( sum(split_sizes) == out_features @@ -972,10 +974,17 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: True, self.seq_parallel_dim, ring=self.seq_parallel_mode == "ring", + use_zbv=self.use_zbv, ) else: output_parallel = linear_with_async_comm( - input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication + input_parallel, + self.weight, + bias, + self.process_group, + True, + fp8_communication=self.fp8_communication, + use_zbv=self.use_zbv, ) if self.gather_output: @@ -1031,6 +1040,7 @@ def __init__( weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), fp8_communication: bool = False, + use_zbv: bool = False, ): super().__init__() # Keep input parameters @@ -1044,6 +1054,7 @@ def __init__( self.seq_parallel_dim = seq_parallel_dim self.num_partitions = dist.get_world_size(self.process_group) self.fp8_communication = fp8_communication + self.use_zbv = use_zbv assert ( sum(split_sizes) == in_features @@ -1170,9 +1181,18 @@ def forward(self, input_: Tensor) -> Tensor: process_group=self.process_group, dim=self.seq_parallel_dim, ring=self.seq_parallel_mode == "ring", + use_zbv=self.use_zbv, ) else: - output_parallel = F.linear(input_, self.weight) + # output_parallel = F.linear(input_, self.weight) # Replace to LinearWithGradAccum + output_parallel = linear_with_grad_accum( + input_, + self.weight, + None, + False, + use_zbv=self.use_zbv, + ) + output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) if not self.skip_bias_add: @@ -1210,3 +1230,134 @@ class FusedLinear1D(ParallelModule): More details about ``initializer`` please refer to `init `_. """ + + def __init__( + self, + in_features: int, + out_features: int, + split_sizes: List[int], + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = False, + seq_parallel_mode: str = None, + seq_parallel_dim: int = 1, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + fp8_communication: bool = False, + use_zbv: bool = False, + ): + super().__init__() + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.seq_parallel_mode = seq_parallel_mode + self.seq_parallel_dim = seq_parallel_dim + self.skip_bias_add = skip_bias_add + self.device = device + self.fp8_communication = fp8_communication + self.use_zbv = use_zbv + + if skip_bias_add and not bias: + raise ValueError("cannot skip bias addition if bias is None") + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=None) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" + else: + assert bias_ is None, "bias_ must be None if weight is None" + + # Parameters. + if weight is None: + # Initialize weight. + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + else: + self.bias = None + + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module( + module: nn.Module, + *args, + **kwargs, + ) -> ParallelModule: + r""" + Convert a fused `torch.nn.linear` layer to a parallelized linear layer. + + Args: + module (`nn.Linear`): The module to be converted. + process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. + split_sizes (List[int]): The sizes of the split tensor. In common, Q,K,V are fused in one weight. + """ + LazyInitContext.materialize(module) + + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + linear_1d = FusedLinear1D( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) + # Set up backprop all-reduce. + input_parallel = input_ + + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + + output_parallel = linear_with_grad_accum( + input_parallel, self.weight, bias, True, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv + ) + + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output From fc77b24c1a9e291d86a0b9a4fd121a5ef3365fc8 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 10 Dec 2024 16:23:59 +0800 Subject: [PATCH 10/54] [shardformer] support zbv in blip2 and sam policy --- colossalai/shardformer/layer/__init__.py | 10 +- .../shardformer/layer/qkv_fused_linear.py | 4 +- colossalai/shardformer/policies/blip2.py | 5 +- colossalai/shardformer/policies/sam.py | 226 ++++++++++++++++++ 4 files changed, 239 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 800364003cef..e9f8e2cb5747 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -6,7 +6,14 @@ from .loss import cross_entropy_1d, dist_cross_entropy from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule -from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, GPT2FusedLinearConv1D +from .qkv_fused_linear import ( + FusedLinear1D, + FusedLinear1D_Col, + FusedLinear1D_Row, + GPT2FusedLinearConv1D, + GPT2FusedLinearConv1D_Col, + GPT2FusedLinearConv1D_Row, +) __all__ = [ "Embedding1D", @@ -27,6 +34,7 @@ "FusedLayerNorm", "FusedRMSNorm", "FusedLinear1D_Col", + "FusedLinear1D", "ParallelModule", "PaddingEmbedding", "PaddingLMHead", diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index c8a8a4f7f340..149e3d66a57f 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -41,6 +41,7 @@ __all__ = [ "FusedLinear1D_Col", "FusedLinear1D_Row", + "FusedLinear1D", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row", "GPT2FusedLinearConv1D", @@ -1235,12 +1236,9 @@ def __init__( self, in_features: int, out_features: int, - split_sizes: List[int], bias: bool = True, dtype: torch.dtype = None, device: torch.device = None, - process_group: ProcessGroup = None, - gather_output: bool = False, seq_parallel_mode: str = None, seq_parallel_dim: int = 1, skip_bias_add: bool = False, diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 4ca1cefc2815..f674d84d0402 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -75,6 +75,7 @@ def module_policy(self): kwargs={ "split_sizes": [self.model.config.vision_config.hidden_size] * 3, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -297,7 +298,6 @@ def module_policy(self): policy=policy, target_key=Blip2MLP, ) - elif use_zbv: policy[Blip2EncoderLayer] = ModulePolicyDescription( sub_module_replacement=[ @@ -307,10 +307,11 @@ def module_policy(self): ), SubModuleReplacementDescription( suffix="self_attn.qkv", - target_module=col_nn.FusedLinear1D_Col, + target_module=col_nn.FusedLinear1D, kwargs={ "split_sizes": [self.model.config.vision_config.hidden_size] * 3, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index a94cc9119356..8d41542dd721 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -27,6 +27,7 @@ def module_policy(self): norm_cls = col_nn.FusedLayerNorm else: norm_cls = col_nn.LayerNorm + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv if self.shard_config.enable_tensor_parallelism: assert ( @@ -44,6 +45,7 @@ def module_policy(self): kwargs={ "split_sizes": [self.model.config.vision_config.hidden_size] * 3, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -51,6 +53,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -58,6 +61,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -65,6 +69,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], @@ -80,6 +85,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -87,6 +93,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -94,6 +101,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -101,6 +109,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -108,6 +117,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -115,6 +125,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -122,6 +133,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -129,6 +141,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -136,6 +149,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -143,6 +157,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -150,6 +165,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -157,6 +173,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -164,6 +181,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -171,6 +189,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], @@ -186,6 +205,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -193,6 +213,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -200,6 +221,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -207,6 +229,210 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + + # add `DropoutForParallelInput` layer to replace the useage of `nn.functional.dropout` + policy[SamVisionAttention] = ModulePolicyDescription( + attribute_replacement={ + "dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout) + }, + method_replacement={"forward": forward_fn()}, + sub_module_replacement=[], + ) + elif use_zbv: + policy[SamVisionLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.qkv", + target_module=col_nn.FusedLinear1D, + kwargs={ + "split_sizes": [self.model.config.vision_config.hidden_size] * 3, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attn.proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.lin1", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.lin2", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + policy[SamTwoWayAttentionBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.num_attention_heads": self.model.config.mask_decoder_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.lin1", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.lin2", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + policy[SamTwoWayTransformer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], From 70b0ae1e9de73638fa59fa4e68fc22a3f4da4c4e Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 10 Dec 2024 16:48:07 +0800 Subject: [PATCH 11/54] [shardformer] fix bug incorrect number of gradients; add fusedLinear base testcase; --- colossalai/shardformer/layer/_operation.py | 3 +- .../shardformer/layer/qkv_fused_linear.py | 4 +-- .../test_layer/test_qkv_fused_linear_1d.py | 35 ++++++++++++++++++- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 921f92b025df..86970c641229 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -173,7 +173,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f grad_weight = None else: grad_weight = total_input.t().matmul(grad_output) - grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_allreduce and not fp8_communication: @@ -1114,7 +1113,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f if ctx.async_grad_reduce_scatter: handle.wait() - return output, grad_weight, grad_bias, None, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 149e3d66a57f..74470ace15d8 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -1349,9 +1349,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = linear_with_grad_accum( - input_parallel, self.weight, bias, True, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv - ) + output_parallel = linear_with_grad_accum(input_parallel, self.weight, bias, True, use_zbv=self.use_zbv) output = output_parallel diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index fccba564f7c7..43fca1ce65cc 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -7,7 +7,7 @@ import colossalai from colossalai.lazy import LazyInitContext -from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row +from colossalai.shardformer.layer import FusedLinear1D, FusedLinear1D_Col, FusedLinear1D_Row from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -120,12 +120,45 @@ def check_linear_1d_col_row(lazy_init: bool): assert_close(target_grad2, linear_row.weight.grad) +@parameterize("lazy_init", [False, True]) +def check_linear_1d_base(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + linear = nn.Linear(8, 80).cuda() + with ctx: + linear_copy = nn.Linear(8, 80).cuda() + linear_base = FusedLinear1D.from_native_module(linear_copy) + + assert linear.weight.shape == torch.Size([80, 8]) + assert linear.bias.shape == torch.Size([80]) + assert linear_base.weight.shape == torch.Size([80, 8]) + assert linear_base.bias.shape == torch.Size([80]) + assert linear_copy.weight is linear_base.weight + assert linear_copy.bias is linear_base.bias + + # ensure weights are reversibly loadable + linear_base.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_base.state_dict()) + + # check computation correctness + x = torch.rand(4, 8).cuda() + out = linear(x) + base_out = linear_base(x) + assert_close(out, base_out) + + # check backward correctness + out.sum().backward() + base_out.sum().backward() + + assert_close(linear.weight.grad, linear_base.weight.grad) + + def run_dist(rank, world_size, port): colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_linear_1d_col() check_linear_1d_row() check_linear_1d_col_row() + check_linear_1d_base() @rerun_if_address_is_in_use() From 37b670efb62f87d30b139c7396e49455e193e3c2 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 10 Dec 2024 19:25:49 +0800 Subject: [PATCH 12/54] [fix] fix incorrect number of gradients ; --- colossalai/shardformer/layer/_operation.py | 4 ++-- colossalai/shardformer/layer/linear.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 86970c641229..f2c02d249f04 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -736,7 +736,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f if ctx.async_grad_reduce_scatter: handle.wait() - return output, grad_weight, grad_bias, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None def _ring_as_reducescatter( @@ -930,7 +930,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f # grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None - return grad_input, grad_weight, grad_bias, None, None, None + return grad_input, grad_weight, grad_bias, None, None, None, None class _ReduceScatterForwardGatherBackward(torch.autograd.Function): diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index d39d6e997af8..fe195d6987da 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -350,6 +350,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: True, self.seq_parallel_dim, ring=self.seq_parallel_mode == "ring", + use_zbv=self.use_zbv, ) else: output_parallel = linear_with_async_comm( @@ -580,6 +581,7 @@ def forward(self, input_: Tensor) -> Tensor: process_group=self.process_group, dim=self.seq_parallel_dim, ring=self.seq_parallel_mode == "ring", + use_zbv=self.use_zbv, ) else: output_parallel = F.linear(input_, self.weight) From 94bb9ec1d628b89705d95c398023f88951b6862a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:27:45 +0000 Subject: [PATCH 13/54] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/policies/gpt2.py | 40 +++++++++++++------ .../test_gpt2_qkv_fused_linear_1d.py | 10 ++--- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 1ee0bb4cef4f..148e0ff2d33d 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -68,7 +68,7 @@ def module_policy(self): sp_partial_derived = sp_mode in ["split_gather", "ring"] use_flash_attention = self.shard_config.enable_flash_attention use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv - + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -96,13 +96,17 @@ def module_policy(self): "split_sizes": [self.model.config.hidden_size] * 3, "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, - "use_zbv":use_zbv, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( suffix="attn.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="mlp.c_fc", @@ -112,13 +116,17 @@ def module_policy(self): "seq_parallel_mode": sp_mode, "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, - "use_zbv":use_zbv, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv,}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", @@ -160,13 +168,17 @@ def module_policy(self): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, - "use_zbv":use_zbv, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( suffix="attn.c_proj", target_module=col_nn.GPT2FusedLinearConv1D, - kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="mlp.c_fc", @@ -175,13 +187,17 @@ def module_policy(self): "seq_parallel_mode": sp_mode, "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, - "use_zbv":use_zbv, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D, - kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, "use_zbv":use_zbv,}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", @@ -205,7 +221,7 @@ def module_policy(self): policy=policy, target_key=GPT2MLP, ) - + if embedding_cls is not None: # padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by self.append_or_create_submodule_replacement( @@ -617,7 +633,7 @@ def get_held_layers(self) -> List[nn.Module]: else: if self.pipeline_stage_manager.is_last_stage(): held_layers.append(self.model.dropout) - held_layers.append(self.model.classifier) + held_layers.append(self.model.classifier) # if self.pipeline_stage_manager.is_last_stage(): # held_layers.append(self.model.dropout) # held_layers.append(self.model.classifier) @@ -654,7 +670,7 @@ def get_held_layers(self) -> List[nn.Module]: else: if self.pipeline_stage_manager.is_last_stage(): held_layers.append(self.model.score) - + # if self.pipeline_stage_manager.is_last_stage(): # held_layers.append(self.model.score) return held_layers diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index 0f68f6c639f4..34074642c58d 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -9,7 +9,7 @@ import colossalai from colossalai.lazy import LazyInitContext from colossalai.pipeline.weight_grad_store import WeightGradStore -from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, GPT2FusedLinearConv1D +from colossalai.shardformer.layer import GPT2FusedLinearConv1D, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -125,9 +125,7 @@ def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() - linear_base = GPT2FusedLinearConv1D.from_native_module( - linear_copy, seq_parallel_mode=seq_parallel_mode - ) + linear_base = GPT2FusedLinearConv1D.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode) assert linear.weight.shape == torch.Size([48, 192]) assert linear_base.weight.shape == torch.Size([48, 192]) @@ -153,6 +151,7 @@ def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel assert_close(out.grad, gather_out.grad) assert_close(linear.weight.grad, linear_base.weight.grad) + def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: str): ctx = LazyInitContext() if lazy_init else nullcontext() @@ -182,7 +181,7 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo # check backward correctness out.sum().backward() gather_out.sum().backward() - + WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue WeightGradStore.pop(chunk=0) @@ -191,6 +190,7 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo # TODO:linear_base.weight.grad is None; But not none in WeightGradStore # assert_close(linear.weight.grad, linear_base.weight.grad) + @parameterize("lazy_init", [False, True]) @parameterize("seq_parallel_mode", ["split_gather", None]) def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool): From dee187860aa0740a33c87f2401076f3a760f1dfa Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 16 Dec 2024 14:50:46 +0800 Subject: [PATCH 14/54] [Shardformer] add en doc for zbv; --- .../zerobubble_pipeline_parallelism.md | 231 ++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 docs/source/en/features/zerobubble_pipeline_parallelism.md diff --git a/docs/source/en/features/zerobubble_pipeline_parallelism.md b/docs/source/en/features/zerobubble_pipeline_parallelism.md new file mode 100644 index 000000000000..21a11f24b3bc --- /dev/null +++ b/docs/source/en/features/zerobubble_pipeline_parallelism.md @@ -0,0 +1,231 @@ +# ZeroBubble Pipeline Parallelism +Author: [Junwen Duan](https://github.com/duanjunwen), [Hongxin Liu](https://github.com/ver217) + +**Related Paper** +- [Zero Bubble Pipeline Parallelism](https://arxiv.org/abs/2401.10241) + +## Introduction +ZeroBubble (V Schedule): +Crucially, splitting B into two stages (also known as an activation gradient and a weight gradient) and a scheme like 1F1B1W can further reduce the bubble compared to the 1F1B scheme in earlier work. + +## Hands-On Practice +We now demonstrate how to use ZeroBubble with booster API with 4 GPUs. + +### step 1. Import libraries +```python +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaModel + +import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin +from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler +``` + +### step 2. Initialize Distributed Environment and Parallism Group +```python +colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") +``` + +### step 3. Initialize Module, Optimizer, and Pipeline Schedule +Build our model and Optimizer. We created a Llama with 8 Decoder-Layer. Then, inite the PipelineGraph and Pipeline schedule by get_v_schedule() function. +```python +# Global Param +NUM_BATCH = 8 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 +NUM_LAYERS = 8 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +TOP_K = 1 +# Init Llama from huggingface +configuration = LlamaConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + attn_implementation="flash_attention_2", +) +model = LlamaModel(configuration).cuda() +optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) +# Init schedule +h, a, s = config.hidden_size, config.num_attention_heads, 1024 +mem_f = 34 * h + 5 * a * s +mem_w = -32 * h +mem_b = -mem_w - mem_f +graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatches, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, +) +zbv_schedule = graph.get_v_schedule() +``` + +### step 4.Init Booster +Pass pp_style="zbv" when initialising the Plugin to use the ZeroBubble Pipeline. +```python +plugin = HybridParallelPlugin( + pp_size=4, + num_microbatches=4, + tp_size=1, + sp_size=1, + zero_stage=1, + initial_scale=1, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) + +dp_size = plugin.dp_size +booster = Booster(plugin=plugin) +``` + +### step 5.Train Your Model +```python +steps = 10 +for step in range(steps): + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) + data_iter = iter([{"inputs_embeds": input_embeddings}]) + output = booster.execute_pipeline( + data_iter, + model, + lambda x, y: x.last_hidden_state.mean(), + optimizer, + return_loss=True, + return_outputs=True, + ) + optimizer.step() + optimizer.zero_grad() +``` + +## Advanced Practice +In ColossalAI, you can get better training performance by using MetaCache and HybridParallel with ZeroBubble. +### 1.Use MetaCache with ZeroBubble +Pass "enable_metadata_cache=True" when initialising the Plugin to use the Meta Cache with ZeroBubble Pipeline. +```python +plugin = HybridParallelPlugin( + pp_size=2, + num_microbatches=4, + tp_size=2, + sp_size=2, + zero_stage=1, + initial_scale=1, + enable_metadata_cache=True, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) +``` + +### 2.HybridParallel with ZeroBubble +Pass pp_size, tp_size, sp_size when initialising the Plugin to use the HybridParallel with ZeroBubble Pipeline. +```python +plugin = HybridParallelPlugin( + pp_size=2, + num_microbatches=2, + tp_size=2, + sp_size=2, + zero_stage=1, + initial_scale=1, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) +``` +Performance Benchmark + + + + + + + + + + + + + + + + + + + + + + +
HybridParallel StrategyPipeline ParallelSequence Parallel + Pipeline ParallelData Parallel + Pipeline Parallel
With 1F1B15.27 samples/sec17.22 samples/sec14.06 samples/sec
With Zero Bubble17.36 samples/sec18.38 samples/sec14.44 samples/sec
+ +### 3.Fine-tuning Scheduler parameters + +```python +``` +## Model compatibility + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Shardformer/ModelBertBlip2BloomChatglm2CommandDeepseekFalconGPT2GptjLlamaMistralOptQwen2SamT5VitWhisper
ZeroBubble✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️
+ +## API Reference +{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.py.ZeroBubbleVPipeScheduler }} + + From 83e670e04a640d3127a2a59aba73b27b4b652ccb Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 16 Dec 2024 14:53:38 +0800 Subject: [PATCH 15/54] [fix] fix typo in Model compatibility table --- docs/source/en/features/zerobubble_pipeline_parallelism.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/features/zerobubble_pipeline_parallelism.md b/docs/source/en/features/zerobubble_pipeline_parallelism.md index 21a11f24b3bc..66bd65bae792 100644 --- a/docs/source/en/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/en/features/zerobubble_pipeline_parallelism.md @@ -219,6 +219,7 @@ Performance Benchmark ✔️ ✔️ ✔️ + ✔️ From 2a55566ca8f2b0b7d3936b2222e609c27c6bbc70 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 16 Dec 2024 14:56:06 +0800 Subject: [PATCH 16/54] [fix] fix API Reference typo --- docs/source/en/features/zerobubble_pipeline_parallelism.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/features/zerobubble_pipeline_parallelism.md b/docs/source/en/features/zerobubble_pipeline_parallelism.md index 66bd65bae792..3fdfb8811228 100644 --- a/docs/source/en/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/en/features/zerobubble_pipeline_parallelism.md @@ -227,6 +227,6 @@ Performance Benchmark ## API Reference -{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.py.ZeroBubbleVPipeScheduler }} +{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }} From 5430eb0b1610ca76767b90ced31da70910dd8ab5 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 17 Dec 2024 11:44:22 +0800 Subject: [PATCH 17/54] [Shardformer] add zh-Han doc for zbv --- .../zerobubble_pipeline_parallelism.md | 235 ++++++++++++++++++ 1 file changed, 235 insertions(+) create mode 100644 docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md diff --git a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md new file mode 100644 index 000000000000..20b175c1a578 --- /dev/null +++ b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md @@ -0,0 +1,235 @@ +# 零气泡流水线并行 +作者: [Junwen Duan](https://github.com/duanjunwen), [Hongxin Liu](https://github.com/ver217) + +**相关论文** +- [Zero Bubble Pipeline Parallelism](https://arxiv.org/abs/2401.10241) + +## 介绍 +零气泡(V Schedule): +与早期工作中的1F1B方案相比,零气泡流水线并行将B分成两个阶段(也称为激活梯度和权重梯度),形如1F1B1W这样的方案可以进一步减少气泡。 + +## 使用 +我们将演示如何在 4 个 GPU 上使用带有 booster API 的 ZeroBubble + +### step 1. 引用仓库 +```python +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaModel + +import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin +from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler +``` + +### step 2. 初始化分布式环境 +```python +colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") +``` + +### step 3. 初始化模型优化器和流水线Schedule +建立我们的模型和优化器 我们创建了一个带有8层Decoder-Layer的 Llama。然后,使用get_v_schedule()函数创建PipelineGraph和Pipeline schedule。 + +```python +# Global Param +NUM_BATCH = 8 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 +NUM_LAYERS = 8 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +TOP_K = 1 +# Init Llama from huggingface +configuration = LlamaConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + attn_implementation="flash_attention_2", +) +model = LlamaModel(configuration).cuda() +optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) +# Init schedule +h, a, s = config.hidden_size, config.num_attention_heads, 1024 +mem_f = 34 * h + 5 * a * s +mem_w = -32 * h +mem_b = -mem_w - mem_f +graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatches, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, +) +zbv_schedule = graph.get_v_schedule() +``` + +### step 4.初始化Booster +在初始化Plugin时输入pp_style="zbv",以使用ZeroBubble流水线并行。 +```python +plugin = HybridParallelPlugin( + pp_size=4, + num_microbatches=4, + tp_size=1, + sp_size=1, + zero_stage=1, + initial_scale=1, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) + +dp_size = plugin.dp_size +booster = Booster(plugin=plugin) +``` + +### step 5.训练模型 +```python +steps = 10 +for step in range(steps): + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) + data_iter = iter([{"inputs_embeds": input_embeddings}]) + output = booster.execute_pipeline( + data_iter, + model, + lambda x, y: x.last_hidden_state.mean(), + optimizer, + return_loss=True, + return_outputs=True, + ) + optimizer.step() + optimizer.zero_grad() +``` + +## 进阶使用技巧 +在 ColossalAI 中,通过使用MetaCache和混合并行的ZeroBubble,可以获得更好的训练性能。 + +### 1.在ZeroBubble中使用元数据缓存 +在初始化Plugin时输入 "enable_metadata_cache=True",以便在ZeroBubble管道中使用元数据缓存。 +Pass "enable_metadata_cache=True" when initialising the Plugin to use the Meta Cache with ZeroBubble Pipeline. +```python +plugin = HybridParallelPlugin( + pp_size=2, + num_microbatches=4, + tp_size=2, + sp_size=2, + zero_stage=1, + initial_scale=1, + enable_metadata_cache=True, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) +``` + +### 2.同时使用ZeroBubble和混合并行 +在初始化插件时传递 pp_size, tp_size, sp_size, 以便使用零气泡混合并行管道(HybridParallel with ZeroBubble Pipeline)。 +```python +plugin = HybridParallelPlugin( + pp_size=2, + num_microbatches=2, + tp_size=2, + sp_size=2, + zero_stage=1, + initial_scale=1, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) +``` +性能指标 + + + + + + + + + + + + + + + + + + + + + + +
HybridParallel StrategyPipeline ParallelSequence Parallel + Pipeline ParallelData Parallel + Pipeline Parallel
With 1F1B15.27 samples/sec17.22 samples/sec14.06 samples/sec
With Zero Bubble17.36 samples/sec18.38 samples/sec14.44 samples/sec
+ +### 3.Fine-tuning Scheduler parameters + +```python +``` +## 模型兼容性 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Shardformer/ModelBertBlip2BloomChatglm2CommandDeepseekFalconGPT2GptjLlamaMistralOptQwen2SamT5VitWhisper
ZeroBubble✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️
+ +## API 参考 +{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }} + + From 25da23de41cc2d7957c671016f7df15988312a63 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 18 Dec 2024 14:00:02 +0800 Subject: [PATCH 18/54] [fix] fix Linear name; update en & zh doc --- colossalai/shardformer/layer/__init__.py | 6 +++--- colossalai/shardformer/layer/qkv_fused_linear.py | 12 ++++++------ colossalai/shardformer/policies/blip2.py | 2 +- colossalai/shardformer/policies/gpt2.py | 15 ++++----------- colossalai/shardformer/policies/sam.py | 2 +- .../features/zerobubble_pipeline_parallelism.md | 12 ++++++++++-- .../features/zerobubble_pipeline_parallelism.md | 9 ++++++--- .../test_layer/test_gpt2_qkv_fused_linear_1d.py | 8 +++----- .../test_layer/test_qkv_fused_linear_1d.py | 4 ++-- 9 files changed, 36 insertions(+), 34 deletions(-) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index e9f8e2cb5747..850001d04227 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -7,10 +7,10 @@ from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule from .qkv_fused_linear import ( - FusedLinear1D, + FusedLinear, FusedLinear1D_Col, FusedLinear1D_Row, - GPT2FusedLinearConv1D, + GPT2FusedLinearConv, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, ) @@ -34,7 +34,7 @@ "FusedLayerNorm", "FusedRMSNorm", "FusedLinear1D_Col", - "FusedLinear1D", + "FusedLinear", "ParallelModule", "PaddingEmbedding", "PaddingLMHead", diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 74470ace15d8..5edcfd9b8a00 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -41,10 +41,10 @@ __all__ = [ "FusedLinear1D_Col", "FusedLinear1D_Row", - "FusedLinear1D", + "FusedLinear", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row", - "GPT2FusedLinearConv1D", + "GPT2FusedLinearConv", ] # ==================================== @@ -634,7 +634,7 @@ def forward(self, input_: Tensor) -> Tensor: return output, self.bias -class GPT2FusedLinearConv1D(ParallelModule): +class GPT2FusedLinearConv(ParallelModule): r"""Linear layer without parallelism. This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface. @@ -738,7 +738,7 @@ def from_native_module( bias = module.bias is not None device = module.weight.device - linear_1d = GPT2FusedLinearConv1D( + linear_1d = GPT2FusedLinearConv( in_features=in_features, out_features=out_features, bias=bias, @@ -1204,7 +1204,7 @@ def forward(self, input_: Tensor) -> Tensor: return output, self.bias -class FusedLinear1D(ParallelModule): +class FusedLinear(ParallelModule): r"""Fused Linear layer with column parallelism. The linear layer is defined as :math:`Y = XA + b`. A is parallelized along @@ -1317,7 +1317,7 @@ def from_native_module( bias = module.bias is not None device = module.weight.device - linear_1d = FusedLinear1D( + linear_1d = FusedLinear( in_features=in_features, out_features=out_features, bias=bias, diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index f674d84d0402..246c4616009e 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -307,7 +307,7 @@ def module_policy(self): ), SubModuleReplacementDescription( suffix="self_attn.qkv", - target_module=col_nn.FusedLinear1D, + target_module=col_nn.FusedLinear, kwargs={ "split_sizes": [self.model.config.vision_config.hidden_size] * 3, "fp8_communication": self.shard_config.fp8_communication, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 148e0ff2d33d..c57d33826a39 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -164,7 +164,7 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="attn.c_attn", - target_module=col_nn.GPT2FusedLinearConv1D, + target_module=col_nn.GPT2FusedLinearConv, kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, @@ -173,7 +173,7 @@ def module_policy(self): ), SubModuleReplacementDescription( suffix="attn.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D, + target_module=col_nn.GPT2FusedLinearConv, kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, @@ -182,7 +182,7 @@ def module_policy(self): ), SubModuleReplacementDescription( suffix="mlp.c_fc", - target_module=col_nn.GPT2FusedLinearConv1D, + target_module=col_nn.GPT2FusedLinearConv, kwargs={ "seq_parallel_mode": sp_mode, "skip_bias_add": self.enable_bias_gelu_fused, @@ -192,7 +192,7 @@ def module_policy(self): ), SubModuleReplacementDescription( suffix="mlp.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D, + target_module=col_nn.GPT2FusedLinearConv, kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, @@ -531,13 +531,6 @@ def get_held_layers(self) -> List[nn.Module]: held_layers.append(multiple_choice_head.activation) held_layers.append(multiple_choice_head.first_dropout) held_layers.append(multiple_choice_head.last_dropout) - # if self.pipeline_stage_manager.is_last_stage(): - # multiple_choice_head = self.model.multiple_choice_head - # held_layers.append(self.model.lm_head) - # held_layers.append(multiple_choice_head.summary) - # held_layers.append(multiple_choice_head.activation) - # held_layers.append(multiple_choice_head.first_dropout) - # held_layers.append(multiple_choice_head.last_dropout) return held_layers diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index 8d41542dd721..237db386930f 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -248,7 +248,7 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="attn.qkv", - target_module=col_nn.FusedLinear1D, + target_module=col_nn.FusedLinear, kwargs={ "split_sizes": [self.model.config.vision_config.hidden_size] * 3, "fp8_communication": self.shard_config.fp8_communication, diff --git a/docs/source/en/features/zerobubble_pipeline_parallelism.md b/docs/source/en/features/zerobubble_pipeline_parallelism.md index 3fdfb8811228..5cb13528b5bf 100644 --- a/docs/source/en/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/en/features/zerobubble_pipeline_parallelism.md @@ -52,6 +52,14 @@ configuration = LlamaConfig( ) model = LlamaModel(configuration).cuda() optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) +``` +### step 4. Initialize Pipeline Schedule +Then, we need to create the PipelineGraph and PipelineSchedule using the get_v_schedule() function. We need to initialise the PipelineGraph with the following parameters. +x_cost represents the runtime consumed by operation x of each model chunk. +x_mem represents the amount of memory consumed by the operation x of each model chunk. +These parameters are estimated and filled in before the pipeline starts. In fact, better results can be obtained based on the runtime and memory cost during the real computation of the model. +In the following example, we assume that the computation times for the model's forward, reverse B, and reverse W are 1, 1, 1, respectively, and the p2p communication time is 1 +```python # Init schedule h, a, s = config.hidden_size, config.num_attention_heads, 1024 mem_f = 34 * h + 5 * a * s @@ -71,7 +79,7 @@ graph = PipelineGraph( zbv_schedule = graph.get_v_schedule() ``` -### step 4.Init Booster +### step 5.Init Booster Pass pp_style="zbv" when initialising the Plugin to use the ZeroBubble Pipeline. ```python plugin = HybridParallelPlugin( @@ -91,7 +99,7 @@ dp_size = plugin.dp_size booster = Booster(plugin=plugin) ``` -### step 5.Train Your Model +### step 6.Train Your Model ```python steps = 10 for step in range(steps): diff --git a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md index 20b175c1a578..9df095e50983 100644 --- a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md @@ -31,7 +31,7 @@ from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") ``` -### step 3. 初始化模型优化器和流水线Schedule +### step 3. 初始化模型优化器 建立我们的模型和优化器 我们创建了一个带有8层Decoder-Layer的 Llama。然后,使用get_v_schedule()函数创建PipelineGraph和Pipeline schedule。 ```python @@ -53,6 +53,9 @@ configuration = LlamaConfig( ) model = LlamaModel(configuration).cuda() optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) +``` +### step 4.初始化流水线Schedule +```python # Init schedule h, a, s = config.hidden_size, config.num_attention_heads, 1024 mem_f = 34 * h + 5 * a * s @@ -72,7 +75,7 @@ graph = PipelineGraph( zbv_schedule = graph.get_v_schedule() ``` -### step 4.初始化Booster +### step 5.初始化Booster 在初始化Plugin时输入pp_style="zbv",以使用ZeroBubble流水线并行。 ```python plugin = HybridParallelPlugin( @@ -92,7 +95,7 @@ dp_size = plugin.dp_size booster = Booster(plugin=plugin) ``` -### step 5.训练模型 +### step 6.训练模型 ```python steps = 10 for step in range(steps): diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index 34074642c58d..53cd9721e6c9 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -9,7 +9,7 @@ import colossalai from colossalai.lazy import LazyInitContext from colossalai.pipeline.weight_grad_store import WeightGradStore -from colossalai.shardformer.layer import GPT2FusedLinearConv1D, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from colossalai.shardformer.layer import GPT2FusedLinearConv, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -125,7 +125,7 @@ def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() - linear_base = GPT2FusedLinearConv1D.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode) + linear_base = GPT2FusedLinearConv.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode) assert linear.weight.shape == torch.Size([48, 192]) assert linear_base.weight.shape == torch.Size([48, 192]) @@ -158,9 +158,7 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() - linear_base = GPT2FusedLinearConv1D.from_native_module( - linear_copy, seq_parallel_mode=seq_parallel_mode, use_zbv=True - ) + linear_base = GPT2FusedLinearConv.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode, use_zbv=True) assert linear.weight.shape == torch.Size([48, 192]) assert linear_base.weight.shape == torch.Size([48, 192]) diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index 43fca1ce65cc..b31342cb30af 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -7,7 +7,7 @@ import colossalai from colossalai.lazy import LazyInitContext -from colossalai.shardformer.layer import FusedLinear1D, FusedLinear1D_Col, FusedLinear1D_Row +from colossalai.shardformer.layer import FusedLinear, FusedLinear1D_Col, FusedLinear1D_Row from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -126,7 +126,7 @@ def check_linear_1d_base(lazy_init: bool): linear = nn.Linear(8, 80).cuda() with ctx: linear_copy = nn.Linear(8, 80).cuda() - linear_base = FusedLinear1D.from_native_module(linear_copy) + linear_base = FusedLinear.from_native_module(linear_copy) assert linear.weight.shape == torch.Size([80, 8]) assert linear.bias.shape == torch.Size([80]) From fd5bd334b2bb88692b0b23e364d2298cad1d9cb8 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 18 Dec 2024 14:32:06 +0800 Subject: [PATCH 19/54] [fix] fix shardformer doc import err --- docs/source/en/features/shardformer.md | 2 +- docs/source/zh-Hans/features/shardformer.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index 40b8954b55b5..4ee98297c942 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -213,7 +213,7 @@ The support matrix will grow larger as more models and optimization tools emerge The configuration of Shardformer is controlled by class `ShardConfig`: -{{ autodoc:colossalai.shardformer.ShardConfig }} +{{ autodoc:colossalai.shardformer.shard.ShardConfig }} If you want to enable Apex Fused Layernorm, please install `apex`. If you want to enable the usage of flash attention, please install `flash_attn`. diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index 02290f3d6eae..38d1aeccbefb 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -209,7 +209,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. Shardformer的配置由类`ShardConfig`的参数控制: -{{ autodoc:colossalai.shardformer.ShardConfig }} +{{ autodoc:colossalai.shardformer.shard.ShardConfig }} 如果您想启用 Apex Fused Layernorm,请安装 `apex`。如果您想启用 flash attention,请安装 `flash_attn`。此外,xFormers 的 `cutlass_op` 可以作为Flash Attention的补充优化方式。 From c749a7cc0acc7d411df6a460ef1fd66cb9e57c0d Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 18 Dec 2024 14:44:10 +0800 Subject: [PATCH 20/54] [fix] fix shardconfig import in doc --- docs/source/en/features/shardformer.md | 2 +- docs/source/zh-Hans/features/shardformer.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index 4ee98297c942..4bec1c57ca3e 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -213,7 +213,7 @@ The support matrix will grow larger as more models and optimization tools emerge The configuration of Shardformer is controlled by class `ShardConfig`: -{{ autodoc:colossalai.shardformer.shard.ShardConfig }} +{{ autodoc:colossalai.shardformer.shard }} If you want to enable Apex Fused Layernorm, please install `apex`. If you want to enable the usage of flash attention, please install `flash_attn`. diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index 38d1aeccbefb..06b8872ec378 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -209,7 +209,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. Shardformer的配置由类`ShardConfig`的参数控制: -{{ autodoc:colossalai.shardformer.shard.ShardConfig }} +{{ autodoc:colossalai.shardformer.shard }} 如果您想启用 Apex Fused Layernorm,请安装 `apex`。如果您想启用 flash attention,请安装 `flash_attn`。此外,xFormers 的 `cutlass_op` 可以作为Flash Attention的补充优化方式。 From eba4e33f304a8c96905df55964a72db49dc01a6e Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 18 Dec 2024 14:50:25 +0800 Subject: [PATCH 21/54] [fix] fix shardformer doc --- docs/source/en/features/shardformer.md | 2 +- docs/source/zh-Hans/features/shardformer.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index 4bec1c57ca3e..40b8954b55b5 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -213,7 +213,7 @@ The support matrix will grow larger as more models and optimization tools emerge The configuration of Shardformer is controlled by class `ShardConfig`: -{{ autodoc:colossalai.shardformer.shard }} +{{ autodoc:colossalai.shardformer.ShardConfig }} If you want to enable Apex Fused Layernorm, please install `apex`. If you want to enable the usage of flash attention, please install `flash_attn`. diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index 06b8872ec378..02290f3d6eae 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -209,7 +209,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. Shardformer的配置由类`ShardConfig`的参数控制: -{{ autodoc:colossalai.shardformer.shard }} +{{ autodoc:colossalai.shardformer.ShardConfig }} 如果您想启用 Apex Fused Layernorm,请安装 `apex`。如果您想启用 flash attention,请安装 `flash_attn`。此外,xFormers 的 `cutlass_op` 可以作为Flash Attention的补充优化方式。 From 3c5ce9e73a4d14235916ea29a026369eccd7a02c Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 18 Dec 2024 14:55:38 +0800 Subject: [PATCH 22/54] [fix] fix shardconfig doc --- docs/source/en/features/shardformer.md | 2 +- docs/source/zh-Hans/features/shardformer.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index 40b8954b55b5..03fd688bd041 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -213,7 +213,7 @@ The support matrix will grow larger as more models and optimization tools emerge The configuration of Shardformer is controlled by class `ShardConfig`: -{{ autodoc:colossalai.shardformer.ShardConfig }} +{{ autodoc:colossalai.shardformer.shard.shard_config.ShardConfig }} If you want to enable Apex Fused Layernorm, please install `apex`. If you want to enable the usage of flash attention, please install `flash_attn`. diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index 02290f3d6eae..a56ca4d581d0 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -209,7 +209,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. Shardformer的配置由类`ShardConfig`的参数控制: -{{ autodoc:colossalai.shardformer.ShardConfig }} +{{ autodoc:colossalai.shardformer.shard.shard_config.ShardConfig }} 如果您想启用 Apex Fused Layernorm,请安装 `apex`。如果您想启用 flash attention,请安装 `flash_attn`。此外,xFormers 的 `cutlass_op` 可以作为Flash Attention的补充优化方式。 From 6bbe66630d0675231ba93748733bd705b7bea361 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 18 Dec 2024 14:59:36 +0800 Subject: [PATCH 23/54] [fix] fix config --- docs/source/en/features/shardformer.md | 2 +- docs/source/zh-Hans/features/shardformer.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index 03fd688bd041..89238dc550da 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -213,7 +213,7 @@ The support matrix will grow larger as more models and optimization tools emerge The configuration of Shardformer is controlled by class `ShardConfig`: -{{ autodoc:colossalai.shardformer.shard.shard_config.ShardConfig }} +{{ autodoc:colossalai.shardformer.shard.shard_config }} If you want to enable Apex Fused Layernorm, please install `apex`. If you want to enable the usage of flash attention, please install `flash_attn`. diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index a56ca4d581d0..b6e8b2657db4 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -209,7 +209,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. Shardformer的配置由类`ShardConfig`的参数控制: -{{ autodoc:colossalai.shardformer.shard.shard_config.ShardConfig }} +{{ autodoc:colossalai.shardformer.shard.shard_config }} 如果您想启用 Apex Fused Layernorm,请安装 `apex`。如果您想启用 flash attention,请安装 `flash_attn`。此外,xFormers 的 `cutlass_op` 可以作为Flash Attention的补充优化方式。 From 394636648f7e28554e917c461d561fca3a19eef1 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 18 Dec 2024 15:04:34 +0800 Subject: [PATCH 24/54] [fix] remove shardconfig --- docs/source/en/features/shardformer.md | 2 -- docs/source/zh-Hans/features/shardformer.md | 2 -- 2 files changed, 4 deletions(-) diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index 89238dc550da..62391467330e 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -213,8 +213,6 @@ The support matrix will grow larger as more models and optimization tools emerge The configuration of Shardformer is controlled by class `ShardConfig`: -{{ autodoc:colossalai.shardformer.shard.shard_config }} - If you want to enable Apex Fused Layernorm, please install `apex`. If you want to enable the usage of flash attention, please install `flash_attn`. In addition, xFormers's `cutlass_op` can serve as a backup for flash attention. diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index b6e8b2657db4..124ae155c891 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -209,8 +209,6 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. Shardformer的配置由类`ShardConfig`的参数控制: -{{ autodoc:colossalai.shardformer.shard.shard_config }} - 如果您想启用 Apex Fused Layernorm,请安装 `apex`。如果您想启用 flash attention,请安装 `flash_attn`。此外,xFormers 的 `cutlass_op` 可以作为Flash Attention的补充优化方式。 ### 启动Shardformer From b99c733dc49eaf70400e751da2e03c3e2f3ab1a5 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 18 Dec 2024 15:11:28 +0800 Subject: [PATCH 25/54] [fix] fix doc --- docs/source/en/features/shardformer.md | 2 ++ .../en/features/zerobubble_pipeline_parallelism.md | 6 +----- docs/source/zh-Hans/features/shardformer.md | 2 ++ .../features/zerobubble_pipeline_parallelism.md | 10 +++++----- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index 62391467330e..40b8954b55b5 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -213,6 +213,8 @@ The support matrix will grow larger as more models and optimization tools emerge The configuration of Shardformer is controlled by class `ShardConfig`: +{{ autodoc:colossalai.shardformer.ShardConfig }} + If you want to enable Apex Fused Layernorm, please install `apex`. If you want to enable the usage of flash attention, please install `flash_attn`. In addition, xFormers's `cutlass_op` can serve as a backup for flash attention. diff --git a/docs/source/en/features/zerobubble_pipeline_parallelism.md b/docs/source/en/features/zerobubble_pipeline_parallelism.md index 5cb13528b5bf..604c7c533a5a 100644 --- a/docs/source/en/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/en/features/zerobubble_pipeline_parallelism.md @@ -58,7 +58,7 @@ Then, we need to create the PipelineGraph and PipelineSchedule using the get_v_s x_cost represents the runtime consumed by operation x of each model chunk. x_mem represents the amount of memory consumed by the operation x of each model chunk. These parameters are estimated and filled in before the pipeline starts. In fact, better results can be obtained based on the runtime and memory cost during the real computation of the model. -In the following example, we assume that the computation times for the model's forward, reverse B, and reverse W are 1, 1, 1, respectively, and the p2p communication time is 1 +In the following example, we assume that the computation times for the model's forward, reverse B, and reverse W are 1, 1, 1, respectively, and the p2p communication time is 1. ```python # Init schedule h, a, s = config.hidden_size, config.num_attention_heads, 1024 @@ -183,10 +183,6 @@ Performance Benchmark -### 3.Fine-tuning Scheduler parameters - -```python -``` ## Model compatibility diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index 124ae155c891..02290f3d6eae 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -209,6 +209,8 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. Shardformer的配置由类`ShardConfig`的参数控制: +{{ autodoc:colossalai.shardformer.ShardConfig }} + 如果您想启用 Apex Fused Layernorm,请安装 `apex`。如果您想启用 flash attention,请安装 `flash_attn`。此外,xFormers 的 `cutlass_op` 可以作为Flash Attention的补充优化方式。 ### 启动Shardformer diff --git a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md index 9df095e50983..fd361dee048f 100644 --- a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md @@ -55,6 +55,11 @@ model = LlamaModel(configuration).cuda() optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) ``` ### step 4.初始化流水线Schedule +然后,我们需要使用 get_v_schedule() 函数创建 PipelineGraph 和 PipelineSchedule。我们需要用以下参数初始化 PipelineGraph。 +x_cost 表示每个模型块的操作 x 所消耗的运行时间。 +x_mem 表示每个模型块的操作 x 所消耗的内存量。 +这些参数都是在流水线启动前估算并填入的。事实上,在模型的实际计算过程中,根据运行时间和内存成本可以获得更好的结果。 +在下面的例子中,我们假设模型的正向、反向 B 和反向 W 的计算时间分别为 1、1、1,p2p 通信时间为 1。 ```python # Init schedule h, a, s = config.hidden_size, config.num_attention_heads, 1024 @@ -123,7 +128,6 @@ for step in range(steps): ### 1.在ZeroBubble中使用元数据缓存 在初始化Plugin时输入 "enable_metadata_cache=True",以便在ZeroBubble管道中使用元数据缓存。 -Pass "enable_metadata_cache=True" when initialising the Plugin to use the Meta Cache with ZeroBubble Pipeline. ```python plugin = HybridParallelPlugin( pp_size=2, @@ -181,10 +185,6 @@ plugin = HybridParallelPlugin(
-### 3.Fine-tuning Scheduler parameters - -```python -``` ## 模型兼容性 From 99a78292c6a174ec286e8d33b3701c49051156c2 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 18 Dec 2024 16:57:20 +0800 Subject: [PATCH 26/54] [feat] add zbv doc string --- colossalai/pipeline/schedule/zero_bubble_pp.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 7cec5f003bae..edbb7118aa1a 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -38,6 +38,19 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: class ZeroBubbleVPipeScheduler(PipelineSchedule): + r""" + ZeroBubbleVPipeScheduler + + Args: + stage_manager (PipelineStageManager): If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism. + schedule (List[ScheduledNode]): Schedule for ZeroBubbleVPipe. + num_model_chunks (int) : The number of model chunk in a device. + num_microbatch (Optional[int]): The number of microbatch. + microbatch_size (Optional[int]): The size per microbatch. + enable_metadata_cache (bool): whether to enable metadata cache to acclerate communication. + overlap_p2p (bool): whether to use overlap_p2p. + """ + def __init__( self, stage_manager: PipelineStageManager, From f67ce86c689051a2e00058470328fb7e9e214ad3 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 18 Dec 2024 17:03:36 +0800 Subject: [PATCH 27/54] [fix] rm doc --- .../zerobubble_pipeline_parallelism.md | 236 ----------------- .../zerobubble_pipeline_parallelism.md | 238 ------------------ 2 files changed, 474 deletions(-) delete mode 100644 docs/source/en/features/zerobubble_pipeline_parallelism.md delete mode 100644 docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md diff --git a/docs/source/en/features/zerobubble_pipeline_parallelism.md b/docs/source/en/features/zerobubble_pipeline_parallelism.md deleted file mode 100644 index 604c7c533a5a..000000000000 --- a/docs/source/en/features/zerobubble_pipeline_parallelism.md +++ /dev/null @@ -1,236 +0,0 @@ -# ZeroBubble Pipeline Parallelism -Author: [Junwen Duan](https://github.com/duanjunwen), [Hongxin Liu](https://github.com/ver217) - -**Related Paper** -- [Zero Bubble Pipeline Parallelism](https://arxiv.org/abs/2401.10241) - -## Introduction -ZeroBubble (V Schedule): -Crucially, splitting B into two stages (also known as an activation gradient and a weight gradient) and a scheme like 1F1B1W can further reduce the bubble compared to the 1F1B scheme in earlier work. - -## Hands-On Practice -We now demonstrate how to use ZeroBubble with booster API with 4 GPUs. - -### step 1. Import libraries -```python -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.testing import assert_close -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaModel - -import colossalai -from colossalai.booster.booster import Booster -from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin -from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler -``` - -### step 2. Initialize Distributed Environment and Parallism Group -```python -colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") -``` - -### step 3. Initialize Module, Optimizer, and Pipeline Schedule -Build our model and Optimizer. We created a Llama with 8 Decoder-Layer. Then, inite the PipelineGraph and Pipeline schedule by get_v_schedule() function. -```python -# Global Param -NUM_BATCH = 8 -NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 -NUM_LAYERS = 8 -HIDDEN_SIZE_PER_HEAD = 4 -NUM_HEADS = 4 -TOP_K = 1 -# Init Llama from huggingface -configuration = LlamaConfig( - hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, - intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, - num_hidden_layers=NUM_LAYERS, - num_attention_heads=NUM_HEADS, - num_key_value_heads=NUM_HEADS, - attn_implementation="flash_attention_2", -) -model = LlamaModel(configuration).cuda() -optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) -``` -### step 4. Initialize Pipeline Schedule -Then, we need to create the PipelineGraph and PipelineSchedule using the get_v_schedule() function. We need to initialise the PipelineGraph with the following parameters. -x_cost represents the runtime consumed by operation x of each model chunk. -x_mem represents the amount of memory consumed by the operation x of each model chunk. -These parameters are estimated and filled in before the pipeline starts. In fact, better results can be obtained based on the runtime and memory cost during the real computation of the model. -In the following example, we assume that the computation times for the model's forward, reverse B, and reverse W are 1, 1, 1, respectively, and the p2p communication time is 1. -```python -# Init schedule -h, a, s = config.hidden_size, config.num_attention_heads, 1024 -mem_f = 34 * h + 5 * a * s -mem_w = -32 * h -mem_b = -mem_w - mem_f -graph = PipelineGraph( - n_stage=pp_size, - n_micro=num_microbatches, - f_cost=1, - b_cost=1, - w_cost=1, - c_cost=1, - f_mem=mem_f, - b_mem=mem_b, - w_mem=mem_w, -) -zbv_schedule = graph.get_v_schedule() -``` - -### step 5.Init Booster -Pass pp_style="zbv" when initialising the Plugin to use the ZeroBubble Pipeline. -```python -plugin = HybridParallelPlugin( - pp_size=4, - num_microbatches=4, - tp_size=1, - sp_size=1, - zero_stage=1, - initial_scale=1, - find_unused_parameters=True, - pp_style="zbv", - scheduler_nodes=zbv_schedule, - num_model_chunks=2, -) - -dp_size = plugin.dp_size -booster = Booster(plugin=plugin) -``` - -### step 6.Train Your Model -```python -steps = 10 -for step in range(steps): - input_embeddings = torch.rand( - NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True - ).cuda() - dist.all_reduce( - input_embeddings, group=plugin.pp_group - ) - data_iter = iter([{"inputs_embeds": input_embeddings}]) - output = booster.execute_pipeline( - data_iter, - model, - lambda x, y: x.last_hidden_state.mean(), - optimizer, - return_loss=True, - return_outputs=True, - ) - optimizer.step() - optimizer.zero_grad() -``` - -## Advanced Practice -In ColossalAI, you can get better training performance by using MetaCache and HybridParallel with ZeroBubble. -### 1.Use MetaCache with ZeroBubble -Pass "enable_metadata_cache=True" when initialising the Plugin to use the Meta Cache with ZeroBubble Pipeline. -```python -plugin = HybridParallelPlugin( - pp_size=2, - num_microbatches=4, - tp_size=2, - sp_size=2, - zero_stage=1, - initial_scale=1, - enable_metadata_cache=True, - find_unused_parameters=True, - pp_style="zbv", - scheduler_nodes=zbv_schedule, - num_model_chunks=2, -) -``` - -### 2.HybridParallel with ZeroBubble -Pass pp_size, tp_size, sp_size when initialising the Plugin to use the HybridParallel with ZeroBubble Pipeline. -```python -plugin = HybridParallelPlugin( - pp_size=2, - num_microbatches=2, - tp_size=2, - sp_size=2, - zero_stage=1, - initial_scale=1, - find_unused_parameters=True, - pp_style="zbv", - scheduler_nodes=zbv_schedule, - num_model_chunks=2, -) -``` -Performance Benchmark -
- - - - - - - - - - - - - - - - - - - - - -
HybridParallel StrategyPipeline ParallelSequence Parallel + Pipeline ParallelData Parallel + Pipeline Parallel
With 1F1B15.27 samples/sec17.22 samples/sec14.06 samples/sec
With Zero Bubble17.36 samples/sec18.38 samples/sec14.44 samples/sec
- -## Model compatibility - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Shardformer/ModelBertBlip2BloomChatglm2CommandDeepseekFalconGPT2GptjLlamaMistralOptQwen2SamT5VitWhisper
ZeroBubble✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️
- -## API Reference -{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }} - - diff --git a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md deleted file mode 100644 index fd361dee048f..000000000000 --- a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md +++ /dev/null @@ -1,238 +0,0 @@ -# 零气泡流水线并行 -作者: [Junwen Duan](https://github.com/duanjunwen), [Hongxin Liu](https://github.com/ver217) - -**相关论文** -- [Zero Bubble Pipeline Parallelism](https://arxiv.org/abs/2401.10241) - -## 介绍 -零气泡(V Schedule): -与早期工作中的1F1B方案相比,零气泡流水线并行将B分成两个阶段(也称为激活梯度和权重梯度),形如1F1B1W这样的方案可以进一步减少气泡。 - -## 使用 -我们将演示如何在 4 个 GPU 上使用带有 booster API 的 ZeroBubble - -### step 1. 引用仓库 -```python -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.testing import assert_close -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaModel - -import colossalai -from colossalai.booster.booster import Booster -from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin -from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler -``` - -### step 2. 初始化分布式环境 -```python -colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") -``` - -### step 3. 初始化模型优化器 -建立我们的模型和优化器 我们创建了一个带有8层Decoder-Layer的 Llama。然后,使用get_v_schedule()函数创建PipelineGraph和Pipeline schedule。 - -```python -# Global Param -NUM_BATCH = 8 -NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 -NUM_LAYERS = 8 -HIDDEN_SIZE_PER_HEAD = 4 -NUM_HEADS = 4 -TOP_K = 1 -# Init Llama from huggingface -configuration = LlamaConfig( - hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, - intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, - num_hidden_layers=NUM_LAYERS, - num_attention_heads=NUM_HEADS, - num_key_value_heads=NUM_HEADS, - attn_implementation="flash_attention_2", -) -model = LlamaModel(configuration).cuda() -optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) -``` -### step 4.初始化流水线Schedule -然后,我们需要使用 get_v_schedule() 函数创建 PipelineGraph 和 PipelineSchedule。我们需要用以下参数初始化 PipelineGraph。 -x_cost 表示每个模型块的操作 x 所消耗的运行时间。 -x_mem 表示每个模型块的操作 x 所消耗的内存量。 -这些参数都是在流水线启动前估算并填入的。事实上,在模型的实际计算过程中,根据运行时间和内存成本可以获得更好的结果。 -在下面的例子中,我们假设模型的正向、反向 B 和反向 W 的计算时间分别为 1、1、1,p2p 通信时间为 1。 -```python -# Init schedule -h, a, s = config.hidden_size, config.num_attention_heads, 1024 -mem_f = 34 * h + 5 * a * s -mem_w = -32 * h -mem_b = -mem_w - mem_f -graph = PipelineGraph( - n_stage=pp_size, - n_micro=num_microbatches, - f_cost=1, - b_cost=1, - w_cost=1, - c_cost=1, - f_mem=mem_f, - b_mem=mem_b, - w_mem=mem_w, -) -zbv_schedule = graph.get_v_schedule() -``` - -### step 5.初始化Booster -在初始化Plugin时输入pp_style="zbv",以使用ZeroBubble流水线并行。 -```python -plugin = HybridParallelPlugin( - pp_size=4, - num_microbatches=4, - tp_size=1, - sp_size=1, - zero_stage=1, - initial_scale=1, - find_unused_parameters=True, - pp_style="zbv", - scheduler_nodes=zbv_schedule, - num_model_chunks=2, -) - -dp_size = plugin.dp_size -booster = Booster(plugin=plugin) -``` - -### step 6.训练模型 -```python -steps = 10 -for step in range(steps): - input_embeddings = torch.rand( - NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True - ).cuda() - dist.all_reduce( - input_embeddings, group=plugin.pp_group - ) - data_iter = iter([{"inputs_embeds": input_embeddings}]) - output = booster.execute_pipeline( - data_iter, - model, - lambda x, y: x.last_hidden_state.mean(), - optimizer, - return_loss=True, - return_outputs=True, - ) - optimizer.step() - optimizer.zero_grad() -``` - -## 进阶使用技巧 -在 ColossalAI 中,通过使用MetaCache和混合并行的ZeroBubble,可以获得更好的训练性能。 - -### 1.在ZeroBubble中使用元数据缓存 -在初始化Plugin时输入 "enable_metadata_cache=True",以便在ZeroBubble管道中使用元数据缓存。 -```python -plugin = HybridParallelPlugin( - pp_size=2, - num_microbatches=4, - tp_size=2, - sp_size=2, - zero_stage=1, - initial_scale=1, - enable_metadata_cache=True, - find_unused_parameters=True, - pp_style="zbv", - scheduler_nodes=zbv_schedule, - num_model_chunks=2, -) -``` - -### 2.同时使用ZeroBubble和混合并行 -在初始化插件时传递 pp_size, tp_size, sp_size, 以便使用零气泡混合并行管道(HybridParallel with ZeroBubble Pipeline)。 -```python -plugin = HybridParallelPlugin( - pp_size=2, - num_microbatches=2, - tp_size=2, - sp_size=2, - zero_stage=1, - initial_scale=1, - find_unused_parameters=True, - pp_style="zbv", - scheduler_nodes=zbv_schedule, - num_model_chunks=2, -) -``` -性能指标 - - - - - - - - - - - - - - - - - - - - - - -
HybridParallel StrategyPipeline ParallelSequence Parallel + Pipeline ParallelData Parallel + Pipeline Parallel
With 1F1B15.27 samples/sec17.22 samples/sec14.06 samples/sec
With Zero Bubble17.36 samples/sec18.38 samples/sec14.44 samples/sec
- -## 模型兼容性 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Shardformer/ModelBertBlip2BloomChatglm2CommandDeepseekFalconGPT2GptjLlamaMistralOptQwen2SamT5VitWhisper
ZeroBubble✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️
- -## API 参考 -{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }} - - From bbdcca10b7734b74d9781a476ebed5396e559d73 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 18 Dec 2024 17:14:26 +0800 Subject: [PATCH 28/54] [fix] fix doc --- .../zerobubble_pipeline_parallelism.md | 238 ++++++++++++++++++ .../zerobubble_pipeline_parallelism.md | 238 ++++++++++++++++++ 2 files changed, 476 insertions(+) create mode 100644 docs/source/en/features/zerobubble_pipeline_parallelism.md create mode 100644 docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md diff --git a/docs/source/en/features/zerobubble_pipeline_parallelism.md b/docs/source/en/features/zerobubble_pipeline_parallelism.md new file mode 100644 index 000000000000..fd361dee048f --- /dev/null +++ b/docs/source/en/features/zerobubble_pipeline_parallelism.md @@ -0,0 +1,238 @@ +# 零气泡流水线并行 +作者: [Junwen Duan](https://github.com/duanjunwen), [Hongxin Liu](https://github.com/ver217) + +**相关论文** +- [Zero Bubble Pipeline Parallelism](https://arxiv.org/abs/2401.10241) + +## 介绍 +零气泡(V Schedule): +与早期工作中的1F1B方案相比,零气泡流水线并行将B分成两个阶段(也称为激活梯度和权重梯度),形如1F1B1W这样的方案可以进一步减少气泡。 + +## 使用 +我们将演示如何在 4 个 GPU 上使用带有 booster API 的 ZeroBubble + +### step 1. 引用仓库 +```python +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaModel + +import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin +from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler +``` + +### step 2. 初始化分布式环境 +```python +colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") +``` + +### step 3. 初始化模型优化器 +建立我们的模型和优化器 我们创建了一个带有8层Decoder-Layer的 Llama。然后,使用get_v_schedule()函数创建PipelineGraph和Pipeline schedule。 + +```python +# Global Param +NUM_BATCH = 8 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 +NUM_LAYERS = 8 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +TOP_K = 1 +# Init Llama from huggingface +configuration = LlamaConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + attn_implementation="flash_attention_2", +) +model = LlamaModel(configuration).cuda() +optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) +``` +### step 4.初始化流水线Schedule +然后,我们需要使用 get_v_schedule() 函数创建 PipelineGraph 和 PipelineSchedule。我们需要用以下参数初始化 PipelineGraph。 +x_cost 表示每个模型块的操作 x 所消耗的运行时间。 +x_mem 表示每个模型块的操作 x 所消耗的内存量。 +这些参数都是在流水线启动前估算并填入的。事实上,在模型的实际计算过程中,根据运行时间和内存成本可以获得更好的结果。 +在下面的例子中,我们假设模型的正向、反向 B 和反向 W 的计算时间分别为 1、1、1,p2p 通信时间为 1。 +```python +# Init schedule +h, a, s = config.hidden_size, config.num_attention_heads, 1024 +mem_f = 34 * h + 5 * a * s +mem_w = -32 * h +mem_b = -mem_w - mem_f +graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatches, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, +) +zbv_schedule = graph.get_v_schedule() +``` + +### step 5.初始化Booster +在初始化Plugin时输入pp_style="zbv",以使用ZeroBubble流水线并行。 +```python +plugin = HybridParallelPlugin( + pp_size=4, + num_microbatches=4, + tp_size=1, + sp_size=1, + zero_stage=1, + initial_scale=1, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) + +dp_size = plugin.dp_size +booster = Booster(plugin=plugin) +``` + +### step 6.训练模型 +```python +steps = 10 +for step in range(steps): + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) + data_iter = iter([{"inputs_embeds": input_embeddings}]) + output = booster.execute_pipeline( + data_iter, + model, + lambda x, y: x.last_hidden_state.mean(), + optimizer, + return_loss=True, + return_outputs=True, + ) + optimizer.step() + optimizer.zero_grad() +``` + +## 进阶使用技巧 +在 ColossalAI 中,通过使用MetaCache和混合并行的ZeroBubble,可以获得更好的训练性能。 + +### 1.在ZeroBubble中使用元数据缓存 +在初始化Plugin时输入 "enable_metadata_cache=True",以便在ZeroBubble管道中使用元数据缓存。 +```python +plugin = HybridParallelPlugin( + pp_size=2, + num_microbatches=4, + tp_size=2, + sp_size=2, + zero_stage=1, + initial_scale=1, + enable_metadata_cache=True, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) +``` + +### 2.同时使用ZeroBubble和混合并行 +在初始化插件时传递 pp_size, tp_size, sp_size, 以便使用零气泡混合并行管道(HybridParallel with ZeroBubble Pipeline)。 +```python +plugin = HybridParallelPlugin( + pp_size=2, + num_microbatches=2, + tp_size=2, + sp_size=2, + zero_stage=1, + initial_scale=1, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) +``` +性能指标 + + + + + + + + + + + + + + + + + + + + + + +
HybridParallel StrategyPipeline ParallelSequence Parallel + Pipeline ParallelData Parallel + Pipeline Parallel
With 1F1B15.27 samples/sec17.22 samples/sec14.06 samples/sec
With Zero Bubble17.36 samples/sec18.38 samples/sec14.44 samples/sec
+ +## 模型兼容性 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Shardformer/ModelBertBlip2BloomChatglm2CommandDeepseekFalconGPT2GptjLlamaMistralOptQwen2SamT5VitWhisper
ZeroBubble✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️
+ +## API 参考 +{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }} + + diff --git a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md new file mode 100644 index 000000000000..fd361dee048f --- /dev/null +++ b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md @@ -0,0 +1,238 @@ +# 零气泡流水线并行 +作者: [Junwen Duan](https://github.com/duanjunwen), [Hongxin Liu](https://github.com/ver217) + +**相关论文** +- [Zero Bubble Pipeline Parallelism](https://arxiv.org/abs/2401.10241) + +## 介绍 +零气泡(V Schedule): +与早期工作中的1F1B方案相比,零气泡流水线并行将B分成两个阶段(也称为激活梯度和权重梯度),形如1F1B1W这样的方案可以进一步减少气泡。 + +## 使用 +我们将演示如何在 4 个 GPU 上使用带有 booster API 的 ZeroBubble + +### step 1. 引用仓库 +```python +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaModel + +import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin +from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler +``` + +### step 2. 初始化分布式环境 +```python +colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") +``` + +### step 3. 初始化模型优化器 +建立我们的模型和优化器 我们创建了一个带有8层Decoder-Layer的 Llama。然后,使用get_v_schedule()函数创建PipelineGraph和Pipeline schedule。 + +```python +# Global Param +NUM_BATCH = 8 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 +NUM_LAYERS = 8 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +TOP_K = 1 +# Init Llama from huggingface +configuration = LlamaConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + attn_implementation="flash_attention_2", +) +model = LlamaModel(configuration).cuda() +optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) +``` +### step 4.初始化流水线Schedule +然后,我们需要使用 get_v_schedule() 函数创建 PipelineGraph 和 PipelineSchedule。我们需要用以下参数初始化 PipelineGraph。 +x_cost 表示每个模型块的操作 x 所消耗的运行时间。 +x_mem 表示每个模型块的操作 x 所消耗的内存量。 +这些参数都是在流水线启动前估算并填入的。事实上,在模型的实际计算过程中,根据运行时间和内存成本可以获得更好的结果。 +在下面的例子中,我们假设模型的正向、反向 B 和反向 W 的计算时间分别为 1、1、1,p2p 通信时间为 1。 +```python +# Init schedule +h, a, s = config.hidden_size, config.num_attention_heads, 1024 +mem_f = 34 * h + 5 * a * s +mem_w = -32 * h +mem_b = -mem_w - mem_f +graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatches, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, +) +zbv_schedule = graph.get_v_schedule() +``` + +### step 5.初始化Booster +在初始化Plugin时输入pp_style="zbv",以使用ZeroBubble流水线并行。 +```python +plugin = HybridParallelPlugin( + pp_size=4, + num_microbatches=4, + tp_size=1, + sp_size=1, + zero_stage=1, + initial_scale=1, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) + +dp_size = plugin.dp_size +booster = Booster(plugin=plugin) +``` + +### step 6.训练模型 +```python +steps = 10 +for step in range(steps): + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) + data_iter = iter([{"inputs_embeds": input_embeddings}]) + output = booster.execute_pipeline( + data_iter, + model, + lambda x, y: x.last_hidden_state.mean(), + optimizer, + return_loss=True, + return_outputs=True, + ) + optimizer.step() + optimizer.zero_grad() +``` + +## 进阶使用技巧 +在 ColossalAI 中,通过使用MetaCache和混合并行的ZeroBubble,可以获得更好的训练性能。 + +### 1.在ZeroBubble中使用元数据缓存 +在初始化Plugin时输入 "enable_metadata_cache=True",以便在ZeroBubble管道中使用元数据缓存。 +```python +plugin = HybridParallelPlugin( + pp_size=2, + num_microbatches=4, + tp_size=2, + sp_size=2, + zero_stage=1, + initial_scale=1, + enable_metadata_cache=True, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) +``` + +### 2.同时使用ZeroBubble和混合并行 +在初始化插件时传递 pp_size, tp_size, sp_size, 以便使用零气泡混合并行管道(HybridParallel with ZeroBubble Pipeline)。 +```python +plugin = HybridParallelPlugin( + pp_size=2, + num_microbatches=2, + tp_size=2, + sp_size=2, + zero_stage=1, + initial_scale=1, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) +``` +性能指标 + + + + + + + + + + + + + + + + + + + + + + +
HybridParallel StrategyPipeline ParallelSequence Parallel + Pipeline ParallelData Parallel + Pipeline Parallel
With 1F1B15.27 samples/sec17.22 samples/sec14.06 samples/sec
With Zero Bubble17.36 samples/sec18.38 samples/sec14.44 samples/sec
+ +## 模型兼容性 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Shardformer/ModelBertBlip2BloomChatglm2CommandDeepseekFalconGPT2GptjLlamaMistralOptQwen2SamT5VitWhisper
ZeroBubble✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️
+ +## API 参考 +{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }} + + From 9665f661ea7b7977de4c03e08d80135943a0c7d5 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 18 Dec 2024 17:23:04 +0800 Subject: [PATCH 29/54] [fix] empty zbv doc --- .../zerobubble_pipeline_parallelism.md | 238 ------------------ .../zerobubble_pipeline_parallelism.md | 238 ------------------ 2 files changed, 476 deletions(-) diff --git a/docs/source/en/features/zerobubble_pipeline_parallelism.md b/docs/source/en/features/zerobubble_pipeline_parallelism.md index fd361dee048f..e69de29bb2d1 100644 --- a/docs/source/en/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/en/features/zerobubble_pipeline_parallelism.md @@ -1,238 +0,0 @@ -# 零气泡流水线并行 -作者: [Junwen Duan](https://github.com/duanjunwen), [Hongxin Liu](https://github.com/ver217) - -**相关论文** -- [Zero Bubble Pipeline Parallelism](https://arxiv.org/abs/2401.10241) - -## 介绍 -零气泡(V Schedule): -与早期工作中的1F1B方案相比,零气泡流水线并行将B分成两个阶段(也称为激活梯度和权重梯度),形如1F1B1W这样的方案可以进一步减少气泡。 - -## 使用 -我们将演示如何在 4 个 GPU 上使用带有 booster API 的 ZeroBubble - -### step 1. 引用仓库 -```python -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.testing import assert_close -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaModel - -import colossalai -from colossalai.booster.booster import Booster -from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin -from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler -``` - -### step 2. 初始化分布式环境 -```python -colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") -``` - -### step 3. 初始化模型优化器 -建立我们的模型和优化器 我们创建了一个带有8层Decoder-Layer的 Llama。然后,使用get_v_schedule()函数创建PipelineGraph和Pipeline schedule。 - -```python -# Global Param -NUM_BATCH = 8 -NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 -NUM_LAYERS = 8 -HIDDEN_SIZE_PER_HEAD = 4 -NUM_HEADS = 4 -TOP_K = 1 -# Init Llama from huggingface -configuration = LlamaConfig( - hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, - intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, - num_hidden_layers=NUM_LAYERS, - num_attention_heads=NUM_HEADS, - num_key_value_heads=NUM_HEADS, - attn_implementation="flash_attention_2", -) -model = LlamaModel(configuration).cuda() -optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) -``` -### step 4.初始化流水线Schedule -然后,我们需要使用 get_v_schedule() 函数创建 PipelineGraph 和 PipelineSchedule。我们需要用以下参数初始化 PipelineGraph。 -x_cost 表示每个模型块的操作 x 所消耗的运行时间。 -x_mem 表示每个模型块的操作 x 所消耗的内存量。 -这些参数都是在流水线启动前估算并填入的。事实上,在模型的实际计算过程中,根据运行时间和内存成本可以获得更好的结果。 -在下面的例子中,我们假设模型的正向、反向 B 和反向 W 的计算时间分别为 1、1、1,p2p 通信时间为 1。 -```python -# Init schedule -h, a, s = config.hidden_size, config.num_attention_heads, 1024 -mem_f = 34 * h + 5 * a * s -mem_w = -32 * h -mem_b = -mem_w - mem_f -graph = PipelineGraph( - n_stage=pp_size, - n_micro=num_microbatches, - f_cost=1, - b_cost=1, - w_cost=1, - c_cost=1, - f_mem=mem_f, - b_mem=mem_b, - w_mem=mem_w, -) -zbv_schedule = graph.get_v_schedule() -``` - -### step 5.初始化Booster -在初始化Plugin时输入pp_style="zbv",以使用ZeroBubble流水线并行。 -```python -plugin = HybridParallelPlugin( - pp_size=4, - num_microbatches=4, - tp_size=1, - sp_size=1, - zero_stage=1, - initial_scale=1, - find_unused_parameters=True, - pp_style="zbv", - scheduler_nodes=zbv_schedule, - num_model_chunks=2, -) - -dp_size = plugin.dp_size -booster = Booster(plugin=plugin) -``` - -### step 6.训练模型 -```python -steps = 10 -for step in range(steps): - input_embeddings = torch.rand( - NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True - ).cuda() - dist.all_reduce( - input_embeddings, group=plugin.pp_group - ) - data_iter = iter([{"inputs_embeds": input_embeddings}]) - output = booster.execute_pipeline( - data_iter, - model, - lambda x, y: x.last_hidden_state.mean(), - optimizer, - return_loss=True, - return_outputs=True, - ) - optimizer.step() - optimizer.zero_grad() -``` - -## 进阶使用技巧 -在 ColossalAI 中,通过使用MetaCache和混合并行的ZeroBubble,可以获得更好的训练性能。 - -### 1.在ZeroBubble中使用元数据缓存 -在初始化Plugin时输入 "enable_metadata_cache=True",以便在ZeroBubble管道中使用元数据缓存。 -```python -plugin = HybridParallelPlugin( - pp_size=2, - num_microbatches=4, - tp_size=2, - sp_size=2, - zero_stage=1, - initial_scale=1, - enable_metadata_cache=True, - find_unused_parameters=True, - pp_style="zbv", - scheduler_nodes=zbv_schedule, - num_model_chunks=2, -) -``` - -### 2.同时使用ZeroBubble和混合并行 -在初始化插件时传递 pp_size, tp_size, sp_size, 以便使用零气泡混合并行管道(HybridParallel with ZeroBubble Pipeline)。 -```python -plugin = HybridParallelPlugin( - pp_size=2, - num_microbatches=2, - tp_size=2, - sp_size=2, - zero_stage=1, - initial_scale=1, - find_unused_parameters=True, - pp_style="zbv", - scheduler_nodes=zbv_schedule, - num_model_chunks=2, -) -``` -性能指标 - - - - - - - - - - - - - - - - - - - - - - -
HybridParallel StrategyPipeline ParallelSequence Parallel + Pipeline ParallelData Parallel + Pipeline Parallel
With 1F1B15.27 samples/sec17.22 samples/sec14.06 samples/sec
With Zero Bubble17.36 samples/sec18.38 samples/sec14.44 samples/sec
- -## 模型兼容性 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Shardformer/ModelBertBlip2BloomChatglm2CommandDeepseekFalconGPT2GptjLlamaMistralOptQwen2SamT5VitWhisper
ZeroBubble✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️
- -## API 参考 -{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }} - - diff --git a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md index fd361dee048f..e69de29bb2d1 100644 --- a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md @@ -1,238 +0,0 @@ -# 零气泡流水线并行 -作者: [Junwen Duan](https://github.com/duanjunwen), [Hongxin Liu](https://github.com/ver217) - -**相关论文** -- [Zero Bubble Pipeline Parallelism](https://arxiv.org/abs/2401.10241) - -## 介绍 -零气泡(V Schedule): -与早期工作中的1F1B方案相比,零气泡流水线并行将B分成两个阶段(也称为激活梯度和权重梯度),形如1F1B1W这样的方案可以进一步减少气泡。 - -## 使用 -我们将演示如何在 4 个 GPU 上使用带有 booster API 的 ZeroBubble - -### step 1. 引用仓库 -```python -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.testing import assert_close -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaModel - -import colossalai -from colossalai.booster.booster import Booster -from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin -from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler -``` - -### step 2. 初始化分布式环境 -```python -colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") -``` - -### step 3. 初始化模型优化器 -建立我们的模型和优化器 我们创建了一个带有8层Decoder-Layer的 Llama。然后,使用get_v_schedule()函数创建PipelineGraph和Pipeline schedule。 - -```python -# Global Param -NUM_BATCH = 8 -NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 -NUM_LAYERS = 8 -HIDDEN_SIZE_PER_HEAD = 4 -NUM_HEADS = 4 -TOP_K = 1 -# Init Llama from huggingface -configuration = LlamaConfig( - hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, - intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, - num_hidden_layers=NUM_LAYERS, - num_attention_heads=NUM_HEADS, - num_key_value_heads=NUM_HEADS, - attn_implementation="flash_attention_2", -) -model = LlamaModel(configuration).cuda() -optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) -``` -### step 4.初始化流水线Schedule -然后,我们需要使用 get_v_schedule() 函数创建 PipelineGraph 和 PipelineSchedule。我们需要用以下参数初始化 PipelineGraph。 -x_cost 表示每个模型块的操作 x 所消耗的运行时间。 -x_mem 表示每个模型块的操作 x 所消耗的内存量。 -这些参数都是在流水线启动前估算并填入的。事实上,在模型的实际计算过程中,根据运行时间和内存成本可以获得更好的结果。 -在下面的例子中,我们假设模型的正向、反向 B 和反向 W 的计算时间分别为 1、1、1,p2p 通信时间为 1。 -```python -# Init schedule -h, a, s = config.hidden_size, config.num_attention_heads, 1024 -mem_f = 34 * h + 5 * a * s -mem_w = -32 * h -mem_b = -mem_w - mem_f -graph = PipelineGraph( - n_stage=pp_size, - n_micro=num_microbatches, - f_cost=1, - b_cost=1, - w_cost=1, - c_cost=1, - f_mem=mem_f, - b_mem=mem_b, - w_mem=mem_w, -) -zbv_schedule = graph.get_v_schedule() -``` - -### step 5.初始化Booster -在初始化Plugin时输入pp_style="zbv",以使用ZeroBubble流水线并行。 -```python -plugin = HybridParallelPlugin( - pp_size=4, - num_microbatches=4, - tp_size=1, - sp_size=1, - zero_stage=1, - initial_scale=1, - find_unused_parameters=True, - pp_style="zbv", - scheduler_nodes=zbv_schedule, - num_model_chunks=2, -) - -dp_size = plugin.dp_size -booster = Booster(plugin=plugin) -``` - -### step 6.训练模型 -```python -steps = 10 -for step in range(steps): - input_embeddings = torch.rand( - NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True - ).cuda() - dist.all_reduce( - input_embeddings, group=plugin.pp_group - ) - data_iter = iter([{"inputs_embeds": input_embeddings}]) - output = booster.execute_pipeline( - data_iter, - model, - lambda x, y: x.last_hidden_state.mean(), - optimizer, - return_loss=True, - return_outputs=True, - ) - optimizer.step() - optimizer.zero_grad() -``` - -## 进阶使用技巧 -在 ColossalAI 中,通过使用MetaCache和混合并行的ZeroBubble,可以获得更好的训练性能。 - -### 1.在ZeroBubble中使用元数据缓存 -在初始化Plugin时输入 "enable_metadata_cache=True",以便在ZeroBubble管道中使用元数据缓存。 -```python -plugin = HybridParallelPlugin( - pp_size=2, - num_microbatches=4, - tp_size=2, - sp_size=2, - zero_stage=1, - initial_scale=1, - enable_metadata_cache=True, - find_unused_parameters=True, - pp_style="zbv", - scheduler_nodes=zbv_schedule, - num_model_chunks=2, -) -``` - -### 2.同时使用ZeroBubble和混合并行 -在初始化插件时传递 pp_size, tp_size, sp_size, 以便使用零气泡混合并行管道(HybridParallel with ZeroBubble Pipeline)。 -```python -plugin = HybridParallelPlugin( - pp_size=2, - num_microbatches=2, - tp_size=2, - sp_size=2, - zero_stage=1, - initial_scale=1, - find_unused_parameters=True, - pp_style="zbv", - scheduler_nodes=zbv_schedule, - num_model_chunks=2, -) -``` -性能指标 - - - - - - - - - - - - - - - - - - - - - - -
HybridParallel StrategyPipeline ParallelSequence Parallel + Pipeline ParallelData Parallel + Pipeline Parallel
With 1F1B15.27 samples/sec17.22 samples/sec14.06 samples/sec
With Zero Bubble17.36 samples/sec18.38 samples/sec14.44 samples/sec
- -## 模型兼容性 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Shardformer/ModelBertBlip2BloomChatglm2CommandDeepseekFalconGPT2GptjLlamaMistralOptQwen2SamT5VitWhisper
ZeroBubble✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️
- -## API 参考 -{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }} - - From 568e2c56a9fbe7274111296622859c141348cbec Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 20 Dec 2024 10:53:39 +0800 Subject: [PATCH 30/54] [fix] ifx torch version --- .../zerobubble_pipeline_parallelism.md | 239 ++++++++++++++++++ .../zerobubble_pipeline_parallelism.md | 238 +++++++++++++++++ requirements/requirements.txt | 2 +- 3 files changed, 478 insertions(+), 1 deletion(-) diff --git a/docs/source/en/features/zerobubble_pipeline_parallelism.md b/docs/source/en/features/zerobubble_pipeline_parallelism.md index e69de29bb2d1..8f676de96459 100644 --- a/docs/source/en/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/en/features/zerobubble_pipeline_parallelism.md @@ -0,0 +1,239 @@ +# ZeroBubble Pipeline Parallelism +Author: [Junwen Duan](https://github.com/duanjunwen), [Hongxin Liu](https://github.com/ver217) + +**Related Paper** +- [Zero Bubble Pipeline Parallelism](https://arxiv.org/abs/2401.10241) + +## Introduction +ZeroBubble (V Schedule): +Crucially, splitting B into two stages (also known as an activation gradient and a weight gradient) and a scheme like 1F1B1W can further reduce the bubble compared to the 1F1B scheme in earlier work. + +## Hands-On Practice +We now demonstrate how to use ZeroBubble with booster API with 4 GPUs. + +### step 1. Import libraries +```python +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaModel + +import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin +from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler +``` + +### step 2. Initialize Distributed Environment and Parallism Group +```python +colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") +``` + +### step 3. Initialize Module, Optimizer, and Pipeline Schedule +Build our model and Optimizer. We created a Llama with 8 Decoder-Layer. Then, inite the PipelineGraph and Pipeline schedule by get_v_schedule() function. +```python +# Global Param +NUM_BATCH = 8 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 +NUM_LAYERS = 8 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +TOP_K = 1 +# Init Llama from huggingface +configuration = LlamaConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + attn_implementation="flash_attention_2", +) +model = LlamaModel(configuration).cuda() +optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) +``` +### step 4. Initialize Module, Optimizer, and Pipeline Schedul +Then, we need to create the PipelineGraph and PipelineSchedule using the get_v_schedule() function. We need to initialise the PipelineGraph with the following parameters. +x_cost represents the runtime consumed by operation x of each model chunk. +x_mem represents the amount of memory consumed by the operation x of each model chunk. +These parameters are estimated and filled in before the pipeline starts. In fact, better results can be obtained based on the runtime and memory cost during the real computation of the model. +In the following example, we assume that the computation times for the model's forward, reverse B, and reverse W are 1, 1, 1, respectively, and the p2p communication time is 1. +```python +# Init schedule +h, a, s = config.hidden_size, config.num_attention_heads, 1024 +mem_f = 34 * h + 5 * a * s +mem_w = -32 * h +mem_b = -mem_w - mem_f +graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatches, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, +) +zbv_schedule = graph.get_v_schedule() +``` + +### step 5.Init Booster +Pass pp_style="zbv" when initialising the Plugin to use the ZeroBubble Pipeline. +```python +plugin = HybridParallelPlugin( + pp_size=4, + num_microbatches=4, + tp_size=1, + sp_size=1, + zero_stage=1, + initial_scale=1, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) + +dp_size = plugin.dp_size +booster = Booster(plugin=plugin) +``` + +### step 6.Train Your Model +```python +steps = 10 +for step in range(steps): + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) + data_iter = iter([{"inputs_embeds": input_embeddings}]) + output = booster.execute_pipeline( + data_iter, + model, + lambda x, y: x.last_hidden_state.mean(), + optimizer, + return_loss=True, + return_outputs=True, + ) + optimizer.step() + optimizer.zero_grad() +``` + +## Advanced Practice +In ColossalAI, you can get better training performance by using MetaCache and HybridParallel with ZeroBubble. +### 1.Use MetaCache with ZeroBubble +Pass "enable_metadata_cache=True" when initialising the Plugin to use the Meta Cache with ZeroBubble Pipeline. +```python +plugin = HybridParallelPlugin( + pp_size=2, + num_microbatches=4, + tp_size=2, + sp_size=2, + zero_stage=1, + initial_scale=1, + enable_metadata_cache=True, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) +``` + +### 2.HybridParallel with ZeroBubble +Pass pp_size, tp_size, sp_size when initialising the Plugin to use the HybridParallel with ZeroBubble Pipeline. +```python +plugin = HybridParallelPlugin( + pp_size=2, + num_microbatches=2, + tp_size=2, + sp_size=2, + zero_stage=1, + initial_scale=1, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) +``` +Performance Benchmark + + + + + + + + + + + + + + + + + + + + + + +
HybridParallel StrategyPipeline ParallelSequence Parallel + Pipeline ParallelData Parallel + Pipeline Parallel
With 1F1B15.27 samples/sec17.22 samples/sec14.06 samples/sec
With Zero Bubble17.36 samples/sec18.38 samples/sec14.44 samples/sec
+ +### 3.Fine-tuning Scheduler parameters + +```python +``` +## Model compatibility + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Shardformer/ModelBertBlip2BloomChatglm2CommandDeepseekFalconGPT2GptjLlamaMistralOptQwen2SamT5VitWhisper
ZeroBubble✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️
+ +## API Reference +{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.py.ZeroBubbleVPipeScheduler }} + + diff --git a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md index e69de29bb2d1..fd361dee048f 100644 --- a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md @@ -0,0 +1,238 @@ +# 零气泡流水线并行 +作者: [Junwen Duan](https://github.com/duanjunwen), [Hongxin Liu](https://github.com/ver217) + +**相关论文** +- [Zero Bubble Pipeline Parallelism](https://arxiv.org/abs/2401.10241) + +## 介绍 +零气泡(V Schedule): +与早期工作中的1F1B方案相比,零气泡流水线并行将B分成两个阶段(也称为激活梯度和权重梯度),形如1F1B1W这样的方案可以进一步减少气泡。 + +## 使用 +我们将演示如何在 4 个 GPU 上使用带有 booster API 的 ZeroBubble + +### step 1. 引用仓库 +```python +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaModel + +import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin +from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler +``` + +### step 2. 初始化分布式环境 +```python +colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") +``` + +### step 3. 初始化模型优化器 +建立我们的模型和优化器 我们创建了一个带有8层Decoder-Layer的 Llama。然后,使用get_v_schedule()函数创建PipelineGraph和Pipeline schedule。 + +```python +# Global Param +NUM_BATCH = 8 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 +NUM_LAYERS = 8 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +TOP_K = 1 +# Init Llama from huggingface +configuration = LlamaConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + attn_implementation="flash_attention_2", +) +model = LlamaModel(configuration).cuda() +optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) +``` +### step 4.初始化流水线Schedule +然后,我们需要使用 get_v_schedule() 函数创建 PipelineGraph 和 PipelineSchedule。我们需要用以下参数初始化 PipelineGraph。 +x_cost 表示每个模型块的操作 x 所消耗的运行时间。 +x_mem 表示每个模型块的操作 x 所消耗的内存量。 +这些参数都是在流水线启动前估算并填入的。事实上,在模型的实际计算过程中,根据运行时间和内存成本可以获得更好的结果。 +在下面的例子中,我们假设模型的正向、反向 B 和反向 W 的计算时间分别为 1、1、1,p2p 通信时间为 1。 +```python +# Init schedule +h, a, s = config.hidden_size, config.num_attention_heads, 1024 +mem_f = 34 * h + 5 * a * s +mem_w = -32 * h +mem_b = -mem_w - mem_f +graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatches, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, +) +zbv_schedule = graph.get_v_schedule() +``` + +### step 5.初始化Booster +在初始化Plugin时输入pp_style="zbv",以使用ZeroBubble流水线并行。 +```python +plugin = HybridParallelPlugin( + pp_size=4, + num_microbatches=4, + tp_size=1, + sp_size=1, + zero_stage=1, + initial_scale=1, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) + +dp_size = plugin.dp_size +booster = Booster(plugin=plugin) +``` + +### step 6.训练模型 +```python +steps = 10 +for step in range(steps): + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) + data_iter = iter([{"inputs_embeds": input_embeddings}]) + output = booster.execute_pipeline( + data_iter, + model, + lambda x, y: x.last_hidden_state.mean(), + optimizer, + return_loss=True, + return_outputs=True, + ) + optimizer.step() + optimizer.zero_grad() +``` + +## 进阶使用技巧 +在 ColossalAI 中,通过使用MetaCache和混合并行的ZeroBubble,可以获得更好的训练性能。 + +### 1.在ZeroBubble中使用元数据缓存 +在初始化Plugin时输入 "enable_metadata_cache=True",以便在ZeroBubble管道中使用元数据缓存。 +```python +plugin = HybridParallelPlugin( + pp_size=2, + num_microbatches=4, + tp_size=2, + sp_size=2, + zero_stage=1, + initial_scale=1, + enable_metadata_cache=True, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) +``` + +### 2.同时使用ZeroBubble和混合并行 +在初始化插件时传递 pp_size, tp_size, sp_size, 以便使用零气泡混合并行管道(HybridParallel with ZeroBubble Pipeline)。 +```python +plugin = HybridParallelPlugin( + pp_size=2, + num_microbatches=2, + tp_size=2, + sp_size=2, + zero_stage=1, + initial_scale=1, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) +``` +性能指标 + + + + + + + + + + + + + + + + + + + + + + +
HybridParallel StrategyPipeline ParallelSequence Parallel + Pipeline ParallelData Parallel + Pipeline Parallel
With 1F1B15.27 samples/sec17.22 samples/sec14.06 samples/sec
With Zero Bubble17.36 samples/sec18.38 samples/sec14.44 samples/sec
+ +## 模型兼容性 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Shardformer/ModelBertBlip2BloomChatglm2CommandDeepseekFalconGPT2GptjLlamaMistralOptQwen2SamT5VitWhisper
ZeroBubble✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️
+ +## API 参考 +{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }} + + diff --git a/requirements/requirements.txt b/requirements/requirements.txt index f357c45fde64..82f76923d008 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ click fabric contexttimer ninja -torch>=2.2.0,<=2.4.1 +torch>=2.5.0 safetensors einops pydantic From f8dc1508422c30b6f9474566f5f17e7611bf166d Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 20 Dec 2024 10:57:02 +0800 Subject: [PATCH 31/54] [fix] fix torch version --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 82f76923d008..ebd9a25767f7 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ click fabric contexttimer ninja -torch>=2.5.0 +torch<=2.5.0 safetensors einops pydantic From 1481b8d5abb0dd3518e1526070c68239f7f45a46 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 20 Dec 2024 11:00:02 +0800 Subject: [PATCH 32/54] [fix] fix torch versions --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index ebd9a25767f7..94b60a9aac9b 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ click fabric contexttimer ninja -torch<=2.5.0 +torch==2.5.0 safetensors einops pydantic From cb52e2826a67115f11935b79a394d827d3e9304b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 20 Dec 2024 11:13:37 +0800 Subject: [PATCH 33/54] [fix] fix torch versions --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 94b60a9aac9b..f357c45fde64 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ click fabric contexttimer ninja -torch==2.5.0 +torch>=2.2.0,<=2.4.1 safetensors einops pydantic From 30e65e7a614ee846aea64717795c78ca25c61666 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 23 Dec 2024 11:20:28 +0800 Subject: [PATCH 34/54] [fix] fix pyramid versions --- requirements/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index f357c45fde64..5a3e158f5a92 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -24,3 +24,4 @@ fastapi uvicorn==0.29.0 galore_torch diffusers==0.29.0 +pyramid==1.5 From 541664a12b049ad6f13e35ef66c55b56982575b3 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 23 Dec 2024 13:40:09 +0800 Subject: [PATCH 35/54] [fix] fix pyramid, zope version --- requirements/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 5a3e158f5a92..50ff1b47e442 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -24,4 +24,5 @@ fastapi uvicorn==0.29.0 galore_torch diffusers==0.29.0 -pyramid==1.5 +pyramid<=1.10.7 +zope<=5.5.2 From e5928840da43369caa8f13aa2a078df81c1d90b1 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 23 Dec 2024 17:50:16 +0800 Subject: [PATCH 36/54] [fix] try fix workflow --- .github/workflows/doc_check_on_pr.yml | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/.github/workflows/doc_check_on_pr.yml b/.github/workflows/doc_check_on_pr.yml index ee8a82128dd7..4f0b57fd08bc 100644 --- a/.github/workflows/doc_check_on_pr.yml +++ b/.github/workflows/doc_check_on_pr.yml @@ -73,4 +73,14 @@ jobs: cd ColossalAI-Documentation pip install -v ./doc-build/third_party/hf-doc-builder pip install -v ./doc-build - bash ./scripts/build.sh + set -euo pipefail # fail early + SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" + DOC_BUILD_DIR="${SCRIPT_DIR}/../doc-build" + DOCUSAURUS_DIR="${SCRIPT_DIR}/../docusaurus" + cd "${DOC_BUILD_DIR}" + docer extract -o hpcaitech -p ColossalAI + docer autodoc -o hpcaitech -p ColossalAI + docer docusaurus -d ../docusaurus + cd "${DOCUSAURUS_DIR}" + yarn install + yarn build From 3b0669a5620ba24e8212847a79a6e61804f5d35e Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 23 Dec 2024 17:56:33 +0800 Subject: [PATCH 37/54] [fix] try import ShardConfig in yml --- .github/workflows/doc_check_on_pr.yml | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/.github/workflows/doc_check_on_pr.yml b/.github/workflows/doc_check_on_pr.yml index 4f0b57fd08bc..ac22b5c676fa 100644 --- a/.github/workflows/doc_check_on_pr.yml +++ b/.github/workflows/doc_check_on_pr.yml @@ -73,14 +73,5 @@ jobs: cd ColossalAI-Documentation pip install -v ./doc-build/third_party/hf-doc-builder pip install -v ./doc-build - set -euo pipefail # fail early - SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" - DOC_BUILD_DIR="${SCRIPT_DIR}/../doc-build" - DOCUSAURUS_DIR="${SCRIPT_DIR}/../docusaurus" - cd "${DOC_BUILD_DIR}" - docer extract -o hpcaitech -p ColossalAI - docer autodoc -o hpcaitech -p ColossalAI - docer docusaurus -d ../docusaurus - cd "${DOCUSAURUS_DIR}" - yarn install - yarn build + list: | + from colossalai.shardformer import ShardConfig From 1cd60a005daef9f2d3c685f153ad7a7ffc50c8f5 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 23 Dec 2024 18:03:29 +0800 Subject: [PATCH 38/54] [fix] fix workflow --- .github/workflows/doc_check_on_pr.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/doc_check_on_pr.yml b/.github/workflows/doc_check_on_pr.yml index ac22b5c676fa..fcf49c778d98 100644 --- a/.github/workflows/doc_check_on_pr.yml +++ b/.github/workflows/doc_check_on_pr.yml @@ -73,5 +73,6 @@ jobs: cd ColossalAI-Documentation pip install -v ./doc-build/third_party/hf-doc-builder pip install -v ./doc-build + bash ./scripts/build.sh list: | from colossalai.shardformer import ShardConfig From 573d5ced46939476f9767ad23ac1504f79bbb9d0 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 23 Dec 2024 18:13:14 +0800 Subject: [PATCH 39/54] [fix] fix workflow --- .github/workflows/doc_check_on_pr.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/doc_check_on_pr.yml b/.github/workflows/doc_check_on_pr.yml index fcf49c778d98..ee8a82128dd7 100644 --- a/.github/workflows/doc_check_on_pr.yml +++ b/.github/workflows/doc_check_on_pr.yml @@ -74,5 +74,3 @@ jobs: pip install -v ./doc-build/third_party/hf-doc-builder pip install -v ./doc-build bash ./scripts/build.sh - list: | - from colossalai.shardformer import ShardConfig From 938bf6db0c29072f2400437aec0f0eba3a42b2f5 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 23 Dec 2024 18:18:03 +0800 Subject: [PATCH 40/54] [fix] fix workflow --- .github/workflows/doc_check_on_pr.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/doc_check_on_pr.yml b/.github/workflows/doc_check_on_pr.yml index ee8a82128dd7..68e13a971e7e 100644 --- a/.github/workflows/doc_check_on_pr.yml +++ b/.github/workflows/doc_check_on_pr.yml @@ -58,6 +58,7 @@ jobs: # there is no main branch, so it's safe to checkout the main branch from the merged branch # docer will rebase the remote main branch to the merged branch, so we have to config user - name: Make the merged branch main + run: | cd ColossalAI git checkout -b main From 90d1d538a656697e2d649a5780c4ce38bb804e6d Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 23 Dec 2024 19:04:49 +0800 Subject: [PATCH 41/54] [fix] fix workflow --- .github/workflows/doc_check_on_pr.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/doc_check_on_pr.yml b/.github/workflows/doc_check_on_pr.yml index 68e13a971e7e..33932bf62779 100644 --- a/.github/workflows/doc_check_on_pr.yml +++ b/.github/workflows/doc_check_on_pr.yml @@ -74,4 +74,3 @@ jobs: cd ColossalAI-Documentation pip install -v ./doc-build/third_party/hf-doc-builder pip install -v ./doc-build - bash ./scripts/build.sh From aab6275eb58b027014ac0a000899208c0d04a04f Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 23 Dec 2024 19:27:46 +0800 Subject: [PATCH 42/54] [fix] fix ci --- .github/workflows/doc_check_on_pr.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/doc_check_on_pr.yml b/.github/workflows/doc_check_on_pr.yml index 33932bf62779..68e13a971e7e 100644 --- a/.github/workflows/doc_check_on_pr.yml +++ b/.github/workflows/doc_check_on_pr.yml @@ -74,3 +74,4 @@ jobs: cd ColossalAI-Documentation pip install -v ./doc-build/third_party/hf-doc-builder pip install -v ./doc-build + bash ./scripts/build.sh From 63b7db59b6b6597ea4b9502719fe16e89ecf33d2 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 23 Dec 2024 19:32:49 +0800 Subject: [PATCH 43/54] [fix] fix zbv doc --- docs/source/en/features/zerobubble_pipeline_parallelism.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/features/zerobubble_pipeline_parallelism.md b/docs/source/en/features/zerobubble_pipeline_parallelism.md index 8f676de96459..ec67fc8f3bf5 100644 --- a/docs/source/en/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/en/features/zerobubble_pipeline_parallelism.md @@ -234,6 +234,6 @@ Performance Benchmark ## API Reference -{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.py.ZeroBubbleVPipeScheduler }} +{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }} From 7fb23a57d83cd9f876d9024079ab922f5e771852 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 24 Dec 2024 14:25:10 +0800 Subject: [PATCH 44/54] [fix] fix param for qkv linear, gpt2fused linear; fix requirments; --- colossalai/shardformer/layer/__init__.py | 2 +- colossalai/shardformer/layer/qkv_fused_linear.py | 4 ---- requirements/requirements.txt | 2 -- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 850001d04227..0bd1b60923e9 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -21,7 +21,7 @@ "LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row", - "GPT2FusedLinearConv1D_Col", + "GPT2FusedLinearConv", "GPT2FusedLinearConv1D_Row", "GPT2FusedLinearConv1D_Col", "DropoutForParallelInput", diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 5edcfd9b8a00..b4c9c4e36868 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -832,7 +832,6 @@ def __init__( bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - fp8_communication: bool = False, use_zbv: bool = False, ): super().__init__() @@ -846,7 +845,6 @@ def __init__( self.device = device self.split_sizes = split_sizes self.process_group = process_group - self.fp8_communication = fp8_communication self.use_zbv = use_zbv assert ( @@ -1246,7 +1244,6 @@ def __init__( bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - fp8_communication: bool = False, use_zbv: bool = False, ): super().__init__() @@ -1257,7 +1254,6 @@ def __init__( self.seq_parallel_dim = seq_parallel_dim self.skip_bias_add = skip_bias_add self.device = device - self.fp8_communication = fp8_communication self.use_zbv = use_zbv if skip_bias_add and not bias: diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 50ff1b47e442..f357c45fde64 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -24,5 +24,3 @@ fastapi uvicorn==0.29.0 galore_torch diffusers==0.29.0 -pyramid<=1.10.7 -zope<=5.5.2 From f0a8d78edc04d4c957693f87cc1ba87d61a00307 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 24 Dec 2024 15:53:01 +0800 Subject: [PATCH 45/54] [fix] fix policy use fused_linear --- colossalai/shardformer/layer/qkv_fused_linear.py | 4 ++-- colossalai/shardformer/policies/blip2.py | 1 - colossalai/shardformer/policies/sam.py | 1 - 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index b4c9c4e36868..e3aaa3a4635c 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -669,7 +669,6 @@ def __init__( bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - fp8_communication: bool = False, use_zbv: bool = False, ): super().__init__() @@ -680,7 +679,6 @@ def __init__( self.seq_parallel_dim = seq_parallel_dim self.skip_bias_add = skip_bias_add self.device = device - self.fp8_communication = fp8_communication self.use_zbv = use_zbv if skip_bias_add and not bias: @@ -832,6 +830,7 @@ def __init__( bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + fp8_communication: bool = False, use_zbv: bool = False, ): super().__init__() @@ -845,6 +844,7 @@ def __init__( self.device = device self.split_sizes = split_sizes self.process_group = process_group + self.fp8_communication = fp8_communication self.use_zbv = use_zbv assert ( diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 246c4616009e..a2f582e9daa6 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -310,7 +310,6 @@ def module_policy(self): target_module=col_nn.FusedLinear, kwargs={ "split_sizes": [self.model.config.vision_config.hidden_size] * 3, - "fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv, }, ), diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index 237db386930f..f37167afffff 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -251,7 +251,6 @@ def module_policy(self): target_module=col_nn.FusedLinear, kwargs={ "split_sizes": [self.model.config.vision_config.hidden_size] * 3, - "fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv, }, ), From ff316c9ddaf0369b578ed406db85791ca350cffd Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 24 Dec 2024 17:10:36 +0800 Subject: [PATCH 46/54] [fix] fix weight grad none, err caused by weight ptr change --- colossalai/pipeline/weight_grad_store.py | 22 ++++++++++++++----- colossalai/shardformer/layer/_operation.py | 21 ++++++++++-------- .../test_gpt2_qkv_fused_linear_1d.py | 5 ++--- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/colossalai/pipeline/weight_grad_store.py b/colossalai/pipeline/weight_grad_store.py index 8c1b64e0ee57..3117621982eb 100644 --- a/colossalai/pipeline/weight_grad_store.py +++ b/colossalai/pipeline/weight_grad_store.py @@ -20,11 +20,23 @@ def pop(cls, chunk=0): if cls.weight_grad_queue[chunk].qsize() > 0: stored_grads = cls.weight_grad_queue[chunk].get() for total_input, grad_output, weight, func in stored_grads: - if weight.grad is not None: - func(total_input, grad_output, weight.grad) - # for first bwd; weight.grad is None, assign grad_weight to weight.grad + if isinstance(weight, tuple): + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. + # View will lead to weight ptr change + # weight_cal & weight_origin in tuple, weight_cal use to cal dw, weight_origin use to update + weight_cal, weight_origin = weight + if weight_origin.grad is not None: + func(total_input, grad_output, weight_origin) + # for first bwd; weight.grad is None, assign grad_weight to weight.grad + else: + grad_weight = func(total_input, grad_output) + weight_origin.grad = grad_weight else: - grad_weight = func(total_input, grad_output) - weight.grad = grad_weight + if weight.grad is not None: + func(total_input, grad_output, weight.grad) + # for first bwd; weight.grad is None, assign grad_weight to weight.grad + else: + grad_weight = func(total_input, grad_output) + weight.grad = grad_weight else: raise Exception("Pop empty queue.") diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index f2c02d249f04..9104126a43e1 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -96,6 +96,7 @@ def backward(ctx, grad_output): use_zbv = ctx.use_zbv # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. + weight_origin = weight weight = weight.view(weight.shape) if bias is not None: bias = bias.view(bias.shape) @@ -130,7 +131,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f WeightGradStore.put( total_input, grad_output, - weight, + (weight, weight_origin), functools.partial( execute_w_pass_grad_accum, wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, @@ -141,7 +142,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f WeightGradStore.put( total_input, grad_output, - weight, + (weight, weight_origin), functools.partial( execute_w_pass_grad_accum, wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, @@ -164,7 +165,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f WeightGradStore.put( total_input, grad_output, - weight, + (weight, weight_origin), functools.partial( execute_w_pass, wgrad_gemm_func=torch.matmul, @@ -212,6 +213,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f return wgrad_gemm_func(_input_.t(), _grad_output_) # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. + weight_origin = weight weight = weight.view(weight.shape) if bias is not None: bias = bias.view(bias.shape) @@ -232,7 +234,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f WeightGradStore.put( total_input, grad_output, - weight, + (weight, weight_origin), functools.partial( execute_w_pass_grad_accum, wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, @@ -243,7 +245,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f WeightGradStore.put( total_input, grad_output, - weight, + (weight, weight_origin), functools.partial( execute_w_pass_grad_accum, wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, @@ -266,7 +268,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f WeightGradStore.put( total_input, grad_output, - weight, + (weight, weight_origin), functools.partial( execute_w_pass, wgrad_gemm_func=torch.matmul, @@ -1026,6 +1028,7 @@ def backward(ctx, grad_output): use_zbv = ctx.use_zbv # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm + weight_origin = weight weight = weight.view(weight.shape) if use_bias: bias = bias.view(bias.shape) @@ -1064,7 +1067,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f WeightGradStore.put( total_input, grad_output, - weight, + (weight, weight_origin), functools.partial( execute_w_pass_grad_accum, wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, @@ -1075,7 +1078,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f WeightGradStore.put( total_input, grad_output, - weight, + (weight, weight_origin), functools.partial( execute_w_pass_grad_accum, wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, @@ -1098,7 +1101,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f WeightGradStore.put( total_input, grad_output, - weight, + (weight, weight_origin), functools.partial( execute_w_pass, wgrad_gemm_func=torch.matmul, diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index 53cd9721e6c9..5868933b5925 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -185,11 +185,10 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo # check the input gradients & weight gradients assert_close(out.grad, gather_out.grad) - # TODO:linear_base.weight.grad is None; But not none in WeightGradStore - # assert_close(linear.weight.grad, linear_base.weight.grad) + assert_close(linear.weight.grad, linear_base.weight.grad) -@parameterize("lazy_init", [False, True]) +@parameterize("lazy_init", [False]) @parameterize("seq_parallel_mode", ["split_gather", None]) def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool): check_linear_conv_1d_col(lazy_init, seq_parallel_mode) From f52c36e9448c0617912e0a5dd0efcfab3de4080a Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 24 Dec 2024 17:18:24 +0800 Subject: [PATCH 47/54] [fix] fix comm in WeightGradStore --- colossalai/pipeline/weight_grad_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/pipeline/weight_grad_store.py b/colossalai/pipeline/weight_grad_store.py index 3117621982eb..e2da7c9de521 100644 --- a/colossalai/pipeline/weight_grad_store.py +++ b/colossalai/pipeline/weight_grad_store.py @@ -24,7 +24,7 @@ def pop(cls, chunk=0): # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. # View will lead to weight ptr change # weight_cal & weight_origin in tuple, weight_cal use to cal dw, weight_origin use to update - weight_cal, weight_origin = weight + _, weight_origin = weight if weight_origin.grad is not None: func(total_input, grad_output, weight_origin) # for first bwd; weight.grad is None, assign grad_weight to weight.grad From feca06e1057b2e8e63fc5963e479aa313990e8e8 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 25 Dec 2024 11:36:16 +0800 Subject: [PATCH 48/54] [fix] fix WeightGradStore pop param --- colossalai/pipeline/weight_grad_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/pipeline/weight_grad_store.py b/colossalai/pipeline/weight_grad_store.py index e2da7c9de521..1a9ef142156d 100644 --- a/colossalai/pipeline/weight_grad_store.py +++ b/colossalai/pipeline/weight_grad_store.py @@ -26,7 +26,7 @@ def pop(cls, chunk=0): # weight_cal & weight_origin in tuple, weight_cal use to cal dw, weight_origin use to update _, weight_origin = weight if weight_origin.grad is not None: - func(total_input, grad_output, weight_origin) + func(total_input, grad_output, weight_origin.grad) # for first bwd; weight.grad is None, assign grad_weight to weight.grad else: grad_weight = func(total_input, grad_output) From d74071ae6d08a988aae609918016d5cc06e7b2b1 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 25 Dec 2024 13:38:14 +0800 Subject: [PATCH 49/54] [fix] remove useless param in doc; fix gpt2 qkv test; --- docs/source/en/features/zerobubble_pipeline_parallelism.md | 3 +-- .../source/zh-Hans/features/zerobubble_pipeline_parallelism.md | 3 +-- .../test_layer/test_gpt2_qkv_fused_linear_1d.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/source/en/features/zerobubble_pipeline_parallelism.md b/docs/source/en/features/zerobubble_pipeline_parallelism.md index ec67fc8f3bf5..c66cf5f77376 100644 --- a/docs/source/en/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/en/features/zerobubble_pipeline_parallelism.md @@ -36,11 +36,10 @@ Build our model and Optimizer. We created a Llama with 8 Decoder-Layer. Then, in ```python # Global Param NUM_BATCH = 8 -NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 +NUM_TOK_PER_BATCH = 4 NUM_LAYERS = 8 HIDDEN_SIZE_PER_HEAD = 4 NUM_HEADS = 4 -TOP_K = 1 # Init Llama from huggingface configuration = LlamaConfig( hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, diff --git a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md index fd361dee048f..4ab07c9b91b0 100644 --- a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md @@ -37,11 +37,10 @@ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, ```python # Global Param NUM_BATCH = 8 -NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 +NUM_TOK_PER_BATCH = 4 NUM_LAYERS = 8 HIDDEN_SIZE_PER_HEAD = 4 NUM_HEADS = 4 -TOP_K = 1 # Init Llama from huggingface configuration = LlamaConfig( hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index 5868933b5925..be1c24818424 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -188,7 +188,7 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo assert_close(linear.weight.grad, linear_base.weight.grad) -@parameterize("lazy_init", [False]) +@parameterize("lazy_init", [False, True]) @parameterize("seq_parallel_mode", ["split_gather", None]) def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool): check_linear_conv_1d_col(lazy_init, seq_parallel_mode) From c0b6fbc3cd910db4fd9900ef6fe6cd2d33882ddd Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 25 Dec 2024 14:02:52 +0800 Subject: [PATCH 50/54] [shardformer] simplify execute_w_pass_grad_accum; --- colossalai/shardformer/layer/_operation.py | 456 +++++++++++++-------- 1 file changed, 281 insertions(+), 175 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 9104126a43e1..e1e2acf3b52e 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -117,7 +117,13 @@ def backward(ctx, grad_output): # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): + if total_input.dtype == torch.float32: + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 + elif total_input.dtype in (torch.float16, torch.bfloat16): + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_) def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): @@ -127,30 +133,39 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad if use_zbv: - if grad.dtype == torch.float32: - WeightGradStore.put( - total_input, - grad_output, - (weight, weight_origin), - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, - ), - ) - grad_weight = None - elif grad.dtype in (torch.float16, torch.bfloat16): - WeightGradStore.put( - total_input, - grad_output, - (weight, weight_origin), - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, - ), - ) - grad_weight = None - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + WeightGradStore.put( + total_input, + grad_output, + (weight, weight_origin), + functools.partial( + execute_w_pass_grad_accum, + ), + ) + grad_weight = None + # if grad.dtype == torch.float32: + # WeightGradStore.put( + # total_input, + # grad_output, + # (weight, weight_origin), + # functools.partial( + # execute_w_pass_grad_accum, + # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, + # ), + # ) + # grad_weight = None + # elif grad.dtype in (torch.float16, torch.bfloat16): + # WeightGradStore.put( + # total_input, + # grad_output, + # (weight, weight_origin), + # functools.partial( + # execute_w_pass_grad_accum, + # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, + # ), + # ) + # grad_weight = None + # else: + # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") else: if grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) @@ -206,7 +221,13 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias use_zbv = ctx.use_zbv - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): + if total_input.dtype == torch.float32: + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 + elif total_input.dtype in (torch.float16, torch.bfloat16): + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_) def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): @@ -229,31 +250,41 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f # split dx & dw if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad + if use_zbv: - if grad.dtype == torch.float32: - WeightGradStore.put( - total_input, - grad_output, - (weight, weight_origin), - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, - ), - ) - grad_weight = None - elif grad.dtype in (torch.float16, torch.bfloat16): - WeightGradStore.put( - total_input, - grad_output, - (weight, weight_origin), - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, - ), - ) - grad_weight = None - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + WeightGradStore.put( + total_input, + grad_output, + (weight, weight_origin), + functools.partial( + execute_w_pass_grad_accum, + ), + ) + grad_weight = None + # if grad.dtype == torch.float32: + # WeightGradStore.put( + # total_input, + # grad_output, + # (weight, weight_origin), + # functools.partial( + # execute_w_pass_grad_accum, + # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, + # ), + # ) + # grad_weight = None + # elif grad.dtype in (torch.float16, torch.bfloat16): + # WeightGradStore.put( + # total_input, + # grad_output, + # (weight, weight_origin), + # functools.partial( + # execute_w_pass_grad_accum, + # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, + # ), + # ) + # grad_weight = None + # else: + # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") else: if grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) @@ -310,7 +341,13 @@ def backward(ctx, grad_output): fp8_communication = ctx.fp8_communication use_zbv = ctx.use_zbv - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): + if total_input.dtype == torch.float32: + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 + elif total_input.dtype in (torch.float16, torch.bfloat16): + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): @@ -339,30 +376,39 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad if use_zbv: - if grad.dtype == torch.float32: - WeightGradStore.put( - total_input, - grad_output, - weight, - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, - ), - ) - grad_weight = None - elif grad.dtype in (torch.float16, torch.bfloat16): - WeightGradStore.put( - total_input, - grad_output, - weight, - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, - ), - ) - grad_weight = None - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + ), + ) + grad_weight = None + # if grad.dtype == torch.float32: + # WeightGradStore.put( + # total_input, + # grad_output, + # weight, + # functools.partial( + # execute_w_pass_grad_accum, + # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, + # ), + # ) + # grad_weight = None + # elif grad.dtype in (torch.float16, torch.bfloat16): + # WeightGradStore.put( + # total_input, + # grad_output, + # weight, + # functools.partial( + # execute_w_pass_grad_accum, + # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, + # ), + # ) + # grad_weight = None + # else: + # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") else: if grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) @@ -418,7 +464,13 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias use_zbv = ctx.use_zbv - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): + if total_input.dtype == torch.float32: + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 + elif total_input.dtype in (torch.float16, torch.bfloat16): + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): @@ -439,30 +491,39 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad if use_zbv: - if grad.dtype == torch.float32: - WeightGradStore.put( - total_input, - grad_output, - weight, - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, - ), - ) - grad_weight = None - elif grad.dtype in (torch.float16, torch.bfloat16): - WeightGradStore.put( - total_input, - grad_output, - weight, - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, - ), - ) - grad_weight = None - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + ), + ) + grad_weight = None + # if grad.dtype == torch.float32: + # WeightGradStore.put( + # total_input, + # grad_output, + # weight, + # functools.partial( + # execute_w_pass_grad_accum, + # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, + # ), + # ) + # grad_weight = None + # elif grad.dtype in (torch.float16, torch.bfloat16): + # WeightGradStore.put( + # total_input, + # grad_output, + # weight, + # functools.partial( + # execute_w_pass_grad_accum, + # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, + # ), + # ) + # grad_weight = None + # else: + # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") else: if grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) @@ -676,7 +737,13 @@ def backward(ctx, grad_output): # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): + if total_input.dtype == torch.float32: + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 + elif total_input.dtype in (torch.float16, torch.bfloat16): + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): @@ -685,30 +752,39 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad if use_zbv: - if grad.dtype == torch.float32: - WeightGradStore.put( - total_input, - grad_output, - weight, - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, - ), - ) - grad_weight = None - elif grad.dtype in (torch.float16, torch.bfloat16): - WeightGradStore.put( - total_input, - grad_output, - weight, - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, - ), - ) - grad_weight = None - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + ), + ) + grad_weight = None + # if grad.dtype == torch.float32: + # WeightGradStore.put( + # total_input, + # grad_output, + # weight, + # functools.partial( + # execute_w_pass_grad_accum, + # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, + # ), + # ) + # grad_weight = None + # elif grad.dtype in (torch.float16, torch.bfloat16): + # WeightGradStore.put( + # total_input, + # grad_output, + # weight, + # functools.partial( + # execute_w_pass_grad_accum, + # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, + # ), + # ) + # grad_weight = None + # else: + # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") else: if grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) @@ -872,7 +948,13 @@ def backward(ctx, grad_output): grad_output = grad_output.view(-1, grad_output.shape[-1]) total_input = total_input.reshape(-1, total_input.shape[-1]) - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): + if total_input.dtype == torch.float32: + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 + elif total_input.dtype in (torch.float16, torch.bfloat16): + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): @@ -881,30 +963,39 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad if use_zbv: - if grad.dtype == torch.float32: - WeightGradStore.put( - total_input, - grad_output, - weight, - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, - ), - ) - grad_weight = None - elif grad.dtype in (torch.float16, torch.bfloat16): - WeightGradStore.put( - total_input, - grad_output, - weight, - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, - ), - ) - grad_weight = None - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + ), + ) + grad_weight = None + # if grad.dtype == torch.float32: + # WeightGradStore.put( + # total_input, + # grad_output, + # weight, + # functools.partial( + # execute_w_pass_grad_accum, + # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, + # ), + # ) + # grad_weight = None + # elif grad.dtype in (torch.float16, torch.bfloat16): + # WeightGradStore.put( + # total_input, + # grad_output, + # weight, + # functools.partial( + # execute_w_pass_grad_accum, + # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, + # ), + # ) + # grad_weight = None + # else: + # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") else: if grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) @@ -1053,7 +1144,13 @@ def backward(ctx, grad_output): # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): + if total_input.dtype == torch.float32: + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 + elif total_input.dtype in (torch.float16, torch.bfloat16): + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_) def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): @@ -1063,30 +1160,39 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad if use_zbv: - if grad.dtype == torch.float32: - WeightGradStore.put( - total_input, - grad_output, - (weight, weight_origin), - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, - ), - ) - grad_weight = None - elif grad.dtype in (torch.float16, torch.bfloat16): - WeightGradStore.put( - total_input, - grad_output, - (weight, weight_origin), - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, - ), - ) - grad_weight = None - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + WeightGradStore.put( + total_input, + grad_output, + (weight, weight_origin), + functools.partial( + execute_w_pass_grad_accum, + ), + ) + grad_weight = None + # if grad.dtype == torch.float32: + # WeightGradStore.put( + # total_input, + # grad_output, + # (weight, weight_origin), + # functools.partial( + # execute_w_pass_grad_accum, + # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, + # ), + # ) + # grad_weight = None + # elif grad.dtype in (torch.float16, torch.bfloat16): + # WeightGradStore.put( + # total_input, + # grad_output, + # (weight, weight_origin), + # functools.partial( + # execute_w_pass_grad_accum, + # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, + # ), + # ) + # grad_weight = None + # else: + # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") else: if grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) From 130b50caac9ffa1111ac108381f82f6ad7293bb2 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 25 Dec 2024 14:06:00 +0800 Subject: [PATCH 51/54] [fix] rm useless comments --- colossalai/shardformer/layer/_operation.py | 168 --------------------- 1 file changed, 168 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index e1e2acf3b52e..353c93cca370 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -142,30 +142,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f ), ) grad_weight = None - # if grad.dtype == torch.float32: - # WeightGradStore.put( - # total_input, - # grad_output, - # (weight, weight_origin), - # functools.partial( - # execute_w_pass_grad_accum, - # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, - # ), - # ) - # grad_weight = None - # elif grad.dtype in (torch.float16, torch.bfloat16): - # WeightGradStore.put( - # total_input, - # grad_output, - # (weight, weight_origin), - # functools.partial( - # execute_w_pass_grad_accum, - # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, - # ), - # ) - # grad_weight = None - # else: - # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") else: if grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) @@ -261,30 +237,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f ), ) grad_weight = None - # if grad.dtype == torch.float32: - # WeightGradStore.put( - # total_input, - # grad_output, - # (weight, weight_origin), - # functools.partial( - # execute_w_pass_grad_accum, - # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, - # ), - # ) - # grad_weight = None - # elif grad.dtype in (torch.float16, torch.bfloat16): - # WeightGradStore.put( - # total_input, - # grad_output, - # (weight, weight_origin), - # functools.partial( - # execute_w_pass_grad_accum, - # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, - # ), - # ) - # grad_weight = None - # else: - # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") else: if grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) @@ -385,30 +337,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f ), ) grad_weight = None - # if grad.dtype == torch.float32: - # WeightGradStore.put( - # total_input, - # grad_output, - # weight, - # functools.partial( - # execute_w_pass_grad_accum, - # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, - # ), - # ) - # grad_weight = None - # elif grad.dtype in (torch.float16, torch.bfloat16): - # WeightGradStore.put( - # total_input, - # grad_output, - # weight, - # functools.partial( - # execute_w_pass_grad_accum, - # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, - # ), - # ) - # grad_weight = None - # else: - # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") else: if grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) @@ -500,30 +428,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f ), ) grad_weight = None - # if grad.dtype == torch.float32: - # WeightGradStore.put( - # total_input, - # grad_output, - # weight, - # functools.partial( - # execute_w_pass_grad_accum, - # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, - # ), - # ) - # grad_weight = None - # elif grad.dtype in (torch.float16, torch.bfloat16): - # WeightGradStore.put( - # total_input, - # grad_output, - # weight, - # functools.partial( - # execute_w_pass_grad_accum, - # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, - # ), - # ) - # grad_weight = None - # else: - # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") else: if grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) @@ -761,30 +665,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f ), ) grad_weight = None - # if grad.dtype == torch.float32: - # WeightGradStore.put( - # total_input, - # grad_output, - # weight, - # functools.partial( - # execute_w_pass_grad_accum, - # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, - # ), - # ) - # grad_weight = None - # elif grad.dtype in (torch.float16, torch.bfloat16): - # WeightGradStore.put( - # total_input, - # grad_output, - # weight, - # functools.partial( - # execute_w_pass_grad_accum, - # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, - # ), - # ) - # grad_weight = None - # else: - # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") else: if grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) @@ -972,30 +852,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f ), ) grad_weight = None - # if grad.dtype == torch.float32: - # WeightGradStore.put( - # total_input, - # grad_output, - # weight, - # functools.partial( - # execute_w_pass_grad_accum, - # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, - # ), - # ) - # grad_weight = None - # elif grad.dtype in (torch.float16, torch.bfloat16): - # WeightGradStore.put( - # total_input, - # grad_output, - # weight, - # functools.partial( - # execute_w_pass_grad_accum, - # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, - # ), - # ) - # grad_weight = None - # else: - # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") else: if grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) @@ -1169,30 +1025,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f ), ) grad_weight = None - # if grad.dtype == torch.float32: - # WeightGradStore.put( - # total_input, - # grad_output, - # (weight, weight_origin), - # functools.partial( - # execute_w_pass_grad_accum, - # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, - # ), - # ) - # grad_weight = None - # elif grad.dtype in (torch.float16, torch.bfloat16): - # WeightGradStore.put( - # total_input, - # grad_output, - # (weight, weight_origin), - # functools.partial( - # execute_w_pass_grad_accum, - # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, - # ), - # ) - # grad_weight = None - # else: - # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") else: if grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) From c4df1ccb33f4700ecfd283c23615ecc57d2706a7 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 25 Dec 2024 17:12:12 +0800 Subject: [PATCH 52/54] [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass --- colossalai/shardformer/layer/_operation.py | 104 +++------------------ colossalai/shardformer/layer/utils.py | 37 ++++++++ 2 files changed, 50 insertions(+), 91 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 353c93cca370..0252f90e1c27 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -6,7 +6,13 @@ from colossalai.pipeline.weight_grad_store import WeightGradStore -from .utils import is_share_sp_tp +from .utils import ( + execute_conv1d_w_pass, + execute_conv1d_w_pass_grad_accum, + execute_w_pass, + execute_w_pass_grad_accum, + is_share_sp_tp, +) try: import fused_mix_prec_layer_norm_cuda @@ -117,18 +123,6 @@ def backward(ctx, grad_output): # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): - if total_input.dtype == torch.float32: - wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 - elif total_input.dtype in (torch.float16, torch.bfloat16): - wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") - wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_) - - def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): - return wgrad_gemm_func(_input_.t(), _grad_output_) - # split dx & dw if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad @@ -138,7 +132,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f grad_output, (weight, weight_origin), functools.partial( - execute_w_pass_grad_accum, + execute_conv1d_w_pass_grad_accum, ), ) grad_weight = None @@ -158,7 +152,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f grad_output, (weight, weight_origin), functools.partial( - execute_w_pass, + execute_conv1d_w_pass, wgrad_gemm_func=torch.matmul, ), ) @@ -197,18 +191,6 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias use_zbv = ctx.use_zbv - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): - if total_input.dtype == torch.float32: - wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 - elif total_input.dtype in (torch.float16, torch.bfloat16): - wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") - wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_) - - def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): - return wgrad_gemm_func(_input_.t(), _grad_output_) - # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. weight_origin = weight weight = weight.view(weight.shape) @@ -233,7 +215,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f grad_output, (weight, weight_origin), functools.partial( - execute_w_pass_grad_accum, + execute_conv1d_w_pass_grad_accum, ), ) grad_weight = None @@ -253,7 +235,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f grad_output, (weight, weight_origin), functools.partial( - execute_w_pass, + execute_conv1d_w_pass, wgrad_gemm_func=torch.matmul, ), ) @@ -293,18 +275,6 @@ def backward(ctx, grad_output): fp8_communication = ctx.fp8_communication use_zbv = ctx.use_zbv - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): - if total_input.dtype == torch.float32: - wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 - elif total_input.dtype in (torch.float16, torch.bfloat16): - wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") - wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) - - def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): - return wgrad_gemm_func(_grad_output_.t(), _input_) - # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. if use_bias: bias.view(bias.shape) @@ -392,18 +362,6 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias use_zbv = ctx.use_zbv - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): - if total_input.dtype == torch.float32: - wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 - elif total_input.dtype in (torch.float16, torch.bfloat16): - wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") - wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) - - def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): - return wgrad_gemm_func(_grad_output_.t(), _input_) - # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. if use_bias: bias.view(bias.shape) @@ -641,18 +599,6 @@ def backward(ctx, grad_output): # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): - if total_input.dtype == torch.float32: - wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 - elif total_input.dtype in (torch.float16, torch.bfloat16): - wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") - wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) - - def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): - return wgrad_gemm_func(_grad_output_.t(), _input_) - if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad if use_zbv: @@ -828,18 +774,6 @@ def backward(ctx, grad_output): grad_output = grad_output.view(-1, grad_output.shape[-1]) total_input = total_input.reshape(-1, total_input.shape[-1]) - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): - if total_input.dtype == torch.float32: - wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 - elif total_input.dtype in (torch.float16, torch.bfloat16): - wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") - wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) - - def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): - return wgrad_gemm_func(_grad_output_.t(), _input_) - if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad if use_zbv: @@ -1000,18 +934,6 @@ def backward(ctx, grad_output): # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): - if total_input.dtype == torch.float32: - wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 - elif total_input.dtype in (torch.float16, torch.bfloat16): - wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") - wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_) - - def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): - return wgrad_gemm_func(_input_.t(), _grad_output_) - # split dx & dw if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad @@ -1021,7 +943,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f grad_output, (weight, weight_origin), functools.partial( - execute_w_pass_grad_accum, + execute_conv1d_w_pass_grad_accum, ), ) grad_weight = None @@ -1041,7 +963,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f grad_output, (weight, weight_origin), functools.partial( - execute_w_pass, + execute_conv1d_w_pass, wgrad_gemm_func=torch.matmul, ), ) diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 2df68e18c64d..6ce1eb79df3f 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -9,6 +9,43 @@ from colossalai.accelerator import get_accelerator +try: + import fused_weight_gradient_mlp_cuda + + _grad_accum_fusion_available = True +except ImportError: + _grad_accum_fusion_available = False + + +# execute_w_pass_grad_accum & execute_conv1d_w_pass for GPT2FusedLinearConv1D +def execute_conv1d_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): + if _input_.dtype == torch.float32: + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 + elif _input_.dtype in (torch.float16, torch.bfloat16): + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_) + + +def execute_conv1d_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): + return wgrad_gemm_func(_input_.t(), _grad_output_) + + +# execute_w_pass_grad_accum & execute_w_pass for Linear (except GPT2FusedLinearConv1D) +def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): + if _input_.dtype == torch.float32: + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 + elif _input_.dtype in (torch.float16, torch.bfloat16): + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) + + +def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): + return wgrad_gemm_func(_grad_output_.t(), _input_) + class SeqParallelUtils: @staticmethod From 52a3b88978d470494af4e1083dad21bbf411b521 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 25 Dec 2024 17:16:06 +0800 Subject: [PATCH 53/54] [shardformer] Run meaningful doc test --- docs/source/en/features/zerobubble_pipeline_parallelism.md | 2 +- docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/features/zerobubble_pipeline_parallelism.md b/docs/source/en/features/zerobubble_pipeline_parallelism.md index c66cf5f77376..4407b18e7e17 100644 --- a/docs/source/en/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/en/features/zerobubble_pipeline_parallelism.md @@ -235,4 +235,4 @@ Performance Benchmark ## API Reference {{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }} - + diff --git a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md index 4ab07c9b91b0..f0d0790788e5 100644 --- a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md @@ -234,4 +234,4 @@ plugin = HybridParallelPlugin( ## API 参考 {{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }} - + From ee6bba96fd30e1e300e95dab084b340db4dbb081 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 25 Dec 2024 17:49:41 +0800 Subject: [PATCH 54/54] [shadformer] fix doc test cmd; --- docs/source/en/features/zerobubble_pipeline_parallelism.md | 2 +- docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/features/zerobubble_pipeline_parallelism.md b/docs/source/en/features/zerobubble_pipeline_parallelism.md index 4407b18e7e17..1f88815fcbb1 100644 --- a/docs/source/en/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/en/features/zerobubble_pipeline_parallelism.md @@ -235,4 +235,4 @@ Performance Benchmark ## API Reference {{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }} - + diff --git a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md index f0d0790788e5..70e9e4c98631 100644 --- a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md +++ b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md @@ -234,4 +234,4 @@ plugin = HybridParallelPlugin( ## API 参考 {{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }} - +