From ef1ebd913cef7cb557278748adc98a75809c7ff1 Mon Sep 17 00:00:00 2001 From: raushan Date: Fri, 25 Jul 2025 16:21:49 +0200 Subject: [PATCH 1/7] fix --- src/transformers/generation/utils.py | 2 +- .../modeling_kyutai_speech_to_text.py | 2 +- src/transformers/models/mimi/modeling_mimi.py | 3 ++ .../modeling_musicgen_melody.py | 28 +++++++++++++++++-- src/transformers/models/rag/modeling_rag.py | 9 ------ .../models/roformer/modeling_roformer.py | 25 ++++++++++++----- .../models/superglue/modeling_superglue.py | 2 +- .../test_modeling_qwen2_5_omni.py | 2 +- .../test_pipelines_image_text_to_text.py | 8 +++--- 9 files changed, 54 insertions(+), 27 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6f4adcfeb14e..509adaa8e2fe 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2055,7 +2055,7 @@ def _prepare_cache_for_generation( generation_config.cache_implementation = None generation_config.cache_implementation = generation_config.cache_implementation or getattr( - self.config.get_text_config(), "cache_implementation", None + self.config.get_text_config(decoder=True), "cache_implementation", None ) if generation_config.cache_implementation is not None: if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: diff --git a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py index d2e9d92e7877..08a8e1a1a5eb 100644 --- a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -1221,7 +1221,7 @@ def _prepare_model_inputs( for method in cache_methods: setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model)) - self.codec_model._prepare_cache_for_generation( + self._prepare_cache_for_generation( generation_config=self.codec_model.generation_config, model_kwargs=temporary_model_kwargs, assistant_model=None, diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 260ea6f7ce2e..c04f273aeced 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1376,6 +1376,9 @@ class MimiPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True + # Mimi is non-generative model but uses cache, special case + _supports_default_dynamic_cache = True + _can_compile_fullgraph = True def _init_weights(self, module): diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index a454d9fe24be..23c5314c5447 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -2162,6 +2162,28 @@ def generate( input_ids_length=input_ids_length, ) + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + # 7. Prepare the cache. + # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. + # - different models have a different cache name expected by the model (default = "past_key_values") + # - `max_length`, prepared above, is used to determine the maximum cache length + max_cache_length = generation_config.max_length - 1 + if ( + inputs_tensor.shape[1] != input_ids_length + and model_input_name == "inputs_embeds" + and not self.config.is_encoder_decoder + ): + max_cache_length += inputs_tensor.shape[1] + self._prepare_cache_for_generation( + generation_config, + model_kwargs, + assistant_model=None, + batch_size=batch_size, + max_cache_length=max_cache_length, + device=inputs_tensor.device, + ) + # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen) input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( input_ids, @@ -2175,15 +2197,15 @@ def generate( if streamer is not None: streamer.put(input_ids.cpu()) - # 7. determine generation mode + # 8. determine generation mode generation_mode = generation_config.get_generation_mode() - # 8. prepare batched CFG externally (to enable coexistence with the unbatched CFG) + # 9. prepare batched CFG externally (to enable coexistence with the unbatched CFG) if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) generation_config.guidance_scale = None - # 9. prepare distribution pre_processing samplers + # 10. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 4cb08b1bc4c3..ffc1f60c8fe1 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1204,8 +1204,6 @@ def _reorder_stacked(hidden_states, new_order): if isinstance(past_key_values, EncoderDecoderCache): reordered_past = EncoderDecoderCache.from_legacy_cache(reordered_past) - if isinstance(past_key_values, EncoderDecoderCache): - reordered_past = EncoderDecoderCache.from_legacy_cache(reordered_past) return reordered_past def marginalize(self, seq_logits, doc_scores, n_docs=None): @@ -1593,13 +1591,6 @@ def extend_enc_output(tensor, num_beams=None): if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") - # 11. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_beams, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) return self._beam_search( input_ids, logits_processor=pre_processor, diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 0fa8ac4e0c48..b880496aa0cf 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -261,6 +261,17 @@ def forward( .transpose(1, 2) ) + # Apply RoPE if self attention + if not is_cross_attention and sinusoidal_pos is not None: + if self.rotary_value: + query_layer, key_layer, value_layer = self.apply_rotary_position_embeddings( + sinusoidal_pos, query_layer, key_layer, value_layer + ) + else: + query_layer, key_layer = self.apply_rotary_position_embeddings( + sinusoidal_pos, query_layer, key_layer + ) + if past_key_value is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None @@ -381,13 +392,13 @@ def forward( ): self_outputs = self.self( hidden_states, - attention_mask, - sinusoidal_pos, - head_mask, - encoder_hidden_states, - past_key_value, - output_attentions, - cache_position, + attention_mask=attention_mask, + sinusoidal_pos=sinusoidal_pos, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py index da0dcfac9245..3286b8912cd9 100644 --- a/src/transformers/models/superglue/modeling_superglue.py +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -274,7 +274,7 @@ def forward( # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else encoder_attention_mask + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask batch_size = hidden_states.shape[0] key_layer = ( diff --git a/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py b/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py index 0863d8def4e3..e5b1f1a06bab 100644 --- a/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py +++ b/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py @@ -698,7 +698,7 @@ def test_small_model_integration_test_batch(self): ) text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) inputs = self.processor( - text=text * 2, + text=[text] * 2, audio=[self.raw_audio, self.raw_audio], images=[self.raw_image, self.raw_image], return_tensors="pt", diff --git a/tests/pipelines/test_pipelines_image_text_to_text.py b/tests/pipelines/test_pipelines_image_text_to_text.py index 781fbad8a907..eea544936268 100644 --- a/tests/pipelines/test_pipelines_image_text_to_text.py +++ b/tests/pipelines/test_pipelines_image_text_to_text.py @@ -119,7 +119,7 @@ def test_small_model_pt_token_text_only(self): }, { "role": "assistant", - "content": "Hugging Face, a company of minds\nWith tools and services that make our lives easier\nFrom", + "content": "Hugging Face, a company of minds\nWith tools and services that make our lives easier\nFrom natural language processing\nTo machine learning and more, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and services\nFrom image and speech recognition\nTo text and language translation, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and services\nFrom image and speech recognition\nTo text and language translation, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and services\nFrom image and speech recognition\nTo text and language translation, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and services\nFrom image and speech recognition\nTo text and language translation, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and services\nFrom image and speech recognition\nTo text and language translation, they've got it all\n\nThey've made it possible for us to be more\nInformed and efficient, with their tools and", }, ], } @@ -150,7 +150,7 @@ def test_small_model_pt_token(self): [ { "input_text": " What this is? Assistant: This is", - "generated_text": " What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable", + "generated_text": " What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable. The photo captures a moment of tranquility and companionship between the two feline friends.", } ], ) @@ -161,11 +161,11 @@ def test_small_model_pt_token(self): [ { "input_text": " What this is? Assistant: This is", - "generated_text": " What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they", + "generated_text": " What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they appear to be sleeping or resting. The blanket is placed on a couch, and the cats are positioned in such a way that they are facing the camera. The image captures a peaceful moment between the two cats, and it's a great way to showcase their cuteness and relaxed demeanor.", }, { "input_text": " What this is? Assistant: This is", - "generated_text": " What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they", + "generated_text": " What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they appear to be sleeping or resting. The blanket is placed on a couch, and the cats are positioned in such a way that they are facing the camera. The image captures a peaceful moment between the two cats, and it's a great way to showcase their cuteness and relaxed demeanor.", }, ], ) From 00d9b76d9b5c1b536c388a080e3a859b12f67526 Mon Sep 17 00:00:00 2001 From: raushan Date: Fri, 25 Jul 2025 17:31:39 +0200 Subject: [PATCH 2/7] fix kyutai at last --- .../modeling_kyutai_speech_to_text.py | 7 +++++-- src/transformers/models/mimi/modeling_mimi.py | 3 --- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py index 08a8e1a1a5eb..bf0b55d5a19c 100644 --- a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -1215,13 +1215,16 @@ def _prepare_model_inputs( cache_methods = [ "_prepare_cache_for_generation", "_get_cache", - "_supports_default_dynamic_cache", "_get_layer_device_map_for_cache_init", ] for method in cache_methods: setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model)) - self._prepare_cache_for_generation( + setattr( + self.codec_model, "_supports_default_dynamic_cache", types.MethodType(lambda x: True, self.codec_model) + ) + + self.codec_model._prepare_cache_for_generation( generation_config=self.codec_model.generation_config, model_kwargs=temporary_model_kwargs, assistant_model=None, diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index c04f273aeced..260ea6f7ce2e 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1376,9 +1376,6 @@ class MimiPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - # Mimi is non-generative model but uses cache, special case - _supports_default_dynamic_cache = True - _can_compile_fullgraph = True def _init_weights(self, module): From 3df3f951ce406c503397f3c03ea858f7481a07e0 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 28 Jul 2025 11:41:16 +0200 Subject: [PATCH 3/7] fix unrelated tests and copies --- .../kyutai_speech_to_text/modular_kyutai_speech_to_text.py | 5 ++++- .../test_modeling_kyutai_speech_to_text.py | 6 +++--- tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py | 1 - tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py | 1 - tests/models/qwen2_vl/test_modeling_qwen2_vl.py | 1 - 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py index 4929c9e4bae1..e0e424ac605e 100644 --- a/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py @@ -344,12 +344,15 @@ def _prepare_model_inputs( cache_methods = [ "_prepare_cache_for_generation", "_get_cache", - "_supports_default_dynamic_cache", "_get_layer_device_map_for_cache_init", ] for method in cache_methods: setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model)) + setattr( + self.codec_model, "_supports_default_dynamic_cache", types.MethodType(lambda x: True, self.codec_model) + ) + self.codec_model._prepare_cache_for_generation( generation_config=self.codec_model.generation_config, model_kwargs=temporary_model_kwargs, diff --git a/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py b/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py index b5c92267edc7..3d9856207776 100644 --- a/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py +++ b/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py @@ -769,12 +769,12 @@ def test_generation_batched(self): out = model.generate(**inputs) # fmt: off - EXPECTED_TOKENS = torch.tensor([ + EXPECTED_TOKENS = [ [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 0, 277, 3, 0, 265, 0, 267, 1162, 261, 274, 410, 0, 272, 3, 0, 265, 0, 260, 1621, 0, 1174, 371, 262, 3, 3, 3, 0, 269, 0, 281, 0, 304, 0, 2433, 3, 0, 266, 3, 0, 281, 1661, 3, 0, 376, 3, 3, 0, 350, 261, 401, 516, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 500, 334, 0, 277, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 264, 261, 0, 511, 1109, 3, 0, 1138, 3, 3, 3, 0, 508, 827, 3, 3, 3, 3, 0, 468, 3, 3, 0, 376, 3, 3, 3, 0, 260, 978, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 414, 0, 527, 261, 3, 0, 409, 3, 3, 3, 0, 271, 3, 0, 309, 3, 0, 285, 3, 0, 521, 371, 609, 3, 3, 0, 260, 959, 3, 3, 3, 0, 272, 3, 0, 265, 0, 546, 262, 3, 3, 3, 3, 3, 3, 0, 291, 3, 0, 975, 2203, 3, 3, 3, 3, 0, 269, 3, 0, 260, 489, 651, 274, 279, 1870, 3, 0, 1084, 873, 273, 3, 0, 260, 531, 3, 3, 0, 409, 262, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1502, 1005, 836, 3, 3, 0, 1666, 306, 3, 0, 340, 3, 0, 260, 3232, 3, 0, 269, 3, 3, 0, 275, 261, 0, 260, 1379, 261, 0, 3324, 3, 3, 3, 3, 0, 549, 3, 3, 0, 693, 405, 323, 3, 0, 266, 3, 3, 0, 265, 0, 699, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 414, 0, 392, 3, 3, 0, 1269, 314, 0, 2607, 261, 3, 3, 3, 0, 1098, 295, 3, 3, 3, 0, 446, 625, 3, 0, 496, 280, 1205, 485, 1071, 1627, 449, 264, 261, 3, 0, 400, 0, 277, 3, 3, 3, 0, 260, 342, 3, 0, 618, 280, 1866, 3, 3, 0, 554, 3, 3, 3, 3, 0, 317, 262, 3, 3, 3, 3, 3, 3, 3, 3, 0, 269, 0, 303, 3, 0, 573, 2615, 3, 3, 0, 276, 3, 0, 275, 0, 305, 3, 0, 260, 415, 3, 3, 0, 272, 3, 3, 3, 3, 0, 1631, 327, 3, 3, 0, 333, 739, 841, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], - ]) + ] # fmt: on - torch.testing.assert_close(out.cpu(), EXPECTED_TOKENS) + self.assertListEqual(out.cpu().tolist*(), EXPECTED_TOKENS) diff --git a/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py b/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py index e5b1f1a06bab..28be4eba3f85 100644 --- a/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py +++ b/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py @@ -413,7 +413,6 @@ def attention_mask_padding_matches_padding_free_with_position_ids( logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()] logits_padfree = res_padfree.logits[0] - torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0) # acceptable numerical instability tol = torch.finfo(torch.bfloat16).eps torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py index 385a714f131c..1b73dd624d22 100644 --- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -403,7 +403,6 @@ def attention_mask_padding_matches_padding_free_with_position_ids( logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()] logits_padfree = res_padfree.logits[0] - torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0) # acceptable numerical instability tol = torch.finfo(torch.bfloat16).eps torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index ea0a8992b434..bbe615d50954 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -362,7 +362,6 @@ def attention_mask_padding_matches_padding_free_with_position_ids( logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()] logits_padfree = res_padfree.logits[0] - torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0) # acceptable numerical instability tol = torch.finfo(torch.bfloat16).eps torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) From daeb4d818d313220fc5cb164e10b00d600bf8663 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 28 Jul 2025 11:51:14 +0200 Subject: [PATCH 4/7] update musicgen as well --- .../models/musicgen/modeling_musicgen.py | 30 ++++++++++++++++--- .../test_modeling_kyutai_speech_to_text.py | 2 +- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index cd76c8716292..91c505c636d3 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1246,7 +1246,29 @@ def generate( input_ids_length=input_ids_length, ) - # 6. Prepare `input_ids` which will be used for auto-regressive generation + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + # 6. Prepare the cache. + # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. + # - different models have a different cache name expected by the model (default = "past_key_values") + # - `max_length`, prepared above, is used to determine the maximum cache length + max_cache_length = generation_config.max_length - 1 + if ( + input_ids_length.shape[1] != input_ids_length + and model_input_name == "inputs_embeds" + and not self.config.is_encoder_decoder + ): + max_cache_length += input_ids_length.shape[1] + self._prepare_cache_for_generation( + generation_config, + model_kwargs, + assistant_model=None, + batch_size=batch_size, + max_cache_length=max_cache_length, + device=input_ids_length.device, + ) + + # 7. Prepare `input_ids` which will be used for auto-regressive generation # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen) input_ids, delay_pattern_mask = self.build_delay_pattern_mask( input_ids, @@ -1260,15 +1282,15 @@ def generate( # stash the delay mask so that we don't have to recompute it in each forward pass model_kwargs["delay_pattern_mask"] = delay_pattern_mask - # 7. determine generation mode + # 8. determine generation mode generation_mode = generation_config.get_generation_mode() - # 8. prepare batched CFG externally (to enable coexistence with the unbatched CFG) + # 9. prepare batched CFG externally (to enable coexistence with the unbatched CFG) if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) generation_config.guidance_scale = None - # 9. prepare distribution pre_processing samplers + # 10. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, diff --git a/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py b/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py index 3d9856207776..ce4100c38d09 100644 --- a/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py +++ b/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py @@ -777,4 +777,4 @@ def test_generation_batched(self): ] # fmt: on - self.assertListEqual(out.cpu().tolist*(), EXPECTED_TOKENS) + self.assertListEqual(out.cpu().tolist(), EXPECTED_TOKENS) From 56758c93e8be21c1ebfc2f6cd3fd40a2907489fb Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 28 Jul 2025 11:54:18 +0200 Subject: [PATCH 5/7] revert tensor --- .../test_modeling_kyutai_speech_to_text.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py b/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py index 0d4ec2baa8a1..6be8879afbe4 100644 --- a/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py +++ b/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py @@ -769,12 +769,12 @@ def test_generation_batched(self): out = model.generate(**inputs) # fmt: off - EXPECTED_TOKENS = [ + EXPECTED_TOKENS = torch.tensor([ [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 0, 277, 3, 0, 265, 0, 267, 1162, 261, 274, 410, 0, 272, 3, 0, 265, 0, 260, 1621, 0, 1174, 371, 262, 3, 3, 3, 0, 269, 0, 281, 0, 304, 0, 2433, 3, 0, 266, 3, 0, 281, 1661, 3, 0, 376, 3, 3, 0, 350, 261, 401, 516, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 500, 334, 0, 277, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 264, 261, 0, 511, 1109, 3, 0, 1138, 3, 3, 3, 0, 508, 827, 3, 3, 3, 3, 0, 468, 3, 3, 0, 376, 3, 3, 3, 0, 260, 978, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 414, 0, 527, 261, 3, 0, 409, 3, 3, 3, 0, 271, 3, 0, 309, 3, 0, 285, 3, 0, 521, 371, 609, 3, 3, 0, 260, 959, 3, 3, 3, 0, 272, 3, 0, 265, 0, 546, 262, 3, 3, 3, 3, 3, 3, 0, 291, 3, 0, 975, 2203, 3, 3, 3, 3, 0, 269, 3, 0, 260, 489, 651, 274, 279, 1870, 3, 0, 1084, 873, 273, 3, 0, 260, 531, 3, 3, 0, 409, 262, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1502, 1005, 836, 3, 3, 0, 1666, 306, 3, 0, 340, 3, 0, 260, 3232, 3, 0, 269, 3, 3, 0, 275, 261, 0, 260, 1379, 261, 0, 3324, 3, 3, 3, 3, 0, 549, 3, 3, 0, 693, 405, 323, 3, 0, 266, 3, 3, 0, 265, 0, 699, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 414, 0, 392, 3, 3, 0, 1269, 314, 0, 2607, 261, 3, 3, 3, 0, 1098, 295, 3, 3, 3, 0, 446, 625, 3, 0, 496, 280, 1205, 485, 1071, 1627, 449, 264, 261, 3, 0, 400, 0, 277, 3, 3, 3, 0, 260, 342, 3, 0, 618, 280, 1866, 3, 3, 0, 554, 3, 3, 3, 3, 0, 317, 262, 3, 3, 3, 3, 3, 3, 3, 3, 0, 269, 0, 303, 3, 0, 573, 2615, 3, 3, 0, 276, 3, 0, 275, 0, 305, 3, 0, 260, 415, 3, 3, 0, 272, 3, 3, 3, 3, 0, 1631, 327, 3, 3, 0, 333, 739, 841, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], - ] + ]) # fmt: on # See https://github.com/huggingface/transformers/pull/39416 From 86fafef2b7e3b30fa99dcf3ed4e720fb9bd39b8c Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 28 Jul 2025 13:08:35 +0200 Subject: [PATCH 6/7] fix old test failures --- tests/models/llava_next/test_modeling_llava_next.py | 4 ++-- .../models/llava_next_video/test_modeling_llava_next_video.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index 8c91176225f7..0c5c771b55c9 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -515,7 +515,7 @@ def test_small_model_integration_test_full_vision_state_selection(self): # test that changing `strategy` won't error out model.vision_feature_select_strategy = "full" - inputs = self.processor(self.prompt, self.image, return_tensors="pt").to(model.device) + inputs = self.processor(text=self.prompt, images=self.image, return_tensors="pt").to(model.device) # verify generation output = model.generate(**inputs, max_new_tokens=30) @@ -536,7 +536,7 @@ def test_granite_vision(self): model = LlavaNextForConditionalGeneration.from_pretrained(granite_model_path) self.processor = AutoProcessor.from_pretrained(granite_model_path) prompt = "<|user|>\n\nWhat is shown in this image?\n<|assistant|>\n" - inputs = self.processor(prompt, self.image, return_tensors="pt").to(model.device) + inputs = self.processor(text=prompt, images=self.image, return_tensors="pt").to(model.device) # verify generation output = model.generate(**inputs, max_new_tokens=30) diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index 2cddb1ecfd36..fa0432ce6e77 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -467,7 +467,7 @@ def test_small_model_integration_test_batch_matches_single(self): padding=True, ).to(torch_device) - inputs_single = self.processor(self.prompt_video, videos=[self.video], return_tensors="pt").to(torch_device) + inputs_single = self.processor(text=self.prompt_video, videos=[self.video], return_tensors="pt").to(torch_device) # verify generation output_batched = model.generate(**inputs_batched, do_sample=False, max_new_tokens=50) From 6d34e6f53a1ec3bb1ee9a809b500e903d6b734de Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 28 Jul 2025 13:09:55 +0200 Subject: [PATCH 7/7] why it wasn't added? --- .../models/llava_next_video/test_modeling_llava_next_video.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index fa0432ce6e77..3230b50e7299 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -467,7 +467,9 @@ def test_small_model_integration_test_batch_matches_single(self): padding=True, ).to(torch_device) - inputs_single = self.processor(text=self.prompt_video, videos=[self.video], return_tensors="pt").to(torch_device) + inputs_single = self.processor(text=self.prompt_video, videos=[self.video], return_tensors="pt").to( + torch_device + ) # verify generation output_batched = model.generate(**inputs_batched, do_sample=False, max_new_tokens=50)