From 6c3ef097e302d9733378579159b5c242356e4d1f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 23 Apr 2024 15:55:35 +0000 Subject: [PATCH 1/4] tmp commit --- .../models/llama/modeling_llama.py | 11 ++- tests/models/llama/test_modeling_llama.py | 69 ++++++++----------- 2 files changed, 35 insertions(+), 45 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 905edf5f71a6..d6c1084e36c2 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -659,15 +659,20 @@ def forward( if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom + # attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() - # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather - # relying on the `is_causal` argument. + # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash + # Attention 2 backend, rather relying on the `is_causal` argument. If using static cache, we need to drop the + # empty KV entries + if causal_mask is None and cache_position is not None: + key_states = key_states[:, :, :cache_position[-1]+1, :] + value_states = value_states[:, :, :cache_position[-1]+1, :] attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index dc24fd848c81..f839faaaa954 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -684,17 +684,18 @@ def test_model_13b_greedy_generation(self): @require_torch_gpu @require_read_token def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 works as + # intended. See https://github.com/pytorch/pytorch/issues/121943 NUM_TOKENS_TO_GENERATE = 40 - EXPECTED_TEXT_COMPLETION = { - 7: [ - "Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", - ], - 8: [ - "Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", - ], - } + # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test + # was changed to have a cache of 53 tokens (as opposed to 4096). + EXPECTED_TEXT_COMPLETION = [ + 'Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial ' + 'reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory ' + 'of relativ', + 'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my ' + 'fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p' + ] prompts = [ "Simply put, the theory of relativity states that ", @@ -706,38 +707,22 @@ def test_compile_static_cache(self): ) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - def decode_one_tokens(model, cur_token, input_pos, cache_position): - logits = model( - cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True - )[0] - new_token = torch.argmax(logits[:, -1], dim=-1)[:, None] - return new_token - - batch_size, seq_length = inputs["input_ids"].shape - with torch.no_grad(): - model._setup_cache(StaticCache, 2, max_cache_len=4096) - cache_position = torch.arange(seq_length, device=torch_device) - generated_ids = torch.zeros( - batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device - ) - generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) - - logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] - next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] - generated_ids[:, seq_length] = next_token[:, 0] - - decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) - cache_position = torch.tensor([seq_length + 1], device=torch_device) - for _ in range(1, NUM_TOKENS_TO_GENERATE): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): - with CaptureLogger(logging.get_logger(__name__)) as cl: - next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) - self.assertNotIn("skipping cudagraphs due to", cl.out) - generated_ids[:, cache_position] = next_token.int() - cache_position += 1 - - text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text) + # Dynamic Cache + # with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + + # Static Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static") + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + + # Static Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static") + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) @require_torch From 872411ca6140143eb6ab9e8a895f999e2c9133e1 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 23 Apr 2024 16:10:03 +0000 Subject: [PATCH 2/4] propagate changes to gemma and cohere --- .../models/cohere/modeling_cohere.py | 21 ++++++++++----- .../models/gemma/modeling_gemma.py | 21 ++++++++++----- .../models/llama/modeling_llama.py | 16 +++++++----- tests/models/llama/test_modeling_llama.py | 26 ++++++++++--------- 4 files changed, 51 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 41bb4c051692..84a7822c2ac7 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -533,10 +533,13 @@ def forward( cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is + # implemented. logger.warning_once( - "CohereModel is using CohereSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + "CohereModel is using CohereSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` " + "does not support `output_attentions=True`. Falling back to the manual attention implementation, " + "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, @@ -583,15 +586,19 @@ def forward( if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom + # attn_mask. Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() - # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather - # relying on the `is_causal` argument. + # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash + # Attention 2 backend, rather relying on the `is_causal` argument. If using static cache, we need to drop the + # empty KV entries + if causal_mask is None and cache_position is not None: + key_states = key_states[:, :, : cache_position[-1] + 1, :] + value_states = value_states[:, :, : cache_position[-1] + 1, :] attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index e5b6b207748a..1fd741971f32 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -521,10 +521,13 @@ def forward( cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is + # implemented. logger.warning_once( - "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does " + "not support `output_attentions=True`. Falling back to the manual attention implementation, " + "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, @@ -563,15 +566,19 @@ def forward( if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom + # attn_mask. Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() - # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather - # relying on the `is_causal` argument. + # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash + # Attention 2 backend, rather relying on the `is_causal` argument. If using static cache, we need to drop the + # empty KV entries + if causal_mask is None and cache_position is not None: + key_states = key_states[:, :, : cache_position[-1] + 1, :] + value_states = value_states[:, :, : cache_position[-1] + 1, :] attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d6c1084e36c2..34f7bba18de5 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -616,10 +616,13 @@ def forward( cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is + # implemented. logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does " + "not support `output_attentions=True`. Falling back to the manual attention implementation, " + "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, @@ -660,8 +663,7 @@ def forward( causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom - # attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. + # attn_mask. Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() @@ -671,8 +673,8 @@ def forward( # Attention 2 backend, rather relying on the `is_causal` argument. If using static cache, we need to drop the # empty KV entries if causal_mask is None and cache_position is not None: - key_states = key_states[:, :, :cache_position[-1]+1, :] - value_states = value_states[:, :, :cache_position[-1]+1, :] + key_states = key_states[:, :, : cache_position[-1] + 1, :] + value_states = value_states[:, :, : cache_position[-1] + 1, :] attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index f839faaaa954..a46194e87660 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -20,9 +20,8 @@ import pytest from parameterized import parameterized -from transformers import LlamaConfig, StaticCache, is_torch_available, logging, set_seed +from transformers import LlamaConfig, is_torch_available, set_seed from transformers.testing_utils import ( - CaptureLogger, require_bitsandbytes, require_flash_attn, require_read_token, @@ -684,17 +683,17 @@ def test_model_13b_greedy_generation(self): @require_torch_gpu @require_read_token def test_compile_static_cache(self): - # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 works as - # intended. See https://github.com/pytorch/pytorch/issues/121943 + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 NUM_TOKENS_TO_GENERATE = 40 # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test # was changed to have a cache of 53 tokens (as opposed to 4096). EXPECTED_TEXT_COMPLETION = [ - 'Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial ' - 'reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory ' - 'of relativ', - 'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my ' - 'fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p' + "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory " + "of relativ", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my " + "fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", ] prompts = [ @@ -708,19 +707,22 @@ def test_compile_static_cache(self): inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) # Dynamic Cache - # with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) # Static Cache - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static") + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) # Static Cache + compile model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static") + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) From 631c2da17c827a954c36cde219fdd37452777dd9 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 24 Apr 2024 10:29:28 +0000 Subject: [PATCH 3/4] empty commit to test a thing in CI From 187bb56689d580867427915aedaecd2466edadb9 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 24 Apr 2024 14:48:06 +0000 Subject: [PATCH 4/4] tmp commit --- src/transformers/models/cohere/modeling_cohere.py | 6 +++--- src/transformers/models/gemma/modeling_gemma.py | 6 +++--- src/transformers/models/llama/modeling_llama.py | 11 ++++++----- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 84a7822c2ac7..f5980472e107 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -594,9 +594,9 @@ def forward( value_states = value_states.contiguous() # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash - # Attention 2 backend, rather relying on the `is_causal` argument. If using static cache, we need to drop the - # empty KV entries - if causal_mask is None and cache_position is not None: + # Attention 2 backend, rather relying on the `is_causal` argument. In that case, if using static cache, we need + # to drop the empty KV entries + if causal_mask is None and cache_position is not None and isinstance(past_key_value, StaticCache): key_states = key_states[:, :, : cache_position[-1] + 1, :] value_states = value_states[:, :, : cache_position[-1] + 1, :] attn_output = torch.nn.functional.scaled_dot_product_attention( diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 1fd741971f32..d5d5d4578f03 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -574,9 +574,9 @@ def forward( value_states = value_states.contiguous() # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash - # Attention 2 backend, rather relying on the `is_causal` argument. If using static cache, we need to drop the - # empty KV entries - if causal_mask is None and cache_position is not None: + # Attention 2 backend, rather relying on the `is_causal` argument. In that case, if using static cache, we need + # to drop the empty KV entries + if causal_mask is None and cache_position is not None and isinstance(past_key_value, StaticCache): key_states = key_states[:, :, : cache_position[-1] + 1, :] value_states = value_states[:, :, : cache_position[-1] + 1, :] attn_output = torch.nn.functional.scaled_dot_product_attention( diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 34f7bba18de5..82f864029807 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -670,11 +670,11 @@ def forward( value_states = value_states.contiguous() # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash - # Attention 2 backend, rather relying on the `is_causal` argument. If using static cache, we need to drop the - # empty KV entries - if causal_mask is None and cache_position is not None: - key_states = key_states[:, :, : cache_position[-1] + 1, :] - value_states = value_states[:, :, : cache_position[-1] + 1, :] + # Attention 2 backend, rather relying on the `is_causal` argument. In that case, if using static cache, we need + # to drop the empty KV entries + # if causal_mask is None and cache_position is not None and isinstance(past_key_value, StaticCache): + # key_states = key_states[:, :, : cache_position[-1] + 1, :] + # value_states = value_states[:, :, : cache_position[-1] + 1, :] attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, @@ -1080,6 +1080,7 @@ def _update_causal_mask( if self.config._attn_implementation == "sdpa": # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # in order to dispatch on Flash Attention 2. + breakpoint() if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens ):