From 4795a896e8ee42fd47aabb61c3fcc879e270f847 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 11 Jan 2024 17:20:20 +0800 Subject: [PATCH 1/6] fix ci fix --- colossalai/pipeline/p2p.py | 5 +- .../test_model/test_shard_gpt2.py | 4 +- .../test_model/test_shard_t5.py | 72 ++++++++++--------- 3 files changed, 43 insertions(+), 38 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 5588aa5789a9..8d27e77b7eb5 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -162,7 +162,10 @@ def create_send_metadata( strict (bool, optional): whether to check if the object is supported for fast send return_tensor (bool, optional): whether to return tensor objects """ - objs, tree_spec = tree_flatten(object) + filtered_object = ( + {key: value for key, value in object.items() if value is not None} if isinstance(object, dict) else object + ) + objs, tree_spec = tree_flatten(filtered_object) tensor_metadata, tensor_objs = [], [] non_tensor_obj_idx, non_tensor_objs = [], [] for idx, obj in enumerate(objs): diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 66b30641acc8..3155420f1cf2 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -165,7 +165,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) @clear_cache_before_run() def run_gpt2_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") + sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -200,7 +200,7 @@ def run_gpt2_test(test_config): ) @clear_cache_before_run() def run_gpt2_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") + sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 73f203d1f023..31e5e543b5e9 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -79,26 +79,27 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() +# TODO t5 pipeline parallelism should be fixed, can't send non-tensor data @parameterize( "test_config", [ - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 4, - "use_lazy_init": False, - "precision": "fp16", - "initial_scale": 1, - }, + # { + # "tp_size": 2, + # "pp_size": 2, + # "num_microbatches": 2, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 4, + # "use_lazy_init": False, + # "precision": "fp16", + # "initial_scale": 1, + # }, { "tp_size": 4, "pp_size": 1, @@ -106,14 +107,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", }, - { - "tp_size": 1, - "pp_size": 4, - "num_microbatches": 4, - "enable_all_optimization": False, - "use_lazy_init": False, - "precision": "fp32", - }, + # { + # "tp_size": 1, + # "pp_size": 4, + # "num_microbatches": 4, + # "enable_all_optimization": False, + # "use_lazy_init": False, + # "precision": "fp32", + # }, {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, { "tp_size": 2, @@ -124,16 +125,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 2, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 1, + # "precision": "fp16", + # "initial_scale": 1, + # }, ], ) @clear_cache_before_run() @@ -208,6 +209,7 @@ def test_t5(): @pytest.mark.largedist @rerun_if_address_is_in_use() @clear_cache_before_run() +@pytest.mark.skip(reason="t5 pipeline parallelism should be fixed, can't send non-tensor data") def test_t5_3d(): spawn(check_t5_3d, 8) From 9bc7cbf30616d9a0761a0b99fc217924b519cacc Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 16 Jan 2024 19:12:42 +0800 Subject: [PATCH 2/6] fix test --- .../test_hybrid_parallel_plugin_checkpoint_io.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index c0bc2d2f5d0a..1bba09c003fe 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -104,17 +104,24 @@ def _preprocess_data(data): # Check whether the loaded model & optimizer works smoothly. model.train() new_model.train() + data_for_shard = data_gen_fn() + data_for_origin = data_gen_fn() if booster.plugin.stage_manager is not None: booster.execute_pipeline( - _preprocess_data(data), model, _criterion, optimizer, return_loss=True, return_outputs=False + _preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True, return_outputs=False ) booster.execute_pipeline( - _preprocess_data(data), new_model, _criterion, new_optimizer, return_loss=True, return_outputs=False + _preprocess_data(data_for_origin), + new_model, + _criterion, + new_optimizer, + return_loss=True, + return_outputs=False, ) else: - old_model_loss = criterion(model(**_preprocess_data(data))) + old_model_loss = criterion(model(**_preprocess_data(data_for_shard))) optimizer.backward(old_model_loss) - new_model_loss = criterion(new_model(**_preprocess_data(data))) + new_model_loss = criterion(new_model(**_preprocess_data(data_for_origin))) new_optimizer.backward(new_model_loss) optimizer.step() From afcd24f9bcc75207be05488670ab7c18605cf21b Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 11 Jan 2024 17:52:40 +0800 Subject: [PATCH 3/6] revert: revert p2p --- colossalai/pipeline/p2p.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 8d27e77b7eb5..5588aa5789a9 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -162,10 +162,7 @@ def create_send_metadata( strict (bool, optional): whether to check if the object is supported for fast send return_tensor (bool, optional): whether to return tensor objects """ - filtered_object = ( - {key: value for key, value in object.items() if value is not None} if isinstance(object, dict) else object - ) - objs, tree_spec = tree_flatten(filtered_object) + objs, tree_spec = tree_flatten(object) tensor_metadata, tensor_objs = [], [] non_tensor_obj_idx, non_tensor_objs = [], [] for idx, obj in enumerate(objs): From 62fe3b4ddb10918a30ce5e13ad21edf937c3b498 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 11 Jan 2024 18:02:50 +0800 Subject: [PATCH 4/6] feat: add enable_metadata_cache option --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 8 +++++++- tests/test_shardformer/test_model/test_shard_whisper.py | 5 +++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 205660f946e9..8ee1e97c6ce3 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -919,6 +919,7 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. + enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. """ def __init__( @@ -956,6 +957,7 @@ def __init__( custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, + enable_metadata_cache: bool = True, ) -> None: super().__init__() assert ( @@ -1002,10 +1004,14 @@ def __init__( num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, ) elif pp_style == "1f1b": self.schedule = OneForwardOneBackwardSchedule( - self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size + stage_manager=self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, ) else: raise NotImplementedError() diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index f839bd84ab69..6efb8a922f85 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -114,6 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 2, "num_microbatches": 2, + "enable_metadata_cache": False, "enable_all_optimization": True, "use_lazy_init": True, "precision": "fp32", @@ -123,6 +124,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 1, "pp_size": 2, "num_microbatches": 4, + "enable_metadata_cache": False, "use_lazy_init": False, "precision": "fp32", "initial_scale": 1, @@ -138,6 +140,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 1, "pp_size": 4, "num_microbatches": 4, + "enable_metadata_cache": False, "use_lazy_init": False, "precision": "fp32", }, @@ -163,6 +166,7 @@ def run_whisper_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 4, + "enable_metadata_cache": False, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", @@ -172,6 +176,7 @@ def run_whisper_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 2, + "enable_metadata_cache": False, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", From 2aa9d596bdc368d1849ec6b020dd3015e099c2b8 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 11 Jan 2024 18:14:25 +0800 Subject: [PATCH 5/6] revert: enable t5 tests --- .../test_model/test_shard_t5.py | 78 ++++++++++--------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 31e5e543b5e9..22c201458ad4 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -79,27 +79,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -# TODO t5 pipeline parallelism should be fixed, can't send non-tensor data @parameterize( "test_config", [ - # { - # "tp_size": 2, - # "pp_size": 2, - # "num_microbatches": 2, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 4, - # "use_lazy_init": False, - # "precision": "fp16", - # "initial_scale": 1, - # }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_metadata_cache": False, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_metadata_cache": False, + "use_lazy_init": False, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 4, "pp_size": 1, @@ -107,14 +108,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", }, - # { - # "tp_size": 1, - # "pp_size": 4, - # "num_microbatches": 4, - # "enable_all_optimization": False, - # "use_lazy_init": False, - # "precision": "fp32", - # }, + { + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "enable_metadata_cache": False, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, { "tp_size": 2, @@ -125,16 +127,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 2, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 1, - # "precision": "fp16", - # "initial_scale": 1, - # }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_metadata_cache": False, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, ], ) @clear_cache_before_run() @@ -160,6 +163,7 @@ def run_t5_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 4, + "enable_metadata_cache": False, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", @@ -169,6 +173,7 @@ def run_t5_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 4, + "enable_metadata_cache": False, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp16", @@ -209,7 +214,6 @@ def test_t5(): @pytest.mark.largedist @rerun_if_address_is_in_use() @clear_cache_before_run() -@pytest.mark.skip(reason="t5 pipeline parallelism should be fixed, can't send non-tensor data") def test_t5_3d(): spawn(check_t5_3d, 8) From e25207d41bd09c47a3386db0195f6020c3eefc8e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 17 Jan 2024 10:46:00 +0800 Subject: [PATCH 6/6] fix --- .../test_hybrid_parallel_plugin_checkpoint_io.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 1bba09c003fe..39f9aaf7140e 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -40,7 +40,7 @@ @clear_cache_before_run() @parameterize("shard", [True, False]) -@parameterize("model_name", ["transformers_gpt"]) +@parameterize("model_name", ["transformers_llama_for_casual_lm"]) @parameterize("size_per_shard", [32]) @parameterize("test_config", TEST_CONFIGS) def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): @@ -128,13 +128,8 @@ def _preprocess_data(data): new_optimizer.step() # Check updated weights. - stage_manager = booster.plugin.stage_manager - - if stage_manager is None or stage_manager.is_first_stage(): - assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3) - assert_close_loose( - model.unwrap().h[0].mlp.c_fc.weight.data, new_model.unwrap().h[0].mlp.c_fc.weight.data, atol=5e-3, rtol=5e-3 - ) + for p1, p2 in zip(model.unwrap().parameters(), new_model.unwrap().parameters()): + assert_close_loose(p1, p2, atol=5e-3, rtol=5e-3) dist.barrier() Randomizer.reset_index()