From 39384a1496972c087a4881218ee16c9cff703f9d Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 25 Mar 2024 18:19:16 +0800 Subject: [PATCH 01/21] feat: add `gradient_checkpointing_ratio` in ShardConfig --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 8 ++++++++ colossalai/shardformer/shard/shard_config.py | 5 +++++ 2 files changed, 13 insertions(+) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index f51cb060c356..ce1eee4d4d88 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -969,6 +969,7 @@ def __init__( custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, + gradient_checkpointing_ratio: Optional[float] = None, enable_metadata_cache: bool = True, ) -> None: super().__init__() @@ -1032,6 +1033,12 @@ def __init__( self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + if gradient_checkpointing_ratio is not None: + if gradient_checkpointing_ratio < 0 or gradient_checkpointing_ratio > 1: + raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%") + warnings.warn( + "gradient_checkpointing_ratio is only used in PipelineParallelism, will be ignored in other parallelism" + ) self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, @@ -1043,6 +1050,7 @@ def __init__( enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, + gradient_checkpointing_ratio=gradient_checkpointing_ratio, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index da27341d9c29..ccd25f004a8a 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -35,6 +35,8 @@ class ShardConfig: enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False parallel_output: bool = True + # NOTE: FIXME: gradient_checkpointing_ratio is only used in PipelineParallelism + gradient_checkpointing_ratio: Optional[float] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) # TODO padding vocab # make_vocab_size_divisible_by: int = 128 @@ -62,6 +64,9 @@ def __post_init__(self): if self.enable_all_optimization: self._turn_on_all_optimization() + if self.gradient_checkpointing_ratio is not None and not (0 <= self.gradient_checkpointing_ratio <= 1): + raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%") + def _turn_on_all_optimization(self): """ Turn on all optimization. From 2f7e78e08a1b0987d167e9577f33230bdbd98f08 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 26 Mar 2024 11:37:48 +0800 Subject: [PATCH 02/21] feat: add `gradient_checkpointing_ratio` for LlamaPipelineForwards --- colossalai/shardformer/modeling/llama.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 29dc8200f338..71dd78f9f724 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -138,13 +138,19 @@ def llama_model_forward( next_decoder_cache = () if use_cache else None start_idx, end_idx = stage_index[0], stage_index[1] + num_ckpt_layers = 0 + if self.gradient_checkpointing and self.training: + num_ckpt_layers = end_idx - start_idx + if shard_config.gradient_checkpointing_ratio is not None: + num_ckpt_layers = int(shard_config.gradient_checkpointing_ratio * num_ckpt_layers) + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: + if idx - start_idx < num_ckpt_layers: def create_custom_forward(module): def custom_forward(*inputs): From da52eb388453a74f24835814d7df9dd3ceed3ea0 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 26 Mar 2024 14:53:01 +0800 Subject: [PATCH 03/21] chore: update comments and test --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 1 + colossalai/shardformer/shard/shard_config.py | 1 + tests/test_shardformer/test_model/test_shard_llama.py | 1 + 3 files changed, 3 insertions(+) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index ce1eee4d4d88..1a674f7bdc79 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -930,6 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. + gradient_checkpointing_ratio (float, optional): The ratio [0, 1] of gradient checkpointing. Defaults to None. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. """ diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index ccd25f004a8a..a5219f97512f 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -23,6 +23,7 @@ class ShardConfig: enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. + gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 126ff23a9f25..289502837a33 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -101,6 +101,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, + "gradient_checkpointing_ratio": 0.5, }, { "tp_size": 1, From 19c8407935c5b63f7c07b17b0014eb3a2e75c08a Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 26 Mar 2024 16:15:08 +0800 Subject: [PATCH 04/21] feat: add `AdvancedPipelineConfig` class --- .../booster/plugin/hybrid_parallel_plugin.py | 14 +-- colossalai/shardformer/__init__.py | 2 +- .../shardformer/policies/base_policy.py | 5 + colossalai/shardformer/shard/__init__.py | 4 +- colossalai/shardformer/shard/shard_config.py | 117 ++++++++++++++++-- 5 files changed, 122 insertions(+), 20 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1a674f7bdc79..e6c925e498af 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -26,7 +26,7 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer import AdvancedPipelineConfig, ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy from colossalai.tensor.d_tensor.api import is_distributed_tensor @@ -930,7 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. - gradient_checkpointing_ratio (float, optional): The ratio [0, 1] of gradient checkpointing. Defaults to None. + advanced_pipeline_config (AdvancedPipelineConfig, optional): Advanced pipeline configuration for pipeline parallelism. Defaults to None. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. """ @@ -970,7 +970,7 @@ def __init__( custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, - gradient_checkpointing_ratio: Optional[float] = None, + advanced_pipeline_config: Optional[AdvancedPipelineConfig] = None, enable_metadata_cache: bool = True, ) -> None: super().__init__() @@ -1034,12 +1034,6 @@ def __init__( self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) - if gradient_checkpointing_ratio is not None: - if gradient_checkpointing_ratio < 0 or gradient_checkpointing_ratio > 1: - raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%") - warnings.warn( - "gradient_checkpointing_ratio is only used in PipelineParallelism, will be ignored in other parallelism" - ) self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, @@ -1051,7 +1045,7 @@ def __init__( enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, - gradient_checkpointing_ratio=gradient_checkpointing_ratio, + advanced_pipeline_config=advanced_pipeline_config, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py index 77c2af8d18f7..8143054f4236 100644 --- a/colossalai/shardformer/__init__.py +++ b/colossalai/shardformer/__init__.py @@ -1 +1 @@ -from .shard import ShardConfig, ShardFormer +from .shard import AdvancedPipelineConfig, ShardConfig, ShardFormer diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 762e754816bf..0b076c3d6005 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -199,6 +199,11 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: """Divide layers into stages""" + if self.shard_config.advanced_pipeline_config is not None: + advanced_pipeline_config = self.shard_config.advanced_pipeline_config + if advanced_pipeline_config.enable_customized_layers_per_stage: + return advanced_pipeline_config.distribute_layers(num_layers, num_stages) + quotient = num_layers // num_stages remainder = num_layers % num_stages diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py index acf8a95a41ca..f12b3631f69c 100644 --- a/colossalai/shardformer/shard/__init__.py +++ b/colossalai/shardformer/shard/__init__.py @@ -1,5 +1,5 @@ -from .shard_config import ShardConfig +from .shard_config import AdvancedPipelineConfig, ShardConfig from .sharder import ModelSharder from .shardformer import ShardFormer -__all__ = ["ShardConfig", "ModelSharder", "ShardFormer"] +__all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "AdvancedPipelineConfig"] diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index a5219f97512f..a92a00ca7b81 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -6,7 +6,114 @@ from colossalai.pipeline.stage_manager import PipelineStageManager -__all__ = ["ShardConfig"] +__all__ = ["ShardConfig", "AdvancedPipelineConfig"] + + +class AdvancedPipelineConfig: + r""" + The advanced pipeline config is designed to provide more flexibility for users to customize the pipeline parallelism. + Refer to https://github.com/hpcaitech/ColossalAI/issues/5509 for more details. + + It provides the following features: + 1. `gradient_checkpointing_ratio`: This is used to control gradient checkpointing more precisely, e.g., set 50% of the layers to use gradient checkpointing. + 2. Customize # layers and # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`. + + """ + + def __init__( + self, + gradient_checkpointing_ratio: Optional[float] = None, + num_stages: Optional[int] = None, + num_model_chunks: Optional[int] = None, + num_model_layers: Optional[int] = None, + num_layers_per_stage: Optional[List[int]] = None, + num_ckpt_layers_per_stage: Optional[List[int]] = None, + ) -> None: + """ + Args: + gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None. + num_stages (Optional[int]): Number of stages in the pipeline. Defaults to None. For sanity check. + num_model_chunks (Optional[int]): Number of model chunks (1F1B or Interleaved). Defaults to None. For sanity check. + num_model_layers (Optional[int]): Number of model layers. Defaults to None. For sanity check. + num_layers_per_stage (Optional[List[int]]): Number of layers for each stage. Defaults to None. + num_ckpt_layers_per_stage (Optional[List[int]]): Number of checkpointed layers for each stage. Defaults to None. + + Example 1: + num_stages = 8 + num_layers = 80 + num_model_chunks = 1 + num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11] + num_ckpt_layers_per_stage = [4, 4, 2, 2, 0, 0, 0, 0] + + Example 2: + num_stages = 4 + num_layers = 80 + num_model_chunks = 2 + num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11] + # device 0 holds num_layers_per_stage[0] and num_layers_per_stage[4] layers + ... + + """ + self.gradient_checkpointing_ratio = gradient_checkpointing_ratio + self.num_stages = num_stages + self.num_model_chunks = num_model_chunks + self.num_model_layers = num_model_layers + self.num_layers_per_stage = num_layers_per_stage + self.num_ckpt_layers_per_stage = num_ckpt_layers_per_stage + self._sanity_check() + + @property + def enable_gradient_checkpointing_ratio(self) -> bool: + return self.gradient_checkpointing_ratio is not None + + @property + def enable_customized_layers_per_stage(self) -> bool: + return self.num_layers_per_stage is not None + + @property + def enable_customized_ckpt_layers_per_stage(self) -> bool: + return self.num_ckpt_layers_per_stage is not None + + def _sanity_check(self): + if self.gradient_checkpointing_ratio is not None: + if not (0 <= self.gradient_checkpointing_ratio <= 1): + raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%") + + if self.num_layers_per_stage is not None: + assert ( + self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None + ) + assert all([0 < num_layers < self.num_model_layers for num_layers in self.num_layers_per_stage]) + assert sum(self.num_layers_per_stage) == self.num_model_layers + assert len(self.num_layers_per_stage) == self.num_stages * self.num_model_chunks + + if self.num_ckpt_layers_per_stage is not None: + assert self.num_layers_per_stage is not None + assert len(self.num_layers_per_stage) == len(self.num_ckpt_layers_per_stage) + assert all( + [ + 0 <= num_ckpt_layers <= num_layers + for num_ckpt_layers, num_layers in zip(self.num_ckpt_layers_per_stage, self.num_layers_per_stage) + ] + ) + self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / sum(self.num_layers_per_stage) + + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: + assert self.enable_customized_layers_per_stage + assert num_layers == self.num_model_layers and num_stages == self.num_stages + return self.num_layers_per_stage + + def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 1) -> int: + if self.enable_customized_layers_per_stage: + assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks + assert num_layers == self.num_layers_per_stage[stage] + + if self.enable_customized_ckpt_layers_per_stage: + return self.num_ckpt_layers_per_stage[stage] + elif self.enable_gradient_checkpointing_ratio: + return int(self.gradient_checkpointing_ratio * num_layers) + else: + raise RuntimeError("No checkpointed layers information is provided") @dataclass @@ -23,7 +130,7 @@ class ShardConfig: enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. - gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None. + advanced_pipeline_config (Optional[AdvancedPipelineConfig]): The advanced pipeline config for more flexibility in pipeline parallelism. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None @@ -36,8 +143,7 @@ class ShardConfig: enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False parallel_output: bool = True - # NOTE: FIXME: gradient_checkpointing_ratio is only used in PipelineParallelism - gradient_checkpointing_ratio: Optional[float] = None + advanced_pipeline_config: Optional[AdvancedPipelineConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) # TODO padding vocab # make_vocab_size_divisible_by: int = 128 @@ -65,9 +171,6 @@ def __post_init__(self): if self.enable_all_optimization: self._turn_on_all_optimization() - if self.gradient_checkpointing_ratio is not None and not (0 <= self.gradient_checkpointing_ratio <= 1): - raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%") - def _turn_on_all_optimization(self): """ Turn on all optimization. From d40bd55d36cfedf4291368595334b74af066f453 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Wed, 27 Mar 2024 12:06:54 +0800 Subject: [PATCH 05/21] feat: apply `AdvancedPipelineConfig` to policy and llama_forward --- colossalai/shardformer/modeling/llama.py | 10 +++- .../shardformer/policies/base_policy.py | 2 +- colossalai/shardformer/shard/shard_config.py | 52 ++++++++++--------- 3 files changed, 37 insertions(+), 27 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 71dd78f9f724..2ef2b5f0fc7d 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -141,8 +141,14 @@ def llama_model_forward( num_ckpt_layers = 0 if self.gradient_checkpointing and self.training: num_ckpt_layers = end_idx - start_idx - if shard_config.gradient_checkpointing_ratio is not None: - num_ckpt_layers = int(shard_config.gradient_checkpointing_ratio * num_ckpt_layers) + if shard_config.advanced_pipeline_config is not None: + advanced_pipeline_config = shard_config.advanced_pipeline_config + if advanced_pipeline_config.control_gradient_checkpointing: + num_ckpt_layers = advanced_pipeline_config.get_num_ckpt_layers( + stage_manager.stage, + end_idx - start_idx, + stage_manager.model_chunk_id if stage_manager.is_interleave else 0, + ) for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 0b076c3d6005..4d541ed6b027 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -201,7 +201,7 @@ def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: """Divide layers into stages""" if self.shard_config.advanced_pipeline_config is not None: advanced_pipeline_config = self.shard_config.advanced_pipeline_config - if advanced_pipeline_config.enable_customized_layers_per_stage: + if advanced_pipeline_config.control_distribute_layers: return advanced_pipeline_config.distribute_layers(num_layers, num_stages) quotient = num_layers // num_stages diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index a92a00ca7b81..a9468b4a515b 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import torch.distributed as dist from torch.distributed import ProcessGroup @@ -62,24 +62,12 @@ def __init__( self.num_ckpt_layers_per_stage = num_ckpt_layers_per_stage self._sanity_check() - @property - def enable_gradient_checkpointing_ratio(self) -> bool: - return self.gradient_checkpointing_ratio is not None - - @property - def enable_customized_layers_per_stage(self) -> bool: - return self.num_layers_per_stage is not None - - @property - def enable_customized_ckpt_layers_per_stage(self) -> bool: - return self.num_ckpt_layers_per_stage is not None - def _sanity_check(self): - if self.gradient_checkpointing_ratio is not None: + if self._enable_gradient_checkpointing_ratio: if not (0 <= self.gradient_checkpointing_ratio <= 1): raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%") - if self.num_layers_per_stage is not None: + if self.control_distribute_layers: assert ( self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None ) @@ -87,7 +75,7 @@ def _sanity_check(self): assert sum(self.num_layers_per_stage) == self.num_model_layers assert len(self.num_layers_per_stage) == self.num_stages * self.num_model_chunks - if self.num_ckpt_layers_per_stage is not None: + if self._enable_customized_ckpt_layers_per_stage: assert self.num_layers_per_stage is not None assert len(self.num_layers_per_stage) == len(self.num_ckpt_layers_per_stage) assert all( @@ -98,19 +86,35 @@ def _sanity_check(self): ) self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / sum(self.num_layers_per_stage) + @property + def control_gradient_checkpointing(self) -> bool: + return self._enable_gradient_checkpointing_ratio or self._enable_customized_ckpt_layers_per_stage + + @property + def control_distribute_layers(self) -> bool: + return self.num_layers_per_stage is not None + + @property + def _enable_gradient_checkpointing_ratio(self) -> bool: + return self.gradient_checkpointing_ratio is not None + + @property + def _enable_customized_ckpt_layers_per_stage(self) -> bool: + return self.num_ckpt_layers_per_stage is not None + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: - assert self.enable_customized_layers_per_stage - assert num_layers == self.num_model_layers and num_stages == self.num_stages + assert self.control_distribute_layers + assert num_layers == self.num_model_layers and num_stages == self.num_stages * self.num_model_chunks return self.num_layers_per_stage - def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 1) -> int: - if self.enable_customized_layers_per_stage: + def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int: + if self.control_distribute_layers: assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks - assert num_layers == self.num_layers_per_stage[stage] + assert num_layers == self.num_layers_per_stage[stage + model_chunk_id * self.num_stages] - if self.enable_customized_ckpt_layers_per_stage: - return self.num_ckpt_layers_per_stage[stage] - elif self.enable_gradient_checkpointing_ratio: + if self._enable_customized_ckpt_layers_per_stage: + return self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages] + elif self._enable_gradient_checkpointing_ratio: return int(self.gradient_checkpointing_ratio * num_layers) else: raise RuntimeError("No checkpointed layers information is provided") From 0ef9e6add3a38b0c009d835fff133fec8ca686a7 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Wed, 27 Mar 2024 12:07:37 +0800 Subject: [PATCH 06/21] test: update llama tests --- tests/kit/model_zoo/transformers/llama.py | 6 +++--- .../test_model/test_shard_llama.py | 20 ++++++++++++++++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 4730642705ff..9f801e0cc732 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -49,9 +49,9 @@ def data_gen_for_casual_lm(): loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = LlamaConfig( - num_hidden_layers=4, - hidden_size=128, - intermediate_size=256, + num_hidden_layers=8, + hidden_size=32, + intermediate_size=64, num_attention_heads=4, max_position_embeddings=128, num_labels=16, diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 289502837a33..be4164e62d70 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -5,6 +5,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import AdvancedPipelineConfig from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @@ -24,9 +25,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False) org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config ) + if enable_gradient_checkpointing: + org_model.gradient_checkpointing_enable() + sharded_model.unwrap().gradient_checkpointing_enable() org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster @@ -101,7 +106,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, - "gradient_checkpointing_ratio": 0.5, + "enable_gradient_checkpointing": True, + "advanced_pipeline_config": AdvancedPipelineConfig(gradient_checkpointing_ratio=0.5), }, { "tp_size": 1, @@ -109,6 +115,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 4, "use_lazy_init": False, "precision": "fp32", + "enable_gradient_checkpointing": True, + "advanced_pipeline_config": AdvancedPipelineConfig( + num_stages=2, num_model_chunks=1, num_model_layers=8, num_layers_per_stage=[5, 3] + ), }, { "tp_size": 4, @@ -190,6 +200,14 @@ def run_llama_test(test_config): "precision": "fp16", "zero_stage": 1, "initial_scale": 1, + "enable_gradient_checkpointing": True, + "advanced_pipeline_config": AdvancedPipelineConfig( + num_stages=2, + num_model_chunks=2, + num_model_layers=8, + num_layers_per_stage=[3, 3, 1, 1], + num_ckpt_layers_per_stage=[0, 0, 1, 1], + ), }, ], ) From 7156110bc6fdae7092b3d959615f9c3499f4babf Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Wed, 27 Mar 2024 14:11:20 +0800 Subject: [PATCH 07/21] fix: fix typo --- colossalai/shardformer/policies/llama.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index daa7708c8fdf..e38c8af03560 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -185,9 +185,6 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config ) } - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=model_cls - ) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) From 272e0a3df1c061dafe4f57b50a07ceacceadd189 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Wed, 27 Mar 2024 15:42:48 +0800 Subject: [PATCH 08/21] fix: fix test cases --- colossalai/shardformer/policies/base_policy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 4d541ed6b027..6e8d9f22984f 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -199,9 +199,9 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: """Divide layers into stages""" - if self.shard_config.advanced_pipeline_config is not None: + if self.shard_config is not None: advanced_pipeline_config = self.shard_config.advanced_pipeline_config - if advanced_pipeline_config.control_distribute_layers: + if advanced_pipeline_config is not None and advanced_pipeline_config.control_distribute_layers: return advanced_pipeline_config.distribute_layers(num_layers, num_stages) quotient = num_layers // num_stages From c2643486a6852b03d8ce60963e1394c2d334b273 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 28 Mar 2024 16:46:58 +0800 Subject: [PATCH 09/21] to: move `distribute_layer` and `get_stage_index` to PipelineStageManager --- colossalai/pipeline/stage_manager.py | 69 ++++++++++++++++++- .../shardformer/policies/base_policy.py | 54 +-------------- 2 files changed, 69 insertions(+), 54 deletions(-) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index c8f9042084da..4751ee80ba07 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -1,6 +1,7 @@ import contextlib -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union +import numpy as np import torch.distributed as dist from torch.distributed import ProcessGroup @@ -29,6 +30,8 @@ def __init__( ) -> None: assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False" + self.num_layers_per_stage = None + self.pg_mesh = pg_mesh self.pipeline_axis = pipeline_axis self.prev_rank: Optional[Tuple[int, ...]] = None @@ -69,6 +72,70 @@ def __init__( # for shardformer, hold model chunk id self.model_chunk_id: Optional[int] = None + @property + def control_distribute_layers(self) -> bool: + return self.num_layers_per_stage is not None + + def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None: + """Set the distribution configuration. + + Args: + num_model_layers (int): Number of layers in the model. + num_layers_per_stage (List[int]): Number of layers for each stage. + """ + assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage]) + assert sum(num_layers_per_stage) == num_model_layers + assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1) + self.num_model_layers = num_model_layers + self.num_layers_per_stage = num_layers_per_stage + + def distribute_layers(self, num_layers: int) -> List[int]: + """Divide layers into stages""" + if self.control_distribute_layers: + assert num_layers == self.num_model_layers + return self.num_layers_per_stage + + else: + quotient = num_layers // self.num_stages + remainder = num_layers % self.num_stages + + # calculate the num_layers per stage + layers_per_stage = [quotient] * self.num_stages + + # deal with the rest layers + if remainder > 0: + start_position = self.num_stages // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 + return layers_per_stage + + def get_stage_index( + self, + layers_per_stage: List[int], + ) -> Union[Tuple[int, int], List[Tuple[int, int]]]: + """ + Get the start index and end index of layers for each stage. + + Args: + layers_per_stage (List[int]): number of layers for each stage + stage (int): the stage index + + Returns: + - Tuple[int, int]: the start index and end index of this stage + - List[Tuple[int, int]]: the start index and end index of this stage for each model chunk + + """ + num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) + + stage_indices = [] + num_model_chunks = self.num_model_chunks if self.is_interleave else 1 + for model_chunk in range(num_model_chunks): + start_idx = num_layers_per_stage_accumulated[self.stage + model_chunk * self.num_stages] + end_idx = num_layers_per_stage_accumulated[self.stage + model_chunk * self.num_stages + 1] + stage_indices.append([start_idx, end_idx]) + + return stage_indices[0] if num_model_chunks == 1 else stage_indices + def is_first_stage(self, ignore_chunk: bool = False) -> bool: """Is the current stage the first stage. diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 6e8d9f22984f..d67ab0a3c6bb 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -2,9 +2,8 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import torch.nn as nn from torch import Tensor from torch.nn import Module @@ -196,54 +195,3 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] """ return [] - - def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: - """Divide layers into stages""" - if self.shard_config is not None: - advanced_pipeline_config = self.shard_config.advanced_pipeline_config - if advanced_pipeline_config is not None and advanced_pipeline_config.control_distribute_layers: - return advanced_pipeline_config.distribute_layers(num_layers, num_stages) - - quotient = num_layers // num_stages - remainder = num_layers % num_stages - - # calculate the num_layers per stage - layers_per_stage = [quotient] * num_stages - - # deal with the rest layers - if remainder > 0: - start_position = num_stages // 2 - remainder // 2 - for i in range(start_position, start_position + remainder): - layers_per_stage[i] += 1 - return layers_per_stage - - def get_stage_index( - self, - layers_per_stage: List[int], - stage: int, - num_model_chunks: int = 1, - num_stages: int = 0, - ) -> Union[Tuple[int, int], List[Tuple[int, int]]]: - """ - Get the start index and end index of layers for each stage. - - Args: - layers_per_stage (List[int]): number of layers for each stage - stage (int): the stage index - num_stages (int): number of stages - num_model_chunks (int): number of model chunks - - Returns: - - Tuple[int, int]: the start index and end index of this stage - - List[Tuple[int, int]]: the start index and end index of this stage for each model chunk - - """ - num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) - - stage_indices = [] - for model_chunk in range(num_model_chunks): - start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] - end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] - stage_indices.append([start_idx, end_idx]) - - return stage_indices[0] if num_model_chunks == 1 else stage_indices From 8d25b6b985d78cbd9e2135a31c80822de0796e2e Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 28 Mar 2024 17:06:50 +0800 Subject: [PATCH 10/21] feat: change `AdvancedPipelineConfig` to `PipelineGradientConfig` --- .../booster/plugin/hybrid_parallel_plugin.py | 8 +-- colossalai/shardformer/__init__.py | 2 +- colossalai/shardformer/shard/__init__.py | 4 +- colossalai/shardformer/shard/shard_config.py | 54 +++++++------------ 4 files changed, 25 insertions(+), 43 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index e6c925e498af..fb9cf1a3e3cf 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -26,7 +26,7 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import AdvancedPipelineConfig, ShardConfig, ShardFormer +from colossalai.shardformer import PipelineGradientConfig, ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy from colossalai.tensor.d_tensor.api import is_distributed_tensor @@ -930,7 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. - advanced_pipeline_config (AdvancedPipelineConfig, optional): Advanced pipeline configuration for pipeline parallelism. Defaults to None. + pipeline_gradient_config (AdvancedPipelineConfig, optional): The configuration for pipeline parallelism. Defaults to None. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. """ @@ -970,7 +970,7 @@ def __init__( custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, - advanced_pipeline_config: Optional[AdvancedPipelineConfig] = None, + pipeline_gradient_config: Optional[PipelineGradientConfig] = None, enable_metadata_cache: bool = True, ) -> None: super().__init__() @@ -1045,7 +1045,7 @@ def __init__( enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, - advanced_pipeline_config=advanced_pipeline_config, + pipeline_gradient_config=pipeline_gradient_config, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py index 8143054f4236..9f49eabaf8f9 100644 --- a/colossalai/shardformer/__init__.py +++ b/colossalai/shardformer/__init__.py @@ -1 +1 @@ -from .shard import AdvancedPipelineConfig, ShardConfig, ShardFormer +from .shard import PipelineGradientConfig, ShardConfig, ShardFormer diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py index f12b3631f69c..b3782152e460 100644 --- a/colossalai/shardformer/shard/__init__.py +++ b/colossalai/shardformer/shard/__init__.py @@ -1,5 +1,5 @@ -from .shard_config import AdvancedPipelineConfig, ShardConfig +from .shard_config import PipelineGradientConfig, ShardConfig from .sharder import ModelSharder from .shardformer import ShardFormer -__all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "AdvancedPipelineConfig"] +__all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "PipelineGradientConfig"] diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index a9468b4a515b..4b96b7aa702c 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -6,17 +6,18 @@ from colossalai.pipeline.stage_manager import PipelineStageManager -__all__ = ["ShardConfig", "AdvancedPipelineConfig"] +__all__ = ["ShardConfig", "PipelineGradientConfig"] -class AdvancedPipelineConfig: +class PipelineGradientConfig: r""" - The advanced pipeline config is designed to provide more flexibility for users to customize the pipeline parallelism. + The pipeline gradient config is designed to provide more flexibility for users to control gradient checkpoint in pipeline parallelism. + Combined with PipelineStageManager.set_distribution_config, user can fully control the distribution of layers and checkpointed layers in pipeline parallelism. Refer to https://github.com/hpcaitech/ColossalAI/issues/5509 for more details. It provides the following features: 1. `gradient_checkpointing_ratio`: This is used to control gradient checkpointing more precisely, e.g., set 50% of the layers to use gradient checkpointing. - 2. Customize # layers and # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`. + 2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`. """ @@ -26,7 +27,6 @@ def __init__( num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None, num_model_layers: Optional[int] = None, - num_layers_per_stage: Optional[List[int]] = None, num_ckpt_layers_per_stage: Optional[List[int]] = None, ) -> None: """ @@ -35,7 +35,6 @@ def __init__( num_stages (Optional[int]): Number of stages in the pipeline. Defaults to None. For sanity check. num_model_chunks (Optional[int]): Number of model chunks (1F1B or Interleaved). Defaults to None. For sanity check. num_model_layers (Optional[int]): Number of model layers. Defaults to None. For sanity check. - num_layers_per_stage (Optional[List[int]]): Number of layers for each stage. Defaults to None. num_ckpt_layers_per_stage (Optional[List[int]]): Number of checkpointed layers for each stage. Defaults to None. Example 1: @@ -58,7 +57,6 @@ def __init__( self.num_stages = num_stages self.num_model_chunks = num_model_chunks self.num_model_layers = num_model_layers - self.num_layers_per_stage = num_layers_per_stage self.num_ckpt_layers_per_stage = num_ckpt_layers_per_stage self._sanity_check() @@ -67,33 +65,20 @@ def _sanity_check(self): if not (0 <= self.gradient_checkpointing_ratio <= 1): raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%") - if self.control_distribute_layers: + if self._enable_customized_ckpt_layers_per_stage: assert ( self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None ) - assert all([0 < num_layers < self.num_model_layers for num_layers in self.num_layers_per_stage]) - assert sum(self.num_layers_per_stage) == self.num_model_layers - assert len(self.num_layers_per_stage) == self.num_stages * self.num_model_chunks - - if self._enable_customized_ckpt_layers_per_stage: - assert self.num_layers_per_stage is not None - assert len(self.num_layers_per_stage) == len(self.num_ckpt_layers_per_stage) + assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks assert all( - [ - 0 <= num_ckpt_layers <= num_layers - for num_ckpt_layers, num_layers in zip(self.num_ckpt_layers_per_stage, self.num_layers_per_stage) - ] + [0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage] ) - self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / sum(self.num_layers_per_stage) + self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers @property def control_gradient_checkpointing(self) -> bool: return self._enable_gradient_checkpointing_ratio or self._enable_customized_ckpt_layers_per_stage - @property - def control_distribute_layers(self) -> bool: - return self.num_layers_per_stage is not None - @property def _enable_gradient_checkpointing_ratio(self) -> bool: return self.gradient_checkpointing_ratio is not None @@ -102,22 +87,19 @@ def _enable_gradient_checkpointing_ratio(self) -> bool: def _enable_customized_ckpt_layers_per_stage(self) -> bool: return self.num_ckpt_layers_per_stage is not None - def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: - assert self.control_distribute_layers - assert num_layers == self.num_model_layers and num_stages == self.num_stages * self.num_model_chunks - return self.num_layers_per_stage - def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int: - if self.control_distribute_layers: - assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks - assert num_layers == self.num_layers_per_stage[stage + model_chunk_id * self.num_stages] + if not self.control_gradient_checkpointing: + raise RuntimeError("No checkpointed layers information is provided") if self._enable_customized_ckpt_layers_per_stage: - return self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages] + assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks + num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages] + assert num_ckpt_layers <= num_layers + return num_ckpt_layers elif self._enable_gradient_checkpointing_ratio: return int(self.gradient_checkpointing_ratio * num_layers) else: - raise RuntimeError("No checkpointed layers information is provided") + raise NotImplementedError() @dataclass @@ -134,7 +116,7 @@ class ShardConfig: enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. - advanced_pipeline_config (Optional[AdvancedPipelineConfig]): The advanced pipeline config for more flexibility in pipeline parallelism. Defaults to None. + pipeline_gradient_config (Optional[PipelineGradientConfig]): The pipeline gradient config for more flexibility in pipeline parallelism. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None @@ -147,7 +129,7 @@ class ShardConfig: enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False parallel_output: bool = True - advanced_pipeline_config: Optional[AdvancedPipelineConfig] = None + pipeline_gradient_config: Optional[PipelineGradientConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) # TODO padding vocab # make_vocab_size_divisible_by: int = 128 From 421f3b0bc04422f86e7194be8e7b07d67572f479 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 28 Mar 2024 17:31:56 +0800 Subject: [PATCH 11/21] fix: fix changed API calls --- colossalai/pipeline/stage_manager.py | 10 ++++---- colossalai/shardformer/modeling/llama.py | 9 +++---- colossalai/shardformer/policies/llama.py | 30 +++++++----------------- 3 files changed, 19 insertions(+), 30 deletions(-) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 4751ee80ba07..77bf6002294f 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -78,6 +78,7 @@ def control_distribute_layers(self) -> bool: def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None: """Set the distribution configuration. + This allows user to customize the number of layers for each stage. Args: num_model_layers (int): Number of layers in the model. @@ -96,15 +97,16 @@ def distribute_layers(self, num_layers: int) -> List[int]: return self.num_layers_per_stage else: - quotient = num_layers // self.num_stages - remainder = num_layers % self.num_stages + num_model_chunk = self.num_model_chunks if self.is_interleave else 1 + quotient = num_layers // (self.num_stages * num_model_chunk) + remainder = num_layers % (self.num_stages * num_model_chunk) # calculate the num_layers per stage - layers_per_stage = [quotient] * self.num_stages + layers_per_stage = [quotient] * self.num_stages * num_model_chunk # deal with the rest layers if remainder > 0: - start_position = self.num_stages // 2 - remainder // 2 + start_position = (self.num_stages * num_model_chunk) // 2 - remainder // 2 for i in range(start_position, start_position + remainder): layers_per_stage[i] += 1 return layers_per_stage diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 2ef2b5f0fc7d..a55fa439c79f 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -141,14 +141,15 @@ def llama_model_forward( num_ckpt_layers = 0 if self.gradient_checkpointing and self.training: num_ckpt_layers = end_idx - start_idx - if shard_config.advanced_pipeline_config is not None: - advanced_pipeline_config = shard_config.advanced_pipeline_config - if advanced_pipeline_config.control_gradient_checkpointing: - num_ckpt_layers = advanced_pipeline_config.get_num_ckpt_layers( + if shard_config.pipeline_gradient_config is not None: + pipeline_gradient_config = shard_config.pipeline_gradient_config + if pipeline_gradient_config.control_gradient_checkpointing: + num_ckpt_layers = pipeline_gradient_config.get_num_ckpt_layers( stage_manager.stage, end_idx - start_idx, stage_manager.model_chunk_id if stage_manager.is_interleave else 0, ) + assert num_ckpt_layers <= end_idx - start_idx for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index e38c8af03560..18d79f84a765 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -164,22 +164,15 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli module = self.model.model if stage_manager.is_interleave: - layers_per_stage = self.distribute_layers( - len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks - ) - stage_manager.stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) } else: - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config @@ -201,15 +194,8 @@ def get_held_layers(self) -> List[Module]: held_layers = [] if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None - layers_per_stage = self.distribute_layers( - len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks - ) - stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) if stage_manager.is_first_stage(ignore_chunk=True): held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: @@ -218,10 +204,10 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.norm) else: - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) From 08c03cfbe0fa0f9843cb68dac168e5f39ffdc842 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 28 Mar 2024 17:32:34 +0800 Subject: [PATCH 12/21] test: update llama tests --- .../test_shardformer/test_model/test_shard_llama.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index be4164e62d70..cd4f5272cc7f 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -5,7 +5,7 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import AdvancedPipelineConfig +from colossalai.shardformer import PipelineGradientConfig from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @@ -107,7 +107,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, "enable_gradient_checkpointing": True, - "advanced_pipeline_config": AdvancedPipelineConfig(gradient_checkpointing_ratio=0.5), + "pipeline_gradient_config": PipelineGradientConfig(gradient_checkpointing_ratio=0.5), }, { "tp_size": 1, @@ -116,8 +116,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", "enable_gradient_checkpointing": True, - "advanced_pipeline_config": AdvancedPipelineConfig( - num_stages=2, num_model_chunks=1, num_model_layers=8, num_layers_per_stage=[5, 3] + "pipeline_gradient_config": PipelineGradientConfig( + num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0] ), }, { @@ -201,12 +201,11 @@ def run_llama_test(test_config): "zero_stage": 1, "initial_scale": 1, "enable_gradient_checkpointing": True, - "advanced_pipeline_config": AdvancedPipelineConfig( + "pipeline_gradient_config": PipelineGradientConfig( num_stages=2, num_model_chunks=2, num_model_layers=8, - num_layers_per_stage=[3, 3, 1, 1], - num_ckpt_layers_per_stage=[0, 0, 1, 1], + num_ckpt_layers_per_stage=[0, 1, 2, 2], ), }, ], From b9de48e267b6f11b484894362e09fb5921c1a5d2 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 29 Mar 2024 10:55:08 +0800 Subject: [PATCH 13/21] feat: add `GradCkptConfig` and `GradCkptCollection` --- colossalai/shardformer/__init__.py | 2 +- colossalai/shardformer/shard/__init__.py | 5 +- .../shardformer/shard/grad_ckpt_config.py | 120 ++++++++++++++++++ colossalai/shardformer/shard/shard_config.py | 101 +-------------- 4 files changed, 129 insertions(+), 99 deletions(-) create mode 100644 colossalai/shardformer/shard/grad_ckpt_config.py diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py index 9f49eabaf8f9..e2773b249175 100644 --- a/colossalai/shardformer/__init__.py +++ b/colossalai/shardformer/__init__.py @@ -1 +1 @@ -from .shard import PipelineGradientConfig, ShardConfig, ShardFormer +from .shard import GradCkptCollection, ModelSharder, PipelineGradCkptConfig, ShardConfig, ShardFormer diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py index b3782152e460..547894327220 100644 --- a/colossalai/shardformer/shard/__init__.py +++ b/colossalai/shardformer/shard/__init__.py @@ -1,5 +1,6 @@ -from .shard_config import PipelineGradientConfig, ShardConfig +from .grad_ckpt_config import GradCkptCollection, PipelineGradCkptConfig +from .shard_config import ShardConfig from .sharder import ModelSharder from .shardformer import ShardFormer -__all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "PipelineGradientConfig"] +__all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "PipelineGradCkptConfig", "GradCkptCollection"] diff --git a/colossalai/shardformer/shard/grad_ckpt_config.py b/colossalai/shardformer/shard/grad_ckpt_config.py new file mode 100644 index 000000000000..dd809a2ae702 --- /dev/null +++ b/colossalai/shardformer/shard/grad_ckpt_config.py @@ -0,0 +1,120 @@ +from dataclasses import dataclass, field +from typing import List, Optional + + +@dataclass +class GradCkptConfig: + # TODO: for future use + _dummy_value: Optional[float] = None + + def __post_init__(self): + raise NotImplementedError() + + @property + def control_gradient_checkpointing(self) -> bool: + raise NotImplementedError() + + def get_num_ckpt_layers(self, *args, **kwargs) -> int: + raise NotImplementedError() + + +@dataclass +class GradCkptCollection: + gradient_ckpt_configs: List[GradCkptConfig] = field(default_factory=list) + + def __post_init__(self): + assert all([isinstance(config, GradCkptConfig) for config in self.gradient_ckpt_configs]) + + @property + def control_gradient_checkpointing(self) -> bool: + return any([config.control_gradient_checkpointing for config in self.gradient_ckpt_configs]) + + def get_num_ckpt_layers(self, *args, **kwargs) -> int: + for config in self.gradient_ckpt_configs: + if config.control_gradient_checkpointing: + return config.get_num_ckpt_layers(*args, **kwargs) + raise RuntimeError("No checkpointed layers information is provided") + + +@dataclass +class PipelineGradCkptConfig(GradCkptConfig): + r""" + The pipeline gradient config is designed to provide more flexibility for users to control gradient checkpoint in pipeline parallelism. + Combined with PipelineStageManager.set_distribution_config, user can fully control the distribution of layers and checkpointed layers in pipeline parallelism. + Refer to https://github.com/hpcaitech/ColossalAI/issues/5509 for more details. + + It provides the following features: + 1. `gradient_checkpointing_ratio`: This is used to control gradient checkpointing more precisely, e.g., set 50% of the layers to use gradient checkpointing. + 2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`. + + """ + """ + Args: + gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None. + num_stages (Optional[int]): Number of stages in the pipeline. Defaults to None. For sanity check. + num_model_chunks (Optional[int]): Number of model chunks (1F1B or Interleaved). Defaults to None. For sanity check. + num_model_layers (Optional[int]): Number of model layers. Defaults to None. For sanity check. + num_ckpt_layers_per_stage (Optional[List[int]]): Number of checkpointed layers for each stage. Defaults to None. + + Example 1: + num_stages = 8 + num_layers = 80 + num_model_chunks = 1 + num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11] + num_ckpt_layers_per_stage = [4, 4, 2, 2, 0, 0, 0, 0] + + Example 2: + num_stages = 4 + num_layers = 80 + num_model_chunks = 2 + num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11] + # device 0 holds num_layers_per_stage[0] and num_layers_per_stage[4] layers + ... + + """ + num_stages: Optional[int] = None + num_model_chunks: Optional[int] = None + num_model_layers: Optional[int] = None + num_ckpt_layers_per_stage: Optional[List[int]] = None + gradient_checkpointing_ratio: Optional[float] = None + + def __post_init__(self): + if self._enable_gradient_checkpointing_ratio: + if not (0 <= self.gradient_checkpointing_ratio <= 1): + raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%") + + if self._enable_customized_ckpt_layers_per_stage: + assert ( + self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None + ) + assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks + assert all( + [0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage] + ) + self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers + + @property + def control_gradient_checkpointing(self) -> bool: + return self._enable_gradient_checkpointing_ratio or self._enable_customized_ckpt_layers_per_stage + + @property + def _enable_gradient_checkpointing_ratio(self) -> bool: + return self.gradient_checkpointing_ratio is not None + + @property + def _enable_customized_ckpt_layers_per_stage(self) -> bool: + return self.num_ckpt_layers_per_stage is not None + + def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int: + if not self.control_gradient_checkpointing: + raise RuntimeError("No checkpointed layers information is provided") + + if self._enable_customized_ckpt_layers_per_stage: + assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks + num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages] + assert num_ckpt_layers <= num_layers + return num_ckpt_layers + elif self._enable_gradient_checkpointing_ratio: + return int(self.gradient_checkpointing_ratio * num_layers) + else: + raise NotImplementedError() diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 4b96b7aa702c..c47a8aaa7c86 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,105 +1,14 @@ from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import torch.distributed as dist from torch.distributed import ProcessGroup from colossalai.pipeline.stage_manager import PipelineStageManager -__all__ = ["ShardConfig", "PipelineGradientConfig"] +from .grad_ckpt_config import GradCkptCollection - -class PipelineGradientConfig: - r""" - The pipeline gradient config is designed to provide more flexibility for users to control gradient checkpoint in pipeline parallelism. - Combined with PipelineStageManager.set_distribution_config, user can fully control the distribution of layers and checkpointed layers in pipeline parallelism. - Refer to https://github.com/hpcaitech/ColossalAI/issues/5509 for more details. - - It provides the following features: - 1. `gradient_checkpointing_ratio`: This is used to control gradient checkpointing more precisely, e.g., set 50% of the layers to use gradient checkpointing. - 2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`. - - """ - - def __init__( - self, - gradient_checkpointing_ratio: Optional[float] = None, - num_stages: Optional[int] = None, - num_model_chunks: Optional[int] = None, - num_model_layers: Optional[int] = None, - num_ckpt_layers_per_stage: Optional[List[int]] = None, - ) -> None: - """ - Args: - gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None. - num_stages (Optional[int]): Number of stages in the pipeline. Defaults to None. For sanity check. - num_model_chunks (Optional[int]): Number of model chunks (1F1B or Interleaved). Defaults to None. For sanity check. - num_model_layers (Optional[int]): Number of model layers. Defaults to None. For sanity check. - num_ckpt_layers_per_stage (Optional[List[int]]): Number of checkpointed layers for each stage. Defaults to None. - - Example 1: - num_stages = 8 - num_layers = 80 - num_model_chunks = 1 - num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11] - num_ckpt_layers_per_stage = [4, 4, 2, 2, 0, 0, 0, 0] - - Example 2: - num_stages = 4 - num_layers = 80 - num_model_chunks = 2 - num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11] - # device 0 holds num_layers_per_stage[0] and num_layers_per_stage[4] layers - ... - - """ - self.gradient_checkpointing_ratio = gradient_checkpointing_ratio - self.num_stages = num_stages - self.num_model_chunks = num_model_chunks - self.num_model_layers = num_model_layers - self.num_ckpt_layers_per_stage = num_ckpt_layers_per_stage - self._sanity_check() - - def _sanity_check(self): - if self._enable_gradient_checkpointing_ratio: - if not (0 <= self.gradient_checkpointing_ratio <= 1): - raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%") - - if self._enable_customized_ckpt_layers_per_stage: - assert ( - self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None - ) - assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks - assert all( - [0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage] - ) - self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers - - @property - def control_gradient_checkpointing(self) -> bool: - return self._enable_gradient_checkpointing_ratio or self._enable_customized_ckpt_layers_per_stage - - @property - def _enable_gradient_checkpointing_ratio(self) -> bool: - return self.gradient_checkpointing_ratio is not None - - @property - def _enable_customized_ckpt_layers_per_stage(self) -> bool: - return self.num_ckpt_layers_per_stage is not None - - def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int: - if not self.control_gradient_checkpointing: - raise RuntimeError("No checkpointed layers information is provided") - - if self._enable_customized_ckpt_layers_per_stage: - assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks - num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages] - assert num_ckpt_layers <= num_layers - return num_ckpt_layers - elif self._enable_gradient_checkpointing_ratio: - return int(self.gradient_checkpointing_ratio * num_layers) - else: - raise NotImplementedError() +__all__ = ["ShardConfig"] @dataclass @@ -116,7 +25,7 @@ class ShardConfig: enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. - pipeline_gradient_config (Optional[PipelineGradientConfig]): The pipeline gradient config for more flexibility in pipeline parallelism. Defaults to None. + gradient_ckpt_collection (Optional[GradCkptCollection]): The gradient checkpointing configs. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None @@ -129,7 +38,7 @@ class ShardConfig: enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False parallel_output: bool = True - pipeline_gradient_config: Optional[PipelineGradientConfig] = None + gradient_ckpt_collection: Optional[GradCkptCollection] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) # TODO padding vocab # make_vocab_size_divisible_by: int = 128 From 29f1d5a426eeab225b51e994b185b60c7dfbc0a8 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 29 Mar 2024 11:11:07 +0800 Subject: [PATCH 14/21] fix: fix llama tests --- .../booster/plugin/hybrid_parallel_plugin.py | 8 +++--- colossalai/shardformer/modeling/llama.py | 14 +++++----- .../test_model/test_shard_llama.py | 26 ++++++++++++------- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index fb9cf1a3e3cf..6781ea5458d2 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -26,7 +26,7 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import PipelineGradientConfig, ShardConfig, ShardFormer +from colossalai.shardformer import GradCkptCollection, ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy from colossalai.tensor.d_tensor.api import is_distributed_tensor @@ -930,7 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. - pipeline_gradient_config (AdvancedPipelineConfig, optional): The configuration for pipeline parallelism. Defaults to None. + gradient_ckpt_collection (GradCkptCollection, optional): The configuration for gradient checkpointing. Defaults to None. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. """ @@ -970,7 +970,7 @@ def __init__( custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, - pipeline_gradient_config: Optional[PipelineGradientConfig] = None, + gradient_ckpt_collection: Optional[GradCkptCollection] = None, enable_metadata_cache: bool = True, ) -> None: super().__init__() @@ -1045,7 +1045,7 @@ def __init__( enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, - pipeline_gradient_config=pipeline_gradient_config, + gradient_ckpt_collection=gradient_ckpt_collection, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index a55fa439c79f..b9db6b2129b8 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -141,13 +141,13 @@ def llama_model_forward( num_ckpt_layers = 0 if self.gradient_checkpointing and self.training: num_ckpt_layers = end_idx - start_idx - if shard_config.pipeline_gradient_config is not None: - pipeline_gradient_config = shard_config.pipeline_gradient_config - if pipeline_gradient_config.control_gradient_checkpointing: - num_ckpt_layers = pipeline_gradient_config.get_num_ckpt_layers( - stage_manager.stage, - end_idx - start_idx, - stage_manager.model_chunk_id if stage_manager.is_interleave else 0, + if shard_config.gradient_ckpt_collection is not None: + gradient_ckpt_collection = shard_config.gradient_ckpt_collection + if gradient_ckpt_collection.control_gradient_checkpointing: + num_ckpt_layers = gradient_ckpt_collection.get_num_ckpt_layers( + stage=stage_manager.stage, + num_layers=end_idx - start_idx, + model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0, ) assert num_ckpt_layers <= end_idx - start_idx diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index cd4f5272cc7f..53525e48bcbc 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -5,7 +5,7 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import PipelineGradientConfig +from colossalai.shardformer import GradCkptCollection, PipelineGradCkptConfig from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @@ -107,7 +107,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, "enable_gradient_checkpointing": True, - "pipeline_gradient_config": PipelineGradientConfig(gradient_checkpointing_ratio=0.5), + "gradient_ckpt_collection": GradCkptCollection([PipelineGradCkptConfig(gradient_checkpointing_ratio=0.5)]), }, { "tp_size": 1, @@ -116,8 +116,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", "enable_gradient_checkpointing": True, - "pipeline_gradient_config": PipelineGradientConfig( - num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0] + "gradient_ckpt_collection": GradCkptCollection( + [ + PipelineGradCkptConfig( + num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0] + ) + ] ), }, { @@ -201,11 +205,15 @@ def run_llama_test(test_config): "zero_stage": 1, "initial_scale": 1, "enable_gradient_checkpointing": True, - "pipeline_gradient_config": PipelineGradientConfig( - num_stages=2, - num_model_chunks=2, - num_model_layers=8, - num_ckpt_layers_per_stage=[0, 1, 2, 2], + "gradient_ckpt_collection": GradCkptCollection( + [ + PipelineGradCkptConfig( + num_stages=2, + num_model_chunks=2, + num_model_layers=8, + num_ckpt_layers_per_stage=[0, 1, 2, 2], + ) + ] ), }, ], From 66673f577c239b9f4b0504f2f8e31ad1b27cc74d Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 29 Mar 2024 13:43:46 +0800 Subject: [PATCH 15/21] fix: changed API calls --- .../colossal_moe/models/mixtral_policy.py | 8 +- colossalai/inference/engine/policies/bloom.py | 4 +- .../inference/engine/policies/chatglm2.py | 4 +- colossalai/inference/engine/policies/llama.py | 4 +- colossalai/shardformer/policies/bert.py | 32 ++----- colossalai/shardformer/policies/bloom.py | 8 +- colossalai/shardformer/policies/chatglm2.py | 8 +- colossalai/shardformer/policies/falcon.py | 8 +- colossalai/shardformer/policies/gpt2.py | 30 ++----- colossalai/shardformer/policies/gptj.py | 8 +- colossalai/shardformer/policies/opt.py | 8 +- colossalai/shardformer/policies/vit.py | 8 +- .../language/openmoe/model/openmoe_policy.py | 85 ++++++++++--------- 13 files changed, 95 insertions(+), 120 deletions(-) diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py index 23ffbf5d317c..c01e02c49a60 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py @@ -109,8 +109,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.model - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls @@ -129,10 +129,10 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) diff --git a/colossalai/inference/engine/policies/bloom.py b/colossalai/inference/engine/policies/bloom.py index f35b50189e82..5bc47c3c1a49 100644 --- a/colossalai/inference/engine/policies/bloom.py +++ b/colossalai/inference/engine/policies/bloom.py @@ -114,12 +114,12 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) if stage_manager.is_first_stage(): held_layers.append(module.word_embeddings) held_layers.append(module.word_embeddings_layernorm) held_layers.append(self.model.lm_head) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.h[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.ln_f) diff --git a/colossalai/inference/engine/policies/chatglm2.py b/colossalai/inference/engine/policies/chatglm2.py index 3e1d94f4785c..c7c6f3b927e1 100644 --- a/colossalai/inference/engine/policies/chatglm2.py +++ b/colossalai/inference/engine/policies/chatglm2.py @@ -69,11 +69,11 @@ def get_held_layers(self) -> List[nn.Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(module.num_layers) if stage_manager.is_first_stage(): held_layers.append(module.embedding) held_layers.append(module.output_layer) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.encoder.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): if module.encoder.post_layer_norm: diff --git a/colossalai/inference/engine/policies/llama.py b/colossalai/inference/engine/policies/llama.py index 11517d7e8a13..a57a4e50cdb9 100644 --- a/colossalai/inference/engine/policies/llama.py +++ b/colossalai/inference/engine/policies/llama.py @@ -194,11 +194,11 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) held_layers.append(self.model.lm_head) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 4d50a3c9920c..cd7bdcdd6fb4 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -279,16 +279,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli module = self.model.bert if stage_manager.is_interleave: - layers_per_stage = self.distribute_layers( - len(module.encoder.layer), - stage_manager.num_stages * stage_manager.num_model_chunks, - ) - stage_manager.stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, @@ -298,8 +290,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli } else: - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, @@ -324,16 +316,8 @@ def get_held_layers(self) -> List[Module]: held_layers = [] if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None - layers_per_stage = self.distribute_layers( - len(module.encoder.layer), - stage_manager.num_stages * stage_manager.num_model_chunks, - ) - stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) if stage_manager.is_first_stage(ignore_chunk=True): held_layers.append(module.embeddings) for start_idx, end_idx in stage_indices: @@ -342,10 +326,10 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.pooler) else: - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) if stage_manager.is_first_stage(): held_layers.append(module.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.encoder.layer[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.pooler) diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index e4714c8c1b15..55b69d5f0d29 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -203,8 +203,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.transformer - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config @@ -226,11 +226,11 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) if stage_manager.is_first_stage(): held_layers.append(module.word_embeddings) held_layers.append(module.word_embeddings_layernorm) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.h[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.ln_f) diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index cbe6254d1561..0830d85f1073 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -179,10 +179,10 @@ def get_held_layers(self) -> List[nn.Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(module.num_layers) if stage_manager.is_first_stage(): held_layers.append(module.embedding) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.encoder.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): if module.encoder.post_layer_norm: @@ -204,8 +204,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.transformer - layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(module.num_layers) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 16bbc3f23f81..fe61c406fae3 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -161,8 +161,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.transformer - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config @@ -181,10 +181,10 @@ def get_held_layers(self) -> List[Module]: module = self.model.transformer stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) if stage_manager.is_first_stage(): held_layers.append(module.word_embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.h[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.ln_f) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index d1a8c9dce2c7..4bcac3951a6b 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -185,15 +185,8 @@ def get_held_layers(self) -> List[nn.Module]: held_layers = [] if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None - layers_per_stage = self.distribute_layers( - len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks - ) - stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) if stage_manager.is_first_stage(ignore_chunk=True): held_layers.append(module.wte) held_layers.append(module.wpe) @@ -203,12 +196,12 @@ def get_held_layers(self) -> List[nn.Module]: if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(module.ln_f) else: - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) if stage_manager.is_first_stage(): held_layers.append(module.wte) held_layers.append(module.wpe) held_layers.append(module.drop) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.h[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.ln_f) @@ -226,15 +219,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli module = self.model.transformer if stage_manager.is_interleave: - layers_per_stage = self.distribute_layers( - len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks - ) - stage_manager.stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, @@ -243,8 +229,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli ) } else: - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index b24443298e07..eab4c214a41f 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -179,11 +179,11 @@ def get_held_layers(self) -> List[nn.Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) if stage_manager.is_first_stage(): held_layers.append(module.wte) held_layers.append(module.drop) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.h[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.ln_f) @@ -200,8 +200,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.transformer - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 683f3a9d5a2d..98e584be861b 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -186,12 +186,12 @@ def get_held_layers(self) -> List[nn.Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) held_layers.append(module.embed_positions) held_layers.append(module.project_in) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.final_layer_norm) @@ -208,8 +208,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.model.decoder - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index b0f224e22dc9..905398c4d51e 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -134,10 +134,10 @@ def get_held_layers(self) -> List[nn.Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) if stage_manager.is_first_stage(): held_layers.append(module.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.encoder.layer[start_idx:end_idx]) return held_layers @@ -149,8 +149,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, else: module = self.model.vit - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py index 66a42e0176e9..8ef07bdb91b5 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from typing import Callable, Dict, List, Optional, Union @@ -21,7 +20,6 @@ class OpenMoePolicy(Policy): - def config_sanity_check(self): pass @@ -43,7 +41,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False raise NotImplementedError( - "openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + "openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag." + ) if self.shard_config.enable_tensor_parallelism: raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.") @@ -97,8 +96,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.model - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls @@ -117,10 +116,10 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) @@ -143,7 +142,6 @@ def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: class OpenMoeModelPolicy(OpenMoePolicy): - def __init__(self) -> None: super().__init__() @@ -169,21 +167,21 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class OpenMoeForCausalLMPolicy(OpenMoePolicy): - def module_policy(self): policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = { - OpenMoeForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ + OpenMoeForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True), ) - ]) + ] + ) } policy.update(new_item) @@ -208,13 +206,17 @@ def get_held_layers(self) -> List[Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: llama_model = self.model.model if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: - if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) - and self.pipeline_stage_manager.num_stages > 1): + if ( + id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): # tie weights - return [{ - 0: llama_model.embed_tokens.weight, - self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, - }] + return [ + { + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] return [] @@ -247,12 +249,13 @@ def openmoe_model_forward( logger = logging.get_logger(__name__) - output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if stage_manager.is_first_stage(): @@ -320,7 +323,8 @@ def openmoe_model_forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False # decoder layers @@ -333,12 +337,11 @@ def openmoe_model_forward( if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = (past_key_values[idx] if past_key_values is not None else None) + past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) @@ -384,14 +387,16 @@ def custom_forward(*inputs): router_z_loss = past_router_z_loss + router_z_loss if stage_manager.is_last_stage(): - return tuple([ - hidden_states, - next_cache, - all_hidden_states, - all_self_attns, - router_aux_loss, - router_z_loss, - ]) + return tuple( + [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + router_aux_loss, + router_z_loss, + ] + ) # always return dict for imediate stage return { "hidden_states": hidden_states, @@ -445,10 +450,11 @@ def llama_for_causal_lm_forward( "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." ```""" logger = logging.get_logger(__name__) - output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: @@ -504,7 +510,6 @@ def llama_for_causal_lm_forward( if chunk_head == True: def create_custom_forward(module): - def custom_forward(*inputs): logits = module(inputs[0]) logits = logits.float() @@ -522,8 +527,8 @@ def custom_forward(*inputs): for batch_idx in range(hidden_states.shape[0]): loss = loss + torch.utils.checkpoint.checkpoint( create_custom_forward(self.lm_head), - hidden_states[batch_idx:batch_idx + 1, :], - labels[batch_idx:batch_idx + 1, :], + hidden_states[batch_idx : batch_idx + 1, :], + labels[batch_idx : batch_idx + 1, :], ) logits = None else: From e14c496020bf04c804cc3026af3cd24d92c1b94f Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 29 Mar 2024 14:46:03 +0800 Subject: [PATCH 16/21] fix: add optional args for `distribute_layer` and `get_stage_index` --- colossalai/pipeline/stage_manager.py | 34 ++++++++++++++++------ colossalai/shardformer/policies/t5.py | 19 ++++++++---- colossalai/shardformer/policies/whisper.py | 17 +++++++---- 3 files changed, 49 insertions(+), 21 deletions(-) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 77bf6002294f..b0556669b2bc 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -90,23 +90,29 @@ def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: L self.num_model_layers = num_model_layers self.num_layers_per_stage = num_layers_per_stage - def distribute_layers(self, num_layers: int) -> List[int]: + def distribute_layers( + self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None + ) -> List[int]: """Divide layers into stages""" + num_stages = self.num_stages if num_stages is None else num_stages + num_model_chunks = ( + (self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks + ) + if self.control_distribute_layers: assert num_layers == self.num_model_layers return self.num_layers_per_stage else: - num_model_chunk = self.num_model_chunks if self.is_interleave else 1 - quotient = num_layers // (self.num_stages * num_model_chunk) - remainder = num_layers % (self.num_stages * num_model_chunk) + quotient = num_layers // (num_stages * num_model_chunks) + remainder = num_layers % (num_stages * num_model_chunks) # calculate the num_layers per stage - layers_per_stage = [quotient] * self.num_stages * num_model_chunk + layers_per_stage = [quotient] * num_stages * num_model_chunks # deal with the rest layers if remainder > 0: - start_position = (self.num_stages * num_model_chunk) // 2 - remainder // 2 + start_position = (num_stages * num_model_chunks) // 2 - remainder // 2 for i in range(start_position, start_position + remainder): layers_per_stage[i] += 1 return layers_per_stage @@ -114,6 +120,9 @@ def distribute_layers(self, num_layers: int) -> List[int]: def get_stage_index( self, layers_per_stage: List[int], + stage: Optional[int] = None, + num_model_chunks: Optional[int] = None, + num_stages: Optional[int] = None, ) -> Union[Tuple[int, int], List[Tuple[int, int]]]: """ Get the start index and end index of layers for each stage. @@ -121,19 +130,26 @@ def get_stage_index( Args: layers_per_stage (List[int]): number of layers for each stage stage (int): the stage index + num_stages (int): number of stages + num_model_chunks (int): number of model chunks Returns: - Tuple[int, int]: the start index and end index of this stage - List[Tuple[int, int]]: the start index and end index of this stage for each model chunk """ + stage = self.stage if stage is None else stage + num_model_chunks = ( + (self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks + ) + num_stages = self.num_stages if num_stages is None else num_stages + num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) stage_indices = [] - num_model_chunks = self.num_model_chunks if self.is_interleave else 1 for model_chunk in range(num_model_chunks): - start_idx = num_layers_per_stage_accumulated[self.stage + model_chunk * self.num_stages] - end_idx = num_layers_per_stage_accumulated[self.stage + model_chunk * self.num_stages + 1] + start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] + end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] stage_indices.append([start_idx, end_idx]) return stage_indices[0] if num_model_chunks == 1 else stage_indices diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index f5f701dc0972..0c8ec15fa0a9 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -251,6 +251,8 @@ def distribute_t5_layers( Return the layer distribution as a list and the starting stage of decoder. If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers. """ + stage_manager = self.pipeline_stage_manager + assert stage_manager is not None, "Pipeline stage manager is not set." # number of encoder layers must be a positive integer if num_encoder_layers <= 0: @@ -262,7 +264,7 @@ def distribute_t5_layers( # in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist if num_decoder_layers == 0: - return self.distribute_layers(num_encoder_layers, num_stages), num_stages + return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) @@ -273,21 +275,26 @@ def objective(num_encoder_stages): num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_decoder_stages = num_stages - num_encoder_stages - encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages) - decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages) + encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages) return encoder_distribution + decoder_distribution, num_encoder_stages def get_t5_stage_index( self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int - ) -> Tuple[bool, int, int]: + ) -> Tuple[int, int]: """ Input the distribution of layers among stages, the current stage and the first stage of decoder. Return the starting/ending idx of layers in encoder/decoder """ + stage_manager = self.pipeline_stage_manager + assert stage_manager is not None, "Pipeline stage manager is not set." + if stage < decoder_starting_stage: - return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) else: - return self.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) + return stage_manager.get_stage_index( + layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage + ) def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 480a4beea581..c63f6d1cc549 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -300,6 +300,8 @@ def distribute_whisper_layers( Return the layer distribution as a list and the starting stage of decoder. If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers. """ + stage_manager = self.pipeline_stage_manager + assert stage_manager is not None, "pipeline_stage_manager is None" # number of encoder layers must be a positive integer if num_encoder_layers <= 0: @@ -311,7 +313,7 @@ def distribute_whisper_layers( # in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist if num_decoder_layers == 0: - return self.distribute_layers(num_encoder_layers, num_stages), num_stages + return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) @@ -322,21 +324,24 @@ def objective(num_encoder_stages): num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_decoder_stages = num_stages - num_encoder_stages - encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages) - decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages) + encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages) return encoder_distribution + decoder_distribution, num_encoder_stages def get_whisper_stage_index( self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int - ) -> Tuple[bool, int, int]: + ) -> Tuple[int, int]: """ Input the distribution of layers among stages, the current stage and the first stage of decoder. Return the starting/ending idx of layers in encoder/decoder """ + stage_manager = self.pipeline_stage_manager + assert stage_manager is not None, "pipeline_stage_manager is None" + if stage < decoder_starting_stage: - return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) else: - return self.get_stage_index( + return stage_manager.get_stage_index( layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage, ) From 779c97e3ca780f9efa957dbaa00be5275a5224f3 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 29 Mar 2024 16:06:39 +0800 Subject: [PATCH 17/21] fix: remove `GradCkptCollection` --- .../booster/plugin/hybrid_parallel_plugin.py | 8 ++--- colossalai/shardformer/__init__.py | 2 +- colossalai/shardformer/shard/__init__.py | 4 +-- .../shardformer/shard/grad_ckpt_config.py | 29 +++---------------- colossalai/shardformer/shard/shard_config.py | 6 ++-- .../test_model/test_shard_llama.py | 26 ++++++----------- 6 files changed, 23 insertions(+), 52 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 6781ea5458d2..eba7d1c1f8b8 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -26,7 +26,7 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import GradCkptCollection, ShardConfig, ShardFormer +from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy from colossalai.tensor.d_tensor.api import is_distributed_tensor @@ -930,7 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. - gradient_ckpt_collection (GradCkptCollection, optional): The configuration for gradient checkpointing. Defaults to None. + gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. """ @@ -970,7 +970,7 @@ def __init__( custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, - gradient_ckpt_collection: Optional[GradCkptCollection] = None, + gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, ) -> None: super().__init__() @@ -1045,7 +1045,7 @@ def __init__( enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, - gradient_ckpt_collection=gradient_ckpt_collection, + gradient_checkpoint_config=gradient_checkpoint_config, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py index e2773b249175..234e7131728f 100644 --- a/colossalai/shardformer/__init__.py +++ b/colossalai/shardformer/__init__.py @@ -1 +1 @@ -from .shard import GradCkptCollection, ModelSharder, PipelineGradCkptConfig, ShardConfig, ShardFormer +from .shard import GradientCheckpointConfig, ModelSharder, PipelineGradientCheckpointConfig, ShardConfig, ShardFormer diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py index 547894327220..dff2118c1c1a 100644 --- a/colossalai/shardformer/shard/__init__.py +++ b/colossalai/shardformer/shard/__init__.py @@ -1,6 +1,6 @@ -from .grad_ckpt_config import GradCkptCollection, PipelineGradCkptConfig +from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig from .shard_config import ShardConfig from .sharder import ModelSharder from .shardformer import ShardFormer -__all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "PipelineGradCkptConfig", "GradCkptCollection"] +__all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "PipelineGradientCheckpointConfig", "GradientCheckpointConfig"] diff --git a/colossalai/shardformer/shard/grad_ckpt_config.py b/colossalai/shardformer/shard/grad_ckpt_config.py index dd809a2ae702..12726806e8d5 100644 --- a/colossalai/shardformer/shard/grad_ckpt_config.py +++ b/colossalai/shardformer/shard/grad_ckpt_config.py @@ -1,43 +1,22 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import List, Optional @dataclass -class GradCkptConfig: +class GradientCheckpointConfig: # TODO: for future use _dummy_value: Optional[float] = None - def __post_init__(self): - raise NotImplementedError() - @property def control_gradient_checkpointing(self) -> bool: - raise NotImplementedError() + return False def get_num_ckpt_layers(self, *args, **kwargs) -> int: raise NotImplementedError() @dataclass -class GradCkptCollection: - gradient_ckpt_configs: List[GradCkptConfig] = field(default_factory=list) - - def __post_init__(self): - assert all([isinstance(config, GradCkptConfig) for config in self.gradient_ckpt_configs]) - - @property - def control_gradient_checkpointing(self) -> bool: - return any([config.control_gradient_checkpointing for config in self.gradient_ckpt_configs]) - - def get_num_ckpt_layers(self, *args, **kwargs) -> int: - for config in self.gradient_ckpt_configs: - if config.control_gradient_checkpointing: - return config.get_num_ckpt_layers(*args, **kwargs) - raise RuntimeError("No checkpointed layers information is provided") - - -@dataclass -class PipelineGradCkptConfig(GradCkptConfig): +class PipelineGradientCheckpointConfig(GradientCheckpointConfig): r""" The pipeline gradient config is designed to provide more flexibility for users to control gradient checkpoint in pipeline parallelism. Combined with PipelineStageManager.set_distribution_config, user can fully control the distribution of layers and checkpointed layers in pipeline parallelism. diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index c47a8aaa7c86..646b611932b7 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -6,7 +6,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager -from .grad_ckpt_config import GradCkptCollection +from .grad_ckpt_config import GradientCheckpointConfig __all__ = ["ShardConfig"] @@ -25,7 +25,7 @@ class ShardConfig: enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. - gradient_ckpt_collection (Optional[GradCkptCollection]): The gradient checkpointing configs. Defaults to None. + gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None @@ -38,7 +38,7 @@ class ShardConfig: enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False parallel_output: bool = True - gradient_ckpt_collection: Optional[GradCkptCollection] = None + gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) # TODO padding vocab # make_vocab_size_divisible_by: int = 128 diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 53525e48bcbc..55858cbd4960 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -5,7 +5,7 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import GradCkptCollection, PipelineGradCkptConfig +from colossalai.shardformer import PipelineGradientCheckpointConfig from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @@ -107,7 +107,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, "enable_gradient_checkpointing": True, - "gradient_ckpt_collection": GradCkptCollection([PipelineGradCkptConfig(gradient_checkpointing_ratio=0.5)]), + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), }, { "tp_size": 1, @@ -116,12 +116,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", "enable_gradient_checkpointing": True, - "gradient_ckpt_collection": GradCkptCollection( - [ - PipelineGradCkptConfig( - num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0] - ) - ] + "gradient_checkpoint_config": PipelineGradientCheckpointConfig( + num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0] ), }, { @@ -205,15 +201,11 @@ def run_llama_test(test_config): "zero_stage": 1, "initial_scale": 1, "enable_gradient_checkpointing": True, - "gradient_ckpt_collection": GradCkptCollection( - [ - PipelineGradCkptConfig( - num_stages=2, - num_model_chunks=2, - num_model_layers=8, - num_ckpt_layers_per_stage=[0, 1, 2, 2], - ) - ] + "gradient_checkpoint_config": PipelineGradientCheckpointConfig( + num_stages=2, + num_model_chunks=2, + num_model_layers=8, + num_ckpt_layers_per_stage=[0, 1, 2, 2], ), }, ], From 68990924f700009198f55cb30bcc53bacc25ec23 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 29 Mar 2024 16:23:30 +0800 Subject: [PATCH 18/21] fix: fix llama tests --- colossalai/shardformer/modeling/llama.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index b9db6b2129b8..e65f1d25fc50 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -141,10 +141,11 @@ def llama_model_forward( num_ckpt_layers = 0 if self.gradient_checkpointing and self.training: num_ckpt_layers = end_idx - start_idx - if shard_config.gradient_ckpt_collection is not None: - gradient_ckpt_collection = shard_config.gradient_ckpt_collection - if gradient_ckpt_collection.control_gradient_checkpointing: - num_ckpt_layers = gradient_ckpt_collection.get_num_ckpt_layers( + # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer + if shard_config.gradient_checkpoint_config is not None: + gradient_checkpoint_config = shard_config.gradient_checkpoint_config + if gradient_checkpoint_config.control_gradient_checkpointing: + num_ckpt_layers = gradient_checkpoint_config.get_num_ckpt_layers( stage=stage_manager.stage, num_layers=end_idx - start_idx, model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0, From 4656f1ded640e4c214602510e8823faebf113a4d Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 29 Mar 2024 17:06:10 +0800 Subject: [PATCH 19/21] style: polish `GradientCheckpointConfig` --- colossalai/shardformer/modeling/llama.py | 12 +++++------ .../shardformer/shard/grad_ckpt_config.py | 21 +++++-------------- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index e65f1d25fc50..eb421c92b82c 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -143,13 +143,11 @@ def llama_model_forward( num_ckpt_layers = end_idx - start_idx # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer if shard_config.gradient_checkpoint_config is not None: - gradient_checkpoint_config = shard_config.gradient_checkpoint_config - if gradient_checkpoint_config.control_gradient_checkpointing: - num_ckpt_layers = gradient_checkpoint_config.get_num_ckpt_layers( - stage=stage_manager.stage, - num_layers=end_idx - start_idx, - model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0, - ) + num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( + stage=stage_manager.stage, + num_layers=end_idx - start_idx, + model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0, + ) assert num_ckpt_layers <= end_idx - start_idx for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): diff --git a/colossalai/shardformer/shard/grad_ckpt_config.py b/colossalai/shardformer/shard/grad_ckpt_config.py index 12726806e8d5..03c99e64d686 100644 --- a/colossalai/shardformer/shard/grad_ckpt_config.py +++ b/colossalai/shardformer/shard/grad_ckpt_config.py @@ -4,15 +4,11 @@ @dataclass class GradientCheckpointConfig: - # TODO: for future use - _dummy_value: Optional[float] = None - - @property - def control_gradient_checkpointing(self) -> bool: - return False + gradient_checkpointing_ratio: float = 0.0 def get_num_ckpt_layers(self, *args, **kwargs) -> int: - raise NotImplementedError() + assert self.gradient_checkpointing_ratio == 0.0, "This function should be overridden in derived class" + return 0 @dataclass @@ -55,7 +51,6 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig): num_model_chunks: Optional[int] = None num_model_layers: Optional[int] = None num_ckpt_layers_per_stage: Optional[List[int]] = None - gradient_checkpointing_ratio: Optional[float] = None def __post_init__(self): if self._enable_gradient_checkpointing_ratio: @@ -72,10 +67,6 @@ def __post_init__(self): ) self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers - @property - def control_gradient_checkpointing(self) -> bool: - return self._enable_gradient_checkpointing_ratio or self._enable_customized_ckpt_layers_per_stage - @property def _enable_gradient_checkpointing_ratio(self) -> bool: return self.gradient_checkpointing_ratio is not None @@ -85,7 +76,7 @@ def _enable_customized_ckpt_layers_per_stage(self) -> bool: return self.num_ckpt_layers_per_stage is not None def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int: - if not self.control_gradient_checkpointing: + if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage: raise RuntimeError("No checkpointed layers information is provided") if self._enable_customized_ckpt_layers_per_stage: @@ -93,7 +84,5 @@ def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages] assert num_ckpt_layers <= num_layers return num_ckpt_layers - elif self._enable_gradient_checkpointing_ratio: - return int(self.gradient_checkpointing_ratio * num_layers) else: - raise NotImplementedError() + return int(self.gradient_checkpointing_ratio * num_layers) From 4455e01399011ddfbb618da7c68bed64b86b6eb9 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 29 Mar 2024 17:32:45 +0800 Subject: [PATCH 20/21] fix: fix pipeline utils tests --- .../test_t5_pipeline_utils.py | 25 +++++++++++++++++++ .../test_whisper_pipeline_utils.py | 25 +++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py index 4ba67225f271..1b7b0073f62e 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -1,4 +1,23 @@ +import random + +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.policies.t5 import T5BasePolicy +from colossalai.shardformer.shard.shard_config import ShardConfig + + +class _ShardConfig(ShardConfig): + def __post_init__(self): + pass + + +class _PipelineStageManager(PipelineStageManager): + def __init__(self): + self.is_interleave = False + self.num_layers_per_stage = None + + @property + def num_stages(self): + return random.randint(5, 10) def test_t5_pipeline_distribution(): @@ -10,7 +29,10 @@ def test_t5_pipeline_distribution(): "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], } + stage_manager = _PipelineStageManager() + shard_config = _ShardConfig(pipeline_stage_manager=stage_manager) policy = T5BasePolicy() + policy.set_shard_config(shard_config) for i in range(num_test_cases): _, decoder_starting_stage = policy.distribute_t5_layers( test_dict["num_encoder_layers"][i], @@ -35,7 +57,10 @@ def test_t5_pipeline_layers(): } for i in range(num_test_cases): + stage_manager = _PipelineStageManager() + shard_config = _ShardConfig(pipeline_stage_manager=stage_manager) policy = T5BasePolicy() + policy.set_shard_config(shard_config) layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers( test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py index 0500e46e890a..9f8c1ad32d23 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -1,4 +1,23 @@ +import random + +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.policies.whisper import WhisperPolicy +from colossalai.shardformer.shard.shard_config import ShardConfig + + +class _ShardConfig(ShardConfig): + def __post_init__(self): + pass + + +class _PipelineStageManager(PipelineStageManager): + def __init__(self): + self.is_interleave = False + self.num_layers_per_stage = None + + @property + def num_stages(self): + return random.randint(5, 10) def test_whisper_pipeline_distribution(): @@ -10,7 +29,10 @@ def test_whisper_pipeline_distribution(): "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], } + stage_manager = _PipelineStageManager() + shard_config = _ShardConfig(pipeline_stage_manager=stage_manager) policy = WhisperPolicy() + policy.set_shard_config(shard_config) for i in range(num_test_cases): _, decoder_starting_stage = policy.distribute_whisper_layers( test_dict["num_encoder_layers"][i], @@ -34,7 +56,10 @@ def test_whisper_pipeline_layers(): ], } + stage_manager = _PipelineStageManager() + shard_config = _ShardConfig(pipeline_stage_manager=stage_manager) policy = WhisperPolicy() + policy.set_shard_config(shard_config) for i in range(num_test_cases): layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers( test_dict["num_encoder_layers"][i], From 75d56202a9b197be8fbaa3e809b34fe4e694c677 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 1 Apr 2024 11:23:14 +0800 Subject: [PATCH 21/21] fix: fix base `GradientCheckpointConfig` --- colossalai/shardformer/shard/grad_ckpt_config.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/shard/grad_ckpt_config.py b/colossalai/shardformer/shard/grad_ckpt_config.py index 03c99e64d686..9c6c2b54ea39 100644 --- a/colossalai/shardformer/shard/grad_ckpt_config.py +++ b/colossalai/shardformer/shard/grad_ckpt_config.py @@ -6,9 +6,8 @@ class GradientCheckpointConfig: gradient_checkpointing_ratio: float = 0.0 - def get_num_ckpt_layers(self, *args, **kwargs) -> int: - assert self.gradient_checkpointing_ratio == 0.0, "This function should be overridden in derived class" - return 0 + def get_num_ckpt_layers(self, num_layers: int) -> int: + return int(self.gradient_checkpointing_ratio * num_layers) @dataclass