Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions colossalai/shardformer/shard/grad_ckpt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`.

"""

"""
Args:
gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None.
Expand Down Expand Up @@ -49,6 +50,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
num_stages: Optional[int] = None
num_model_chunks: Optional[int] = None
num_model_layers: Optional[int] = None
num_layers_per_stage: Optional[List[int]] = None
num_ckpt_layers_per_stage: Optional[List[int]] = None

def __post_init__(self):
Expand All @@ -70,6 +72,10 @@ def __post_init__(self):
def _enable_gradient_checkpointing_ratio(self) -> bool:
return self.gradient_checkpointing_ratio is not None

@property
def _customize_num_layers_per_stage(self) -> bool:
return self.num_layers_per_stage is not None and self.num_model_layers is not None

@property
def _enable_customized_ckpt_layers_per_stage(self) -> bool:
return self.num_ckpt_layers_per_stage is not None
Expand Down
13 changes: 12 additions & 1 deletion colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from colossalai.pipeline.stage_manager import PipelineStageManager

from .grad_ckpt_config import GradientCheckpointConfig
from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig

__all__ = ["ShardConfig"]
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
Expand All @@ -30,6 +30,7 @@ class ShardConfig:
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
"""

tensor_parallel_process_group: Optional[ProcessGroup] = None
sequence_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None
Expand Down Expand Up @@ -104,6 +105,16 @@ def __post_init__(self):
else:
self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)

if (
self.pipeline_stage_manager is not None
and isinstance(self.gradient_checkpoint_config, PipelineGradientCheckpointConfig)
and self.gradient_checkpoint_config._customize_num_layers_per_stage
):
self.pipeline_stage_manager.set_distribution_config(
self.gradient_checkpoint_config.num_model_layers,
self.gradient_checkpoint_config.num_layers_per_stage,
)

def _turn_on_all_optimization(self):
"""
Turn on all optimization.
Expand Down