From 601968bb1a4724ba3c192f8baad4a13911b47e08 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 16 Aug 2023 15:15:32 +0800 Subject: [PATCH 1/2] support DDP for HybridPlugin/add tp+dp tests --- .../booster/plugin/hybrid_parallel_plugin.py | 81 +++++++++++++------ tests/test_shardformer/test_model/_utils.py | 14 ++++ .../test_model/test_shard_bert.py | 17 ++-- .../test_model/test_shard_bloom.py | 19 ++--- .../test_model/test_shard_chatglm.py | 19 ++--- .../test_model/test_shard_gpt2.py | 21 +++-- .../test_model/test_shard_llama.py | 20 ++--- .../test_model/test_shard_opt.py | 20 ++--- .../test_model/test_shard_t5.py | 20 +++-- .../test_model/test_shard_vit.py | 21 ++--- 10 files changed, 152 insertions(+), 100 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 3d45a9112fce..29ee55cf9fa3 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -6,7 +6,8 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from torch.nn import Module +from torch.nn import Module, SyncBatchNorm +from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader @@ -28,7 +29,8 @@ class HybridParallelModule(ModelWrapper): - def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup) -> None: + def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, + ddp_config: dict) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group shardformer = ShardFormer(shard_config) @@ -45,7 +47,15 @@ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp module = module.to(dtype=torch.bfloat16).cuda() else: module = module.cuda() # train without AMP - # TODO(ver217): support TP+DP + + if use_ddp: + + # convert model to sync bn + module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group) + + # wrap the model with PyTorch DDP + module = DDP(module, process_group=dp_group, **ddp_config) + super().__init__(module) def sync_shared_params(self): @@ -68,6 +78,12 @@ def sync_grads(self): dist.all_reduce(p.grad, group=self.dp_group) p.grad.div_(self.dp_group.size()) + def unwrap(self): + module = super().unwrap() + if isinstance(module, DDP): + module = module.module + return module + def init_pipeline_optimizer(optim: Optimizer, model: Module): params = set(model.parameters()) @@ -141,28 +157,32 @@ def __init__( class HybridParallelPlugin(PipelinePluginBase): - def __init__( - self, - tp_size: int, - pp_size: int, - precision: str = 'fp16', - zero_stage: int = 0, - cpu_offload: bool = False, - enable_all_optimization: bool = False, - enable_fused_normalization: bool = False, - enable_flash_attention: bool = False, - enable_jit_fused: bool = False, - enable_sequence_parallelism: bool = False, - num_microbatches: Optional[int] = None, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0, - ) -> None: + def __init__(self, + tp_size: int, + pp_size: int, + precision: str = 'fp16', + zero_stage: int = 0, + cpu_offload: bool = False, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + num_microbatches: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers=True, + bucket_cap_mb=25, + find_unused_parameters=False, + check_reduction=False, + gradient_as_bucket_view=False, + static_graph=False) -> None: super().__init__() assert dist.get_world_size() % ( tp_size * pp_size @@ -208,6 +228,13 @@ def __init__( min_scale=min_scale, max_scale=max_scale, ) + + self.ddp_config = dict(broadcast_buffers=broadcast_buffers, + bucket_cap_mb=bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph) self.max_norm = max_norm @property @@ -241,7 +268,9 @@ def configure( lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: if not isinstance(model, ModelWrapper): - model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group) + use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 + model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, + self.ddp_config) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: if self.precision in ['fp16', 'bf16']: diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 7e1e6f2fe03a..8f43d4147131 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -9,10 +9,12 @@ from torch import distributed as dist from torch.distributed import ProcessGroup from torch.nn import Module +from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Adam, Optimizer from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.lazy import LazyInitContext from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -259,3 +261,15 @@ def check_grad(org_model: Module, assert torch.allclose( org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" + + +def unwrap_model(module: Module, + base_model_class_name: Optional[str] = None, + base_model_attribute_name: Optional[str] = None): + if isinstance(module, HybridParallelModule): + module = module.unwrap() + if base_model_class_name is None: + return module + if module.__class__.__name__ == base_model_class_name: + return module + return getattr(module, base_model_attribute_name, None) diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 0a24e46d28f2..49de9cc0311c 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -15,6 +15,7 @@ check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -44,13 +45,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - if org_model.__class__.__name__ == 'BertModel': - bert = org_model - sharded_bert = sharded_model.unwrap() - else: - bert = org_model.bert - sharded_bert = sharded_model.unwrap().bert + + bert = unwrap_model(org_model, 'BertModel', 'bert') + sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert') col_layer_for_check = ['encoder.layer[0].output.dense'] row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense'] @@ -98,6 +95,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_bert_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index ed0d1d8e401d..af014a8585b5 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -13,6 +13,7 @@ check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -46,12 +47,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - if org_model.__class__.__name__ == 'BloomModel': - bloom = org_model - sharded_bloom = sharded_model.unwrap() - else: - bloom = org_model.transformer - sharded_bloom = sharded_model.unwrap().transformer + bloom = unwrap_model(org_model, 'BloomModel', 'transformer') + sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer') # check grad row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] @@ -97,12 +94,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_all_optimization': True, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_bloom_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index bb77759048b3..210f775b540d 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -14,6 +14,7 @@ check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -48,12 +49,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - if org_model.__class__.__name__ == 'ChatGLMModel': - chatglm_model = org_model - shard_chatglm_model = sharded_model.unwrap() - else: - chatglm_model = org_model.transformer - shard_chatglm_model = sharded_model.unwrap().transformer + chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer') + shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer') # check grad row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] @@ -121,12 +118,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_all_optimization': True, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_chatglm_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index c97702cbb281..97295f72f4e1 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -3,7 +3,6 @@ from torch import distributed as dist import colossalai -from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.logging import disable_existing_loggers from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @@ -15,6 +14,7 @@ check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -48,16 +48,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - def unwrap(module): - if isinstance(module, HybridParallelModule): - module = module.unwrap() - if module.__class__.__name__ == 'GPT2Model': - return module - return module.transformer - # unwrap model - gpt2 = unwrap(org_model) - sharded_gpt2 = unwrap(sharded_model) + gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer') + sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer') col_layer_for_check = ['h[0].mlp.c_fc'] row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] @@ -106,6 +99,12 @@ def unwrap(module): 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', }, { 'tp_size': 4, 'pp_size': 1, @@ -117,8 +116,6 @@ def unwrap(module): @clear_cache_before_run() def run_gpt2_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 30ebdfbe5cd9..a433567b3702 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -16,6 +16,7 @@ check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -52,12 +53,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - if org_model.__class__.__name__ == 'LlamaModel': - llama_model = org_model - shard_llama_model = sharded_model.unwrap() - else: - llama_model = org_model.model - shard_llama_model = sharded_model.unwrap().model + llama_model = unwrap_model(org_model, 'LlamaModel', 'model') + shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model') # check grad row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] @@ -128,13 +125,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 1, 'pp_size': 4, 'num_microbatches': 4, + 'enable_all_optimization': False, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_llama_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 8d1154d82638..2fb14903b6a9 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -16,6 +16,7 @@ check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -51,12 +52,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - if org_model.__class__.__name__ == 'OPTModel': - opt_model = org_model - shard_opt_model = sharded_model.unwrap() - else: - opt_model = org_model.model - shard_opt_model = sharded_model.unwrap().model + opt_model = unwrap_model(org_model, 'OPTModel', 'model') + shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model') # check grad row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens' @@ -123,14 +120,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_all_optimization': True, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_opt_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 066f7ee815b4..234ce812a08c 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -1,5 +1,6 @@ import pytest import torch +from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.logging import disable_existing_loggers @@ -14,6 +15,7 @@ check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -48,8 +50,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - t5 = org_model - sharded_t5 = sharded_model.unwrap() + t5 = unwrap_model(org_model) + sharded_t5 = unwrap_model(sharded_model) row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] @@ -99,17 +101,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 1, 'pp_size': 4, 'num_microbatches': 4, + 'enable_all_optimization': False, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) @clear_cache_before_run() def run_t5_test(test_config): - # TODO(baizhou): add plugin_config for TP+DP after supporting & debugging it - # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} - - # TODO(baizhou): add test_config for flash attention & jit operator after supporting - sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 18df8ef555f2..b9d303841215 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -14,6 +14,7 @@ check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -48,12 +49,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - if org_model.__class__.__name__ == 'ViTModel': - vit_model = org_model - shard_vit_model = sharded_model.unwrap() - else: - vit_model = org_model.vit - shard_vit_model = sharded_model.unwrap().vit + vit_model = unwrap_model(org_model, 'ViTModel', 'vit') + shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit') # check grad row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] @@ -120,15 +117,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_all_optimization': True, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_vit_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - # TODO(baizhou): fix bug when settign lazy_init for Conv2D Layers in ViT models + # TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) From 358dcde0804d0c2b5eadc2eaa98c00d6a07ebc62 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 16 Aug 2023 15:45:13 +0800 Subject: [PATCH 2/2] add docstring for HybridParallelPlugin --- .../booster/plugin/hybrid_parallel_plugin.py | 48 +++++++++++++++++++ tests/test_shardformer/test_model/_utils.py | 1 - 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 29ee55cf9fa3..00c714fe4612 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -156,6 +156,53 @@ def __init__( class HybridParallelPlugin(PipelinePluginBase): + """ + Plugin for Hybrid Parallel Training. + Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. + The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). + + Example: + >>> from colossalai.booster import Booster + >>> from colossalai.booster.plugin import HybridParallelPlugin + + >>> model, train_dataset, optimizer, criterion = ... + >>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2) + + >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + >>> booster = Booster(plugin=plugin) + >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + + Args: + tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. + pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + precision (str, optional): Specifies the precision of parameters during training. + Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. + Defaults to 'fp16'. + zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. + When set to 0, ZeRO will not be used. Defaults to 0. + cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase. + num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. + initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. + min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. + growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. + backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5. + growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000. + hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. + max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. + max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Only for usage of DDP. Defaults to True. + bucket_cap_mb (int, optional): The bucket size in MB. Only for usage of DDP. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters. Only for usage of DDP. Defaults to False. + check_reduction (bool, optional): Whether to check reduction. Only for usage of DDP. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Only for usage of DDP. Defaults to False. + static_graph (bool, optional): Whether to use static graph. Only for usage of DDP. Defaults to False. + """ def __init__(self, tp_size: int, @@ -183,6 +230,7 @@ def __init__(self, check_reduction=False, gradient_as_bucket_view=False, static_graph=False) -> None: + super().__init__() assert dist.get_world_size() % ( tp_size * pp_size diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 8f43d4147131..789b3b24e696 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -9,7 +9,6 @@ from torch import distributed as dist from torch.distributed import ProcessGroup from torch.nn import Module -from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Adam, Optimizer from colossalai.booster import Booster