From 9311e59761b209e7f613285b53c7a512e93c34b3 Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Fri, 25 Jul 2025 15:10:59 +0000 Subject: [PATCH 1/7] start fixing kwarg handling --- .../models/granitemoe/modeling_granitemoe.py | 1 + .../modeling_granitemoehybrid.py | 38 +++++++++++++++++-- .../modular_granitemoehybrid.py | 11 ++++-- .../modeling_granitemoeshared.py | 35 +++++++++++++++-- .../modular_granitemoeshared.py | 32 ++++++++++++++-- 5 files changed, 102 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 5ea293e5bfff..29ae0b2f7f86 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -947,6 +947,7 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) # Only compute necessary logits diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index ab31709f3d5f..408cb861a142 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -19,7 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TypedDict, Union import torch import torch.nn.functional as F @@ -34,6 +34,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_granitemoehybrid import GraniteMoeHybridConfig @@ -918,6 +919,31 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states +class GraniteFlashAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage. + Use cases include padding-free training and fewer `torch.compile` graph breaks. + + Attributes: + cu_seq_lens_q (`torch.LongTensor`) + Gets cumulative sequence length for query state. + cu_seq_lens_k (`torch.LongTensor`) + Gets cumulative sequence length for key state. + max_length_q (`int`): + Maximum sequence length for query state. + max_length_k (`int`): + Maximum sequence length for key state. + seq_idx (`torch.IntTensor): + Index of each packed sequence. + """ + + cu_seq_lens_q: torch.LongTensor + cu_seq_lens_k: torch.LongTensor + max_length_q: int + max_length_k: int + seq_idx: torch.IntTensor + + class GraniteMoeHybridRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -1125,7 +1151,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, + **kwargs: Unpack[GraniteFlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -1149,8 +1175,8 @@ def forward( Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model + Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for + padding-free training and/or improve torch.compile performance. """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -1161,6 +1187,7 @@ def forward( cache_position=cache_position, cache_params=past_key_value, attention_mask=attention_mask, + **kwargs, ) # No attention weights for state space layers self_attn_weights = None @@ -1303,6 +1330,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[GraniteFlashAttentionKwargs], ) -> Union[tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1374,6 +1402,7 @@ def forward( cache_position=cache_position, output_router_logits=output_router_logits, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = layer_outputs[0] @@ -1706,6 +1735,7 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) # Only compute necessary logits diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index 4274af1b1920..242c95b90767 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -20,10 +20,12 @@ from ...cache_utils import Cache from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast +from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, logging from ..bamba.configuration_bamba import BambaConfig from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache from ..granitemoeshared.modeling_granitemoeshared import ( + GraniteFlashAttentionKwargs, GraniteMoeSharedAttention, GraniteMoeSharedDecoderLayer, GraniteMoeSharedForCausalLM, @@ -84,7 +86,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, + **kwargs: Unpack[GraniteFlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -108,8 +110,8 @@ def forward( Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model + Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for + padding-free training and/or improve torch.compile performance. """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -120,6 +122,7 @@ def forward( cache_position=cache_position, cache_params=past_key_value, attention_mask=attention_mask, + **kwargs, ) # No attention weights for state space layers self_attn_weights = None @@ -198,6 +201,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[GraniteFlashAttentionKwargs], ) -> Union[tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -269,6 +273,7 @@ def forward( cache_position=cache_position, output_router_logits=output_router_logits, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index b10369e767f3..d0d1923dbb4e 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -19,7 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Union +from typing import Callable, Optional, TypedDict, Union import torch import torch.nn.functional as F @@ -33,6 +33,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_granitemoeshared import GraniteMoeSharedConfig @@ -46,6 +47,31 @@ logger = logging.get_logger(__name__) +class GraniteFlashAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage. + Use cases include padding-free training and fewer `torch.compile` graph breaks. + + Attributes: + cu_seq_lens_q (`torch.LongTensor`) + Gets cumulative sequence length for query state. + cu_seq_lens_k (`torch.LongTensor`) + Gets cumulative sequence length for key state. + max_length_q (`int`): + Maximum sequence length for query state. + max_length_k (`int`): + Maximum sequence length for key state. + seq_idx (`torch.IntTensor): + Index of each packed sequence. + """ + + cu_seq_lens_q: torch.LongTensor + cu_seq_lens_k: torch.LongTensor + max_length_q: int + max_length_k: int + seq_idx: torch.IntTensor + + class GraniteMoeSharedMLP(nn.Module): """ MLP layer for shared experts @@ -431,7 +457,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, + **kwargs: Unpack[GraniteFlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -455,8 +481,8 @@ def forward( Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model + Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for + padding-free training and/or improve torch.compile performance. """ residual = hidden_states @@ -979,6 +1005,7 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) # Only compute necessary logits diff --git a/src/transformers/models/granitemoeshared/modular_granitemoeshared.py b/src/transformers/models/granitemoeshared/modular_granitemoeshared.py index 29342cb6251d..6bd71dc116ce 100644 --- a/src/transformers/models/granitemoeshared/modular_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modular_granitemoeshared.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, TypedDict import torch from torch import nn @@ -21,6 +21,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...utils import logging +from ...processing_utils import Unpack from ..granitemoe.modeling_granitemoe import ( GraniteMoeDecoderLayer, GraniteMoeForCausalLM, @@ -32,6 +33,29 @@ logger = logging.get_logger(__name__) +class GraniteFlashAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage. + Use cases include padding-free training and fewer `torch.compile` graph breaks. + + Attributes: + cu_seq_lens_q (`torch.LongTensor`) + Gets cumulative sequence length for query state. + cu_seq_lens_k (`torch.LongTensor`) + Gets cumulative sequence length for key state. + max_length_q (`int`): + Maximum sequence length for query state. + max_length_k (`int`): + Maximum sequence length for key state. + seq_idx (`torch.IntTensor): + Index of each packed sequence. + """ + + cu_seq_lens_q: torch.LongTensor + cu_seq_lens_k: torch.LongTensor + max_length_q: int + max_length_k: int + seq_idx: torch.IntTensor class GraniteMoeSharedMLP(nn.Module): """ @@ -75,7 +99,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, + **kwargs: Unpack[GraniteFlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -99,8 +123,8 @@ def forward( Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model + Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for + padding-free training and/or improve torch.compile performance. """ residual = hidden_states From 272315a5b8dbb68021681e1dcdb5ab7cda1b6158 Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Fri, 25 Jul 2025 15:45:45 +0000 Subject: [PATCH 2/7] fmt --- .../models/granitemoeshared/modular_granitemoeshared.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/granitemoeshared/modular_granitemoeshared.py b/src/transformers/models/granitemoeshared/modular_granitemoeshared.py index 6bd71dc116ce..630e5aa18439 100644 --- a/src/transformers/models/granitemoeshared/modular_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modular_granitemoeshared.py @@ -20,8 +20,8 @@ from ...activations import ACT2FN from ...cache_utils import Cache -from ...utils import logging from ...processing_utils import Unpack +from ...utils import logging from ..granitemoe.modeling_granitemoe import ( GraniteMoeDecoderLayer, GraniteMoeForCausalLM, @@ -33,6 +33,7 @@ logger = logging.get_logger(__name__) + class GraniteFlashAttentionKwargs(TypedDict, total=False): """ Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage. @@ -57,6 +58,7 @@ class GraniteFlashAttentionKwargs(TypedDict, total=False): max_length_k: int seq_idx: torch.IntTensor + class GraniteMoeSharedMLP(nn.Module): """ MLP layer for shared experts From dd8f3d0fcdf12ab34218272b92f53024af05503b Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Fri, 25 Jul 2025 15:45:50 +0000 Subject: [PATCH 3/7] updates padding free tests --- tests/models/bamba/test_modeling_bamba.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index e1d8128a2c22..cf660b7fa09d 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -551,6 +551,15 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_id inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1) dummy_attention_mask = inputs_dict["attention_mask"] inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id + # Ensure inputs_dict also has labels in it, as their presence/absence can induce + # dtype conversions. This also lets us compare losses. + labels = inputs_dict["input_ids"].clone() + # Mask padding tokens + labels[~dummy_attention_mask.bool()] = -100 + # Also need to mask the first non-trivial token to match the padding-free batch. + first_nonneg_idx = (labels >= 0).int().argmax(dim=1) + labels[torch.arange(labels.size(0), device=labels.device), first_nonneg_idx] = -100 + inputs_dict["labels"] = labels model = ( model_class.from_pretrained( @@ -586,6 +595,10 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_id tol = torch.finfo(torch.float16).eps torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) + loss_padded = res_padded.loss + loss_padfree = res_padfree.loss + torch.testing.assert_close(loss_padded, loss_padfree) + @slow @require_torch From 2370c5af7470bf1d185823b03029569f15495546 Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Fri, 25 Jul 2025 16:30:08 +0000 Subject: [PATCH 4/7] docs --- docs/source/en/model_doc/granitemoehybrid.md | 28 +++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/granitemoehybrid.md b/docs/source/en/model_doc/granitemoehybrid.md index 49fe0b04ac21..92d6e3b70ac4 100644 --- a/docs/source/en/model_doc/granitemoehybrid.md +++ b/docs/source/en/model_doc/granitemoehybrid.md @@ -48,6 +48,32 @@ for i in output: This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co/SukritiSharma) and [Alexander Brooks](https://huggingface.co/abrooks9944). +## Notes + +- `GraniteMoeHybridForCausalLM` supports padding-free training which concatenates distinct training examples while still processing inputs as separate batches. It can significantly accelerate inference by [~2x](https://github.com/huggingface/transformers/pull/35861#issue-2807873129) (depending on model and data distribution) and reduce memory-usage if there are examples of varying lengths by avoiding unnecessary compute and memory overhead from padding tokens. + + Padding-free training requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d` packages and the following arguments must be passed to the model in addition to `input_ids` and `labels`. + + - `position_ids: torch.LongTensor`: the position index of each token in each sequence. + - `seq_idx: torch.IntTensor`: the index of each sequence in the batch. + - Each of the [`FlashAttentionKwargs`] + - `cu_seq_lens_q: torch.LongTensor`: the cumulative sequence lengths of all queries. + - `cu_seq_lens_k: torch.LongTensor`: the cumulative sequence lengths of all keys. + - `max_length_q: int`: the longest query length in the batch. + - `max_length_k: int`: the longest key length in the batch. + + The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] programmatically generates the set of additional arguments above using `return_seq_idx=True` and `return_flash_attn_kwargs=True`. See the [Improving Hugging Face Training Efficiency Through Packing with Flash Attention](https://huggingface.co/blog/packing-with-FA2) blog post for additional information. + + ```python + from transformers import DataCollatorWithFlattening + + # Example of using padding-free training + data_collator = DataCollatorWithFlattening( + tokenizer=tokenizer, + return_seq_idx=True, + return_flash_attn_kwargs=True + ) + ``` ## GraniteMoeHybridConfig @@ -61,4 +87,4 @@ This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co ## GraniteMoeHybridForCausalLM [[autodoc]] GraniteMoeHybridForCausalLM - - forward \ No newline at end of file + - forward From 472d37ff9dcec80ffb6e5105701f70d26ecfe2ea Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Fri, 25 Jul 2025 17:48:13 +0000 Subject: [PATCH 5/7] add missing kwargs modeling_granitemoe.py --- src/transformers/models/granitemoe/modeling_granitemoe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 29ae0b2f7f86..fffe51d794bd 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -641,6 +641,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Union[tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( From 7eb0852a1ed108db9186bc8b340a2abf2cf8edb2 Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Fri, 25 Jul 2025 17:56:40 +0000 Subject: [PATCH 6/7] run modular util --- src/transformers/models/deepseek_vl/modeling_deepseek_vl.py | 4 ++-- .../models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py | 4 ++-- src/transformers/models/evolla/modeling_evolla.py | 2 +- .../models/granitemoeshared/modeling_granitemoeshared.py | 1 + 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py b/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py index ce85d739bc68..60ca3394fa2a 100644 --- a/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py @@ -139,7 +139,7 @@ class DeepseekVLPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_param_buffer_assignment = False def _init_weights(self, module): @@ -236,7 +236,7 @@ def forward( class DeepseekVLForConditionalGeneration(DeepseekVLPreTrainedModel, GenerationMixin): _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"] - _supports_static_cache = True + _can_compile_fullgraph = True def __init__(self, config: DeepseekVLConfig): super().__init__(config) diff --git a/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py index 67b67371f952..1910c5659c5a 100644 --- a/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py @@ -222,7 +222,7 @@ class DeepseekVLHybridPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_param_buffer_assignment = False def _init_weights(self, module): @@ -376,7 +376,7 @@ def get_high_res_image_features(self, pixel_values): class DeepseekVLHybridForConditionalGeneration(DeepseekVLHybridPreTrainedModel, GenerationMixin): _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"] - _supports_static_cache = True + _can_compile_fullgraph = True def __init__(self, config: DeepseekVLHybridConfig): super().__init__(config) diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py index f51f27d6d342..8f91f6005628 100644 --- a/src/transformers/models/evolla/modeling_evolla.py +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -1516,7 +1516,7 @@ class EvollaPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = False _can_record_outputs = { "hidden_states": EvollaDecoderLayer, diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index d0d1923dbb4e..21e9d13f7195 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -619,6 +619,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Union[tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( From dbccf7927ea8455dfc630b1965df029130bb78ba Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Fri, 25 Jul 2025 18:06:27 +0000 Subject: [PATCH 7/7] rm unrelated changes from modular util --- src/transformers/models/deepseek_vl/modeling_deepseek_vl.py | 4 ++-- .../models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py | 4 ++-- src/transformers/models/evolla/modeling_evolla.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py b/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py index 60ca3394fa2a..ce85d739bc68 100644 --- a/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py @@ -139,7 +139,7 @@ class DeepseekVLPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _can_compile_fullgraph = True + _supports_static_cache = True _supports_param_buffer_assignment = False def _init_weights(self, module): @@ -236,7 +236,7 @@ def forward( class DeepseekVLForConditionalGeneration(DeepseekVLPreTrainedModel, GenerationMixin): _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"] - _can_compile_fullgraph = True + _supports_static_cache = True def __init__(self, config: DeepseekVLConfig): super().__init__(config) diff --git a/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py index 1910c5659c5a..67b67371f952 100644 --- a/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py @@ -222,7 +222,7 @@ class DeepseekVLHybridPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _can_compile_fullgraph = True + _supports_static_cache = True _supports_param_buffer_assignment = False def _init_weights(self, module): @@ -376,7 +376,7 @@ def get_high_res_image_features(self, pixel_values): class DeepseekVLHybridForConditionalGeneration(DeepseekVLHybridPreTrainedModel, GenerationMixin): _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"] - _can_compile_fullgraph = True + _supports_static_cache = True def __init__(self, config: DeepseekVLHybridConfig): super().__init__(config) diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py index 8f91f6005628..f51f27d6d342 100644 --- a/src/transformers/models/evolla/modeling_evolla.py +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -1516,7 +1516,7 @@ class EvollaPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _can_compile_fullgraph = True + _supports_static_cache = True _supports_attention_backend = False _can_record_outputs = { "hidden_states": EvollaDecoderLayer,