Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ jobs:
fi

- name: Upload test coverage artifact
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: report
path: report/
25 changes: 14 additions & 11 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,15 @@ def __init__(
else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)

# sync gradients across DP * SP ranks
# sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.dp_size = get_world_size(self.mixed_dp_group)
else:
self.mixed_dp_group = self.dp_group
Comment thread
ver217 marked this conversation as resolved.

self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group,
Expand Down Expand Up @@ -1298,19 +1307,11 @@ def configure(
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1 and self.pp_size == 1
)
# sync gradients across DP * SP ranks
# sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.dp_size = get_world_size(dp_group)
else:
dp_group = self.dp_group
model = HybridParallelModule(
model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=dp_group,
dp_group=self.mixed_dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp,
Expand Down Expand Up @@ -1359,7 +1360,7 @@ def configure(
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=dp_group,
dp_process_group=self.mixed_dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
verbose=True,
Expand Down Expand Up @@ -1488,7 +1489,9 @@ def seed_worker(worker_id):
)

def get_checkpoint_io(self) -> CheckpointIO:
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage)
return HybridParallelCheckpointIO(
self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage
)

def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert (
Expand Down
22 changes: 12 additions & 10 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,14 @@ def __init__(
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)

# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
self.dp_size = dist.get_world_size(self.mixed_dp_group)
else:
self.mixed_dp_group = self.dp_group

self.use_fp8 = use_fp8

self.shard_config = ShardConfig(
Expand Down Expand Up @@ -404,7 +412,7 @@ def __init__(

def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
self.dp_group,
self.mixed_dp_group,
self.pp_group,
self.tp_group,
self.sp_group,
Expand Down Expand Up @@ -435,20 +443,14 @@ def configure(
and self.sequence_parallelism_mode == "all_to_all"
)

# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
else:
dp_group = self.dp_group

if use_ddp:
self.logger.warning(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
ranks=[0],
)
self.ddp_config["find_unused_parameters"] = True

if dist.get_process_group_ranks(dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
if dist.get_process_group_ranks(self.mixed_dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
raise ValueError(
f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this"
)
Expand All @@ -457,7 +459,7 @@ def configure(
module=model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=dp_group,
dp_group=self.mixed_dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp,
Expand Down Expand Up @@ -507,7 +509,7 @@ def configure(
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=dp_group,
dp_process_group=self.mixed_dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
moe_dp_group=self.moe_dp_group,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_pipeline/test_schedule/test_zerobubble_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,12 +885,12 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)

# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
Expand Down Expand Up @@ -1040,12 +1040,12 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)

# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_shardformer/test_model/test_shard_deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ def run_deepseek_commom(parallel_config: Tuple[int, ...]):
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)

# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_shardformer/test_model/test_shard_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ def run_mixtral_commom(config: Tuple[int, ...]):
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)

# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
Expand Down