From 739af907d6770838e348831d286b521eca693f54 Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Thu, 21 Mar 2024 12:01:45 +0800 Subject: [PATCH 1/6] flash_attention forward upgrade --- colossalai/shardformer/modeling/llama.py | 29 ++++++++++++++---------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index eb8e9f748527..533290538c9a 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -11,6 +11,7 @@ ) from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel from transformers.utils import logging +from transformers.cache_utils import Cache from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig @@ -438,11 +439,15 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) bsz, q_len, _ = hidden_states.size() assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." @@ -452,23 +457,23 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads - if llama_version == 2: - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) From 976396c03ee5c3cb59b38ecf1ae7bb53ab729839 Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Mon, 25 Mar 2024 10:36:25 +0800 Subject: [PATCH 2/6] llama_model_forward --- colossalai/shardformer/modeling/llama.py | 65 ++++++++++++------------ 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 533290538c9a..b9b1cc1424dc 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -20,7 +20,7 @@ from ..layer._operation import gather_forward_split_backward try: - from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask + from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa LATEST_VERSION = True except ImportError: @@ -63,13 +63,13 @@ def llama_model_forward( # retrieve input_ids and inputs_embeds if stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: - batch_size, seq_length = input_ids.shape + batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape + batch_size, seq_length, _ = inputs_embeds.shape[:2] else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + raise ValueError("You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -101,24 +101,28 @@ def llama_model_forward( position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + position_ids = position_ids.unsqueeze(0) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device - ) if LATEST_VERSION: - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) - else: - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) + if self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -130,37 +134,32 @@ def llama_model_forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None start_idx, end_idx = stage_index[0], stage_index[1] for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None + #past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, - None, + past_key_values, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) @@ -168,7 +167,7 @@ def custom_forward(*inputs): hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) From 63ef374bdc49d54f3267c9c3af3fb1e23b202943 Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Mon, 25 Mar 2024 14:24:57 +0800 Subject: [PATCH 3/6] remove useless comment --- colossalai/shardformer/modeling/llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index b9b1cc1424dc..f88bbec27c43 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -141,8 +141,6 @@ def llama_model_forward( if output_hidden_states: all_hidden_states += (hidden_states,) - #past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( From b00f9ea2dbfb161f17a226a058ea0441dfaa049e Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Mon, 25 Mar 2024 17:12:37 +0800 Subject: [PATCH 4/6] update the requirements.txt --- requirements/requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 4136cefc3c37..d42cb08a792f 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -3,7 +3,7 @@ pytest coverage==7.2.3 git+https://github.com/hpcaitech/pytest-testmon torchvision -transformers==4.33.0 +transformers==4.36.0 timm titans torchaudio From dc8b9d46f1ea40c80c043953e848337c8025ad48 Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Mon, 25 Mar 2024 17:52:35 +0800 Subject: [PATCH 5/6] add the transformers version requirements --- requirements/requirements-test.txt | 1 - requirements/requirements.txt | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index d42cb08a792f..0b15b9311937 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -3,7 +3,6 @@ pytest coverage==7.2.3 git+https://github.com/hpcaitech/pytest-testmon torchvision -transformers==4.36.0 timm titans torchaudio diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 095617d76355..38b8f66a8f8b 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -16,3 +16,4 @@ ray sentencepiece google protobuf +transformers==4.36.0 From 9206dd1c5b332f272f853da2a247757a5a2c8a06 Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Tue, 26 Mar 2024 15:50:36 +0800 Subject: [PATCH 6/6] remove the LATEST VERSION try --- colossalai/shardformer/modeling/llama.py | 34 +++++++++--------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index f88bbec27c43..5e7540b96598 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -19,12 +19,7 @@ from ..layer import cross_entropy_1d from ..layer._operation import gather_forward_split_backward -try: - from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa - - LATEST_VERSION = True -except ImportError: - LATEST_VERSION = False +from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa class LlamaPipelineForwards: @@ -106,23 +101,20 @@ def llama_model_forward( if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - - # embed positions, for the first stage, hidden_states is the input embeddings, - # for the other stages, hidden_states is the output of the previous stage - if LATEST_VERSION: - if self._use_sdpa and not output_attentions: + elif self._use_sdpa and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) if self.gradient_checkpointing and self.training: if use_cache: