From d69f88fe6dd1a4df428733a564858c87a4fa9715 Mon Sep 17 00:00:00 2001 From: eric8607242 Date: Tue, 5 Sep 2023 10:23:10 +0800 Subject: [PATCH 1/2] Enable policy assignment in HybridPlugin and enable llama policy for llamav2 --- .../booster/plugin/hybrid_parallel_plugin.py | 13 ++++++++----- colossalai/shardformer/__init__.py | 1 + colossalai/shardformer/policies/llama.py | 18 ++++++++++++------ 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 8ad9b795692a..67027bdcb6d7 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -21,7 +21,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer import ShardConfig, ShardFormer, Policy from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase @@ -38,13 +38,13 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): class HybridParallelModule(ModelWrapper): def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, - ddp_config: dict) -> None: + ddp_config: dict, policy: Optional[Policy]=None) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group shardformer = ShardFormer(shard_config) - module, self.shared_params = shardformer.optimize(module) + module, self.shared_params = shardformer.optimize(module, policy=policy) # setting process groups for shared parameters self.shared_param_process_groups = [] @@ -268,6 +268,7 @@ class HybridParallelPlugin(PipelinePluginBase): cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. + policy (Policy, optional): Shardformer policy when using custom model. If not specified, ShardFormer will try to fetch the policy automatically. """ def __init__(self, @@ -299,7 +300,8 @@ def __init__(self, zero_bucket_size_in_m: int = 12, cpu_offload: bool = False, communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True) -> None: + overlap_communication: bool = True, + policy: Optional[Policy] = None) -> None: super().__init__() assert dist.get_world_size() % ( @@ -323,6 +325,7 @@ def __init__(self, self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.stage_manager = None self.schedule = None + self.policy = policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' @@ -401,7 +404,7 @@ def configure( if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, - self.ddp_config) + self.ddp_config, self.policy) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: if self.precision in ['fp16', 'bf16']: diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py index 77c2af8d18f7..acedd0b96727 100644 --- a/colossalai/shardformer/__init__.py +++ b/colossalai/shardformer/__init__.py @@ -1 +1,2 @@ from .shard import ShardConfig, ShardFormer +from .policies.base_policy import Policy diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index c417e5d017bd..875c8747633d 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -40,14 +40,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: self.shard_config.enable_sequence_parallelism = False warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: + decoder_attribute_replacement = { + "self_attn.hidden_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = \ + self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + policy[LlamaDecoderLayer] = ModulePolicyDescription( - attribute_replacement={ - "self_attn.hidden_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, + attribute_replacement=decoder_attribute_replacement, sub_module_replacement=[ SubModuleReplacementDescription( suffix="self_attn.q_proj", From 028be031bbb89a111f2445f3f5ebd9a3ca5bbb09 Mon Sep 17 00:00:00 2001 From: eric8607242 Date: Tue, 5 Sep 2023 21:30:33 +0800 Subject: [PATCH 2/2] Remove Policy from Plugin --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 8 ++++---- colossalai/shardformer/__init__.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 67027bdcb6d7..76a36790728c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -21,7 +21,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import ShardConfig, ShardFormer, Policy +from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase @@ -38,7 +38,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): class HybridParallelModule(ModelWrapper): def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, - ddp_config: dict, policy: Optional[Policy]=None) -> None: + ddp_config: dict, policy: Optional[object]=None) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group @@ -268,7 +268,7 @@ class HybridParallelPlugin(PipelinePluginBase): cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. - policy (Policy, optional): Shardformer policy when using custom model. If not specified, ShardFormer will try to fetch the policy automatically. + policy (object, optional): Shardformer policy when using custom model. If not specified, ShardFormer will try to fetch the policy automatically. """ def __init__(self, @@ -301,7 +301,7 @@ def __init__(self, cpu_offload: bool = False, communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, - policy: Optional[Policy] = None) -> None: + policy: Optional[object] = None) -> None: super().__init__() assert dist.get_world_size() % ( diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py index acedd0b96727..77c2af8d18f7 100644 --- a/colossalai/shardformer/__init__.py +++ b/colossalai/shardformer/__init__.py @@ -1,2 +1 @@ from .shard import ShardConfig, ShardFormer -from .policies.base_policy import Policy