From 26c7fa9edb73b9f8ffd99f23208dd1f3c392a5bf Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 14 Sep 2023 16:34:27 +0800 Subject: [PATCH 1/2] add custom policy --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index fc04f3ecd8e7..efe02a893510 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -22,6 +22,7 @@ from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.base_policy import Policy from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase @@ -38,13 +39,13 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): class HybridParallelModule(ModelWrapper): def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, - ddp_config: dict) -> None: + ddp_config: dict, custom_policy: Policy) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group shardformer = ShardFormer(shard_config) - module, self.shared_params = shardformer.optimize(module) + module, self.shared_params = shardformer.optimize(module, policy=custom_policy) # setting process groups for shared parameters self.shared_param_process_groups = [] @@ -268,6 +269,7 @@ class HybridParallelPlugin(PipelinePluginBase): cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. + custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. """ def __init__(self, @@ -300,7 +302,8 @@ def __init__(self, zero_bucket_size_in_m: int = 12, cpu_offload: bool = False, communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True) -> None: + overlap_communication: bool = True, + custom_policy: Policy = None) -> None: super().__init__() assert dist.get_world_size() % ( @@ -324,6 +327,7 @@ def __init__(self, self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.stage_manager = None self.schedule = None + self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' @@ -403,7 +407,7 @@ def configure( if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, - self.ddp_config) + self.ddp_config, self.custom_policy) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: if self.precision in ['fp16', 'bf16']: From e4c21ba867899a3acce316a11ac4909b258b43ff Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 14 Sep 2023 16:56:45 +0800 Subject: [PATCH 2/2] update assert --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index efe02a893510..efa36b502481 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -45,6 +45,8 @@ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp self.dp_group = dp_group shardformer = ShardFormer(shard_config) + if custom_policy is not None: + assert isinstance(custom_policy, object) module, self.shared_params = shardformer.optimize(module, policy=custom_policy) # setting process groups for shared parameters