diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b2f330548f56..23abf9d6d89f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -359,7 +359,7 @@ def prepare_inputs_for_generation(self, *args, **kwargs): def _prepare_model_inputs( self, inputs: Optional[torch.Tensor] = None, - bos_token_id: Optional[int] = None, + bos_token_id: Optional[torch.Tensor] = None, model_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: """ @@ -423,7 +423,7 @@ def _prepare_model_inputs( def _maybe_initialize_input_ids_for_generation( self, inputs: Optional[torch.Tensor] = None, - bos_token_id: Optional[int] = None, + bos_token_id: Optional[torch.Tensor] = None, model_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.LongTensor: """Initializes input ids for generation, if necessary.""" @@ -454,20 +454,29 @@ def _maybe_initialize_input_ids_for_generation( def _prepare_attention_mask_for_generation( self, inputs: torch.Tensor, - pad_token_id: Optional[int], - eos_token_id: Optional[Union[int, List[int]]], + pad_token_id: Optional[Optional[torch.Tensor]], + eos_token_id: Optional[Optional[torch.Tensor]], ) -> torch.LongTensor: + # No information for attention mask inference -> return default attention mask + default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) + if pad_token_id is None: + return default_attention_mask + + # Otherwise we have may have information -> try to infer the attention mask is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] - is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) + is_pad_token_in_inputs = (pad_token_id is not None) and ( + torch.isin(elements=inputs, test_elements=pad_token_id).any() + ) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~( + torch.isin(elements=eos_token_id, test_elements=pad_token_id).any() + ) - # Check if input is input_ids and padded -> only then is attention_mask defined - if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: - return inputs.ne(pad_token_id).long() - else: - return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) + can_infer_attention_mask = is_input_ids * is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id + attention_mask_from_padding = inputs.ne(pad_token_id).long() + attention_mask = ( + attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask + ) + return attention_mask def _prepare_encoder_decoder_kwargs_for_generation( self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None @@ -509,8 +518,7 @@ def _prepare_decoder_input_ids_for_generation( batch_size: int, model_input_name: str, model_kwargs: Dict[str, torch.Tensor], - decoder_start_token_id: Union[int, List[int]] = None, - bos_token_id: int = None, + decoder_start_token_id: Optional[torch.Tensor] = None, device: torch.device = None, ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: """Prepares `decoder_input_ids` for generation with encoder-decoder models""" @@ -524,20 +532,14 @@ def _prepare_decoder_input_ids_for_generation( decoder_input_ids = None # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. - decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) if device is None: device = self.device - if isinstance(decoder_start_token_id, list): - if len(decoder_start_token_id) != batch_size: + if decoder_start_token_id.ndim == 1: + if decoder_start_token_id.shape[0] != batch_size: raise ValueError( - f"`decoder_start_token_id` expcted to have length {batch_size} but got {len(decoder_start_token_id)}" + f"`decoder_start_token_id` expcted to have length {batch_size} but got {decoder_start_token_id.shape[0]}" ) - decoder_input_ids_start = torch.tensor(decoder_start_token_id, dtype=torch.long, device=device) - decoder_input_ids_start = decoder_input_ids_start.view(-1, 1) - else: - decoder_input_ids_start = ( - torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id - ) + decoder_input_ids_start = decoder_input_ids_start.view(-1, 1) # no user input -> use decoder_start_token_id as decoder_input_ids if decoder_input_ids is None: @@ -568,7 +570,7 @@ def _prepare_decoder_input_ids_for_generation( return decoder_input_ids, model_kwargs def _get_decoder_start_token_id( - self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None + self, decoder_start_token_id: Optional[torch.tensor] = None, bos_token_id: Optional[torch.tensor] = None ) -> int: decoder_start_token_id = ( decoder_start_token_id @@ -1207,9 +1209,7 @@ def _prepare_generation_config( # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled. if is_torchdynamo_compiling(): model_kwargs = kwargs - generate_attributes_in_kwargs = [ - key for key, value in kwargs.items() if getattr(generation_config, key, None) != value - ] + generate_attributes_in_kwargs = [key for key in kwargs.keys() if hasattr(generation_config, key)] if len(generate_attributes_in_kwargs) > 0: raise ValueError( "`torch.compile` exception: all generation configuration attributes must be passed within a " @@ -1221,6 +1221,38 @@ def _prepare_generation_config( return generation_config, model_kwargs + def _prepare_special_tokens( + self, generation_config: GenerationConfig, kwargs_has_attention_mask: bool + ) -> Tuple[Optional[torch.Tensor]]: + """Prepares the special tokens for generation.""" + + # Convert special tokens to tensors (if they exist) + def _tensor_or_none(token): + return torch.tensor(token, device=self.device, dtype=torch.long) if token is not None else None + + bos_token_id = _tensor_or_none(generation_config.bos_token_id) + eos_token_id = _tensor_or_none(generation_config.eos_token_id) + pad_token_id = _tensor_or_none(generation_config.pad_token_id) + decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id) or bos_token_id + + if self.config.is_encoder_decoder and decoder_start_token_id is None: + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) + + # Set pad token if unset (and there are conditions to do so) + if pad_token_id is None and eos_token_id is not None: + if not kwargs_has_attention_mask: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + if eos_token_id.ndim == 1: + pad_token_id = pad_token_id[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.") + + return bos_token_id, eos_token_id, pad_token_id, decoder_start_token_id + @torch.no_grad() def generate( self, @@ -1333,31 +1365,35 @@ def generate( logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: - if model_kwargs.get("attention_mask", None) is None: - logger.warning( - "The attention mask and the pad token id were not set. As a consequence, you may observe " - "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." - ) - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - generation_config.pad_token_id = eos_token_id + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + bos_token_id, eos_token_id, pad_token_id, decoder_start_token_id = self._prepare_special_tokens( + generation_config, kwargs_has_attention_mask + ) # 3. Define model inputs - # inputs_tensor has to be defined - # model_input_name is defined if model-specific keyword input is passed - # otherwise model_input_name is None - # all model-specific keyword inputs are removed from `model_kwargs` - inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( - inputs, generation_config.bos_token_id, model_kwargs - ) + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs) batch_size = inputs_tensor.shape[0] + # decoder-only models must use left-padding for generation. + if not self.config.is_encoder_decoder and not is_torchdynamo_compiling(): + # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` + # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. + if ( + pad_token_id is not None + and len(inputs_tensor.shape) == 2 + and torch.sum(inputs_tensor[:, -1] == pad_token_id) > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + # 4. Define other model kwargs model_kwargs["output_attentions"] = generation_config.output_attentions model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are # generating the first new token or not, and we only want to use the embeddings for the first new token) if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": @@ -1365,31 +1401,13 @@ def generate( else: model_kwargs["use_cache"] = generation_config.use_cache - accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) - requires_attention_mask = "encoder_outputs" not in model_kwargs - - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: + if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id + inputs_tensor, pad_token_id, eos_token_id ) - # decoder-only models should use left-padding for generation - if not self.config.is_encoder_decoder: - # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` - # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. - if ( - generation_config.pad_token_id is not None - and len(inputs_tensor.shape) == 2 - and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 - ): - logger.warning( - "A decoder-only architecture is being used, but right-padding was detected! For correct " - "generation results, please set `padding_side='left'` when initializing the tokenizer." - ) - if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: - # if model is encoder decoder encoder_outputs are created - # and added to `model_kwargs` + # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, model_kwargs, model_input_name ) @@ -1400,8 +1418,7 @@ def generate( batch_size=batch_size, model_input_name=model_input_name, model_kwargs=model_kwargs, - decoder_start_token_id=generation_config.decoder_start_token_id, - bos_token_id=generation_config.bos_token_id, + decoder_start_token_id=decoder_start_token_id, device=inputs_tensor.device, ) else: @@ -1446,7 +1463,8 @@ def generate( ) self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) - self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + if not is_torchdynamo_compiling(): + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) # 7. determine generation mode generation_mode = generation_config.get_generation_mode(assistant_model) @@ -1483,6 +1501,7 @@ def generate( prepared_stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) + # 10. go into different generation modes if generation_mode == GenerationMode.ASSISTED_GENERATION: if generation_config.num_return_sequences > 1: @@ -1778,23 +1797,30 @@ def typeerror(): return result - def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool: + def _has_unfinished_sequences( + self, this_peer_finished: bool, cur_len, max_length, synced_gpus: bool, device: torch.device + ) -> bool: """ Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is fed through `this_peer_finished`. ZeRO stage 3-friendly. """ - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: + # torch.compile does not support data-dependent control flow. This is a workaround to allow torch.compile, + # although we lose the ability to stop when all sequences return an EOS token (and other stopping criteria) + if is_torchdynamo_compiling(): + return cur_len < max_length + else: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + return False + elif this_peer_finished: return False - elif this_peer_finished: - return False - return True + return True def contrastive_search(self, *args, **kwargs): logger.warning_once( @@ -2403,7 +2429,15 @@ def _greedy_search( unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + max_length = None + for criteria in stopping_criteria: + if isinstance(criteria, MaxLengthCriteria): + max_length = criteria.max_length + break + + while self._has_unfinished_sequences( + this_peer_finished, cur_len, max_length, synced_gpus, device=input_ids.device + ): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -2470,6 +2504,7 @@ def _greedy_search( unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 + cur_len += 1 if streamer is not None: streamer.end() diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 8360b4080781..a54aeb24b053 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -879,14 +879,19 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache: + static_cache = getattr(self.layers[0].self_attn, "past_key_value", None) + if static_cache is not None: + past_seen_tokens = static_cache.get_seq_length() + else: + if not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() if cache_position is None: - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + # `torch.compile`-friendly `torch.arange` from a shape + cache_position = ( + torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 ) if position_ids is None: @@ -1202,12 +1207,6 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - else: - cache_position = cache_position[-input_length:] - if has_static_cache: past_key_values = None diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f5f3dc02ee9d..dc25d319d513 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -972,16 +972,19 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) + if use_cache: + static_cache = getattr(self.layers[0].self_attn, "past_key_value", None) + if static_cache is not None: + past_seen_tokens = static_cache.get_seq_length() + else: + if not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + # `torch.compile`-friendly `torch.arange` from a shape + cache_position = ( + torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 ) if position_ids is None: @@ -1258,25 +1261,35 @@ def prepare_inputs_for_generation( cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + attention_based_slicing = attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1] + past_based_slicing = past_length < input_ids.shape[1] + # no_slicing = not attention_based_slicing and not past_based_slicing + input_ids_slice_index = (-(attention_mask.shape[1] - past_length) * attention_based_slicing) + ( + past_length * past_based_slicing + ) + # input_ids_slice_index = (-(attention_mask.shape[1] - past_length) * attention_based_slicing) + (past_length * past_based_slicing) + (0 * no_slicing) + # input_ids = input_ids[:, input_ids_slice_index:] + input_ids = input_ids[:, cache_position] + + # # Keep only the unprocessed tokens: + # # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # # input) + # if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + # input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # # input_ids based on the past_length. + # elif past_length < input_ids.shape[1]: + # input_ids = input_ids[:, past_length:] + # # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + # if ( + # max_cache_length is not None + # and attention_mask is not None + # and cache_length + input_ids.shape[1] > max_cache_length + # ): + # attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -1295,12 +1308,6 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - else: - cache_position = cache_position[-input_length:] - if has_static_cache: past_key_values = None diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 91e9c3876a56..0a4db70a27eb 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -27,6 +27,7 @@ is_flaky, require_accelerate, require_torch, + require_torch_gpu, require_torch_multi_accelerator, slow, torch_device, @@ -1653,6 +1654,31 @@ def test_new_cache_format(self, num_beams, do_sample): ) ) + @require_torch_gpu + @slow + def test_generate_compile_fullgraph(self): + """Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results""" + for model_class in self.all_generative_model_classes: + if not hasattr(model_class, "_setup_cache"): + continue + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(torch_device) + input_ids = inputs_dict["input_ids"].to(torch_device) + + # dynamic cache + output_dynamic = model.generate(input_ids) + + # eager static cache + model.generation_config.cache_implementation = "static" + output_static = model.generate(input_ids) + self.assertListEqual(output_dynamic.tolist(), output_static.tolist()) + + # compiled static cache + compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") + output_compiled = compiled_generate(input_ids) + self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist()) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape num_sequences_in_output = batch_size * num_return_sequences @@ -2808,3 +2834,21 @@ def test_return_unprocessed_logit_scores(self): self.assertTrue(y_prob > 0.001 and n_prob > 0.001) self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0) + + def test_bad_generate_compilation_flags(self): + """ + Tests that certain parameterization options in `generate` properly raise a custom exception (a `ValueError` + defined in `transformers` instead of general `torch._dynamo.exc.Unsupported`). + """ + + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # Passing generation_config parameters through kwargs is not supported + with self.assertRaises(ValueError): + compiled_generate(input_ids, max_length=10, do_sample=True, temperature=0.7) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index ca9ac62bf28b..bd5a7412cca4 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -104,7 +104,7 @@ from safetensors.torch import save_file as safe_save_file from torch import nn - from transformers import MODEL_MAPPING, AdaptiveEmbedding + from transformers import MODEL_MAPPING, AdaptiveEmbedding, StaticCache from transformers.modeling_utils import load_state_dict, no_init_weights from transformers.pytorch_utils import id_tensor_storage @@ -4027,6 +4027,78 @@ def test_flash_attn_2_from_config(self): self.assertFalse(fa2_correctly_converted) + @require_torch_gpu + @slow + def test_implicit_cache_position(self): + """ + Tests that passing the correct cache_position yields the same results as passing cache_position=None, i.e. that + inference with implicit cache_position is working. + """ + for model_class in self.all_generative_model_classes: + if not hasattr(model_class, "_setup_cache"): + continue + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(torch_device) + + input_ids = inputs_dict["input_ids"].to(torch_device) + + def run_2_forward_passes_with_cache(model, input_ids, static_cache, compile): + # runs two generate-style forward passes, to ensure cudagraphs need two different values of implicit + # `cache_position` to work correctly + if static_cache: + model._setup_cache( + cache_cls=StaticCache, max_batch_size=input_ids.shape[0], max_cache_len=input_ids.shape[1] + 1 + ) + + if compile: + model = torch.compile(model, fullgraph=True, mode="reduce-overhead") + + # Implicit cache_positions + logits_implicit = [] + outputs = model(input_ids, cache_position=None) + if static_cache: + self.assertTrue(outputs.past_key_values is None) # sanity check -- it is None with static cache + logits_implicit.append(outputs.logits) + new_input_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(1) + outputs = model(new_input_ids, cache_position=None, past_key_values=outputs.past_key_values) + logits_implicit.append(outputs.logits) + + if static_cache: + # Restart the cache + model._reset_cache() + model._setup_cache( + cache_cls=StaticCache, max_batch_size=input_ids.shape[0], max_cache_len=input_ids.shape[1] + 1 + ) + + # Explicit cache_positions + logits_explicit = [] + cache_positions = torch.arange(input_ids.shape[1], dtype=torch.long, device=torch_device) + outputs = model(input_ids, cache_position=cache_positions) + if static_cache: + self.assertTrue(outputs.past_key_values is None) # sanity check -- it is None with static cache + logits_explicit.append(outputs.logits) + cache_positions = cache_positions[-1:] + 1 + new_input_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(1) + outputs = model(new_input_ids, cache_position=cache_positions, past_key_values=outputs.past_key_values) + logits_explicit.append(outputs.logits) + + if static_cache: + model._reset_cache() + + # Confirm that explicit and implicity cache_positions yield the same results + for idx in range(len(logits_implicit)): + self.assertTrue(torch.allclose(logits_implicit[idx], logits_explicit[idx])) + + # dynamic cache + run_2_forward_passes_with_cache(model, input_ids, static_cache=False, compile=False) + + # eager static cache + run_2_forward_passes_with_cache(model, input_ids, static_cache=True, compile=False) + + # compiled static cache [to confirm that it works with cuda graphs] + run_2_forward_passes_with_cache(model, input_ids, static_cache=True, compile=True) + global_rng = random.Random()