From 86650744eb9cf31ee3baeeaccd83c0bcd7311837 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 23 Apr 2024 15:19:40 +0800 Subject: [PATCH] [shardformer] fix whisper --- colossalai/shardformer/modeling/whisper.py | 61 ++++++------------- tests/kit/model_zoo/transformers/whisper.py | 1 + .../test_model/test_shard_whisper.py | 2 +- 3 files changed, 21 insertions(+), 43 deletions(-) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index ae1772a66848..6d7df963a3a0 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -39,6 +39,8 @@ def _get_attention_mask( hidden_states: torch.Tensor, past_key_values_length: int, attention_mask: Optional[torch.FloatTensor], + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ): batch_size, seq_length = hidden_states.shape[:2] mask_seq_length = past_key_values_length + seq_length @@ -51,12 +53,20 @@ def _get_attention_mask( is_causal=True, ) else: - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, - (batch_size, seq_length), - hidden_states, - past_key_values_length, - ) + input_shape = (batch_size, seq_length) + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and head_mask is None and not output_attentions: + # output_attentions=True & head_mask can not be supported when using SDPA. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, input_shape, hidden_states, past_key_values_length + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, hidden_states, past_key_values_length + ) return attention_mask @@ -700,33 +710,9 @@ def whisper_decoder_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + attention_mask = _get_attention_mask( + self, shard_config, inputs_embeds, past_key_values_length, attention_mask + ) # embed positions if input_ids is not None: @@ -734,14 +720,6 @@ def whisper_decoder_forward( else: positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) - attention_mask = _get_attention_mask( - self, - shard_config, - inputs_embeds, - past_key_values_length, - attention_mask, - ) - hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -758,7 +736,6 @@ def whisper_decoder_forward( "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder." ) input_shape = hidden_states.size()[:-1] - attention_mask = _get_attention_mask( self, shard_config, diff --git a/tests/kit/model_zoo/transformers/whisper.py b/tests/kit/model_zoo/transformers/whisper.py index d69bebe6cc04..0d9a581dfbe9 100644 --- a/tests/kit/model_zoo/transformers/whisper.py +++ b/tests/kit/model_zoo/transformers/whisper.py @@ -66,6 +66,7 @@ def data_gen_for_audio_classification(): encoder_ffn_dim=1536, encoder_layers=2, vocab_size=51866, + _attn_implementation="eager", ) # register the Whisper variants diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 6efb8a922f85..af61e464014f 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -116,7 +116,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 2, "enable_metadata_cache": False, "enable_all_optimization": True, - "use_lazy_init": True, + "use_lazy_init": False, "precision": "fp32", "initial_scale": 1, },