From fa0b05e05278baa83767c91870a4c99761a69cb8 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 25 Apr 2024 09:02:30 +0000 Subject: [PATCH 01/12] working :D --- src/transformers/cache_utils.py | 74 +++++------ src/transformers/generation/utils.py | 41 +++--- .../models/llama/modeling_llama.py | 78 ++++-------- tests/models/llama/test_modeling_llama.py | 118 +++++++++++------- 4 files changed, 154 insertions(+), 157 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 2ed663b26256..5e3535a95136 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -6,7 +6,6 @@ from .configuration_utils import PretrainedConfig from .utils import logging - logger = logging.get_logger(__name__) @@ -61,6 +60,14 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) - return max_length - new_seq_length return previous_seq_length + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + @property def seen_tokens(self): logger.warning_once( @@ -158,14 +165,6 @@ def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" return None - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" legacy_cache = () @@ -332,14 +331,6 @@ def update( return self.key_cache[layer_idx], self.value_cache[layer_idx] - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - class StaticCache(Cache): """ @@ -347,8 +338,7 @@ class StaticCache(Cache): Parameters: config (`PretrainedConfig): - The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads` - required to initialize the static cache. + The configuration file defining the shape-related attributes required to initialize the static cache. max_batch_size (`int`): The maximum batch size with which the model will be used. max_cache_len (`int`): @@ -373,9 +363,18 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) - self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) - self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) + for _ in range(config.num_hidden_layers): + # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. + new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + self.key_cache.append(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_key_cache) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + self.value_cache.append(new_layer_value_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) def update( self, @@ -394,42 +393,31 @@ def update( value_states (`torch.Tensor`): The new value states to cache. layer_idx (`int`): - The index of the layer to cache the states for. Kept for backward compatibility + The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len` - to know how much of the cache it should overwrite. + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. Return: A tuple containing the updated key and value states. """ - new_cache_positions = cache_kwargs.get("cache_position") - k_out = self.key_cache - v_out = self.value_cache + cache_position = cache_kwargs.get("cache_position") + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] - k_out[:, :, new_cache_positions] = key_states - v_out[:, :, new_cache_positions] = value_states + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states return k_out, v_out def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" + """Returns the sequence length of the cached states that were seen by the model.""" # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after # https://github.com/pytorch/pytorch/issues/120248 is fixed - return (self.key_cache[0, 0].any(dim=-1)).sum() + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" + """Returns the maximum sequence length of the cached states.""" return self.max_cache_len - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - device = self.key_cache.device - self.key_cache = self.key_cache.index_select(0, beam_idx.to(device)) - device = self.value_cache.device - self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) - - def to_legacy_cache(self): - """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it""" - return None diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9e6a58d3e5a5..62eebaf0b1f6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1514,19 +1514,30 @@ def generate( input_ids_length=input_ids_length, ) - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: + raise ValueError( + "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " + "Cache object) is unsupported. Please use only one of the two." + ) + elif generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if not self._supports_cache_class: + raise ValueError( + "This model does not support the `cache_implementation` argument. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981." + ) if generation_config.cache_implementation == "static": - if model_kwargs.get("past_key_values", False) is not False: - raise ValueError( - "Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository." - ) cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"] - if not callable(getattr(self, "_setup_cache", None)): - raise ValueError( - "The `generation_config` defines a `cache_implementation` that is not compatible with this model." - " Make sure it has a `_setup_cache` function." - ) - self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) + if hasattr(self.config, "_pre_quantization_dtype"): + cache_dtype = self.config._pre_quantization_dtype + else: + cache_dtype = self.dtype + model_kwargs["past_key_values"] = cache_cls( + config=self.config, + max_batch_size=batch_size, + max_cache_len=generation_config.max_length, + device=self.device, + dtype=cache_dtype, + ) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) @@ -1844,14 +1855,6 @@ def typeerror(): **model_kwargs, ) - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: - if not callable(getattr(self, "_reset_cache", None)): - raise ValueError( - "A `static_cache` was used to generate but there was a failure when trying to release the cache. " - " Make sure this model implements a `_reset_cache` function." - ) - self._reset_cache() - return result def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 905edf5f71a6..4f607f315d5d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -428,6 +428,13 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -809,27 +816,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): - if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - for layer in self.model.layers: - device = layer.input_layernorm.weight.device - if hasattr(self.config, "_pre_quantization_dtype"): - dtype = self.config._pre_quantization_dtype - else: - dtype = layer.self_attn.o_proj.weight.dtype - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype - ) - - def _reset_cache(self): - for layer in self.model.layers: - layer.self_attn.past_key_value = None - LLAMA_INPUTS_DOCSTRING = r""" Args: @@ -944,7 +930,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache | List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -973,23 +959,18 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # embed positions hidden_states = inputs_embeds @@ -1042,7 +1023,7 @@ def forward( next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, DynamicCache) else next_decoder_cache ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -1058,7 +1039,7 @@ def _update_causal_mask( attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, - past_seen_tokens: int, + past_key_values: Cache, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -1070,9 +1051,12 @@ def _update_causal_mask( return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens ): @@ -1081,9 +1065,9 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache + if using_static_cache: + target_length = past_key_values.get_max_length() + else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) @@ -1164,7 +1148,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache | List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1262,13 +1246,6 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - # With static cache, the `past_key_values` is None - # TODO joao: standardize interface for the different Cache classes and remove of this if - has_static_cache = False - if past_key_values is None: - past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None) - has_static_cache = past_key_values is not None - past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): @@ -1327,9 +1304,6 @@ def prepare_inputs_for_generation( elif use_cache: cache_position = cache_position[-input_length:] - if has_static_cache: - past_key_values = None - model_inputs.update( { "position_ids": position_ids, @@ -1388,7 +1362,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache | List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1505,7 +1479,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache | List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index dc24fd848c81..22d6759fe411 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -684,21 +684,22 @@ def test_model_13b_greedy_generation(self): @require_torch_gpu @require_read_token def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 NUM_TOKENS_TO_GENERATE = 40 - EXPECTED_TEXT_COMPLETION = { - 7: [ - "Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", - ], - 8: [ - "Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", - ], - } + # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test + # was changed to have a cache of 53 tokens (as opposed to 4096). + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory " + "of relativ", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my " + "fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + ] prompts = [ "Simply put, the theory of relativity states that ", - "My favorite all time favorite condiment is ketchup.", + "My favorite all time favorite condiment is ketchup.", ] tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="", padding_side="right") model = LlamaForCausalLM.from_pretrained( @@ -706,39 +707,70 @@ def test_compile_static_cache(self): ) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - def decode_one_tokens(model, cur_token, input_pos, cache_position): - logits = model( - cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True - )[0] - new_token = torch.argmax(logits[:, -1], dim=-1)[:, None] - return new_token - - batch_size, seq_length = inputs["input_ids"].shape - with torch.no_grad(): - model._setup_cache(StaticCache, 2, max_cache_len=4096) - cache_position = torch.arange(seq_length, device=torch_device) - generated_ids = torch.zeros( - batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device - ) - generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) - - logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] - next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] - generated_ids[:, seq_length] = next_token[:, 0] - - decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) - cache_position = torch.tensor([seq_length + 1], device=torch_device) - for _ in range(1, NUM_TOKENS_TO_GENERATE): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): - with CaptureLogger(logging.get_logger(__name__)) as cl: - next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) - self.assertNotIn("skipping cudagraphs due to", cl.out) - generated_ids[:, cache_position] = next_token.int() - cache_position += 1 - - text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text) + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + + # Static Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + # Static Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) + + @slow + @require_torch_gpu + @require_read_token + def test_compile_repeated_calls(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + NUM_TOKENS_TO_GENERATE = 40 + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory " + "of relativ", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my " + "fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + ] + + prompts = [ + "Simply put, the theory of relativity states that ", + "My favorite all time favorite condiment is ketchup.", + ] + tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="", padding_side="right") + model = LlamaForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16 + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + + # Static Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + + # Static Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) @require_torch class CodeLlamaIntegrationTest(unittest.TestCase): From 56cf73b2230c146d5257ae2771aff6ff9584e0e8 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 26 Apr 2024 10:18:09 +0000 Subject: [PATCH 02/12] tmp commit --- src/transformers/cache_utils.py | 7 +++--- .../models/llama/modeling_llama.py | 22 +++++++++++-------- tests/models/llama/test_modeling_llama.py | 8 +++---- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 5e3535a95136..ca465123d5bd 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -6,6 +6,7 @@ from .configuration_utils import PretrainedConfig from .utils import logging + logger = logging.get_logger(__name__) @@ -370,11 +371,11 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) - self.key_cache.append(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_key_cache) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) - self.value_cache.append(new_layer_value_cache) + torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) def update( self, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4f607f315d5d..2b75b2b9bfd7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -428,7 +428,6 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " @@ -930,7 +929,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache | List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -959,7 +958,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: @@ -1023,7 +1022,9 @@ def forward( next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, DynamicCache) else next_decoder_cache + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, DynamicCache) + else next_decoder_cache ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -1089,6 +1090,10 @@ def _update_causal_mask( # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input length is deprecated and will be removed in " + "transformers v4.42.0" + ) offset = cache_position[0] else: offset = 0 @@ -1148,7 +1153,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache | List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1263,8 +1268,7 @@ def prepare_inputs_for_generation( # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1362,7 +1366,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache | List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1479,7 +1483,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache | List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 22d6759fe411..d5e5a90fd7bc 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -20,9 +20,8 @@ import pytest from parameterized import parameterized -from transformers import LlamaConfig, StaticCache, is_torch_available, logging, set_seed +from transformers import LlamaConfig, is_torch_available, set_seed from transformers.testing_utils import ( - CaptureLogger, require_bitsandbytes, require_flash_attn, require_read_token, @@ -699,7 +698,7 @@ def test_compile_static_cache(self): prompts = [ "Simply put, the theory of relativity states that ", - "My favorite all time favorite condiment is ketchup.", + "My favorite all time favorite condiment is ketchup.", ] tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="", padding_side="right") model = LlamaForCausalLM.from_pretrained( @@ -744,7 +743,7 @@ def test_compile_repeated_calls(self): prompts = [ "Simply put, the theory of relativity states that ", - "My favorite all time favorite condiment is ketchup.", + "My favorite all time favorite condiment is ketchup.", ] tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="", padding_side="right") model = LlamaForCausalLM.from_pretrained( @@ -772,6 +771,7 @@ def test_compile_repeated_calls(self): static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) + @require_torch class CodeLlamaIntegrationTest(unittest.TestCase): PROMPTS = [ From 97aa785bac68c1bc01167b24420da8c9dec6c19d Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 26 Apr 2024 11:04:00 +0000 Subject: [PATCH 03/12] working on generate too; make fixup --- src/transformers/cache_utils.py | 7 ++ src/transformers/generation/utils.py | 41 +++++++--- .../models/cohere/modeling_cohere.py | 55 ++++---------- .../models/gemma/modeling_gemma.py | 36 ++++----- .../models/jamba/modeling_jamba.py | 4 +- .../models/mistral/modeling_mistral.py | 2 +- .../models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 74 ++++++------------- .../models/persimmon/modeling_persimmon.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 2 +- .../models/stablelm/modeling_stablelm.py | 2 +- .../models/starcoder2/modeling_starcoder2.py | 2 +- 13 files changed, 99 insertions(+), 132 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index ca465123d5bd..22b6a125ece3 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -422,3 +422,10 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states.""" return self.max_cache_len + + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx] *= 0.0 + self.value_cache[layer_idx] *= 0.0 diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 62eebaf0b1f6..1633e41021ae 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1310,6 +1310,34 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) return model_kwargs + def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache: + """ + Sets a static cache for `generate`, that will persist across calls. A new cache will only be initialized a + new `generate` call requires a larger cache. + + Returns the resulting static cache object. + """ + needs_new_cache = ( + not hasattr(self, "_static_cache") + or self._static_cache.max_batch_size < max_batch_size + or self._static_cache.max_cache_len < max_cache_len + ) + if needs_new_cache: + if hasattr(self.config, "_pre_quantization_dtype"): + cache_dtype = self.config._pre_quantization_dtype + else: + cache_dtype = self.dtype + self._static_cache = StaticCache( + config=self.config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=self.device, + dtype=cache_dtype, + ) + else: + self._static_cache.reset() # reset the cache for a new generation + return self._static_cache + @torch.no_grad() def generate( self, @@ -1526,18 +1554,7 @@ def generate( "issue: https://github.com/huggingface/transformers/issues/28981." ) if generation_config.cache_implementation == "static": - cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"] - if hasattr(self.config, "_pre_quantization_dtype"): - cache_dtype = self.config._pre_quantization_dtype - else: - cache_dtype = self.dtype - model_kwargs["past_key_values"] = cache_cls( - config=self.config, - max_batch_size=batch_size, - max_cache_len=generation_config.max_length, - device=self.device, - dtype=cache_dtype, - ) + model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 41bb4c051692..f1222a3081a2 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -732,27 +732,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): - if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - for layer in self.model.layers: - device = layer.input_layernorm.weight.device - if hasattr(self.config, "_pre_quantization_dtype"): - dtype = self.config._pre_quantization_dtype - else: - dtype = layer.self_attn.o_proj.weight.dtype - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype - ) - - def _reset_cache(self): - for layer in self.model.layers: - layer.self_attn.past_key_value = None - COHERE_INPUTS_DOCSTRING = r""" Args: @@ -980,7 +959,7 @@ def _update_causal_mask( attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, - past_seen_tokens: int, + past_key_values: Cache, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -992,9 +971,12 @@ def _update_causal_mask( return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens ): @@ -1003,9 +985,9 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache + if using_static_cache: + target_length = past_key_values.get_max_length() + else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) @@ -1027,6 +1009,10 @@ def _update_causal_mask( # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input length is deprecated and will be removed in " + "transformers v4.42.0" + ) offset = cache_position[0] else: offset = 0 @@ -1184,13 +1170,6 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - # With static cache, the `past_key_values` is None - # TODO joao: standardize interface for the different Cache classes and remove of this if - has_static_cache = False - if past_key_values is None: - past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None) - has_static_cache = past_key_values is not None - past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): @@ -1208,8 +1187,7 @@ def prepare_inputs_for_generation( # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1249,9 +1227,6 @@ def prepare_inputs_for_generation( elif use_cache: cache_position = cache_position[-input_length:] - if has_static_cache: - past_key_values = None - model_inputs.update( { "position_ids": position_ids, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index e5b6b207748a..ff0755794db4 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -966,7 +966,7 @@ def _update_causal_mask( attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, - past_seen_tokens: int, + past_key_values: Cache, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -978,9 +978,12 @@ def _update_causal_mask( return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens ): @@ -989,9 +992,9 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache + if using_static_cache: + target_length = past_key_values.get_max_length() + else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) @@ -1013,6 +1016,10 @@ def _update_causal_mask( # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input length is deprecated and will be removed in " + "transformers v4.42.0" + ) offset = cache_position[0] else: offset = 0 @@ -1166,13 +1173,6 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - # With static cache, the `past_key_values` is None - # TODO joao: standardize interface for the different Cache classes and remove of this if - has_static_cache = False - if past_key_values is None: - past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None) - has_static_cache = past_key_values is not None - past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): @@ -1190,8 +1190,7 @@ def prepare_inputs_for_generation( # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1231,9 +1230,6 @@ def prepare_inputs_for_generation( elif use_cache: cache_position = cache_position[-input_length:] - if has_static_cache: - past_key_values = None - model_inputs.update( { "position_ids": position_ids, @@ -1293,7 +1289,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 80d5dad3cbd8..1dbcbc76f3c2 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import DynamicCache # we need __iter__ and __len__ of pkv +from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ of pkv from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) @@ -1807,7 +1807,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index c013967c78f1..665e95a8fd78 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1301,7 +1301,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index c78e907d5fdb..e5a81c4c9083 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1525,7 +1525,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index e3b0e05127c5..c5f94695f19e 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -790,27 +790,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): - if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - for layer in self.model.layers: - device = layer.input_layernorm.weight.device - if hasattr(self.config, "_pre_quantization_dtype"): - dtype = self.config._pre_quantization_dtype - else: - dtype = layer.self_attn.o_proj.weight.dtype - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype - ) - - def _reset_cache(self): - for layer in self.model.layers: - layer.self_attn.past_key_value = None - OLMO_INPUTS_DOCSTRING = r""" Args: @@ -926,7 +905,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -955,23 +934,18 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # embed positions hidden_states = inputs_embeds @@ -1024,7 +998,9 @@ def forward( next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, DynamicCache) + else next_decoder_cache ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -1041,7 +1017,7 @@ def _update_causal_mask( attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, - past_seen_tokens: int, + past_key_values: Cache, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -1053,9 +1029,12 @@ def _update_causal_mask( return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens ): @@ -1064,9 +1043,9 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache + if using_static_cache: + target_length = past_key_values.get_max_length() + else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) @@ -1088,6 +1067,10 @@ def _update_causal_mask( # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input length is deprecated and will be removed in " + "transformers v4.42.0" + ) offset = cache_position[0] else: offset = 0 @@ -1243,13 +1226,6 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - # With static cache, the `past_key_values` is None - # TODO joao: standardize interface for the different Cache classes and remove of this if - has_static_cache = False - if past_key_values is None: - past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None) - has_static_cache = past_key_values is not None - past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): @@ -1267,8 +1243,7 @@ def prepare_inputs_for_generation( # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1308,9 +1283,6 @@ def prepare_inputs_for_generation( elif use_cache: cache_position = cache_position[-input_length:] - if has_static_cache: - past_key_values = None - model_inputs.update( { "position_ids": position_ids, diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index c83ba413952b..8d4ad532074f 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -927,7 +927,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 13719166edf9..b23073d332e4 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1313,7 +1313,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 70072c91720a..ca349dca1c1b 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1509,7 +1509,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 3262f2cd3c61..bc133ffb3d73 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -1299,7 +1299,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index ca4c8af23304..61e8518d659c 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1292,7 +1292,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, From a6853a192d8abda809f2ff57bfa6e7d19d6d0eba Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 26 Apr 2024 11:26:49 +0000 Subject: [PATCH 04/12] dbrx --- src/transformers/models/dbrx/modeling_dbrx.py | 101 +++++++++--------- .../models/gemma/modeling_gemma.py | 2 +- .../models/llama/modeling_llama.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- 4 files changed, 51 insertions(+), 56 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 99b865c773f8..8100807bef7e 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -354,6 +354,11 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Any, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) logger.info("Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.") output_attentions = False @@ -957,28 +962,6 @@ def _init_weights(self, module: nn.Module): module.v1.data.normal_(mean=0.0, std=std) module.w2.data.normal_(mean=0.0, std=std) - def _setup_cache(self, cache_cls: Any, max_batch_size: int, max_cache_len: int): - if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: - raise ValueError( - "`static` cache implementation is not compatible with " - + "`attn_implementation==flash_attention_2`. Make sure to use " - + "`spda` in the mean time and open an issue at https://github.com/huggingface/transformers." - ) - - for block in self.transformer.blocks: - device = block.norm_attn_norm.norm_1.weight.device - if hasattr(self.config, "_pre_quantization_dtype"): - dtype = self.config._pre_quantization_dtype - else: - dtype = block.norm_attn_norm.attn.out_proj.weight.dtype - block.norm_attn_norm.attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype - ) - - def _reset_cache(self): - for block in self.transformer.blocks: - block.norm_attn_norm.attn.past_key_value = None - DBRX_INPUTS_DOCSTRING = r""" Args: @@ -1131,22 +1114,18 @@ def forward( inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # embed positions hidden_states = inputs_embeds @@ -1205,7 +1184,9 @@ def forward( next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, DynamicCache) + else next_decoder_cache ) if not return_dict: return tuple( @@ -1221,28 +1202,45 @@ def forward( router_logits=all_router_logits, ) - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( - self, attention_mask: Optional[torch.Tensor], input_tensor: torch.Tensor, cache_position: torch.Tensor - ) -> Optional[torch.Tensor]: + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + ): + return None + dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(self.blocks[0].norm_attn_norm.attn, "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache + if using_static_cache: + target_length = past_key_values.get_max_length() + else: target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 ) - target_length = int(target_length) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if sequence_length != 1: @@ -1259,6 +1257,10 @@ def _update_causal_mask( # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input length is deprecated and will be removed in " + "transformers v4.42.0" + ) offset = cache_position[0] else: offset = 0 @@ -1273,17 +1275,10 @@ def _update_causal_mask( and attention_mask is not None and attention_mask.device.type == "cuda" ): - # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). - is_tracing = ( - torch.jit.is_tracing() - or isinstance(input_tensor, torch.fx.Proxy) - or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) - ) - if not is_tracing and torch.any(attention_mask != 1): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index ff0755794db4..9f8489289d70 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -613,7 +613,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 2b75b2b9bfd7..42000824ce34 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -714,7 +714,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index c5f94695f19e..7c3370ce7289 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -688,7 +688,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, From 326a1c74385006d5ae91b63dc1c88c7804a5df2f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 26 Apr 2024 11:34:20 +0000 Subject: [PATCH 05/12] missing in dbrx --- src/transformers/models/dbrx/modeling_dbrx.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 8100807bef7e..da5e01ecc1d3 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -627,6 +627,7 @@ def forward( value_states, attn_mask=causal_mask, dropout_p=self.attn_pdrop if self.training else 0.0, + is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() From 9de9556634b17eb3992ac5458e28cdb0e13b80e3 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 26 Apr 2024 12:07:08 +0000 Subject: [PATCH 06/12] more dbrx details; gemma --- src/transformers/models/dbrx/modeling_dbrx.py | 55 +++++++++---------- .../models/gemma/modeling_gemma.py | 40 +++++--------- 2 files changed, 40 insertions(+), 55 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index da5e01ecc1d3..7c2e6abbca97 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -15,7 +15,7 @@ """ PyTorch DBRX model. """ import math -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -1427,28 +1427,35 @@ def forward( router_logits=outputs.router_logits, ) + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, - input_ids: torch.Tensor, - past_key_values: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: Any, - ) -> Dict[str, Any]: + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1473,22 +1480,6 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if self.generation_config.cache_implementation == "static": - # generation with static cache - cache_position = kwargs.get("cache_position", None) - if cache_position is None: - past_length = 0 - else: - past_length = cache_position[-1] + 1 - input_ids = input_ids[:, past_length:] - position_ids = position_ids[:, past_length:] if position_ids is not None else None - - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. - # same goes for position ids. Could also help with continued generation. - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - position_ids = position_ids.contiguous() if position_ids is not None else None - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -1498,12 +1489,18 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, } ) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 9f8489289d70..6a5cb365212a 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -332,6 +332,11 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -715,23 +720,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): - if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - for layer in self.model.layers: - weights = layer.self_attn.o_proj.weight - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype - ) - - def _reset_cache(self): - for layer in self.model.layers: - layer.self_attn.past_key_value = None - GEMMA_INPUTS_DOCSTRING = r""" Args: @@ -848,7 +836,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -877,13 +865,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -891,7 +877,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # embed positions hidden_states = inputs_embeds @@ -950,7 +936,9 @@ def forward( next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, DynamicCache) + else next_decoder_cache ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -1081,7 +1069,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, From 90771df1c061b9ab8981792158ea42014a9e0483 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 26 Apr 2024 12:47:48 +0000 Subject: [PATCH 07/12] tmp --- docs/source/en/llm_optims.md | 68 +++++++++++-------- .../models/cohere/modeling_cohere.py | 16 +++-- .../aqlm_integration/test_aqlm.py | 16 +++-- 3 files changed, 61 insertions(+), 39 deletions(-) diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index f1dc6d5f23ce..d56278c7720b 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -65,13 +65,12 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True) ['The theory of special relativity states 1. The speed of light is constant in all inertial reference'] ``` - - +Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. However, if the batch size or the maximum output length increase between calls, the cache will have to be reinitialized, triggering a new compilation. -> [!WARNING] -> The `_setup_cache` method is an internal and private method that is still under development. This means it may not be backward compatible and the API design may change in the future. + + -The `_setup_cache` method doesn't support [`~GenerationMixin.generate`] yet, so this method is a bit more involved. You'll need to write your own function to decode the next token given the current token and position and cache position of previously generated tokens. +A [`StaticCache`] object can be passed to the model's forward pass under the `past_key_values` argument, enabling the use of this object as a static kv-cache. Using this strategy, you'll need to write your own function to decode the next token given the current token and position and cache position of previously generated tokens. ```py from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging @@ -90,17 +89,22 @@ tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential") inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) -def decode_one_tokens(model, cur_token, input_pos, cache_position): +def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values): logits = model( - cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True + cur_token, + position_ids=input_pos, + cache_position=cache_position, + past_key_values=past_key_values, + return_dict=False, + use_cache=True )[0] new_token = torch.argmax(logits[:, -1], dim=-1)[:, None] return new_token ``` -There are a few important things you must do to enable static kv-cache and torch.compile with the `_setup_cache` method: +There are a few important things you must do to enable static kv-cache and torch.compile with the `StaticCache` method: -1. Access the model's `_setup_cache` method and pass it the [`StaticCache`] class. This is a more flexible method because it allows you to configure parameters like the maximum batch size and sequence length. +1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length. 2. Call torch.compile on the model to compile the forward pass with the static kv-cache. @@ -109,24 +113,28 @@ There are a few important things you must do to enable static kv-cache and torch ```py batch_size, seq_length = inputs["input_ids"].shape with torch.no_grad(): - model._setup_cache(StaticCache, 2, max_cache_len=4096) - cache_position = torch.arange(seq_length, device=torch_device) - generated_ids = torch.zeros( - batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device - ) - generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) - - logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] - next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] - generated_ids[:, seq_length] = next_token[:, 0] - - decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) - cache_position = torch.tensor([seq_length + 1], device=torch_device) - for _ in range(1, NUM_TOKENS_TO_GENERATE): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): - next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) - generated_ids[:, cache_position] = next_token.int() - cache_position += 1 + past_key_values = StaticCache( + config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype + ) + cache_position = torch.arange(seq_length, device=torch_device) + generated_ids = torch.zeros( + batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device + ) + generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) + + logits = model( + **inputs, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True + )[0] + next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] + generated_ids[:, seq_length] = next_token[:, 0] + + decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) + cache_position = torch.tensor([seq_length + 1], device=torch_device) + for _ in range(1, NUM_TOKENS_TO_GENERATE): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): + next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values) + generated_ids[:, cache_position] = next_token.int() + cache_position += 1 text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) text @@ -134,6 +142,12 @@ text 'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p'] ``` +Please note that the cache has to be manually reset if you want to repeat this process multiple times reusing the same cache object. + +```py +past_key_values.reset() # Clears the cache's contents without destroying the objects +``` + diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index f1222a3081a2..f20edf17cf39 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -340,6 +340,11 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -875,14 +880,11 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -890,7 +892,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # embed positions hidden_states = inputs_embeds diff --git a/tests/quantization/aqlm_integration/test_aqlm.py b/tests/quantization/aqlm_integration/test_aqlm.py index 46b64573b938..fbd12fd0a32a 100644 --- a/tests/quantization/aqlm_integration/test_aqlm.py +++ b/tests/quantization/aqlm_integration/test_aqlm.py @@ -196,9 +196,9 @@ def test_quantized_model_compile(self): """ # Sample tokens greedily - def decode_one_tokens(model, cur_token, input_pos, cache_position): + def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values): logits = model( - cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True + cur_token, position_ids=input_pos, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True )[0] new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int) @@ -209,7 +209,13 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): seq_length = input_ids.shape[1] # Setup static KV cache for generation - self.quantized_model._setup_cache(StaticCache, 1, max_cache_len=seq_length + self.max_new_tokens + 1) + if hasattr(self.quantized_model.config, "_pre_quantization_dtype"): + cache_dtype = self.quantized_model.config._pre_quantization_dtype + else: + cache_dtype = self.quantized_model.dtype + past_key_values = StaticCache( + config=self.quantized_model.config, max_batch_size=2, max_cache_len=seq_length + self.max_new_tokens + 1, device=torch_device, dtype=cache_dtype + ) # Allocate token ids to be generated and copy prefix ids cache_position = torch.arange(seq_length, device=torch_device) @@ -217,7 +223,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int) # Do a forward pass to fill the prefix cache and compile the kernels if necessary - logits = self.quantized_model(input_ids, cache_position=cache_position, return_dict=False, use_cache=True)[0] + logits = self.quantized_model(input_ids, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True)[0] next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int) generated_ids[:, [seq_length]] = next_token @@ -229,7 +235,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): cache_position = torch.tensor([seq_length + 1], device=torch_device) for _ in range(1, self.max_new_tokens): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): - next_token = decode_one_tokens(self.quantized_model, next_token.clone(), None, cache_position) + next_token = decode_one_tokens(self.quantized_model, next_token.clone(), None, cache_position, past_key_values) generated_ids.index_copy_(1, cache_position, next_token) cache_position += 1 From cd6aecd7c288088c097b61c199e7bdd188dff185 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 26 Apr 2024 13:06:59 +0000 Subject: [PATCH 08/12] finalized mvp --- docs/source/en/llm_optims.md | 9 ++---- .../aqlm_integration/test_aqlm.py | 29 ++++++++++++++----- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index d56278c7720b..4b44c1d78c81 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -70,7 +70,7 @@ Under the hood, `generate` will attempt to reuse the same cache object, removing -A [`StaticCache`] object can be passed to the model's forward pass under the `past_key_values` argument, enabling the use of this object as a static kv-cache. Using this strategy, you'll need to write your own function to decode the next token given the current token and position and cache position of previously generated tokens. +A [`StaticCache`] object can be passed to the model's forward pass under the `past_key_values` argument, enabling the use of this object as a static kv-cache. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens. You can also pass the [`StaticCache`] object to [`~GenerationMixin.generate`] and use it across calls, like you would do with a dynamic cache. ```py from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging @@ -142,11 +142,8 @@ text 'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p'] ``` -Please note that the cache has to be manually reset if you want to repeat this process multiple times reusing the same cache object. - -```py -past_key_values.reset() # Clears the cache's contents without destroying the objects -``` +> [!TIP] +> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method diff --git a/tests/quantization/aqlm_integration/test_aqlm.py b/tests/quantization/aqlm_integration/test_aqlm.py index fbd12fd0a32a..3b0dd99adcd9 100644 --- a/tests/quantization/aqlm_integration/test_aqlm.py +++ b/tests/quantization/aqlm_integration/test_aqlm.py @@ -198,7 +198,12 @@ def test_quantized_model_compile(self): # Sample tokens greedily def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values): logits = model( - cur_token, position_ids=input_pos, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True + cur_token, + position_ids=input_pos, + cache_position=cache_position, + past_key_values=past_key_values, + return_dict=False, + use_cache=True, )[0] new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int) @@ -209,12 +214,12 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu seq_length = input_ids.shape[1] # Setup static KV cache for generation - if hasattr(self.quantized_model.config, "_pre_quantization_dtype"): - cache_dtype = self.quantized_model.config._pre_quantization_dtype - else: - cache_dtype = self.quantized_model.dtype past_key_values = StaticCache( - config=self.quantized_model.config, max_batch_size=2, max_cache_len=seq_length + self.max_new_tokens + 1, device=torch_device, dtype=cache_dtype + config=self.quantized_model.config, + max_batch_size=1, + max_cache_len=seq_length + self.max_new_tokens + 1, + device=torch_device, + dtype=self.quantized_model.config._pre_quantization_dtype, ) # Allocate token ids to be generated and copy prefix ids @@ -223,7 +228,13 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int) # Do a forward pass to fill the prefix cache and compile the kernels if necessary - logits = self.quantized_model(input_ids, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True)[0] + logits = self.quantized_model( + input_ids, + cache_position=cache_position, + past_key_values=past_key_values, + return_dict=False, + use_cache=True, + )[0] next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int) generated_ids[:, [seq_length]] = next_token @@ -235,7 +246,9 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu cache_position = torch.tensor([seq_length + 1], device=torch_device) for _ in range(1, self.max_new_tokens): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): - next_token = decode_one_tokens(self.quantized_model, next_token.clone(), None, cache_position, past_key_values) + next_token = decode_one_tokens( + self.quantized_model, next_token.clone(), None, cache_position, past_key_values + ) generated_ids.index_copy_(1, cache_position, next_token) cache_position += 1 From b8caa5e1d74f7f653ac5070d03858d9d965f4e6c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 29 Apr 2024 11:45:36 +0000 Subject: [PATCH 09/12] test for different cuda compute capabilities --- src/transformers/models/phi3/modeling_phi3.py | 2 +- tests/models/llama/test_modeling_llama.py | 76 +++++-------------- 2 files changed, 21 insertions(+), 57 deletions(-) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index f9364d130b7e..530c22a87449 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1419,7 +1419,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index d5e5a90fd7bc..b18031c912eb 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -687,59 +687,23 @@ def test_compile_static_cache(self): # work as intended. See https://github.com/pytorch/pytorch/issues/121943 NUM_TOKENS_TO_GENERATE = 40 # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test - # was changed to have a cache of 53 tokens (as opposed to 4096). - EXPECTED_TEXT_COMPLETION = [ - "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " - "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory " - "of relativ", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my " - "fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", - ] - - prompts = [ - "Simply put, the theory of relativity states that ", - "My favorite all time favorite condiment is ketchup.", - ] - tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="", padding_side="right") - model = LlamaForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16 - ) - inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - - # Dynamic Cache - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) - dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) - - # Static Cache - generated_ids = model.generate( - **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" - ) - static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) - - # Static Cache + compile - model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) - generated_ids = model.generate( - **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" - ) - static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) - - @slow - @require_torch_gpu - @require_read_token - def test_compile_repeated_calls(self): - # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 - # work as intended. See https://github.com/pytorch/pytorch/issues/121943 - NUM_TOKENS_TO_GENERATE = 40 - EXPECTED_TEXT_COMPLETION = [ - "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " - "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory " - "of relativ", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my " - "fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", - ] + # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. + EXPECTED_TEXT_COMPLETION = { + 8: [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " + "theory of relativ", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + ], + 7: [ + "Simply put, the theory of relativity states that 1. surely nothing is faster than light.\nThe theory " + "goes that nothing travels faster than light, but the faster you go, the slower everything else will " + "be.\nThe theory of relativity", + "My favorite all time favorite condiment is ketchup. I love it on hamburgers, hot dogs, fries, eggs, " + "and even on a good old fashioned cheeseburger. I love it on everything. I love it so", + ], + } prompts = [ "Simply put, the theory of relativity states that ", @@ -754,14 +718,14 @@ def test_compile_repeated_calls(self): # Dynamic Cache generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output # Static Cache generated_ids = model.generate( **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" ) static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) # Static Cache + compile model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) @@ -769,7 +733,7 @@ def test_compile_repeated_calls(self): **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" ) static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) @require_torch From 16e01f2d90af91c9a541b8341a0eeb23eb207582 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 29 Apr 2024 11:49:43 +0000 Subject: [PATCH 10/12] missing doc entry --- docs/source/en/internal/generation_utils.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 7270af049c32..e6872efe7308 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -362,3 +362,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] StaticCache - update - get_seq_length + - reorder_cache From 2d1eab8371905e286bbea1734cc80b2681fea2ce Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 30 Apr 2024 14:47:39 +0000 Subject: [PATCH 11/12] tmp commit --- src/transformers/cache_utils.py | 10 ++++++---- src/transformers/models/llama/modeling_llama.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 22b6a125ece3..ceca9d3eeb35 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -44,6 +44,7 @@ def update( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") def get_max_length(self) -> Optional[int]: @@ -158,6 +159,7 @@ def update( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` if len(self.key_cache) <= layer_idx: return 0 return self.key_cache[layer_idx].shape[-2] @@ -244,6 +246,7 @@ def _get_rerotation_cos_sin( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length if len(self.key_cache) <= layer_idx: return 0 @@ -415,8 +418,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model.""" # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. - # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after - # https://github.com/pytorch/pytorch/issues/120248 is fixed + # TODO: deprecate this function in favor of `cache_position` return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() def get_max_length(self) -> Optional[int]: @@ -427,5 +429,5 @@ def reset(self): """Resets the cache values while preserving the objects""" for layer_idx in range(len(self.key_cache)): # In-place ops prevent breaking the static address - self.key_cache[layer_idx] *= 0.0 - self.value_cache[layer_idx] *= 0.0 + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 42000824ce34..2592a9a1f7a7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1055,7 +1055,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = cache_position[0] if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( From bb159d293d15b49329271c851b54a8446c0ce5ca Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 30 Apr 2024 15:09:33 +0000 Subject: [PATCH 12/12] final nits --- src/transformers/models/llama/modeling_llama.py | 2 +- tests/models/llama/test_modeling_llama.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 2592a9a1f7a7..42000824ce34 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1055,7 +1055,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = cache_position[0] if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index b18031c912eb..0592922e4470 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -18,6 +18,7 @@ import unittest import pytest +from packaging import version from parameterized import parameterized from transformers import LlamaConfig, is_torch_available, set_seed @@ -685,6 +686,9 @@ def test_model_13b_greedy_generation(self): def test_compile_static_cache(self): # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest("This test requires torch >= 2.3 to run.") + NUM_TOKENS_TO_GENERATE = 40 # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.