Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,12 @@ def update(
if not self.is_initialized:
self.lazy_initialization(key_states)

cache_position = cache_kwargs.get("cache_position")
# Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
# in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
cache_position = (
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
)
Comment on lines +398 to +403
Copy link
Copy Markdown
Contributor Author

@gante gante Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pattern copied from StaticLayer

(t5gemma may have staticslidingwindow encoder cache layers)


cumulative_length = self.cumulative_length
is_full = cumulative_length >= self.max_cache_len
Expand Down Expand Up @@ -955,17 +960,19 @@ def __init__(
layers = []
# If a config is passed, use it to infer the layer types and initialize accordingly
if config is not None:
config = config.get_text_config(decoder=True)
sliding_window = getattr(config, "sliding_window", None) or getattr(config, "attention_chunk_size", None)
layer_types = getattr(config, "layer_types", None)
decoder_config = config.get_text_config(decoder=True)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(no actual change here, except config -> decoder_config for more descriptive name and easier debugging)

sliding_window = getattr(decoder_config, "sliding_window", None) or getattr(
decoder_config, "attention_chunk_size", None
)
layer_types = getattr(decoder_config, "layer_types", None)
if layer_types is None:
layer_types = [
"sliding_attention" if sliding_window is not None else "full_attention"
for _ in range(config.num_hidden_layers)
for _ in range(decoder_config.num_hidden_layers)
]
# Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
if hasattr(config, "num_kv_shared_layers"):
layer_types = layer_types[: -config.num_kv_shared_layers]
if hasattr(decoder_config, "num_kv_shared_layers"):
layer_types = layer_types[: -decoder_config.num_kv_shared_layers]

for layer_type in layer_types:
# From a cache point of view, both sliding and chunked are the same in how they should behave and how many
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/t5gemma/configuration_t5gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,5 @@ def __setattr__(self, key, value):
setattr(self.decoder, key, value)
super().__setattr__(key, value)

def get_text_config(self, *args, **kwargs):
# Always return self, regardless of the decoder option.
return self
Comment on lines -326 to -328
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was preventing us from accessing the decoder config in get_text_config(decoder=True). In some places, like in the KV cache initialization, we need specifically the decoder settings.



__all__ = ["T5GemmaConfig", "T5GemmaModuleConfig"]
4 changes: 0 additions & 4 deletions src/transformers/models/t5gemma/modular_t5gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,6 @@ def __setattr__(self, key, value):
setattr(self.decoder, key, value)
super().__setattr__(key, value)

def get_text_config(self, *args, **kwargs):
# Always return self, regardless of the decoder option.
return self


class T5GemmaRMSNorm(Gemma2RMSNorm):
pass
Expand Down
19 changes: 8 additions & 11 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ def test_left_padding_compatibility(self):
decoder_only_classes = []
for model_class in self.all_generative_model_classes:
config, _ = self.prepare_config_and_inputs_for_generate()
if config.get_text_config(decoder=True).is_encoder_decoder:
Copy link
Copy Markdown
Contributor Author

@gante gante Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic was fundamentally wrong 👀

(is_encoder_decoder is a root-level parameter, not a decoder-level parameter)

if config.is_encoder_decoder:
continue
else:
decoder_only_classes.append(model_class)
Expand Down Expand Up @@ -1192,7 +1192,7 @@ def test_generate_from_inputs_embeds(self, _, num_beams):

# This test is for decoder-only models (encoder-decoder models have native input embeddings support in the
# decoder)
if config.get_text_config(decoder=True).is_encoder_decoder:
if config.is_encoder_decoder:
continue
config.is_decoder = True

Expand Down Expand Up @@ -1271,7 +1271,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self):

config, inputs_dict = self.prepare_config_and_inputs_for_generate()

if config.get_text_config(decoder=True).is_encoder_decoder:
if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")

model = model_class(config).to(torch_device).eval()
Expand Down Expand Up @@ -1422,7 +1422,7 @@ def test_generate_continue_from_inputs_embeds(self):
if "token_type_ids" in inputs_dict:
del inputs_dict["token_type_ids"]

if config.get_text_config(decoder=True).is_encoder_decoder:
if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder")
# TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`,
# but it breaks a few models. Fix and then apply `has_similar_generate_outputs` pattern
Expand Down Expand Up @@ -1495,7 +1495,7 @@ def test_generate_with_static_cache(self):
set_config_for_less_flaky_test(config)
main_input = inputs_dict[model_class.main_input_name]

if config.get_text_config(decoder=True).is_encoder_decoder:
if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")

config.is_decoder = True
Expand Down Expand Up @@ -1550,10 +1550,7 @@ def test_generate_with_quant_cache(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()

if (
config.get_text_config(decoder=True).is_encoder_decoder
or not model_class._supports_default_dynamic_cache()
):
if config.is_encoder_decoder or not model_class._supports_default_dynamic_cache():
self.skipTest(reason="This model does not support the quantized cache format")

config.is_decoder = True
Expand Down Expand Up @@ -1653,7 +1650,7 @@ def test_generate_compile_model_forward_fullgraph(self):
if not has_defined_cache_implementation:
decoder_cache = (
gen_out.past_key_values.self_attention_cache
if config.get_text_config(decoder=True).is_encoder_decoder
if config.is_encoder_decoder
else gen_out.past_key_values
)
self.assertTrue(isinstance(decoder_cache, DynamicCache))
Expand All @@ -1679,7 +1676,7 @@ def test_generate_compile_model_forward_fullgraph(self):
# sanity checks
decoder_cache = (
gen_out.past_key_values.self_attention_cache
if config.get_text_config(decoder=True).is_encoder_decoder
if config.is_encoder_decoder
else gen_out.past_key_values
)
self.assertFalse(isinstance(decoder_cache, DynamicCache))
Expand Down
4 changes: 2 additions & 2 deletions tests/models/gemma3n/test_modeling_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self):

config, inputs_dict = self.prepare_config_and_inputs_for_generate()

if config.get_text_config(decoder=True).is_encoder_decoder:
if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")

model = model_class(config).to(torch_device).eval()
Expand Down Expand Up @@ -522,7 +522,7 @@ def test_generate_with_static_cache(self):
set_config_for_less_flaky_test(config)
main_input = inputs_dict[model_class.main_input_name]

if config.get_text_config(decoder=True).is_encoder_decoder:
if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")

config.is_decoder = True
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4391,7 +4391,7 @@ def test_flex_attention_with_grads(self):
if key in inputs_dict:
dummy_inputs[key] = inputs_dict[key].to(torch_device)

if config.get_text_config(decoder=True).is_encoder_decoder:
if config.is_encoder_decoder:
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device)
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device)

Expand Down