diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 41bb4c051692..f5980472e107 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -533,10 +533,13 @@ def forward( cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is + # implemented. logger.warning_once( - "CohereModel is using CohereSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + "CohereModel is using CohereSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` " + "does not support `output_attentions=True`. Falling back to the manual attention implementation, " + "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, @@ -583,15 +586,19 @@ def forward( if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom + # attn_mask. Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() - # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather - # relying on the `is_causal` argument. + # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash + # Attention 2 backend, rather relying on the `is_causal` argument. In that case, if using static cache, we need + # to drop the empty KV entries + if causal_mask is None and cache_position is not None and isinstance(past_key_value, StaticCache): + key_states = key_states[:, :, : cache_position[-1] + 1, :] + value_states = value_states[:, :, : cache_position[-1] + 1, :] attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index e5b6b207748a..d5d5d4578f03 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -521,10 +521,13 @@ def forward( cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is + # implemented. logger.warning_once( - "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does " + "not support `output_attentions=True`. Falling back to the manual attention implementation, " + "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, @@ -563,15 +566,19 @@ def forward( if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom + # attn_mask. Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() - # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather - # relying on the `is_causal` argument. + # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash + # Attention 2 backend, rather relying on the `is_causal` argument. In that case, if using static cache, we need + # to drop the empty KV entries + if causal_mask is None and cache_position is not None and isinstance(past_key_value, StaticCache): + key_states = key_states[:, :, : cache_position[-1] + 1, :] + value_states = value_states[:, :, : cache_position[-1] + 1, :] attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 905edf5f71a6..82f864029807 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -616,10 +616,13 @@ def forward( cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is + # implemented. logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does " + "not support `output_attentions=True`. Falling back to the manual attention implementation, " + "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, @@ -659,15 +662,19 @@ def forward( if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom + # attn_mask. Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() - # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather - # relying on the `is_causal` argument. + # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash + # Attention 2 backend, rather relying on the `is_causal` argument. In that case, if using static cache, we need + # to drop the empty KV entries + # if causal_mask is None and cache_position is not None and isinstance(past_key_value, StaticCache): + # key_states = key_states[:, :, : cache_position[-1] + 1, :] + # value_states = value_states[:, :, : cache_position[-1] + 1, :] attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, @@ -1073,6 +1080,7 @@ def _update_causal_mask( if self.config._attn_implementation == "sdpa": # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # in order to dispatch on Flash Attention 2. + breakpoint() if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens ): diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index dc24fd848c81..a46194e87660 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -20,9 +20,8 @@ import pytest from parameterized import parameterized -from transformers import LlamaConfig, StaticCache, is_torch_available, logging, set_seed +from transformers import LlamaConfig, is_torch_available, set_seed from transformers.testing_utils import ( - CaptureLogger, require_bitsandbytes, require_flash_attn, require_read_token, @@ -684,17 +683,18 @@ def test_model_13b_greedy_generation(self): @require_torch_gpu @require_read_token def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 NUM_TOKENS_TO_GENERATE = 40 - EXPECTED_TEXT_COMPLETION = { - 7: [ - "Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", - ], - 8: [ - "Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", - ], - } + # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test + # was changed to have a cache of 53 tokens (as opposed to 4096). + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory " + "of relativ", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my " + "fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + ] prompts = [ "Simply put, the theory of relativity states that ", @@ -706,38 +706,25 @@ def test_compile_static_cache(self): ) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - def decode_one_tokens(model, cur_token, input_pos, cache_position): - logits = model( - cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True - )[0] - new_token = torch.argmax(logits[:, -1], dim=-1)[:, None] - return new_token - - batch_size, seq_length = inputs["input_ids"].shape - with torch.no_grad(): - model._setup_cache(StaticCache, 2, max_cache_len=4096) - cache_position = torch.arange(seq_length, device=torch_device) - generated_ids = torch.zeros( - batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device - ) - generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) - - logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] - next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] - generated_ids[:, seq_length] = next_token[:, 0] - - decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) - cache_position = torch.tensor([seq_length + 1], device=torch_device) - for _ in range(1, NUM_TOKENS_TO_GENERATE): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): - with CaptureLogger(logging.get_logger(__name__)) as cl: - next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) - self.assertNotIn("skipping cudagraphs due to", cl.out) - generated_ids[:, cache_position] = next_token.int() - cache_position += 1 - - text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text) + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + + # Static Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + + # Static Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) @require_torch