-
Notifications
You must be signed in to change notification settings - Fork 33.1k
[t5gemma] fix get_text_config and related fixes
#40939
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -395,7 +395,12 @@ def update( | |
| if not self.is_initialized: | ||
| self.lazy_initialization(key_states) | ||
|
|
||
| cache_position = cache_kwargs.get("cache_position") | ||
| # Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention, | ||
| # in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len) | ||
| cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None | ||
| cache_position = ( | ||
| cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device) | ||
| ) | ||
|
|
||
| cumulative_length = self.cumulative_length | ||
| is_full = cumulative_length >= self.max_cache_len | ||
|
|
@@ -955,17 +960,19 @@ def __init__( | |
| layers = [] | ||
| # If a config is passed, use it to infer the layer types and initialize accordingly | ||
| if config is not None: | ||
| config = config.get_text_config(decoder=True) | ||
| sliding_window = getattr(config, "sliding_window", None) or getattr(config, "attention_chunk_size", None) | ||
| layer_types = getattr(config, "layer_types", None) | ||
| decoder_config = config.get_text_config(decoder=True) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (no actual change here, except |
||
| sliding_window = getattr(decoder_config, "sliding_window", None) or getattr( | ||
| decoder_config, "attention_chunk_size", None | ||
| ) | ||
| layer_types = getattr(decoder_config, "layer_types", None) | ||
| if layer_types is None: | ||
| layer_types = [ | ||
| "sliding_attention" if sliding_window is not None else "full_attention" | ||
| for _ in range(config.num_hidden_layers) | ||
| for _ in range(decoder_config.num_hidden_layers) | ||
| ] | ||
| # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n) | ||
| if hasattr(config, "num_kv_shared_layers"): | ||
| layer_types = layer_types[: -config.num_kv_shared_layers] | ||
| if hasattr(decoder_config, "num_kv_shared_layers"): | ||
| layer_types = layer_types[: -decoder_config.num_kv_shared_layers] | ||
|
|
||
| for layer_type in layer_types: | ||
| # From a cache point of view, both sliding and chunked are the same in how they should behave and how many | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -323,9 +323,5 @@ def __setattr__(self, key, value): | |
| setattr(self.decoder, key, value) | ||
| super().__setattr__(key, value) | ||
|
|
||
| def get_text_config(self, *args, **kwargs): | ||
| # Always return self, regardless of the decoder option. | ||
| return self | ||
|
Comment on lines
-326
to
-328
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was preventing us from accessing the decoder config in |
||
|
|
||
|
|
||
| __all__ = ["T5GemmaConfig", "T5GemmaModuleConfig"] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -944,7 +944,7 @@ def test_left_padding_compatibility(self): | |
| decoder_only_classes = [] | ||
| for model_class in self.all_generative_model_classes: | ||
| config, _ = self.prepare_config_and_inputs_for_generate() | ||
| if config.get_text_config(decoder=True).is_encoder_decoder: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic was fundamentally wrong 👀 ( |
||
| if config.is_encoder_decoder: | ||
| continue | ||
| else: | ||
| decoder_only_classes.append(model_class) | ||
|
|
@@ -1192,7 +1192,7 @@ def test_generate_from_inputs_embeds(self, _, num_beams): | |
|
|
||
| # This test is for decoder-only models (encoder-decoder models have native input embeddings support in the | ||
| # decoder) | ||
| if config.get_text_config(decoder=True).is_encoder_decoder: | ||
| if config.is_encoder_decoder: | ||
| continue | ||
| config.is_decoder = True | ||
|
|
||
|
|
@@ -1271,7 +1271,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self): | |
|
|
||
| config, inputs_dict = self.prepare_config_and_inputs_for_generate() | ||
|
|
||
| if config.get_text_config(decoder=True).is_encoder_decoder: | ||
| if config.is_encoder_decoder: | ||
| self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") | ||
|
|
||
| model = model_class(config).to(torch_device).eval() | ||
|
|
@@ -1422,7 +1422,7 @@ def test_generate_continue_from_inputs_embeds(self): | |
| if "token_type_ids" in inputs_dict: | ||
| del inputs_dict["token_type_ids"] | ||
|
|
||
| if config.get_text_config(decoder=True).is_encoder_decoder: | ||
| if config.is_encoder_decoder: | ||
| self.skipTest(reason="This model is encoder-decoder") | ||
| # TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`, | ||
| # but it breaks a few models. Fix and then apply `has_similar_generate_outputs` pattern | ||
|
|
@@ -1495,7 +1495,7 @@ def test_generate_with_static_cache(self): | |
| set_config_for_less_flaky_test(config) | ||
| main_input = inputs_dict[model_class.main_input_name] | ||
|
|
||
| if config.get_text_config(decoder=True).is_encoder_decoder: | ||
| if config.is_encoder_decoder: | ||
| self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") | ||
|
|
||
| config.is_decoder = True | ||
|
|
@@ -1550,10 +1550,7 @@ def test_generate_with_quant_cache(self): | |
| for model_class in self.all_generative_model_classes: | ||
| config, inputs_dict = self.prepare_config_and_inputs_for_generate() | ||
|
|
||
| if ( | ||
| config.get_text_config(decoder=True).is_encoder_decoder | ||
| or not model_class._supports_default_dynamic_cache() | ||
| ): | ||
| if config.is_encoder_decoder or not model_class._supports_default_dynamic_cache(): | ||
| self.skipTest(reason="This model does not support the quantized cache format") | ||
|
|
||
| config.is_decoder = True | ||
|
|
@@ -1653,7 +1650,7 @@ def test_generate_compile_model_forward_fullgraph(self): | |
| if not has_defined_cache_implementation: | ||
| decoder_cache = ( | ||
| gen_out.past_key_values.self_attention_cache | ||
| if config.get_text_config(decoder=True).is_encoder_decoder | ||
| if config.is_encoder_decoder | ||
| else gen_out.past_key_values | ||
| ) | ||
| self.assertTrue(isinstance(decoder_cache, DynamicCache)) | ||
|
|
@@ -1679,7 +1676,7 @@ def test_generate_compile_model_forward_fullgraph(self): | |
| # sanity checks | ||
| decoder_cache = ( | ||
| gen_out.past_key_values.self_attention_cache | ||
| if config.get_text_config(decoder=True).is_encoder_decoder | ||
| if config.is_encoder_decoder | ||
| else gen_out.past_key_values | ||
| ) | ||
| self.assertFalse(isinstance(decoder_cache, DynamicCache)) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pattern copied from
StaticLayer(t5gemma may have staticslidingwindow encoder cache layers)