From 49b97ee9c1f88501e7362e428d70b74f6e998e0b Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Sun, 2 Feb 2025 23:11:09 +0800 Subject: [PATCH 01/10] Update hybrid_parallel_plugin.py --- .../booster/plugin/hybrid_parallel_plugin.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index bc9425a0b0cd..07d378028b77 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1188,6 +1188,13 @@ def __init__( else: self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) + # sync gradients across DP * SP ranks + # sync gradients across DP * SP ranks + # Apply Hybrid ZeRO across DP * SP ranks + if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): + self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) + self.dp_size = get_world_size(self.dp_group) + self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, sequence_parallel_process_group=self.sp_group, @@ -1298,19 +1305,11 @@ def configure( use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( self.dp_size == 1 and self.pp_size == 1 ) - # sync gradients across DP * SP ranks - # sync gradients across DP * SP ranks - # Apply Hybrid ZeRO across DP * SP ranks - if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): - dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) - self.dp_size = get_world_size(dp_group) - else: - dp_group = self.dp_group model = HybridParallelModule( model, precision=self.precision, shard_config=self.shard_config, - dp_group=dp_group, + dp_group=self.dp_group, tp_group=self.tp_group, sp_group=self.sp_group, use_ddp=use_ddp, @@ -1359,7 +1358,7 @@ def configure( model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info, - dp_process_group=dp_group, + dp_process_group=self.dp_group, tp_process_group=self.tp_group, pp_process_group=self.pp_group, verbose=True, From 1cfafb3a40cf6f58927269f3a6c775485bc06092 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 3 Feb 2025 15:34:15 +0800 Subject: [PATCH 02/10] Update hybrid_parallel_plugin.py --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 07d378028b77..a3b335ac6695 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1192,8 +1192,10 @@ def __init__( # sync gradients across DP * SP ranks # Apply Hybrid ZeRO across DP * SP ranks if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): - self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) - self.dp_size = get_world_size(self.dp_group) + self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) + self.dp_size = get_world_size(self.mixed_dp_group) + else: + self.mixed_dp_group = self.dp_group self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, @@ -1309,7 +1311,7 @@ def configure( model, precision=self.precision, shard_config=self.shard_config, - dp_group=self.dp_group, + dp_group=self.mixed_dp_group, tp_group=self.tp_group, sp_group=self.sp_group, use_ddp=use_ddp, @@ -1358,7 +1360,7 @@ def configure( model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info, - dp_process_group=self.dp_group, + dp_process_group=self.mixed_dp_group, tp_process_group=self.tp_group, pp_process_group=self.pp_group, verbose=True, From bc1879a5c25c286bd44d4dbe69c9a8065f46a54f Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 3 Feb 2025 16:07:00 +0800 Subject: [PATCH 03/10] Update hybrid_parallel_plugin.py --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index a3b335ac6695..cd51f07a5409 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1489,7 +1489,7 @@ def seed_worker(worker_id): ) def get_checkpoint_io(self) -> CheckpointIO: - return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage) + return HybridParallelCheckpointIO(self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage) def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert ( From c177b7fddf179d11080cd49ced2d809173789c0e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Feb 2025 08:07:53 +0000 Subject: [PATCH 04/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index cd51f07a5409..62046bc36af1 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1489,7 +1489,9 @@ def seed_worker(worker_id): ) def get_checkpoint_io(self) -> CheckpointIO: - return HybridParallelCheckpointIO(self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage) + return HybridParallelCheckpointIO( + self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage + ) def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert ( From c9c1429542f8f130fe2218f984518681aa9a3467 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 5 Feb 2025 10:26:54 +0800 Subject: [PATCH 05/10] Update build_on_pr.yml --- .github/workflows/build_on_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 8d96ca1b90bc..b05cb660b25f 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -199,7 +199,7 @@ jobs: fi - name: Upload test coverage artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: report path: report/ From 206a97f4808e7a4425969c26caa9000e137c002e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 5 Feb 2025 14:16:50 +0800 Subject: [PATCH 06/10] Update test_zerobubble_pp.py --- tests/test_pipeline/test_schedule/test_zerobubble_pp.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index a01b75eeebb7..67b05f0273dc 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -885,12 +885,12 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): parallel_optimizer.backward(parallel_output) parallel_optimizer.step() parallel_optimizer.zero_grad() - dist.all_reduce(parallel_output, group=plugin.dp_group) + dist.all_reduce(parallel_output, group=plugin.mixed_dp_group) # =================================================================================== # run normal model with all dp(different) inputs all_inputs = [input_embeddings.clone() for _ in range(dp_size)] - dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group) torch_output_sum = 0 for input_data_ in all_inputs: torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() @@ -1040,12 +1040,12 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): parallel_optimizer.backward(parallel_output) parallel_optimizer.step() parallel_optimizer.zero_grad() - dist.all_reduce(parallel_output, group=plugin.dp_group) + dist.all_reduce(parallel_output, group=plugin.mixed_dp_group) # =================================================================================== # run normal model with all dp(different) inputs all_inputs = [input_embeddings.clone() for _ in range(dp_size)] - dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group) torch_output_sum = 0 for input_data_ in all_inputs: torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() From 43f39bc1205657eee612a23641656977e0379cf2 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 5 Feb 2025 14:47:49 +0800 Subject: [PATCH 07/10] fix --- tests/test_pipeline/test_schedule/test_zerobubble_pp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 67b05f0273dc..47c460690cc6 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -885,12 +885,12 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): parallel_optimizer.backward(parallel_output) parallel_optimizer.step() parallel_optimizer.zero_grad() - dist.all_reduce(parallel_output, group=plugin.mixed_dp_group) + dist.all_reduce(parallel_output, group=plugin.dp_group) # =================================================================================== # run normal model with all dp(different) inputs all_inputs = [input_embeddings.clone() for _ in range(dp_size)] - dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group) + dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) torch_output_sum = 0 for input_data_ in all_inputs: torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() From 3f211f3d59af47856f716ea0205e368cc6406571 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 6 Feb 2025 10:46:11 +0800 Subject: [PATCH 08/10] fix --- .../plugin/moe_hybrid_parallel_plugin.py | 23 +++++++++++-------- .../test_schedule/test_zerobubble_pp.py | 4 ++-- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 6937b8d74ab9..8659e047a74f 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -351,6 +351,15 @@ def __init__( self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) else: self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) + + + # sync gradients across DP * SP ranks + if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": + self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis]) + self.dp_size = dist.get_world_size(self.mixed_dp_group) + else: + self.mixed_dp_group = self.dp_group + self.use_fp8 = use_fp8 self.shard_config = ShardConfig( @@ -404,7 +413,7 @@ def __init__( def get_checkpoint_io(self) -> MoECheckpointIO: return MoECheckpointIO( - self.dp_group, + self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, @@ -435,12 +444,6 @@ def configure( and self.sequence_parallelism_mode == "all_to_all" ) - # sync gradients across DP * SP ranks - if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": - dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis]) - else: - dp_group = self.dp_group - if use_ddp: self.logger.warning( f"Will have to check all params are used in pytorch DDP since not all experts are always activated", @@ -448,7 +451,7 @@ def configure( ) self.ddp_config["find_unused_parameters"] = True - if dist.get_process_group_ranks(dp_group) != dist.get_process_group_ranks(self.moe_dp_group): + if dist.get_process_group_ranks(self.mixed_dp_group) != dist.get_process_group_ranks(self.moe_dp_group): raise ValueError( f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this" ) @@ -457,7 +460,7 @@ def configure( module=model, precision=self.precision, shard_config=self.shard_config, - dp_group=dp_group, + dp_group=self.mixed_dp_group, tp_group=self.tp_group, sp_group=self.sp_group, use_ddp=use_ddp, @@ -507,7 +510,7 @@ def configure( model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info, - dp_process_group=dp_group, + dp_process_group=self.mixed_dp_group, tp_process_group=self.tp_group, pp_process_group=self.pp_group, moe_dp_group=self.moe_dp_group, diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 47c460690cc6..67b05f0273dc 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -885,12 +885,12 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): parallel_optimizer.backward(parallel_output) parallel_optimizer.step() parallel_optimizer.zero_grad() - dist.all_reduce(parallel_output, group=plugin.dp_group) + dist.all_reduce(parallel_output, group=plugin.mixed_dp_group) # =================================================================================== # run normal model with all dp(different) inputs all_inputs = [input_embeddings.clone() for _ in range(dp_size)] - dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group) torch_output_sum = 0 for input_data_ in all_inputs: torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() From 1cb07b13c0611debdea6cede021aea60d066586f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 6 Feb 2025 03:05:15 +0000 Subject: [PATCH 09/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/booster/plugin/moe_hybrid_parallel_plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 8659e047a74f..35f076e02008 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -352,7 +352,6 @@ def __init__( else: self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) - # sync gradients across DP * SP ranks if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis]) From e92c0e87e98ef0a5423e52bd0164551b4a78c8e7 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 6 Feb 2025 14:39:00 +0800 Subject: [PATCH 10/10] fix --- tests/test_shardformer/test_model/test_shard_deepseek.py | 4 ++-- tests/test_shardformer/test_model/test_shard_mixtral.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py index 4b92dbdee4bf..20dfa78c6aa6 100644 --- a/tests/test_shardformer/test_model/test_shard_deepseek.py +++ b/tests/test_shardformer/test_model/test_shard_deepseek.py @@ -125,12 +125,12 @@ def run_deepseek_commom(parallel_config: Tuple[int, ...]): parallel_optimizer.backward(parallel_output) parallel_optimizer.step() parallel_optimizer.zero_grad() - dist.all_reduce(parallel_output, group=plugin.dp_group) + dist.all_reduce(parallel_output, group=plugin.mixed_dp_group) # =================================================================================== # run normal model with all dp(different) inputs all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)] - dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group) torch_output_sum = 0 for input_data_ in all_inputs: torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index 940c66cf637b..b691130720f7 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -118,12 +118,12 @@ def run_mixtral_commom(config: Tuple[int, ...]): parallel_optimizer.backward(parallel_output) parallel_optimizer.step() parallel_optimizer.zero_grad() - dist.all_reduce(parallel_output, group=plugin.dp_group) + dist.all_reduce(parallel_output, group=plugin.mixed_dp_group) # =================================================================================== # run normal model with all dp(different) inputs all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)] - dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group) torch_output_sum = 0 for input_data_ in all_inputs: torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()