diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index fac7e7b6799e..87dfcca8ef21 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -182,7 +182,7 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None): # If sequence parallelism is enabled and mode is all_to_all, gradients are synchronized # across the sequence parallelism group. group = self.sp_group - only_sp_partial = False + only_sp_partial = True else: raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}") @@ -1100,7 +1100,6 @@ def __init__( sequence_parallelism_mode=sequence_parallelism_mode, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, - zero_stage=zero_stage, ) self.amp_config = dict( initial_scale=initial_scale, @@ -1168,7 +1167,8 @@ def configure( ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: param_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): - use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 + use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or \ + (self.dp_size == 1 and self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all") if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS]) else: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index a970cdec2fa3..a6de5948564a 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1059,9 +1059,9 @@ def custom_forward(*inputs): hidden_states = self.norm(hidden_states) - if sp_mode == "ring" or sp_mode == "split_gather" or (sp_mode == "all_to_all" and zero_stage == 0): + if sp_mode == "ring" or sp_mode == "split_gather": hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) - elif sp_mode == "all_to_all" and zero_stage in [1, 2]: + elif sp_mode == "all_to_all": hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) # add hidden states from the last decoder layer diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index ca9e74f8df63..55454b6f37c2 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -54,7 +54,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: warnings.warn( f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" ) - zero_stage = self.shard_config.zero_stage sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None sp_group = ( @@ -126,7 +125,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group, - zero_stage=zero_stage, ), }, policy=policy, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 5bf5964d25c9..07239b545229 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -37,7 +37,6 @@ class ShardConfig: enable_jit_fused: bool = False enable_sequence_parallelism: bool = False sequence_parallelism_mode: str = None - zero_stage: int = 0 enable_sequence_overlap: bool = False parallel_output: bool = True extra_kwargs: Dict[str, Any] = field(default_factory=dict) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 581a81ef59da..611f7864e834 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -163,10 +163,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 1, "pp_size": 1, + "sp_size": 2, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": False, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, @@ -183,6 +183,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, + { + "tp_size": 1, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2,