From fc97eab714229b700c406ed403e6b6a61408763b Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 24 Jul 2024 10:54:46 +0200 Subject: [PATCH 1/4] enable flash-attn & static cache --- .../models/gemma2/modeling_gemma2.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 10d00fa460ba..e15ccde2bc1c 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -821,14 +821,9 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] + batch_size, sequence_length = input_tensor.shape[:2] if past_key_values is not None: target_length = past_key_values.get_max_length() else: @@ -838,7 +833,19 @@ def _update_causal_mask( # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing if attention_mask.max() != 0: raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + if self.config._attn_implementation == "flash_attention_2": + raise ValueError("Custom 4D attention mask should not be passed when Flash_attention is used.") causal_mask = attention_mask + + elif self.config._attn_implementation == "flash_attention_2": + # Flash attention is a special case. We cannot skip this step and assign mask=None because the cache used + # by Gemma2 is a static cache which means that right-end values will all be zeros for kv. We will need to + # mask them out and prepare 2D mask for Flash-attention. + causal_mask = torch.ones((batch_size, target_length), dtype=torch.int64, device=device) + if attention_mask.shape[1] <= target_length: + mask_length = attention_mask.shape[-1] + causal_mask[:, :mask_length] = causal_mask[:, :mask_length] * attention_mask + return causal_mask else: causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device From cad10d1195ac98f6a66c8667fa9751576e3414fc Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 24 Jul 2024 12:11:08 +0200 Subject: [PATCH 2/4] this works, not the prev --- .../models/gemma2/modeling_gemma2.py | 26 ++++++++--------- tests/models/gemma2/test_modeling_gemma2.py | 29 +++++++++++++++++++ 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index e15ccde2bc1c..8a63ed488b94 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -324,6 +324,11 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if attention_mask is not None: + seq_len = attention_mask.shape[1] + key_states = key_states[:, :, :seq_len] + value_states = value_states[:, :, :seq_len] + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) @@ -821,9 +826,16 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): + # Flash Attention currently doesn't support static cache but Gemma2 work only with static cache. + # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape + # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible + # as it doesn't cause dynamic control issues. + if self.config._attn_implementation == "flash_attention_2": + return attention_mask + dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min - batch_size, sequence_length = input_tensor.shape[:2] + sequence_length = input_tensor.shape[1] if past_key_values is not None: target_length = past_key_values.get_max_length() else: @@ -833,19 +845,7 @@ def _update_causal_mask( # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing if attention_mask.max() != 0: raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") - if self.config._attn_implementation == "flash_attention_2": - raise ValueError("Custom 4D attention mask should not be passed when Flash_attention is used.") causal_mask = attention_mask - - elif self.config._attn_implementation == "flash_attention_2": - # Flash attention is a special case. We cannot skip this step and assign mask=None because the cache used - # by Gemma2 is a static cache which means that right-end values will all be zeros for kv. We will need to - # mask them out and prepare 2D mask for Flash-attention. - causal_mask = torch.ones((batch_size, target_length), dtype=torch.int64, device=device) - if attention_mask.shape[1] <= target_length: - mask_length = attention_mask.shape[-1] - causal_mask[:, :mask_length] = causal_mask[:, :mask_length] * attention_mask - return causal_mask else: causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 20b8ea3ec5c8..f51b3cb59963 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -16,8 +16,11 @@ import unittest +from pytest import mark + from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline from transformers.testing_utils import ( + require_flash_attn, require_read_token, require_torch, require_torch_gpu, @@ -161,3 +164,29 @@ def test_model_9b_pipeline_bf16(self): self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0]) self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1]) + + @require_read_token + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_model_9b_flash_attn(self): + # See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context + # NOTE: the quality is a lot better whan fp16 is used, and worse for bf16 + model_id = "google/gemma-2-9b" + EXPECTED_TEXTS = [ + 'Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few', + "Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the" + ] # fmt: skip + + model = AutoModelForCausalLM.from_pretrained( + model_id, attn_implementation="flash_attention_2", torch_dtype="float16" + ).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=100, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=False) + print(output_text) + + self.assertEqual(output_text, EXPECTED_TEXTS) From 9c0f447546c2f1763b02acc56829c22b5da0f527 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 29 Jul 2024 10:45:24 +0200 Subject: [PATCH 3/4] fix for sliding window layers --- .../models/gemma2/modeling_gemma2.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 8a63ed488b94..0793cef8e2d7 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -512,16 +512,18 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if ( - self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None - ): # efficient SDPA and no padding - min_dtype = torch.finfo(hidden_states.dtype).min - sliding_window_mask = torch.tril( - torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window - ) - attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - if attention_mask.shape[-1] <= 1: # when decoding - attention_mask = attention_mask[:, :, :, -self.sliding_window :] + if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding + # Flash-attn is a 2D tensor + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask[:, -self.sliding_window :] + else: + min_dtype = torch.finfo(hidden_states.dtype).min + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window + ) + attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) + if attention_mask.shape[-1] <= 1: # when decoding + attention_mask = attention_mask[:, :, :, -self.sliding_window :] residual = hidden_states From dc9266f927f4e433bafacadf39975a1050911907 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 31 Jul 2024 07:11:39 +0200 Subject: [PATCH 4/4] not needed anymore --- tests/models/gemma2/test_modeling_gemma2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index f51b3cb59963..1229ca47eb69 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -172,7 +172,6 @@ def test_model_9b_pipeline_bf16(self): @slow def test_model_9b_flash_attn(self): # See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context - # NOTE: the quality is a lot better whan fp16 is used, and worse for bf16 model_id = "google/gemma-2-9b" EXPECTED_TEXTS = [ 'Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',