From 51ea56f92dac3d66fdbc15c8d7e62e3b7a974848 Mon Sep 17 00:00:00 2001 From: eric8607242 Date: Tue, 5 Sep 2023 10:23:10 +0800 Subject: [PATCH 1/6] Enable policy assignment in HybridPlugin and enable llama policy for llamav2 --- .../booster/plugin/hybrid_parallel_plugin.py | 13 ++++++++----- colossalai/shardformer/__init__.py | 1 + colossalai/shardformer/policies/llama.py | 18 ++++++++++++------ 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index d33e3485c39c..9fbd4f123a15 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -21,7 +21,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer import ShardConfig, ShardFormer, Policy from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase @@ -38,13 +38,13 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): class HybridParallelModule(ModelWrapper): def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, - ddp_config: dict) -> None: + ddp_config: dict, policy: Optional[Policy]=None) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group shardformer = ShardFormer(shard_config) - module, self.shared_params = shardformer.optimize(module) + module, self.shared_params = shardformer.optimize(module, policy=policy) # setting process groups for shared parameters self.shared_param_process_groups = [] @@ -268,6 +268,7 @@ class HybridParallelPlugin(PipelinePluginBase): cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. + policy (Policy, optional): Shardformer policy when using custom model. If not specified, ShardFormer will try to fetch the policy automatically. """ def __init__(self, @@ -300,7 +301,8 @@ def __init__(self, zero_bucket_size_in_m: int = 12, cpu_offload: bool = False, communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True) -> None: + overlap_communication: bool = True, + policy: Optional[Policy] = None) -> None: super().__init__() assert dist.get_world_size() % ( @@ -324,6 +326,7 @@ def __init__(self, self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.stage_manager = None self.schedule = None + self.policy = policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' @@ -403,7 +406,7 @@ def configure( if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, - self.ddp_config) + self.ddp_config, self.policy) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: if self.precision in ['fp16', 'bf16']: diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py index 77c2af8d18f7..acedd0b96727 100644 --- a/colossalai/shardformer/__init__.py +++ b/colossalai/shardformer/__init__.py @@ -1 +1,2 @@ from .shard import ShardConfig, ShardFormer +from .policies.base_policy import Policy diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index c417e5d017bd..875c8747633d 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -40,14 +40,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: self.shard_config.enable_sequence_parallelism = False warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: + decoder_attribute_replacement = { + "self_attn.hidden_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = \ + self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + policy[LlamaDecoderLayer] = ModulePolicyDescription( - attribute_replacement={ - "self_attn.hidden_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, + attribute_replacement=decoder_attribute_replacement, sub_module_replacement=[ SubModuleReplacementDescription( suffix="self_attn.q_proj", From 33ffb67ee630a223fbc60f4f90630f08f8499a14 Mon Sep 17 00:00:00 2001 From: eric8607242 Date: Tue, 5 Sep 2023 21:30:33 +0800 Subject: [PATCH 2/6] Remove Policy from Plugin --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 8 ++++---- colossalai/shardformer/__init__.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 9fbd4f123a15..5d0773a4e244 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -21,7 +21,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import ShardConfig, ShardFormer, Policy +from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase @@ -38,7 +38,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): class HybridParallelModule(ModelWrapper): def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, - ddp_config: dict, policy: Optional[Policy]=None) -> None: + ddp_config: dict, policy: Optional[object]=None) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group @@ -268,7 +268,7 @@ class HybridParallelPlugin(PipelinePluginBase): cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. - policy (Policy, optional): Shardformer policy when using custom model. If not specified, ShardFormer will try to fetch the policy automatically. + policy (object, optional): Shardformer policy when using custom model. If not specified, ShardFormer will try to fetch the policy automatically. """ def __init__(self, @@ -302,7 +302,7 @@ def __init__(self, cpu_offload: bool = False, communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, - policy: Optional[Policy] = None) -> None: + policy: Optional[object] = None) -> None: super().__init__() assert dist.get_world_size() % ( diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py index acedd0b96727..77c2af8d18f7 100644 --- a/colossalai/shardformer/__init__.py +++ b/colossalai/shardformer/__init__.py @@ -1,2 +1 @@ from .shard import ShardConfig, ShardFormer -from .policies.base_policy import Policy From c1112606a80669849afb8b0fb5c482b3a6291325 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 6 Sep 2023 11:18:39 +0800 Subject: [PATCH 3/6] revert changes of plugin HybridParallelModule --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 5d0773a4e244..fbfb906f409b 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -38,13 +38,13 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): class HybridParallelModule(ModelWrapper): def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, - ddp_config: dict, policy: Optional[object]=None) -> None: + ddp_config: dict) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group shardformer = ShardFormer(shard_config) - module, self.shared_params = shardformer.optimize(module, policy=policy) + module, self.shared_params = shardformer.optimize(module) # setting process groups for shared parameters self.shared_param_process_groups = [] @@ -268,7 +268,6 @@ class HybridParallelPlugin(PipelinePluginBase): cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. - policy (object, optional): Shardformer policy when using custom model. If not specified, ShardFormer will try to fetch the policy automatically. """ def __init__(self, @@ -301,8 +300,7 @@ def __init__(self, zero_bucket_size_in_m: int = 12, cpu_offload: bool = False, communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, - policy: Optional[object] = None) -> None: + overlap_communication: bool = True) -> None: super().__init__() assert dist.get_world_size() % ( @@ -326,7 +324,6 @@ def __init__(self, self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.stage_manager = None self.schedule = None - self.policy = policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' @@ -406,7 +403,7 @@ def configure( if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, - self.ddp_config, self.policy) + self.ddp_config) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: if self.precision in ['fp16', 'bf16']: From 9e690ffc87f6da627dd9a0736b5053487acc2ae1 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 6 Sep 2023 11:19:56 +0800 Subject: [PATCH 4/6] revert changes in plugin --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index fbfb906f409b..d33e3485c39c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -38,7 +38,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): class HybridParallelModule(ModelWrapper): def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, - ddp_config: dict) -> None: + ddp_config: dict) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group From 7a964c47f1c2e91490e7dadec0f2f342bb883b7f Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 6 Sep 2023 11:36:12 +0800 Subject: [PATCH 5/6] upgrade transformers --- requirements/requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index ba5ea0936010..2a6e59d32d50 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -4,7 +4,7 @@ pytest coverage==7.2.3 git+https://github.com/hpcaitech/pytest-testmon torchvision -transformers==4.30.2 +transformers timm titans torchaudio From 3ddffa4854a8f94251cf1b7698b58f7e87bd9604 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 6 Sep 2023 18:32:27 +0800 Subject: [PATCH 6/6] revert transformers version --- requirements/requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 2a6e59d32d50..ba5ea0936010 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -4,7 +4,7 @@ pytest coverage==7.2.3 git+https://github.com/hpcaitech/pytest-testmon torchvision -transformers +transformers==4.30.2 timm titans torchaudio