From 702ffe5befe39f9dd51f362e75e55490416b7bf7 Mon Sep 17 00:00:00 2001 From: ver217 Date: Sun, 21 Apr 2024 21:07:20 +0800 Subject: [PATCH 1/3] [shardformer] fix pipeline grad ckpt --- colossalai/shardformer/shard/grad_ckpt_config.py | 6 ++++++ colossalai/shardformer/shard/shard_config.py | 16 ++++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/shard/grad_ckpt_config.py b/colossalai/shardformer/shard/grad_ckpt_config.py index 9c6c2b54ea39..9fc857d19dbc 100644 --- a/colossalai/shardformer/shard/grad_ckpt_config.py +++ b/colossalai/shardformer/shard/grad_ckpt_config.py @@ -22,6 +22,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig): 2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`. """ + """ Args: gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None. @@ -49,6 +50,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig): num_stages: Optional[int] = None num_model_chunks: Optional[int] = None num_model_layers: Optional[int] = None + num_layers_per_stage: Optional[List[int]] = None num_ckpt_layers_per_stage: Optional[List[int]] = None def __post_init__(self): @@ -70,6 +72,10 @@ def __post_init__(self): def _enable_gradient_checkpointing_ratio(self) -> bool: return self.gradient_checkpointing_ratio is not None + @property + def _customize_num_layers_per_stage(self) -> bool: + return self.num_layers_per_stage is not None and self.num_model_layers is not None + @property def _enable_customized_ckpt_layers_per_stage(self) -> bool: return self.num_ckpt_layers_per_stage is not None diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 963732543f27..2af230947d9b 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -7,7 +7,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager -from .grad_ckpt_config import GradientCheckpointConfig +from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig __all__ = ["ShardConfig"] SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] @@ -30,6 +30,7 @@ class ShardConfig: gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. """ + tensor_parallel_process_group: Optional[ProcessGroup] = None sequence_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None @@ -42,9 +43,10 @@ class ShardConfig: sequence_parallelism_mode: str = None enable_sequence_overlap: bool = False parallel_output: bool = True - make_vocab_size_divisible_by: int = 64 gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) + # TODO padding vocab + # make_vocab_size_divisible_by: int = 128 # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] @@ -104,6 +106,16 @@ def __post_init__(self): else: self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group) + if ( + self.pipeline_stage_manager is not None + and isinstance(self.gradient_checkpoint_config, PipelineGradientCheckpointConfig) + and self.gradient_checkpoint_config._customize_num_layers_per_stage + ): + self.pipeline_stage_manager.set_distribution_config( + self.gradient_checkpoint_config.num_model_layers, + self.gradient_checkpoint_config.num_layers_per_stage, + ) + def _turn_on_all_optimization(self): """ Turn on all optimization. From e0f8b6855e178261cecdae3b8d22873e4722ce83 Mon Sep 17 00:00:00 2001 From: ver217 Date: Sun, 21 Apr 2024 21:09:18 +0800 Subject: [PATCH 2/3] [shardformer] fix pipeline grad ckpt --- colossalai/shardformer/shard/shard_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 2af230947d9b..6e34c955e5f0 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -43,6 +43,7 @@ class ShardConfig: sequence_parallelism_mode: str = None enable_sequence_overlap: bool = False parallel_output: bool = True + make_vocab_size_divisible_by: int = 64 gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) # TODO padding vocab From cb7bf94d71c8e266f3e1568dc273c64ed28cf3df Mon Sep 17 00:00:00 2001 From: ver217 Date: Sun, 21 Apr 2024 21:09:55 +0800 Subject: [PATCH 3/3] [shardformer] fix pipeline grad ckpt --- colossalai/shardformer/shard/shard_config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 6e34c955e5f0..597dd9c26354 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -46,8 +46,6 @@ class ShardConfig: make_vocab_size_divisible_by: int = 64 gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) - # TODO padding vocab - # make_vocab_size_divisible_by: int = 128 # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']