Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 19 additions & 42 deletions colossalai/shardformer/modeling/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def _get_attention_mask(
hidden_states: torch.Tensor,
past_key_values_length: int,
attention_mask: Optional[torch.FloatTensor],
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
batch_size, seq_length = hidden_states.shape[:2]
mask_seq_length = past_key_values_length + seq_length
Expand All @@ -51,12 +53,20 @@ def _get_attention_mask(
is_causal=True,
)
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
)
input_shape = (batch_size, seq_length)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, input_shape, hidden_states, past_key_values_length
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, hidden_states, past_key_values_length
)
return attention_mask


Expand Down Expand Up @@ -700,48 +710,16 @@ def whisper_decoder_forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)

if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
attention_mask = _get_attention_mask(
self, shard_config, inputs_embeds, past_key_values_length, attention_mask
)

# embed positions
if input_ids is not None:
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)

attention_mask = _get_attention_mask(
self,
shard_config,
inputs_embeds,
past_key_values_length,
attention_mask,
)

hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

Expand All @@ -758,7 +736,6 @@ def whisper_decoder_forward(
"hidden_states shouldn't be None for stages other than the first stage of encoder/decoder."
)
input_shape = hidden_states.size()[:-1]

attention_mask = _get_attention_mask(
self,
shard_config,
Expand Down
1 change: 1 addition & 0 deletions tests/kit/model_zoo/transformers/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def data_gen_for_audio_classification():
encoder_ffn_dim=1536,
encoder_layers=2,
vocab_size=51866,
_attn_implementation="eager",
)

# register the Whisper variants
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/test_shard_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"num_microbatches": 2,
"enable_metadata_cache": False,
"enable_all_optimization": True,
"use_lazy_init": True,
"use_lazy_init": False,
"precision": "fp32",
"initial_scale": 1,
},
Expand Down