From 677e3be9e3c17ad4de6783f30c0646d116ea4da5 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 11 Jun 2024 17:47:05 +0800 Subject: [PATCH 1/3] [shardformer] fix modeling of gpt2 and gptj --- colossalai/shardformer/modeling/gpt2.py | 5 ++++- colossalai/shardformer/modeling/gptj.py | 20 +++++++++++++++----- colossalai/shardformer/policies/gptj.py | 8 ++------ 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index c49458dbdf55..aa75bab115a7 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -738,7 +738,10 @@ def gpt2_for_sequence_classification_forward( sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 logger.warning_once( diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index 4f4cec8bc81f..facd2fcafbae 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -32,6 +32,7 @@ def _get_attention_mask( hidden_states: torch.Tensor, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]], attention_mask: Optional[torch.FloatTensor], + use_flash_attention_2: bool = False, ) -> Optional[Union[torch.Tensor, dict]]: batch_size, seq_len = hidden_states.shape[:2] past_key_values_length = 0 @@ -47,7 +48,7 @@ def _get_attention_mask( attention_mask, is_causal=True, ) - elif attention_mask is not None: + elif use_flash_attention_2 and attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") attention_mask = attention_mask.view(batch_size, -1) @@ -162,7 +163,9 @@ def gptj_model_forward( output_shape = input_shape + (hidden_states.size(-1),) - attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask) + attention_mask = _get_attention_mask( + self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2 + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -419,7 +422,10 @@ def gptj_for_sequence_classification_forward( sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 logger.warning_once( @@ -712,7 +718,9 @@ def forward( hidden_states = self.drop(hidden_states) - attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask) + attention_mask = _get_attention_mask( + self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2 + ) output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) @@ -886,7 +894,9 @@ def forward( hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) - attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask) + attention_mask = _get_attention_mask( + self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2 + ) if self.gradient_checkpointing and self.training: if use_cache: diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 3315eb1e9256..c394d911e289 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -34,15 +34,11 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel - - ATTN_IMPLEMENTATION = { - "eager": GPTJAttention, - } + from transformers.models.gptj.modeling_gptj import GPTJ_ATTENTION_CLASSES, GPTJBlock, GPTJModel policy = {} - attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + attn_cls = GPTJ_ATTENTION_CLASSES[self.origin_attn_implement] embedding_cls = None if self.shard_config.enable_tensor_parallelism: From 9a8d619c83fbefd53156d9245ca2052c5715a597 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 12 Jun 2024 11:56:18 +0800 Subject: [PATCH 2/3] [shardformer] fix whisper modeling --- colossalai/shardformer/modeling/whisper.py | 32 ++++++++++++++++++---- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 6d7df963a3a0..cf925983be4e 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -17,6 +17,7 @@ SequenceClassifierOutput, ) from transformers.models.whisper.modeling_whisper import ( + _HIDDEN_STATES_START_POSITION, WhisperDecoder, WhisperEncoder, WhisperForAudioClassification, @@ -166,6 +167,7 @@ def forward( cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, + position_ids=None, use_cache=None, output_attentions=None, output_hidden_states=None, @@ -199,9 +201,13 @@ def forward( # embed positions if input_ids is not None: - positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + positions = self.embed_positions( + input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids + ) else: - positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + positions = self.embed_positions( + inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids + ) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -599,6 +605,7 @@ def whisper_decoder_forward( cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, + position_ids=None, use_cache=None, output_attentions=None, output_hidden_states=None, @@ -716,9 +723,13 @@ def whisper_decoder_forward( # embed positions if input_ids is not None: - positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + positions = self.embed_positions( + input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids + ) else: - positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + positions = self.embed_positions( + inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids + ) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -841,6 +852,7 @@ def whisper_model_forward( encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -944,6 +956,7 @@ def whisper_model_forward( cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, + position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -986,6 +999,7 @@ def whisper_for_conditional_generation_forward( encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1048,6 +1062,7 @@ def whisper_for_conditional_generation_forward( cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, decoder_inputs_embeds=decoder_inputs_embeds, + decoder_position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1118,6 +1133,12 @@ def whisper_for_audio_classification_forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + + if self.config.use_weighted_layer_sum: + output_hidden_states = True + elif output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # audio_classification only holds encoder @@ -1138,7 +1159,8 @@ def whisper_for_audio_classification_forward( return encoder_outputs if self.config.use_weighted_layer_sum: - hidden_states = torch.stack(encoder_outputs, dim=1) + hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) else: From f6ba66b237b4f4fd944a221bcbfa750212a1b154 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 12 Jun 2024 14:04:11 +0800 Subject: [PATCH 3/3] [misc] update requirements --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index fa88501ef968..27bbc3769448 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -16,7 +16,7 @@ ray sentencepiece google protobuf -transformers>=4.36.2,<4.40.0 +transformers==4.39.3 peft>=0.7.1 bitsandbytes>=0.39.0 rpyc==6.0.0