diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index e6f2645a766e..528e0b36f853 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -395,7 +395,12 @@ def update( if not self.is_initialized: self.lazy_initialization(key_states) - cache_position = cache_kwargs.get("cache_position") + # Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention, + # in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len) + cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None + cache_position = ( + cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device) + ) cumulative_length = self.cumulative_length is_full = cumulative_length >= self.max_cache_len @@ -955,17 +960,19 @@ def __init__( layers = [] # If a config is passed, use it to infer the layer types and initialize accordingly if config is not None: - config = config.get_text_config(decoder=True) - sliding_window = getattr(config, "sliding_window", None) or getattr(config, "attention_chunk_size", None) - layer_types = getattr(config, "layer_types", None) + decoder_config = config.get_text_config(decoder=True) + sliding_window = getattr(decoder_config, "sliding_window", None) or getattr( + decoder_config, "attention_chunk_size", None + ) + layer_types = getattr(decoder_config, "layer_types", None) if layer_types is None: layer_types = [ "sliding_attention" if sliding_window is not None else "full_attention" - for _ in range(config.num_hidden_layers) + for _ in range(decoder_config.num_hidden_layers) ] # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n) - if hasattr(config, "num_kv_shared_layers"): - layer_types = layer_types[: -config.num_kv_shared_layers] + if hasattr(decoder_config, "num_kv_shared_layers"): + layer_types = layer_types[: -decoder_config.num_kv_shared_layers] for layer_type in layer_types: # From a cache point of view, both sliding and chunked are the same in how they should behave and how many diff --git a/src/transformers/models/t5gemma/configuration_t5gemma.py b/src/transformers/models/t5gemma/configuration_t5gemma.py index 217a24df0417..2085cc8aa517 100644 --- a/src/transformers/models/t5gemma/configuration_t5gemma.py +++ b/src/transformers/models/t5gemma/configuration_t5gemma.py @@ -323,9 +323,5 @@ def __setattr__(self, key, value): setattr(self.decoder, key, value) super().__setattr__(key, value) - def get_text_config(self, *args, **kwargs): - # Always return self, regardless of the decoder option. - return self - __all__ = ["T5GemmaConfig", "T5GemmaModuleConfig"] diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index 4ac42d99239c..7e68245c235f 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -206,10 +206,6 @@ def __setattr__(self, key, value): setattr(self.decoder, key, value) super().__setattr__(key, value) - def get_text_config(self, *args, **kwargs): - # Always return self, regardless of the decoder option. - return self - class T5GemmaRMSNorm(Gemma2RMSNorm): pass diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 680002d4600b..dfe6dfed355b 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -944,7 +944,7 @@ def test_left_padding_compatibility(self): decoder_only_classes = [] for model_class in self.all_generative_model_classes: config, _ = self.prepare_config_and_inputs_for_generate() - if config.get_text_config(decoder=True).is_encoder_decoder: + if config.is_encoder_decoder: continue else: decoder_only_classes.append(model_class) @@ -1192,7 +1192,7 @@ def test_generate_from_inputs_embeds(self, _, num_beams): # This test is for decoder-only models (encoder-decoder models have native input embeddings support in the # decoder) - if config.get_text_config(decoder=True).is_encoder_decoder: + if config.is_encoder_decoder: continue config.is_decoder = True @@ -1271,7 +1271,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self): config, inputs_dict = self.prepare_config_and_inputs_for_generate() - if config.get_text_config(decoder=True).is_encoder_decoder: + if config.is_encoder_decoder: self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") model = model_class(config).to(torch_device).eval() @@ -1422,7 +1422,7 @@ def test_generate_continue_from_inputs_embeds(self): if "token_type_ids" in inputs_dict: del inputs_dict["token_type_ids"] - if config.get_text_config(decoder=True).is_encoder_decoder: + if config.is_encoder_decoder: self.skipTest(reason="This model is encoder-decoder") # TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`, # but it breaks a few models. Fix and then apply `has_similar_generate_outputs` pattern @@ -1495,7 +1495,7 @@ def test_generate_with_static_cache(self): set_config_for_less_flaky_test(config) main_input = inputs_dict[model_class.main_input_name] - if config.get_text_config(decoder=True).is_encoder_decoder: + if config.is_encoder_decoder: self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") config.is_decoder = True @@ -1550,10 +1550,7 @@ def test_generate_with_quant_cache(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - if ( - config.get_text_config(decoder=True).is_encoder_decoder - or not model_class._supports_default_dynamic_cache() - ): + if config.is_encoder_decoder or not model_class._supports_default_dynamic_cache(): self.skipTest(reason="This model does not support the quantized cache format") config.is_decoder = True @@ -1653,7 +1650,7 @@ def test_generate_compile_model_forward_fullgraph(self): if not has_defined_cache_implementation: decoder_cache = ( gen_out.past_key_values.self_attention_cache - if config.get_text_config(decoder=True).is_encoder_decoder + if config.is_encoder_decoder else gen_out.past_key_values ) self.assertTrue(isinstance(decoder_cache, DynamicCache)) @@ -1679,7 +1676,7 @@ def test_generate_compile_model_forward_fullgraph(self): # sanity checks decoder_cache = ( gen_out.past_key_values.self_attention_cache - if config.get_text_config(decoder=True).is_encoder_decoder + if config.is_encoder_decoder else gen_out.past_key_values ) self.assertFalse(isinstance(decoder_cache, DynamicCache)) diff --git a/tests/models/gemma3n/test_modeling_gemma3n.py b/tests/models/gemma3n/test_modeling_gemma3n.py index 5e4b774a8bd0..630eb921d94e 100644 --- a/tests/models/gemma3n/test_modeling_gemma3n.py +++ b/tests/models/gemma3n/test_modeling_gemma3n.py @@ -461,7 +461,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self): config, inputs_dict = self.prepare_config_and_inputs_for_generate() - if config.get_text_config(decoder=True).is_encoder_decoder: + if config.is_encoder_decoder: self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") model = model_class(config).to(torch_device).eval() @@ -522,7 +522,7 @@ def test_generate_with_static_cache(self): set_config_for_less_flaky_test(config) main_input = inputs_dict[model_class.main_input_name] - if config.get_text_config(decoder=True).is_encoder_decoder: + if config.is_encoder_decoder: self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") config.is_decoder = True diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4e95b1f255a5..0965a3e2ea52 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4391,7 +4391,7 @@ def test_flex_attention_with_grads(self): if key in inputs_dict: dummy_inputs[key] = inputs_dict[key].to(torch_device) - if config.get_text_config(decoder=True).is_encoder_decoder: + if config.is_encoder_decoder: dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device) dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device)