From 52bab19750c5a3cae28e02b1fb2b02a4a8711add Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 7 Aug 2024 07:24:51 +0200 Subject: [PATCH 1/4] fix check --- src/transformers/modeling_flash_attention_utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 120f60a31afe..f2caab76360f 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -264,11 +264,10 @@ def _flash_attention_forward( ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - # if position_ids is provided and check not all examples (row) contain only 1 sequence, and is in pre-fill/training stage - # then use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach - elif ( - position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all() and query_length != 1 - ): + # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing + # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. + # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach + elif position_ids is not None and not (torch.diff(position_ids) >= 0).all() and query_length != 1: batch_size = query_states.size(0) query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( query_states, key_states, value_states, position_ids From bfe11ad242d243e980a43a2ae7513287730a7b2a Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 8 Aug 2024 11:52:03 +0200 Subject: [PATCH 2/4] add tests --- .../modeling_flash_attention_utils.py | 2 +- src/transformers/models/phi3/modeling_phi3.py | 2 +- tests/test_modeling_common.py | 57 +++++++++++++++++++ 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index f2caab76360f..44e61825dd9c 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -267,7 +267,7 @@ def _flash_attention_forward( # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach - elif position_ids is not None and not (torch.diff(position_ids) >= 0).all() and query_length != 1: + elif position_ids is not None and not (torch.diff(position_ids, dim=-1) >= 0).all() and query_length != 1: batch_size = query_states.size(0) query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( query_states, key_states, value_states, position_ids diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index b9c544587986..40be45e812d0 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -540,7 +540,7 @@ def forward( max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len ) - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4a29942641ea..c1f88cdf33be 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4284,6 +4284,63 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_reuse_cache(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + max_new_tokens = 2 + + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + # run generate once to get filled cache + output = model.generate( + dummy_input, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + return_dict_in_generate=True, + ) + past_key_values = output.past_key_values + + # Try to continue generation from where we left, given that we have more than 1 new token to process + # e.g. this can happen in speculative decoding when feeding candidate tokens back to target model + dummy_input_updated = torch.cat([dummy_input, output.sequences], dim=-1) + _ = model.generate( + dummy_input_updated, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + past_key_values=past_key_values, + ) + @require_flash_attn @require_torch_gpu @require_bitsandbytes From 4b49c85ecaebf979bda972a84b1f04066838b1e2 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 8 Aug 2024 15:52:57 +0200 Subject: [PATCH 3/4] [run-slow] llama, gemma2 --- tests/models/whisper/test_modeling_whisper.py | 10 ++++++++++ tests/test_modeling_common.py | 1 - 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index f43f29f56510..1fcaad460606 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1844,6 +1844,10 @@ def test_generate_output_type(self, return_dict_in_generate): ) assert isinstance(pred_ids, expected_output_type) + @unittest.skip("Reusing cache seems to be not working in encoder-decoder setting") + def test_flash_attn_2_generate_reuse_cache(self): + pass + @require_torch @require_torchaudio @@ -4066,6 +4070,12 @@ def test_retain_grad_hidden_states_attentions(self): def test_save_load_fast_init_from_base(self): pass + @unittest.skip( + reason="FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test" + ) + def test_flash_attn_2_generate_reuse_cache(self): + pass + @unittest.skip( "Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test" ) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index c1f88cdf33be..7a6142813b54 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4293,7 +4293,6 @@ def test_flash_attn_2_generate_reuse_cache(self): self.skipTest(reason="Model architecture does not support attentions") max_new_tokens = 2 - for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") From bdcf2cbe072fea920928e61f20f4091594a3ca2c Mon Sep 17 00:00:00 2001 From: raushan Date: Fri, 9 Aug 2024 09:19:48 +0200 Subject: [PATCH 4/4] oops, whisper actually runs but needed some special treatment --- tests/models/whisper/test_modeling_whisper.py | 53 ++++++++++++++++++- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 1fcaad460606..8eb7262809c7 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1844,9 +1844,58 @@ def test_generate_output_type(self, return_dict_in_generate): ) assert isinstance(pred_ids, expected_output_type) - @unittest.skip("Reusing cache seems to be not working in encoder-decoder setting") + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow def test_flash_attn_2_generate_reuse_cache(self): - pass + max_new_tokens = 2 + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name][..., :10] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + # run generate once to get filled cache + output = model.generate( + dummy_input, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + return_dict_in_generate=True, + ) + past_key_values = output.past_key_values + + # Try to continue generation from where we left, given that we have more than 1 new token to process + # e.g. this can happen in speculative decoding when feeding candidate tokens back to target model + _ = model.generate( + dummy_input, + decoder_input_ids=output.sequences, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + past_key_values=past_key_values, + ) @require_torch