From 1b0c28a8d76ace8bd895bc74b70dfd1ac81a26b7 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 11 Jan 2024 19:07:45 +0800 Subject: [PATCH 1/2] [ci] fix shardformer tests. (#5255) * fix ci fix * revert: revert p2p * feat: add enable_metadata_cache option * revert: enable t5 tests --------- Co-authored-by: Wenhao Chen --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 8 +++++++- tests/test_shardformer/test_model/test_shard_gpt2.py | 4 ++-- tests/test_shardformer/test_model/test_shard_t5.py | 6 ++++++ tests/test_shardformer/test_model/test_shard_whisper.py | 5 +++++ 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 205660f946e9..8ee1e97c6ce3 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -919,6 +919,7 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. + enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. """ def __init__( @@ -956,6 +957,7 @@ def __init__( custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, + enable_metadata_cache: bool = True, ) -> None: super().__init__() assert ( @@ -1002,10 +1004,14 @@ def __init__( num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, ) elif pp_style == "1f1b": self.schedule = OneForwardOneBackwardSchedule( - self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size + stage_manager=self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, ) else: raise NotImplementedError() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 66b30641acc8..3155420f1cf2 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -165,7 +165,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) @clear_cache_before_run() def run_gpt2_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") + sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -200,7 +200,7 @@ def run_gpt2_test(test_config): ) @clear_cache_before_run() def run_gpt2_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") + sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 73f203d1f023..22c201458ad4 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -86,6 +86,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 2, "num_microbatches": 2, + "enable_metadata_cache": False, "enable_all_optimization": True, "use_lazy_init": True, "precision": "fp16", @@ -95,6 +96,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 1, "pp_size": 2, "num_microbatches": 4, + "enable_metadata_cache": False, "use_lazy_init": False, "precision": "fp16", "initial_scale": 1, @@ -110,6 +112,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 1, "pp_size": 4, "num_microbatches": 4, + "enable_metadata_cache": False, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", @@ -128,6 +131,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 1, "pp_size": 2, "num_microbatches": 2, + "enable_metadata_cache": False, "enable_all_optimization": True, "use_lazy_init": True, "zero_stage": 1, @@ -159,6 +163,7 @@ def run_t5_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 4, + "enable_metadata_cache": False, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", @@ -168,6 +173,7 @@ def run_t5_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 4, + "enable_metadata_cache": False, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp16", diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index f839bd84ab69..6efb8a922f85 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -114,6 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 2, "num_microbatches": 2, + "enable_metadata_cache": False, "enable_all_optimization": True, "use_lazy_init": True, "precision": "fp32", @@ -123,6 +124,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 1, "pp_size": 2, "num_microbatches": 4, + "enable_metadata_cache": False, "use_lazy_init": False, "precision": "fp32", "initial_scale": 1, @@ -138,6 +140,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 1, "pp_size": 4, "num_microbatches": 4, + "enable_metadata_cache": False, "use_lazy_init": False, "precision": "fp32", }, @@ -163,6 +166,7 @@ def run_whisper_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 4, + "enable_metadata_cache": False, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", @@ -172,6 +176,7 @@ def run_whisper_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 2, + "enable_metadata_cache": False, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", From 7b38faa51e41c2a26675eadb4e815ecf85b688ca Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 29 Jan 2024 16:04:56 +0800 Subject: [PATCH 2/2] fix t5 test --- colossalai/shardformer/modeling/t5.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index f67aa84e4e72..dcb1785207eb 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -593,10 +593,6 @@ def t5_encoder_model_forward( def get_t5_flash_attention_forward(): - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") from transformers.models.t5.modeling_t5 import T5Attention def forward( @@ -632,11 +628,11 @@ def forward( def shape(states): """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim) + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) def unshape(states): """reshape""" - return states.view(batch_size, -1, self.inner_dim) + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) def project(hidden_states, proj_layer, key_value_states, past_key_value): """projects hidden states correctly to key/query states""" @@ -653,8 +649,8 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if key_value_states is None: # self-attn # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=1) - elif past_key_value.shape[1] != key_value_states.shape[1]: + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: # checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning # cross-attn @@ -701,10 +697,15 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): else: position_bias_masked = position_bias - position_bias_masked = position_bias_masked.contiguous() - attn_output = me_attention( - query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0 - ) + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=position_bias_masked, + dropout_p=self.dropout, + scale=1.0, + ) attn_output = unshape(attn_output) attn_output = self.o(attn_output)