Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None):
# If sequence parallelism is enabled and mode is all_to_all, gradients are synchronized
# across the sequence parallelism group.
group = self.sp_group
only_sp_partial = False
only_sp_partial = True
else:
raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}")

Expand Down Expand Up @@ -1100,7 +1100,6 @@ def __init__(
sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
zero_stage=zero_stage,
)
self.amp_config = dict(
initial_scale=initial_scale,
Expand Down Expand Up @@ -1168,7 +1167,8 @@ def configure(
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or \
(self.dp_size == 1 and self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all")
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS])
else:
Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,9 +1059,9 @@ def custom_forward(*inputs):

hidden_states = self.norm(hidden_states)

if sp_mode == "ring" or sp_mode == "split_gather" or (sp_mode == "all_to_all" and zero_stage == 0):
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all" and zero_stage in [1, 2]:
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)

# add hidden states from the last decoder layer
Expand Down
2 changes: 0 additions & 2 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
warnings.warn(
f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
)
zero_stage = self.shard_config.zero_stage
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None
sp_group = (
Expand Down Expand Up @@ -126,7 +125,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
zero_stage=zero_stage,
),
},
policy=policy,
Expand Down
1 change: 0 additions & 1 deletion colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class ShardConfig:
enable_jit_fused: bool = False
enable_sequence_parallelism: bool = False
sequence_parallelism_mode: str = None
zero_stage: int = 0
enable_sequence_overlap: bool = False
parallel_output: bool = True
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
Expand Down
13 changes: 12 additions & 1 deletion tests/test_shardformer/test_model/test_shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
Expand All @@ -183,6 +183,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
Expand Down