From db8e4ffc2ccd984c44befb7e936d778a2339ae6b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 23 Mar 2026 17:22:20 +0100 Subject: [PATCH 01/56] add Cache and test on Mamba --- src/transformers/cache_utils.py | 196 +++++++++++++++++- .../models/jamba/configuration_jamba.py | 6 + .../models/mamba/modeling_mamba.py | 168 ++------------- 3 files changed, 216 insertions(+), 154 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 7dede60a7b27..6f171ac0c5c9 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -669,6 +669,101 @@ def _dequantize(self, qtensor): return tensor +class MambaCacheLayerMixin(ABC): + """Base, abstract class for a mamba single layer's cache.""" + + is_compileable = False + + def __init__(self): + self.conv_states: torch.Tensor | None = None + self.ssm_states: torch.Tensor | None = None + self.is_initialized = False + self.has_previous_state = False + + def __repr__(self): + return f"{self.__class__.__name__}" + + @abstractmethod + def lazy_initialization(self, conv_states: torch.Tensor) -> None: ... + + @abstractmethod + def update_conv_state(self, conv_states: torch.Tensor) -> torch.Tensor: ... + + @abstractmethod + def update_ssm_state(self, ssm_states: torch.Tensor) -> torch.Tensor: ... + + def offload(self): + """Offload this layer's data to CPU device.""" + if self.is_initialized: + self.conv_states = self.conv_states.to("cpu", non_blocking=True) + self.ssm_states = self.ssm_states.to("cpu", non_blocking=True) + + def prefetch(self): + """In case of layer offloading, this allows to move the data back to the layer's device ahead of time.""" + if self.is_initialized and self.conv_states.device != self.device: + self.conv_states = self.conv_states.to(self.device, non_blocking=True) + self.ssm_states = self.ssm_states.to(self.device, non_blocking=True) + + def reset(self) -> None: + """Resets the cache values while preserving the objects""" + if self.is_initialized: + self.conv_states.zero_() + self.ssm_states.zero_() + self.has_previous_state = False + + +class MambaLayer(MambaCacheLayerMixin): + def lazy_initialization(self, conv_states: torch.Tensor, ssm_states: torch.Tensor) -> None: + self.dtype, self.device = conv_states.dtype, conv_states.device + self.conv_states = torch.tensor([], dtype=self.dtype, device=self.device) + self.is_initialized = True + + def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Update the mamba cache in-place, and return the necessary conv states. + + Args: + conv_states (`torch.Tensor`): The new conv states to cache. + + Returns: + `torch.Tensor`: The updated conv states. + """ + # Lazy initialization + if not self.is_initialized: + self.lazy_initialization(conv_states, conv_states) + + # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, + # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. + # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now + if not self.has_previous_state: + self.conv_states = conv_states + self.has_previous_state = True + else: + new_conv_states = self.conv_states.roll(shifts=-1, dims=-1) + new_conv_states[:, :, -1:] = conv_states + self.conv_states = new_conv_states + + return self.conv_states + + def update_ssm_state(self, ssm_states: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Update the mamba cache in-place, and return the necessary ssm states. + + Args: + smm_states (`torch.Tensor`): The new ssm states to cache. + + Returns: + `torch.Tensor`: The updated ssm states. + """ + # Lazy initialization + if not self.is_initialized: + self.lazy_initialization(ssm_states, ssm_states) + + self.ssm_states = ssm_states + + return self.ssm_states + + class Cache: """ A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for @@ -676,9 +771,9 @@ class Cache: Args: layers (`Optional`, *optional*): - A list of pre-created `CacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate` will - be used. - layer_class_to_replicate (`type[CacheLayerMixin]`, *optional*): + A list of pre-created `CacheLayerMixin` or `MambaCacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate` + will be used. + layer_class_to_replicate (`type[CacheLayerMixin | MambaCacheLayerMixin]`, *optional*): Only used if `layers` is omitted (`None`), in which case it will be used as the base class for each layer, and the layers will be added lazily as soon as `update` is called with a `layer_idx` greater than the current list of layers. @@ -691,8 +786,8 @@ class Cache: def __init__( self, - layers: list[CacheLayerMixin] | None = None, - layer_class_to_replicate: type[CacheLayerMixin] | None = None, + layers: list[CacheLayerMixin | MambaCacheLayerMixin] | None = None, + layer_class_to_replicate: type[CacheLayerMixin | MambaCacheLayerMixin] | None = None, offloading: bool = False, offload_only_non_sliding: bool = True, ): @@ -779,6 +874,46 @@ def update( return keys, values + def update_conv_state(self, conv_states: torch.Tensor, layer_idx: int, **kwargs) -> torch.Tensor: + """ + Updates the cache with the new `conv_states` for the layer `layer_idx`. + + Parameters: + conv_states (`torch.Tensor`): + The new conv states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + + Return: + `torch.Tensor`: The updated conv states. + """ + # NOTE: if we slightly break `update` arg order, we could combine this with it, and allow offloading support + # out of the box + if not isinstance(self.layers[layer_idx], MambaCacheLayerMixin): + raise ValueError("Cannot call `update_conv_state` on a non-Mamba layer!") + conv_states = self.layers[layer_idx].update_conv_state(conv_states, **kwargs) + return conv_states + + def update_ssm_state(self, ssm_states: torch.Tensor, layer_idx: int, **kwargs) -> torch.Tensor: + """ + Updates the cache with the new `ssm_states` for the layer `layer_idx`. + + Parameters: + smm_states (`torch.Tensor`): + The new ssm states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + + Return: + `torch.Tensor`: The updated ssm states. + """ + # NOTE: if we slightly break `update` arg order, we could combine this with it, and allow offloading support + # out of the box + if not isinstance(self.layers[layer_idx], MambaCacheLayerMixin): + raise ValueError("Cannot call `update_conv_state` on a non-Mamba layer!") + ssm_states = self.layers[layer_idx].update_ssm_state(ssm_states, **kwargs) + return ssm_states + def early_initialization( self, batch_size: int, num_heads: int, head_dim: int, dtype: torch.dtype, device: torch.device ): @@ -798,6 +933,24 @@ def get_seq_length(self, layer_idx: int = 0) -> int: """Returns the sequence length of the cache for the given layer.""" if layer_idx >= len(self.layers): return 0 + + # For Hybrid attention-mamba caches, `get_seq_length` needs to use attention layer idx when called with default layer_idx + if isinstance(self.layers[layer_idx], MambaCacheLayerMixin): + # If this is called with non-default arg, raise + if layer_idx != 0: + raise ValueError( + f"You called `get_seq_length` on layer index {layer_idx}, but this layer is a Mamba layer, which " + "does not track sequence length." + ) + try: + # Use the first attention layer + layer_idx = next(idx for idx in range(len(self)) if isinstance(self.layers[idx], CacheLayerMixin)) + except StopIteration: + raise ValueError( + "`get_seq_length` can only be called on Attention layers, and the current Cache seem to only contain " + "Mamba layers." + ) + return self.layers[layer_idx].get_seq_length() def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: @@ -810,6 +963,24 @@ def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: # simply the query_length if layer_idx >= len(self.layers): return query_length, 0 + + # For Hybrid attention-mamba caches, `get_seq_length` needs to use attention layer idx when called with default layer_idx + if isinstance(self.layers[layer_idx], MambaCacheLayerMixin): + # If this is called with non-default arg, raise + if layer_idx != 0: + raise ValueError( + f"You called `get_mask_sizes` on layer index {layer_idx}, but this layer is a Mamba layer, which " + "does not track sequence length." + ) + try: + # Use the first attention layer + layer_idx = next(idx for idx in range(len(self)) if isinstance(self.layers[idx], CacheLayerMixin)) + except StopIteration: + raise ValueError( + "`get_mask_sizes` can only be called on Attention layers, and the current Cache seem to only contain " + "Mamba layers." + ) + return self.layers[layer_idx].get_mask_sizes(query_length) def get_max_cache_shape(self, layer_idx: int = 0) -> int: @@ -943,12 +1114,17 @@ def __init__( sliding_window = getattr(decoder_config, "sliding_window", None) or getattr( decoder_config, "attention_chunk_size", None ) + conv_kernel = getattr(decoder_config, "conv_kernel", None) layer_types = getattr(decoder_config, "layer_types", None) if layer_types is None: - layer_types = [ - "sliding_attention" if sliding_window is not None else "full_attention" - for _ in range(decoder_config.num_hidden_layers) - ] + layer_types = [] + for _ in range(decoder_config.num_hidden_layers): + if sliding_window is not None: + layer_types.append("sliding_attention") + elif conv_kernel is not None: + layer_types.append("mamba") + else: + layer_types.append("full_attention") # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n) if hasattr(decoder_config, "num_kv_shared_layers"): layer_types = layer_types[: -decoder_config.num_kv_shared_layers] @@ -958,6 +1134,8 @@ def __init__( # states they should return - only the mask changes to make them different at the end! if layer_type in ("sliding_attention", "chunked_attention"): layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window)) + elif layer_type in ("mamba", "conv"): + layers.append(MambaLayer()) else: layers.append(DynamicLayer()) diff --git a/src/transformers/models/jamba/configuration_jamba.py b/src/transformers/models/jamba/configuration_jamba.py index 8fdb63d4af5b..4c940aa6a3fb 100644 --- a/src/transformers/models/jamba/configuration_jamba.py +++ b/src/transformers/models/jamba/configuration_jamba.py @@ -93,6 +93,12 @@ def layers_block_type(self): for i in range(self.num_hidden_layers) ] + @property + def layer_types(self): + # Follow the `layer_types` conventions + layer_types = self.layers_block_type + return ["full_attention" if x == "attention" else x for x in layer_types] + @property def layers_num_experts(self): return [ diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 69002f50ab78..cfb5a19be930 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -22,7 +22,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...configuration_utils import PreTrainedConfig +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import lazy_load_kernel from ...modeling_layers import GradientCheckpointingLayer @@ -55,117 +55,6 @@ pscan = None -class MambaCache: - """ - Cache for mamba model which does not have attention mechanism and key value states. - - Arguments: - config (`PreTrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if - a smaller batch size is used. - dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): - The default `dtype` to use when initializing the layer. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. Should be the same as the layer. - - Example: - - ```python - >>> import torch - >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache - - >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") - - >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> cache_params = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, cache_params=cache_params, use_cache=True) - >>> outputs.cache_params - ``` - """ - - is_compileable = True - - # TODO (joao): add layer_device_map arg and update code in `generate` accordingly - def __init__( - self, - config: PreTrainedConfig, - max_batch_size: int, - dtype: torch.dtype = torch.float16, - device: torch.device | str | None = None, - ): - self.max_batch_size = max_batch_size - self._dtype = dtype - self.intermediate_size = config.intermediate_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_kernel - self.has_previous_state = False - - self.conv_states: list[torch.Tensor] = [] - self.ssm_states: list[torch.Tensor] = [] - device = torch.device(device) if device is not None else None - for _ in range(config.num_hidden_layers): - conv_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.conv_kernel_size, - device=device, - dtype=self._dtype, - ) - ssm_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.ssm_state_size, - device=device, - dtype=self._dtype, - ) - - torch._dynamo.mark_static_address(conv_state) - torch._dynamo.mark_static_address(ssm_state) - self.conv_states.append(conv_state) - self.ssm_states.append(ssm_state) - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used - # when the cache is initialized in the forward pass (e.g. Mamba) - if self.conv_states[layer_idx].device != new_conv_state.device: - self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) - - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now - if cache_init: - self.conv_states[layer_idx].copy_(new_conv_state) - else: - conv_state = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - conv_state[:, :, -1:] = new_conv_state - self.conv_states[layer_idx].copy_(conv_state) - - # If last layer is updated, set the flag - if layer_idx == len(self.conv_states) - 1: - self.has_previous_state = True - - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx].zero_() - self.ssm_states[layer_idx] += new_ssm_state.to(self.ssm_states[layer_idx].device) - return self.ssm_states[layer_idx] - - def reset(self): - self.has_previous_state = False - for layer_idx in range(len(self.conv_states)): - # In-place ops prevent breaking the static address - self.conv_states[layer_idx].zero_() - self.ssm_states[layer_idx].zero_() - - class MambaMixer(nn.Module): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. @@ -278,7 +167,7 @@ def warn_slow_implementation(self): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: MambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): # 1. Gated MLP's linear projection @@ -325,7 +214,7 @@ def cuda_kernels_forward( conv_states = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.update_conv_state(self.layer_idx, conv_states, cache_init=True) + cache_params.update_conv_state(conv_states, self.layer_idx) hidden_states = causal_conv1d_fn( hidden_states, conv_weights, self.conv1d.bias, activation=self.activation ) @@ -371,14 +260,14 @@ def cuda_kernels_forward( return_last_state=True, ) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(self.layer_idx, ssm_state) + cache_params.update_ssm_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) return contextualized_states # fmt: off - def slow_forward(self, input_states, cache_params: MambaCache | None=None, attention_mask: torch.LongTensor | None = None): + def slow_forward(self, input_states, cache_params: Cache | None=None, attention_mask: torch.LongTensor | None = None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated MLP's linear projection @@ -390,7 +279,7 @@ def slow_forward(self, input_states, cache_params: MambaCache | None=None, atten # 2. Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = cache_params[self.layer_idx].ssm_states.clone() ssm_state = ssm_state.to(hidden_states.device) if not cache_params.has_previous_state: conv_state = nn.functional.pad( @@ -398,10 +287,10 @@ def slow_forward(self, input_states, cache_params: MambaCache | None=None, atten (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.update_conv_state(self.layer_idx, conv_state, cache_init=True) + cache_params.update_conv_state(conv_state, self.layer_idx) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] else: - conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_init=False) + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) conv_state = conv_state.to(self.conv1d.weight.device) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: @@ -465,7 +354,7 @@ def combine_fn(left, right): scan_output = (scan_output * self.act(gate)) if cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_ssm_states(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] @@ -475,7 +364,7 @@ def combine_fn(left, right): def forward( self, hidden_states, - cache_params: MambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, **kwargs, ): @@ -519,7 +408,7 @@ def __init__(self, config, layer_idx): def forward( self, hidden_states, - cache_params: MambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, **kwargs, ): @@ -587,7 +476,7 @@ def _init_weights(self, module): ) class MambaOutput(ModelOutput): r""" - cache_params (`MambaCache`): + cache_params (`Cache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -595,7 +484,7 @@ class MambaOutput(ModelOutput): """ last_hidden_state: torch.FloatTensor | None = None - cache_params: MambaCache | None = None + cache_params: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None @@ -611,7 +500,7 @@ class MambaCausalLMOutput(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - cache_params (`MambaCache`): + cache_params (`Cache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -620,7 +509,7 @@ class MambaCausalLMOutput(ModelOutput): loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None - cache_params: MambaCache | None = None + cache_params: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None @@ -655,7 +544,7 @@ def forward( self, input_ids: torch.LongTensor | None = None, inputs_embeds: torch.LongTensor | None = None, - cache_params: MambaCache | None = None, + cache_params: Cache | None = None, use_cache: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, @@ -663,7 +552,7 @@ def forward( **kwargs, ) -> tuple | MambaOutput: r""" - cache_params (`MambaCache`, *optional*): + cache_params (`Cache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). use_cache (`bool`, *optional*): @@ -685,9 +574,7 @@ def forward( use_cache = False if use_cache and cache_params is None: - cache_params = MambaCache( - self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype - ) + cache_params = DynamicCache(self.config) hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None @@ -743,12 +630,11 @@ def prepare_inputs_for_generation( input_ids, inputs_embeds=None, use_cache=None, - cache_params: MambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, is_first_iteration: bool | None = False, **kwargs, ): - # Overwritten -- has custom cache class `MambaCache` model_inputs = super().prepare_inputs_for_generation( input_ids, inputs_embeds=inputs_embeds, @@ -759,15 +645,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if use_cache and cache_params is None: - if inputs_embeds is not None: - max_batch_size = inputs_embeds.size(0) - else: - max_batch_size = input_ids.size(0) - model_inputs["cache_params"] = MambaCache( - self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype - ) - elif use_cache and not is_first_iteration: + if use_cache and not is_first_iteration: model_inputs["attention_mask"] = None return model_inputs @@ -778,7 +656,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, - cache_params: MambaCache | None = None, + cache_params: Cache | None = None, labels: torch.LongTensor | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, @@ -787,7 +665,7 @@ def forward( **kwargs, # for now we need this for generation ) -> tuple | MambaCausalLMOutput: r""" - cache_params (`MambaCache`, *optional*): + cache_params (`Cache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -837,4 +715,4 @@ def forward( ) -__all__ = ["MambaForCausalLM", "MambaModel", "MambaPreTrainedModel", "MambaCache"] +__all__ = ["MambaForCausalLM", "MambaModel", "MambaPreTrainedModel"] From 9d5259826631cc8ef37708c09470ef1ab20ef7fd Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 23 Mar 2026 17:37:43 +0100 Subject: [PATCH 02/56] fix --- docs/source/en/model_doc/falcon_mamba.md | 7 --- docs/source/en/model_doc/mamba.md | 7 --- src/transformers/__init__.py | 1 - .../falcon_mamba/modular_falcon_mamba.py | 59 ++++--------------- .../models/mamba/modeling_mamba.py | 6 +- .../test_modeling_falcon_mamba.py | 21 ++++--- tests/models/mamba/test_modeling_mamba.py | 46 ++++----------- 7 files changed, 35 insertions(+), 112 deletions(-) diff --git a/docs/source/en/model_doc/falcon_mamba.md b/docs/source/en/model_doc/falcon_mamba.md index 78b6e23a8127..bba9e95b63a2 100644 --- a/docs/source/en/model_doc/falcon_mamba.md +++ b/docs/source/en/model_doc/falcon_mamba.md @@ -111,13 +111,6 @@ outputs = model.generate(**inputs, max_new_tokens=100) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` -## FalconMambaCache - -[[autodoc]] FalconMambaCache - - update_conv_state - - update_ssm_state - - reset - ## FalconMambaConfig [[autodoc]] FalconMambaConfig diff --git a/docs/source/en/model_doc/mamba.md b/docs/source/en/model_doc/mamba.md index 7add263ab4fd..dd2bb2580a1a 100644 --- a/docs/source/en/model_doc/mamba.md +++ b/docs/source/en/model_doc/mamba.md @@ -110,13 +110,6 @@ print(tokenizer.decode(output[0], skip_special_tokens=True)) trainer.train() ``` -## MambaCache - -[[autodoc]] MambaCache - - update_conv_state - - update_ssm_state - - reset - ## MambaConfig [[autodoc]] MambaConfig diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b6ba58e2e301..1f084cf413e3 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -632,7 +632,6 @@ from .modeling_utils import AttentionInterface as AttentionInterface from .modeling_utils import PreTrainedModel as PreTrainedModel from .models import * - from .models.mamba.modeling_mamba import MambaCache as MambaCache from .models.timm_wrapper import TimmWrapperImageProcessor as TimmWrapperImageProcessor # Optimization diff --git a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py index 8aec2f95ccb0..b5bda5b31dde 100644 --- a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -18,12 +18,12 @@ from torch import nn from ... import initialization as init +from ...cache_utils import Cache from ...utils import auto_docstring, logging from ...utils.import_utils import is_mambapy_available, is_torch_greater_or_equal, is_torchdynamo_compiling, is_tracing from ..mamba.configuration_mamba import MambaConfig from ..mamba.modeling_mamba import ( MambaBlock, - MambaCache, MambaCausalLMOutput, MambaForCausalLM, MambaMixer, @@ -102,40 +102,6 @@ class FalconMambaConfig(MambaConfig): mixer_rms_eps: float = 1e-6 -class FalconMambaCache(MambaCache): - """ - Cache for falcon_mamba model which does not have attention mechanism and key value states. - - Arguments: - config (`PreTrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if - a smaller batch size is used. - dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): - The default `dtype` to use when initializing the layer. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. Should be the same as the layer. - - Example: - - ```python - >>> import torch - >>> from transformers import AutoTokenizer, FalconMambaForCausalLM, FalconMambaCache - - >>> model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b") - >>> tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b") - - >>> inputs = tokenizer(text="My name is FalconMamba", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> cache_params = FalconMambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, cache_params=cache_params, use_cache=True) - >>> outputs.cache_params - ``` - """ - - def rms_forward(hidden_states, variance_epsilon=1e-6): """ Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will @@ -194,7 +160,7 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int, initialize_mixer_w def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): # 1. Gated MLP's linear projection @@ -233,7 +199,7 @@ def cuda_kernels_forward( if is_decoding: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), - cache_params.conv_states[self.layer_idx], + cache_params[self.layer_idx].conv_states, conv_weights, self.conv1d.bias, self.activation, @@ -244,7 +210,7 @@ def cuda_kernels_forward( conv_states = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.update_conv_state(self.layer_idx, conv_states, cache_init=True) + cache_params.update_conv_state(conv_states, self.layer_idx) hidden_states = causal_conv1d_fn( hidden_states, conv_weights, self.conv1d.bias, activation=self.activation ) @@ -275,7 +241,7 @@ def cuda_kernels_forward( time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None if is_decoding: scan_outputs = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params[self.layer_idx].ssm_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -300,7 +266,7 @@ def cuda_kernels_forward( return_last_state=True, ) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(self.layer_idx, ssm_state) + cache_params.update_ssm_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) @@ -309,7 +275,7 @@ def cuda_kernels_forward( def slow_forward( self, input_states, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): batch_size, seq_len, _ = input_states.shape @@ -323,17 +289,17 @@ def slow_forward( # 2. Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = cache_params[self.layer_idx].ssm_states.clone() ssm_state = ssm_state.to(hidden_states.device) if not cache_params.has_previous_state: conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) - cache_params.update_conv_state(self.layer_idx, conv_state, cache_init=True) + cache_params.update_conv_state(conv_state, self.layer_idx) hidden_states = self.act( self.conv1d(hidden_states)[..., :seq_len] ) # [batch, intermediate_size, seq_len] else: - conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_init=False) + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) conv_state = conv_state.to(self.conv1d.weight.device) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: @@ -422,7 +388,7 @@ def combine_fn(left, right): scan_output = scan_output * self.act(gate) if cache_params is not None: - cache_params.update_ssm_state(self.layer_idx, ssm_state) + cache_params.update_ssm_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] @@ -431,7 +397,7 @@ def combine_fn(left, right): def forward( self, hidden_states, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, **kwargs, ): @@ -499,6 +465,5 @@ class FalconMambaForCausalLM(MambaForCausalLM): "FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel", - "FalconMambaCache", "FalconMambaConfig", ] diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index cfb5a19be930..b74fb8419d10 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -203,7 +203,7 @@ def cuda_kernels_forward( if is_decoding: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), - cache_params.conv_states[self.layer_idx], + cache_params[self.layer_idx].conv_states, conv_weights, self.conv1d.bias, self.activation, @@ -235,7 +235,7 @@ def cuda_kernels_forward( time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None if is_decoding: scan_outputs = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params[self.layer_idx].ssm_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -574,7 +574,7 @@ def forward( use_cache = False if use_cache and cache_params is None: - cache_params = DynamicCache(self.config) + cache_params = DynamicCache(config=self.config) hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py index 9e4ce3e5cb65..42cca6f8d447 100644 --- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py +++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py @@ -41,8 +41,7 @@ if is_torch_available(): import torch - from transformers import FalconMambaForCausalLM, FalconMambaModel - from transformers.models.falcon_mamba.modeling_falcon_mamba import FalconMambaCache + from transformers import DynamicCache, FalconMambaForCausalLM, FalconMambaModel # Copied from transformers.tests.models.mamba.MambaModelTester with Mamba->FalconMamba,mamba->falcon_mamba @@ -255,7 +254,6 @@ def prepare_config_and_inputs_for_common(self): @require_torch -# Copied from transformers.tests.models.mamba.MambaModelTest with Mamba->Falcon,mamba->falcon_mamba,FalconMambaCache->MambaCache class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (FalconMambaModel, FalconMambaForCausalLM) if is_torch_available() else () has_attentions = False # FalconMamba does not support attentions @@ -277,16 +275,16 @@ def setUp(self): ) def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, FalconMambaCache) + self.assertIsInstance(past_key_values, DynamicCache) conv_shape = (batch_size, config.intermediate_size, config.conv_kernel) ssm_shape = (batch_size, config.intermediate_size, config.state_size) - self.assertTrue(config.num_hidden_layers, len(past_key_values.conv_states)) + self.assertTrue(config.num_hidden_layers, len(past_key_values)) - for idx in range(len(past_key_values.conv_states)): - self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) - self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) + for idx in range(len(past_key_values)): + self.assertEqual(past_key_values[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values[idx].ssm_states.shape, ssm_shape) def assertInterval(self, member, container, msg=None): r""" @@ -348,9 +346,10 @@ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, FalconMambaCache): # MODIFIED PART START - recursive_check(tuple_object.conv_states, dict_object.conv_states) - recursive_check(tuple_object.ssm_states, dict_object.ssm_states) + if isinstance(tuple_object, DynamicCache): # MODIFIED PART START + for idx in len(tuple_object): + recursive_check(tuple_object[idx].conv_states, dict_object[idx].conv_states) + recursive_check(tuple_object[idx].ssm_states, dict_object[idx].ssm_states) elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): recursive_check(tuple_iterable_value, dict_iterable_value) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 6430a014ec4f..88367447fddb 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -31,11 +31,7 @@ if is_torch_available(): import torch - from transformers import ( - MambaForCausalLM, - MambaModel, - ) - from transformers.models.mamba.modeling_mamba import MambaCache + from transformers import DynamicCache, MambaForCausalLM, MambaModel class MambaModelTester: @@ -247,16 +243,16 @@ def test_enable_input_require_grads(self): self.skipTest("Mamba currently requires CUDA/Metal/XPU to run enable_input_require_grads.") def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, MambaCache) + self.assertIsInstance(past_key_values, DynamicCache) conv_shape = (batch_size, config.intermediate_size, config.conv_kernel) ssm_shape = (batch_size, config.intermediate_size, config.state_size) - self.assertTrue(config.num_hidden_layers, len(past_key_values.conv_states)) + self.assertTrue(config.num_hidden_layers, len(past_key_values)) - for idx in range(len(past_key_values.conv_states)): - self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) - self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) + for idx in range(len(past_key_values)): + self.assertEqual(past_key_values[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values[idx].ssm_states.shape, ssm_shape) def assertInterval(self, member, container, msg=None): r""" @@ -317,9 +313,10 @@ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, MambaCache): # MODIFIED PART START - recursive_check(tuple_object.conv_states, dict_object.conv_states) - recursive_check(tuple_object.ssm_states, dict_object.ssm_states) + if isinstance(tuple_object, DynamicCache): # MODIFIED PART START + for idx in len(tuple_object): + recursive_check(tuple_object[idx].conv_states, dict_object[idx].conv_states) + recursive_check(tuple_object[idx].ssm_states, dict_object[idx].ssm_states) elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): recursive_check(tuple_iterable_value, dict_iterable_value) @@ -368,29 +365,6 @@ def recursive_check(tuple_object, dict_object): def test_beam_sample_generate(self): pass - def test_dtype_mismatch_handled_in_cache(self): - config, input_ids, *args = self.model_tester.prepare_config_and_inputs() - model = MambaModel(config) - model.to(torch_device).to(torch.float16) - model.eval() - - # Create cache with float32 dtype - cache_params = MambaCache(config, max_batch_size=input_ids.size(0), dtype=torch.float32, device=torch_device) - - # If code is correct, no error occurs and test passes - outputs = model( - input_ids, - cache_params=cache_params, - use_cache=True, - ) - - self.assertIsNotNone(outputs) - self.assertIsNotNone(outputs.last_hidden_state) - self.assertEqual( - outputs.last_hidden_state.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size), - ) - @unittest.skip("Mamba models do not support DDP.") def test_multi_gpu_data_parallel_forward(self): pass From 659beee9ba3a7af02abad67aab6bd373d1620528 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 23 Mar 2026 17:52:23 +0100 Subject: [PATCH 03/56] fix --- .../falcon_mamba/modeling_falcon_mamba.py | 172 +++--------------- .../models/mamba2/modeling_mamba2.py | 163 ++++------------- tests/models/mamba2/test_modeling_mamba2.py | 24 ++- 3 files changed, 67 insertions(+), 292 deletions(-) diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index f80d0f7ca06f..cb98905ea795 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -27,7 +27,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...configuration_utils import PreTrainedConfig +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import lazy_load_kernel from ...modeling_layers import GradientCheckpointingLayer @@ -57,117 +57,6 @@ logger = logging.get_logger(__name__) -class FalconMambaCache: - """ - Cache for falcon_mamba model which does not have attention mechanism and key value states. - - Arguments: - config (`PreTrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if - a smaller batch size is used. - dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): - The default `dtype` to use when initializing the layer. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. Should be the same as the layer. - - Example: - - ```python - >>> import torch - >>> from transformers import AutoTokenizer, FalconMambaForCausalLM, FalconMambaCache - - >>> model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b") - >>> tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b") - - >>> inputs = tokenizer(text="My name is FalconMamba", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> cache_params = FalconMambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, cache_params=cache_params, use_cache=True) - >>> outputs.cache_params - ``` - """ - - is_compileable = True - - # TODO (joao): add layer_device_map arg and update code in `generate` accordingly - def __init__( - self, - config: PreTrainedConfig, - max_batch_size: int, - dtype: torch.dtype = torch.float16, - device: torch.device | str | None = None, - ): - self.max_batch_size = max_batch_size - self._dtype = dtype - self.intermediate_size = config.intermediate_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_kernel - self.has_previous_state = False - - self.conv_states: list[torch.Tensor] = [] - self.ssm_states: list[torch.Tensor] = [] - device = torch.device(device) if device is not None else None - for _ in range(config.num_hidden_layers): - conv_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.conv_kernel_size, - device=device, - dtype=self._dtype, - ) - ssm_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.ssm_state_size, - device=device, - dtype=self._dtype, - ) - - torch._dynamo.mark_static_address(conv_state) - torch._dynamo.mark_static_address(ssm_state) - self.conv_states.append(conv_state) - self.ssm_states.append(ssm_state) - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used - # when the cache is initialized in the forward pass (e.g. FalconMamba) - if self.conv_states[layer_idx].device != new_conv_state.device: - self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) - - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now - if cache_init: - self.conv_states[layer_idx].copy_(new_conv_state) - else: - conv_state = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - conv_state[:, :, -1:] = new_conv_state - self.conv_states[layer_idx].copy_(conv_state) - - # If last layer is updated, set the flag - if layer_idx == len(self.conv_states) - 1: - self.has_previous_state = True - - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx].zero_() - self.ssm_states[layer_idx] += new_ssm_state.to(self.ssm_states[layer_idx].device) - return self.ssm_states[layer_idx] - - def reset(self): - self.has_previous_state = False - for layer_idx in range(len(self.conv_states)): - # In-place ops prevent breaking the static address - self.conv_states[layer_idx].zero_() - self.ssm_states[layer_idx].zero_() - - def rms_forward(hidden_states, variance_epsilon=1e-6): """ Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will @@ -310,7 +199,7 @@ def warn_slow_implementation(self): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): # 1. Gated MLP's linear projection @@ -349,7 +238,7 @@ def cuda_kernels_forward( if is_decoding: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), - cache_params.conv_states[self.layer_idx], + cache_params[self.layer_idx].conv_states, conv_weights, self.conv1d.bias, self.activation, @@ -360,7 +249,7 @@ def cuda_kernels_forward( conv_states = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.update_conv_state(self.layer_idx, conv_states, cache_init=True) + cache_params.update_conv_state(conv_states, self.layer_idx) hidden_states = causal_conv1d_fn( hidden_states, conv_weights, self.conv1d.bias, activation=self.activation ) @@ -391,7 +280,7 @@ def cuda_kernels_forward( time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None if is_decoding: scan_outputs = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params[self.layer_idx].ssm_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -416,7 +305,7 @@ def cuda_kernels_forward( return_last_state=True, ) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(self.layer_idx, ssm_state) + cache_params.update_ssm_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) @@ -425,7 +314,7 @@ def cuda_kernels_forward( # fmt: off def slow_forward(self, input_states, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): batch_size, seq_len, _ = input_states.shape @@ -439,17 +328,17 @@ def slow_forward(self, # 2. Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = cache_params[self.layer_idx].ssm_states.clone() ssm_state = ssm_state.to(hidden_states.device) if not cache_params.has_previous_state: conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) - cache_params.update_conv_state(self.layer_idx, conv_state, cache_init=True) + cache_params.update_conv_state(conv_state, self.layer_idx) hidden_states = self.act( self.conv1d(hidden_states)[..., :seq_len] ) # [batch, intermediate_size, seq_len] else: - conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_init=False) + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) conv_state = conv_state.to(self.conv1d.weight.device) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: @@ -538,7 +427,7 @@ def combine_fn(left, right): scan_output = scan_output * self.act(gate) if cache_params is not None: - cache_params.update_ssm_state(self.layer_idx, ssm_state) + cache_params.update_ssm_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] @@ -548,7 +437,7 @@ def combine_fn(left, right): def forward( self, hidden_states, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, **kwargs, ): @@ -590,7 +479,7 @@ def __init__(self, config, layer_idx): def forward( self, hidden_states, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, **kwargs, ): @@ -658,7 +547,7 @@ def _init_weights(self, module): ) class FalconMambaOutput(ModelOutput): r""" - cache_params (`FalconMambaCache`): + cache_params (`Cache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -666,7 +555,7 @@ class FalconMambaOutput(ModelOutput): """ last_hidden_state: torch.FloatTensor | None = None - cache_params: FalconMambaCache | None = None + cache_params: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None @@ -682,7 +571,7 @@ class FalconMambaCausalLMOutput(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - cache_params (`FalconMambaCache`): + cache_params (`Cache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -691,7 +580,7 @@ class FalconMambaCausalLMOutput(ModelOutput): loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None - cache_params: FalconMambaCache | None = None + cache_params: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None @@ -721,7 +610,7 @@ def forward( self, input_ids: torch.LongTensor | None = None, inputs_embeds: torch.LongTensor | None = None, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, use_cache: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, @@ -729,7 +618,7 @@ def forward( **kwargs, ) -> tuple | FalconMambaOutput: r""" - cache_params (`FalconMambaCache`, *optional*): + cache_params (`Cache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). use_cache (`bool`, *optional*): @@ -751,9 +640,7 @@ def forward( use_cache = False if use_cache and cache_params is None: - cache_params = FalconMambaCache( - self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype - ) + cache_params = DynamicCache(config=self.config) hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None @@ -809,12 +696,11 @@ def prepare_inputs_for_generation( input_ids, inputs_embeds=None, use_cache=None, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, is_first_iteration: bool | None = False, **kwargs, ): - # Overwritten -- has custom cache class `FalconMambaCache` model_inputs = super().prepare_inputs_for_generation( input_ids, inputs_embeds=inputs_embeds, @@ -825,15 +711,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if use_cache and cache_params is None: - if inputs_embeds is not None: - max_batch_size = inputs_embeds.size(0) - else: - max_batch_size = input_ids.size(0) - model_inputs["cache_params"] = FalconMambaCache( - self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype - ) - elif use_cache and not is_first_iteration: + if use_cache and not is_first_iteration: model_inputs["attention_mask"] = None return model_inputs @@ -844,7 +722,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, labels: torch.LongTensor | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, @@ -853,7 +731,7 @@ def forward( **kwargs, # for now we need this for generation ) -> tuple | FalconMambaCausalLMOutput: r""" - cache_params (`FalconMambaCache`, *optional*): + cache_params (`Cache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -903,4 +781,4 @@ def forward( ) -__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel", "FalconMambaCache"] +__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel"] diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 3a70ec35dfd9..0c3a1d14d12e 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -21,6 +21,7 @@ from ... import initialization as init from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import lazy_load_kernel from ...modeling_layers import GradientCheckpointingLayer @@ -99,94 +100,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): return hidden_states -class Mamba2Cache: - """ - Arguments: - config: Mamba2Config - batch_size: int - dtype: torch.dtype - device: torch.device - - Attributes: - dtype: (`torch.dtype`): - The default `dtype` used to initializing the cache. - conv_kernel_size: (`int`): - Model's convolution kernel size taken from config. - n_groups: (`int`): - Model's number of groups taken from the config - similar to tensor parallel in Transformer. - state_size: (`int`): - Model's SSM state size taken from config. - num_heads: (`int`): - The number of heads used in the linear attention / SSM. - head_dim: (`int`): - The respective dimension of the heads used in the linear attention / SSM. - intermediate_size: (`int`): - Model's intermediate_size based on (expand * hidden_dim) from config. - conv_states: (`torch.Tensor`): - A tensor of shape `[num_layers, batch_size, conv_kernel_size, intermediate_size + 2 * n_groups * state_size]` that holds convolutional states. - ssm_states: (`torch.Tensor`): - A tensor of shape `[num_layers, batch_size, num_heads, head_dim, state_size]` that holds ssm states. - """ - - def __init__( - self, config: Mamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: str | None = None - ): - self.dtype = dtype - self.conv_kernel_size = config.conv_kernel - self.n_groups = config.n_groups - self.state_size = config.state_size - self.num_heads = config.num_heads - self.head_dim = config.head_dim - self.intermediate_size = int(config.expand * config.hidden_size) - - self.conv_states = torch.zeros( - config.num_hidden_layers, - batch_size, - self.intermediate_size + 2 * self.n_groups * self.state_size, - self.conv_kernel_size, - device=device, - dtype=dtype, - ) - self.ssm_states = torch.zeros( - config.num_hidden_layers, - batch_size, - self.num_heads, - self.head_dim, - self.state_size, - device=device, - dtype=dtype, - ) - self.has_previous_state = False - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now - if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) - else: - self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) - - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) - - # If last layer is updated, set the flag - if layer_idx == self.conv_states.shape[0] - 1: - self.has_previous_state = True - - return self.ssm_states[layer_idx] - - def reset(self): - self.has_previous_state = False - self.conv_states.zero_() - self.ssm_states.zero_() - - class MambaRMSNormGated(torch.nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() @@ -325,7 +238,7 @@ def init_mamba2_weights(self): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: Mamba2Cache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): # 1. Gated MLP's linear projection @@ -351,7 +264,7 @@ def cuda_kernels_forward( # 2. Convolution sequence transformation hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -373,7 +286,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params[self.layer_idx].ssm_states, hidden_states_reshaped, dt, A, @@ -429,11 +342,9 @@ def cuda_kernels_forward( hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) conv_states = nn.functional.pad( hidden_states_B_C_transposed, - (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), - ) - cache_params.update_conv_state( - layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True + (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), ) + conv_states = cache_params.update_conv_state(conv_states, layer_idx=self.layer_idx) if self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( @@ -473,7 +384,7 @@ def cuda_kernels_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + cache_params.update_ssm_state(ssm_state, layer_idx=self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer @@ -487,8 +398,8 @@ def cuda_kernels_forward( def torch_forward( self, hidden_states: torch.Tensor, - cache_params: Mamba2Cache | None=None, - attention_mask: torch.Tensor | None=None + cache_params: Cache | None = None, + attention_mask: torch.Tensor | None = None ): batch_size, seq_len, _ = hidden_states.shape dtype = hidden_states.dtype @@ -503,10 +414,10 @@ def torch_forward( # 2. Convolution sequence transformation if cache_params is not None and cache_params.has_previous_state: - cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False) + cache_params.update_conv_state(hidden_states_B_C[:, 0:1, :], layer_idx=self.layer_idx) # We need to guarantee that anything regarding the cache is on the same device - conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + conv_states = cache_params[self.layer_idx].conv_states.to(device=self.conv1d.weight.device) hidden_states_B_C = torch.sum( conv_states * self.conv1d.weight.squeeze(1), dim=-1 @@ -519,9 +430,9 @@ def torch_forward( if cache_params is not None: hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) conv_states = nn.functional.pad( - hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) ) - cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True) + cache_params.update_conv_state(conv_states, layer_idx=self.layer_idx) hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) @@ -536,7 +447,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if cache_params is not None and cache_params.has_previous_state: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.ssm_states.device + cache_device = cache_params[self.layer_idx].device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -567,8 +478,8 @@ def torch_forward( # State calculation cache_params.update_ssm_state( + cache_params[self.layer_idx].ssm_states * dA + dBx, layer_idx=self.layer_idx, - new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx ) # Subsequent output @@ -578,7 +489,7 @@ def torch_forward( C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + ssm_states = cache_params[self.layer_idx].ssm_states.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -640,7 +551,7 @@ def torch_forward( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) if cache_params is not None and cache_params.has_previous_state: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + previous_states = cache_params[self.layer_idx].ssm_states[:, None, ...].to(device=states.device) else: previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) @@ -669,7 +580,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + cache_params.update_ssm_state(ssm_state, layer_idx=self.layer_idx) scan_output = self.norm(y, gate) @@ -683,7 +594,7 @@ def torch_forward( def forward( self, hidden_states, - cache_params: Mamba2Cache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): @@ -721,7 +632,7 @@ def __init__(self, config, layer_idx): def forward( self, hidden_states, - cache_params: Mamba2Cache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): @@ -790,7 +701,7 @@ def _init_weights(self, module): # Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2 class Mamba2Output(ModelOutput): r""" - cache_params (`Mamba2Cache`): + cache_params (`Cache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -798,7 +709,7 @@ class Mamba2Output(ModelOutput): """ last_hidden_state: torch.FloatTensor | None = None - cache_params: Mamba2Cache | None = None + cache_params: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None @@ -815,7 +726,7 @@ class Mamba2CausalLMOutput(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - cache_params (`Mamba2Cache`): + cache_params (`Cache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -824,7 +735,7 @@ class Mamba2CausalLMOutput(ModelOutput): loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None - cache_params: Mamba2Cache | None = None + cache_params: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None @@ -859,7 +770,7 @@ def forward( self, input_ids: torch.LongTensor | None = None, inputs_embeds: torch.LongTensor | None = None, - cache_params: Mamba2Cache | None = None, + cache_params: Cache | None = None, use_cache: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, @@ -867,7 +778,7 @@ def forward( **kwargs, ) -> tuple | Mamba2Output: r""" - cache_params (`Mamba2Cache`, *optional*): + cache_params (`Cache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). use_cache (`bool`, *optional*): @@ -889,9 +800,7 @@ def forward( use_cache = False if use_cache and cache_params is None: - cache_params = Mamba2Cache( - self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype - ) + cache_params = DynamicCache(config=self.config) hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None @@ -947,13 +856,11 @@ def prepare_inputs_for_generation( input_ids, inputs_embeds=None, use_cache=None, - cache_params: Mamba2Cache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, is_first_iteration: bool | None = False, **kwargs, ): - # Overwritten -- has custom cache class `Mamba2Cache` - model_inputs = super().prepare_inputs_for_generation( input_ids, inputs_embeds=inputs_embeds, @@ -964,15 +871,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if use_cache and cache_params is None: - if inputs_embeds is not None: - max_batch_size = inputs_embeds.size(0) - else: - max_batch_size = input_ids.size(0) - model_inputs["cache_params"] = Mamba2Cache( - self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype - ) - elif use_cache and not is_first_iteration: + if use_cache and not is_first_iteration: model_inputs["attention_mask"] = None return model_inputs @@ -982,7 +881,7 @@ def forward( self, input_ids: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, - cache_params: Mamba2Cache | None = None, + cache_params: Cache | None = None, labels: torch.LongTensor | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, @@ -992,7 +891,7 @@ def forward( **kwargs, # for now we need this for generation and loss_function ) -> tuple | Mamba2CausalLMOutput: r""" - cache_params (`Mamba2Cache`, *optional*): + cache_params (`Cache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index aef487d46351..bfbfa178f8d1 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -34,11 +34,8 @@ if is_torch_available(): import torch - from transformers import ( - Mamba2ForCausalLM, - Mamba2Model, - ) - from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer + from transformers import DynamicCache, Mamba2ForCausalLM, Mamba2Model + from transformers.models.mamba2.modeling_mamba2 import Mamba2Mixer class Mamba2ConfigTester(ConfigTester): @@ -248,19 +245,19 @@ def setUp(self): ) def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, Mamba2Cache) + self.assertIsInstance(past_key_values, DynamicCache) intermediate_size = config.expand * config.hidden_size conv_shape = ( - config.num_hidden_layers, batch_size, intermediate_size + 2 * config.n_groups * config.state_size, config.conv_kernel, ) - ssm_shape = (config.num_hidden_layers, batch_size, config.num_heads, config.head_dim, config.state_size) + ssm_shape = (batch_size, config.num_heads, config.head_dim, config.state_size) - self.assertEqual(past_key_values.conv_states.shape, conv_shape) - self.assertEqual(past_key_values.ssm_states.shape, ssm_shape) + for idx in range(len(past_key_values)): + self.assertEqual(past_key_values[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values[idx].ssm_states.shape, ssm_shape) def test_mamba2_caching(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -291,9 +288,10 @@ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, Mamba2Cache): # MODIFIED PART START - recursive_check(tuple_object.conv_states, dict_object.conv_states) - recursive_check(tuple_object.ssm_states, dict_object.ssm_states) + if isinstance(tuple_object, DynamicCache): # MODIFIED PART START + for idx in len(tuple_object): + recursive_check(tuple_object[idx].conv_states, dict_object[idx].conv_states) + recursive_check(tuple_object[idx].ssm_states, dict_object[idx].ssm_states) elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): recursive_check(tuple_iterable_value, dict_iterable_value) From 29b91abf0a8d802d076facfef900476fb181eecf Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 23 Mar 2026 18:08:54 +0100 Subject: [PATCH 04/56] fix --- .../models/falcon_mamba/modeling_falcon_mamba.py | 4 ++-- .../models/falcon_mamba/modular_falcon_mamba.py | 4 ++-- src/transformers/models/mamba/modeling_mamba.py | 4 ++-- src/transformers/models/mamba2/modeling_mamba2.py | 13 ++++++------- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index cb98905ea795..6bb8c9725721 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -231,7 +231,7 @@ def cuda_kernels_forward( if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) - is_decoding = cache_params is not None and cache_params.has_previous_state + is_decoding = cache_params is not None and cache_params[self.layer_idx].has_previous_state # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) @@ -330,7 +330,7 @@ def slow_forward(self, if cache_params is not None: ssm_state = cache_params[self.layer_idx].ssm_states.clone() ssm_state = ssm_state.to(hidden_states.device) - if not cache_params.has_previous_state: + if not cache_params[self.layer_idx].has_previous_state: conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) cache_params.update_conv_state(conv_state, self.layer_idx) diff --git a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py index b5bda5b31dde..0c2dc6ab5541 100644 --- a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -192,7 +192,7 @@ def cuda_kernels_forward( if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) - is_decoding = cache_params is not None and cache_params.has_previous_state + is_decoding = cache_params is not None and cache_params[self.layer_idx].has_previous_state # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) @@ -291,7 +291,7 @@ def slow_forward( if cache_params is not None: ssm_state = cache_params[self.layer_idx].ssm_states.clone() ssm_state = ssm_state.to(hidden_states.device) - if not cache_params.has_previous_state: + if not cache_params[self.layer_idx].has_previous_state: conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) cache_params.update_conv_state(conv_state, self.layer_idx) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index b74fb8419d10..a8ad08aba889 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -196,7 +196,7 @@ def cuda_kernels_forward( if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) - is_decoding = cache_params is not None and cache_params.has_previous_state + is_decoding = cache_params is not None and cache_params[self.layer_idx].has_previous_state # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) @@ -281,7 +281,7 @@ def slow_forward(self, input_states, cache_params: Cache | None=None, attention_ if cache_params is not None: ssm_state = cache_params[self.layer_idx].ssm_states.clone() ssm_state = ssm_state.to(hidden_states.device) - if not cache_params.has_previous_state: + if not cache_params[self.layer_idx].has_previous_state: conv_state = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 0c3a1d14d12e..4684bf7f06f9 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -256,7 +256,7 @@ def cuda_kernels_forward( ) // 2 # Single step calculations via cache - if cache_params is not None and cache_params.has_previous_state: + if cache_params is not None and cache_params[self.layer_idx].has_previous_state: _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) @@ -412,8 +412,10 @@ def torch_forward( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + is_decoding = cache_params is not None and cache_params[self.layer_idx].has_previous_state + # 2. Convolution sequence transformation - if cache_params is not None and cache_params.has_previous_state: + if is_decoding: cache_params.update_conv_state(hidden_states_B_C[:, 0:1, :], layer_idx=self.layer_idx) # We need to guarantee that anything regarding the cache is on the same device @@ -445,7 +447,7 @@ def torch_forward( # 3. SSM transformation A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and cache_params.has_previous_state: + if is_decoding: # We need to guarantee that anything regarding the cache is on the same device cache_device = cache_params[self.layer_idx].device @@ -550,10 +552,7 @@ def torch_forward( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - if cache_params is not None and cache_params.has_previous_state: - previous_states = cache_params[self.layer_idx].ssm_states[:, None, ...].to(device=states.device) - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) decay_chunk = decay_chunk.transpose(1, 3) From 3e02650b77e50ad300ae843890f0a968a979f21f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 23 Mar 2026 18:17:42 +0100 Subject: [PATCH 05/56] fix --- src/transformers/cache_utils.py | 19 +++++++++++++++++++ .../falcon_mamba/modeling_falcon_mamba.py | 4 ++-- .../falcon_mamba/modular_falcon_mamba.py | 4 ++-- .../models/mamba/modeling_mamba.py | 4 ++-- .../models/mamba2/modeling_mamba2.py | 4 ++-- 5 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 6f171ac0c5c9..aa4380a8c330 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -953,6 +953,25 @@ def get_seq_length(self, layer_idx: int = 0) -> int: return self.layers[layer_idx].get_seq_length() + def has_previous_state(self, layer_idx: int = -1) -> bool: + """Returns whether the Mamba layer at index `layer_idx` has previous state or not.""" + if layer_idx >= len(self.layers): + return False + + # In this case, use last Mamba layer + if layer_idx == -1: + try: + layer_idx = next( + idx for idx in range(len(self) - 1, -1, -1) if isinstance(self.layers[idx], CacheLayerMixin) + ) + except StopIteration: + raise ValueError( + "`has_previous_state` can only be called on Mamba layers, and the current Cache seem to only contain " + "Attention layers." + ) + + return self.layers[layer_idx].has_previous_state + def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: """ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 6bb8c9725721..6dbf3017b6ff 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -231,7 +231,7 @@ def cuda_kernels_forward( if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) - is_decoding = cache_params is not None and cache_params[self.layer_idx].has_previous_state + is_decoding = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) @@ -330,7 +330,7 @@ def slow_forward(self, if cache_params is not None: ssm_state = cache_params[self.layer_idx].ssm_states.clone() ssm_state = ssm_state.to(hidden_states.device) - if not cache_params[self.layer_idx].has_previous_state: + if not cache_params.has_previous_state(self.layer_idx): conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) cache_params.update_conv_state(conv_state, self.layer_idx) diff --git a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py index 0c2dc6ab5541..6dc29117a521 100644 --- a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -192,7 +192,7 @@ def cuda_kernels_forward( if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) - is_decoding = cache_params is not None and cache_params[self.layer_idx].has_previous_state + is_decoding = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) @@ -291,7 +291,7 @@ def slow_forward( if cache_params is not None: ssm_state = cache_params[self.layer_idx].ssm_states.clone() ssm_state = ssm_state.to(hidden_states.device) - if not cache_params[self.layer_idx].has_previous_state: + if not cache_params.has_previous_state(self.layer_idx): conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) cache_params.update_conv_state(conv_state, self.layer_idx) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index a8ad08aba889..9791066b79ea 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -196,7 +196,7 @@ def cuda_kernels_forward( if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) - is_decoding = cache_params is not None and cache_params[self.layer_idx].has_previous_state + is_decoding = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) @@ -281,7 +281,7 @@ def slow_forward(self, input_states, cache_params: Cache | None=None, attention_ if cache_params is not None: ssm_state = cache_params[self.layer_idx].ssm_states.clone() ssm_state = ssm_state.to(hidden_states.device) - if not cache_params[self.layer_idx].has_previous_state: + if not cache_params.has_previous_state(self.layer_idx): conv_state = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 4684bf7f06f9..77a83b5aa89a 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -256,7 +256,7 @@ def cuda_kernels_forward( ) // 2 # Single step calculations via cache - if cache_params is not None and cache_params[self.layer_idx].has_previous_state: + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) @@ -412,7 +412,7 @@ def torch_forward( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) - is_decoding = cache_params is not None and cache_params[self.layer_idx].has_previous_state + is_decoding = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # 2. Convolution sequence transformation if is_decoding: From fb883452ff6453955e23490276e4ddb9c6a649f7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 23 Mar 2026 18:25:26 +0100 Subject: [PATCH 06/56] fix --- src/transformers/cache_utils.py | 4 ++-- .../falcon_mamba/modeling_falcon_mamba.py | 6 +++--- .../falcon_mamba/modular_falcon_mamba.py | 6 +++--- .../models/mamba/modeling_mamba.py | 6 +++--- .../models/mamba2/modeling_mamba2.py | 18 ++++++------------ 5 files changed, 17 insertions(+), 23 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index aa4380a8c330..5079f81117a0 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -953,13 +953,13 @@ def get_seq_length(self, layer_idx: int = 0) -> int: return self.layers[layer_idx].get_seq_length() - def has_previous_state(self, layer_idx: int = -1) -> bool: + def has_previous_state(self, layer_idx: int | None = None) -> bool: """Returns whether the Mamba layer at index `layer_idx` has previous state or not.""" if layer_idx >= len(self.layers): return False # In this case, use last Mamba layer - if layer_idx == -1: + if layer_idx is None: try: layer_idx = next( idx for idx in range(len(self) - 1, -1, -1) if isinstance(self.layers[idx], CacheLayerMixin) diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 6dbf3017b6ff..1886c1393960 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -238,7 +238,7 @@ def cuda_kernels_forward( if is_decoding: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), - cache_params[self.layer_idx].conv_states, + cache_params.layers[self.layer_idx].conv_states, conv_weights, self.conv1d.bias, self.activation, @@ -280,7 +280,7 @@ def cuda_kernels_forward( time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None if is_decoding: scan_outputs = selective_state_update( - cache_params[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].ssm_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -328,7 +328,7 @@ def slow_forward(self, # 2. Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params[self.layer_idx].ssm_states.clone() + ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() ssm_state = ssm_state.to(hidden_states.device) if not cache_params.has_previous_state(self.layer_idx): conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) diff --git a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py index 6dc29117a521..e937112f2988 100644 --- a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -199,7 +199,7 @@ def cuda_kernels_forward( if is_decoding: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), - cache_params[self.layer_idx].conv_states, + cache_params.layers[self.layer_idx].conv_states, conv_weights, self.conv1d.bias, self.activation, @@ -241,7 +241,7 @@ def cuda_kernels_forward( time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None if is_decoding: scan_outputs = selective_state_update( - cache_params[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].ssm_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -289,7 +289,7 @@ def slow_forward( # 2. Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params[self.layer_idx].ssm_states.clone() + ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() ssm_state = ssm_state.to(hidden_states.device) if not cache_params.has_previous_state(self.layer_idx): conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 9791066b79ea..0f5a45ca3d57 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -203,7 +203,7 @@ def cuda_kernels_forward( if is_decoding: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), - cache_params[self.layer_idx].conv_states, + cache_params.layers[self.layer_idx].conv_states, conv_weights, self.conv1d.bias, self.activation, @@ -235,7 +235,7 @@ def cuda_kernels_forward( time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None if is_decoding: scan_outputs = selective_state_update( - cache_params[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].ssm_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -279,7 +279,7 @@ def slow_forward(self, input_states, cache_params: Cache | None=None, attention_ # 2. Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params[self.layer_idx].ssm_states.clone() + ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() ssm_state = ssm_state.to(hidden_states.device) if not cache_params.has_previous_state(self.layer_idx): conv_state = nn.functional.pad( diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 77a83b5aa89a..41d97923c120 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -264,7 +264,7 @@ def cuda_kernels_forward( # 2. Convolution sequence transformation hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params[self.layer_idx].conv_states, + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -286,7 +286,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].ssm_states, hidden_states_reshaped, dt, A, @@ -416,10 +416,7 @@ def torch_forward( # 2. Convolution sequence transformation if is_decoding: - cache_params.update_conv_state(hidden_states_B_C[:, 0:1, :], layer_idx=self.layer_idx) - - # We need to guarantee that anything regarding the cache is on the same device - conv_states = cache_params[self.layer_idx].conv_states.to(device=self.conv1d.weight.device) + conv_states = cache_params.update_conv_state(hidden_states_B_C[:, 0:1, :], layer_idx=self.layer_idx) hidden_states_B_C = torch.sum( conv_states * self.conv1d.weight.squeeze(1), dim=-1 @@ -449,7 +446,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if is_decoding: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params[self.layer_idx].device + cache_device = cache_params.layers[self.layer_idx].device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -479,10 +476,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - cache_params.update_ssm_state( - cache_params[self.layer_idx].ssm_states * dA + dBx, - layer_idx=self.layer_idx, - ) + ssm_states = cache_params.layers[self.layer_idx].ssm_states * dA + dBx + ssm_states = cache_params.update_ssm_state(ssm_states, layer_idx=self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -491,7 +486,6 @@ def torch_forward( C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params[self.layer_idx].ssm_states.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] From 1aeddfaefab66ade4a28dcf0f2554262e5ba3ab7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 23 Mar 2026 18:44:19 +0100 Subject: [PATCH 07/56] final fix --- .../models/falcon_mamba/modeling_falcon_mamba.py | 12 +++++++----- .../models/falcon_mamba/modular_falcon_mamba.py | 12 +++++++----- src/transformers/models/mamba/modeling_mamba.py | 16 +++++++++------- .../falcon_mamba/test_modeling_falcon_mamba.py | 10 +++++----- tests/models/mamba/test_modeling_mamba.py | 10 +++++----- tests/models/mamba2/test_modeling_mamba2.py | 10 +++++----- 6 files changed, 38 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 1886c1393960..b2fc8afdae75 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -326,10 +326,15 @@ def slow_forward(self, if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): + ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype + ) + # 2. Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() - ssm_state = ssm_state.to(hidden_states.device) if not cache_params.has_previous_state(self.layer_idx): conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) @@ -347,9 +352,6 @@ def slow_forward(self, self.act(hidden_states).to(dtype).unsqueeze(-1) ) # [batch, intermediate_size, 1] : decoding else: - ssm_state = torch.zeros( - (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype - ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] if attention_mask is not None: diff --git a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py index e937112f2988..b53b85465cc2 100644 --- a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -287,10 +287,15 @@ def slow_forward( if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): + ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype + ) + # 2. Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() - ssm_state = ssm_state.to(hidden_states.device) if not cache_params.has_previous_state(self.layer_idx): conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) @@ -308,9 +313,6 @@ def slow_forward( self.act(hidden_states).to(dtype).unsqueeze(-1) ) # [batch, intermediate_size, 1] : decoding else: - ssm_state = torch.zeros( - (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype - ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] if attention_mask is not None: diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 0f5a45ca3d57..1591583112dc 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -277,10 +277,16 @@ def slow_forward(self, input_states, cache_params: Cache | None=None, attention_ if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): + ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) + # 2. Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() - ssm_state = ssm_state.to(hidden_states.device) if not cache_params.has_previous_state(self.layer_idx): conv_state = nn.functional.pad( hidden_states, @@ -297,10 +303,6 @@ def slow_forward(self, input_states, cache_params: Cache | None=None, attention_ hidden_states += self.conv1d.bias hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding else: - ssm_state = torch.zeros( - (batch_size, self.intermediate_size, self.ssm_state_size), - device=hidden_states.device, dtype=dtype - ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] if attention_mask is not None: @@ -354,7 +356,7 @@ def combine_fn(left, right): scan_output = (scan_output * self.act(gate)) if cache_params is not None: - cache_params.update_ssm_states(ssm_state, self.layer_idx) + cache_params.update_ssm_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py index 42cca6f8d447..8aea217bfbc1 100644 --- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py +++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py @@ -283,8 +283,8 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertTrue(config.num_hidden_layers, len(past_key_values)) for idx in range(len(past_key_values)): - self.assertEqual(past_key_values[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values[idx].ssm_states.shape, ssm_shape) + self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) def assertInterval(self, member, container, msg=None): r""" @@ -347,9 +347,9 @@ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): def recursive_check(tuple_object, dict_object): if isinstance(tuple_object, DynamicCache): # MODIFIED PART START - for idx in len(tuple_object): - recursive_check(tuple_object[idx].conv_states, dict_object[idx].conv_states) - recursive_check(tuple_object[idx].ssm_states, dict_object[idx].ssm_states) + for idx in range(len(tuple_object)): + recursive_check(tuple_object.layers[idx].conv_states, dict_object.layers[idx].conv_states) + recursive_check(tuple_object.layers[idx].ssm_states, dict_object.layers[idx].ssm_states) elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): recursive_check(tuple_iterable_value, dict_iterable_value) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 88367447fddb..f68a17491590 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -251,8 +251,8 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertTrue(config.num_hidden_layers, len(past_key_values)) for idx in range(len(past_key_values)): - self.assertEqual(past_key_values[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values[idx].ssm_states.shape, ssm_shape) + self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) def assertInterval(self, member, container, msg=None): r""" @@ -314,9 +314,9 @@ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): def recursive_check(tuple_object, dict_object): if isinstance(tuple_object, DynamicCache): # MODIFIED PART START - for idx in len(tuple_object): - recursive_check(tuple_object[idx].conv_states, dict_object[idx].conv_states) - recursive_check(tuple_object[idx].ssm_states, dict_object[idx].ssm_states) + for idx in range(len(tuple_object)): + recursive_check(tuple_object.layers[idx].conv_states, dict_object.layers[idx].conv_states) + recursive_check(tuple_object.layers[idx].ssm_states, dict_object.layers[idx].ssm_states) elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): recursive_check(tuple_iterable_value, dict_iterable_value) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index bfbfa178f8d1..6aa16ddeecd2 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -256,8 +256,8 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l ssm_shape = (batch_size, config.num_heads, config.head_dim, config.state_size) for idx in range(len(past_key_values)): - self.assertEqual(past_key_values[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values[idx].ssm_states.shape, ssm_shape) + self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) def test_mamba2_caching(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -289,9 +289,9 @@ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): def recursive_check(tuple_object, dict_object): if isinstance(tuple_object, DynamicCache): # MODIFIED PART START - for idx in len(tuple_object): - recursive_check(tuple_object[idx].conv_states, dict_object[idx].conv_states) - recursive_check(tuple_object[idx].ssm_states, dict_object[idx].ssm_states) + for idx in range(len(tuple_object)): + recursive_check(tuple_object.layers[idx].conv_states, dict_object.layers[idx].conv_states) + recursive_check(tuple_object.layers[idx].ssm_states, dict_object.layers[idx].ssm_states) elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): recursive_check(tuple_iterable_value, dict_iterable_value) From 35db152bde81c10efa978430fa2b971e5720d62d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 23 Mar 2026 18:59:51 +0100 Subject: [PATCH 08/56] test hybrid with jamba --- .../models/jamba/modeling_jamba.py | 174 ++++-------------- .../models/jamba/modular_jamba.py | 174 ++++-------------- 2 files changed, 62 insertions(+), 286 deletions(-) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 3d7986933395..ef0b15b6856b 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -23,13 +23,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable -from typing import Any import torch from torch import nn from ... import initialization as init from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import ( lazy_load_kernel, @@ -74,100 +74,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class HybridMambaAttentionDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config, batch_size, dtype=torch.float16, device=None): - self.dtype = dtype - self.layers_block_type = config.layers_block_type - self.has_previous_state = False # only used by mamba - intermediate_size = config.mamba_expand * config.hidden_size - ssm_state_size = config.mamba_d_state - conv_kernel_size = config.mamba_d_conv - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - if self.layers_block_type[i] == "mamba": - self.conv_states += [ - torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) - ] - self.ssm_states += [ - torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) - ] - else: - self.conv_states += [torch.tensor([[]] * batch_size, device=device)] - self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.key_cache) - - def __getitem__(self, layer_idx): - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - if self.get_seq_length() > 0: - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].shape[-1] == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -260,7 +166,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] @@ -371,17 +277,12 @@ def __init__(self, config: JambaConfig, layer_idx): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): batch_size, seq_len, _ = hidden_states.shape use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) @@ -398,7 +299,7 @@ def cuda_kernels_forward( if use_precomputed_states: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, conv_weights, self.conv1d.bias, self.activation, @@ -407,7 +308,7 @@ def cuda_kernels_forward( else: if cache_params is not None: conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + cache_params.update_conv_state(conv_states, self.layer_idx) hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation) if attention_mask is not None: @@ -442,7 +343,7 @@ def cuda_kernels_forward( time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None if use_precomputed_states: scan_outputs = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].ssm_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -467,7 +368,7 @@ def cuda_kernels_forward( return_last_state=True, ) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_ssm_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) @@ -475,7 +376,7 @@ def cuda_kernels_forward( return contextualized_states # fmt: off - def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache | None = None, attention_mask: torch.LongTensor | None = None): + def slow_forward(self, input_states, cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated MLP's linear projection @@ -485,23 +386,19 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) - use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache) - # 2. Convolution sequence transformation - if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size: - if self.training: - # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass - ssm_state = cache_params.ssm_states[self.layer_idx].clone() - else: - ssm_state = cache_params.ssm_states[self.layer_idx] - - ssm_state = ssm_state.to(hidden_states.device) + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): + # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) - if cache_params.has_previous_state and seq_len == 1 and \ - cache_params.conv_states[self.layer_idx].shape[0] == batch_size: - conv_state = cache_params.conv_states[self.layer_idx] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) - conv_state[:, :, -1] = hidden_states[:, :, 0] - cache_params.conv_states[self.layer_idx] = conv_state + # 2. Convolution sequence transformation + if cache_params is not None: + if cache_params.has_previous_state and seq_len == 1: + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: hidden_states += self.conv1d.bias @@ -511,13 +408,9 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx] = conv_state + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) else: - ssm_state = torch.zeros( - (batch_size, self.intermediate_size, self.ssm_state_size), - device=hidden_states.device, dtype=dtype - ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) if attention_mask is not None: @@ -552,8 +445,8 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa scan_output = scan_output + (hidden_states * self.D[None, :, None]) scan_output = (scan_output * self.act(gate)) - if use_cache: - cache_params.ssm_states[self.layer_idx] = ssm_state + if cache_params is not None: + cache_params.update_ssm_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) @@ -563,7 +456,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa def forward( self, hidden_states, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): if self.config.use_mamba_kernels and ( @@ -690,7 +583,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: @@ -727,7 +620,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: residual = hidden_states @@ -804,7 +697,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -816,12 +709,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = HybridMambaAttentionDynamicCache( - config=self.config, - batch_size=inputs_embeds.shape[0], - dtype=inputs_embeds.dtype, - device=inputs_embeds.device, - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -866,7 +754,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None @@ -980,7 +868,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index da80fbd7187e..45d529d4ff01 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -17,13 +17,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable -from typing import Any import torch from torch import nn from ... import initialization as init from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...integrations import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer @@ -47,100 +47,6 @@ class JambaRMSNorm(LlamaRMSNorm): pass -class HybridMambaAttentionDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config, batch_size, dtype=torch.float16, device=None): - self.dtype = dtype - self.layers_block_type = config.layers_block_type - self.has_previous_state = False # only used by mamba - intermediate_size = config.mamba_expand * config.hidden_size - ssm_state_size = config.mamba_d_state - conv_kernel_size = config.mamba_d_conv - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - if self.layers_block_type[i] == "mamba": - self.conv_states += [ - torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) - ] - self.ssm_states += [ - torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) - ] - else: - self.conv_states += [torch.tensor([[]] * batch_size, device=device)] - self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.key_cache) - - def __getitem__(self, layer_idx): - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - if self.get_seq_length() > 0: - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].shape[-1] == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - class JambaAttention(LlamaAttention): def __init__(self, config: JambaConfig, layer_idx: int): super().__init__(config, layer_idx) @@ -153,7 +59,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] @@ -264,17 +170,12 @@ def __init__(self, config: JambaConfig, layer_idx): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): batch_size, seq_len, _ = hidden_states.shape use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) @@ -291,7 +192,7 @@ def cuda_kernels_forward( if use_precomputed_states: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, conv_weights, self.conv1d.bias, self.activation, @@ -300,7 +201,7 @@ def cuda_kernels_forward( else: if cache_params is not None: conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + cache_params.update_conv_state(conv_states, self.layer_idx) hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation) if attention_mask is not None: @@ -335,7 +236,7 @@ def cuda_kernels_forward( time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None if use_precomputed_states: scan_outputs = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].ssm_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -360,7 +261,7 @@ def cuda_kernels_forward( return_last_state=True, ) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_ssm_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) @@ -368,7 +269,7 @@ def cuda_kernels_forward( return contextualized_states # fmt: off - def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache | None = None, attention_mask: torch.LongTensor | None = None): + def slow_forward(self, input_states, cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated MLP's linear projection @@ -378,23 +279,19 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) - use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache) - # 2. Convolution sequence transformation - if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size: - if self.training: - # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass - ssm_state = cache_params.ssm_states[self.layer_idx].clone() - else: - ssm_state = cache_params.ssm_states[self.layer_idx] - - ssm_state = ssm_state.to(hidden_states.device) + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): + # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) - if cache_params.has_previous_state and seq_len == 1 and \ - cache_params.conv_states[self.layer_idx].shape[0] == batch_size: - conv_state = cache_params.conv_states[self.layer_idx] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) - conv_state[:, :, -1] = hidden_states[:, :, 0] - cache_params.conv_states[self.layer_idx] = conv_state + # 2. Convolution sequence transformation + if cache_params is not None: + if cache_params.has_previous_state and seq_len == 1: + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: hidden_states += self.conv1d.bias @@ -404,13 +301,9 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx] = conv_state + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) else: - ssm_state = torch.zeros( - (batch_size, self.intermediate_size, self.ssm_state_size), - device=hidden_states.device, dtype=dtype - ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) if attention_mask is not None: @@ -445,8 +338,8 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa scan_output = scan_output + (hidden_states * self.D[None, :, None]) scan_output = (scan_output * self.act(gate)) - if use_cache: - cache_params.ssm_states[self.layer_idx] = ssm_state + if cache_params is not None: + cache_params.update_ssm_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) @@ -456,7 +349,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa def forward( self, hidden_states, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): if self.config.use_mamba_kernels and ( @@ -535,7 +428,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: @@ -572,7 +465,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: residual = hidden_states @@ -649,7 +542,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -661,12 +554,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = HybridMambaAttentionDynamicCache( - config=self.config, - batch_size=inputs_embeds.shape[0], - dtype=inputs_embeds.dtype, - device=inputs_embeds.device, - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -711,7 +599,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None @@ -728,7 +616,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, From a50293ca0498dc0e79663773f429ac5461a807d7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 23 Mar 2026 19:06:55 +0100 Subject: [PATCH 09/56] fix tests --- tests/models/jamba/test_modeling_jamba.py | 49 +++++++---------------- 1 file changed, 15 insertions(+), 34 deletions(-) diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index a71271dd3cbe..b562fe477ab9 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -43,14 +43,7 @@ if is_torch_available(): import torch - from transformers import ( - JambaForCausalLM, - JambaForSequenceClassification, - JambaModel, - ) - from transformers.models.jamba.modeling_jamba import ( - HybridMambaAttentionDynamicCache, - ) + from transformers import JambaForCausalLM, JambaForSequenceClassification, JambaModel, DynamicCache class JambaConfigTester(ConfigTester): @@ -250,17 +243,7 @@ def create_and_check_decoder_model_past_large_inputs( model.to(torch_device) model.eval() - # first forward pass - # Attention: Jamba needs the cache to be initialized to return a cache! - past_key_values = HybridMambaAttentionDynamicCache( - config, input_ids.shape[0], model.dtype, device=model.device - ) - outputs = model( - input_ids, - attention_mask=input_mask, - past_key_values=past_key_values, - use_cache=True, - ) + outputs = model(input_ids, attention_mask=input_mask, use_cache=True) past_key_values = outputs.past_key_values # create hypothetical multiple next token and extent to next_input_ids @@ -340,7 +323,7 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ) def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, HybridMambaAttentionDynamicCache) + self.assertIsInstance(past_key_values, DynamicCache) # (batch, kv heads, seq_length, head_dim) num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) @@ -353,18 +336,14 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l for idx in range(len(past_key_values)): if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) - self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) + self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) else: - self.assertEqual(past_key_values.key_cache[idx].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[idx].shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) - def _check_caches_are_equal( - self, cache1: HybridMambaAttentionDynamicCache, cache2: HybridMambaAttentionDynamicCache - ): - if not isinstance(cache1, HybridMambaAttentionDynamicCache) or not isinstance( - cache2, HybridMambaAttentionDynamicCache - ): + def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): + if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): raise ValueError("The wrong cache is being used!") if not len(cache1) == len(cache2): @@ -372,10 +351,12 @@ def _check_caches_are_equal( num_layers = len(cache1) for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx]) - torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx]) + if config.layers_block_type[idx] == "mamba": + torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) + torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + else: + torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) + torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) def setUp(self): self.model_tester = JambaModelTester(self) From 1607fe29487ffad3e389612d523dd6b2a10ccada Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 23 Mar 2026 19:16:40 +0100 Subject: [PATCH 10/56] fixes --- src/transformers/cache_utils.py | 4 ++-- src/transformers/models/jamba/modeling_jamba.py | 2 +- src/transformers/models/jamba/modular_jamba.py | 2 +- tests/models/jamba/test_modeling_jamba.py | 5 +++-- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 5079f81117a0..2b5d921b2a7e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -955,14 +955,14 @@ def get_seq_length(self, layer_idx: int = 0) -> int: def has_previous_state(self, layer_idx: int | None = None) -> bool: """Returns whether the Mamba layer at index `layer_idx` has previous state or not.""" - if layer_idx >= len(self.layers): + if layer_idx is not None and layer_idx >= len(self.layers): return False # In this case, use last Mamba layer if layer_idx is None: try: layer_idx = next( - idx for idx in range(len(self) - 1, -1, -1) if isinstance(self.layers[idx], CacheLayerMixin) + idx for idx in range(len(self) - 1, -1, -1) if isinstance(self.layers[idx], MambaCacheLayerMixin) ) except StopIteration: raise ValueError( diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index ef0b15b6856b..7a12072e6bdd 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -388,7 +388,7 @@ def slow_forward(self, input_states, cache_params: Cache | None = None, attentio if cache_params is not None and cache_params.has_previous_state(self.layer_idx): # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass - ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() else: ssm_state = torch.zeros( (batch_size, self.intermediate_size, self.ssm_state_size), diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index 45d529d4ff01..40e2188175a8 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -281,7 +281,7 @@ def slow_forward(self, input_states, cache_params: Cache | None = None, attentio if cache_params is not None and cache_params.has_previous_state(self.layer_idx): # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass - ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() else: ssm_state = torch.zeros( (batch_size, self.intermediate_size, self.ssm_state_size), diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index b562fe477ab9..499941fcbe81 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -43,7 +43,8 @@ if is_torch_available(): import torch - from transformers import JambaForCausalLM, JambaForSequenceClassification, JambaModel, DynamicCache + from transformers import DynamicCache, JambaForCausalLM, JambaForSequenceClassification, JambaModel + from transformers.cache_utils import MambaLayer class JambaConfigTester(ConfigTester): @@ -351,7 +352,7 @@ def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): num_layers = len(cache1) for idx in range(num_layers): - if config.layers_block_type[idx] == "mamba": + if isinstance(cache1.layers[idx], MambaLayer): torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) else: From ddc198af68e4836537ec654ff1a7a32b8d670966 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 23 Mar 2026 19:27:43 +0100 Subject: [PATCH 11/56] fix --- src/transformers/cache_utils.py | 7 +++++++ src/transformers/models/jamba/modeling_jamba.py | 4 ++-- src/transformers/models/jamba/modular_jamba.py | 4 ++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 2b5d921b2a7e..3229141f67c8 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -711,11 +711,18 @@ def reset(self) -> None: self.ssm_states.zero_() self.has_previous_state = False + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + if self.has_previous_state: + self.conv_states = self.conv_states.index_select(0, beam_idx.to(self.device)) + self.ssm_states = self.ssm_states.index_select(0, beam_idx.to(self.device)) + class MambaLayer(MambaCacheLayerMixin): def lazy_initialization(self, conv_states: torch.Tensor, ssm_states: torch.Tensor) -> None: self.dtype, self.device = conv_states.dtype, conv_states.device self.conv_states = torch.tensor([], dtype=self.dtype, device=self.device) + self.ssm_states = torch.tensor([], dtype=self.dtype, device=self.device) self.is_initialized = True def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor: diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 7a12072e6bdd..753e703ff715 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -397,7 +397,7 @@ def slow_forward(self, input_states, cache_params: Cache | None = None, attentio # 2. Convolution sequence transformation if cache_params is not None: - if cache_params.has_previous_state and seq_len == 1: + if cache_params.has_previous_state(self.layer_idx) and seq_len == 1: conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: @@ -408,7 +408,7 @@ def slow_forward(self, input_states, cache_params: Cache | None = None, attentio hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) else: hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index 40e2188175a8..2a5b21aad005 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -290,7 +290,7 @@ def slow_forward(self, input_states, cache_params: Cache | None = None, attentio # 2. Convolution sequence transformation if cache_params is not None: - if cache_params.has_previous_state and seq_len == 1: + if cache_params.has_previous_state(self.layer_idx) and seq_len == 1: conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: @@ -301,7 +301,7 @@ def slow_forward(self, input_states, cache_params: Cache | None = None, attentio hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) else: hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) From bae4a78d64d2e8a3b953c58ea83c81b045906f71 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 11:16:48 +0100 Subject: [PATCH 12/56] fix --- src/transformers/cache_utils.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 3229141f67c8..dfb03d4eb311 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -719,8 +719,10 @@ def reorder_cache(self, beam_idx: torch.LongTensor): class MambaLayer(MambaCacheLayerMixin): - def lazy_initialization(self, conv_states: torch.Tensor, ssm_states: torch.Tensor) -> None: + def lazy_initialization(self, conv_states: torch.Tensor) -> None: self.dtype, self.device = conv_states.dtype, conv_states.device + # Even if prefill is larfer/shorter than the conv_size, the tensor is always either padded or truncated + self.conv_kernel_size = conv_states.shape[-1] self.conv_states = torch.tensor([], dtype=self.dtype, device=self.device) self.ssm_states = torch.tensor([], dtype=self.dtype, device=self.device) self.is_initialized = True @@ -737,7 +739,7 @@ def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor """ # Lazy initialization if not self.is_initialized: - self.lazy_initialization(conv_states, conv_states) + self.lazy_initialization(conv_states) # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. @@ -746,9 +748,13 @@ def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor self.conv_states = conv_states self.has_previous_state = True else: - new_conv_states = self.conv_states.roll(shifts=-1, dims=-1) - new_conv_states[:, :, -1:] = conv_states - self.conv_states = new_conv_states + num_new_tokens = conv_states.shape[-1] + if num_new_tokens >= self.conv_kernel_size: + self.conv_states = conv_states[..., -self.conv_kernel_size :] + else: + new_conv_states = self.conv_states.roll(shifts=-num_new_tokens, dims=-1) + new_conv_states[:, :, -num_new_tokens:] = conv_states + self.conv_states = new_conv_states return self.conv_states @@ -762,12 +768,7 @@ def update_ssm_state(self, ssm_states: torch.Tensor, **kwargs) -> torch.Tensor: Returns: `torch.Tensor`: The updated ssm states. """ - # Lazy initialization - if not self.is_initialized: - self.lazy_initialization(ssm_states, ssm_states) - self.ssm_states = ssm_states - return self.ssm_states From 984b578a8e807e7e2dae077d118c7f5d446e56c6 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 11:54:57 +0100 Subject: [PATCH 13/56] fix --- src/transformers/models/mamba2/modeling_mamba2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 41d97923c120..43b70dab9185 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -411,12 +411,13 @@ def torch_forward( _, _, gate, hidden_states_B_C, dt = projected_states.split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + hidden_states_B_C = hidden_states_B_C.transpose(1,2) is_decoding = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # 2. Convolution sequence transformation if is_decoding: - conv_states = cache_params.update_conv_state(hidden_states_B_C[:, 0:1, :], layer_idx=self.layer_idx) + conv_states = cache_params.update_conv_state(hidden_states_B_C, layer_idx=self.layer_idx) hidden_states_B_C = torch.sum( conv_states * self.conv1d.weight.squeeze(1), dim=-1 @@ -427,13 +428,12 @@ def torch_forward( else: # Init cache if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) conv_states = nn.functional.pad( - hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + hidden_states_B_C, (self.conv_kernel_size - hidden_states_B_C.shape[-1], 0) ) cache_params.update_conv_state(conv_states, layer_idx=self.layer_idx) - hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C)[..., :seq_len].transpose(1, 2)) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( From cac5d17c0b03c0bb68aa7f57fad204b99df63587 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 13:00:14 +0100 Subject: [PATCH 14/56] combine both types + zambas --- src/transformers/cache_utils.py | 35 +++- .../models/zamba/modeling_zamba.py | 190 +++--------------- .../models/zamba2/modeling_zamba2.py | 182 ++++------------- .../models/zamba2/modular_zamba2.py | 132 +++--------- 4 files changed, 127 insertions(+), 412 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index dfb03d4eb311..42c1d8acf849 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -772,6 +772,26 @@ def update_ssm_state(self, ssm_states: torch.Tensor, **kwargs) -> torch.Tensor: return self.ssm_states +class MambaAndAttentionLayer(MambaLayer, DynamicLayer): + def __init__(self): + DynamicLayer.__init__(self) + MambaLayer.__init__(self) + + def lazy_initialization(self, states_1: torch.Tensor, states_2: torch.Tensor | None = None) -> None: + MambaLayer.lazy_initialization(self, states_1) + self.keys = torch.tensor([], dtype=self.dtype, device=self.device) + self.values = torch.tensor([], dtype=self.dtype, device=self.device) + + def reset(self) -> None: + MambaLayer.reset(self) + DynamicLayer.reset(self) + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + MambaLayer.reorder_cache(self, beam_idx) + DynamicLayer.reorder_cache(self, beam_idx) + + class Cache: """ A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for @@ -942,8 +962,8 @@ def get_seq_length(self, layer_idx: int = 0) -> int: if layer_idx >= len(self.layers): return 0 - # For Hybrid attention-mamba caches, `get_seq_length` needs to use attention layer idx when called with default layer_idx - if isinstance(self.layers[layer_idx], MambaCacheLayerMixin): + # For alternating attention-mamba caches, `get_seq_length` needs to use attention layer idx when called with default layer_idx + if not isinstance(self.layers[layer_idx], CacheLayerMixin): # If this is called with non-default arg, raise if layer_idx != 0: raise ValueError( @@ -977,6 +997,11 @@ def has_previous_state(self, layer_idx: int | None = None) -> bool: "`has_previous_state` can only be called on Mamba layers, and the current Cache seem to only contain " "Attention layers." ) + elif not isinstance(self.layers[layer_idx], MambaCacheLayerMixin): + raise ValueError( + f"You called `has_previous_state` on layer index {layer_idx}, but this layer is an Attention layer, which " + "does not support calling it." + ) return self.layers[layer_idx].has_previous_state @@ -991,8 +1016,8 @@ def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: if layer_idx >= len(self.layers): return query_length, 0 - # For Hybrid attention-mamba caches, `get_seq_length` needs to use attention layer idx when called with default layer_idx - if isinstance(self.layers[layer_idx], MambaCacheLayerMixin): + # For alternating attention-mamba caches, `get_mask_sizes` needs to use attention layer idx when called with default layer_idx + if not isinstance(self.layers[layer_idx], CacheLayerMixin): # If this is called with non-default arg, raise if layer_idx != 0: raise ValueError( @@ -1163,6 +1188,8 @@ def __init__( layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window)) elif layer_type in ("mamba", "conv"): layers.append(MambaLayer()) + elif layer_type == "hybrid": + layers.append(MambaAndAttentionLayer()) else: layers.append(DynamicLayer()) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index c985236ff0f7..dd6bf0ed5414 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -20,7 +20,6 @@ import math from collections.abc import Callable -from typing import Any import torch from torch import nn @@ -28,7 +27,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask @@ -80,107 +79,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class ZambaHybridDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config, batch_size, dtype=torch.float16, device=None): - self.dtype = dtype - self.is_compileable = False - self.layers_block_type = config.layers_block_type - self.has_previous_state = False # only used by mamba - self.intermediate_size = config.mamba_expand * config.hidden_size - self.ssm_state_size = config.mamba_d_state - self.conv_kernel_size = config.mamba_d_conv - self.n_mamba_heads = config.n_mamba_heads - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - self._modules = {} - self._parameters = {} - self._buffers = {} - for i in range(config.num_hidden_layers): - self.conv_states += [ - torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype) - ] - cache_shape = ( - batch_size, - self.n_mamba_heads, - self.intermediate_size // self.n_mamba_heads, - self.ssm_state_size, - ) - self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)] - if self.layers_block_type[i] == "hybrid": - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.key_cache) - - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.reorder_cache - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - if self.get_seq_length() > 0: - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.get_seq_length - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].shape[-1] == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.get_mask_sizes - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -243,7 +141,7 @@ def forward( hidden_states: torch.Tensor, layer_idx: int, attention_mask: torch.Tensor | None, - past_key_values: ZambaHybridDynamicCache | None = None, + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: input_shape = hidden_states.shape[:-1] @@ -369,7 +267,7 @@ def __init__(self, config: ZambaConfig, layer_idx): ) def cuda_kernels_forward( - self, hidden_states: torch.Tensor, cache_params: ZambaHybridDynamicCache = None, attention_mask=None + self, hidden_states: torch.Tensor, cache_params: Cache | None = None, attention_mask=None ): batch_size, seq_len, _ = hidden_states.shape use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1 @@ -387,7 +285,7 @@ def cuda_kernels_forward( if use_precomputed_states: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, conv_weights, self.conv1d.bias, self.activation, @@ -398,7 +296,7 @@ def cuda_kernels_forward( hidden_states = hidden_states * attention_mask.unsqueeze(1) if cache_params is not None: conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation) if attention_mask is not None and not torch.all(attention_mask == 1): hidden_states = hidden_states * attention_mask.unsqueeze(1) @@ -424,7 +322,7 @@ def cuda_kernels_forward( if use_precomputed_states: for n in range(self.n_mamba_heads): scan_outputs_ = selective_state_update( - cache_params.ssm_states[self.layer_idx][:, n], + cache_params.layers[self.layer_idx].ssm_states[:, n], hidden_states[n, ..., 0], discrete_time_step[n, ..., 0], A[n], @@ -459,13 +357,13 @@ def cuda_kernels_forward( scan_outputs = torch.cat((scan_outputs, scan_outputs_), dim=1).contiguous() ssm_state = torch.cat((ssm_state, ssm_state_.unsqueeze(1)), dim=1) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_ssm_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) return contextualized_states - def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None): + def slow_forward(self, input_states, cache_params: Cache | None = None, attention_mask=None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated linear projection @@ -476,26 +374,18 @@ def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = Non gate = gate.squeeze(2) gate = gate.reshape(batch_size, self.n_mamba_heads, -1, seq_len).transpose(0, 1) - use_cache = isinstance(cache_params, ZambaHybridDynamicCache) + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): + # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass + ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype + ) + # 2. Convolution sequence transformation - if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size: - if self.training: - # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass - ssm_state = cache_params.ssm_states[self.layer_idx].clone() - else: - ssm_state = cache_params.ssm_states[self.layer_idx] - - ssm_state = ssm_state.to(hidden_states.device) - - if ( - cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] == batch_size - ): - conv_state = cache_params.conv_states[self.layer_idx] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) - conv_state[:, :, -1] = hidden_states[:, :, 0] - cache_params.conv_states[self.layer_idx] = conv_state + if cache_params is not None: + if cache_params.has_previous_state(self.layer_idx) and seq_len == 1: + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: hidden_states += self.conv1d.bias @@ -504,16 +394,11 @@ def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = Non if attention_mask is not None: hidden_states = hidden_states * attention_mask[:, -hidden_states.shape[-1] :].unsqueeze(1) conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) - cache_params.conv_states[self.layer_idx] = conv_state + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) if attention_mask is not None: hidden_states = hidden_states * attention_mask[:, -hidden_states.shape[-1] :].unsqueeze(1) else: - ssm_state = torch.zeros( - (batch_size, self.n_mamba_heads, self.mamba_head_dim, self.ssm_state_size), - device=hidden_states.device, - dtype=dtype, - ) if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) @@ -549,8 +434,8 @@ def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = Non scan_output = scan_output + (hidden_states * self.D[:, None, :, None]) scan_output = scan_output * self.act(gate) - if use_cache: - cache_params.ssm_states[self.layer_idx] = ssm_state + if cache_params is not None: + cache_params.update_ssm_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj( @@ -558,7 +443,7 @@ def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = Non ) return contextualized_states - def forward(self, hidden_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None, **kwargs): + def forward(self, hidden_states, cache_params: Cache | None = None, attention_mask=None, **kwargs): is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) @@ -606,7 +491,7 @@ def forward( original_hidden_states: torch.Tensor, layer_idx: int, attention_mask: torch.Tensor | None = None, - past_key_values: ZambaHybridDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: @@ -620,7 +505,7 @@ def forward( layer_idx (`int`): layer_idx in the forward pass. Used to distinguish Zamba's tied transformer layers. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -656,7 +541,7 @@ def forward( layer_idx: int | None = None, attention_mask: torch.Tensor | None = None, causal_mask: torch.Tensor | None = None, - past_key_values: ZambaHybridDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, position_ids: torch.LongTensor | None = None, transformer_hidden_states: torch.Tensor | None = None, @@ -667,7 +552,7 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -708,7 +593,7 @@ def forward( layer_idx: int | None = None, attention_mask: torch.Tensor | None = None, causal_mask: torch.Tensor | None = None, - past_key_values: ZambaHybridDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: @@ -720,7 +605,7 @@ def forward( layer_idx (`int`): layer number. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -758,7 +643,6 @@ class ZambaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = False _supports_sdpa = False - # Note: only supports ZambaHybridDynamicCache _is_stateful = True _can_record_outputs = { "hidden_states": ZambaMambaDecoderLayer, @@ -835,7 +719,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: ZambaHybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -854,10 +738,7 @@ def forward( # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer if use_cache and past_key_values is None: - logger.warning_once( - "Zamba requires an initialized `ZambaHybridDynamicCache` to return a cache. None was " - "provided, so no cache will be returned." - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -915,7 +796,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: ZambaHybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, @@ -987,13 +868,6 @@ def prepare_inputs_for_generation( is_first_iteration=False, **kwargs, ): - # Overwritten -- has a unique cache type, `ZambaHybridDynamicCache` - - if past_key_values is None: - past_key_values = ZambaHybridDynamicCache( - self.config, input_ids.shape[0], dtype=self.dtype, device=self.device - ) - kwargs["logits_to_keep"] = self.config.num_logits_to_keep model_inputs = super().prepare_inputs_for_generation( input_ids, diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 775d0d45d009..c882353a8639 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -21,13 +21,14 @@ import math from collections.abc import Callable from itertools import cycle -from typing import Any, Optional +from typing import Optional import torch from torch import nn from ... import initialization as init from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_func_from_hub from ...integrations.hub_kernels import lazy_load_kernel @@ -88,107 +89,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Zamba2HybridDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__( - self, config: Zamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: str | None = None - ): - self.dtype = dtype - self.layers_block_type = config.layers_block_type - self.has_previous_state = False - self.intermediate_size = int(config.mamba_expand * config.hidden_size) - self.ssm_state_size = config.mamba_d_state - self.conv_kernel_size = config.mamba_d_conv - self.n_mamba_heads = config.n_mamba_heads - self.transformer_layers = [] - self._modules = {} - self._parameters = {} - self._buffers = {} - self.conv_states = {} - self.ssm_states = {} - for i in range(config.num_hidden_layers): - self.conv_states[i] = torch.zeros( - batch_size, - self.intermediate_size + 2 * config.mamba_ngroups * config.mamba_d_state, - self.conv_kernel_size, - device=device, - dtype=dtype, - ) - self.ssm_states[i] = torch.zeros( - batch_size, self.n_mamba_heads, config.mamba_headdim, self.ssm_state_size, device=device, dtype=dtype - ) - if self.layers_block_type[i] == "hybrid": - self.transformer_layers.append(i) - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.key_cache) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - if self.get_seq_length() > 0: - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - class Zamba2RotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -401,7 +301,7 @@ def forward( hidden_states: torch.Tensor, layer_idx: int, attention_mask: torch.Tensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: @@ -604,7 +504,7 @@ def __init__(self, config: Zamba2Config, layer_idx: int | None = None): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: Zamba2HybridDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): # set up dimensions for reshapes later @@ -614,7 +514,7 @@ def cuda_kernels_forward( d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads # getting projected states from cache if it exists - if cache_params is not None and cache_params.has_previous_state: + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D) d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads] @@ -622,7 +522,7 @@ def cuda_kernels_forward( hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -643,7 +543,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].ssm_states, hidden_states_reshaped, dt, A, @@ -707,7 +607,7 @@ def cuda_kernels_forward( conv_state = nn.functional.pad( hidden_states_B_C_t, (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] @@ -744,7 +644,7 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_ssm_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer scan_output = self.norm(scan_output, gate) @@ -752,11 +652,11 @@ def cuda_kernels_forward( return out # fmt: off - def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | None=None, attention_mask: torch.Tensor | None=None): + def torch_forward(self, input_states, cache_params: Cache | None=None, attention_mask: torch.Tensor | None=None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # Gated MLP's linear projection - if cache_params is not None and cache_params.has_previous_state: + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): projected_states = self.in_proj(input_states.squeeze(1)) else: if attention_mask is not None: @@ -768,17 +668,15 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + has_previous_state = cache_params.has_previous_state(self.layer_idx) + # Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() ssm_state = ssm_state.to(hidden_states.device) - if cache_params.has_previous_state: + if has_previous_state: gate = gate.unsqueeze(1) - conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) - # handle batched generation - states are copied through - conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states - cache_params.conv_states[self.layer_idx].copy_(conv_state) + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: hidden_states += self.conv1d.bias @@ -789,7 +687,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] if attention_mask is not None: dtype = hidden_states.dtype @@ -803,7 +701,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and cache_params.has_previous_state: + if cache_params is not None and has_previous_state: # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] @@ -832,9 +730,8 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N dBx = dB * hidden_states[..., None] # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_state = ssm_state * dA + dBx + ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -843,7 +740,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_state.to(C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -904,10 +801,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] # permute back B * decay states states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) - if cache_params is not None and cache_params.has_previous_state: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) @@ -935,7 +829,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N y = y[:, :seq_len, :, :] y = y.reshape(batch_size, seq_len, -1) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_ssm_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) @@ -949,7 +843,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N def forward( self, hidden_states, - cache_params: Zamba2HybridDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): @@ -1017,7 +911,7 @@ def forward( original_hidden_states: torch.Tensor, layer_idx: int, attention_mask: torch.Tensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, position_embeddings: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.FloatTensor]: @@ -1030,7 +924,7 @@ def forward( (see fig. 2 in https://huggingface.co/papers/2405.16712). attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -1069,7 +963,7 @@ def forward( layer_idx: int | None = None, attention_mask: torch.Tensor | None = None, causal_mask: torch.Tensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, position_ids: torch.LongTensor | None = None, transformer_hidden_states: torch.Tensor | None = None, @@ -1080,7 +974,7 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -1123,7 +1017,7 @@ def forward( layer_idx: int | None = None, attention_mask: torch.Tensor | None = None, causal_mask: torch.Tensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, position_embeddings: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, @@ -1137,7 +1031,7 @@ def forward( layer_idx (`int`): layer number. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -1245,7 +1139,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -1264,8 +1158,7 @@ def forward( # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer if use_cache and past_key_values is None: - batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] - past_key_values = Zamba2HybridDynamicCache(self.config, batch_size, dtype=self.dtype, device=self.device) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1365,7 +1258,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, @@ -1437,13 +1330,6 @@ def prepare_inputs_for_generation( is_first_iteration=False, **kwargs, ): - # Overwritten -- has a unique cache type, `Zamba2HybridDynamicCache` - - if past_key_values is None: - past_key_values = Zamba2HybridDynamicCache( - self.config, input_ids.shape[0], dtype=self.dtype, device=self.device - ) - kwargs["logits_to_keep"] = self.config.num_logits_to_keep model_inputs = super().prepare_inputs_for_generation( input_ids, @@ -1490,7 +1376,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 2d57dc94046c..fd32a7d217a8 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -21,6 +21,7 @@ from ... import initialization as init from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast @@ -37,7 +38,6 @@ ZambaAttentionDecoderLayer, ZambaForCausalLM, ZambaForSequenceClassification, - ZambaHybridDynamicCache, ZambaHybridLayer, ZambaMambaDecoderLayer, ZambaModel, @@ -77,71 +77,6 @@ class Zamba2RMSNorm(ZambaRMSNorm): pass -class Zamba2HybridDynamicCache(ZambaHybridDynamicCache): - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - def __init__( - self, config: Zamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: str | None = None - ): - self.dtype = dtype - self.layers_block_type = config.layers_block_type - self.has_previous_state = False - self.intermediate_size = int(config.mamba_expand * config.hidden_size) - self.ssm_state_size = config.mamba_d_state - self.conv_kernel_size = config.mamba_d_conv - self.n_mamba_heads = config.n_mamba_heads - self.transformer_layers = [] - self._modules = {} - self._parameters = {} - self._buffers = {} - self.conv_states = {} - self.ssm_states = {} - for i in range(config.num_hidden_layers): - self.conv_states[i] = torch.zeros( - batch_size, - self.intermediate_size + 2 * config.mamba_ngroups * config.mamba_d_state, - self.conv_kernel_size, - device=device, - dtype=dtype, - ) - self.ssm_states[i] = torch.zeros( - batch_size, self.n_mamba_heads, config.mamba_headdim, self.ssm_state_size, device=device, dtype=dtype - ) - if self.layers_block_type[i] == "hybrid": - self.transformer_layers.append(i) - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - class Zamba2RotaryEmbedding(LlamaRotaryEmbedding): pass @@ -208,7 +143,7 @@ def forward( hidden_states: torch.Tensor, layer_idx: int, attention_mask: torch.Tensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: @@ -357,7 +292,7 @@ def __init__(self, config: Zamba2Config, layer_idx: int | None = None): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: Zamba2HybridDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): # set up dimensions for reshapes later @@ -367,7 +302,7 @@ def cuda_kernels_forward( d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads # getting projected states from cache if it exists - if cache_params is not None and cache_params.has_previous_state: + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D) d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads] @@ -375,7 +310,7 @@ def cuda_kernels_forward( hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -396,7 +331,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].ssm_states, hidden_states_reshaped, dt, A, @@ -460,7 +395,7 @@ def cuda_kernels_forward( conv_state = nn.functional.pad( hidden_states_B_C_t, (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] @@ -497,7 +432,7 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_ssm_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer scan_output = self.norm(scan_output, gate) @@ -505,11 +440,11 @@ def cuda_kernels_forward( return out # fmt: off - def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | None=None, attention_mask: torch.Tensor | None=None): + def torch_forward(self, input_states, cache_params: Cache | None=None, attention_mask: torch.Tensor | None=None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # Gated MLP's linear projection - if cache_params is not None and cache_params.has_previous_state: + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): projected_states = self.in_proj(input_states.squeeze(1)) else: if attention_mask is not None: @@ -521,17 +456,15 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + has_previous_state = cache_params.has_previous_state(self.layer_idx) + # Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() ssm_state = ssm_state.to(hidden_states.device) - if cache_params.has_previous_state: + if has_previous_state: gate = gate.unsqueeze(1) - conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) - # handle batched generation - states are copied through - conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states - cache_params.conv_states[self.layer_idx].copy_(conv_state) + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: hidden_states += self.conv1d.bias @@ -542,7 +475,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] if attention_mask is not None: dtype = hidden_states.dtype @@ -556,7 +489,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and cache_params.has_previous_state: + if cache_params is not None and has_previous_state: # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] @@ -585,9 +518,8 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N dBx = dB * hidden_states[..., None] # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_state = ssm_state * dA + dBx + ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -596,7 +528,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_state.to(C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -657,10 +589,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] # permute back B * decay states states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) - if cache_params is not None and cache_params.has_previous_state: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) @@ -688,7 +617,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N y = y[:, :seq_len, :, :] y = y.reshape(batch_size, seq_len, -1) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_ssm_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) @@ -702,7 +631,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N def forward( self, hidden_states, - cache_params: Zamba2HybridDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): @@ -768,7 +697,7 @@ def forward( original_hidden_states: torch.Tensor, layer_idx: int, attention_mask: torch.Tensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, position_embeddings: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.FloatTensor]: @@ -781,7 +710,7 @@ def forward( (see fig. 2 in https://huggingface.co/papers/2405.16712). attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -828,7 +757,7 @@ def forward( layer_idx: int | None = None, attention_mask: torch.Tensor | None = None, causal_mask: torch.Tensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, position_embeddings: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, @@ -842,7 +771,7 @@ def forward( layer_idx (`int`): layer number. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -983,7 +912,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -1002,8 +931,7 @@ def forward( # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer if use_cache and past_key_values is None: - batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] - past_key_values = Zamba2HybridDynamicCache(self.config, batch_size, dtype=self.dtype, device=self.device) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1069,7 +997,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, From bd8f9e9a75997866c54dc086e2d70e6ee0ca25e8 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 13:03:13 +0100 Subject: [PATCH 15/56] =?UTF-8?q?add=20config=20map=C3=A8ping?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/transformers/models/zamba/configuration_zamba.py | 1 + src/transformers/models/zamba2/configuration_zamba2.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/zamba/configuration_zamba.py b/src/transformers/models/zamba/configuration_zamba.py index f0793f9c165e..2fd99acf2207 100644 --- a/src/transformers/models/zamba/configuration_zamba.py +++ b/src/transformers/models/zamba/configuration_zamba.py @@ -53,6 +53,7 @@ class ZambaConfig(PreTrainedConfig): model_type = "zamba" keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"layer_types": "layers_block_type"} vocab_size: int = 32000 tie_word_embeddings: bool = True diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index d5f2673dc08d..f9b56fabf6d4 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -66,7 +66,7 @@ class Zamba2Config(PreTrainedConfig): ```""" model_type = "zamba2" - attribute_map = {"head_dim": "attention_head_dim"} + attribute_map = {"head_dim": "attention_head_dim", "layer_types": "layers_block_type"} keys_to_ignore_at_inference = ["past_key_values"] vocab_size: int = 32000 From b2f1bb8edd26dc362093a1d23dd89b7fd3ad5a37 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 13:12:55 +0100 Subject: [PATCH 16/56] adjust tests --- tests/models/zamba/test_modeling_zamba.py | 43 ++++++++++----------- tests/models/zamba2/test_modeling_zamba2.py | 43 ++++++++++----------- 2 files changed, 42 insertions(+), 44 deletions(-) diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index c5d0c81be98c..2bd7dd2c3bc9 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -38,14 +38,8 @@ if is_torch_available(): import torch - from transformers import ( - ZambaForCausalLM, - ZambaForSequenceClassification, - ZambaModel, - ) - from transformers.models.zamba.modeling_zamba import ( - ZambaHybridDynamicCache, - ) + from transformers import DynamicCache, ZambaForCausalLM, ZambaForSequenceClassification, ZambaModel + from transformers.cache_utils import MambaLayer class ZambaModelTester: @@ -212,12 +206,9 @@ def create_and_check_decoder_model_past_large_inputs( model.eval() # first forward pass - # Attention: Zamba needs the cache to be initialized to return a cache! - past_key_values = ZambaHybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device) outputs = model( input_ids, attention_mask=input_mask, - past_key_values=past_key_values, use_cache=True, ) past_key_values = outputs.past_key_values @@ -299,7 +290,7 @@ class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ) def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, ZambaHybridDynamicCache) + self.assertIsInstance(past_key_values, DynamicCache) # (batch, kv heads, seq_length, head_dim) num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) @@ -314,14 +305,16 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l for idx in range(len(past_key_values)): if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) - self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) + self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) else: - self.assertEqual(past_key_values.key_cache[idx].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[idx].shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) + self.assertEqual(past_key_values.layers[idx].key_cache.shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].value_cache.shape, attention_shape) - def _check_caches_are_equal(self, cache1: ZambaHybridDynamicCache, cache2: ZambaHybridDynamicCache): - if not isinstance(cache1, ZambaHybridDynamicCache) or not isinstance(cache2, ZambaHybridDynamicCache): + def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): + if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): raise ValueError("The wrong cache is being used!") if not len(cache1) == len(cache2): @@ -329,10 +322,16 @@ def _check_caches_are_equal(self, cache1: ZambaHybridDynamicCache, cache2: Zamba num_layers = len(cache1) for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx]) - torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx]) + # Mamba layer + if type(cache1.layers[idx]) is MambaLayer: + torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) + torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + # Hybrid mamba + attention layer + else: + torch.testing.assert_close(cache1.layers[idx].key_cache, cache2.layers[idx].key_cache) + torch.testing.assert_close(cache1.layers[idx].value_cache, cache2.layers[idx].value_cache) + torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) + torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) def setUp(self): self.model_tester = ZambaModelTester(self) diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 8a9d168fe0c5..632855f2138d 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -39,14 +39,8 @@ if is_torch_available(): import torch - from transformers import ( - Zamba2ForCausalLM, - Zamba2ForSequenceClassification, - Zamba2Model, - ) - from transformers.models.zamba2.modeling_zamba2 import ( - Zamba2HybridDynamicCache, - ) + from transformers import DynamicCache, Zamba2ForCausalLM, Zamba2ForSequenceClassification, Zamba2Model + from transformers.cache_utils import MambaLayer class Zamba2ModelTester: @@ -222,12 +216,9 @@ def create_and_check_decoder_model_past_large_inputs( model.eval() # first forward pass - # Attention: Zamba2 needs the cache to be initialized to return a cache! - past_key_values = Zamba2HybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device) outputs = model( input_ids, attention_mask=input_mask, - past_key_values=past_key_values, use_cache=True, ) past_key_values = outputs.past_key_values @@ -309,7 +300,7 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix ) def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, Zamba2HybridDynamicCache) + self.assertIsInstance(past_key_values, DynamicCache) # (batch, kv heads, seq_length, head_dim) num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) @@ -328,14 +319,16 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l for idx in range(len(past_key_values)): if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) - self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) + self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) else: - self.assertEqual(past_key_values.key_cache[idx].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[idx].shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) + self.assertEqual(past_key_values.layers[idx].key_cache.shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].value_cache.shape, attention_shape) - def _check_caches_are_equal(self, cache1: Zamba2HybridDynamicCache, cache2: Zamba2HybridDynamicCache): - if not isinstance(cache1, Zamba2HybridDynamicCache) or not isinstance(cache2, Zamba2HybridDynamicCache): + def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): + if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): raise ValueError("The wrong cache is being used!") if not len(cache1) == len(cache2): @@ -343,10 +336,16 @@ def _check_caches_are_equal(self, cache1: Zamba2HybridDynamicCache, cache2: Zamb num_layers = len(cache1) for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx]) - torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx]) + # Mamba layer + if type(cache1.layers[idx]) is MambaLayer: + torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) + torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + # Hybrid mamba + attention layer + else: + torch.testing.assert_close(cache1.layers[idx].key_cache, cache2.layers[idx].key_cache) + torch.testing.assert_close(cache1.layers[idx].value_cache, cache2.layers[idx].value_cache) + torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) + torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) def setUp(self): self.model_tester = Zamba2ModelTester(self) From 7795808775ad7a68b88a708355597794c213bb2e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 13:27:51 +0100 Subject: [PATCH 17/56] fix --- src/transformers/configuration_utils.py | 1 + src/transformers/models/zamba/modeling_zamba.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 5ff548c523f6..4d7034f9b1d3 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -68,6 +68,7 @@ "attention", "sparse", "dense", + "hybrid", # for layers that have both mamba and attention in zamba and zamba2 ) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index dd6bf0ed5414..feefa4d7d55c 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -379,7 +379,9 @@ def slow_forward(self, input_states, cache_params: Cache | None = None, attentio ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() else: ssm_state = torch.zeros( - (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype + (batch_size, self.n_mamba_heads, self.mamba_head_dim, self.ssm_state_size), + device=hidden_states.device, + dtype=dtype, ) # 2. Convolution sequence transformation From 18685c605bc5ae83e56b0fcde1faf4b4f57b4bb8 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 14:28:43 +0100 Subject: [PATCH 18/56] fix --- tests/models/zamba/test_modeling_zamba.py | 8 ++++---- tests/models/zamba2/test_modeling_zamba2.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index 2bd7dd2c3bc9..4ce37dd638a7 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -310,8 +310,8 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l else: self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) - self.assertEqual(past_key_values.layers[idx].key_cache.shape, attention_shape) - self.assertEqual(past_key_values.layers[idx].value_cache.shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): @@ -328,8 +328,8 @@ def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) # Hybrid mamba + attention layer else: - torch.testing.assert_close(cache1.layers[idx].key_cache, cache2.layers[idx].key_cache) - torch.testing.assert_close(cache1.layers[idx].value_cache, cache2.layers[idx].value_cache) + torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) + torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 632855f2138d..0760c5e60789 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -324,8 +324,8 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l else: self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) - self.assertEqual(past_key_values.layers[idx].key_cache.shape, attention_shape) - self.assertEqual(past_key_values.layers[idx].value_cache.shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): @@ -342,8 +342,8 @@ def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) # Hybrid mamba + attention layer else: - torch.testing.assert_close(cache1.layers[idx].key_cache, cache2.layers[idx].key_cache) - torch.testing.assert_close(cache1.layers[idx].value_cache, cache2.layers[idx].value_cache) + torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) + torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) From fcec6bc0f9f9e3337d5cc1c1756bef9890b0380f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 14:45:50 +0100 Subject: [PATCH 19/56] fix --- src/transformers/models/zamba2/modeling_zamba2.py | 14 ++++++++------ src/transformers/models/zamba2/modular_zamba2.py | 14 ++++++++------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index c882353a8639..2feaf7f7bcba 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -670,10 +670,16 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention has_previous_state = cache_params.has_previous_state(self.layer_idx) + if cache_params is not None and has_previous_state: + ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() + else: + ssm_state = torch.zeros( + (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) + # Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() - ssm_state = ssm_state.to(hidden_states.device) if has_previous_state: gate = gate.unsqueeze(1) conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) @@ -694,10 +700,6 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) else: - ssm_state = torch.zeros( - (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), - device=hidden_states.device, dtype=dtype - ) hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) A = -torch.exp(self.A_log.float()) # [num_heads] diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index fd32a7d217a8..6a4ab477d731 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -458,10 +458,16 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention has_previous_state = cache_params.has_previous_state(self.layer_idx) + if cache_params is not None and has_previous_state: + ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() + else: + ssm_state = torch.zeros( + (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) + # Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() - ssm_state = ssm_state.to(hidden_states.device) if has_previous_state: gate = gate.unsqueeze(1) conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) @@ -482,10 +488,6 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) else: - ssm_state = torch.zeros( - (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), - device=hidden_states.device, dtype=dtype - ) hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) A = -torch.exp(self.A_log.float()) # [num_heads] From b1df43f64fe4d9821c1fc541eb14235edd82b1ba Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 15:19:42 +0100 Subject: [PATCH 20/56] more models --- .../models/bamba/configuration_bamba.py | 1 + .../models/bamba/modeling_bamba.py | 191 +++------------- .../models/bamba/modular_bamba.py | 162 ++++---------- .../configuration_granitemoehybrid.py | 9 +- .../modeling_granitemoehybrid.py | 207 +++--------------- .../modular_granitemoehybrid.py | 38 +--- .../models/zamba2/modeling_zamba2.py | 49 ++--- .../models/zamba2/modular_zamba2.py | 49 ++--- tests/models/bamba/test_modeling_bamba.py | 42 ++-- .../test_modeling_granitemoehybrid.py | 37 ++-- 10 files changed, 175 insertions(+), 610 deletions(-) diff --git a/src/transformers/models/bamba/configuration_bamba.py b/src/transformers/models/bamba/configuration_bamba.py index d9b88cb8c95d..5755ff3d1d55 100644 --- a/src/transformers/models/bamba/configuration_bamba.py +++ b/src/transformers/models/bamba/configuration_bamba.py @@ -43,6 +43,7 @@ class BambaConfig(PreTrainedConfig): """ model_type = "bamba" + attribute_map = {"layer_types": "layers_block_type"} keys_to_ignore_at_inference = ["past_key_values"] vocab_size: int = 128000 diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 8888fa5ddbdf..d9a8ca68576d 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -24,15 +24,14 @@ # limitations under the License. from collections.abc import Callable -from typing import Any, Optional, TypedDict +from typing import Optional, TypedDict import torch from torch import nn -from transformers.activations import ACT2FN - from ... import initialization as init -from ...cache_utils import Cache +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernelized_func from ...integrations.hub_kernels import lazy_load_kernel @@ -76,112 +75,6 @@ class BambaFlashAttentionKwargs(TypedDict, total=False): seq_idx: torch.IntTensor -class HybridMambaAttentionDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): - self.layers_block_type = config.layers_block_type - self.has_previous_state = False # only used by mamba - conv_kernel_size = config.mamba_d_conv - ssm_state_size = config.mamba_d_state - - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - if self.layers_block_type[i] == "mamba": - self.conv_states += [ - torch.zeros( - batch_size, - (config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * ssm_state_size), - conv_kernel_size, - device=device, - dtype=dtype, - ) - ] - self.ssm_states += [ - torch.zeros( - batch_size, - config.mamba_n_heads, - config.mamba_d_head, - ssm_state_size, - device=device, - dtype=dtype, - ) - ] - else: - self.conv_states += [torch.tensor([[]] * batch_size, device=device)] - self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.key_cache) - - def __getitem__(self, layer_idx): - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - if self.get_seq_length() > 0: - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].shape[-1] == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - class BambaRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -592,7 +485,7 @@ def __init__(self, config: BambaConfig, layer_idx: int): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, seq_idx: torch.IntTensor | None = None, ): @@ -605,12 +498,7 @@ def cuda_kernels_forward( groups_time_state_size = self.n_groups * self.ssm_state_size use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) # getting projected states from cache if it exists @@ -622,7 +510,7 @@ def cuda_kernels_forward( # 2. Convolution sequence transformation hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -644,7 +532,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].ssm_states, hidden_states_reshaped, dt, A, @@ -704,7 +592,7 @@ def cuda_kernels_forward( hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), ) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) if self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( @@ -745,7 +633,7 @@ def cuda_kernels_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer @@ -759,7 +647,7 @@ def cuda_kernels_forward( def torch_forward( self, input_states, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): batch_size, seq_len, _ = input_states.shape @@ -771,23 +659,13 @@ def torch_forward( gate, hidden_states_B_C, dt = projected_states.split( [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + hidden_states_B_C = hidden_states_B_C.transpsose(1,2) - use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size - ) + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 # 2. Convolution sequence transformation if use_precomputed_states: - cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1) - cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device) - - # We need to guarantee that anything regarding the cache is on the same device - conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + conv_states = cache_params.update_conv_state(hidden_states_B_C, self.layer_idx) hidden_states_B_C = torch.sum( conv_states * self.conv1d.weight.squeeze(1), dim=-1 @@ -798,13 +676,12 @@ def torch_forward( else: # Init cache if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) conv_states = nn.functional.pad( - hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + hidden_states_B_C, (self.conv_kernel_size - hidden_states_B_C.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) - hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C)[..., :seq_len].transpose(1, 2)) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( @@ -817,7 +694,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if use_precomputed_states: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.ssm_states[self.layer_idx].device + cache_device = cache_params.layers[self.layer_idx].ssm_states.device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -847,9 +724,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_states = cache_params.layers[self.layer_idx].ssm_states * dA + dBx + ssm_states = cache_params.update_ssm_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -858,7 +734,7 @@ def torch_forward( C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_states.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -919,10 +795,7 @@ def torch_forward( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - if use_precomputed_states: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) decay_chunk = decay_chunk.transpose(1, 3) @@ -949,7 +822,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) @@ -963,7 +836,7 @@ def torch_forward( def forward( self, hidden_states, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, seq_idx: torch.IntTensor | None = None, **kwargs, @@ -1042,7 +915,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[BambaFlashAttentionKwargs], @@ -1133,7 +1006,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[BambaFlashAttentionKwargs], @@ -1146,10 +1019,7 @@ def forward( hidden_states = inputs_embeds if use_cache and past_key_values is None: - logger.warning_once( - "Bamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " - "provided, so no cache will be returned." - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) @@ -1226,7 +1096,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, @@ -1295,13 +1165,6 @@ def prepare_inputs_for_generation( is_first_iteration=False, **kwargs, ): - # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` - - if past_key_values is None: - past_key_values = HybridMambaAttentionDynamicCache( - self.config, input_ids.shape[0], self.dtype, device=self.device - ) - kwargs["logits_to_keep"] = self.config.num_logits_to_keep model_inputs = super().prepare_inputs_for_generation( input_ids, diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 4961025f1743..a4ba833f5565 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -23,9 +23,20 @@ import torch from torch import nn -from transformers.activations import ACT2FN -from transformers.models.jamba.modeling_jamba import HybridMambaAttentionDynamicCache, JambaAttentionDecoderLayer -from transformers.models.llama.modeling_llama import ( +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...integrations.hub_kernels import lazy_load_kernel +from ...masking_utils import create_causal_mask +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.generic import merge_with_config_defaults +from ...utils.import_utils import resolve_internal_import +from ...utils.output_capturing import capture_outputs +from ..jamba.modeling_jamba import JambaAttentionDecoderLayer +from ..llama.modeling_llama import ( LlamaAttention, LlamaForCausalLM, LlamaMLP, @@ -33,24 +44,13 @@ LlamaRotaryEmbedding, rotate_half, ) -from transformers.models.mamba2.modeling_mamba2 import ( +from ..mamba2.modeling_mamba2 import ( MambaRMSNormGated, apply_mask_to_padding_states, pad_tensor_by_size, reshape_into_chunks, segment_sum, ) - -from ... import initialization as init -from ...integrations.hub_kernels import lazy_load_kernel -from ...masking_utils import create_causal_mask -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_utils import PreTrainedModel -from ...processing_utils import Unpack -from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging -from ...utils.generic import merge_with_config_defaults -from ...utils.import_utils import resolve_internal_import -from ...utils.output_capturing import capture_outputs from .configuration_bamba import BambaConfig @@ -81,60 +81,6 @@ class BambaFlashAttentionKwargs(TypedDict, total=False): seq_idx: torch.IntTensor -# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer -class HybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache): - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): - self.layers_block_type = config.layers_block_type - self.has_previous_state = False # only used by mamba - conv_kernel_size = config.mamba_d_conv - ssm_state_size = config.mamba_d_state - - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - if self.layers_block_type[i] == "mamba": - self.conv_states += [ - torch.zeros( - batch_size, - (config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * ssm_state_size), - conv_kernel_size, - device=device, - dtype=dtype, - ) - ] - self.ssm_states += [ - torch.zeros( - batch_size, - config.mamba_n_heads, - config.mamba_d_head, - ssm_state_size, - device=device, - dtype=dtype, - ) - ] - else: - self.conv_states += [torch.tensor([[]] * batch_size, device=device)] - self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - class BambaRotaryEmbedding(LlamaRotaryEmbedding): pass @@ -296,7 +242,7 @@ def __init__(self, config: BambaConfig, layer_idx: int): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, seq_idx: torch.IntTensor | None = None, ): @@ -309,12 +255,7 @@ def cuda_kernels_forward( groups_time_state_size = self.n_groups * self.ssm_state_size use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) # getting projected states from cache if it exists @@ -326,7 +267,7 @@ def cuda_kernels_forward( # 2. Convolution sequence transformation hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -348,7 +289,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].ssm_states, hidden_states_reshaped, dt, A, @@ -408,7 +349,7 @@ def cuda_kernels_forward( hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), ) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) if self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( @@ -449,7 +390,7 @@ def cuda_kernels_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer @@ -463,7 +404,7 @@ def cuda_kernels_forward( def torch_forward( self, input_states, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): batch_size, seq_len, _ = input_states.shape @@ -475,23 +416,13 @@ def torch_forward( gate, hidden_states_B_C, dt = projected_states.split( [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + hidden_states_B_C = hidden_states_B_C.transpsose(1,2) - use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size - ) + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 # 2. Convolution sequence transformation if use_precomputed_states: - cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1) - cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device) - - # We need to guarantee that anything regarding the cache is on the same device - conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + conv_states = cache_params.update_conv_state(hidden_states_B_C, self.layer_idx) hidden_states_B_C = torch.sum( conv_states * self.conv1d.weight.squeeze(1), dim=-1 @@ -502,13 +433,12 @@ def torch_forward( else: # Init cache if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) conv_states = nn.functional.pad( - hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + hidden_states_B_C, (self.conv_kernel_size - hidden_states_B_C.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) - hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C)[..., :seq_len].transpose(1, 2)) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( @@ -521,7 +451,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if use_precomputed_states: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.ssm_states[self.layer_idx].device + cache_device = cache_params.layers[self.layer_idx].ssm_states.device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -551,9 +481,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_states = cache_params.layers[self.layer_idx].ssm_states * dA + dBx + ssm_states = cache_params.update_ssm_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -562,7 +491,7 @@ def torch_forward( C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_states.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -623,10 +552,7 @@ def torch_forward( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - if use_precomputed_states: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) decay_chunk = decay_chunk.transpose(1, 3) @@ -653,7 +579,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) @@ -667,7 +593,7 @@ def torch_forward( def forward( self, hidden_states, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, seq_idx: torch.IntTensor | None = None, **kwargs, @@ -717,7 +643,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[BambaFlashAttentionKwargs], @@ -808,7 +734,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[BambaFlashAttentionKwargs], @@ -821,10 +747,7 @@ def forward( hidden_states = inputs_embeds if use_cache and past_key_values is None: - logger.warning_once( - "Bamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " - "provided, so no cache will be returned." - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) @@ -893,7 +816,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, @@ -962,13 +885,6 @@ def prepare_inputs_for_generation( is_first_iteration=False, **kwargs, ): - # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` - - if past_key_values is None: - past_key_values = HybridMambaAttentionDynamicCache( - self.config, input_ids.shape[0], self.dtype, device=self.device - ) - kwargs["logits_to_keep"] = self.config.num_logits_to_keep model_inputs = super().prepare_inputs_for_generation( input_ids, diff --git a/src/transformers/models/granitemoehybrid/configuration_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/configuration_granitemoehybrid.py index 48278a265572..1277b107d09e 100644 --- a/src/transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/configuration_granitemoehybrid.py @@ -45,9 +45,7 @@ class GraniteMoeHybridConfig(PreTrainedConfig): ```""" model_type = "granitemoehybrid" - attribute_map = { - "layers_block_type": "layer_types", - } + attribute_map = {"layers_block_type": "layer_types"} keys_to_ignore_at_inference = ["past_key_values"] vocab_size: int = 32000 @@ -116,10 +114,5 @@ def validate_architecture(self): if self.mamba_d_head * self.mamba_n_heads != mamba_intermediate: raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size") - # overwrite the function to use in `HybridMambaAttentionDynamicCache` - @property - def layers_block_type(self): - return self.layer_types - __all__ = ["GraniteMoeHybridConfig"] diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 676ec8f93773..63ba0b0a9dcc 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -19,16 +19,15 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable -from typing import Any, Optional, TypedDict +from typing import Optional, TypedDict import torch from torch import nn from torch.nn import functional as F -from transformers.activations import ACT2FN - from ... import initialization as init -from ...cache_utils import Cache +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func from ...integrations.hub_kernels import lazy_load_kernel @@ -187,112 +186,6 @@ def forward( return attn_output, attn_weights -class HybridMambaAttentionDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config: GraniteMoeHybridConfig, batch_size, dtype=torch.float16, device=None): - self.layers_block_type = config.layers_block_type - self.has_previous_state = False # only used by mamba - conv_kernel_size = config.mamba_d_conv - ssm_state_size = config.mamba_d_state - - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - if self.layers_block_type[i] == "mamba": - self.conv_states += [ - torch.zeros( - batch_size, - (config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * ssm_state_size), - conv_kernel_size, - device=device, - dtype=dtype, - ) - ] - self.ssm_states += [ - torch.zeros( - batch_size, - config.mamba_n_heads, - config.mamba_d_head, - ssm_state_size, - device=device, - dtype=dtype, - ) - ] - else: - self.conv_states += [torch.tensor([[]] * batch_size, device=device)] - self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.key_cache) - - def __getitem__(self, layer_idx): - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - if self.get_seq_length() > 0: - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].shape[-1] == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - # Helper methods for segment sum computation @@ -469,7 +362,7 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, seq_idx: torch.IntTensor | None = None, ): @@ -482,12 +375,7 @@ def cuda_kernels_forward( groups_time_state_size = self.n_groups * self.ssm_state_size use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) # getting projected states from cache if it exists @@ -499,7 +387,7 @@ def cuda_kernels_forward( # 2. Convolution sequence transformation hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -521,7 +409,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].ssm_states, hidden_states_reshaped, dt, A, @@ -581,7 +469,7 @@ def cuda_kernels_forward( hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), ) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) if self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( @@ -622,7 +510,7 @@ def cuda_kernels_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer @@ -636,7 +524,7 @@ def cuda_kernels_forward( def torch_forward( self, input_states, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): batch_size, seq_len, _ = input_states.shape @@ -648,23 +536,13 @@ def torch_forward( gate, hidden_states_B_C, dt = projected_states.split( [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + hidden_states_B_C = hidden_states_B_C.transpsose(1,2) - use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size - ) + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 # 2. Convolution sequence transformation if use_precomputed_states: - cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1) - cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device) - - # We need to guarantee that anything regarding the cache is on the same device - conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + conv_states = cache_params.update_conv_state(hidden_states_B_C, self.layer_idx) hidden_states_B_C = torch.sum( conv_states * self.conv1d.weight.squeeze(1), dim=-1 @@ -675,13 +553,12 @@ def torch_forward( else: # Init cache if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) conv_states = nn.functional.pad( - hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + hidden_states_B_C, (self.conv_kernel_size - hidden_states_B_C.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) - hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C)[..., :seq_len].transpose(1, 2)) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( @@ -694,7 +571,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if use_precomputed_states: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.ssm_states[self.layer_idx].device + cache_device = cache_params.layers[self.layer_idx].ssm_states.device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -724,9 +601,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_states = cache_params.layers[self.layer_idx].ssm_states * dA + dBx + ssm_states = cache_params.update_ssm_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -735,7 +611,7 @@ def torch_forward( C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_states.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -796,10 +672,7 @@ def torch_forward( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - if use_precomputed_states: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) decay_chunk = decay_chunk.transpose(1, 3) @@ -826,7 +699,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) @@ -840,7 +713,7 @@ def torch_forward( def forward( self, hidden_states, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, seq_idx: torch.IntTensor | None = None, **kwargs, @@ -1293,6 +1166,9 @@ def forward( inputs_embeds = inputs_embeds * self.embedding_multiplier + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens @@ -1534,36 +1410,5 @@ def forward( router_logits=outputs.router_logits, ) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - position_ids=None, - use_cache=True, - is_first_iteration=False, - **kwargs, - ): - # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` - - if past_key_values is None and use_cache: - past_key_values = HybridMambaAttentionDynamicCache( - self.config, input_ids.shape[0], self.dtype, device=self.device - ) - - model_inputs = super().prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - use_cache=use_cache, - is_first_iteration=is_first_iteration, - **kwargs, - ) - - return model_inputs - __all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"] diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index 4c72531bddb5..ccf8c58bfb51 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -18,7 +18,7 @@ from torch import nn from ... import initialization as init -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -27,7 +27,7 @@ from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..bamba.configuration_bamba import BambaConfig -from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache +from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated from ..gemma2.modeling_gemma2 import Gemma2RotaryEmbedding from ..granitemoeshared.modeling_granitemoeshared import ( GraniteFlashAttentionKwargs, @@ -226,6 +226,9 @@ def forward( inputs_embeds = inputs_embeds * self.embedding_multiplier + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens @@ -315,36 +318,5 @@ def forward(self, **super_kwargs): ```""" return super().forward(**super_kwargs) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - position_ids=None, - use_cache=True, - is_first_iteration=False, - **kwargs, - ): - # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` - - if past_key_values is None and use_cache: - past_key_values = HybridMambaAttentionDynamicCache( - self.config, input_ids.shape[0], self.dtype, device=self.device - ) - - model_inputs = super().prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - use_cache=use_cache, - is_first_iteration=is_first_iteration, - **kwargs, - ) - - return model_inputs - __all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"] diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 2feaf7f7bcba..899bd737d3c5 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -652,7 +652,7 @@ def cuda_kernels_forward( return out # fmt: off - def torch_forward(self, input_states, cache_params: Cache | None=None, attention_mask: torch.Tensor | None=None): + def torch_forward(self, input_states, cache_params: Cache | None=None, attention_mask: torch.Tensor | None = None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # Gated MLP's linear projection @@ -667,43 +667,35 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention _, _, gate, hidden_states, dt = projected_states.split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + hidden_states = hidden_states.transpose(1, 2) - has_previous_state = cache_params.has_previous_state(self.layer_idx) - - if cache_params is not None and has_previous_state: - ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() - else: - ssm_state = torch.zeros( - (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), - device=hidden_states.device, dtype=dtype - ) + use_precomputed_state = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # Convolution sequence transformation - if cache_params is not None: - if has_previous_state: - gate = gate.unsqueeze(1) - conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) - hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) - if self.use_conv_bias: - hidden_states += self.conv1d.bias - hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding - else: - hidden_states = hidden_states.transpose(1,2) + if use_precomputed_state: + gate = gate.unsqueeze(1) + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) + hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding + else: + if cache_params is not None: conv_state = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) - hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] - if attention_mask is not None: - dtype = hidden_states.dtype - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - else: - hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len].transpose(1, 2)) + if attention_mask is not None: + dtype = hidden_states.dtype + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and has_previous_state: + if use_precomputed_state: # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] @@ -732,6 +724,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention dBx = dB * hidden_states[..., None] # State calculation + ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() ssm_state = ssm_state * dA + dBx ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 6a4ab477d731..c6e7af8c3cba 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -440,7 +440,7 @@ def cuda_kernels_forward( return out # fmt: off - def torch_forward(self, input_states, cache_params: Cache | None=None, attention_mask: torch.Tensor | None=None): + def torch_forward(self, input_states, cache_params: Cache | None=None, attention_mask: torch.Tensor | None = None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # Gated MLP's linear projection @@ -455,43 +455,35 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention _, _, gate, hidden_states, dt = projected_states.split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + hidden_states = hidden_states.transpose(1, 2) - has_previous_state = cache_params.has_previous_state(self.layer_idx) - - if cache_params is not None and has_previous_state: - ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() - else: - ssm_state = torch.zeros( - (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), - device=hidden_states.device, dtype=dtype - ) + use_precomputed_state = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # Convolution sequence transformation - if cache_params is not None: - if has_previous_state: - gate = gate.unsqueeze(1) - conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) - hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) - if self.use_conv_bias: - hidden_states += self.conv1d.bias - hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding - else: - hidden_states = hidden_states.transpose(1,2) + if use_precomputed_state: + gate = gate.unsqueeze(1) + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) + hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding + else: + if cache_params is not None: conv_state = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) - hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] - if attention_mask is not None: - dtype = hidden_states.dtype - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - else: - hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len].transpose(1, 2)) + if attention_mask is not None: + dtype = hidden_states.dtype + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and has_previous_state: + if use_precomputed_state: # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] @@ -520,6 +512,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention dBx = dB * hidden_states[..., None] # State calculation + ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() ssm_state = ssm_state * dA + dBx ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index ee8143f31c14..da236c72790c 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -47,11 +47,8 @@ if is_torch_available(): import torch - from transformers import ( - BambaForCausalLM, - BambaModel, - ) - from transformers.models.bamba.modeling_bamba import HybridMambaAttentionDynamicCache + from transformers import BambaForCausalLM, BambaModel, DynamicCache + from transformers.cache_utils import MambaLayer class BambaModelTester: @@ -228,14 +225,9 @@ def create_and_check_decoder_model_past_large_inputs( model.eval() # first forward pass - # Attention: Jamba needs the cache to be initialized to return a cache! - past_key_values = HybridMambaAttentionDynamicCache( - config, input_ids.shape[0], model.dtype, device=model.device - ) outputs = model( input_ids, attention_mask=input_mask, - past_key_values=past_key_values, use_cache=True, ) past_key_values = outputs.past_key_values @@ -289,7 +281,7 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi model_split_percents = [0.5, 0.7, 0.8] def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, HybridMambaAttentionDynamicCache) + self.assertIsInstance(past_key_values, DynamicCache) # (batch, kv heads, seq_length, head_dim) num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) @@ -307,18 +299,14 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l for idx in range(len(past_key_values)): if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) - self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) + self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) else: - self.assertEqual(past_key_values.key_cache[idx].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[idx].shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) - def _check_caches_are_equal( - self, cache1: HybridMambaAttentionDynamicCache, cache2: HybridMambaAttentionDynamicCache - ): - if not isinstance(cache1, HybridMambaAttentionDynamicCache) or not isinstance( - cache2, HybridMambaAttentionDynamicCache - ): + def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): + if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): raise ValueError("The wrong cache is being used!") if not len(cache1) == len(cache2): @@ -326,10 +314,14 @@ def _check_caches_are_equal( num_layers = len(cache1) for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx]) - torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx]) + # Mamba layer + if type(cache1.layers[idx]) is MambaLayer: + torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) + torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + # Attention layer + else: + torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) + torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) def setUp(self): self.model_tester = self.model_tester_class(self) diff --git a/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py b/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py index 8cb946d0aa2e..e102ff9c93ea 100644 --- a/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py +++ b/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py @@ -46,11 +46,8 @@ if is_torch_available(): import torch - from transformers import ( - GraniteMoeHybridForCausalLM, - GraniteMoeHybridModel, - ) - from transformers.models.granitemoehybrid.modeling_granitemoehybrid import HybridMambaAttentionDynamicCache + from transformers import DynamicCache, GraniteMoeHybridForCausalLM, GraniteMoeHybridModel + from transformers.cache_utils import MambaLayer class GraniteMoeHybridModelTester(BambaModelTester): @@ -109,12 +106,8 @@ class GraniteMoeHybridModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.7, 0.8] - def _check_caches_are_equal( - self, cache1: HybridMambaAttentionDynamicCache, cache2: HybridMambaAttentionDynamicCache - ): - if not isinstance(cache1, HybridMambaAttentionDynamicCache) or not isinstance( - cache2, HybridMambaAttentionDynamicCache - ): + def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): + if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): raise ValueError("The wrong cache is being used!") if not len(cache1) == len(cache2): @@ -122,10 +115,14 @@ def _check_caches_are_equal( num_layers = len(cache1) for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx]) - torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx]) + # Mamba layer + if type(cache1.layers[idx]) is MambaLayer: + torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) + torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + # Attention layer + else: + torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) + torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) def setUp(self): self.model_tester = self.model_tester_class(self) @@ -325,7 +322,7 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_id torch.testing.assert_close(loss_padded, loss_padfree) def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, HybridMambaAttentionDynamicCache) + self.assertIsInstance(past_key_values, DynamicCache) # (batch, kv heads, seq_length, head_dim) num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) @@ -343,11 +340,11 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l for idx in range(len(past_key_values)): if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) - self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) + self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) else: - self.assertEqual(past_key_values.key_cache[idx].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[idx].shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) def test_config_requires_mamba_or_attention_layers(self): """Ensure we can't create a config with disallowed layers.""" From fdb1579124cc9dfebe2c33d46ac2862d9fc58ff2 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 15:46:00 +0100 Subject: [PATCH 21/56] final mambas --- .../falcon_h1/configuration_falcon_h1.py | 1 + .../models/falcon_h1/modeling_falcon_h1.py | 243 ++---------------- .../models/falcon_h1/modular_falcon_h1.py | 235 +++-------------- .../models/nemotron_h/modeling_nemotron_h.py | 220 +++------------- .../models/nemotron_h/modular_nemotron_h.py | 66 +---- .../falcon_h1/test_modeling_falcon_h1.py | 54 ++-- .../nemotron_h/test_modeling_nemotron_h.py | 52 ++-- 7 files changed, 168 insertions(+), 703 deletions(-) diff --git a/src/transformers/models/falcon_h1/configuration_falcon_h1.py b/src/transformers/models/falcon_h1/configuration_falcon_h1.py index caa0e6288528..19401fbbc632 100644 --- a/src/transformers/models/falcon_h1/configuration_falcon_h1.py +++ b/src/transformers/models/falcon_h1/configuration_falcon_h1.py @@ -54,6 +54,7 @@ class FalconH1Config(PreTrainedConfig): """ model_type = "falcon_h1" + attribute_map = {"layer_types": "layers_block_type"} keys_to_ignore_at_inference = ["past_key_values"] vocab_size: int = 128000 diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 237c04c8d28d..2911f76a15cb 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -24,16 +24,15 @@ # limitations under the License. from collections.abc import Callable -from typing import Any, Optional +from typing import Optional import torch import torch.nn.functional as F from torch import nn -from transformers.activations import ACT2FN - from ... import initialization as init -from ...cache_utils import Cache +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func from ...integrations.hub_kernels import lazy_load_kernel @@ -54,161 +53,6 @@ logger = logging.get_logger(__name__) -class FalconHybridMambaAttentionDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__( - self, - config: FalconH1Config, - batch_size: int, - dtype: torch.dtype = torch.float16, - devices: list[str] | None = None, - ): - self.seqlen_offset = 0 - self.dtype = dtype - self.has_previous_state = False - self.conv_kernel_size = config.mamba_d_conv - - self.intermediate_size = ( - config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) - ) - - self.conv_states = { - i: torch.zeros( - batch_size, - self.intermediate_size + 2 * config.mamba_n_groups * config.mamba_d_state, - self.conv_kernel_size, - device=devices[i], - dtype=dtype, - ) - for i in range(config.num_hidden_layers) - } - self.ssm_states = { - i: torch.zeros( - batch_size, - config.mamba_n_heads, - config.mamba_d_head, - config.mamba_d_state, - device=devices[i], - dtype=dtype, - ) - for i in range(config.num_hidden_layers) - } - - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - self.transformer_layers.append(i) - - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - def __len__(self): - return len(self.key_cache) - - def __getitem__(self, layer_idx): - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if len(self.key_cache) <= layer_idx: - # There may be skipped layers, fill them with empty lists - for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append([]) - self.value_cache.append([]) - self.key_cache.append(key_states) - self.value_cache.append(value_states) - elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - if self.get_seq_length() > 0: - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].shape[-1] == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now - if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states[layer_idx].device) - else: - self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states[layer_idx].device) - - return self.conv_states[layer_idx] - - def reset(self): - self.has_previous_state = False - self.conv_states.zero_() - self.ssm_states.zero_() - - class FalconH1RotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -633,7 +477,7 @@ def __init__(self, config: FalconH1Config, layer_idx: int): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: FalconHybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): # 1. Gated MLP's linear projection @@ -649,12 +493,7 @@ def cuda_kernels_forward( groups_time_state_size = self.n_groups * self.ssm_state_size use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) # getting projected states from cache if it exists @@ -668,7 +507,7 @@ def cuda_kernels_forward( # 2. Convolution sequence transformation hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -690,7 +529,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].ssm_states, hidden_states_reshaped, dt, A, @@ -763,7 +602,7 @@ def cuda_kernels_forward( hidden_states_B_C.permute(0, 2, 1), (self.conv_kernel_size - hidden_states_B_C.shape[-2], 0), ) - cache_params.update_conv_state(self.layer_idx, conv_states, cache_init=True) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) time_step = nn.functional.softplus(dt + self.dt_bias) # 1D Convolution @@ -810,7 +649,7 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer if self.mamba_rms_norm: @@ -824,7 +663,7 @@ def cuda_kernels_forward( def torch_forward( self, input_states, - cache_params: FalconHybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): batch_size, seq_len, _ = input_states.shape @@ -839,19 +678,13 @@ def torch_forward( gate, hidden_states_B_C, dt = projected_states.split([ self.intermediate_size, self.conv_dim, self.num_heads ], dim=-1) + hidden_states_B_C = hidden_states_B_C.transpose(1,2) - use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size - ) + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 # 2. Convolution sequence transformation if use_precomputed_states: - conv_states = cache_params.update_conv_state(self.layer_idx, hidden_states_B_C, cache_init=False) + conv_states = cache_params.update_conv_state(hidden_states_B_C, self.layer_idx) # We need to guarantee that anything regarding the cache is on the same device conv_states = conv_states.to(device=self.conv1d.weight.device) @@ -864,13 +697,12 @@ def torch_forward( else: # Init cache if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) conv_states = nn.functional.pad( - hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + hidden_states_B_C, (self.conv_kernel_size - hidden_states_B_C.shape[-1], 0) ) - conv_states = cache_params.update_conv_state(self.layer_idx, conv_states, cache_init=True) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) - hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C)[..., :seq_len].transpose(1, 2)) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( @@ -883,7 +715,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if use_precomputed_states: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.ssm_states[self.layer_idx].device + cache_device = cache_params.layers[self.layer_idx].ssm_states.device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -913,9 +745,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_states = cache_params.layers[self.layer_idx].ssm_states * dA + dBx + ssm_states = cache_params.update_ssm_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -924,7 +755,7 @@ def torch_forward( C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_states.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -985,10 +816,7 @@ def torch_forward( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - if use_precomputed_states: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) decay_chunk = decay_chunk.transpose(1, 3) @@ -1015,7 +843,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_ssm_states(ssm_state, self.layer_idx) if self.mamba_rms_norm: scan_output = self.norm(y, gate) @@ -1032,7 +860,7 @@ def torch_forward( def forward( self, hidden_states, - cache_params: FalconHybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): @@ -1110,7 +938,7 @@ def forward( attention_mask: torch.Tensor | None = None, mamba_attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: FalconHybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs, @@ -1120,7 +948,7 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`FalconHybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -1271,7 +1099,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: FalconHybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -1284,10 +1112,7 @@ def forward( hidden_states = inputs_embeds if use_cache and past_key_values is None: - logger.warning_once( - "FalconH1 requires an initialized `FalconHybridMambaAttentionDynamicCache` to return a cache. None was " - "provided, so no cache will be returned." - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1365,7 +1190,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: FalconHybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, @@ -1427,18 +1252,6 @@ def prepare_inputs_for_generation( is_first_iteration=False, **kwargs, ): - # Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache` - - if past_key_values is None: - past_key_values = FalconHybridMambaAttentionDynamicCache( - self.config, - input_ids.shape[0], - self.dtype, - devices=[ - self.model.layers[i].mamba.conv1d.weight.device for i in range(self.config.num_hidden_layers) - ], - ) - kwargs["logits_to_keep"] = self.config.num_logits_to_keep model_inputs = super().prepare_inputs_for_generation( input_ids, diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 4bd7301cd3fd..985e14dda2c4 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -19,15 +19,26 @@ """PyTorch FalconH1 model.""" from collections.abc import Callable -from typing import Any import torch import torch.nn.functional as F from torch import nn -from transformers.activations import ACT2FN -from transformers.models.jamba.modeling_jamba import HybridMambaAttentionDynamicCache -from transformers.models.llama.modeling_llama import ( +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...integrations.hub_kernels import lazy_load_kernel +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.generic import merge_with_config_defaults +from ...utils.import_utils import resolve_internal_import +from ...utils.output_capturing import capture_outputs +from ..llama.modeling_llama import ( LlamaAttention, LlamaForCausalLM, LlamaMLP, @@ -36,152 +47,19 @@ apply_rotary_pos_emb, eager_attention_forward, ) -from transformers.models.mamba2.modeling_mamba2 import ( +from ..mamba2.modeling_mamba2 import ( MambaRMSNormGated, apply_mask_to_padding_states, pad_tensor_by_size, reshape_into_chunks, segment_sum, ) - -from ... import initialization as init -from ...cache_utils import Cache -from ...integrations.hub_kernels import lazy_load_kernel -from ...masking_utils import create_causal_mask -from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging -from ...utils.generic import merge_with_config_defaults -from ...utils.import_utils import resolve_internal_import -from ...utils.output_capturing import capture_outputs from .configuration_falcon_h1 import FalconH1Config logger = logging.get_logger(__name__) -class FalconHybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache): - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - def __init__( - self, - config: FalconH1Config, - batch_size: int, - dtype: torch.dtype = torch.float16, - devices: list[str] | None = None, - ): - self.seqlen_offset = 0 - self.dtype = dtype - self.has_previous_state = False - self.conv_kernel_size = config.mamba_d_conv - - self.intermediate_size = ( - config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) - ) - - self.conv_states = { - i: torch.zeros( - batch_size, - self.intermediate_size + 2 * config.mamba_n_groups * config.mamba_d_state, - self.conv_kernel_size, - device=devices[i], - dtype=dtype, - ) - for i in range(config.num_hidden_layers) - } - self.ssm_states = { - i: torch.zeros( - batch_size, - config.mamba_n_heads, - config.mamba_d_head, - config.mamba_d_state, - device=devices[i], - dtype=dtype, - ) - for i in range(config.num_hidden_layers) - } - - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - self.transformer_layers.append(i) - - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if len(self.key_cache) <= layer_idx: - # There may be skipped layers, fill them with empty lists - for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append([]) - self.value_cache.append([]) - self.key_cache.append(key_states) - self.value_cache.append(value_states) - elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now - if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states[layer_idx].device) - else: - self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states[layer_idx].device) - - return self.conv_states[layer_idx] - - def reset(self): - self.has_previous_state = False - self.conv_states.zero_() - self.ssm_states.zero_() - - class FalconH1RotaryEmbedding(LlamaRotaryEmbedding): pass @@ -386,7 +264,7 @@ def __init__(self, config: FalconH1Config, layer_idx: int): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: FalconHybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): # 1. Gated MLP's linear projection @@ -402,12 +280,7 @@ def cuda_kernels_forward( groups_time_state_size = self.n_groups * self.ssm_state_size use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) # getting projected states from cache if it exists @@ -421,7 +294,7 @@ def cuda_kernels_forward( # 2. Convolution sequence transformation hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -443,7 +316,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].ssm_states, hidden_states_reshaped, dt, A, @@ -516,7 +389,7 @@ def cuda_kernels_forward( hidden_states_B_C.permute(0, 2, 1), (self.conv_kernel_size - hidden_states_B_C.shape[-2], 0), ) - cache_params.update_conv_state(self.layer_idx, conv_states, cache_init=True) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) time_step = nn.functional.softplus(dt + self.dt_bias) # 1D Convolution @@ -563,7 +436,7 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer if self.mamba_rms_norm: @@ -577,7 +450,7 @@ def cuda_kernels_forward( def torch_forward( self, input_states, - cache_params: FalconHybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): batch_size, seq_len, _ = input_states.shape @@ -592,19 +465,13 @@ def torch_forward( gate, hidden_states_B_C, dt = projected_states.split([ self.intermediate_size, self.conv_dim, self.num_heads ], dim=-1) + hidden_states_B_C = hidden_states_B_C.transpose(1,2) - use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size - ) + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 # 2. Convolution sequence transformation if use_precomputed_states: - conv_states = cache_params.update_conv_state(self.layer_idx, hidden_states_B_C, cache_init=False) + conv_states = cache_params.update_conv_state(hidden_states_B_C, self.layer_idx) # We need to guarantee that anything regarding the cache is on the same device conv_states = conv_states.to(device=self.conv1d.weight.device) @@ -617,13 +484,12 @@ def torch_forward( else: # Init cache if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) conv_states = nn.functional.pad( - hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + hidden_states_B_C, (self.conv_kernel_size - hidden_states_B_C.shape[-1], 0) ) - conv_states = cache_params.update_conv_state(self.layer_idx, conv_states, cache_init=True) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) - hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C)[..., :seq_len].transpose(1, 2)) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( @@ -636,7 +502,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if use_precomputed_states: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.ssm_states[self.layer_idx].device + cache_device = cache_params.layers[self.layer_idx].ssm_states.device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -666,9 +532,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_states = cache_params.layers[self.layer_idx].ssm_states * dA + dBx + ssm_states = cache_params.update_ssm_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -677,7 +542,7 @@ def torch_forward( C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_states.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -738,10 +603,7 @@ def torch_forward( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - if use_precomputed_states: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) decay_chunk = decay_chunk.transpose(1, 3) @@ -768,7 +630,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_ssm_states(ssm_state, self.layer_idx) if self.mamba_rms_norm: scan_output = self.norm(y, gate) @@ -785,7 +647,7 @@ def torch_forward( def forward( self, hidden_states, - cache_params: FalconHybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): @@ -839,7 +701,7 @@ def forward( attention_mask: torch.Tensor | None = None, mamba_attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: FalconHybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs, @@ -849,7 +711,7 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`FalconHybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -1000,7 +862,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: FalconHybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -1013,10 +875,7 @@ def forward( hidden_states = inputs_embeds if use_cache and past_key_values is None: - logger.warning_once( - "FalconH1 requires an initialized `FalconHybridMambaAttentionDynamicCache` to return a cache. None was " - "provided, so no cache will be returned." - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1080,7 +939,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: FalconHybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, @@ -1142,18 +1001,6 @@ def prepare_inputs_for_generation( is_first_iteration=False, **kwargs, ): - # Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache` - - if past_key_values is None: - past_key_values = FalconHybridMambaAttentionDynamicCache( - self.config, - input_ids.shape[0], - self.dtype, - devices=[ - self.model.layers[i].mamba.conv1d.weight.device for i in range(self.config.num_hidden_layers) - ], - ) - kwargs["logits_to_keep"] = self.config.num_logits_to_keep model_inputs = super().prepare_inputs_for_generation( input_ids, diff --git a/src/transformers/models/nemotron_h/modeling_nemotron_h.py b/src/transformers/models/nemotron_h/modeling_nemotron_h.py index 4455c9ba49b7..87c456cf4ecf 100644 --- a/src/transformers/models/nemotron_h/modeling_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modeling_nemotron_h.py @@ -22,7 +22,6 @@ import math from collections.abc import Callable -from typing import Any import torch import torch.nn.functional as F @@ -30,6 +29,7 @@ from ... import initialization as init from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import ( lazy_load_kernel, @@ -54,119 +54,6 @@ logger = logging.get_logger(__name__) -class NemotronHHybridDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__( - self, config: NemotronHConfig, batch_size: int, dtype: torch.dtype = torch.float16, device: str | None = None - ): - self.dtype = dtype - self.layers_block_type = config.layers_block_type - self.has_previous_state = False - self.intermediate_size = int(config.mamba_num_heads * config.mamba_head_dim) - self.ssm_state_size = config.ssm_state_size - self.conv_kernel_size = config.conv_kernel - self.n_mamba_heads = config.mamba_num_heads - self.transformer_layers = [] - self._modules = {} - self._parameters = {} - self._buffers = {} - self.conv_states = {} - self.ssm_states = {} - for i in range(config.num_hidden_layers): - if self.layers_block_type[i] == "mamba": - # Only allocate mamba cache for mamba layers - self.conv_states[i] = torch.zeros( - batch_size, - self.intermediate_size + 2 * config.n_groups * self.ssm_state_size, - self.conv_kernel_size, - device=device, - dtype=dtype, - ) - self.ssm_states[i] = torch.zeros( - batch_size, - self.n_mamba_heads, - config.mamba_head_dim, - self.ssm_state_size, - device=device, - dtype=dtype, - ) - else: - # For attention and moe layers, use empty tensors - self.conv_states[i] = torch.tensor([[]] * batch_size, device=device) - self.ssm_states[i] = torch.tensor([[]] * batch_size, device=device) - - if self.layers_block_type[i] == "attention": - self.transformer_layers.append(i) - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.key_cache) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - if self.get_seq_length() > 0: - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - # Helper methods for segment sum computation @@ -329,7 +216,7 @@ def __init__(self, config: NemotronHConfig, layer_idx: int | None = None): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: NemotronHHybridDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): # set up dimensions for reshapes later @@ -339,7 +226,7 @@ def cuda_kernels_forward( d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads # getting projected states from cache if it exists - if cache_params is not None and cache_params.has_previous_state: + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D) d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads] @@ -347,7 +234,7 @@ def cuda_kernels_forward( hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -368,7 +255,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].ssm_states, hidden_states_reshaped, dt, A, @@ -432,7 +319,7 @@ def cuda_kernels_forward( conv_state = nn.functional.pad( hidden_states_B_C_t, (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] @@ -469,7 +356,7 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_ssm_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer scan_output = self.norm(scan_output, gate) @@ -477,11 +364,11 @@ def cuda_kernels_forward( return out # fmt: off - def torch_forward(self, input_states, cache_params: NemotronHHybridDynamicCache | None=None, attention_mask: torch.Tensor | None=None): + def torch_forward(self, input_states, cache_params: Cache | None=None, attention_mask: torch.Tensor | None = None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # Gated MLP's linear projection - if cache_params is not None and cache_params.has_previous_state: + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): projected_states = self.in_proj(input_states.squeeze(1)) else: if attention_mask is not None: @@ -492,43 +379,35 @@ def torch_forward(self, input_states, cache_params: NemotronHHybridDynamicCache _, _, gate, hidden_states, dt = projected_states.split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + hidden_states = hidden_states.transpose(1, 2) + + use_precomputed_state = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # Convolution sequence transformation - if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx].clone() - ssm_state = ssm_state.to(hidden_states.device) - if cache_params.has_previous_state: - gate = gate.unsqueeze(1) - conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) - # handle batched generation - states are copied through - conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states - cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) - if self.use_conv_bias: - hidden_states += self.conv1d.bias - hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding - else: - hidden_states = hidden_states.transpose(1,2) + if use_precomputed_state: + gate = gate.unsqueeze(1) + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) + hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding + else: + if cache_params is not None: conv_state = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] - if attention_mask is not None: - dtype = hidden_states.dtype - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - else: - ssm_state = torch.zeros( - (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), - device=hidden_states.device, dtype=dtype - ) - hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) + + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len].transpose(1, 2)) + if attention_mask is not None: + dtype = hidden_states.dtype + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and cache_params.has_previous_state: + if use_precomputed_state: # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] @@ -557,9 +436,9 @@ def torch_forward(self, input_states, cache_params: NemotronHHybridDynamicCache dBx = dB * hidden_states[..., None] # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() + ssm_state = ssm_state * dA + dBx + ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -568,7 +447,7 @@ def torch_forward(self, input_states, cache_params: NemotronHHybridDynamicCache C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_state.to(C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -629,10 +508,7 @@ def torch_forward(self, input_states, cache_params: NemotronHHybridDynamicCache B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] # permute back B * decay states states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) - if cache_params is not None and cache_params.has_previous_state: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) @@ -660,7 +536,7 @@ def torch_forward(self, input_states, cache_params: NemotronHHybridDynamicCache y = y[:, :seq_len, :, :] y = y.reshape(batch_size, seq_len, -1) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_ssm_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) @@ -674,7 +550,7 @@ def torch_forward(self, input_states, cache_params: NemotronHHybridDynamicCache def forward( self, hidden_states, - cache_params: NemotronHHybridDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): @@ -968,7 +844,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - past_key_values: NemotronHHybridDynamicCache | None = None, + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] @@ -1036,7 +912,7 @@ def __init__(self, config, layer_idx): def forward( self, hidden_states, - past_key_values: NemotronHHybridDynamicCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, use_cache: bool | None = False, @@ -1162,7 +1038,7 @@ def forward( input_ids: torch.LongTensor | None = None, inputs_embeds: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: NemotronHHybridDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = None, attention_mask: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], @@ -1174,12 +1050,7 @@ def forward( inputs_embeds = self.embeddings(input_ids) if use_cache and past_key_values is None: - past_key_values = NemotronHHybridDynamicCache( - config=self.config, - batch_size=inputs_embeds.shape[0], - dtype=inputs_embeds.dtype, - device=inputs_embeds.device, - ) + past_key_values = DynamicCache(config=self.config) hidden_states = inputs_embeds @@ -1260,7 +1131,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: NemotronHHybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, @@ -1327,13 +1198,6 @@ def prepare_inputs_for_generation( is_first_iteration=False, **kwargs, ): - # Overwritten -- has a unique cache type, `NemotronHHybridDynamicCache` - - if past_key_values is None: - past_key_values = NemotronHHybridDynamicCache( - self.config, input_ids.shape[0], dtype=self.dtype, device=self.device - ) - kwargs["logits_to_keep"] = self.config.num_logits_to_keep model_inputs = super().prepare_inputs_for_generation( input_ids, diff --git a/src/transformers/models/nemotron_h/modular_nemotron_h.py b/src/transformers/models/nemotron_h/modular_nemotron_h.py index 8412be1dc5f9..5a849833cc89 100644 --- a/src/transformers/models/nemotron_h/modular_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modular_nemotron_h.py @@ -22,6 +22,7 @@ from ... import initialization as init from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...integrations import use_experts_implementation from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer @@ -32,7 +33,7 @@ from ...models.llama.modeling_llama import LlamaRMSNorm from ...models.nemotron.modeling_nemotron import NemotronMLP from ...models.zamba.modeling_zamba import ZambaForCausalLM -from ...models.zamba2.modeling_zamba2 import Zamba2HybridDynamicCache, Zamba2MambaMixer, Zamba2RMSNormGated +from ...models.zamba2.modeling_zamba2 import Zamba2MambaMixer, Zamba2RMSNormGated from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging from ...utils.generic import merge_with_config_defaults @@ -45,52 +46,6 @@ is_fast_path_available = False -class NemotronHHybridDynamicCache(Zamba2HybridDynamicCache): - def __init__( - self, config: NemotronHConfig, batch_size: int, dtype: torch.dtype = torch.float16, device: str | None = None - ): - self.dtype = dtype - self.layers_block_type = config.layers_block_type - self.has_previous_state = False - self.intermediate_size = int(config.mamba_num_heads * config.mamba_head_dim) - self.ssm_state_size = config.ssm_state_size - self.conv_kernel_size = config.conv_kernel - self.n_mamba_heads = config.mamba_num_heads - self.transformer_layers = [] - self._modules = {} - self._parameters = {} - self._buffers = {} - self.conv_states = {} - self.ssm_states = {} - for i in range(config.num_hidden_layers): - if self.layers_block_type[i] == "mamba": - # Only allocate mamba cache for mamba layers - self.conv_states[i] = torch.zeros( - batch_size, - self.intermediate_size + 2 * config.n_groups * self.ssm_state_size, - self.conv_kernel_size, - device=device, - dtype=dtype, - ) - self.ssm_states[i] = torch.zeros( - batch_size, - self.n_mamba_heads, - config.mamba_head_dim, - self.ssm_state_size, - device=device, - dtype=dtype, - ) - else: - # For attention and moe layers, use empty tensors - self.conv_states[i] = torch.tensor([[]] * batch_size, device=device) - self.ssm_states[i] = torch.tensor([[]] * batch_size, device=device) - - if self.layers_block_type[i] == "attention": - self.transformer_layers.append(i) - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - class NemotronHMamba2Mixer(Zamba2MambaMixer): def __init__(self, config: NemotronHConfig, layer_idx: int | None = None): super().__init__(config, layer_idx) @@ -133,7 +88,7 @@ def __init__(self, config: NemotronHConfig, layer_idx: int | None = None): def forward( self, hidden_states, - cache_params: NemotronHHybridDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): @@ -277,7 +232,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - past_key_values: NemotronHHybridDynamicCache | None = None, + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: return super().forward(hidden_states, attention_mask, past_key_values, **kwargs) @@ -318,7 +273,7 @@ def __init__(self, config, layer_idx): def forward( self, hidden_states, - past_key_values: NemotronHHybridDynamicCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, use_cache: bool | None = False, @@ -444,7 +399,7 @@ def forward( input_ids: torch.LongTensor | None = None, inputs_embeds: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: NemotronHHybridDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = None, attention_mask: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], @@ -456,12 +411,7 @@ def forward( inputs_embeds = self.embeddings(input_ids) if use_cache and past_key_values is None: - past_key_values = NemotronHHybridDynamicCache( - config=self.config, - batch_size=inputs_embeds.shape[0], - dtype=inputs_embeds.dtype, - device=inputs_embeds.device, - ) + past_key_values = DynamicCache(config=self.config) hidden_states = inputs_embeds @@ -532,7 +482,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: NemotronHHybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, diff --git a/tests/models/falcon_h1/test_modeling_falcon_h1.py b/tests/models/falcon_h1/test_modeling_falcon_h1.py index 88c656dce1e0..e236a818835b 100644 --- a/tests/models/falcon_h1/test_modeling_falcon_h1.py +++ b/tests/models/falcon_h1/test_modeling_falcon_h1.py @@ -36,10 +36,8 @@ if is_torch_available(): import torch - from transformers import AutoTokenizer, FalconH1ForCausalLM, FalconH1Model - from transformers.models.falcon_h1.modeling_falcon_h1 import ( - FalconHybridMambaAttentionDynamicCache, - ) + from transformers import AutoTokenizer, DynamicCache, FalconH1ForCausalLM, FalconH1Model + from transformers.cache_utils import MambaLayer class FalconH1ModelTester: @@ -206,17 +204,9 @@ def create_and_check_decoder_model_past_large_inputs( model.eval() # first forward pass - # Attention: Jamba needs the cache to be initialized to return a cache! - past_key_values = FalconHybridMambaAttentionDynamicCache( - config, - input_ids.shape[0], - model.dtype, - devices=[model.device for _ in range(model.config.num_hidden_layers)], - ) outputs = model( input_ids, attention_mask=input_mask, - past_key_values=past_key_values, use_cache=True, ) past_key_values = outputs.past_key_values @@ -265,26 +255,34 @@ class FalconH1ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM ) def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, FalconHybridMambaAttentionDynamicCache) + self.assertIsInstance(past_key_values, DynamicCache) # (batch, kv heads, seq_length, head_dim) num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - expected_shape = (batch_size, num_heads, seq_length, head_dim) + attention_shape = (batch_size, num_heads, seq_length, head_dim) - self.assertListEqual( - [key_tensor.shape for key_tensor in past_key_values.key_cache], - [expected_shape] * len(past_key_values.key_cache), + conv_kernel_size = config.mamba_d_conv + intermediate_size = ( + config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) ) - self.assertListEqual( - [value_cache.shape for value_cache in past_key_values.value_cache], - [expected_shape] * len(past_key_values.value_cache), + conv_shape = ( + batch_size, + intermediate_size + 2 * config.mamba_n_groups * config.mamba_d_state, + conv_kernel_size, ) + ssm_shape = (batch_size, config.mamba_n_heads, config.mamba_d_head, config.mamba_d_state) + + for idx in range(len(past_key_values)): + if config.layers_block_type[idx] == "mamba": + self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) + else: + self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) def _check_caches_are_equal(self, cache1, cache2): - if not isinstance(cache1, FalconHybridMambaAttentionDynamicCache) or not isinstance( - cache2, FalconHybridMambaAttentionDynamicCache - ): + if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): raise ValueError("The wrong cache is being used!") if not len(cache1) == len(cache2): @@ -292,10 +290,12 @@ def _check_caches_are_equal(self, cache1, cache2): num_layers = len(cache1) for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx]) - torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx]) + if isinstance(cache1.layers[idx], MambaLayer): + torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) + torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + else: + torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) + torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) def setUp(self): self.model_tester = FalconH1ModelTester(self) diff --git a/tests/models/nemotron_h/test_modeling_nemotron_h.py b/tests/models/nemotron_h/test_modeling_nemotron_h.py index 467971673065..43a3ea8af9b1 100644 --- a/tests/models/nemotron_h/test_modeling_nemotron_h.py +++ b/tests/models/nemotron_h/test_modeling_nemotron_h.py @@ -40,13 +40,8 @@ if is_torch_available(): import torch - from transformers import ( - NemotronHForCausalLM, - NemotronHModel, - ) - from transformers.models.nemotron_h.modeling_nemotron_h import ( - NemotronHHybridDynamicCache, - ) + from transformers import DynamicCache, NemotronHForCausalLM, NemotronHModel + from transformers.cache_utils import MambaLayer class NemotronHModelTester: @@ -237,12 +232,9 @@ def create_and_check_decoder_model_past_large_inputs( model.eval() # first forward pass - # Attention: NemotronH needs the cache to be initialized to return a cache! - past_key_values = NemotronHHybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device) outputs = model( input_ids, attention_mask=input_mask, - past_key_values=past_key_values, use_cache=True, ) past_key_values = outputs.past_key_values @@ -319,17 +311,11 @@ def create_and_check_mamba2_slow_vs_fast_forward(self, config, input_ids, *args) self.parent.assertTrue(torch.allclose(outputs_fast, outputs_slow, atol=1e-3, rtol=1e-3)) # Test with cache - batch_size = input_ids.shape[0] - cache_params = NemotronHHybridDynamicCache( - config=config, batch_size=batch_size, dtype=token_emb.dtype, device=torch_device - ) - + cache_params = DynamicCache(config=config) outputs_fast_cached = mamba_mixer.cuda_kernels_forward(token_emb, cache_params=cache_params) # Reset cache for fair comparison - cache_params_slow = NemotronHHybridDynamicCache( - config=config, batch_size=batch_size, dtype=token_emb.dtype, device=torch_device - ) + cache_params_slow = DynamicCache(config=config) outputs_slow_cached = mamba_mixer.torch_forward(token_emb, cache_params=cache_params_slow) self.parent.assertTrue(torch.allclose(outputs_fast_cached, outputs_slow_cached, atol=1e-3, rtol=1e-3)) @@ -368,7 +354,7 @@ class NemotronHModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ) def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, NemotronHHybridDynamicCache) + self.assertIsInstance(past_key_values, DynamicCache) # (batch, kv heads, seq_length, head_dim) num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) @@ -388,14 +374,14 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l for idx in range(len(past_key_values)): if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) - self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) - elif config.layers_block_type[idx] == "attention": - self.assertEqual(past_key_values.key_cache[idx].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[idx].shape, attention_shape) - - def _check_caches_are_equal(self, cache1: NemotronHHybridDynamicCache, cache2: NemotronHHybridDynamicCache): - if not isinstance(cache1, NemotronHHybridDynamicCache) or not isinstance(cache2, NemotronHHybridDynamicCache): + self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) + else: + self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) + + def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): + if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): raise ValueError("The wrong cache is being used!") if not len(cache1) == len(cache2): @@ -403,10 +389,14 @@ def _check_caches_are_equal(self, cache1: NemotronHHybridDynamicCache, cache2: N num_layers = len(cache1) for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx]) - torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx]) + # Mamba layer + if type(cache1.layers[idx]) is MambaLayer: + torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) + torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + # Attention layer + else: + torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) + torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) def setUp(self): self.model_tester = NemotronHModelTester(self) From b156adeb278eb94059e202a4fcdc74f6db1b6107 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 15:51:11 +0100 Subject: [PATCH 22/56] config --- src/transformers/models/nemotron_h/configuration_nemotron_h.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/nemotron_h/configuration_nemotron_h.py b/src/transformers/models/nemotron_h/configuration_nemotron_h.py index f2ab3ed25ebd..cf6d51bc2c45 100644 --- a/src/transformers/models/nemotron_h/configuration_nemotron_h.py +++ b/src/transformers/models/nemotron_h/configuration_nemotron_h.py @@ -80,6 +80,7 @@ class NemotronHConfig(PreTrainedConfig): ```""" model_type = "nemotron_h" + attribute_map = {"layer_types": "layers_block_type"} keys_to_ignore_at_inference = ["past_key_values"] vocab_size: int = 131072 From 330e397d5aa4461141a95177e76c12644edcb78a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 16:03:03 +0100 Subject: [PATCH 23/56] finalize almost everything --- src/transformers/models/lfm2/modeling_lfm2.py | 187 ++--------------- src/transformers/models/lfm2/modular_lfm2.py | 185 ++--------------- .../models/lfm2_moe/modeling_lfm2_moe.py | 189 ++---------------- .../models/lfm2_moe/modular_lfm2_moe.py | 15 +- tests/models/lfm2/test_modeling_lfm2.py | 32 +-- .../models/lfm2_moe/test_modeling_lfm2_moe.py | 32 +-- 6 files changed, 87 insertions(+), 553 deletions(-) diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index ad3d154c2d06..ef753e3b2893 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -18,13 +18,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable -from typing import Any, Optional +from typing import Optional import torch import torch.nn.functional as F from torch import nn -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func from ...masking_utils import create_causal_mask @@ -152,160 +152,6 @@ def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) -class Lfm2HybridConvCache: - """ - Attention and conv cache for Lfm2. - - It stores the Key and Value states as a list of tensors, one for each layer. - Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`. - Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`. - """ - - # Override @property existing in Cache - max_batch_size = None - is_compileable = False - key_cache = None - value_cache = None - - def __init__( - self, - config: Lfm2Config, - max_batch_size: int, - dtype: torch.dtype = torch.float32, - device: torch.device | str | None = None, - ): - self.key_cache = [] - self.value_cache = [] - self.max_batch_size = max_batch_size - self.layer_types = config.layer_types - self.first_attention_layer = self.layer_types.index("full_attention") - self.last_conv_layer = len(self.layer_types) - self.layer_types[::-1].index("conv") - 1 - self.conv_L_cache = config.conv_L_cache - self._dtype = dtype - self.has_previous_state = False - - self.conv_cache: list[torch.Tensor] = [] - device = torch.device(device) if device is not None else None - - for _ in range(config.num_hidden_layers): - conv_state = torch.zeros( - self.max_batch_size, - config.hidden_size, - self.conv_L_cache, - dtype=self._dtype, - device=device, - ) - self.conv_cache.append(conv_state) - self.key_cache.append(torch.tensor([])) - self.value_cache.append(torch.tensor([])) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if self.key_cache[layer_idx].numel() == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now - if cache_init: - self.conv_cache[layer_idx] = new_conv_state.to(self.conv_cache[layer_idx].device) - else: - self.conv_cache[layer_idx] = self.conv_cache[layer_idx].roll(shifts=-1, dims=-1) - self.conv_cache[layer_idx][:, :, -1] = new_conv_state[:, :, -1].to(self.conv_cache[layer_idx].device) - - # If last layer is updated, set the flag - if layer_idx == self.last_conv_layer: - self.has_previous_state = True - - return self.conv_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - if self.conv_cache[layer_idx].numel(): - device = self.conv_cache[layer_idx].device - self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.first_attention_layer if self.layer_types[layer_idx] != "full_attention" else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - full_mask_kv_offset = 0 - past_seen_tokens = self.get_seq_length() - kv_length = query_length + past_seen_tokens - return kv_length, full_mask_kv_offset - - def crop(self, max_length: int): - """Crop the cache to the given length""" - if max_length < 0: - max_length = self.get_seq_length() - abs(max_length) - - if self.get_seq_length() <= max_length: - return - - for idx in range(len(self.key_cache)): - if self.key_cache[idx].numel(): - self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] - self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] - - def __len__(self) -> int: - return len(self.key_cache) - - def reset(self): - self.has_previous_state = False - for layer_idx in range(len(self.conv_cache)): - # In-place ops prevent breaking the static address - self.conv_cache[layer_idx].zero_() - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -400,7 +246,7 @@ def forward( hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] @@ -477,7 +323,7 @@ def __init__( def cuda_kernels_forward( self, x: torch.Tensor, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): x = apply_mask_to_padding_states(x, attention_mask) @@ -487,10 +333,10 @@ def cuda_kernels_forward( Bx = B * x conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2)) - if past_key_values is not None and past_key_values.has_previous_state: + if past_key_values is not None and past_key_values.has_previous_state(self.layer_idx): conv_out = causal_conv1d_update( Bx.squeeze(-1), - past_key_values.conv_cache[self.layer_idx], + past_key_values.layers[self.layer_idx].conv_states, conv_weights, self.conv.bias, None, @@ -499,7 +345,7 @@ def cuda_kernels_forward( else: if past_key_values is not None: conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) - past_key_values.update_conv_state(self.layer_idx, conv_state, cache_init=True) + conv_state = past_key_values.update_conv_state(conv_state, self.layer_idx) conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None) @@ -510,7 +356,7 @@ def cuda_kernels_forward( def slow_forward( self, x: torch.Tensor, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): seqlen = x.shape[1] @@ -521,8 +367,8 @@ def slow_forward( Bx = B * x - if past_key_values is not None and past_key_values.has_previous_state: - conv_state = past_key_values.update_conv_state(self.layer_idx, Bx, cache_init=False) + if past_key_values is not None and past_key_values.has_previous_state(self.layer_idx): + conv_state = past_key_values.update_conv_state(Bx, self.layer_idx) conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1) if self.bias: conv_out += self.conv.bias @@ -531,7 +377,7 @@ def slow_forward( else: if past_key_values is not None: conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) - conv_state = past_key_values.update_conv_state(self.layer_idx, conv_state, cache_init=True) + conv_state = past_key_values.update_conv_state(conv_state, self.layer_idx) conv_out = self.conv(Bx)[..., :seqlen] @@ -543,7 +389,7 @@ def slow_forward( def forward( self, hidden_states: torch.Tensor, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): if is_fast_path_available and "cuda" in hidden_states.device.type and not is_torchdynamo_compiling(): @@ -570,7 +416,7 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, **kwargs, ) -> torch.Tensor: residual = hidden_states @@ -639,7 +485,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -651,10 +497,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - batch_size = inputs_embeds.shape[0] - past_key_values = Lfm2HybridConvCache( - config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index 6f3a754d69ae..c08b61956081 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable -from typing import Any import torch import torch.nn.functional as F from torch import nn +from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast @@ -80,160 +80,6 @@ def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) -class Lfm2HybridConvCache: - """ - Attention and conv cache for Lfm2. - - It stores the Key and Value states as a list of tensors, one for each layer. - Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`. - Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`. - """ - - # Override @property existing in Cache - max_batch_size = None - is_compileable = False - key_cache = None - value_cache = None - - def __init__( - self, - config: Lfm2Config, - max_batch_size: int, - dtype: torch.dtype = torch.float32, - device: torch.device | str | None = None, - ): - self.key_cache = [] - self.value_cache = [] - self.max_batch_size = max_batch_size - self.layer_types = config.layer_types - self.first_attention_layer = self.layer_types.index("full_attention") - self.last_conv_layer = len(self.layer_types) - self.layer_types[::-1].index("conv") - 1 - self.conv_L_cache = config.conv_L_cache - self._dtype = dtype - self.has_previous_state = False - - self.conv_cache: list[torch.Tensor] = [] - device = torch.device(device) if device is not None else None - - for _ in range(config.num_hidden_layers): - conv_state = torch.zeros( - self.max_batch_size, - config.hidden_size, - self.conv_L_cache, - dtype=self._dtype, - device=device, - ) - self.conv_cache.append(conv_state) - self.key_cache.append(torch.tensor([])) - self.value_cache.append(torch.tensor([])) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if self.key_cache[layer_idx].numel() == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now - if cache_init: - self.conv_cache[layer_idx] = new_conv_state.to(self.conv_cache[layer_idx].device) - else: - self.conv_cache[layer_idx] = self.conv_cache[layer_idx].roll(shifts=-1, dims=-1) - self.conv_cache[layer_idx][:, :, -1] = new_conv_state[:, :, -1].to(self.conv_cache[layer_idx].device) - - # If last layer is updated, set the flag - if layer_idx == self.last_conv_layer: - self.has_previous_state = True - - return self.conv_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - if self.conv_cache[layer_idx].numel(): - device = self.conv_cache[layer_idx].device - self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.first_attention_layer if self.layer_types[layer_idx] != "full_attention" else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - full_mask_kv_offset = 0 - past_seen_tokens = self.get_seq_length() - kv_length = query_length + past_seen_tokens - return kv_length, full_mask_kv_offset - - def crop(self, max_length: int): - """Crop the cache to the given length""" - if max_length < 0: - max_length = self.get_seq_length() - abs(max_length) - - if self.get_seq_length() <= max_length: - return - - for idx in range(len(self.key_cache)): - if self.key_cache[idx].numel(): - self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] - self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] - - def __len__(self) -> int: - return len(self.key_cache) - - def reset(self): - self.has_previous_state = False - for layer_idx in range(len(self.conv_cache)): - # In-place ops prevent breaking the static address - self.conv_cache[layer_idx].zero_() - - class Lfm2Attention(LlamaAttention): def __init__(self, config: Lfm2Config, layer_idx: int): super().__init__(config, layer_idx) @@ -251,7 +97,7 @@ def forward( hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] @@ -312,7 +158,7 @@ def __init__( def cuda_kernels_forward( self, x: torch.Tensor, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): x = apply_mask_to_padding_states(x, attention_mask) @@ -322,10 +168,10 @@ def cuda_kernels_forward( Bx = B * x conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2)) - if past_key_values is not None and past_key_values.has_previous_state: + if past_key_values is not None and past_key_values.has_previous_state(self.layer_idx): conv_out = causal_conv1d_update( Bx.squeeze(-1), - past_key_values.conv_cache[self.layer_idx], + past_key_values.layers[self.layer_idx].conv_states, conv_weights, self.conv.bias, None, @@ -334,7 +180,7 @@ def cuda_kernels_forward( else: if past_key_values is not None: conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) - past_key_values.update_conv_state(self.layer_idx, conv_state, cache_init=True) + conv_state = past_key_values.update_conv_state(conv_state, self.layer_idx) conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None) @@ -345,7 +191,7 @@ def cuda_kernels_forward( def slow_forward( self, x: torch.Tensor, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): seqlen = x.shape[1] @@ -356,8 +202,8 @@ def slow_forward( Bx = B * x - if past_key_values is not None and past_key_values.has_previous_state: - conv_state = past_key_values.update_conv_state(self.layer_idx, Bx, cache_init=False) + if past_key_values is not None and past_key_values.has_previous_state(self.layer_idx): + conv_state = past_key_values.update_conv_state(Bx, self.layer_idx) conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1) if self.bias: conv_out += self.conv.bias @@ -366,7 +212,7 @@ def slow_forward( else: if past_key_values is not None: conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) - conv_state = past_key_values.update_conv_state(self.layer_idx, conv_state, cache_init=True) + conv_state = past_key_values.update_conv_state(conv_state, self.layer_idx) conv_out = self.conv(Bx)[..., :seqlen] @@ -378,7 +224,7 @@ def slow_forward( def forward( self, hidden_states: torch.Tensor, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): if is_fast_path_available and "cuda" in hidden_states.device.type and not is_torchdynamo_compiling(): @@ -405,7 +251,7 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, **kwargs, ) -> torch.Tensor: residual = hidden_states @@ -445,7 +291,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -457,10 +303,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - batch_size = inputs_embeds.shape[0] - past_key_values = Lfm2HybridConvCache( - config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py index 2a5d4564a1e1..0369ae31b8ae 100644 --- a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py @@ -19,14 +19,14 @@ # limitations under the License. from collections.abc import Callable -from typing import Any, Optional +from typing import Optional import torch import torch.nn.functional as F from torch import nn from ... import initialization as init -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import ( use_experts_implementation, @@ -228,160 +228,6 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) -class Lfm2MoeHybridConvCache: - """ - Attention and conv cache for Lfm2Moe. - - It stores the Key and Value states as a list of tensors, one for each layer. - Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`. - Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`. - """ - - # Override @property existing in Cache - max_batch_size = None - is_compileable = False - key_cache = None - value_cache = None - - def __init__( - self, - config: Lfm2MoeConfig, - max_batch_size: int, - dtype: torch.dtype = torch.float32, - device: torch.device | str | None = None, - ): - self.key_cache = [] - self.value_cache = [] - self.max_batch_size = max_batch_size - self.layer_types = config.layer_types - self.first_attention_layer = self.layer_types.index("full_attention") - self.last_conv_layer = len(self.layer_types) - self.layer_types[::-1].index("conv") - 1 - self.conv_L_cache = config.conv_L_cache - self._dtype = dtype - self.has_previous_state = False - - self.conv_cache: list[torch.Tensor] = [] - device = torch.device(device) if device is not None else None - - for _ in range(config.num_hidden_layers): - conv_state = torch.zeros( - self.max_batch_size, - config.hidden_size, - self.conv_L_cache, - dtype=self._dtype, - device=device, - ) - self.conv_cache.append(conv_state) - self.key_cache.append(torch.tensor([])) - self.value_cache.append(torch.tensor([])) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if self.key_cache[layer_idx].numel() == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now - if cache_init: - self.conv_cache[layer_idx] = new_conv_state.to(self.conv_cache[layer_idx].device) - else: - self.conv_cache[layer_idx] = self.conv_cache[layer_idx].roll(shifts=-1, dims=-1) - self.conv_cache[layer_idx][:, :, -1] = new_conv_state[:, :, -1].to(self.conv_cache[layer_idx].device) - - # If last layer is updated, set the flag - if layer_idx == self.last_conv_layer: - self.has_previous_state = True - - return self.conv_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - if self.conv_cache[layer_idx].numel(): - device = self.conv_cache[layer_idx].device - self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.first_attention_layer if self.layer_types[layer_idx] != "full_attention" else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - full_mask_kv_offset = 0 - past_seen_tokens = self.get_seq_length() - kv_length = query_length + past_seen_tokens - return kv_length, full_mask_kv_offset - - def crop(self, max_length: int): - """Crop the cache to the given length""" - if max_length < 0: - max_length = self.get_seq_length() - abs(max_length) - - if self.get_seq_length() <= max_length: - return - - for idx in range(len(self.key_cache)): - if self.key_cache[idx].numel(): - self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] - self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] - - def __len__(self) -> int: - return len(self.key_cache) - - def reset(self): - self.has_previous_state = False - for layer_idx in range(len(self.conv_cache)): - # In-place ops prevent breaking the static address - self.conv_cache[layer_idx].zero_() - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -476,7 +322,7 @@ def forward( hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None, - past_key_values: Lfm2MoeHybridConvCache | None = None, + past_key_values: Cache | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] @@ -553,7 +399,7 @@ def __init__( def cuda_kernels_forward( self, x: torch.Tensor, - past_key_values: Lfm2MoeHybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): x = apply_mask_to_padding_states(x, attention_mask) @@ -563,10 +409,10 @@ def cuda_kernels_forward( Bx = B * x conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2)) - if past_key_values is not None and past_key_values.has_previous_state: + if past_key_values is not None and past_key_values.has_previous_state(self.layer_idx): conv_out = causal_conv1d_update( Bx.squeeze(-1), - past_key_values.conv_cache[self.layer_idx], + past_key_values.layers[self.layer_idx].conv_states, conv_weights, self.conv.bias, None, @@ -575,7 +421,7 @@ def cuda_kernels_forward( else: if past_key_values is not None: conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) - past_key_values.update_conv_state(self.layer_idx, conv_state, cache_init=True) + conv_state = past_key_values.update_conv_state(conv_state, self.layer_idx) conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None) @@ -586,7 +432,7 @@ def cuda_kernels_forward( def slow_forward( self, x: torch.Tensor, - past_key_values: Lfm2MoeHybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): seqlen = x.shape[1] @@ -597,8 +443,8 @@ def slow_forward( Bx = B * x - if past_key_values is not None and past_key_values.has_previous_state: - conv_state = past_key_values.update_conv_state(self.layer_idx, Bx, cache_init=False) + if past_key_values is not None and past_key_values.has_previous_state(self.layer_idx): + conv_state = past_key_values.update_conv_state(Bx, self.layer_idx) conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1) if self.bias: conv_out += self.conv.bias @@ -607,7 +453,7 @@ def slow_forward( else: if past_key_values is not None: conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) - conv_state = past_key_values.update_conv_state(self.layer_idx, conv_state, cache_init=True) + conv_state = past_key_values.update_conv_state(conv_state, self.layer_idx) conv_out = self.conv(Bx)[..., :seqlen] @@ -619,7 +465,7 @@ def slow_forward( def forward( self, hidden_states: torch.Tensor, - past_key_values: Lfm2MoeHybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): if is_fast_path_available and "cuda" in hidden_states.device.type and not is_torchdynamo_compiling(): @@ -650,7 +496,7 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Lfm2MoeHybridConvCache | None = None, + past_key_values: Cache | None = None, **kwargs, ) -> torch.Tensor: residual = hidden_states @@ -685,7 +531,7 @@ class Lfm2MoePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _can_compile_fullgraph = False # uses a non-compilable custom cache class Lfm2MoeHybridConvCache + _can_compile_fullgraph = False # uses a non-compilable cache class _supports_attention_backend = True _can_record_outputs = { "hidden_states": Lfm2MoeDecoderLayer, @@ -729,7 +575,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Lfm2MoeHybridConvCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -741,10 +587,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - batch_size = inputs_embeds.shape[0] - past_key_values = Lfm2MoeHybridConvCache( - config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/src/transformers/models/lfm2_moe/modular_lfm2_moe.py b/src/transformers/models/lfm2_moe/modular_lfm2_moe.py index 7d97d8b70dd5..a1b2799d2bae 100644 --- a/src/transformers/models/lfm2_moe/modular_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modular_lfm2_moe.py @@ -17,6 +17,7 @@ from torch import nn from ... import initialization as init +from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask from ...modeling_outputs import MoeModelOutputWithPast from ...modeling_utils import PreTrainedModel @@ -26,7 +27,6 @@ from ..lfm2.modeling_lfm2 import ( Lfm2Attention, Lfm2DecoderLayer, - Lfm2HybridConvCache, Lfm2MLP, Lfm2RotaryEmbedding, Lfm2ShortConv, @@ -110,10 +110,6 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) -class Lfm2MoeHybridConvCache(Lfm2HybridConvCache): - pass - - class Lfm2MoeAttention(Lfm2Attention): pass @@ -133,7 +129,7 @@ def __init__(self, config: Lfm2MoeConfig, layer_idx: int): class Lfm2MoePreTrainedModel(LlamaPreTrainedModel): - _can_compile_fullgraph = False # uses a non-compilable custom cache class Lfm2MoeHybridConvCache + _can_compile_fullgraph = False # uses a non-compilable cache class @torch.no_grad() def _init_weights(self, module): @@ -159,7 +155,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Lfm2MoeHybridConvCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -171,10 +167,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - batch_size = inputs_embeds.shape[0] - past_key_values = Lfm2MoeHybridConvCache( - config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/tests/models/lfm2/test_modeling_lfm2.py b/tests/models/lfm2/test_modeling_lfm2.py index 13afd1c2726b..f44ae5632641 100644 --- a/tests/models/lfm2/test_modeling_lfm2.py +++ b/tests/models/lfm2/test_modeling_lfm2.py @@ -29,8 +29,8 @@ if is_torch_available(): import torch - from transformers import Lfm2ForCausalLM, Lfm2Model - from transformers.models.lfm2.modeling_lfm2 import Lfm2HybridConvCache + from transformers import DynamicCache, Lfm2ForCausalLM, Lfm2Model + from transformers.cache_utils import MambaLayer class Lfm2ModelTester(CausalLMModelTester): @@ -53,7 +53,7 @@ class Lfm2ModelTest(CausalLMModelTest, unittest.TestCase): _torch_compile_train_cls = Lfm2ForCausalLM if is_torch_available() else None def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, Lfm2HybridConvCache) + self.assertIsInstance(past_key_values, DynamicCache) # (batch, kv heads, seq_length, head_dim) num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) @@ -61,15 +61,16 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l attention_shape = (batch_size, num_heads, seq_length, head_dim) conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) - for i in range(config.num_hidden_layers): - if config.layer_types[i] == "full_attention": - self.assertEqual(past_key_values.key_cache[i].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[i].shape, attention_shape) + for idx in range(config.num_hidden_layers): + if config.layer_types[idx] == "full_attention": + self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) else: - self.assertEqual(past_key_values.conv_cache[i].shape, conv_shape) + self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values.layers[idx].ssm_states.shape, (1,)) - def _check_caches_are_equal(self, cache1: Lfm2HybridConvCache, cache2: Lfm2HybridConvCache): - if not isinstance(cache1, Lfm2HybridConvCache) or not isinstance(cache2, Lfm2HybridConvCache): + def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): + if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): raise ValueError("The wrong cache is being used!") if not len(cache1) == len(cache2): @@ -77,9 +78,14 @@ def _check_caches_are_equal(self, cache1: Lfm2HybridConvCache, cache2: Lfm2Hybri num_layers = len(cache1) for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_cache[idx], cache2.conv_cache[idx]) + # Mamba layer + if type(cache1.layers[idx]) is MambaLayer: + torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) + torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + # Attention layer + else: + torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) + torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) def test_attention_outputs(self): """Lfm2Moe alternates between attention and short-conv layers.""" diff --git a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py index fa8aecc99707..a5ed1b9fb9b9 100644 --- a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py +++ b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py @@ -32,8 +32,8 @@ if is_torch_available(): import torch - from transformers import Lfm2MoeConfig, Lfm2MoeForCausalLM, Lfm2MoeModel - from transformers.models.lfm2_moe.modeling_lfm2_moe import Lfm2MoeHybridConvCache + from transformers import DynamicCache, Lfm2MoeConfig, Lfm2MoeForCausalLM, Lfm2MoeModel + from transformers.cache_utils import MambaLayer class Lfm2MoeModelTester(CausalLMModelTester): @@ -71,7 +71,7 @@ class Lfm2MoeModelTest(CausalLMModelTest, unittest.TestCase): _torch_compile_train_cls = Lfm2MoeForCausalLM if is_torch_available() else None def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, Lfm2MoeHybridConvCache) + self.assertIsInstance(past_key_values, DynamicCache) # (batch, kv heads, seq_length, head_dim) num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) @@ -79,15 +79,16 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l attention_shape = (batch_size, num_heads, seq_length, head_dim) conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) - for i in range(config.num_hidden_layers): - if config.layer_types[i] == "full_attention": - self.assertEqual(past_key_values.key_cache[i].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[i].shape, attention_shape) + for idx in range(config.num_hidden_layers): + if config.layer_types[idx] == "full_attention": + self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) + self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) else: - self.assertEqual(past_key_values.conv_cache[i].shape, conv_shape) + self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) + self.assertEqual(past_key_values.layers[idx].ssm_states.shape, (1,)) - def _check_caches_are_equal(self, cache1: Lfm2MoeHybridConvCache, cache2: Lfm2MoeHybridConvCache): - if not isinstance(cache1, Lfm2MoeHybridConvCache) or not isinstance(cache2, Lfm2MoeHybridConvCache): + def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): + if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): raise ValueError("The wrong cache is being used!") if not len(cache1) == len(cache2): @@ -95,9 +96,14 @@ def _check_caches_are_equal(self, cache1: Lfm2MoeHybridConvCache, cache2: Lfm2Mo num_layers = len(cache1) for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_cache[idx], cache2.conv_cache[idx]) + # Mamba layer + if type(cache1.layers[idx]) is MambaLayer: + torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) + torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + # Attention layer + else: + torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) + torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) def test_attention_outputs(self): """Lfm2Moe alternates between attention and short-conv layers.""" From b60c6f59d102a16490d355ce6ac7519b34521b7f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 16:08:56 +0100 Subject: [PATCH 24/56] simplify tests --- tests/generation/test_utils.py | 19 +++++++++++++++-- tests/models/bamba/test_modeling_bamba.py | 19 ----------------- .../falcon_h1/test_modeling_falcon_h1.py | 17 --------------- .../test_modeling_granitemoehybrid.py | 19 ----------------- tests/models/jamba/test_modeling_jamba.py | 17 --------------- tests/models/lfm2/test_modeling_lfm2.py | 19 ----------------- .../models/lfm2_moe/test_modeling_lfm2_moe.py | 19 ----------------- tests/models/lfm2_vl/test_modeling_lfm2_vl.py | 14 ------------- .../nemotron_h/test_modeling_nemotron_h.py | 19 ----------------- tests/models/zamba/test_modeling_zamba.py | 21 ------------------- tests/models/zamba2/test_modeling_zamba2.py | 21 ------------------- 11 files changed, 17 insertions(+), 187 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 7d764b55d008..149ef8911176 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -81,6 +81,8 @@ Cache, DynamicCache, EncoderDecoderCache, + MambaAndAttentionLayer, + MambaLayer, QuantoQuantizedLayer, StaticCache, ) @@ -2616,8 +2618,21 @@ def _check_caches_are_equal(self, cache1: Cache, cache2: Cache): num_layers = len(cache1) for idx in range(num_layers): - torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) - torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) + self.assertEqual(type(cache1.layers[idx], cache2.layers[idx])) + # Mamba layer + if isinstance(cache1.layers[idx], MambaLayer): + torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) + torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + # Mamba + Attention layer + elif isinstance(cache1.layers[idx], MambaAndAttentionLayer): + torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) + torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) + torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) + torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + # Attention layer + else: + torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) + torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) @require_torch diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index da236c72790c..6b9eb77dd64a 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -48,7 +48,6 @@ import torch from transformers import BambaForCausalLM, BambaModel, DynamicCache - from transformers.cache_utils import MambaLayer class BambaModelTester: @@ -305,24 +304,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) - def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): - if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - # Mamba layer - if type(cache1.layers[idx]) is MambaLayer: - torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) - torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) - # Attention layer - else: - torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) - torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) - def setUp(self): self.model_tester = self.model_tester_class(self) self.config_tester = ConfigTester(self, config_class=self.model_tester.config_class, hidden_size=64) diff --git a/tests/models/falcon_h1/test_modeling_falcon_h1.py b/tests/models/falcon_h1/test_modeling_falcon_h1.py index e236a818835b..0cb54ff547f9 100644 --- a/tests/models/falcon_h1/test_modeling_falcon_h1.py +++ b/tests/models/falcon_h1/test_modeling_falcon_h1.py @@ -37,7 +37,6 @@ import torch from transformers import AutoTokenizer, DynamicCache, FalconH1ForCausalLM, FalconH1Model - from transformers.cache_utils import MambaLayer class FalconH1ModelTester: @@ -281,22 +280,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) - def _check_caches_are_equal(self, cache1, cache2): - if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - if isinstance(cache1.layers[idx], MambaLayer): - torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) - torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) - else: - torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) - torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) - def setUp(self): self.model_tester = FalconH1ModelTester(self) self.config_tester = ConfigTester(self, config_class=FalconH1Config, hidden_size=64) diff --git a/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py b/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py index e102ff9c93ea..c05e496808c9 100644 --- a/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py +++ b/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py @@ -47,7 +47,6 @@ import torch from transformers import DynamicCache, GraniteMoeHybridForCausalLM, GraniteMoeHybridModel - from transformers.cache_utils import MambaLayer class GraniteMoeHybridModelTester(BambaModelTester): @@ -106,24 +105,6 @@ class GraniteMoeHybridModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.7, 0.8] - def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): - if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - # Mamba layer - if type(cache1.layers[idx]) is MambaLayer: - torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) - torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) - # Attention layer - else: - torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) - torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) - def setUp(self): self.model_tester = self.model_tester_class(self) self.config_tester = ConfigTester(self, config_class=self.model_tester.config_class, hidden_size=64) diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index 499941fcbe81..2b3853d4b051 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -44,7 +44,6 @@ import torch from transformers import DynamicCache, JambaForCausalLM, JambaForSequenceClassification, JambaModel - from transformers.cache_utils import MambaLayer class JambaConfigTester(ConfigTester): @@ -343,22 +342,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) - def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): - if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - if isinstance(cache1.layers[idx], MambaLayer): - torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) - torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) - else: - torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) - torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) - def setUp(self): self.model_tester = JambaModelTester(self) self.config_tester = JambaConfigTester(self, config_class=JambaConfig, hidden_size=32) diff --git a/tests/models/lfm2/test_modeling_lfm2.py b/tests/models/lfm2/test_modeling_lfm2.py index f44ae5632641..12cc3fdb38f9 100644 --- a/tests/models/lfm2/test_modeling_lfm2.py +++ b/tests/models/lfm2/test_modeling_lfm2.py @@ -30,7 +30,6 @@ import torch from transformers import DynamicCache, Lfm2ForCausalLM, Lfm2Model - from transformers.cache_utils import MambaLayer class Lfm2ModelTester(CausalLMModelTester): @@ -69,24 +68,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) self.assertEqual(past_key_values.layers[idx].ssm_states.shape, (1,)) - def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): - if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - # Mamba layer - if type(cache1.layers[idx]) is MambaLayer: - torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) - torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) - # Attention layer - else: - torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) - torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) - def test_attention_outputs(self): """Lfm2Moe alternates between attention and short-conv layers.""" config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py index a5ed1b9fb9b9..8b7c2a55004f 100644 --- a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py +++ b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py @@ -33,7 +33,6 @@ import torch from transformers import DynamicCache, Lfm2MoeConfig, Lfm2MoeForCausalLM, Lfm2MoeModel - from transformers.cache_utils import MambaLayer class Lfm2MoeModelTester(CausalLMModelTester): @@ -87,24 +86,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) self.assertEqual(past_key_values.layers[idx].ssm_states.shape, (1,)) - def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): - if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - # Mamba layer - if type(cache1.layers[idx]) is MambaLayer: - torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) - torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) - # Attention layer - else: - torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) - torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) - def test_attention_outputs(self): """Lfm2Moe alternates between attention and short-conv layers.""" config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/lfm2_vl/test_modeling_lfm2_vl.py b/tests/models/lfm2_vl/test_modeling_lfm2_vl.py index aa153563e4f8..4ccdf539f453 100644 --- a/tests/models/lfm2_vl/test_modeling_lfm2_vl.py +++ b/tests/models/lfm2_vl/test_modeling_lfm2_vl.py @@ -188,20 +188,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l else: self.assertEqual(past_key_values.conv_cache[i].shape, conv_shape) - def _check_caches_are_equal(self, cache1: Lfm2HybridConvCache, cache2: Lfm2HybridConvCache): - """Text model uses lfm2, which has non-standard cache""" - if not isinstance(cache1, Lfm2HybridConvCache) or not isinstance(cache2, Lfm2HybridConvCache): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_cache[idx], cache2.conv_cache[idx]) - def test_config(self): self.config_tester.run_common_tests() diff --git a/tests/models/nemotron_h/test_modeling_nemotron_h.py b/tests/models/nemotron_h/test_modeling_nemotron_h.py index 43a3ea8af9b1..9ea2dffc5f74 100644 --- a/tests/models/nemotron_h/test_modeling_nemotron_h.py +++ b/tests/models/nemotron_h/test_modeling_nemotron_h.py @@ -41,7 +41,6 @@ import torch from transformers import DynamicCache, NemotronHForCausalLM, NemotronHModel - from transformers.cache_utils import MambaLayer class NemotronHModelTester: @@ -380,24 +379,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) - def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): - if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - # Mamba layer - if type(cache1.layers[idx]) is MambaLayer: - torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) - torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) - # Attention layer - else: - torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) - torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) - def setUp(self): self.model_tester = NemotronHModelTester(self) self.config_tester = ConfigTester( diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index 4ce37dd638a7..e59205417458 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -39,7 +39,6 @@ import torch from transformers import DynamicCache, ZambaForCausalLM, ZambaForSequenceClassification, ZambaModel - from transformers.cache_utils import MambaLayer class ZambaModelTester: @@ -313,26 +312,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) - def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): - if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - # Mamba layer - if type(cache1.layers[idx]) is MambaLayer: - torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) - torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) - # Hybrid mamba + attention layer - else: - torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) - torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) - torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) - torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) - def setUp(self): self.model_tester = ZambaModelTester(self) self.config_tester = ConfigTester(self, config_class=ZambaConfig, hidden_size=32) diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 0760c5e60789..545f94e97c02 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -40,7 +40,6 @@ import torch from transformers import DynamicCache, Zamba2ForCausalLM, Zamba2ForSequenceClassification, Zamba2Model - from transformers.cache_utils import MambaLayer class Zamba2ModelTester: @@ -327,26 +326,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) - def _check_caches_are_equal(self, cache1: DynamicCache, cache2: DynamicCache): - if not isinstance(cache1, DynamicCache) or not isinstance(cache2, DynamicCache): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - # Mamba layer - if type(cache1.layers[idx]) is MambaLayer: - torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) - torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) - # Hybrid mamba + attention layer - else: - torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) - torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) - torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) - torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) - def setUp(self): self.model_tester = Zamba2ModelTester(self) self.config_tester = ConfigTester(self, config_class=Zamba2Config, hidden_size=32) From 0e8ca28529438f77e559925032152b4f7d991ba8 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 16:29:59 +0100 Subject: [PATCH 25/56] simplify tests further --- tests/generation/test_utils.py | 33 +++++++++++++++---- tests/models/bamba/test_modeling_bamba.py | 23 +++---------- .../falcon_h1/test_modeling_falcon_h1.py | 20 ++--------- .../test_modeling_falcon_mamba.py | 12 ------- .../test_modeling_granitemoehybrid.py | 22 ++----------- tests/models/jamba/test_modeling_jamba.py | 21 ++---------- tests/models/lfm2/test_modeling_lfm2.py | 20 +++-------- .../models/lfm2_moe/test_modeling_lfm2_moe.py | 20 +++-------- tests/models/lfm2_vl/test_modeling_lfm2_vl.py | 18 ++-------- tests/models/mamba/test_modeling_mamba.py | 12 ------- tests/models/mamba2/test_modeling_mamba2.py | 9 ++--- .../nemotron_h/test_modeling_nemotron_h.py | 21 ++---------- tests/models/zamba/test_modeling_zamba.py | 24 ++------------ tests/models/zamba2/test_modeling_zamba2.py | 24 ++------------ 14 files changed, 61 insertions(+), 218 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 149ef8911176..b086117b0549 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2537,6 +2537,12 @@ def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, c [encoder_expected_shape] * len(hidden_states), ) + def _get_mamba_cache_shapes(batch_size: int, config): + # Default mamba cache shape - can vary based on models so this function is convenient to easily check caches + conv_shape = (batch_size, config.intermediate_size, config.conv_kernel) + ssm_shape = (batch_size, config.intermediate_size, config.state_size) + return conv_shape, ssm_shape + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): # Raise a useful error, asking to explicitly override the method if not isinstance(past_key_values, Cache): @@ -2560,12 +2566,15 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l head_dim = getattr(config, "head_dim", hidden_size // config.num_attention_heads) # For cross attention cache, the seq_length depends on the model, so we remove that dim - expected_shape = ( + attention_shape = ( (batch_size, num_heads, seq_length, head_dim) if seq_length is not None else (batch_size, num_heads, head_dim) ) + # For mamba layers + conv_shape, ssm_shape = self._get_mamba_cache_shapes(batch_size, config) + # Check the size is coherent num_hidden_layers = config.num_hidden_layers if getattr(config, "num_kv_shared_layers", None) is not None: @@ -2574,11 +2583,23 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l # Check each layer has the correct shape for layer in past_key_values.layers: - # Remove the seq_length dim for cross-attention cache (it changes based on the model) - keys = layer.keys if seq_length is not None else layer.keys[:, :, 0, :] - values = layer.values if seq_length is not None else layer.values[:, :, 0, :] - self.assertEqual(keys.shape, expected_shape) - self.assertEqual(values.shape, expected_shape) + if isinstance(layer, MambaLayer): + self.assertEqual(layer.conv_states.shape, conv_shape) + self.assertEqual(layer.ssm_states.shape, ssm_shape) + elif isinstance(layer, MambaAndAttentionLayer): + # Remove the seq_length dim for cross-attention cache (it changes based on the model) + keys = layer.keys if seq_length is not None else layer.keys[:, :, 0, :] + values = layer.values if seq_length is not None else layer.values[:, :, 0, :] + self.assertEqual(keys.shape, attention_shape) + self.assertEqual(values.shape, attention_shape) + self.assertEqual(layer.conv_states.shape, conv_shape) + self.assertEqual(layer.ssm_states.shape, ssm_shape) + else: + # Remove the seq_length dim for cross-attention cache (it changes based on the model) + keys = layer.keys if seq_length is not None else layer.keys[:, :, 0, :] + values = layer.values if seq_length is not None else layer.values[:, :, 0, :] + self.assertEqual(keys.shape, attention_shape) + self.assertEqual(values.shape, attention_shape) def _check_sequence_inside_sequence(self, tensor_1, tensor_2): # check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1. diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index 6b9eb77dd64a..9ca8127b5616 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -47,7 +47,7 @@ if is_torch_available(): import torch - from transformers import BambaForCausalLM, BambaModel, DynamicCache + from transformers import BambaForCausalLM, BambaModel class BambaModelTester: @@ -279,30 +279,15 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.7, 0.8] - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, DynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) - + def _get_mamba_cache_shapes(batch_size: int, config): + # For mamba layers conv_shape = ( batch_size, config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * config.mamba_d_state, config.mamba_d_conv, ) ssm_shape = (batch_size, config.mamba_n_heads, config.mamba_d_head, config.mamba_d_state) - - self.assertTrue(config.num_hidden_layers, len(past_key_values)) - - for idx in range(len(past_key_values)): - if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) - else: - self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) - self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) + return conv_shape, ssm_shape def setUp(self): self.model_tester = self.model_tester_class(self) diff --git a/tests/models/falcon_h1/test_modeling_falcon_h1.py b/tests/models/falcon_h1/test_modeling_falcon_h1.py index 0cb54ff547f9..a562ff404c59 100644 --- a/tests/models/falcon_h1/test_modeling_falcon_h1.py +++ b/tests/models/falcon_h1/test_modeling_falcon_h1.py @@ -36,7 +36,7 @@ if is_torch_available(): import torch - from transformers import AutoTokenizer, DynamicCache, FalconH1ForCausalLM, FalconH1Model + from transformers import AutoTokenizer, FalconH1ForCausalLM, FalconH1Model class FalconH1ModelTester: @@ -253,14 +253,7 @@ class FalconH1ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM {"feature-extraction": FalconH1Model, "text-generation": FalconH1ForCausalLM} if is_torch_available() else {} ) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, DynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) - + def _get_mamba_cache_shapes(batch_size: int, config): conv_kernel_size = config.mamba_d_conv intermediate_size = ( config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) @@ -271,14 +264,7 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l conv_kernel_size, ) ssm_shape = (batch_size, config.mamba_n_heads, config.mamba_d_head, config.mamba_d_state) - - for idx in range(len(past_key_values)): - if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) - else: - self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) - self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) + return conv_shape, ssm_shape def setUp(self): self.model_tester = FalconH1ModelTester(self) diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py index 8aea217bfbc1..3c674b1302eb 100644 --- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py +++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py @@ -274,18 +274,6 @@ def setUp(self): self, config_class=FalconMambaConfig, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] ) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, DynamicCache) - - conv_shape = (batch_size, config.intermediate_size, config.conv_kernel) - ssm_shape = (batch_size, config.intermediate_size, config.state_size) - - self.assertTrue(config.num_hidden_layers, len(past_key_values)) - - for idx in range(len(past_key_values)): - self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) - def assertInterval(self, member, container, msg=None): r""" Simple utility function to check if a member is inside an interval. diff --git a/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py b/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py index c05e496808c9..41faac093dd4 100644 --- a/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py +++ b/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py @@ -46,7 +46,7 @@ if is_torch_available(): import torch - from transformers import DynamicCache, GraniteMoeHybridForCausalLM, GraniteMoeHybridModel + from transformers import GraniteMoeHybridForCausalLM, GraniteMoeHybridModel class GraniteMoeHybridModelTester(BambaModelTester): @@ -302,30 +302,14 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_id loss_padfree = res_padfree.loss torch.testing.assert_close(loss_padded, loss_padfree) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, DynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) - + def _get_mamba_cache_shapes(batch_size: int, config): conv_shape = ( batch_size, config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * config.mamba_d_state, config.mamba_d_conv, ) ssm_shape = (batch_size, config.mamba_n_heads, config.mamba_d_head, config.mamba_d_state) - - self.assertTrue(config.num_hidden_layers, len(past_key_values)) - - for idx in range(len(past_key_values)): - if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) - else: - self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) - self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) + return conv_shape, ssm_shape def test_config_requires_mamba_or_attention_layers(self): """Ensure we can't create a config with disallowed layers.""" diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index 2b3853d4b051..3f0b2548cea9 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -43,7 +43,7 @@ if is_torch_available(): import torch - from transformers import DynamicCache, JambaForCausalLM, JambaForSequenceClassification, JambaModel + from transformers import JambaForCausalLM, JambaForSequenceClassification, JambaModel class JambaConfigTester(ConfigTester): @@ -322,25 +322,10 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi else {} ) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, DynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) + def _get_mamba_cache_shapes(batch_size: int, config): conv_shape = (batch_size, config.mamba_expand * config.hidden_size, config.mamba_d_conv) ssm_shape = (batch_size, config.mamba_expand * config.hidden_size, config.mamba_d_state) - - self.assertTrue(config.num_hidden_layers, len(past_key_values)) - - for idx in range(len(past_key_values)): - if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) - else: - self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) - self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) + return conv_shape, ssm_shape def setUp(self): self.model_tester = JambaModelTester(self) diff --git a/tests/models/lfm2/test_modeling_lfm2.py b/tests/models/lfm2/test_modeling_lfm2.py index 12cc3fdb38f9..ad3425ab4512 100644 --- a/tests/models/lfm2/test_modeling_lfm2.py +++ b/tests/models/lfm2/test_modeling_lfm2.py @@ -29,7 +29,7 @@ if is_torch_available(): import torch - from transformers import DynamicCache, Lfm2ForCausalLM, Lfm2Model + from transformers import Lfm2ForCausalLM, Lfm2Model class Lfm2ModelTester(CausalLMModelTester): @@ -51,22 +51,10 @@ class Lfm2ModelTest(CausalLMModelTest, unittest.TestCase): # used in `test_torch_compile_for_training` _torch_compile_train_cls = Lfm2ForCausalLM if is_torch_available() else None - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, DynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) + def _get_mamba_cache_shapes(batch_size: int, config): conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) - - for idx in range(config.num_hidden_layers): - if config.layer_types[idx] == "full_attention": - self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) - self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) - else: - self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values.layers[idx].ssm_states.shape, (1,)) + ssm_shape = (1,) + return conv_shape, ssm_shape def test_attention_outputs(self): """Lfm2Moe alternates between attention and short-conv layers.""" diff --git a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py index 8b7c2a55004f..ed8a2b8366ee 100644 --- a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py +++ b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py @@ -32,7 +32,7 @@ if is_torch_available(): import torch - from transformers import DynamicCache, Lfm2MoeConfig, Lfm2MoeForCausalLM, Lfm2MoeModel + from transformers import Lfm2MoeConfig, Lfm2MoeForCausalLM, Lfm2MoeModel class Lfm2MoeModelTester(CausalLMModelTester): @@ -69,22 +69,10 @@ class Lfm2MoeModelTest(CausalLMModelTest, unittest.TestCase): # used in `test_torch_compile_for_training` _torch_compile_train_cls = Lfm2MoeForCausalLM if is_torch_available() else None - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, DynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) + def _get_mamba_cache_shapes(batch_size: int, config): conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) - - for idx in range(config.num_hidden_layers): - if config.layer_types[idx] == "full_attention": - self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) - self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) - else: - self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values.layers[idx].ssm_states.shape, (1,)) + ssm_shape = (1,) + return conv_shape, ssm_shape def test_attention_outputs(self): """Lfm2Moe alternates between attention and short-conv layers.""" diff --git a/tests/models/lfm2_vl/test_modeling_lfm2_vl.py b/tests/models/lfm2_vl/test_modeling_lfm2_vl.py index 4ccdf539f453..9415b3a55ef7 100644 --- a/tests/models/lfm2_vl/test_modeling_lfm2_vl.py +++ b/tests/models/lfm2_vl/test_modeling_lfm2_vl.py @@ -44,7 +44,6 @@ import torch from transformers import Lfm2VlConfig, Lfm2VlForConditionalGeneration, Lfm2VlModel - from transformers.models.lfm2.modeling_lfm2 import Lfm2HybridConvCache class Lfm2VlModelTester(CausalLMModelTester): @@ -172,21 +171,10 @@ def setUp(self): self, config_class=Lfm2VlConfig, has_text_modality=False, common_properties=common_properties ) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, Lfm2HybridConvCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) + def _get_mamba_cache_shapes(batch_size: int, config): conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) - - for i in range(config.num_hidden_layers): - if config.layer_types[i] == "full_attention": - self.assertEqual(past_key_values.key_cache[i].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[i].shape, attention_shape) - else: - self.assertEqual(past_key_values.conv_cache[i].shape, conv_shape) + ssm_shape = (1,) + return conv_shape, ssm_shape def test_config(self): self.config_tester.run_common_tests() diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index f68a17491590..5d56d1e724b2 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -242,18 +242,6 @@ def setUp(self): def test_enable_input_require_grads(self): self.skipTest("Mamba currently requires CUDA/Metal/XPU to run enable_input_require_grads.") - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, DynamicCache) - - conv_shape = (batch_size, config.intermediate_size, config.conv_kernel) - ssm_shape = (batch_size, config.intermediate_size, config.state_size) - - self.assertTrue(config.num_hidden_layers, len(past_key_values)) - - for idx in range(len(past_key_values)): - self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) - def assertInterval(self, member, container, msg=None): r""" Simple utility function to check if a member is inside an interval. diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 6aa16ddeecd2..e9704a2bd2dc 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -244,9 +244,7 @@ def setUp(self): self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] ) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, DynamicCache) - + def _get_mamba_cache_shapes(batch_size: int, config): intermediate_size = config.expand * config.hidden_size conv_shape = ( batch_size, @@ -254,10 +252,7 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l config.conv_kernel, ) ssm_shape = (batch_size, config.num_heads, config.head_dim, config.state_size) - - for idx in range(len(past_key_values)): - self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) + return conv_shape, ssm_shape def test_mamba2_caching(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/nemotron_h/test_modeling_nemotron_h.py b/tests/models/nemotron_h/test_modeling_nemotron_h.py index 9ea2dffc5f74..f90bf0276fbf 100644 --- a/tests/models/nemotron_h/test_modeling_nemotron_h.py +++ b/tests/models/nemotron_h/test_modeling_nemotron_h.py @@ -352,15 +352,7 @@ class NemotronHModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester else {} ) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, DynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) - - # Mamba cache shapes + def _get_mamba_cache_shapes(batch_size: int, config): intermediate_size = config.mamba_num_heads * config.mamba_head_dim conv_shape = ( batch_size, @@ -368,16 +360,7 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l config.conv_kernel, ) ssm_shape = (batch_size, config.mamba_num_heads, config.mamba_head_dim, config.ssm_state_size) - - self.assertTrue(config.num_hidden_layers, len(past_key_values)) - - for idx in range(len(past_key_values)): - if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) - else: - self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) - self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) + return conv_shape, ssm_shape def setUp(self): self.model_tester = NemotronHModelTester(self) diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index e59205417458..34b5eb9d611f 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -38,7 +38,7 @@ if is_torch_available(): import torch - from transformers import DynamicCache, ZambaForCausalLM, ZambaForSequenceClassification, ZambaModel + from transformers import ZambaForCausalLM, ZambaForSequenceClassification, ZambaModel class ZambaModelTester: @@ -288,29 +288,11 @@ class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi else {} ) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, DynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "attention_head_dim") - attention_shape = (batch_size, num_heads, seq_length, head_dim) - + def _get_mamba_cache_shapes(batch_size: int, config): intermediate_size = config.mamba_expand * config.hidden_size conv_shape = (batch_size, intermediate_size, config.mamba_d_conv) ssm_shape = (batch_size, config.n_mamba_heads, intermediate_size // config.n_mamba_heads, config.mamba_d_state) - - self.assertTrue(config.num_hidden_layers, len(past_key_values)) - - for idx in range(len(past_key_values)): - if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) - else: - self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) - self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) - self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) + return conv_shape, ssm_shape def setUp(self): self.model_tester = ZambaModelTester(self) diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 545f94e97c02..03530ad7db83 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -39,7 +39,7 @@ if is_torch_available(): import torch - from transformers import DynamicCache, Zamba2ForCausalLM, Zamba2ForSequenceClassification, Zamba2Model + from transformers import Zamba2ForCausalLM, Zamba2ForSequenceClassification, Zamba2Model class Zamba2ModelTester: @@ -298,14 +298,7 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix else {} ) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, DynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) - + def _get_mamba_cache_shapes(batch_size: int, config): intermediate_size = config.mamba_expand * config.hidden_size conv_shape = ( batch_size, @@ -313,18 +306,7 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l config.mamba_d_conv, ) ssm_shape = (batch_size, config.n_mamba_heads, config.mamba_headdim, config.mamba_d_state) - - self.assertTrue(config.num_hidden_layers, len(past_key_values)) - - for idx in range(len(past_key_values)): - if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) - else: - self.assertEqual(past_key_values.layers[idx].conv_states.shape, conv_shape) - self.assertEqual(past_key_values.layers[idx].ssm_states.shape, ssm_shape) - self.assertEqual(past_key_values.layers[idx].keys.shape, attention_shape) - self.assertEqual(past_key_values.layers[idx].values.shape, attention_shape) + return conv_shape, ssm_shape def setUp(self): self.model_tester = Zamba2ModelTester(self) From c2ddcf9b1e73d1668eb86cf0e1e5a8641f2cd03e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 16:37:07 +0100 Subject: [PATCH 26/56] fix tests --- tests/generation/test_utils.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b086117b0549..9fab4d6a63dd 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2583,10 +2583,8 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l # Check each layer has the correct shape for layer in past_key_values.layers: - if isinstance(layer, MambaLayer): - self.assertEqual(layer.conv_states.shape, conv_shape) - self.assertEqual(layer.ssm_states.shape, ssm_shape) - elif isinstance(layer, MambaAndAttentionLayer): + # Mamba + Attention layer cache + if type(layer) is MambaAndAttentionLayer: # Remove the seq_length dim for cross-attention cache (it changes based on the model) keys = layer.keys if seq_length is not None else layer.keys[:, :, 0, :] values = layer.values if seq_length is not None else layer.values[:, :, 0, :] @@ -2594,6 +2592,11 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertEqual(values.shape, attention_shape) self.assertEqual(layer.conv_states.shape, conv_shape) self.assertEqual(layer.ssm_states.shape, ssm_shape) + # Mamba only layer cache + elif type(layer) is MambaLayer: + self.assertEqual(layer.conv_states.shape, conv_shape) + self.assertEqual(layer.ssm_states.shape, ssm_shape) + # Attention only layer type else: # Remove the seq_length dim for cross-attention cache (it changes based on the model) keys = layer.keys if seq_length is not None else layer.keys[:, :, 0, :] @@ -2639,17 +2642,18 @@ def _check_caches_are_equal(self, cache1: Cache, cache2: Cache): num_layers = len(cache1) for idx in range(num_layers): - self.assertEqual(type(cache1.layers[idx], cache2.layers[idx])) - # Mamba layer - if isinstance(cache1.layers[idx], MambaLayer): - torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) - torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + self.assertEqual(type(cache1.layers[idx]), type(cache2.layers[idx])) + # Mamba + Attention layer - elif isinstance(cache1.layers[idx], MambaAndAttentionLayer): + if type(cache1.layers[idx]) is MambaAndAttentionLayer: torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + # Mamba layer + elif type(cache1.layers[idx]) is MambaLayer: + torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) + torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) # Attention layer else: torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) From b23708fa3580d85e6b374accc5b5a06319183d8d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 16:45:40 +0100 Subject: [PATCH 27/56] oupsi --- tests/generation/test_utils.py | 7 ++++--- tests/models/bamba/test_modeling_bamba.py | 2 +- tests/models/falcon_h1/test_modeling_falcon_h1.py | 2 +- .../granitemoehybrid/test_modeling_granitemoehybrid.py | 2 +- tests/models/jamba/test_modeling_jamba.py | 2 +- tests/models/lfm2/test_modeling_lfm2.py | 2 +- tests/models/lfm2_moe/test_modeling_lfm2_moe.py | 2 +- tests/models/lfm2_vl/test_modeling_lfm2_vl.py | 2 +- tests/models/mamba2/test_modeling_mamba2.py | 2 +- tests/models/nemotron_h/test_modeling_nemotron_h.py | 2 +- tests/models/zamba/test_modeling_zamba.py | 2 +- tests/models/zamba2/test_modeling_zamba2.py | 2 +- 12 files changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 9fab4d6a63dd..906380997763 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2537,7 +2537,7 @@ def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, c [encoder_expected_shape] * len(hidden_states), ) - def _get_mamba_cache_shapes(batch_size: int, config): + def _get_mamba_cache_shapes(self, batch_size: int, config): # Default mamba cache shape - can vary based on models so this function is convenient to easily check caches conv_shape = (batch_size, config.intermediate_size, config.conv_kernel) ssm_shape = (batch_size, config.intermediate_size, config.state_size) @@ -2561,9 +2561,10 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l config = config.get_text_config(decoder=True) # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + # Only pure mamba models do not have num_attention_heads defined in config, so it can never be 1 in practice for attention models + num_heads = getattr(config, "num_key_value_heads", getattr(config, "num_attention_heads", 1)) hidden_size = getattr(config, "d_model", config.hidden_size) - head_dim = getattr(config, "head_dim", hidden_size // config.num_attention_heads) + head_dim = getattr(config, "head_dim", hidden_size // num_heads) # For cross attention cache, the seq_length depends on the model, so we remove that dim attention_shape = ( diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index 9ca8127b5616..bf180d0fedf0 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -279,7 +279,7 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.7, 0.8] - def _get_mamba_cache_shapes(batch_size: int, config): + def _get_mamba_cache_shapes(self, batch_size: int, config): # For mamba layers conv_shape = ( batch_size, diff --git a/tests/models/falcon_h1/test_modeling_falcon_h1.py b/tests/models/falcon_h1/test_modeling_falcon_h1.py index a562ff404c59..2537a76b0847 100644 --- a/tests/models/falcon_h1/test_modeling_falcon_h1.py +++ b/tests/models/falcon_h1/test_modeling_falcon_h1.py @@ -253,7 +253,7 @@ class FalconH1ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM {"feature-extraction": FalconH1Model, "text-generation": FalconH1ForCausalLM} if is_torch_available() else {} ) - def _get_mamba_cache_shapes(batch_size: int, config): + def _get_mamba_cache_shapes(self, batch_size: int, config): conv_kernel_size = config.mamba_d_conv intermediate_size = ( config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) diff --git a/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py b/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py index 41faac093dd4..83fdc436b9a9 100644 --- a/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py +++ b/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py @@ -302,7 +302,7 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_id loss_padfree = res_padfree.loss torch.testing.assert_close(loss_padded, loss_padfree) - def _get_mamba_cache_shapes(batch_size: int, config): + def _get_mamba_cache_shapes(self, batch_size: int, config): conv_shape = ( batch_size, config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * config.mamba_d_state, diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index 3f0b2548cea9..77500ad06251 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -322,7 +322,7 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi else {} ) - def _get_mamba_cache_shapes(batch_size: int, config): + def _get_mamba_cache_shapes(self, batch_size: int, config): conv_shape = (batch_size, config.mamba_expand * config.hidden_size, config.mamba_d_conv) ssm_shape = (batch_size, config.mamba_expand * config.hidden_size, config.mamba_d_state) return conv_shape, ssm_shape diff --git a/tests/models/lfm2/test_modeling_lfm2.py b/tests/models/lfm2/test_modeling_lfm2.py index ad3425ab4512..88d8583a8aec 100644 --- a/tests/models/lfm2/test_modeling_lfm2.py +++ b/tests/models/lfm2/test_modeling_lfm2.py @@ -51,7 +51,7 @@ class Lfm2ModelTest(CausalLMModelTest, unittest.TestCase): # used in `test_torch_compile_for_training` _torch_compile_train_cls = Lfm2ForCausalLM if is_torch_available() else None - def _get_mamba_cache_shapes(batch_size: int, config): + def _get_mamba_cache_shapes(self, batch_size: int, config): conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) ssm_shape = (1,) return conv_shape, ssm_shape diff --git a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py index ed8a2b8366ee..7068d62f01fc 100644 --- a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py +++ b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py @@ -69,7 +69,7 @@ class Lfm2MoeModelTest(CausalLMModelTest, unittest.TestCase): # used in `test_torch_compile_for_training` _torch_compile_train_cls = Lfm2MoeForCausalLM if is_torch_available() else None - def _get_mamba_cache_shapes(batch_size: int, config): + def _get_mamba_cache_shapes(self, batch_size: int, config): conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) ssm_shape = (1,) return conv_shape, ssm_shape diff --git a/tests/models/lfm2_vl/test_modeling_lfm2_vl.py b/tests/models/lfm2_vl/test_modeling_lfm2_vl.py index 9415b3a55ef7..3c3fd4d844dc 100644 --- a/tests/models/lfm2_vl/test_modeling_lfm2_vl.py +++ b/tests/models/lfm2_vl/test_modeling_lfm2_vl.py @@ -171,7 +171,7 @@ def setUp(self): self, config_class=Lfm2VlConfig, has_text_modality=False, common_properties=common_properties ) - def _get_mamba_cache_shapes(batch_size: int, config): + def _get_mamba_cache_shapes(self, batch_size: int, config): conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) ssm_shape = (1,) return conv_shape, ssm_shape diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index e9704a2bd2dc..6c0b3b865dd9 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -244,7 +244,7 @@ def setUp(self): self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] ) - def _get_mamba_cache_shapes(batch_size: int, config): + def _get_mamba_cache_shapes(self, batch_size: int, config): intermediate_size = config.expand * config.hidden_size conv_shape = ( batch_size, diff --git a/tests/models/nemotron_h/test_modeling_nemotron_h.py b/tests/models/nemotron_h/test_modeling_nemotron_h.py index f90bf0276fbf..7d06790a2016 100644 --- a/tests/models/nemotron_h/test_modeling_nemotron_h.py +++ b/tests/models/nemotron_h/test_modeling_nemotron_h.py @@ -352,7 +352,7 @@ class NemotronHModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester else {} ) - def _get_mamba_cache_shapes(batch_size: int, config): + def _get_mamba_cache_shapes(self, batch_size: int, config): intermediate_size = config.mamba_num_heads * config.mamba_head_dim conv_shape = ( batch_size, diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index 34b5eb9d611f..899c67252583 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -288,7 +288,7 @@ class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi else {} ) - def _get_mamba_cache_shapes(batch_size: int, config): + def _get_mamba_cache_shapes(self, batch_size: int, config): intermediate_size = config.mamba_expand * config.hidden_size conv_shape = (batch_size, intermediate_size, config.mamba_d_conv) ssm_shape = (batch_size, config.n_mamba_heads, intermediate_size // config.n_mamba_heads, config.mamba_d_state) diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 03530ad7db83..d8d2d2cb41b6 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -298,7 +298,7 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix else {} ) - def _get_mamba_cache_shapes(batch_size: int, config): + def _get_mamba_cache_shapes(self, batch_size: int, config): intermediate_size = config.mamba_expand * config.hidden_size conv_shape = ( batch_size, From 18feef288e0c8c4c1fb1ed90e62e18dea92af7e9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 17:00:19 +0100 Subject: [PATCH 28/56] fix --- tests/generation/test_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 906380997763..fbb133896618 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2539,8 +2539,13 @@ def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, c def _get_mamba_cache_shapes(self, batch_size: int, config): # Default mamba cache shape - can vary based on models so this function is convenient to easily check caches - conv_shape = (batch_size, config.intermediate_size, config.conv_kernel) - ssm_shape = (batch_size, config.intermediate_size, config.state_size) + # Note that non-mamba models will NOT have the config fields - it does not matter as they will only use attention cache layers + # so the None default values will not be used + intermediate_size = getattr(config, "intermediate_size", None) + conv_kernel = getattr(config, "conv_kernel", None) + state_size = getattr(config, "state_size", None) + conv_shape = (batch_size, intermediate_size, conv_kernel) + ssm_shape = (batch_size, intermediate_size, state_size) return conv_shape, ssm_shape def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): From ce92f3dc0690e92abcf6f21dcdf8f7096bcb7fb7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 17:05:21 +0100 Subject: [PATCH 29/56] fix broken no_split_modules --- src/transformers/models/zamba/modeling_zamba.py | 2 +- src/transformers/models/zamba2/modeling_zamba2.py | 2 +- src/transformers/models/zamba2/modular_zamba2.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index feefa4d7d55c..88174e03b861 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -641,7 +641,7 @@ class ZambaPreTrainedModel(PreTrainedModel): config: ZambaConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["ZambaAttentionDecoderLayer", "ZambaMambaDecoderLayer"] + _no_split_modules = ["ZambaHybridLayer", "ZambaMambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn = False _supports_sdpa = False diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 899bd737d3c5..42dba140204d 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1065,7 +1065,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): config: Zamba2Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] + _no_split_modules = ["Zamba2HybridLayer", "Zamba2MambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_flex_attn = True diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index c6e7af8c3cba..d2d44be47721 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -805,7 +805,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): config: Zamba2Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] + _no_split_modules = ["Zamba2HybridLayer", "Zamba2MambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_flex_attn = True From ab4472bf9ae8316019ac09cb77f322f52c626431 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 17:16:12 +0100 Subject: [PATCH 30/56] fix --- tests/generation/test_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index fbb133896618..b23b004c75ea 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2567,15 +2567,16 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l # (batch, kv heads, seq_length, head_dim) # Only pure mamba models do not have num_attention_heads defined in config, so it can never be 1 in practice for attention models - num_heads = getattr(config, "num_key_value_heads", getattr(config, "num_attention_heads", 1)) + num_attention_heads = getattr(config, "num_attention_heads", 1) + num_kv_heads = getattr(config, "num_key_value_heads", num_attention_heads) hidden_size = getattr(config, "d_model", config.hidden_size) - head_dim = getattr(config, "head_dim", hidden_size // num_heads) + head_dim = getattr(config, "head_dim", hidden_size // num_attention_heads) # For cross attention cache, the seq_length depends on the model, so we remove that dim attention_shape = ( - (batch_size, num_heads, seq_length, head_dim) + (batch_size, num_kv_heads, seq_length, head_dim) if seq_length is not None - else (batch_size, num_heads, head_dim) + else (batch_size, num_kv_heads, head_dim) ) # For mamba layers From 08e62658a4f9c0865f5697320aa033778ad423c1 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 17:28:37 +0100 Subject: [PATCH 31/56] fixes --- src/transformers/models/bamba/modeling_bamba.py | 2 +- src/transformers/models/bamba/modular_bamba.py | 2 +- src/transformers/models/falcon_h1/configuration_falcon_h1.py | 2 +- .../models/granitemoehybrid/modeling_granitemoehybrid.py | 2 +- tests/models/lfm2/test_modeling_lfm2.py | 2 +- tests/models/lfm2_moe/test_modeling_lfm2_moe.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index d9a8ca68576d..4491f64df39b 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -659,7 +659,7 @@ def torch_forward( gate, hidden_states_B_C, dt = projected_states.split( [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) - hidden_states_B_C = hidden_states_B_C.transpsose(1,2) + hidden_states_B_C = hidden_states_B_C.transpose(1,2) use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index a4ba833f5565..6af5cb7bd5b5 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -416,7 +416,7 @@ def torch_forward( gate, hidden_states_B_C, dt = projected_states.split( [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) - hidden_states_B_C = hidden_states_B_C.transpsose(1,2) + hidden_states_B_C = hidden_states_B_C.transpose(1,2) use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 diff --git a/src/transformers/models/falcon_h1/configuration_falcon_h1.py b/src/transformers/models/falcon_h1/configuration_falcon_h1.py index 19401fbbc632..64cd746e0c93 100644 --- a/src/transformers/models/falcon_h1/configuration_falcon_h1.py +++ b/src/transformers/models/falcon_h1/configuration_falcon_h1.py @@ -133,7 +133,7 @@ def validate_architecture(self): @property def layers_block_type(self): - return ["attention" for i in range(self.num_hidden_layers)] + return ["hybrid" for i in range(self.num_hidden_layers)] __all__ = ["FalconH1Config"] diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 63ba0b0a9dcc..ad1b7776dbd9 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -536,7 +536,7 @@ def torch_forward( gate, hidden_states_B_C, dt = projected_states.split( [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) - hidden_states_B_C = hidden_states_B_C.transpsose(1,2) + hidden_states_B_C = hidden_states_B_C.transpose(1,2) use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 diff --git a/tests/models/lfm2/test_modeling_lfm2.py b/tests/models/lfm2/test_modeling_lfm2.py index 88d8583a8aec..a44c303c6c97 100644 --- a/tests/models/lfm2/test_modeling_lfm2.py +++ b/tests/models/lfm2/test_modeling_lfm2.py @@ -53,7 +53,7 @@ class Lfm2ModelTest(CausalLMModelTest, unittest.TestCase): def _get_mamba_cache_shapes(self, batch_size: int, config): conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) - ssm_shape = (1,) + ssm_shape = (0,) return conv_shape, ssm_shape def test_attention_outputs(self): diff --git a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py index 7068d62f01fc..8264cf4f6118 100644 --- a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py +++ b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py @@ -71,7 +71,7 @@ class Lfm2MoeModelTest(CausalLMModelTest, unittest.TestCase): def _get_mamba_cache_shapes(self, batch_size: int, config): conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) - ssm_shape = (1,) + ssm_shape = (0,) return conv_shape, ssm_shape def test_attention_outputs(self): From 66d071646045693af6d6920a051b84a0bbf2503a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 17:39:26 +0100 Subject: [PATCH 32/56] fix --- src/transformers/cache_utils.py | 3 --- .../models/falcon_mamba/configuration_falcon_mamba.py | 4 ++++ src/transformers/models/falcon_mamba/modular_falcon_mamba.py | 4 ++++ src/transformers/models/mamba/configuration_mamba.py | 4 ++++ src/transformers/models/mamba2/configuration_mamba2.py | 4 ++++ 5 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 42c1d8acf849..7ef92d2e7ef6 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1166,15 +1166,12 @@ def __init__( sliding_window = getattr(decoder_config, "sliding_window", None) or getattr( decoder_config, "attention_chunk_size", None ) - conv_kernel = getattr(decoder_config, "conv_kernel", None) layer_types = getattr(decoder_config, "layer_types", None) if layer_types is None: layer_types = [] for _ in range(decoder_config.num_hidden_layers): if sliding_window is not None: layer_types.append("sliding_attention") - elif conv_kernel is not None: - layer_types.append("mamba") else: layer_types.append("full_attention") # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n) diff --git a/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py b/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py index 0da599e463ad..ba8f96d3ec44 100644 --- a/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py @@ -105,5 +105,9 @@ def __post_init__(self, **kwargs): ) super().__post_init__(**kwargs) + @property + def layer_types(self): + return ["mamba"] * self.num_hidden_layers + __all__ = ["FalconMambaConfig"] diff --git a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py index b53b85465cc2..fb90a60a4a8f 100644 --- a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -101,6 +101,10 @@ class FalconMambaConfig(MambaConfig): use_associative_scan: bool = True mixer_rms_eps: float = 1e-6 + @property + def layer_types(self): + return ["mamba"] * self.num_hidden_layers + def rms_forward(hidden_states, variance_epsilon=1e-6): """ diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 61bb91416b0f..8b3e87d32ae1 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -98,5 +98,9 @@ def __post_init__(self, **kwargs): ) super().__post_init__(**kwargs) + @property + def layer_types(self): + return ["mamba"] * self.num_hidden_layers + __all__ = ["MambaConfig"] diff --git a/src/transformers/models/mamba2/configuration_mamba2.py b/src/transformers/models/mamba2/configuration_mamba2.py index d692c1b90c9b..7e513c6a9402 100644 --- a/src/transformers/models/mamba2/configuration_mamba2.py +++ b/src/transformers/models/mamba2/configuration_mamba2.py @@ -103,5 +103,9 @@ def validate_architecture(self): f"({self.num_heads * self.head_dim})." ) + @property + def layer_types(self): + return ["mamba"] * self.num_hidden_layers + __all__ = ["Mamba2Config"] From c86f9bb896eab8e8ef29bc783df6637f1791f0fe Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 18:27:46 +0100 Subject: [PATCH 33/56] fix --- src/transformers/models/falcon_h1/modeling_falcon_h1.py | 2 +- src/transformers/models/falcon_h1/modular_falcon_h1.py | 2 +- src/transformers/models/nemotron_h/modeling_nemotron_h.py | 4 ++-- src/transformers/models/zamba/configuration_zamba.py | 2 +- src/transformers/models/zamba2/configuration_zamba2.py | 2 +- src/transformers/models/zamba2/modeling_zamba2.py | 4 ++-- src/transformers/models/zamba2/modular_zamba2.py | 4 ++-- 7 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 2911f76a15cb..a15d59717c5d 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -843,7 +843,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - ssm_state = cache_params.update_ssm_states(ssm_state, self.layer_idx) + ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) if self.mamba_rms_norm: scan_output = self.norm(y, gate) diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 985e14dda2c4..bef05f97649e 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -630,7 +630,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - ssm_state = cache_params.update_ssm_states(ssm_state, self.layer_idx) + ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) if self.mamba_rms_norm: scan_output = self.norm(y, gate) diff --git a/src/transformers/models/nemotron_h/modeling_nemotron_h.py b/src/transformers/models/nemotron_h/modeling_nemotron_h.py index 87c456cf4ecf..aa8e2aee597d 100644 --- a/src/transformers/models/nemotron_h/modeling_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modeling_nemotron_h.py @@ -369,7 +369,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention dtype = input_states.dtype # Gated MLP's linear projection if cache_params is not None and cache_params.has_previous_state(self.layer_idx): - projected_states = self.in_proj(input_states.squeeze(1)) + projected_states = self.in_proj(input_states) else: if attention_mask is not None: # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 @@ -387,7 +387,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention if use_precomputed_state: gate = gate.unsqueeze(1) conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) - hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: hidden_states += self.conv1d.bias hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding diff --git a/src/transformers/models/zamba/configuration_zamba.py b/src/transformers/models/zamba/configuration_zamba.py index 2fd99acf2207..588d2efc2cc3 100644 --- a/src/transformers/models/zamba/configuration_zamba.py +++ b/src/transformers/models/zamba/configuration_zamba.py @@ -53,7 +53,7 @@ class ZambaConfig(PreTrainedConfig): model_type = "zamba" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"layer_types": "layers_block_type"} + attribute_map = {"layer_types": "layers_block_type", "head_dim": "attention_head_dim"} vocab_size: int = 32000 tie_word_embeddings: bool = True diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index f9b56fabf6d4..d77491bd984a 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -66,7 +66,7 @@ class Zamba2Config(PreTrainedConfig): ```""" model_type = "zamba2" - attribute_map = {"head_dim": "attention_head_dim", "layer_types": "layers_block_type"} + attribute_map = {"layer_types": "layers_block_type", "head_dim": "attention_head_dim"} keys_to_ignore_at_inference = ["past_key_values"] vocab_size: int = 32000 diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 42dba140204d..576dfc71d016 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -657,7 +657,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention dtype = input_states.dtype # Gated MLP's linear projection if cache_params is not None and cache_params.has_previous_state(self.layer_idx): - projected_states = self.in_proj(input_states.squeeze(1)) + projected_states = self.in_proj(input_states) else: if attention_mask is not None: # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 @@ -675,7 +675,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention if use_precomputed_state: gate = gate.unsqueeze(1) conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) - hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: hidden_states += self.conv1d.bias hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index d2d44be47721..e61455c4fc5b 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -445,7 +445,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention dtype = input_states.dtype # Gated MLP's linear projection if cache_params is not None and cache_params.has_previous_state(self.layer_idx): - projected_states = self.in_proj(input_states.squeeze(1)) + projected_states = self.in_proj(input_states) else: if attention_mask is not None: # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 @@ -463,7 +463,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention if use_precomputed_state: gate = gate.unsqueeze(1) conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) - hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: hidden_states += self.conv1d.bias hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding From ba1b7d688f1a0a2338c82318c5af057c5d78d719 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 18:54:09 +0100 Subject: [PATCH 34/56] fixes --- src/transformers/cache_utils.py | 22 ++++++++++++++----- .../models/nemotron_h/modeling_nemotron_h.py | 1 - .../models/zamba2/modeling_zamba2.py | 1 - .../models/zamba2/modular_zamba2.py | 1 - 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 7ef92d2e7ef6..d6c5c5ad0f3a 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -715,7 +715,13 @@ def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" if self.has_previous_state: self.conv_states = self.conv_states.index_select(0, beam_idx.to(self.device)) - self.ssm_states = self.ssm_states.index_select(0, beam_idx.to(self.device)) + # ssm_states can stay empty sometimes, see e.g. lfm2 which only uses the conv_states + if self.ssm_states.numel() > 0: + self.ssm_states = self.ssm_states.index_select(0, beam_idx.to(self.device)) + + def crop(self, max_length: int): + # We don't crop the mamba cache, so simply do nothing + pass class MambaLayer(MambaCacheLayerMixin): @@ -777,10 +783,16 @@ def __init__(self): DynamicLayer.__init__(self) MambaLayer.__init__(self) - def lazy_initialization(self, states_1: torch.Tensor, states_2: torch.Tensor | None = None) -> None: - MambaLayer.lazy_initialization(self, states_1) - self.keys = torch.tensor([], dtype=self.dtype, device=self.device) - self.values = torch.tensor([], dtype=self.dtype, device=self.device) + def lazy_initialization(self, states1: torch.Tensor, states2: torch.Tensor | None = None) -> None: + MambaLayer.lazy_initialization(self, states1) + DynamicLayer.lazy_initialization(self, states1, states2) + + def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor: + # We need this as lazy initialization may be called first from `update` from the attention part, so the inferred + # conv_kernel_size may not be correct - make sure we grab the correct one during the first call to `update_conv_state` + if not self.has_previous_state: + self.conv_kernel_size = conv_states.shape[-1] + super().update_conv_state(conv_states, **kwargs) def reset(self) -> None: MambaLayer.reset(self) diff --git a/src/transformers/models/nemotron_h/modeling_nemotron_h.py b/src/transformers/models/nemotron_h/modeling_nemotron_h.py index aa8e2aee597d..6aa8fa6a7332 100644 --- a/src/transformers/models/nemotron_h/modeling_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modeling_nemotron_h.py @@ -385,7 +385,6 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention # Convolution sequence transformation if use_precomputed_state: - gate = gate.unsqueeze(1) conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 576dfc71d016..c1c8ebca01f4 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -673,7 +673,6 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention # Convolution sequence transformation if use_precomputed_state: - gate = gate.unsqueeze(1) conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index e61455c4fc5b..fd4f25a5196d 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -461,7 +461,6 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention # Convolution sequence transformation if use_precomputed_state: - gate = gate.unsqueeze(1) conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: From 17856217ad23d3f9f2b839c8e940beb4a2edf1b9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 18:57:36 +0100 Subject: [PATCH 35/56] add layer type --- src/transformers/configuration_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 4d7034f9b1d3..56c519a6e92b 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -69,6 +69,7 @@ "sparse", "dense", "hybrid", # for layers that have both mamba and attention in zamba and zamba2 + "moe", # for nemotron_h, which uses either attention, mamba or moe ) From f684133e0c52fec126edd8c8379c900bde43bfc1 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 19:12:46 +0100 Subject: [PATCH 36/56] oupsi --- src/transformers/cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d6c5c5ad0f3a..4e8ca7bf1401 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -792,7 +792,7 @@ def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor # conv_kernel_size may not be correct - make sure we grab the correct one during the first call to `update_conv_state` if not self.has_previous_state: self.conv_kernel_size = conv_states.shape[-1] - super().update_conv_state(conv_states, **kwargs) + return super().update_conv_state(conv_states, **kwargs) def reset(self) -> None: MambaLayer.reset(self) From 8ca92a96f511eb4bc5b9d30e2c58565a1506e346 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 19:25:20 +0100 Subject: [PATCH 37/56] fix --- tests/models/lfm2_vl/test_modeling_lfm2_vl.py | 2 +- tests/models/nemotron_h/test_modeling_nemotron_h.py | 5 ----- tests/models/zamba2/test_modeling_zamba2.py | 5 ----- 3 files changed, 1 insertion(+), 11 deletions(-) diff --git a/tests/models/lfm2_vl/test_modeling_lfm2_vl.py b/tests/models/lfm2_vl/test_modeling_lfm2_vl.py index 3c3fd4d844dc..f4ef47a97402 100644 --- a/tests/models/lfm2_vl/test_modeling_lfm2_vl.py +++ b/tests/models/lfm2_vl/test_modeling_lfm2_vl.py @@ -173,7 +173,7 @@ def setUp(self): def _get_mamba_cache_shapes(self, batch_size: int, config): conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) - ssm_shape = (1,) + ssm_shape = (0,) return conv_shape, ssm_shape def test_config(self): diff --git a/tests/models/nemotron_h/test_modeling_nemotron_h.py b/tests/models/nemotron_h/test_modeling_nemotron_h.py index 7d06790a2016..0475d3f061c6 100644 --- a/tests/models/nemotron_h/test_modeling_nemotron_h.py +++ b/tests/models/nemotron_h/test_modeling_nemotron_h.py @@ -540,11 +540,6 @@ def test_flash_attn_2_fp32_ln(self): # with attention mask _ = model(dummy_input, attention_mask=dummy_attention_mask) - @unittest.skip(reason="NemotronH has its own special cache type") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch_accelerator def test_flex_attention_with_grads(self): """ diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index d8d2d2cb41b6..18cf3746690d 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -468,11 +468,6 @@ def test_flash_attn_2_fp32_ln(self): # with attention mask _ = model(dummy_input, attention_mask=dummy_attention_mask) - @unittest.skip(reason="Zamba2 has its own special cache type") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch_accelerator def test_flex_attention_with_grads(self): """ From 0d991d74107c72bb607f85e35df7f87e757c25f9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 19:26:30 +0100 Subject: [PATCH 38/56] style --- tests/models/nemotron_h/test_modeling_nemotron_h.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/nemotron_h/test_modeling_nemotron_h.py b/tests/models/nemotron_h/test_modeling_nemotron_h.py index 0475d3f061c6..ffda5868abb5 100644 --- a/tests/models/nemotron_h/test_modeling_nemotron_h.py +++ b/tests/models/nemotron_h/test_modeling_nemotron_h.py @@ -18,7 +18,6 @@ import pytest from huggingface_hub.errors import StrictDataclassClassValidationError -from parameterized import parameterized from transformers import AutoTokenizer, NemotronHConfig, NemotronHForCausalLM, is_torch_available from transformers.testing_utils import ( From 670d09a3578ae1fd2d6bc4d8a8d77308fb0607d8 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 19:41:43 +0100 Subject: [PATCH 39/56] fix --- .../nemotron_h/test_modeling_nemotron_h.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/models/nemotron_h/test_modeling_nemotron_h.py b/tests/models/nemotron_h/test_modeling_nemotron_h.py index ffda5868abb5..b27f781a3db4 100644 --- a/tests/models/nemotron_h/test_modeling_nemotron_h.py +++ b/tests/models/nemotron_h/test_modeling_nemotron_h.py @@ -361,6 +361,43 @@ def _get_mamba_cache_shapes(self, batch_size: int, config): ssm_shape = (batch_size, config.mamba_num_heads, config.mamba_head_dim, config.ssm_state_size) return conv_shape, ssm_shape + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): + # Raise a useful error, asking to explicitly override the method + if not isinstance(past_key_values, DynamicCache): + raise ValueError("The cache does not use the correct Cache") + + # Use the correct config + config = config.get_text_config(decoder=True) + + # (batch, kv heads, seq_length, head_dim) + # Only pure mamba models do not have num_attention_heads defined in config, so it can never be 1 in practice for attention models + num_attention_heads = getattr(config, "num_attention_heads", 1) + num_kv_heads = getattr(config, "num_key_value_heads", num_attention_heads) + hidden_size = getattr(config, "d_model", config.hidden_size) + head_dim = getattr(config, "head_dim", hidden_size // num_attention_heads) + + # For cross attention cache, the seq_length depends on the model, so we remove that dim + attention_shape = (batch_size, num_kv_heads, seq_length, head_dim) + # For mamba layers + conv_shape, ssm_shape = self._get_mamba_cache_shapes(batch_size, config) + + # Check each layer has the correct shape + for layer, layer_type in zip(past_key_values.layers, config.layer_types): + # Moe layers have a default attention cache instantiated, but it stays empty as the layer does not use it + if layer_type == "moe": + self.assertEqual(layer.keys, None) + self.assertEqual(layer.values, None) + # Attention layer cache + elif layer_type == "attention": + self.assertEqual(layer.keys.shape, attention_shape) + self.assertEqual(layer.values.shape, attention_shape) + # Mamba layer cache + elif layer_type == "mamba": + self.assertEqual(layer.conv_states.shape, conv_shape) + self.assertEqual(layer.ssm_states.shape, ssm_shape) + else: + raise ValueError("Unknown layer type.") + def setUp(self): self.model_tester = NemotronHModelTester(self) self.config_tester = ConfigTester( From 63e0b93987bd62e7a7188fbc5b23b744f3451f48 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 20:29:04 +0100 Subject: [PATCH 40/56] fixes --- src/transformers/cache_utils.py | 6 +++++- src/transformers/models/bamba/modeling_bamba.py | 2 +- src/transformers/models/bamba/modular_bamba.py | 2 +- src/transformers/models/falcon_h1/modeling_falcon_h1.py | 2 +- src/transformers/models/falcon_h1/modular_falcon_h1.py | 2 +- .../models/granitemoehybrid/modeling_granitemoehybrid.py | 2 +- .../models/granitemoehybrid/modular_granitemoehybrid.py | 2 +- src/transformers/models/nemotron_h/modeling_nemotron_h.py | 2 +- src/transformers/models/nemotron_h/modular_nemotron_h.py | 2 +- src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py | 2 +- src/transformers/models/qwen3_5/modeling_qwen3_5.py | 2 +- src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py | 2 +- src/transformers/models/qwen3_next/modeling_qwen3_next.py | 2 +- src/transformers/models/qwen3_next/modular_qwen3_next.py | 2 +- tests/models/nemotron_h/test_modeling_nemotron_h.py | 6 +++--- 15 files changed, 21 insertions(+), 17 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 4e8ca7bf1401..54cd06b42079 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1195,7 +1195,11 @@ def __init__( # states they should return - only the mask changes to make them different at the end! if layer_type in ("sliding_attention", "chunked_attention"): layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window)) - elif layer_type in ("mamba", "conv"): + # Note: we want moe layers to be MambaLayer, so that we can correctly grab sequence length etc from attention layers. + # Since moe layers will stay empty (they don't need any cache), we don't want them to collide for mask creation etc + # TODO: maybe use a dummy layer in those cases, or a dictionary {idx: Layer} for self.layers, so that we can skip + # the indices we don't need + elif layer_type in ("mamba", "conv", "moe"): layers.append(MambaLayer()) elif layer_type == "hybrid": layers.append(MambaAndAttentionLayer()) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 4491f64df39b..caf7d08cc56a 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1066,7 +1066,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 6af5cb7bd5b5..a1a440be3597 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -794,7 +794,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index a15d59717c5d..b85d02b00b81 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -1161,7 +1161,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index bef05f97649e..6408200a081e 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -924,7 +924,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index ad1b7776dbd9..11dc138d17dc 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1217,7 +1217,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index ccf8c58bfb51..d0d2329ec202 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -277,7 +277,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None diff --git a/src/transformers/models/nemotron_h/modeling_nemotron_h.py b/src/transformers/models/nemotron_h/modeling_nemotron_h.py index 6aa8fa6a7332..fe85ab4b8065 100644 --- a/src/transformers/models/nemotron_h/modeling_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modeling_nemotron_h.py @@ -1103,7 +1103,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None diff --git a/src/transformers/models/nemotron_h/modular_nemotron_h.py b/src/transformers/models/nemotron_h/modular_nemotron_h.py index 5a849833cc89..05a0b25de2b6 100644 --- a/src/transformers/models/nemotron_h/modular_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modular_nemotron_h.py @@ -465,7 +465,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None diff --git a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py index e3b83bff770a..58ac37ced520 100644 --- a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py @@ -1024,7 +1024,7 @@ def _update_linear_attn_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ linear_attn_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): linear_attn_mask = None diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index 57589b70b94f..efa661e4373e 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1384,7 +1384,7 @@ def _update_linear_attn_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ linear_attn_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): linear_attn_mask = None diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 801156d236c3..243d7bee0fe3 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1509,7 +1509,7 @@ def _update_linear_attn_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ linear_attn_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): linear_attn_mask = None diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 7b45f0ea4838..002d879f8135 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -1068,7 +1068,7 @@ def _update_linear_attn_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ linear_attn_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): linear_attn_mask = None diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index a22b85bf9278..7dbd15e6d91b 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -824,7 +824,7 @@ def _update_linear_attn_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ linear_attn_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): linear_attn_mask = None diff --git a/tests/models/nemotron_h/test_modeling_nemotron_h.py b/tests/models/nemotron_h/test_modeling_nemotron_h.py index b27f781a3db4..1ca4225cb7a9 100644 --- a/tests/models/nemotron_h/test_modeling_nemotron_h.py +++ b/tests/models/nemotron_h/test_modeling_nemotron_h.py @@ -383,10 +383,10 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l # Check each layer has the correct shape for layer, layer_type in zip(past_key_values.layers, config.layer_types): - # Moe layers have a default attention cache instantiated, but it stays empty as the layer does not use it + # Moe layers have a default mamba cache instantiated, but it stays empty as the layer does not use it if layer_type == "moe": - self.assertEqual(layer.keys, None) - self.assertEqual(layer.values, None) + self.assertEqual(layer.conv_states, None) + self.assertEqual(layer.ssm_states, None) # Attention layer cache elif layer_type == "attention": self.assertEqual(layer.keys.shape, attention_shape) From eb018e7e519fc776bdc37226c7d54530ef84b64e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 20:34:09 +0100 Subject: [PATCH 41/56] final fix --- src/transformers/models/bamba/modeling_bamba.py | 7 +------ src/transformers/models/bamba/modular_bamba.py | 7 +------ src/transformers/models/falcon_h1/modeling_falcon_h1.py | 7 +------ src/transformers/models/falcon_h1/modular_falcon_h1.py | 7 +------ .../models/granitemoehybrid/modeling_granitemoehybrid.py | 3 --- .../models/granitemoehybrid/modular_granitemoehybrid.py | 3 --- src/transformers/models/jamba/modeling_jamba.py | 3 --- src/transformers/models/jamba/modular_jamba.py | 3 --- src/transformers/models/nemotron_h/modeling_nemotron_h.py | 3 --- src/transformers/models/nemotron_h/modular_nemotron_h.py | 3 --- src/transformers/models/zamba/modeling_zamba.py | 3 --- src/transformers/models/zamba2/modeling_zamba2.py | 3 --- src/transformers/models/zamba2/modular_zamba2.py | 3 --- 13 files changed, 4 insertions(+), 51 deletions(-) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index caf7d08cc56a..0b599490c5ff 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1049,14 +1049,9 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - - next_cache = None if not use_cache else past_key_values - return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, ) def _update_mamba_mask(self, attention_mask, past_key_values): diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index a1a440be3597..f60d8c405b4d 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -777,14 +777,9 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - - next_cache = None if not use_cache else past_key_values - return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, ) def _update_mamba_mask(self, attention_mask, past_key_values): diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index b85d02b00b81..7aaba6aacb4d 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -1144,14 +1144,9 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - - next_cache = None if not use_cache else past_key_values - return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, ) def _update_mamba_mask(self, attention_mask, past_key_values): diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 6408200a081e..61fab3117d96 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -907,14 +907,9 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - - next_cache = None if not use_cache else past_key_values - return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, ) def _update_mamba_mask(self, attention_mask, past_key_values): diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 11dc138d17dc..baeaf65bbd45 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1202,9 +1202,6 @@ def forward( ) hidden_states = self.norm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index d0d2329ec202..93291275b5b7 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -262,9 +262,6 @@ def forward( ) hidden_states = self.norm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 753e703ff715..5ffd16302e06 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -739,9 +739,6 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index 2a5b21aad005..e24be7640f40 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -584,9 +584,6 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, diff --git a/src/transformers/models/nemotron_h/modeling_nemotron_h.py b/src/transformers/models/nemotron_h/modeling_nemotron_h.py index fe85ab4b8065..41aba6122448 100644 --- a/src/transformers/models/nemotron_h/modeling_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modeling_nemotron_h.py @@ -1088,9 +1088,6 @@ def forward( hidden_states = self.norm_f(hidden_states) - if past_key_values is not None and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, diff --git a/src/transformers/models/nemotron_h/modular_nemotron_h.py b/src/transformers/models/nemotron_h/modular_nemotron_h.py index 05a0b25de2b6..a7433a982f1c 100644 --- a/src/transformers/models/nemotron_h/modular_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modular_nemotron_h.py @@ -450,9 +450,6 @@ def forward( hidden_states = self.norm_f(hidden_states) - if past_key_values is not None and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 88174e03b861..405434004ca9 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -769,9 +769,6 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index c1c8ebca01f4..4de0979a71c0 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1189,9 +1189,6 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values is not None and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index fd4f25a5196d..2f858fc2eeb8 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -962,9 +962,6 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values is not None and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, From f8a070292843da2bb13ce20a4f17c0cd54dc928b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 22:07:39 +0100 Subject: [PATCH 42/56] forgot those qwens --- src/transformers/cache_utils.py | 2 +- .../olmo_hybrid/modeling_olmo_hybrid.py | 6 +- .../models/olmo_hybrid/modular_olmo_hybrid.py | 62 ++++++++-- .../models/qwen3_5/modeling_qwen3_5.py | 109 ++--------------- .../models/qwen3_5/modular_qwen3_5.py | 25 ++-- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 111 ++--------------- .../models/qwen3_5_moe/modular_qwen3_5_moe.py | 5 - .../models/qwen3_next/modeling_qwen3_next.py | 114 ++---------------- .../models/qwen3_next/modular_qwen3_next.py | 114 ++---------------- tests/models/qwen3_5/test_modeling_qwen3_5.py | 39 ++---- .../qwen3_5_moe/test_modeling_qwen3_5_moe.py | 39 ++---- .../qwen3_next/test_modeling_qwen3_next.py | 39 ++---- 12 files changed, 146 insertions(+), 519 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 54cd06b42079..1c4d7ed22513 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1199,7 +1199,7 @@ def __init__( # Since moe layers will stay empty (they don't need any cache), we don't want them to collide for mask creation etc # TODO: maybe use a dummy layer in those cases, or a dictionary {idx: Layer} for self.layers, so that we can skip # the indices we don't need - elif layer_type in ("mamba", "conv", "moe"): + elif layer_type in ("mamba", "conv", "linear_attention", "moe"): layers.append(MambaLayer()) elif layer_type == "hybrid": layers.append(MambaAndAttentionLayer()) diff --git a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py index 58ac37ced520..bf72e06d94eb 100644 --- a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py @@ -61,8 +61,7 @@ class OlmoHybridDynamicCache: """ Cache for hybrid model supporting both attention KV cache and linear attention state. - Inherits from Qwen3NextDynamicCache. The main difference is that this cache - stores separate conv states for q, k, v (instead of a single conv_states list). + The main difference is that this cache stores separate conv states for q, k, v (instead of a single conv_states). """ is_compileable = False @@ -155,7 +154,6 @@ def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: kv_length = query_length + past_seen_tokens return kv_length, kv_offset - @property def has_previous_state(self): """We have a previous state if the last linear (conv) layer was already updated.""" return self.conv_states_q[self.last_linear_layer] is not None @@ -725,7 +723,7 @@ def forward( batch_size, seq_len, _ = hidden_states.shape use_cache = cache_params is not None - use_precomputed = use_cache and getattr(cache_params, "has_previous_state", False) and seq_len == 1 + use_precomputed = use_cache and cache_params.has_previous_state() and seq_len == 1 conv_state_q = cache_params.conv_states_q[self.layer_idx] if cache_params else None conv_state_k = cache_params.conv_states_k[self.layer_idx] if cache_params else None diff --git a/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py index 74f6a70ce6de..089f29309007 100644 --- a/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py @@ -16,6 +16,7 @@ import math from collections.abc import Callable +from typing import Any import torch import torch.nn as nn @@ -48,7 +49,6 @@ eager_attention_forward, ) from ..qwen3_next.modeling_qwen3_next import ( - Qwen3NextDynamicCache, Qwen3NextModel, Qwen3NextPreTrainedModel, Qwen3NextRMSNormGated, @@ -189,22 +189,49 @@ def validate_architecture(self): raise ValueError("OLMoHybrid expects at least one attention layer.") -class OlmoHybridDynamicCache(Qwen3NextDynamicCache): +class OlmoHybridDynamicCache: """ Cache for hybrid model supporting both attention KV cache and linear attention state. - Inherits from Qwen3NextDynamicCache. The main difference is that this cache - stores separate conv states for q, k, v (instead of a single conv_states list). + The main difference is that this cache stores separate conv states for q, k, v (instead of a single conv_states). """ + is_compileable = False + def __init__(self, config: OlmoHybridConfig): - super().__init__(config) - del self.conv_states + super().__init__() + self.layer_types = config.layer_types + self.transformer_layers = [ + i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" + ] + self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") + self.recurrent_states = [None for _ in range(config.num_hidden_layers)] + self.key_cache = [None for _ in range(config.num_hidden_layers)] + self.value_cache = [None for _ in range(config.num_hidden_layers)] # Replace single conv_states with separate q, k, v conv states self.conv_states_q = [None for _ in range(config.num_hidden_layers)] self.conv_states_k = [None for _ in range(config.num_hidden_layers)] self.conv_states_v = [None for _ in range(config.num_hidden_layers)] + def __len__(self): + return len(self.layer_types) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.key_cache[layer_idx] is None: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" batch_size = beam_idx.shape[0] @@ -240,8 +267,27 @@ def reorder_cache(self, beam_idx: torch.LongTensor): 0, beam_idx.to(device) ) - @property + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. + """ + kv_offset = 0 + past_seen_tokens = self.get_seq_length(layer_idx) + kv_length = query_length + past_seen_tokens + return kv_length, kv_offset + def has_previous_state(self): + """We have a previous state if the last linear (conv) layer was already updated.""" return self.conv_states_q[self.last_linear_layer] is not None @@ -495,7 +541,7 @@ def forward( batch_size, seq_len, _ = hidden_states.shape use_cache = cache_params is not None - use_precomputed = use_cache and getattr(cache_params, "has_previous_state", False) and seq_len == 1 + use_precomputed = use_cache and cache_params.has_previous_state() and seq_len == 1 conv_state_q = cache_params.conv_states_q[self.layer_idx] if cache_params else None conv_state_k = cache_params.conv_states_k[self.layer_idx] if cache_params else None diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index efa661e4373e..feda0bd67a2b 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -29,7 +29,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernelized_func from ...masking_utils import create_causal_mask @@ -66,95 +66,6 @@ logger = logging.get_logger(__name__) -class Qwen3_5DynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention - cache (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config: Qwen3_5Config): - super().__init__() - self.layer_types = config.layer_types - self.transformer_layers = [ - i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" - ] - self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") - - # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference - self.conv_states = [None for _ in range(config.num_hidden_layers)] - self.recurrent_states = [None for _ in range(config.num_hidden_layers)] - self.key_cache = [None for _ in range(config.num_hidden_layers)] - self.value_cache = [None for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.layer_types) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if self.key_cache[layer_idx] is None: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx] is not None: - device = self.key_cache[layer_idx].device - beam_idx = beam_idx.to(device) - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) - - if self.conv_states[layer_idx] is not None: - device = self.conv_states[layer_idx].device - beam_idx = beam_idx.to(device) - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx) - self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. - """ - kv_offset = 0 - past_seen_tokens = self.get_seq_length(layer_idx) - kv_length = query_length + past_seen_tokens - return kv_length, kv_offset - - @property - def has_previous_state(self): - """We have a previous state if the last linear (conv) layer was already updated.""" - return self.conv_states[self.last_linear_layer] is not None - - class Qwen3_5VisionRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -511,7 +422,7 @@ def __init__(self, config: Qwen3_5Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - cache_params: Qwen3_5DynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) @@ -519,12 +430,14 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1 + use_precomputed_states = ( + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 + ) # getting projected states from cache if it exists - if cache_params is not None: - conv_state = cache_params.conv_states[self.layer_idx] - recurrent_state = cache_params.recurrent_states[self.layer_idx] + if use_precomputed_states: + conv_state = cache_params.layers[self.layer_idx].conv_states + recurrent_state = cache_params.layers[self.layer_idx].ssm_states mixed_qkv = self.in_proj_qkv(hidden_states) mixed_qkv = mixed_qkv.transpose(1, 2) @@ -548,7 +461,7 @@ def forward( else: if cache_params is not None: conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - cache_params.conv_states[self.layer_idx] = conv_state + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -608,7 +521,7 @@ def forward( # Update cache if cache_params is not None: - cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + cache_params.update_ssm_state(last_recurrent_state, self.layer_idx) # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) @@ -1328,7 +1241,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = Qwen3_5DynamicCache(config=self.config) + past_key_values = DynamicCache(config=self.config) # the hard coded `4` is for text, temporal, height and width. if position_ids is None: diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index b76991426f13..69e9bdb0faf2 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -21,7 +21,7 @@ from torch import nn from ... import initialization as init -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling @@ -34,7 +34,6 @@ from ..qwen3_next.configuration_qwen3_next import Qwen3NextConfig from ..qwen3_next.modeling_qwen3_next import ( Qwen3NextAttention, - Qwen3NextDynamicCache, Qwen3NextGatedDeltaNet, Qwen3NextMLP, Qwen3NextModel, @@ -160,10 +159,6 @@ class Qwen3_5Config(Qwen3VLConfig): vision_end_token_id: int = 248054 -class Qwen3_5DynamicCache(Qwen3NextDynamicCache): - pass - - class Qwen3_5VisionRotaryEmbedding(Qwen3VLVisionRotaryEmbedding): pass @@ -212,7 +207,7 @@ def fix_query_key_value_ordering(self): def forward( self, hidden_states: torch.Tensor, - cache_params: Qwen3_5DynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) @@ -220,12 +215,14 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1 + use_precomputed_states = ( + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 + ) # getting projected states from cache if it exists - if cache_params is not None: - conv_state = cache_params.conv_states[self.layer_idx] - recurrent_state = cache_params.recurrent_states[self.layer_idx] + if use_precomputed_states: + conv_state = cache_params.layers[self.layer_idx].conv_states + recurrent_state = cache_params.layers[self.layer_idx].ssm_states mixed_qkv = self.in_proj_qkv(hidden_states) mixed_qkv = mixed_qkv.transpose(1, 2) @@ -249,7 +246,7 @@ def forward( else: if cache_params is not None: conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - cache_params.conv_states[self.layer_idx] = conv_state + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -309,7 +306,7 @@ def forward( # Update cache if cache_params is not None: - cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + cache_params.update_ssm_state(last_recurrent_state, self.layer_idx) # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) @@ -501,7 +498,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = Qwen3_5DynamicCache(config=self.config) + past_key_values = DynamicCache(config=self.config) # the hard coded `4` is for text, temporal, height and width. if position_ids is None: diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 243d7bee0fe3..3c578e6fe836 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -29,7 +29,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_experts_implementation, use_kernelized_func from ...masking_utils import create_causal_mask @@ -173,95 +173,6 @@ def apply_interleaved_mrope(self, freqs, mrope_section): return freqs_t -class Qwen3_5MoeDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention - cache (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config: Qwen3_5MoeConfig): - super().__init__() - self.layer_types = config.layer_types - self.transformer_layers = [ - i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" - ] - self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") - - # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference - self.conv_states = [None for _ in range(config.num_hidden_layers)] - self.recurrent_states = [None for _ in range(config.num_hidden_layers)] - self.key_cache = [None for _ in range(config.num_hidden_layers)] - self.value_cache = [None for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.layer_types) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if self.key_cache[layer_idx] is None: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx] is not None: - device = self.key_cache[layer_idx].device - beam_idx = beam_idx.to(device) - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) - - if self.conv_states[layer_idx] is not None: - device = self.conv_states[layer_idx].device - beam_idx = beam_idx.to(device) - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx) - self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. - """ - kv_offset = 0 - past_seen_tokens = self.get_seq_length(layer_idx) - kv_length = query_length + past_seen_tokens - return kv_length, kv_offset - - @property - def has_previous_state(self): - """We have a previous state if the last linear (conv) layer was already updated.""" - return self.conv_states[self.last_linear_layer] is not None - - class Qwen3_5MoeRMSNormGated(nn.Module): def __init__(self, hidden_size, eps=1e-6, **kwargs): super().__init__() @@ -512,7 +423,7 @@ def __init__(self, config: Qwen3_5MoeConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - cache_params: Qwen3_5MoeDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) @@ -520,12 +431,14 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1 + use_precomputed_states = ( + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 + ) # getting projected states from cache if it exists - if cache_params is not None: - conv_state = cache_params.conv_states[self.layer_idx] - recurrent_state = cache_params.recurrent_states[self.layer_idx] + if use_precomputed_states: + conv_state = cache_params.layers[self.layer_idx].conv_states + recurrent_state = cache_params.layers[self.layer_idx].ssm_states mixed_qkv = self.in_proj_qkv(hidden_states) mixed_qkv = mixed_qkv.transpose(1, 2) @@ -549,7 +462,7 @@ def forward( else: if cache_params is not None: conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - cache_params.conv_states[self.layer_idx] = conv_state + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -609,7 +522,7 @@ def forward( # Update cache if cache_params is not None: - cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + cache_params.update_ssm_state(last_recurrent_state, self.layer_idx) # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) @@ -1453,7 +1366,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = Qwen3_5MoeDynamicCache(config=self.config) + past_key_values = DynamicCache(config=self.config) # the hard coded `4` is for text, temporal, height and width. if position_ids is None: @@ -2005,7 +1918,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Qwen3_5MoeDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, diff --git a/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py index f8684ddd83db..312b22bc88ed 100644 --- a/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py @@ -35,7 +35,6 @@ from ..qwen3_next.modeling_qwen3_next import ( Qwen3NextAttention, Qwen3NextDecoderLayer, - Qwen3NextDynamicCache, Qwen3NextExperts, Qwen3NextForCausalLM, Qwen3NextPreTrainedModel, @@ -156,10 +155,6 @@ class Qwen3_5MoeTextRotaryEmbedding(Qwen3_5TextRotaryEmbedding): pass -class Qwen3_5MoeDynamicCache(Qwen3NextDynamicCache): - pass - - class Qwen3_5MoeGatedDeltaNet(Qwen3_5GatedDeltaNet): pass diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 002d879f8135..da31a5eb8e7d 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -19,7 +19,7 @@ # limitations under the License. from collections.abc import Callable -from typing import Any, Optional +from typing import Optional import torch import torch.nn.functional as F @@ -27,7 +27,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_experts_implementation, use_kernelized_func from ...masking_utils import create_causal_mask @@ -82,95 +82,6 @@ def forward(self, hidden_states, gate=None): return hidden_states.to(input_dtype) -class Qwen3NextDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention - cache (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config: Qwen3NextConfig): - super().__init__() - self.layer_types = config.layer_types - self.transformer_layers = [ - i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" - ] - self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") - - # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference - self.conv_states = [None for _ in range(config.num_hidden_layers)] - self.recurrent_states = [None for _ in range(config.num_hidden_layers)] - self.key_cache = [None for _ in range(config.num_hidden_layers)] - self.value_cache = [None for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.layer_types) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if self.key_cache[layer_idx] is None: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx] is not None: - device = self.key_cache[layer_idx].device - beam_idx = beam_idx.to(device) - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) - - if self.conv_states[layer_idx] is not None: - device = self.conv_states[layer_idx].device - beam_idx = beam_idx.to(device) - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx) - self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. - """ - kv_offset = 0 - past_seen_tokens = self.get_seq_length(layer_idx) - kv_length = query_length + past_seen_tokens - return kv_length, kv_offset - - @property - def has_previous_state(self): - """We have a previous state if the last linear (conv) layer was already updated.""" - return self.conv_states[self.last_linear_layer] is not None - - class Qwen3NextRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -681,7 +592,7 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): def forward( self, hidden_states: torch.Tensor, - cache_params: Qwen3NextDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) @@ -689,12 +600,14 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1 + use_precomputed_states = ( + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 + ) # getting projected states from cache if it exists - if cache_params is not None: - conv_state = cache_params.conv_states[self.layer_idx] - recurrent_state = cache_params.recurrent_states[self.layer_idx] + if use_precomputed_states: + conv_state = cache_params.layers[self.layer_idx].conv_states + recurrent_state = cache_params.layers[self.layer_idx].ssm_states projected_states_qkvz = self.in_proj_qkvz(hidden_states) projected_states_ba = self.in_proj_ba(hidden_states) @@ -717,7 +630,7 @@ def forward( else: if cache_params is not None: conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - cache_params.conv_states[self.layer_idx] = conv_state + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -761,7 +674,6 @@ def forward( output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) - else: core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, @@ -776,7 +688,7 @@ def forward( # Update cache if cache_params is not None: - cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + cache_params.update_ssm_state(last_recurrent_state, self.layer_idx) z_shape_og = z.shape # reshape input data into 2D tensor @@ -1021,7 +933,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = Qwen3NextDynamicCache(config=self.config) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1182,7 +1094,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Qwen3NextDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index 7dbd15e6d91b..80fa61c3229b 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -14,7 +14,7 @@ """PyTorch Qwen3-Next model.""" from collections.abc import Callable -from typing import Any, Optional +from typing import Optional import torch import torch.nn.functional as F @@ -22,7 +22,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -92,95 +92,6 @@ def forward(self, hidden_states, gate=None): return hidden_states.to(input_dtype) -class Qwen3NextDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention - cache (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config: Qwen3NextConfig): - super().__init__() - self.layer_types = config.layer_types - self.transformer_layers = [ - i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" - ] - self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") - - # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference - self.conv_states = [None for _ in range(config.num_hidden_layers)] - self.recurrent_states = [None for _ in range(config.num_hidden_layers)] - self.key_cache = [None for _ in range(config.num_hidden_layers)] - self.value_cache = [None for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.layer_types) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if self.key_cache[layer_idx] is None: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx] is not None: - device = self.key_cache[layer_idx].device - beam_idx = beam_idx.to(device) - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) - - if self.conv_states[layer_idx] is not None: - device = self.conv_states[layer_idx].device - beam_idx = beam_idx.to(device) - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx) - self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. - """ - kv_offset = 0 - past_seen_tokens = self.get_seq_length(layer_idx) - kv_length = query_length + past_seen_tokens - return kv_length, kv_offset - - @property - def has_previous_state(self): - """We have a previous state if the last linear (conv) layer was already updated.""" - return self.conv_states[self.last_linear_layer] is not None - - class Qwen3NextRotaryEmbedding(Gemma2RotaryEmbedding): @staticmethod def compute_default_rope_parameters( @@ -520,7 +431,7 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): def forward( self, hidden_states: torch.Tensor, - cache_params: Qwen3NextDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) @@ -528,12 +439,14 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1 + use_precomputed_states = ( + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 + ) # getting projected states from cache if it exists - if cache_params is not None: - conv_state = cache_params.conv_states[self.layer_idx] - recurrent_state = cache_params.recurrent_states[self.layer_idx] + if use_precomputed_states: + conv_state = cache_params.layers[self.layer_idx].conv_states + recurrent_state = cache_params.layers[self.layer_idx].ssm_states projected_states_qkvz = self.in_proj_qkvz(hidden_states) projected_states_ba = self.in_proj_ba(hidden_states) @@ -556,7 +469,7 @@ def forward( else: if cache_params is not None: conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - cache_params.conv_states[self.layer_idx] = conv_state + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -600,7 +513,6 @@ def forward( output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) - else: core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, @@ -615,7 +527,7 @@ def forward( # Update cache if cache_params is not None: - cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + cache_params.update_ssm_state(last_recurrent_state, self.layer_idx) z_shape_og = z.shape # reshape input data into 2D tensor @@ -777,7 +689,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = Qwen3NextDynamicCache(config=self.config) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -841,7 +753,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Qwen3NextDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, diff --git a/tests/models/qwen3_5/test_modeling_qwen3_5.py b/tests/models/qwen3_5/test_modeling_qwen3_5.py index 191a8cf788e4..04b815c04422 100644 --- a/tests/models/qwen3_5/test_modeling_qwen3_5.py +++ b/tests/models/qwen3_5/test_modeling_qwen3_5.py @@ -71,35 +71,16 @@ class Qwen3_5TextModelTest(CausalLMModelTest, unittest.TestCase): config_class = Qwen3_5TextConfig model_split_percents = [0.5, 0.8, 0.9] - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - "Qwen3.5 has a special Cache as it alternates with gated deltanet layers" - self.assertIsInstance(past_key_values, Qwen3_5DynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - expected_shape = (batch_size, num_heads, seq_length, head_dim) - - attention_layer_indices = past_key_values.transformer_layers - self.assertListEqual( - [past_key_values.key_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - self.assertListEqual( - [past_key_values.value_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - - def _check_caches_are_equal(self, cache1, cache2): - "Qwen3.5 has a special Cache as it alternates with gated deltanet layers" - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - if cache1.key_cache[idx] is not None: - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) + def _get_mamba_cache_shapes(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + num_k_heads = config.linear_num_key_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + intermediate_size = 2 * num_k_heads * head_k_dim + num_v_heads * head_v_dim + + conv_shape = (batch_size, intermediate_size, config.linear_conv_kernel_dim) + ssm_shape = (batch_size, num_v_heads, head_k_dim, head_v_dim) + return conv_shape, ssm_shape def test_attention_outputs(self): "Needs to be overwritten as Qwen3.5 alternates between attention layers and gated deltanet layers." diff --git a/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py b/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py index 27cef6196313..c249d77d045a 100644 --- a/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py +++ b/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py @@ -74,35 +74,16 @@ class Qwen3_5MoeTextModelTest(CausalLMModelTest, unittest.TestCase): model_tester_class = Qwen3_5MoeTextModelTester config_class = Qwen3_5MoeTextConfig - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - "Qwen3.5 Moe has a special Cache as it alternates with gated deltanet layers" - self.assertIsInstance(past_key_values, Qwen3_5MoeDynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - expected_shape = (batch_size, num_heads, seq_length, head_dim) - - attention_layer_indices = past_key_values.transformer_layers - self.assertListEqual( - [past_key_values.key_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - self.assertListEqual( - [past_key_values.value_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - - def _check_caches_are_equal(self, cache1, cache2): - "Qwen3.5 Moe has a special Cache as it alternates with gated deltanet layers" - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - if cache1.key_cache[idx] is not None: - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) + def _get_mamba_cache_shapes(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + num_k_heads = config.linear_num_key_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + intermediate_size = 2 * num_k_heads * head_k_dim + num_v_heads * head_v_dim + + conv_shape = (batch_size, intermediate_size, config.linear_conv_kernel_dim) + ssm_shape = (batch_size, num_v_heads, head_k_dim, head_v_dim) + return conv_shape, ssm_shape def test_attention_outputs(self): "Needs to be overwritten as Qwen3.5 Moe alternates between attention layers and gated deltanet layers." diff --git a/tests/models/qwen3_next/test_modeling_qwen3_next.py b/tests/models/qwen3_next/test_modeling_qwen3_next.py index 29e5f51705de..f597b8e28192 100644 --- a/tests/models/qwen3_next/test_modeling_qwen3_next.py +++ b/tests/models/qwen3_next/test_modeling_qwen3_next.py @@ -25,10 +25,8 @@ import torch from transformers import ( - Cache, Qwen3NextModel, ) - from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextDynamicCache from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester from ...test_modeling_common import ( @@ -59,35 +57,16 @@ def __init__(self, parent): class Qwen3NextModelTest(CausalLMModelTest, unittest.TestCase): model_tester_class = Qwen3NextModelTester - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - "Qwen3-Next has a special Cache as it alternates with gated deltanet layers" - self.assertIsInstance(past_key_values, Qwen3NextDynamicCache) + def _get_mamba_cache_shapes(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + num_k_heads = config.linear_num_key_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + intermediate_size = 2 * num_k_heads * head_k_dim + num_v_heads * head_v_dim - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - expected_shape = (batch_size, num_heads, seq_length, head_dim) - - attention_layer_indices = past_key_values.transformer_layers - self.assertListEqual( - [past_key_values.key_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - self.assertListEqual( - [past_key_values.value_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - - def _check_caches_are_equal(self, cache1: Cache, cache2: Cache): - "Qwen3-Next has a special Cache as it alternates with gated deltanet layers" - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - if cache1.key_cache[idx] is not None: - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) + conv_shape = (batch_size, intermediate_size, config.linear_conv_kernel_dim) + ssm_shape = (batch_size, num_v_heads, head_k_dim, head_v_dim) + return conv_shape, ssm_shape def test_attention_outputs(self): "Needs to be overwritten as Qwen3-Next alternates between attention layers and gated deltanet layers." From fc27c37fab09f412575f0defc329652d4e2cadf1 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Mar 2026 22:09:08 +0100 Subject: [PATCH 43/56] tests --- tests/models/qwen3_5/test_modeling_qwen3_5.py | 38 +++++-------------- .../qwen3_5_moe/test_modeling_qwen3_5_moe.py | 38 +++++-------------- 2 files changed, 18 insertions(+), 58 deletions(-) diff --git a/tests/models/qwen3_5/test_modeling_qwen3_5.py b/tests/models/qwen3_5/test_modeling_qwen3_5.py index 04b815c04422..839b4f6c7fc4 100644 --- a/tests/models/qwen3_5/test_modeling_qwen3_5.py +++ b/tests/models/qwen3_5/test_modeling_qwen3_5.py @@ -46,7 +46,6 @@ Qwen3_5TextConfig, Qwen3_5TextModel, ) - from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DynamicCache class Qwen3_5TextModelTester(CausalLMModelTester): @@ -300,35 +299,16 @@ def setUp(self): def test_config(self): self.config_tester.run_common_tests() - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - "Qwen3.5 has a special Cache as it alternates with gated deltanet layers" - self.assertIsInstance(past_key_values, Qwen3_5DynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - expected_shape = (batch_size, num_heads, seq_length, head_dim) - - attention_layer_indices = past_key_values.transformer_layers - self.assertListEqual( - [past_key_values.key_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - self.assertListEqual( - [past_key_values.value_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - - def _check_caches_are_equal(self, cache1, cache2): - "Qwen3.5 has a special Cache as it alternates with gated deltanet layers" - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") + def _get_mamba_cache_shapes(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + num_k_heads = config.linear_num_key_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + intermediate_size = 2 * num_k_heads * head_k_dim + num_v_heads * head_v_dim - num_layers = len(cache1) - for idx in range(num_layers): - if cache1.key_cache[idx] is not None: - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) + conv_shape = (batch_size, intermediate_size, config.linear_conv_kernel_dim) + ssm_shape = (batch_size, num_v_heads, head_k_dim, head_v_dim) + return conv_shape, ssm_shape def test_attention_outputs(self): "Needs to be overwritten as Qwen3.5 alternates between attention layers and gated deltanet layers." diff --git a/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py b/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py index c249d77d045a..55b58a1aea34 100644 --- a/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py +++ b/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py @@ -51,7 +51,6 @@ Qwen3_5MoeTextConfig, Qwen3_5MoeTextModel, ) - from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeDynamicCache class Qwen3_5MoeTextModelTester(CausalLMModelTester): @@ -382,35 +381,16 @@ def setUp(self): def test_config(self): self.config_tester.run_common_tests() - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - "Qwen3.5 Moe has a special Cache as it alternates with gated deltanet layers" - self.assertIsInstance(past_key_values, Qwen3_5MoeDynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - expected_shape = (batch_size, num_heads, seq_length, head_dim) - - attention_layer_indices = past_key_values.transformer_layers - self.assertListEqual( - [past_key_values.key_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - self.assertListEqual( - [past_key_values.value_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - - def _check_caches_are_equal(self, cache1, cache2): - "Qwen3.5 Moe has a special Cache as it alternates with gated deltanet layers" - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") + def _get_mamba_cache_shapes(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + num_k_heads = config.linear_num_key_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + intermediate_size = 2 * num_k_heads * head_k_dim + num_v_heads * head_v_dim - num_layers = len(cache1) - for idx in range(num_layers): - if cache1.key_cache[idx] is not None: - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) + conv_shape = (batch_size, intermediate_size, config.linear_conv_kernel_dim) + ssm_shape = (batch_size, num_v_heads, head_k_dim, head_v_dim) + return conv_shape, ssm_shape def test_attention_outputs(self): "Needs to be overwritten as Qwen3.5 Moe alternates between attention layers and gated deltanet layers." From 9c616dd759260940dd407fd17814b0b26bcd4656 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 25 Mar 2026 11:35:33 +0100 Subject: [PATCH 44/56] offloading --- tests/models/zamba/test_modeling_zamba.py | 13 +++++++++++++ tests/models/zamba2/test_modeling_zamba2.py | 1 + tests/test_modeling_common.py | 2 +- 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index 899c67252583..037f157f79db 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -287,6 +287,7 @@ class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi if is_torch_available() else {} ) + model_split_percents = [0.5, 0.8, 0.9] def _get_mamba_cache_shapes(self, batch_size: int, config): intermediate_size = config.mamba_expand * config.hidden_size @@ -298,6 +299,18 @@ def setUp(self): self.model_tester = ZambaModelTester(self) self.config_tester = ConfigTester(self, config_class=ZambaConfig, hidden_size=32) + @unittest.skip( + "Same as zamba2 -> investigate, it's probably due to their tied weights that accelerate does not work" + ) + def test_disk_offload_bin(self): + pass + + @unittest.skip( + "Same as zamba2 -> investigate, it's probably due to their tied weights that accelerate does not work" + ) + def test_disk_offload_safetensors(self): + pass + def test_config(self): self.config_tester.run_common_tests() diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 18cf3746690d..1695cc9d5556 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -297,6 +297,7 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix if is_torch_available() else {} ) + model_split_percents = [0.5, 0.8, 0.9] def _get_mamba_cache_shapes(self, batch_size: int, config): intermediate_size = config.mamba_expand * config.hidden_size diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 13b81855aaa6..5ae3a88fb3b8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2945,7 +2945,7 @@ def test_cpu_offload(self): model.cpu().save_pretrained(tmp_dir) for max_size in max_gpu_sizes: - max_memory = {0: max_size, "cpu": model_size * 2} + max_memory = {0: max_size, "cpu": model_size * 3} new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) # Making sure part of the model will actually end up offloaded self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"}) From 6f85f54912665d89fb49816ff7000ff1dfa89654 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 25 Mar 2026 14:46:11 +0100 Subject: [PATCH 45/56] much better static shape native design --- src/transformers/cache_utils.py | 81 ++++++++++++------- src/transformers/generation/utils.py | 42 +++++----- .../models/mamba2/modeling_mamba2.py | 1 + tests/models/mamba/test_modeling_mamba.py | 8 +- 4 files changed, 78 insertions(+), 54 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 1c4d7ed22513..df9057d57723 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -672,19 +672,23 @@ def _dequantize(self, qtensor): class MambaCacheLayerMixin(ABC): """Base, abstract class for a mamba single layer's cache.""" - is_compileable = False + # All shapes are static by essence in a Mamba layer, so it is compileable + is_compileable = True def __init__(self): self.conv_states: torch.Tensor | None = None self.ssm_states: torch.Tensor | None = None - self.is_initialized = False + self.is_conv_states_initialized = False + self.is_ssm_states_initialized = False self.has_previous_state = False def __repr__(self): return f"{self.__class__.__name__}" @abstractmethod - def lazy_initialization(self, conv_states: torch.Tensor) -> None: ... + def lazy_initialization( + self, conv_states: torch.Tensor | None = None, ssm_states: torch.Tensor | None = None + ) -> None: ... @abstractmethod def update_conv_state(self, conv_states: torch.Tensor) -> torch.Tensor: ... @@ -694,44 +698,55 @@ def update_ssm_state(self, ssm_states: torch.Tensor) -> torch.Tensor: ... def offload(self): """Offload this layer's data to CPU device.""" - if self.is_initialized: + if self.is_conv_states_initialized: self.conv_states = self.conv_states.to("cpu", non_blocking=True) + if self.is_ssm_states_initialized: self.ssm_states = self.ssm_states.to("cpu", non_blocking=True) def prefetch(self): """In case of layer offloading, this allows to move the data back to the layer's device ahead of time.""" - if self.is_initialized and self.conv_states.device != self.device: + if self.is_conv_states_initialized and self.conv_states.device != self.device: self.conv_states = self.conv_states.to(self.device, non_blocking=True) + if self.is_ssm_states_initialized and self.ssm_states.device != self.device: self.ssm_states = self.ssm_states.to(self.device, non_blocking=True) def reset(self) -> None: """Resets the cache values while preserving the objects""" - if self.is_initialized: + if self.is_conv_states_initialized: self.conv_states.zero_() + if self.is_ssm_states_initialized: self.ssm_states.zero_() self.has_previous_state = False def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" - if self.has_previous_state: + if self.is_conv_states_initialized: self.conv_states = self.conv_states.index_select(0, beam_idx.to(self.device)) - # ssm_states can stay empty sometimes, see e.g. lfm2 which only uses the conv_states - if self.ssm_states.numel() > 0: - self.ssm_states = self.ssm_states.index_select(0, beam_idx.to(self.device)) + # ssm_states can stay empty sometimes, see e.g. lfm2 which only uses the conv_states + if self.is_ssm_states_initialized: + self.ssm_states = self.ssm_states.index_select(0, beam_idx.to(self.device)) def crop(self, max_length: int): - # We don't crop the mamba cache, so simply do nothing + # We don't crop the mamba cache, so simply do nothing here pass class MambaLayer(MambaCacheLayerMixin): - def lazy_initialization(self, conv_states: torch.Tensor) -> None: - self.dtype, self.device = conv_states.dtype, conv_states.device - # Even if prefill is larfer/shorter than the conv_size, the tensor is always either padded or truncated - self.conv_kernel_size = conv_states.shape[-1] - self.conv_states = torch.tensor([], dtype=self.dtype, device=self.device) - self.ssm_states = torch.tensor([], dtype=self.dtype, device=self.device) - self.is_initialized = True + def lazy_initialization( + self, conv_states: torch.Tensor | None = None, ssm_states: torch.Tensor | None = None + ) -> None: + # Here, we will lazy init both states separately, each in their own update function + if conv_states is not None: + self.dtype, self.device = conv_states.dtype, conv_states.device + # Even if prefill is larfer/shorter than the conv_size, the tensor is always either padded or truncated + self.max_batch_size, self.conv_kernel_size = conv_states.shape[0], conv_states.shape[-1] + # The shape is always static, so we init as such + self.conv_states = torch.zeros_like(conv_states, dtype=self.dtype, device=self.device) + self.is_conv_states_initialized = True + if ssm_states is not None: + # The shape is always static, so we init as such + self.ssm_states = torch.zeros_like(ssm_states, dtype=self.dtype, device=self.device) + self.is_ssm_states_initialized = True def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor: """ @@ -744,8 +759,8 @@ def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor `torch.Tensor`: The updated conv states. """ # Lazy initialization - if not self.is_initialized: - self.lazy_initialization(conv_states) + if not self.is_conv_states_initialized: + self.lazy_initialization(conv_states=conv_states) # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. @@ -774,25 +789,28 @@ def update_ssm_state(self, ssm_states: torch.Tensor, **kwargs) -> torch.Tensor: Returns: `torch.Tensor`: The updated ssm states. """ + if not self.is_ssm_states_initialized: + self.lazy_initialization(ssm_states=ssm_states) self.ssm_states = ssm_states return self.ssm_states class MambaAndAttentionLayer(MambaLayer, DynamicLayer): + # The dynamic Attention part makes it non-compileable + is_compileable = False + def __init__(self): DynamicLayer.__init__(self) MambaLayer.__init__(self) - def lazy_initialization(self, states1: torch.Tensor, states2: torch.Tensor | None = None) -> None: - MambaLayer.lazy_initialization(self, states1) - DynamicLayer.lazy_initialization(self, states1, states2) - - def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor: - # We need this as lazy initialization may be called first from `update` from the attention part, so the inferred - # conv_kernel_size may not be correct - make sure we grab the correct one during the first call to `update_conv_state` - if not self.has_previous_state: - self.conv_kernel_size = conv_states.shape[-1] - return super().update_conv_state(conv_states, **kwargs) + def lazy_initialization(self, *args, **kwargs) -> None: + # When the Attention cache is used with `update`, `lazy_initialization` is called with 2 positional args + if len(args) == 2 and len(kwargs) == 0: + DynamicLayer.lazy_initialization(self, *args) + # Otherwise, for the Mamba cache, when it's called in `update_conv_state` or `update_ssm_state`, it's + # always called with 1 single kwarg (cause it needs to know if it's for the conv or ssm states) + if len(args) == 0 and len(kwargs) == 1: + MambaLayer.lazy_initialization(self, **kwargs) def reset(self) -> None: MambaLayer.reset(self) @@ -1312,6 +1330,9 @@ def __init__( layer = StaticSlidingWindowLayer( max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size ) + # Mamba layers are static by essence - using `"moe"` as well is a trick, see the comment about it on DynamicCache + elif layer_type in ("mamba", "conv", "linear_attention", "moe"): + layers.append(MambaLayer()) else: layer = StaticLayer(max_cache_len=max_cache_len) layers.append(layer) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8a55c184b0f0..d00aaee37db8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1768,19 +1768,19 @@ def _prepare_static_cache( def _supports_default_dynamic_cache(cls: type["GenerativePreTrainedModel"]) -> bool: """ Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. - This adds exception for some models like `Mamba` models which use their own caches. """ # NOTE: remove xlnet/reformer when the models are deprecated, non-standard model architecture/cache name - return not cls._is_stateful and all( - special_model_name not in cls.__name__.lower() - or "minimaxm2" in cls.__name__.lower() # name clash between minimax and minimax m2 - for special_model_name in [ - "reformer", - "minimax", - "xlnet", - "lfm2", - "lfm2_vl", - ] + unsupported_model_names = ( + "reformer", + "minimax", + "xlnet", + "olmo_hybrid", + "rwkv", + "xlstm", + ) + # name clash between minimax and minimax m2, so we add this "or" + return "minimaxm2" in cls.__name__.lower() or all( + unsupported_name not in cls.__name__.lower() for unsupported_name in unsupported_model_names ) def _prepare_cache_for_generation( @@ -1855,7 +1855,7 @@ def _prepare_cache_for_generation( f"and will be removed in v5.13. Please only use one of {STATIC_CACHE_IMPLEMENTATIONS}, " "and the layer structure will be inferred automatically." ) - model_kwargs["past_key_values"] = self._prepare_static_cache( + model_kwargs[cache_name] = self._prepare_static_cache( cache_implementation=generation_config.cache_implementation, batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size, max_cache_len=max_cache_length, @@ -1871,19 +1871,19 @@ def _prepare_cache_for_generation( cache_config = generation_config.cache_config if generation_config.cache_config is not None else {} cache_config.setdefault("config", self.config.get_text_config(decoder=True)) backend = cache_config.pop("backend", "quanto") - model_kwargs["past_key_values"] = QuantizedCache(backend=backend, **cache_config) + model_kwargs[cache_name] = QuantizedCache(backend=backend, **cache_config) # i.e. `cache_implementation` in [None, "dynamic", "offloaded", "dynamic_full"] # TODO: prepare linear cache from a single API, instead of creating in modeling code else: - model_kwargs["past_key_values"] = DynamicCache(**dynamic_cache_kwargs) + model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs) if ( self.config.is_encoder_decoder - and "past_key_values" in model_kwargs - and not isinstance(model_kwargs["past_key_values"], EncoderDecoderCache) + and cache_name in model_kwargs + and not isinstance(model_kwargs[cache_name], EncoderDecoderCache) ): - model_kwargs["past_key_values"] = EncoderDecoderCache( - model_kwargs["past_key_values"], # self-attention cache + model_kwargs[cache_name] = EncoderDecoderCache( + model_kwargs[cache_name], # self-attention cache DynamicCache(**dynamic_cache_kwargs), # cross-attention cache ) @@ -1983,13 +1983,13 @@ def _valid_auto_compile_criteria( if generation_config.disable_compile: return False + cache = model_kwargs.get("past_key_values", model_kwargs.get("cache_params")) + # Base logic valid_hardware = self.device.type in ["cuda", "xpu"] or bool( generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices ) - using_compilable_cache = ( - isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable - ) + using_compilable_cache = cache is not None and cache.is_compileable can_compile = valid_hardware and using_compilable_cache # Exception 1: Some quantization methods do not support compilation diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 43b70dab9185..32bda17d976e 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -487,6 +487,7 @@ def torch_forward( # [bsz, num_heads, head_dim] # Reshape ssm_states to merge the first two dimensions + ssm_states = ssm_states.to(device=C.device, dtype=C.dtype) ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] y = torch.bmm(ssm_states_reshaped, C_reshaped) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 5d56d1e724b2..78ec5789085d 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -31,7 +31,7 @@ if is_torch_available(): import torch - from transformers import DynamicCache, MambaForCausalLM, MambaModel + from transformers import CompileConfig, DynamicCache, MambaForCausalLM, MambaModel class MambaModelTester: @@ -452,8 +452,10 @@ def test_compile_mamba_cache(self): output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) - model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead") - output = model.generate(input_ids, max_new_tokens=20) + compile_config = CompileConfig(fullgraph=True, mode="reduce-overhead") + output = model.generate( + input_ids, max_new_tokens=20, cache_implementation="static", compile_config=compile_config + ) output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) From f4fc80170aa90e97b91b1e19440a2881d0293d0e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 25 Mar 2026 14:53:25 +0100 Subject: [PATCH 46/56] oupsi --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d00aaee37db8..49b864fc48b3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1774,7 +1774,7 @@ def _supports_default_dynamic_cache(cls: type["GenerativePreTrainedModel"]) -> b "reformer", "minimax", "xlnet", - "olmo_hybrid", + "olmohybrid", "rwkv", "xlstm", ) From 6aca24ed518acf1d4ae7220108aa36aa0a0ad163 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 25 Mar 2026 16:59:29 +0100 Subject: [PATCH 47/56] adjustments in generate --- src/transformers/generation/utils.py | 15 +++++++++------ tests/generation/test_utils.py | 16 ++++++++++++---- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 49b864fc48b3..c02448b51331 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1842,7 +1842,9 @@ def _prepare_cache_for_generation( generation_config.cache_implementation = "dynamic_full" dynamic_cache_kwargs = {} - if generation_config.cache_implementation != "dynamic_full": + # mamba models always need to pass the config, otherwise it will use an Attention cache + is_mamba = any(x in ("mamba", "conv", "linear_attention") for x in getattr(self.config, "layer_types", [])) + if generation_config.cache_implementation != "dynamic_full" or is_mamba: dynamic_cache_kwargs["config"] = self.config.get_text_config(decoder=True) if generation_config.cache_implementation == "offloaded": @@ -1989,7 +1991,9 @@ def _valid_auto_compile_criteria( valid_hardware = self.device.type in ["cuda", "xpu"] or bool( generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices ) - using_compilable_cache = cache is not None and cache.is_compileable + # Note: for full mamba models, even a DynamicCache is compileable since all layers are mamba, but we don't want + # to ALWAYS compile when calling `generate`, so we check the type + using_compilable_cache = cache is not None and cache.is_compileable and type(cache) is not DynamicCache can_compile = valid_hardware and using_compilable_cache # Exception 1: Some quantization methods do not support compilation @@ -3460,10 +3464,9 @@ def _assisted_decoding( # The cache must be dynamic for assisted generation, and the check must happen AFTER preparing cache if not model_kwargs["use_cache"]: raise ValueError("assisted generate requires `use_cache=True`") - if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"] or ( - "past_key_values" in model_kwargs - and hasattr(model_kwargs["past_key_values"], "layers") - and any(getattr(l, "is_compileable", False) for l in model_kwargs["past_key_values"].layers) + if ( + generation_config.cache_implementation in ["static", "hybrid", "sliding_window"] + or type(model_kwargs.get("past_key_values")) is StaticCache ): raise ValueError("assisted generate is not supported with Static cache classes`") # Get the candidate generator, given the parameterization diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b23b004c75ea..e91b5dc3aeb3 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2598,11 +2598,15 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertEqual(keys.shape, attention_shape) self.assertEqual(values.shape, attention_shape) self.assertEqual(layer.conv_states.shape, conv_shape) - self.assertEqual(layer.ssm_states.shape, ssm_shape) + # May not be used (e.g. lfm2) + if layer.is_ssm_states_initialized: + self.assertEqual(layer.ssm_states.shape, ssm_shape) # Mamba only layer cache elif type(layer) is MambaLayer: self.assertEqual(layer.conv_states.shape, conv_shape) - self.assertEqual(layer.ssm_states.shape, ssm_shape) + # May not be used (e.g. lfm2) + if layer.is_ssm_states_initialized: + self.assertEqual(layer.ssm_states.shape, ssm_shape) # Attention only layer type else: # Remove the seq_length dim for cross-attention cache (it changes based on the model) @@ -2656,11 +2660,15 @@ def _check_caches_are_equal(self, cache1: Cache, cache2: Cache): torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) - torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + # May not be used (e.g. lfm2) + if cache1.layers[idx].is_ssm_states_initialized: + torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) # Mamba layer elif type(cache1.layers[idx]) is MambaLayer: torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) - torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + # May not be used (e.g. lfm2) + if cache1.layers[idx].is_ssm_states_initialized: + torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) # Attention layer else: torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) From 13781f17a18ef2bd3d0c245d126e267d2ce48658 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 25 Mar 2026 17:04:03 +0100 Subject: [PATCH 48/56] allow cudagraphs --- src/transformers/cache_utils.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index df9057d57723..21fbaefcba29 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -742,10 +742,16 @@ def lazy_initialization( self.max_batch_size, self.conv_kernel_size = conv_states.shape[0], conv_states.shape[-1] # The shape is always static, so we init as such self.conv_states = torch.zeros_like(conv_states, dtype=self.dtype, device=self.device) + # Mark as static address to be able to use cudagraphs + if not is_torchdynamo_compiling(): + torch._dynamo.mark_static_address(self.conv_states) self.is_conv_states_initialized = True if ssm_states is not None: # The shape is always static, so we init as such self.ssm_states = torch.zeros_like(ssm_states, dtype=self.dtype, device=self.device) + # Mark as static address to be able to use cudagraphs + if not is_torchdynamo_compiling(): + torch._dynamo.mark_static_address(self.ssm_states) self.is_ssm_states_initialized = True def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor: @@ -762,20 +768,22 @@ def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor if not self.is_conv_states_initialized: self.lazy_initialization(conv_states=conv_states) - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now if not self.has_previous_state: - self.conv_states = conv_states + # Note that we copy instead of assigning, to preserve the static address for cudagraphs + self.conv_states.copy_(conv_states) self.has_previous_state = True + # Technically, this update is not logically correct if the prefill is smaller than `conv_kernel_size`, + # as it will `roll` anyway in the first decoding step, even though it should `roll` ONLY if the cache is already full. + # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now else: + # Note that we copy instead of assigning, to preserve the static address for cudagraphs num_new_tokens = conv_states.shape[-1] if num_new_tokens >= self.conv_kernel_size: - self.conv_states = conv_states[..., -self.conv_kernel_size :] + self.conv_states.copy_(conv_states[..., -self.conv_kernel_size :]) else: new_conv_states = self.conv_states.roll(shifts=-num_new_tokens, dims=-1) new_conv_states[:, :, -num_new_tokens:] = conv_states - self.conv_states = new_conv_states + self.conv_states.copy_(new_conv_states) return self.conv_states @@ -791,7 +799,8 @@ def update_ssm_state(self, ssm_states: torch.Tensor, **kwargs) -> torch.Tensor: """ if not self.is_ssm_states_initialized: self.lazy_initialization(ssm_states=ssm_states) - self.ssm_states = ssm_states + # Note that we copy instead of assigning, to preserve the static address for cudagraphs + self.ssm_states.copy_(ssm_states) return self.ssm_states @@ -1332,7 +1341,7 @@ def __init__( ) # Mamba layers are static by essence - using `"moe"` as well is a trick, see the comment about it on DynamicCache elif layer_type in ("mamba", "conv", "linear_attention", "moe"): - layers.append(MambaLayer()) + layer = MambaLayer() else: layer = StaticLayer(max_cache_len=max_cache_len) layers.append(layer) From 3df0d85ae30889ae50f484d284bc0c2b73d35846 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 25 Mar 2026 17:12:47 +0100 Subject: [PATCH 49/56] small oupsi --- src/transformers/generation/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c02448b51331..1f6125d52957 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1842,8 +1842,11 @@ def _prepare_cache_for_generation( generation_config.cache_implementation = "dynamic_full" dynamic_cache_kwargs = {} - # mamba models always need to pass the config, otherwise it will use an Attention cache - is_mamba = any(x in ("mamba", "conv", "linear_attention") for x in getattr(self.config, "layer_types", [])) + # mamba models always need to pass the config, otherwise it will use an Attention cache for the Mamba layers + is_mamba = any( + x in ("mamba", "conv", "linear_attention") + for x in getattr(self.config.get_text_config(decoder=True), "layer_types", []) + ) if generation_config.cache_implementation != "dynamic_full" or is_mamba: dynamic_cache_kwargs["config"] = self.config.get_text_config(decoder=True) From eadcfa47d04ebdaf0b4aadec46ebc9062ed8dc39 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 31 Mar 2026 10:58:58 +0200 Subject: [PATCH 50/56] start renaming --- src/transformers/cache_utils.py | 96 +- .../models/align/configuration_align.py | 1 + .../models/bamba/modeling_bamba.py | 12 +- .../models/bamba/modular_bamba.py | 12 +- .../models/falcon_h1/modeling_falcon_h1.py | 12 +- .../models/falcon_h1/modular_falcon_h1.py | 12 +- .../falcon_mamba/modeling_falcon_mamba.py | 8 +- .../falcon_mamba/modular_falcon_mamba.py | 8 +- .../modeling_granitemoehybrid.py | 12 +- .../models/jamba/modeling_jamba.py | 8 +- .../models/jamba/modular_jamba.py | 8 +- .../models/mamba/modeling_mamba.py | 8 +- .../models/mamba2/modeling_mamba2.py | 10 +- .../models/nemotron_h/modeling_nemotron_h.py | 14 +- .../models/qwen3_5/modeling_qwen3_5.py | 4 +- .../models/qwen3_5/modular_qwen3_5.py | 4 +- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 4 +- .../models/qwen3_next/modeling_qwen3_next.py | 4 +- .../models/qwen3_next/modular_qwen3_next.py | 4 +- .../models/zamba/modeling_zamba.py | 8 +- .../models/zamba2/modeling_zamba2.py | 14 +- .../models/zamba2/modular_zamba2.py | 14 +- tests/generation/test_utils.py | 32 +- .../test_modeling_falcon_mamba.py | 4 +- tests/models/mamba/test_modeling_mamba.py | 4 +- tests/models/mamba2/test_modeling_mamba2.py | 4 +- .../nemotron_h/test_modeling_nemotron_h.py | 4 +- utils/mlinter/.mlinter_cache.json | 1073 +++++++++++++++++ 28 files changed, 1242 insertions(+), 156 deletions(-) create mode 100644 utils/mlinter/.mlinter_cache.json diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 21fbaefcba29..297059a5b1e2 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -669,7 +669,7 @@ def _dequantize(self, qtensor): return tensor -class MambaCacheLayerMixin(ABC): +class LinearAttentionCacheLayerMixin(ABC): """Base, abstract class for a mamba single layer's cache.""" # All shapes are static by essence in a Mamba layer, so it is compileable @@ -677,9 +677,9 @@ class MambaCacheLayerMixin(ABC): def __init__(self): self.conv_states: torch.Tensor | None = None - self.ssm_states: torch.Tensor | None = None + self.recurrent_states: torch.Tensor | None = None self.is_conv_states_initialized = False - self.is_ssm_states_initialized = False + self.is_recurrent_states_initialized = False self.has_previous_state = False def __repr__(self): @@ -687,53 +687,53 @@ def __repr__(self): @abstractmethod def lazy_initialization( - self, conv_states: torch.Tensor | None = None, ssm_states: torch.Tensor | None = None + self, conv_states: torch.Tensor | None = None, recurrent_states: torch.Tensor | None = None ) -> None: ... @abstractmethod def update_conv_state(self, conv_states: torch.Tensor) -> torch.Tensor: ... @abstractmethod - def update_ssm_state(self, ssm_states: torch.Tensor) -> torch.Tensor: ... + def update_recurrent_state(self, recurrent_states: torch.Tensor) -> torch.Tensor: ... def offload(self): """Offload this layer's data to CPU device.""" if self.is_conv_states_initialized: self.conv_states = self.conv_states.to("cpu", non_blocking=True) - if self.is_ssm_states_initialized: - self.ssm_states = self.ssm_states.to("cpu", non_blocking=True) + if self.is_recurrent_states_initialized: + self.recurrent_states = self.recurrent_states.to("cpu", non_blocking=True) def prefetch(self): """In case of layer offloading, this allows to move the data back to the layer's device ahead of time.""" if self.is_conv_states_initialized and self.conv_states.device != self.device: self.conv_states = self.conv_states.to(self.device, non_blocking=True) - if self.is_ssm_states_initialized and self.ssm_states.device != self.device: - self.ssm_states = self.ssm_states.to(self.device, non_blocking=True) + if self.is_recurrent_states_initialized and self.recurrent_states.device != self.device: + self.recurrent_states = self.recurrent_states.to(self.device, non_blocking=True) def reset(self) -> None: """Resets the cache values while preserving the objects""" if self.is_conv_states_initialized: self.conv_states.zero_() - if self.is_ssm_states_initialized: - self.ssm_states.zero_() + if self.is_recurrent_states_initialized: + self.recurrent_states.zero_() self.has_previous_state = False def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" if self.is_conv_states_initialized: self.conv_states = self.conv_states.index_select(0, beam_idx.to(self.device)) - # ssm_states can stay empty sometimes, see e.g. lfm2 which only uses the conv_states - if self.is_ssm_states_initialized: - self.ssm_states = self.ssm_states.index_select(0, beam_idx.to(self.device)) + # recurrent_states can stay empty sometimes, see e.g. lfm2 which only uses the conv_states + if self.is_recurrent_states_initialized: + self.recurrent_states = self.recurrent_states.index_select(0, beam_idx.to(self.device)) def crop(self, max_length: int): # We don't crop the mamba cache, so simply do nothing here pass -class MambaLayer(MambaCacheLayerMixin): +class LinearAttentionLayer(LinearAttentionCacheLayerMixin): def lazy_initialization( - self, conv_states: torch.Tensor | None = None, ssm_states: torch.Tensor | None = None + self, conv_states: torch.Tensor | None = None, recurrent_states: torch.Tensor | None = None ) -> None: # Here, we will lazy init both states separately, each in their own update function if conv_states is not None: @@ -746,13 +746,13 @@ def lazy_initialization( if not is_torchdynamo_compiling(): torch._dynamo.mark_static_address(self.conv_states) self.is_conv_states_initialized = True - if ssm_states is not None: + if recurrent_states is not None: # The shape is always static, so we init as such - self.ssm_states = torch.zeros_like(ssm_states, dtype=self.dtype, device=self.device) + self.recurrent_states = torch.zeros_like(recurrent_states, dtype=self.dtype, device=self.device) # Mark as static address to be able to use cudagraphs if not is_torchdynamo_compiling(): - torch._dynamo.mark_static_address(self.ssm_states) - self.is_ssm_states_initialized = True + torch._dynamo.mark_static_address(self.recurrent_states) + self.is_recurrent_states_initialized = True def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor: """ @@ -787,7 +787,7 @@ def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor return self.conv_states - def update_ssm_state(self, ssm_states: torch.Tensor, **kwargs) -> torch.Tensor: + def update_recurrent_state(self, recurrent_states: torch.Tensor, **kwargs) -> torch.Tensor: """ Update the mamba cache in-place, and return the necessary ssm states. @@ -797,37 +797,37 @@ def update_ssm_state(self, ssm_states: torch.Tensor, **kwargs) -> torch.Tensor: Returns: `torch.Tensor`: The updated ssm states. """ - if not self.is_ssm_states_initialized: - self.lazy_initialization(ssm_states=ssm_states) + if not self.is_recurrent_states_initialized: + self.lazy_initialization(recurrent_states=recurrent_states) # Note that we copy instead of assigning, to preserve the static address for cudagraphs - self.ssm_states.copy_(ssm_states) - return self.ssm_states + self.recurrent_states.copy_(recurrent_states) + return self.recurrent_states -class MambaAndAttentionLayer(MambaLayer, DynamicLayer): +class LinearAttentionAndAttentionLayer(LinearAttentionLayer, DynamicLayer): # The dynamic Attention part makes it non-compileable is_compileable = False def __init__(self): DynamicLayer.__init__(self) - MambaLayer.__init__(self) + LinearAttentionLayer.__init__(self) def lazy_initialization(self, *args, **kwargs) -> None: # When the Attention cache is used with `update`, `lazy_initialization` is called with 2 positional args if len(args) == 2 and len(kwargs) == 0: DynamicLayer.lazy_initialization(self, *args) - # Otherwise, for the Mamba cache, when it's called in `update_conv_state` or `update_ssm_state`, it's + # Otherwise, for the Mamba cache, when it's called in `update_conv_state` or `update_recurrent_state`, it's # always called with 1 single kwarg (cause it needs to know if it's for the conv or ssm states) if len(args) == 0 and len(kwargs) == 1: - MambaLayer.lazy_initialization(self, **kwargs) + LinearAttentionLayer.lazy_initialization(self, **kwargs) def reset(self) -> None: - MambaLayer.reset(self) + LinearAttentionLayer.reset(self) DynamicLayer.reset(self) def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" - MambaLayer.reorder_cache(self, beam_idx) + LinearAttentionLayer.reorder_cache(self, beam_idx) DynamicLayer.reorder_cache(self, beam_idx) @@ -838,9 +838,9 @@ class Cache: Args: layers (`Optional`, *optional*): - A list of pre-created `CacheLayerMixin` or `MambaCacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate` + A list of pre-created `CacheLayerMixin` or `LinearAttentionCacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate` will be used. - layer_class_to_replicate (`type[CacheLayerMixin | MambaCacheLayerMixin]`, *optional*): + layer_class_to_replicate (`type[CacheLayerMixin | LinearAttentionCacheLayerMixin]`, *optional*): Only used if `layers` is omitted (`None`), in which case it will be used as the base class for each layer, and the layers will be added lazily as soon as `update` is called with a `layer_idx` greater than the current list of layers. @@ -853,8 +853,8 @@ class Cache: def __init__( self, - layers: list[CacheLayerMixin | MambaCacheLayerMixin] | None = None, - layer_class_to_replicate: type[CacheLayerMixin | MambaCacheLayerMixin] | None = None, + layers: list[CacheLayerMixin | LinearAttentionCacheLayerMixin] | None = None, + layer_class_to_replicate: type[CacheLayerMixin | LinearAttentionCacheLayerMixin] | None = None, offloading: bool = False, offload_only_non_sliding: bool = True, ): @@ -956,14 +956,14 @@ def update_conv_state(self, conv_states: torch.Tensor, layer_idx: int, **kwargs) """ # NOTE: if we slightly break `update` arg order, we could combine this with it, and allow offloading support # out of the box - if not isinstance(self.layers[layer_idx], MambaCacheLayerMixin): + if not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin): raise ValueError("Cannot call `update_conv_state` on a non-Mamba layer!") conv_states = self.layers[layer_idx].update_conv_state(conv_states, **kwargs) return conv_states - def update_ssm_state(self, ssm_states: torch.Tensor, layer_idx: int, **kwargs) -> torch.Tensor: + def update_recurrent_state(self, recurrent_states: torch.Tensor, layer_idx: int, **kwargs) -> torch.Tensor: """ - Updates the cache with the new `ssm_states` for the layer `layer_idx`. + Updates the cache with the new `recurrent_states` for the layer `layer_idx`. Parameters: smm_states (`torch.Tensor`): @@ -976,10 +976,10 @@ def update_ssm_state(self, ssm_states: torch.Tensor, layer_idx: int, **kwargs) - """ # NOTE: if we slightly break `update` arg order, we could combine this with it, and allow offloading support # out of the box - if not isinstance(self.layers[layer_idx], MambaCacheLayerMixin): + if not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin): raise ValueError("Cannot call `update_conv_state` on a non-Mamba layer!") - ssm_states = self.layers[layer_idx].update_ssm_state(ssm_states, **kwargs) - return ssm_states + recurrent_states = self.layers[layer_idx].update_recurrent_state(recurrent_states, **kwargs) + return recurrent_states def early_initialization( self, batch_size: int, num_heads: int, head_dim: int, dtype: torch.dtype, device: torch.device @@ -1029,14 +1029,16 @@ def has_previous_state(self, layer_idx: int | None = None) -> bool: if layer_idx is None: try: layer_idx = next( - idx for idx in range(len(self) - 1, -1, -1) if isinstance(self.layers[idx], MambaCacheLayerMixin) + idx + for idx in range(len(self) - 1, -1, -1) + if isinstance(self.layers[idx], LinearAttentionCacheLayerMixin) ) except StopIteration: raise ValueError( "`has_previous_state` can only be called on Mamba layers, and the current Cache seem to only contain " "Attention layers." ) - elif not isinstance(self.layers[layer_idx], MambaCacheLayerMixin): + elif not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin): raise ValueError( f"You called `has_previous_state` on layer index {layer_idx}, but this layer is an Attention layer, which " "does not support calling it." @@ -1222,14 +1224,14 @@ def __init__( # states they should return - only the mask changes to make them different at the end! if layer_type in ("sliding_attention", "chunked_attention"): layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window)) - # Note: we want moe layers to be MambaLayer, so that we can correctly grab sequence length etc from attention layers. + # Note: we want moe layers to be LinearAttentionLayer, so that we can correctly grab sequence length etc from attention layers. # Since moe layers will stay empty (they don't need any cache), we don't want them to collide for mask creation etc # TODO: maybe use a dummy layer in those cases, or a dictionary {idx: Layer} for self.layers, so that we can skip # the indices we don't need elif layer_type in ("mamba", "conv", "linear_attention", "moe"): - layers.append(MambaLayer()) + layers.append(LinearAttentionLayer()) elif layer_type == "hybrid": - layers.append(MambaAndAttentionLayer()) + layers.append(LinearAttentionAndAttentionLayer()) else: layers.append(DynamicLayer()) @@ -1341,7 +1343,7 @@ def __init__( ) # Mamba layers are static by essence - using `"moe"` as well is a trick, see the comment about it on DynamicCache elif layer_type in ("mamba", "conv", "linear_attention", "moe"): - layer = MambaLayer() + layer = LinearAttentionLayer() else: layer = StaticLayer(max_cache_len=max_cache_len) layers.append(layer) diff --git a/src/transformers/models/align/configuration_align.py b/src/transformers/models/align/configuration_align.py index cde6445cf62f..babf97d4572a 100644 --- a/src/transformers/models/align/configuration_align.py +++ b/src/transformers/models/align/configuration_align.py @@ -59,6 +59,7 @@ class AlignTextConfig(PreTrainedConfig): pad_token_id: int | None = 0 bos_token_id: int | None = None eos_token_id: int | list[int] | None = None + tie_word_embeddings: True @auto_docstring(checkpoint="kakaobrain/align-base") diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 0b599490c5ff..90129fc998b1 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -532,7 +532,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.layers[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -633,7 +633,7 @@ def cuda_kernels_forward( # Init cache if ssm_state is not None and cache_params is not None: - ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer @@ -694,7 +694,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if use_precomputed_states: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.layers[self.layer_idx].ssm_states.device + cache_device = cache_params.layers[self.layer_idx].recurrent_states.device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -724,8 +724,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - ssm_states = cache_params.layers[self.layer_idx].ssm_states * dA + dBx - ssm_states = cache_params.update_ssm_state(ssm_states, self.layer_idx) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -822,7 +822,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index f60d8c405b4d..a79b26ff8fe9 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -289,7 +289,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.layers[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -390,7 +390,7 @@ def cuda_kernels_forward( # Init cache if ssm_state is not None and cache_params is not None: - ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer @@ -451,7 +451,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if use_precomputed_states: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.layers[self.layer_idx].ssm_states.device + cache_device = cache_params.layers[self.layer_idx].recurrent_states.device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -481,8 +481,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - ssm_states = cache_params.layers[self.layer_idx].ssm_states * dA + dBx - ssm_states = cache_params.update_ssm_state(ssm_states, self.layer_idx) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -579,7 +579,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 7aaba6aacb4d..37b5da9df4b3 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -529,7 +529,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.layers[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -649,7 +649,7 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: - ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer if self.mamba_rms_norm: @@ -715,7 +715,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if use_precomputed_states: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.layers[self.layer_idx].ssm_states.device + cache_device = cache_params.layers[self.layer_idx].recurrent_states.device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -745,8 +745,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - ssm_states = cache_params.layers[self.layer_idx].ssm_states * dA + dBx - ssm_states = cache_params.update_ssm_state(ssm_states, self.layer_idx) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -843,7 +843,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) if self.mamba_rms_norm: scan_output = self.norm(y, gate) diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 61fab3117d96..75cbed28a646 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -316,7 +316,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.layers[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -436,7 +436,7 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: - ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer if self.mamba_rms_norm: @@ -502,7 +502,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if use_precomputed_states: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.layers[self.layer_idx].ssm_states.device + cache_device = cache_params.layers[self.layer_idx].recurrent_states.device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -532,8 +532,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - ssm_states = cache_params.layers[self.layer_idx].ssm_states * dA + dBx - ssm_states = cache_params.update_ssm_state(ssm_states, self.layer_idx) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -630,7 +630,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) if self.mamba_rms_norm: scan_output = self.norm(y, gate) diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index b2fc8afdae75..32fd1bf4a358 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -280,7 +280,7 @@ def cuda_kernels_forward( time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None if is_decoding: scan_outputs = selective_state_update( - cache_params.layers[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].recurrent_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -305,7 +305,7 @@ def cuda_kernels_forward( return_last_state=True, ) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) @@ -327,7 +327,7 @@ def slow_forward(self, hidden_states = hidden_states * attention_mask.unsqueeze(1) if cache_params is not None and cache_params.has_previous_state(self.layer_idx): - ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() + ssm_state = cache_params.layers[self.layer_idx].recurrent_states.clone() else: ssm_state = torch.zeros( (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype @@ -429,7 +429,7 @@ def combine_fn(left, right): scan_output = scan_output * self.act(gate) if cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] diff --git a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py index 2e1b940cf48d..6962c66c2973 100644 --- a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -245,7 +245,7 @@ def cuda_kernels_forward( time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None if is_decoding: scan_outputs = selective_state_update( - cache_params.layers[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].recurrent_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -270,7 +270,7 @@ def cuda_kernels_forward( return_last_state=True, ) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) @@ -292,7 +292,7 @@ def slow_forward( hidden_states = hidden_states * attention_mask.unsqueeze(1) if cache_params is not None and cache_params.has_previous_state(self.layer_idx): - ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() + ssm_state = cache_params.layers[self.layer_idx].recurrent_states.clone() else: ssm_state = torch.zeros( (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype @@ -394,7 +394,7 @@ def combine_fn(left, right): scan_output = scan_output * self.act(gate) if cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index baeaf65bbd45..dadffaea0072 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -409,7 +409,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.layers[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -510,7 +510,7 @@ def cuda_kernels_forward( # Init cache if ssm_state is not None and cache_params is not None: - ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer @@ -571,7 +571,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if use_precomputed_states: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.layers[self.layer_idx].ssm_states.device + cache_device = cache_params.layers[self.layer_idx].recurrent_states.device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -601,8 +601,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - ssm_states = cache_params.layers[self.layer_idx].ssm_states * dA + dBx - ssm_states = cache_params.update_ssm_state(ssm_states, self.layer_idx) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -699,7 +699,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 5ffd16302e06..ae618fb4a2b3 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -343,7 +343,7 @@ def cuda_kernels_forward( time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None if use_precomputed_states: scan_outputs = selective_state_update( - cache_params.layers[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].recurrent_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -368,7 +368,7 @@ def cuda_kernels_forward( return_last_state=True, ) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) @@ -388,7 +388,7 @@ def slow_forward(self, input_states, cache_params: Cache | None = None, attentio if cache_params is not None and cache_params.has_previous_state(self.layer_idx): # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass - ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() + ssm_state = cache_params.layers[self.layer_idx].recurrent_states.clone() else: ssm_state = torch.zeros( (batch_size, self.intermediate_size, self.ssm_state_size), @@ -446,7 +446,7 @@ def slow_forward(self, input_states, cache_params: Cache | None = None, attentio scan_output = (scan_output * self.act(gate)) if cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index e24be7640f40..21e6623d3296 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -236,7 +236,7 @@ def cuda_kernels_forward( time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None if use_precomputed_states: scan_outputs = selective_state_update( - cache_params.layers[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].recurrent_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -261,7 +261,7 @@ def cuda_kernels_forward( return_last_state=True, ) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) @@ -281,7 +281,7 @@ def slow_forward(self, input_states, cache_params: Cache | None = None, attentio if cache_params is not None and cache_params.has_previous_state(self.layer_idx): # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass - ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() + ssm_state = cache_params.layers[self.layer_idx].recurrent_states.clone() else: ssm_state = torch.zeros( (batch_size, self.intermediate_size, self.ssm_state_size), @@ -339,7 +339,7 @@ def slow_forward(self, input_states, cache_params: Cache | None = None, attentio scan_output = (scan_output * self.act(gate)) if cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 1591583112dc..2e5695c1d4fa 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -235,7 +235,7 @@ def cuda_kernels_forward( time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None if is_decoding: scan_outputs = selective_state_update( - cache_params.layers[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].recurrent_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -260,7 +260,7 @@ def cuda_kernels_forward( return_last_state=True, ) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) @@ -278,7 +278,7 @@ def slow_forward(self, input_states, cache_params: Cache | None=None, attention_ hidden_states = hidden_states * attention_mask.unsqueeze(1) if cache_params is not None and cache_params.has_previous_state(self.layer_idx): - ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() + ssm_state = cache_params.layers[self.layer_idx].recurrent_states.clone() else: ssm_state = torch.zeros( (batch_size, self.intermediate_size, self.ssm_state_size), @@ -356,7 +356,7 @@ def combine_fn(left, right): scan_output = (scan_output * self.act(gate)) if cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 32bda17d976e..d0a47ef9dc63 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -286,7 +286,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.layers[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -384,7 +384,7 @@ def cuda_kernels_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(ssm_state, layer_idx=self.layer_idx) + cache_params.update_recurrent_state(ssm_state, layer_idx=self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer @@ -476,8 +476,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - ssm_states = cache_params.layers[self.layer_idx].ssm_states * dA + dBx - ssm_states = cache_params.update_ssm_state(ssm_states, layer_idx=self.layer_idx) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, layer_idx=self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -574,7 +574,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(ssm_state, layer_idx=self.layer_idx) + cache_params.update_recurrent_state(ssm_state, layer_idx=self.layer_idx) scan_output = self.norm(y, gate) diff --git a/src/transformers/models/nemotron_h/modeling_nemotron_h.py b/src/transformers/models/nemotron_h/modeling_nemotron_h.py index 41aba6122448..af09fdfaf36f 100644 --- a/src/transformers/models/nemotron_h/modeling_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modeling_nemotron_h.py @@ -255,7 +255,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.layers[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -356,7 +356,7 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer scan_output = self.norm(scan_output, gate) @@ -435,9 +435,9 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention dBx = dB * hidden_states[..., None] # State calculation - ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() - ssm_state = ssm_state * dA + dBx - ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) + recurrent_states = cache_params.layers[self.layer_idx].recurrent_states.clone() + recurrent_states = recurrent_states * dA + dBx + recurrent_states = cache_params.update_recurrent_state(recurrent_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -446,7 +446,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = ssm_state.to(C.dtype) # Shape: [b, h, d, n] + ssm_states = recurrent_states.to(C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -535,7 +535,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention y = y[:, :seq_len, :, :] y = y.reshape(batch_size, seq_len, -1) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index feda0bd67a2b..eba3eec02fdd 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -437,7 +437,7 @@ def forward( # getting projected states from cache if it exists if use_precomputed_states: conv_state = cache_params.layers[self.layer_idx].conv_states - recurrent_state = cache_params.layers[self.layer_idx].ssm_states + recurrent_state = cache_params.layers[self.layer_idx].recurrent_states mixed_qkv = self.in_proj_qkv(hidden_states) mixed_qkv = mixed_qkv.transpose(1, 2) @@ -521,7 +521,7 @@ def forward( # Update cache if cache_params is not None: - cache_params.update_ssm_state(last_recurrent_state, self.layer_idx) + cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx) # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 69e9bdb0faf2..8fddbc6115c1 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -222,7 +222,7 @@ def forward( # getting projected states from cache if it exists if use_precomputed_states: conv_state = cache_params.layers[self.layer_idx].conv_states - recurrent_state = cache_params.layers[self.layer_idx].ssm_states + recurrent_state = cache_params.layers[self.layer_idx].recurrent_states mixed_qkv = self.in_proj_qkv(hidden_states) mixed_qkv = mixed_qkv.transpose(1, 2) @@ -306,7 +306,7 @@ def forward( # Update cache if cache_params is not None: - cache_params.update_ssm_state(last_recurrent_state, self.layer_idx) + cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx) # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 3c578e6fe836..be4501d34903 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -438,7 +438,7 @@ def forward( # getting projected states from cache if it exists if use_precomputed_states: conv_state = cache_params.layers[self.layer_idx].conv_states - recurrent_state = cache_params.layers[self.layer_idx].ssm_states + recurrent_state = cache_params.layers[self.layer_idx].recurrent_states mixed_qkv = self.in_proj_qkv(hidden_states) mixed_qkv = mixed_qkv.transpose(1, 2) @@ -522,7 +522,7 @@ def forward( # Update cache if cache_params is not None: - cache_params.update_ssm_state(last_recurrent_state, self.layer_idx) + cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx) # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index da31a5eb8e7d..9e7fa7e01c69 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -607,7 +607,7 @@ def forward( # getting projected states from cache if it exists if use_precomputed_states: conv_state = cache_params.layers[self.layer_idx].conv_states - recurrent_state = cache_params.layers[self.layer_idx].ssm_states + recurrent_state = cache_params.layers[self.layer_idx].recurrent_states projected_states_qkvz = self.in_proj_qkvz(hidden_states) projected_states_ba = self.in_proj_ba(hidden_states) @@ -688,7 +688,7 @@ def forward( # Update cache if cache_params is not None: - cache_params.update_ssm_state(last_recurrent_state, self.layer_idx) + cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx) z_shape_og = z.shape # reshape input data into 2D tensor diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index 80fa61c3229b..417a9a59cf8b 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -446,7 +446,7 @@ def forward( # getting projected states from cache if it exists if use_precomputed_states: conv_state = cache_params.layers[self.layer_idx].conv_states - recurrent_state = cache_params.layers[self.layer_idx].ssm_states + recurrent_state = cache_params.layers[self.layer_idx].recurrent_states projected_states_qkvz = self.in_proj_qkvz(hidden_states) projected_states_ba = self.in_proj_ba(hidden_states) @@ -527,7 +527,7 @@ def forward( # Update cache if cache_params is not None: - cache_params.update_ssm_state(last_recurrent_state, self.layer_idx) + cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx) z_shape_og = z.shape # reshape input data into 2D tensor diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 405434004ca9..3e891f9a3baf 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -322,7 +322,7 @@ def cuda_kernels_forward( if use_precomputed_states: for n in range(self.n_mamba_heads): scan_outputs_ = selective_state_update( - cache_params.layers[self.layer_idx].ssm_states[:, n], + cache_params.layers[self.layer_idx].recurrent_states[:, n], hidden_states[n, ..., 0], discrete_time_step[n, ..., 0], A[n], @@ -357,7 +357,7 @@ def cuda_kernels_forward( scan_outputs = torch.cat((scan_outputs, scan_outputs_), dim=1).contiguous() ssm_state = torch.cat((ssm_state, ssm_state_.unsqueeze(1)), dim=1) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) @@ -376,7 +376,7 @@ def slow_forward(self, input_states, cache_params: Cache | None = None, attentio if cache_params is not None and cache_params.has_previous_state(self.layer_idx): # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass - ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() + ssm_state = cache_params.layers[self.layer_idx].recurrent_states.clone() else: ssm_state = torch.zeros( (batch_size, self.n_mamba_heads, self.mamba_head_dim, self.ssm_state_size), @@ -437,7 +437,7 @@ def slow_forward(self, input_states, cache_params: Cache | None = None, attentio scan_output = scan_output * self.act(gate) if cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj( diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 4de0979a71c0..ea5054194de6 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -543,7 +543,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.layers[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -644,7 +644,7 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer scan_output = self.norm(scan_output, gate) @@ -723,9 +723,9 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention dBx = dB * hidden_states[..., None] # State calculation - ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() - ssm_state = ssm_state * dA + dBx - ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) + recurrent_states = cache_params.layers[self.layer_idx].recurrent_states.clone() + recurrent_states = recurrent_states * dA + dBx + recurrent_states = cache_params.update_recurrent_state(recurrent_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -734,7 +734,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = ssm_state.to(C.dtype) # Shape: [b, h, d, n] + ssm_states = recurrent_states.to(C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -823,7 +823,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention y = y[:, :seq_len, :, :] y = y.reshape(batch_size, seq_len, -1) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 2f858fc2eeb8..f39b6de31ff6 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -331,7 +331,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.layers[self.layer_idx].ssm_states, + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -432,7 +432,7 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer scan_output = self.norm(scan_output, gate) @@ -511,9 +511,9 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention dBx = dB * hidden_states[..., None] # State calculation - ssm_state = cache_params.layers[self.layer_idx].ssm_states.clone() - ssm_state = ssm_state * dA + dBx - ssm_state = cache_params.update_ssm_state(ssm_state, self.layer_idx) + recurrent_states = cache_params.layers[self.layer_idx].recurrent_states.clone() + recurrent_states = recurrent_states * dA + dBx + recurrent_states = cache_params.update_recurrent_state(recurrent_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -522,7 +522,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = ssm_state.to(C.dtype) # Shape: [b, h, d, n] + ssm_states = recurrent_states.to(C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -611,7 +611,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention y = y[:, :seq_len, :, :] y = y.reshape(batch_size, seq_len, -1) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(ssm_state, self.layer_idx) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e91b5dc3aeb3..c0d24abda5c7 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -81,8 +81,8 @@ Cache, DynamicCache, EncoderDecoderCache, - MambaAndAttentionLayer, - MambaLayer, + LinearAttentionAndAttentionLayer, + LinearAttentionLayer, QuantoQuantizedLayer, StaticCache, ) @@ -2591,7 +2591,7 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l # Check each layer has the correct shape for layer in past_key_values.layers: # Mamba + Attention layer cache - if type(layer) is MambaAndAttentionLayer: + if type(layer) is LinearAttentionAndAttentionLayer: # Remove the seq_length dim for cross-attention cache (it changes based on the model) keys = layer.keys if seq_length is not None else layer.keys[:, :, 0, :] values = layer.values if seq_length is not None else layer.values[:, :, 0, :] @@ -2599,14 +2599,14 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertEqual(values.shape, attention_shape) self.assertEqual(layer.conv_states.shape, conv_shape) # May not be used (e.g. lfm2) - if layer.is_ssm_states_initialized: - self.assertEqual(layer.ssm_states.shape, ssm_shape) + if layer.is_recurrent_states_initialized: + self.assertEqual(layer.recurrent_states.shape, ssm_shape) # Mamba only layer cache - elif type(layer) is MambaLayer: + elif type(layer) is LinearAttentionLayer: self.assertEqual(layer.conv_states.shape, conv_shape) # May not be used (e.g. lfm2) - if layer.is_ssm_states_initialized: - self.assertEqual(layer.ssm_states.shape, ssm_shape) + if layer.is_recurrent_states_initialized: + self.assertEqual(layer.recurrent_states.shape, ssm_shape) # Attention only layer type else: # Remove the seq_length dim for cross-attention cache (it changes based on the model) @@ -2656,19 +2656,23 @@ def _check_caches_are_equal(self, cache1: Cache, cache2: Cache): self.assertEqual(type(cache1.layers[idx]), type(cache2.layers[idx])) # Mamba + Attention layer - if type(cache1.layers[idx]) is MambaAndAttentionLayer: + if type(cache1.layers[idx]) is LinearAttentionAndAttentionLayer: torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) # May not be used (e.g. lfm2) - if cache1.layers[idx].is_ssm_states_initialized: - torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + if cache1.layers[idx].is_recurrent_states_initialized: + torch.testing.assert_close( + cache1.layers[idx].recurrent_states, cache2.layers[idx].recurrent_states + ) # Mamba layer - elif type(cache1.layers[idx]) is MambaLayer: + elif type(cache1.layers[idx]) is LinearAttentionLayer: torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) # May not be used (e.g. lfm2) - if cache1.layers[idx].is_ssm_states_initialized: - torch.testing.assert_close(cache1.layers[idx].ssm_states, cache2.layers[idx].ssm_states) + if cache1.layers[idx].is_recurrent_states_initialized: + torch.testing.assert_close( + cache1.layers[idx].recurrent_states, cache2.layers[idx].recurrent_states + ) # Attention layer else: torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py index 3c674b1302eb..b14bc4227efe 100644 --- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py +++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py @@ -337,7 +337,9 @@ def recursive_check(tuple_object, dict_object): if isinstance(tuple_object, DynamicCache): # MODIFIED PART START for idx in range(len(tuple_object)): recursive_check(tuple_object.layers[idx].conv_states, dict_object.layers[idx].conv_states) - recursive_check(tuple_object.layers[idx].ssm_states, dict_object.layers[idx].ssm_states) + recursive_check( + tuple_object.layers[idx].recurrent_states, dict_object.layers[idx].recurrent_states + ) elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): recursive_check(tuple_iterable_value, dict_iterable_value) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 78ec5789085d..276c03d65099 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -304,7 +304,9 @@ def recursive_check(tuple_object, dict_object): if isinstance(tuple_object, DynamicCache): # MODIFIED PART START for idx in range(len(tuple_object)): recursive_check(tuple_object.layers[idx].conv_states, dict_object.layers[idx].conv_states) - recursive_check(tuple_object.layers[idx].ssm_states, dict_object.layers[idx].ssm_states) + recursive_check( + tuple_object.layers[idx].recurrent_states, dict_object.layers[idx].recurrent_states + ) elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): recursive_check(tuple_iterable_value, dict_iterable_value) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 6c0b3b865dd9..f64047907acc 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -286,7 +286,9 @@ def recursive_check(tuple_object, dict_object): if isinstance(tuple_object, DynamicCache): # MODIFIED PART START for idx in range(len(tuple_object)): recursive_check(tuple_object.layers[idx].conv_states, dict_object.layers[idx].conv_states) - recursive_check(tuple_object.layers[idx].ssm_states, dict_object.layers[idx].ssm_states) + recursive_check( + tuple_object.layers[idx].recurrent_states, dict_object.layers[idx].recurrent_states + ) elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): recursive_check(tuple_iterable_value, dict_iterable_value) diff --git a/tests/models/nemotron_h/test_modeling_nemotron_h.py b/tests/models/nemotron_h/test_modeling_nemotron_h.py index 1ca4225cb7a9..2c45f4c5485c 100644 --- a/tests/models/nemotron_h/test_modeling_nemotron_h.py +++ b/tests/models/nemotron_h/test_modeling_nemotron_h.py @@ -386,7 +386,7 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l # Moe layers have a default mamba cache instantiated, but it stays empty as the layer does not use it if layer_type == "moe": self.assertEqual(layer.conv_states, None) - self.assertEqual(layer.ssm_states, None) + self.assertEqual(layer.recurrent_states, None) # Attention layer cache elif layer_type == "attention": self.assertEqual(layer.keys.shape, attention_shape) @@ -394,7 +394,7 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l # Mamba layer cache elif layer_type == "mamba": self.assertEqual(layer.conv_states.shape, conv_shape) - self.assertEqual(layer.ssm_states.shape, ssm_shape) + self.assertEqual(layer.recurrent_states.shape, ssm_shape) else: raise ValueError("Unknown layer type.") diff --git a/utils/mlinter/.mlinter_cache.json b/utils/mlinter/.mlinter_cache.json new file mode 100644 index 000000000000..39154b73c1d7 --- /dev/null +++ b/utils/mlinter/.mlinter_cache.json @@ -0,0 +1,1073 @@ +{ + "src/transformers/models/afmoe/configuration_afmoe.py": "cb18784c5d578c0352c0f2958afb473b5aadd3ef4ecdc5216150204ab553493d", + "src/transformers/models/afmoe/modeling_afmoe.py": "a5681d5e3ed6e4c25c8cc5d207a40b1757c7f21e342c982380ec1119ac44861c", + "src/transformers/models/afmoe/modular_afmoe.py": "ae6dfcbc3fcf34c2fd3b217128d1e33db0d70db9045a19fb6da66e40f1f50a2c", + "src/transformers/models/aimv2/configuration_aimv2.py": "89696ddda44298d16a3c5c5e40741588d5d66a0679339479099a5ff42f6a23f1", + "src/transformers/models/aimv2/modeling_aimv2.py": "cb9287ca4946a51b16c99eb49bb4eafbad9df2a81fab1aa61d3beb7cab8c0a80", + "src/transformers/models/aimv2/modular_aimv2.py": "97d66d6d0756d07e7e0d1d730b4c89da394412d4507d92ede09fd7f5b11623f3", + "src/transformers/models/albert/configuration_albert.py": "e3c4d4e4c87111b669cc285cadb206d58731e0c88d166a1113a0f096f4a8909f", + "src/transformers/models/albert/modeling_albert.py": "fc64e43d93f5bbbe8734c663ef6740a559799cabf441e436cb2901cac9b27a38", + "src/transformers/models/align/configuration_align.py": "c4dc28a3ba2be74752f6ffecbb75c425332f2670a9de6d13584f6779a5ee1058", + "src/transformers/models/align/modeling_align.py": "73ddf7860acecd9ba61dc269ab54bb5e949c227016eefc3a66d823dcab8ef94f", + "src/transformers/models/altclip/configuration_altclip.py": "9487f951824ef5c5f60eb001604ead445b48968d8738ba674b4823cf5f7e298f", + "src/transformers/models/altclip/modeling_altclip.py": "b2d926d1f63e86913b12cdeaf3cbee1b28c95b6c80765a30a63ac910f9dd02b9", + "src/transformers/models/apertus/configuration_apertus.py": "ae34c92ef6630fdf3b4875f5b4b6fc08fb67a6b11533515411ae37b7a4dc4ed5", + "src/transformers/models/apertus/modeling_apertus.py": "f6e3cb98e5dfd454dafe50c36b52734e554700dce7cd0dced4ded6791c586b26", + "src/transformers/models/apertus/modular_apertus.py": "7c0c5f12ce3e6bbe163501d0a476eeb52c88255de081e5a03cebcfb42f7a5e67", + "src/transformers/models/arcee/configuration_arcee.py": "c1bdb413f20fda66604002b0c1d2fd06cb67c1e6bc3090b0282e34a77a52e387", + "src/transformers/models/arcee/modeling_arcee.py": "1fdfa8f1d32d2f2193d3146ed6baceafb3122e72f6888c4b471f9cd3be44a087", + "src/transformers/models/arcee/modular_arcee.py": "bb3829636df7ae7428960147106cbbc23f73503f57919371c8d79ca75c7fa45d", + "src/transformers/models/aria/configuration_aria.py": "8cabf127bb9ead8ae62278279aebac74aaecf8eb78a28be1e19af8e984aec4e6", + "src/transformers/models/aria/modeling_aria.py": "9828aae955a915bfc59b9714d6ae9c560f0791131d1a15d95218564fadffb66b", + "src/transformers/models/aria/modular_aria.py": "11784c387433a2ffa27e5eefadc8eda4c3bb75191a3077fc4d63a01421d80203", + "src/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py": "fd56d62a882a8c4baaa6e07b437635c24360a1e6f782125480baa6fbec35f18f", + "src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py": "0c744ebdb41884f72e9fd983da1f5b4a62ead979ac0d3f0010edeb968056401a", + "src/transformers/models/audioflamingo3/configuration_audioflamingo3.py": "23203bcec4b5df4af65faf6dc95d2ea6e9a3c36f38bdb4ac200a5b311674a6e2", + "src/transformers/models/audioflamingo3/modeling_audioflamingo3.py": "14b9c7e3a43c55ea507b41e815b3acb6d02fb0ff917b55a616e5451b70dc42d0", + "src/transformers/models/audioflamingo3/modular_audioflamingo3.py": "ba453855d1764eada7316808882e1bcca9a54505d58d97fbe701730d84485eba", + "src/transformers/models/auto/configuration_auto.py": "9bc2674a59eac35c771a46cc4059aaeb37e8e6602f17d36f8b10c3d18babc502", + "src/transformers/models/auto/modeling_auto.py": "b01d5c96f8d7125b2f70b630032dab0ce24a42f45e4f46b963ba34492bf5ef16", + "src/transformers/models/autoformer/configuration_autoformer.py": "243ec867b1578384d0860237d85a86451dce375b6b013f209a06c74411bfc810", + "src/transformers/models/autoformer/modeling_autoformer.py": "4d2f8c8cedd2a220a9d48dfcbe8dff1b8605390ced03a860293e3b81a3524d7e", + "src/transformers/models/aya_vision/configuration_aya_vision.py": "e8039bc9df5ac44693f533fd50b53b0b30dd3cd3f17f66a9df156b139e1f1b1f", + "src/transformers/models/aya_vision/modeling_aya_vision.py": "bfebfb56427cf3428c25fd5c2869cd3d5f95061220f67abca10fd102b9330f5d", + "src/transformers/models/aya_vision/modular_aya_vision.py": "68f24cffef72590c7f59074803d806942e2e290bade6737387c1e36faddbc9bb", + "src/transformers/models/bamba/configuration_bamba.py": "cf1b9f81f03a825255dc9546022356a6eb3b1e6f4cd664c3278759fe7d3f3f08", + "src/transformers/models/bamba/modeling_bamba.py": "420ea53746e558276e350de794ff5b234a79e1ec9170eba5fafa1f3f2a82b9c9", + "src/transformers/models/bamba/modular_bamba.py": "23adb0e8edbc56ba4bd0bea5ad53904d5963a7e36f42d6fcbc66d17691cac65f", + "src/transformers/models/bark/configuration_bark.py": "237c42af22103f191e551a5a1192a9bf4794e2b7861189a58cf6a2bbb75daa7b", + "src/transformers/models/bark/modeling_bark.py": "3bbfc82ff3f700a4b5deb860a423bf644ea26e7c3bc469407597351dda4533e0", + "src/transformers/models/bart/configuration_bart.py": "5f56eecabf2ff9bbf8f53c5df4b7368287e4626ea526bc1e7dc1776fc6922e5c", + "src/transformers/models/bart/modeling_bart.py": "8e8d6b713dc94cbced6cc05c54b0fc9217dcede141e79891def2b4c65cf205b6", + "src/transformers/models/beit/configuration_beit.py": "cc1c33ab0e97b5a6b1c7274d96b185d0484b7d18989720f18294dc656dac67f2", + "src/transformers/models/beit/modeling_beit.py": "e5e3be58febade51052456042defdb04f8bac95bc00ccaa7c299c2f847947a71", + "src/transformers/models/bert/configuration_bert.py": "924e13540f603e40b2b4ed51d7139eb95ed646a851d7cbb2ee6186da7c9829b9", + "src/transformers/models/bert/modeling_bert.py": "bc1c375d781fdbeb424e4592ea82be8ec25c4841a057f5f85e08f20f50d2068a", + "src/transformers/models/bert_generation/configuration_bert_generation.py": "ee10f065b884880ba1f6c9d08a966b2ed26812ee7602a7f565e7d87e683c1c6a", + "src/transformers/models/bert_generation/modeling_bert_generation.py": "5a0cfaf598a7be00806d9e3aa2a59c31aa7d4c671652e2e15cad600b14fe5080", + "src/transformers/models/big_bird/configuration_big_bird.py": "c328fdd36de02bef93bc9c02c34aac7d61e23ce32cd23932c894dc1b8ce5d35a", + "src/transformers/models/big_bird/modeling_big_bird.py": "1700b3b52a5ae85e5321367be5edfa1e3a4cc8e0709fc5c19b88ff4eb9be5523", + "src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py": "dc782ebd96e20824279e1c60c9dccdabfb4569da461a1054e6f46c1c15e5e8dd", + "src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py": "78b36af511d3a2e4df5ed526bb609638488ea5574d10b75f4d349b907fcf3077", + "src/transformers/models/biogpt/configuration_biogpt.py": "a34f6a9c53d25bea7d3df6b8a73612b287e754157e557c7061d45e4b367d36fb", + "src/transformers/models/biogpt/modeling_biogpt.py": "47a87ee17f7d5cca6221d5c87f1836af8c15da04ab5a16c0d3c5c67937d6b154", + "src/transformers/models/biogpt/modular_biogpt.py": "82a58e3790d5d236add39b6d7233da5f2d98108e1150e640e43f213952936e5b", + "src/transformers/models/bit/configuration_bit.py": "db00bd41886ba99305fde95b0b3ebad06ebfb143438441bbdf496697db50a78f", + "src/transformers/models/bit/modeling_bit.py": "44724c5dc21875fb438691a8b86039d60a34c903c6b0c92d88fc543bd1a57f8c", + "src/transformers/models/bitnet/configuration_bitnet.py": "7d79cd088e7cf06e5f27084c2aa98971ae146dc51623fd22611c3b29e23fb684", + "src/transformers/models/bitnet/modeling_bitnet.py": "1f589083a3b6d7836763de0b641b2555eafddb20c89aa7c985d7185aa30a7ee2", + "src/transformers/models/bitnet/modular_bitnet.py": "1bde132b6ae0857152632e990006a8872eebbadd76ac975c53194fe051896df1", + "src/transformers/models/blenderbot/configuration_blenderbot.py": "170ced3b7b82617975dfc05f34bb64d3c82dbcf1ef6a78b76794d554e7466f4e", + "src/transformers/models/blenderbot/modeling_blenderbot.py": "14505b502d4ec2d94bf15934752ba5b896caf478ba775a66f2e6d00ee51189ab", + "src/transformers/models/blenderbot_small/configuration_blenderbot_small.py": "03c1529efa869494c1293229409c9903184af564d857465d6c50eb8196433653", + "src/transformers/models/blenderbot_small/modeling_blenderbot_small.py": "64f35af1e75ac11f88bb294ec370ccde54782b1594bd96597870b9d7d2555508", + "src/transformers/models/blip/configuration_blip.py": "ac921b6675d31bae5a879571d49ac1fc60ebb5da0536fe636b7718d2f72304f3", + "src/transformers/models/blip/modeling_blip.py": "015f1b98b382a666db8c31711c6f295acd3c89254292401ff937f8c877f660b7", + "src/transformers/models/blip/modeling_blip_text.py": "eb0378f5c76c4a9ba5c1ea5ad3ad5287bf40e492afdce059bd84ddd4c1856087", + "src/transformers/models/blip_2/configuration_blip_2.py": "03e6a77499c9e5d172d6e04a69c1f126772bbd67a5217c41dcdc0711420063e3", + "src/transformers/models/blip_2/modeling_blip_2.py": "ed4e5d9fb1ad6557076af88f7679b91434d906b366acb4930887abd2bddef136", + "src/transformers/models/bloom/configuration_bloom.py": "47d054ceb7654b1328441e6f59d1afdffbd8f80685f708e5d1d30c5e53974324", + "src/transformers/models/bloom/modeling_bloom.py": "798a7dd43aa1e9584219255c36740873ea4d1f204c002ddb8ec7296bcf70128e", + "src/transformers/models/blt/configuration_blt.py": "9cfb74b1e0eb9ce9611a19942b6e82c32573470ecb8195c043e312e7b23abbce", + "src/transformers/models/blt/modeling_blt.py": "81e4bf6e0737688e8aa552d802c56362af7a171ab3464c546f5e73adc8505d9c", + "src/transformers/models/blt/modular_blt.py": "529435735e3413bcfc5c3a8096b56e58d57aa176a6628989c5d64039f8c2d0bb", + "src/transformers/models/bridgetower/configuration_bridgetower.py": "dabf6a2d4fdd1d38ae9f7b32c620ec7da52730e9533e668861513d7adfc7f7d7", + "src/transformers/models/bridgetower/modeling_bridgetower.py": "7e99550d2f015f6ce304bc73a461211080adfa52433298efb821ac81a75db03a", + "src/transformers/models/bros/configuration_bros.py": "af1f360c52851af9e08aa6767c39274d99f224e4b559fda04599461a9d6f50df", + "src/transformers/models/bros/modeling_bros.py": "0efca639237d5be48c6d28148d268c1728a5062f4422493fea5660e72c4a0cbf", + "src/transformers/models/camembert/configuration_camembert.py": "65bbc35964cbb42e7d06ea05128302237b80c16f8fb38b794e48299831bb0d31", + "src/transformers/models/camembert/modeling_camembert.py": "463f946576635f1db3013d5b5c77a98d99aea36c31b4e472f6a8efbbc47b24d3", + "src/transformers/models/camembert/modular_camembert.py": "12b8cfeeac5270dd43827aad7f3f3f61d67d781fd458b7d033f4daad44a386f6", + "src/transformers/models/canine/configuration_canine.py": "660d924669bd0ed8ecdea59870c8a351d3eda6d5b4850faa3a744db4a41a64a1", + "src/transformers/models/canine/modeling_canine.py": "5d180ff1dcfe4e264284338cfe74806392f821bc6abceb3c807249807339e911", + "src/transformers/models/chameleon/configuration_chameleon.py": "c180ffea27d9f06fcedccff1507367bf1e90aa60e933acaa7f5cce80162160d5", + "src/transformers/models/chameleon/modeling_chameleon.py": "5a234980db08fbfef75e932b3150f583f93439ca63939967a4d2395e59c153ec", + "src/transformers/models/chinese_clip/configuration_chinese_clip.py": "8b80e122141da04367b1d44fdee1ea2e2ca3d0e11f65be2fcc478870758f595b", + "src/transformers/models/chinese_clip/modeling_chinese_clip.py": "17c95c09d63f1a0a9304b55e6993e3b2c7edfb32b740b1a462171d20dfc7e310", + "src/transformers/models/chmv2/configuration_chmv2.py": "a6b4ecdfd6d5f728ba49e2fa06a97c469d32802ca4910eae1dd5b0c23c6dcf70", + "src/transformers/models/chmv2/modeling_chmv2.py": "2124cf933637bb60bb831bd3a632fc00e0d16c314b0e0a43e051370a738e58fa", + "src/transformers/models/chmv2/modular_chmv2.py": "27d4a6c27e4625bb9c881bad275cfc486db675a9581aed8a026e0e6e9db934cb", + "src/transformers/models/clap/configuration_clap.py": "0a4e35390b3a48ed865f342772e88f92d2338730f03bb1b95a3e66de3d7cfeba", + "src/transformers/models/clap/modeling_clap.py": "fecade298c566bc1656c64392c74ec6c7964a88c79d70888f48d746deef270eb", + "src/transformers/models/clip/configuration_clip.py": "50f7854da6572cf58f4c7087113a1cb5c0e5a37be413cdd33d820640c1cdfe45", + "src/transformers/models/clip/modeling_clip.py": "4f6dfbdd988328fc585c719ff88a516ceb90984618b82ab3acdf828becb741dd", + "src/transformers/models/clipseg/configuration_clipseg.py": "7676e7950a04a2e9a0e225f73fa1229362c2388991deedf4f27c639323badfc0", + "src/transformers/models/clipseg/modeling_clipseg.py": "20482c86460de66145ae112edc25ff58561765309604be426df2b20c5bf5eec8", + "src/transformers/models/clvp/configuration_clvp.py": "aa0fa5ac98e1a2e8dbac9b1c32e1bb2d6c8824d81853d1b40717cca15d46ade5", + "src/transformers/models/clvp/modeling_clvp.py": "2979ab9758d37b4b555db31b686b073c0b9e444253cdbfe912f15820fc1d7f46", + "src/transformers/models/codegen/configuration_codegen.py": "14cf06cc4237fbab5836d8607634e98db62ff778eed49b0060705bd7d900d99d", + "src/transformers/models/codegen/modeling_codegen.py": "410f4e458ba972a4431f19833397013ba311933b6f726d3280157269978fa5d1", + "src/transformers/models/cohere/configuration_cohere.py": "eb90b47a9977d1a201ffd823ac25c838e467412ee98c3d6aef81fc538c3975d9", + "src/transformers/models/cohere/modeling_cohere.py": "64fb21af826bf06f77b3e5cd3a41bb5f696aabce3afbd8f8ef73860cd37dc105", + "src/transformers/models/cohere/modular_cohere.py": "1dc15708e9c61f30dfe4cf43224017771355d0e7659f23ca8dde06906163e1e8", + "src/transformers/models/cohere2/configuration_cohere2.py": "c6a6e4e4cac03ed563f6407709cae4486ca131976a0152c4bfa58587a8a04fc0", + "src/transformers/models/cohere2/modeling_cohere2.py": "59e2f99fb0a33b3b74b63e6ba3869e8cecfcd70cda7f2131004699d7b7843949", + "src/transformers/models/cohere2/modular_cohere2.py": "32823f88142efa0826575ede028363cd721d92953117930b1033665f0a796e39", + "src/transformers/models/cohere2_vision/configuration_cohere2_vision.py": "cfc86d90cc8f9b71e8ddd4872326dd00e442a0ed5356ce21b0faba0590e090fe", + "src/transformers/models/cohere2_vision/modeling_cohere2_vision.py": "1aeeeaaa5b3a1741857a3559df0c919417255f4d22b10d811bb79d44cc22d03c", + "src/transformers/models/cohere2_vision/modular_cohere2_vision.py": "d889425af19565875b2f69af5410e02686e4e9f2545975b21f41776956c5f255", + "src/transformers/models/colmodernvbert/configuration_colmodernvbert.py": "2273072e47fae401bdd8a8e0bc3eced93033761a442cc5831357a519b125b9f7", + "src/transformers/models/colmodernvbert/modeling_colmodernvbert.py": "03ccbe74c30b366bd75728b749ecdd509f3324363a4a8ef605cf91acd40ea869", + "src/transformers/models/colmodernvbert/modular_colmodernvbert.py": "46ee2364642acd94bf2104593fbf9bb05f7a3ad19da173ec7cd2e84321cd9ba5", + "src/transformers/models/colpali/configuration_colpali.py": "769aabe840281c95bb22b9961e94fcbc61449b05149835c2f1b3b8c258ca41ea", + "src/transformers/models/colpali/modeling_colpali.py": "4c95c77f45d0fd414560068a4046d8323e2396c4f6b834945035e5677adba8b8", + "src/transformers/models/colpali/modular_colpali.py": "36fce66ab94350016138dbf44c63b248d51786a615548c1c73842a7dde56ec8c", + "src/transformers/models/colqwen2/configuration_colqwen2.py": "3d04aef7a93e9daf9aecba1547b74e9822d73b93b09e4ef7061425c94c6e7ffb", + "src/transformers/models/colqwen2/modeling_colqwen2.py": "b82dfb4975f1699fc2a79737e65435f04fee75a06f34e02b6ef2e68207e8c033", + "src/transformers/models/colqwen2/modular_colqwen2.py": "8ed503c2994674cc48897be4c2721723e7d617e3cd196a917dbd93b1c10f991f", + "src/transformers/models/conditional_detr/configuration_conditional_detr.py": "457c3bcfa4fbc6be338378385016fd756b8a4b0e486e3e7ebf85564b21d21ee3", + "src/transformers/models/conditional_detr/modeling_conditional_detr.py": "7dc0aee54e6404bfcdbae3a892a9b790e1d449d6be04c485b596e0773f54017a", + "src/transformers/models/conditional_detr/modular_conditional_detr.py": "3eb3d1dbf37cdf1758ad78a4b982fd8bc957c4a42176dd2ce37cea9c1f48ae3f", + "src/transformers/models/convbert/configuration_convbert.py": "fef352f0db34a64ffcb041cf21dbbc20e900451356b1fa253719ede911690239", + "src/transformers/models/convbert/modeling_convbert.py": "e0f1d3ae8512bbc6dde76c6bf873020fabed3597bbd95c41efba2b5e63763f8b", + "src/transformers/models/convnext/configuration_convnext.py": "df583a9c6a371c99bc30ad8f85db06cee883583e076531906d7e46624630bec3", + "src/transformers/models/convnext/modeling_convnext.py": "a18de5845a48b4d8b534973bf0401b6478eb16dad588287fa4ea794e0d9f0a19", + "src/transformers/models/convnextv2/configuration_convnextv2.py": "5a90b69b59adf695fc52841fe6a133091e69296113432b5f76728ef4b9cbcdbf", + "src/transformers/models/convnextv2/modeling_convnextv2.py": "ca087067320d3172df9b8f8c51264e8604a5f2057a81c0bc83c0dfca811d00c1", + "src/transformers/models/cpmant/configuration_cpmant.py": "e36d39c95f9a0359b69cbe16b1bc8d4cbf9ef78df7c65f10db52923ea5227140", + "src/transformers/models/cpmant/modeling_cpmant.py": "b67849add8536b34b2ec1fea4d34a0eb7ce3c3ed12c368ebc7eddf83e8a2c151", + "src/transformers/models/csm/configuration_csm.py": "580c7b5e4f04cf18685df855576c732b4a62b0ac5122ef07a98da2c17b0ab573", + "src/transformers/models/csm/modeling_csm.py": "e36148fc0c7785ec8f6e951e441cded0d2cd45611ebcdca5980dee2f3316703e", + "src/transformers/models/csm/modular_csm.py": "caac353646e7891bd3715ec04b9156e967e7dd81a758dadfa919d3a160719224", + "src/transformers/models/ctrl/configuration_ctrl.py": "4f1be8b7f1d941bce79fed85f1c353aa0b4c75966ff9350781bea78298d84c2a", + "src/transformers/models/ctrl/modeling_ctrl.py": "3d2185a273fbb35959bee234c1a2b320b94cb484e11fc35e4346f4b056f2f966", + "src/transformers/models/cvt/configuration_cvt.py": "33ad93c650aa04394ff3a4c4012c3bfd4efc4dd3b9dea64635f04abd966d064c", + "src/transformers/models/cvt/modeling_cvt.py": "7b71afc92f44097c0db3d622ffdfa95c858a13bfe8e0a27f25600e4d61f4a1c2", + "src/transformers/models/cwm/configuration_cwm.py": "b0736e8ee6ac08c8559cd9ba3c4613a9780914d8bfaac1f34c09b7d683e7465a", + "src/transformers/models/cwm/modeling_cwm.py": "b4fe43d614b9c21cda5d047b392d1527ee89d9f1582332df7d452565d72b5a22", + "src/transformers/models/cwm/modular_cwm.py": "eaabf54d8f7f97684a4f4cf1a3ab9ca8bb9f7f098ae2b98cc006f02aaa07328b", + "src/transformers/models/d_fine/configuration_d_fine.py": "57f8dbc1a5b0d8d8bc55b00cfcc729ea2c090489f71d0d7402433c7c7a4ff06d", + "src/transformers/models/d_fine/modeling_d_fine.py": "f4d4149af92a2d992aac90d2e9a9663d724d9df711ea319ac0e45320ce2b6849", + "src/transformers/models/d_fine/modular_d_fine.py": "1987db0582e024b9cebc92092a1db6e1c9401a8df5a30b5baed20d59dd8c3bc8", + "src/transformers/models/dab_detr/configuration_dab_detr.py": "044aa0b1f0972e9fda615bcc38a65b7a140b29bac2bfb105abf3494c0065a39f", + "src/transformers/models/dab_detr/modeling_dab_detr.py": "64be492568f9343e130a8e2be45d7f2559cd26d4a033cc904989f0169544b508", + "src/transformers/models/dac/configuration_dac.py": "12e72cf9357385a432466d6d6e91718c645db7e2095007feda71d653862f7768", + "src/transformers/models/dac/modeling_dac.py": "6e2837eb3c4dc681aaf68d7290beb5c6b095eff0cfcabd71e185815178fdaa3c", + "src/transformers/models/data2vec/configuration_data2vec_audio.py": "c79ee10106d926c90bd0d949feba8f10a9f0e8ef3305e17c00123f9ea5b1ec02", + "src/transformers/models/data2vec/configuration_data2vec_text.py": "543a47ac859d73c09194d2c5523da8b4121179af422b0d49d8ce2ae2dd75eac4", + "src/transformers/models/data2vec/configuration_data2vec_vision.py": "53fc204a7441a7d850353d0844b529e890cfd943760998099f738fe4c776ebc4", + "src/transformers/models/data2vec/modeling_data2vec_audio.py": "32326098589c9157b17601dd6e97c7b8318a4521d8fcf9dfbe8261631b5be734", + "src/transformers/models/data2vec/modeling_data2vec_text.py": "e7fe8a8ae658cae940421ce375ded83c2f7332e2187c15d2179e15baaf80106c", + "src/transformers/models/data2vec/modeling_data2vec_vision.py": "1a5fefca4e9018ad8d18142645ba4421d8311b393b2a91c6caf94ee25cd256e9", + "src/transformers/models/data2vec/modular_data2vec_audio.py": "b0a3e99d2981097dbfedd8460a9506d618179475beb0c0c2150d1632dec229fa", + "src/transformers/models/data2vec/modular_data2vec_text.py": "750785d49fda1170e291a5a3c5ad73c8f0211579e081e47c4dd1eafbc946d3d7", + "src/transformers/models/dbrx/configuration_dbrx.py": "beb361d07bf94f9a2d469cb9725b1a6988e6a819770fd7671743e5e84b8049fa", + "src/transformers/models/dbrx/modeling_dbrx.py": "2a2b9a881ac1f1f3ec41af08141bd2bd2b5d0fbf734034c43f35ad88f7b35018", + "src/transformers/models/dbrx/modular_dbrx.py": "e2602ff5b5deea70d6049a93926cd111fb5ffd2a658ece949828dc1c2185b97c", + "src/transformers/models/deberta/configuration_deberta.py": "8913b60a0a2f8b352f89db50123714430a0f5ebf84bedce8261fe8903ebf233d", + "src/transformers/models/deberta/modeling_deberta.py": "538de8ff25de19dade5b45b319eed2217cacd4e384a5a5d169fcdd609adc2419", + "src/transformers/models/deberta_v2/configuration_deberta_v2.py": "dd8976631d74b432d7f8ad042be77c978c45e2ec34a21275c0e49b616b9cbe98", + "src/transformers/models/deberta_v2/modeling_deberta_v2.py": "c12c852843921037e454dc0cdd41733c67eddb8d746e611334deb300f00e1e84", + "src/transformers/models/decision_transformer/configuration_decision_transformer.py": "263b4494b5aacff8b786af5b6eddf4b55561c1219db9302c653947eccbeb79a2", + "src/transformers/models/decision_transformer/modeling_decision_transformer.py": "ed48dce33a36564ed4870d6245d0ad2048845909ca8271ab04d18c1dea478400", + "src/transformers/models/deepseek_v2/configuration_deepseek_v2.py": "b0798b62092d04d0a55d0e0e8c2ea8623d436ed0324da31f7df29b5f9d2d8ffa", + "src/transformers/models/deepseek_v2/modeling_deepseek_v2.py": "215f762dc3c2cf50046d15905a00f2fb3a2c5441da01cccc657ffc3054f40757", + "src/transformers/models/deepseek_v2/modular_deepseek_v2.py": "37df6ba10372a53146a3beda9e4a57f85cad362a311151f78d271660bd583fcd", + "src/transformers/models/deepseek_v3/configuration_deepseek_v3.py": "d5b1a9f1d1eed3215ea4bc85f68a17fdc1cdba508f636c799d45532341698603", + "src/transformers/models/deepseek_v3/modeling_deepseek_v3.py": "4d71f61941533b2d9b7c9271195b2a0f50ec422b7ca58230de716b6fb8fa3212", + "src/transformers/models/deepseek_v3/modular_deepseek_v3.py": "66c80357dc4ad27a4fe5f059178ff44a81d0e9e5201e36c783cccede4fcdc7a1", + "src/transformers/models/deepseek_vl/configuration_deepseek_vl.py": "5455bfcffa4ae5bb841e37ddfdfb0965293963354b8cc05b2786ee0ef301ddac", + "src/transformers/models/deepseek_vl/modeling_deepseek_vl.py": "b23d11250b155e480a48fe9417bc4e1be3f8ac4d08cc93535375b5b0dc1d4ff9", + "src/transformers/models/deepseek_vl/modular_deepseek_vl.py": "635e1ee771e5fda88db6759ff81222685523ea0a630fe74532a0f63b318cd5e0", + "src/transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py": "54c67e75d0923bd643e0dc7afea40880f7a13b7c8f96120169f5e9af44555bca", + "src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py": "e72d6f404b762711c9f09f140d2a9620d1ae760a6fbe411979a5402555b63bc7", + "src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py": "47949abf6b64854955f64bcd62584b07686b0f30afe549319f91b31d253d8962", + "src/transformers/models/deformable_detr/configuration_deformable_detr.py": "139c7c496b0fae37c5589e0d3a4727614afbfdff762a1046d43355da39f2d17e", + "src/transformers/models/deformable_detr/modeling_deformable_detr.py": "6748c8f13c078922dda79652cc73d89f199d180f4c34308aa933a0c30d79c1a4", + "src/transformers/models/deformable_detr/modular_deformable_detr.py": "c842574153dc2b512244bd42430f5f8692872b5d115302c063988bd3e903e802", + "src/transformers/models/deit/configuration_deit.py": "14b0f4b247d0c88c84d6b2392ba8fe8455100b5b406867c54f0f468362ff9d71", + "src/transformers/models/deit/modeling_deit.py": "4652df841649330c2c647797b8bde19aa1a83074a29eba32c1ae662807460885", + "src/transformers/models/depth_anything/configuration_depth_anything.py": "2984b681dd707e858540686e8899d802b656e108f63b15cb33df88c16a46a1ff", + "src/transformers/models/depth_anything/modeling_depth_anything.py": "9fdf71845fefaa3f03ac3d8eed95131145bf03918f48c4406bc8d8dd3caab4cd", + "src/transformers/models/depth_pro/configuration_depth_pro.py": "1944a304cde013988e18f6caebedef414871bdf988be72ac4094c63f2ea59301", + "src/transformers/models/depth_pro/modeling_depth_pro.py": "3a4fd7a2f9cb24970f248f7e2ccc71a0f002752765514c93cc4e250deb5fd657", + "src/transformers/models/detr/configuration_detr.py": "ff234ca95fd31187b559b12c52fb508d36196667e7f276fd65e7b9963ec0b645", + "src/transformers/models/detr/modeling_detr.py": "a37df6a1ab87689ca259a31b61064d3c59cf5604748b5560fc1d131a505dafc2", + "src/transformers/models/dia/configuration_dia.py": "d080c34197d6b5083a3911f2975b262b86422e0e41c1bdb0a53cabb03d008a88", + "src/transformers/models/dia/modeling_dia.py": "f713fa943dc5f20bfa48db7afcacbcda34c0d92d01c7756c82c1409531148445", + "src/transformers/models/dia/modular_dia.py": "ea6fd0cf5b66505b1c08a1be2548ad86421cd0ea555c6e44082de4d37bb37320", + "src/transformers/models/diffllama/configuration_diffllama.py": "2eace633c0f28f912853ed476d30bd53e2f4a4cf660bf6fc89868d228df15d82", + "src/transformers/models/diffllama/modeling_diffllama.py": "bb977358857f1a378fd66aed97fcfbc18e382ca2a04c4fe523d78e7d792514fb", + "src/transformers/models/diffllama/modular_diffllama.py": "92c69dc0f0a515bddd6f94443ad2b10a82a8225eaeaaf95e8de6080764e2f16b", + "src/transformers/models/dinat/configuration_dinat.py": "601f91c7e337648c205244b9adb7139ccb93254b67259ab0bc211d16bb723150", + "src/transformers/models/dinat/modeling_dinat.py": "1283c47a7b334529bc4bfe23bc464495f0825054890cf95321c9638f6e2f442f", + "src/transformers/models/dinov2/configuration_dinov2.py": "7131b13fd760b969cb6b67428afd70821fc1fc1b9aafe5b3dc275fa05bbff420", + "src/transformers/models/dinov2/modeling_dinov2.py": "7fc6541b122d81a8026ab11a0b17a15bc960855c4df60b3d0e6bf311e1a9ef31", + "src/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py": "2de0c55dae17da749fe402f29c48e15d3d8acba62cf1ccf2ece307b7353a4a38", + "src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py": "5423f67c492393a598be37955b2bee1001a76c067c5c12dc82d526ad044a25bf", + "src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py": "ab18b1abaac08759c8e69ae690677bd32574d68e31d01be1d057841fcc718e25", + "src/transformers/models/dinov3_convnext/configuration_dinov3_convnext.py": "1a4e095afabe1bcc6bdffbc741f6b5c9b5fa51aea352fcf6559f18913bea6757", + "src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py": "bdcad82600f1a10b8a440a9e8978060d2da645f4391a88c01f749322d76146a2", + "src/transformers/models/dinov3_vit/configuration_dinov3_vit.py": "e39f3b6e93bba2aac3acf7b3925c461be8ba1a6fe188e64c3f3a48fe26f65129", + "src/transformers/models/dinov3_vit/modeling_dinov3_vit.py": "5f69e2640c8e62c1407663f3d48e823be098adaac4af0beabcbe75caf578773c", + "src/transformers/models/dinov3_vit/modular_dinov3_vit.py": "7e07ea03c3a2cab3eb1563b6e64dd93a1e4ded7c2589febfca0a605535cab983", + "src/transformers/models/distilbert/configuration_distilbert.py": "3fec20ef91059e5f8da778748b1dd02785ffbfd5535bb7558d0f80649f6087a1", + "src/transformers/models/distilbert/modeling_distilbert.py": "c613d45b2e6641e57b14f7b49d65fd8757fa1b0cef9ef2bf632642e1a67161fa", + "src/transformers/models/doge/configuration_doge.py": "e07aabcf9d9011a000c98dae9cd4e30d3b56935ecb69dc439fa55dbb9e42ddaf", + "src/transformers/models/doge/modeling_doge.py": "412d27f4bc2899df2b0557b8850098cb67a7f1227aba90fd6870e225a0f0f82d", + "src/transformers/models/doge/modular_doge.py": "b7fce8eedadb1dcf2ad1cd9995eed417c3c43a0c3a82033cb4d5730883fd3194", + "src/transformers/models/donut/configuration_donut_swin.py": "dcb89fcea06eeb3cd42ebb4082a1801ab329a788c6506c8769f3f30546002c77", + "src/transformers/models/donut/modeling_donut_swin.py": "6eab5e6afa15cd5a63f6d7b9b2907038071ddfaab35fa53ca636fc46359e1caf", + "src/transformers/models/dots1/configuration_dots1.py": "a8d761c4a70d80581311a7ba7c856ee605329cf460d908661f902bfaec07d0d7", + "src/transformers/models/dots1/modeling_dots1.py": "05d5c96c50096e4681bf30827b0eb656f56c511a883d912166bda19d7ade7903", + "src/transformers/models/dots1/modular_dots1.py": "381cd5bc6bbb4724df36a7c455c75ea0a6ce5149852ba298d47b6546b2c4c1b8", + "src/transformers/models/dpr/configuration_dpr.py": "59d713cbdd6f1be276c8cd515bd03653f9354b6fce9900d85758ab338f74a02d", + "src/transformers/models/dpr/modeling_dpr.py": "b8d5856d4a1d1e7abfd5dc55b4c7a5da5b8f7bd365fe19ce832b944093bb1566", + "src/transformers/models/dpt/configuration_dpt.py": "d90a7848c48c97f93cef4de36640b32ad9c5c9ef2c57809dfacc0cbe4508d5b6", + "src/transformers/models/dpt/modeling_dpt.py": "f9e205ccfe355d6623e3abb6369590ac121f9a41af4bc8abb09320f76a743114", + "src/transformers/models/dpt/modular_dpt.py": "6e7c3269fd38f84c7a03c1a67b3092ce36990eb6744ff559ab450bd08a583e71", + "src/transformers/models/edgetam/configuration_edgetam.py": "b2459a794266c3e394aaca7b5f4c17992f3eacf615db7a8611b096bbed4fceba", + "src/transformers/models/edgetam/modeling_edgetam.py": "5ac7d7798397ecbd8cf337d4e3e3f56e67899dba0f6d5bf546251f0cec990028", + "src/transformers/models/edgetam/modular_edgetam.py": "34b875d10c8a310b1bf734d771f9161c756606e4774f4de4826d70ba131cf391", + "src/transformers/models/edgetam_video/configuration_edgetam_video.py": "86f8eed2cf7a2ff4c6aae00529e056408c0f335df2930366ce95f37888c9f50e", + "src/transformers/models/edgetam_video/modeling_edgetam_video.py": "9fd826a6751ffc942f865f2b39dadd2b10a89f6751a6186cd7ab51a828d5e2f1", + "src/transformers/models/edgetam_video/modular_edgetam_video.py": "bf2abc99e6926345aefd7a622ff58d074cbe0250f7603df5ca3ab372357316fa", + "src/transformers/models/efficientloftr/configuration_efficientloftr.py": "268e037030682dfbd74f9394e87869ec75a2c4b86b4d995f3fc1bc894a3a1fbd", + "src/transformers/models/efficientloftr/modeling_efficientloftr.py": "d4f097fa7ffc0199f43c244dbbfdb042202be58ab8615495768350ed7c7cb3dd", + "src/transformers/models/efficientloftr/modular_efficientloftr.py": "10ba603a0542d60b3c12775862cd5bfa25bea57d42dd83d726b18f33f467bc3c", + "src/transformers/models/efficientnet/configuration_efficientnet.py": "3ea809766f7a479717d8d90896a29c1e7468c85e1712a9f3ba2fd84ef0547057", + "src/transformers/models/efficientnet/modeling_efficientnet.py": "31a5e2555e9784a5be395ec10755131699ed236aa37aebe12d9c7c3ca8d1f964", + "src/transformers/models/electra/configuration_electra.py": "53e603b02d3def048b998ea7dff4dc89505aa1d8af82f7832d09d0daad48b438", + "src/transformers/models/electra/modeling_electra.py": "d8bc54ad919a0d3e2cb31d765d8a501811b28b0b4f97502ee3171eb5a6b4ecfd", + "src/transformers/models/emu3/configuration_emu3.py": "edca6bff04b5bc3f4efdf51edee84b673f1ae5b2bd25048af61693e6e3015055", + "src/transformers/models/emu3/modeling_emu3.py": "6d1b9d472d94aaa1ec4bc36a7ad7bad41f6e59a10137612c93f390efb293da1b", + "src/transformers/models/emu3/modular_emu3.py": "4e6dcd4810d409445bebed5531e6812cb484619427f046a4d324e14d3140e20f", + "src/transformers/models/encodec/configuration_encodec.py": "cff01a374c1b9f206d3317014574d5715b211c77914a75cf2ca9c641d33b5f40", + "src/transformers/models/encodec/modeling_encodec.py": "39ccce4266ceecf45ff2aaba2743e1b9aa5937914214ef6bea18ff136d30973c", + "src/transformers/models/encoder_decoder/configuration_encoder_decoder.py": "585faa9249a5f69b2a512e8ebe042115d20c9aaf8019653ae9cfa1adbd0a8dc0", + "src/transformers/models/encoder_decoder/modeling_encoder_decoder.py": "fbf66b3d8909ffb835022cc7ad6f53bb6d2476a0c71291402dd5ec393e01ba96", + "src/transformers/models/eomt/configuration_eomt.py": "6608061c3f3d8c7f5bb866bba9b77c31c123c18be7ee67ab7dc46bb44666ce7f", + "src/transformers/models/eomt/modeling_eomt.py": "2e931b265494c224a0898f5f9e10549996b0bce879052f09038a1146ee7ca2cb", + "src/transformers/models/eomt/modular_eomt.py": "aaab107c5f91c5b67a35c7f885edfb84feeebf9c97b65aa5d1161df29f405451", + "src/transformers/models/eomt_dinov3/configuration_eomt_dinov3.py": "5ad8a03cb1d95bd9927faa261084ba36d5663edd37216537f5f5bb1e196d5349", + "src/transformers/models/eomt_dinov3/modeling_eomt_dinov3.py": "09061ee86dfe85ed20d1105565215a7dabe4cd50fedd3ae8ec57c9075b9a8684", + "src/transformers/models/eomt_dinov3/modular_eomt_dinov3.py": "abe6347307a7896f7edbc918ba4102dba19adc404056d6b0e3abd1272e5ccc5d", + "src/transformers/models/ernie/configuration_ernie.py": "65957de9d59960da78f00c5e1b2d436f574e4034099b93292567b656e71d2e8a", + "src/transformers/models/ernie/modeling_ernie.py": "b3d71e2b93ddbfdc38272f600cf1b57b7bea8f02db7751b13c1aa0ce7fb59d7b", + "src/transformers/models/ernie/modular_ernie.py": "5807a39769ea643f3e2f92d6fff5cf68888cf0d0387aa6002f6d7fdc2d2a2366", + "src/transformers/models/ernie4_5/configuration_ernie4_5.py": "c91ab7b3c45023b7ad24f2334be76d39bfd0df25f776366a15cc951513206587", + "src/transformers/models/ernie4_5/modeling_ernie4_5.py": "c50eea1ca81685f6933be53c76aa2e4197d4d7c9ca1f527a0480dc22b01f5eab", + "src/transformers/models/ernie4_5/modular_ernie4_5.py": "974be5f0a88e34664fc04e09a8131b28cf8fa31bfbef2ac5bf3e88ef27c7eb94", + "src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py": "d68328c29f38f1dc24caa1e4bc68a1b09d29797c54cecd74c4f48de365b52684", + "src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py": "394cd12396d84aa2bc4c1b48dee65d19ee445cd1bc9acb3aed2bfca2c623ee3f", + "src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py": "d581c3267f586a3ac16e55000dfba777b823f5f97742a403c32262e3501d102c", + "src/transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py": "1ec22399dd386689e7bd4d027c044e6da74bd21e33f2c7abbbceebd5526ab924", + "src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py": "27fbdd60aa982ef09ff8990a7342033b9be97b77d9027a48332ef7134144d79d", + "src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py": "4a8dc470910160bb2363915f49da199ee4026e40936a342f7731ee36803139c6", + "src/transformers/models/esm/configuration_esm.py": "581936d5f9791a3024d9c7776ecdf6aaaefa6603abc51dfeea9c01c9c20cc4b9", + "src/transformers/models/esm/modeling_esm.py": "9c83f87f554820e8bb5a2b7b73a455dca54ac48b21e7f3a5c34b7e042cc332aa", + "src/transformers/models/esm/modeling_esmfold.py": "6bced2ad7502f73bdb8026d817cb0bcc7e7d2c787c4bd6d25055bb0f8afe25c9", + "src/transformers/models/eurobert/configuration_eurobert.py": "3686f8d2537d3c0e7a4357f43ab56efcacdd302e0f1162d481b39412617e1423", + "src/transformers/models/eurobert/modeling_eurobert.py": "aab3bf203d1d5ef3e89299a8b62553ba0aec80e994a7c875cdc952df4c325b14", + "src/transformers/models/eurobert/modular_eurobert.py": "30f8134a3974a4cbba52e9c13dad659fb80986ee9ce4768e23b7eb266aca5ed1", + "src/transformers/models/evolla/configuration_evolla.py": "1480f8d162dd26f21a56fc83b0e24093d5bb41e88747dc5d86cd13341c1a3c66", + "src/transformers/models/evolla/modeling_evolla.py": "b75f66f3e7cf6ce873f8019862971a671b5f4e17b0b7c88d30685e59b3db4c9c", + "src/transformers/models/evolla/modular_evolla.py": "cba489702e1543131603015b805a84982e94192041acb80f7f1ab1a6582deea6", + "src/transformers/models/exaone4/configuration_exaone4.py": "6e5c461f84881ec9835633d506815ef2926e20581545367c1d3af5f3c8964cc5", + "src/transformers/models/exaone4/modeling_exaone4.py": "c9992e4049d7bf4022392e0c6c9391d6fcfab3cb486b7a89455324ba9f5f9562", + "src/transformers/models/exaone4/modular_exaone4.py": "60ae4bd5c733f8769d8cf466a7af1efb265e89570ef25ab46162b16b3ecfb8f0", + "src/transformers/models/exaone_moe/configuration_exaone_moe.py": "9dd01d8361be9d8c71a85c8746fecf6664ca1e703c2ba61c657439aafce576cf", + "src/transformers/models/exaone_moe/modeling_exaone_moe.py": "a56541292aa6e82de28b405bdc62272be44ef73c83ce56a17cd06fb43e597bd8", + "src/transformers/models/exaone_moe/modular_exaone_moe.py": "1749ec26996b110a47d61fcbdaa4e738e094c48eaf1d7d5de02974324c3044cd", + "src/transformers/models/falcon/configuration_falcon.py": "56c34229dfdf286a23f571db23114247e1cbfd2761234b46e4496b15667bf3df", + "src/transformers/models/falcon/modeling_falcon.py": "859b28338f403d6b8ff70c713006712fd1d4b38c7548109d459c253b4d557a5e", + "src/transformers/models/falcon_h1/configuration_falcon_h1.py": "009a43ec8826b15120d456f01ac9f03f31f8e419f3fc2ce415afdbf3b0872c11", + "src/transformers/models/falcon_h1/modeling_falcon_h1.py": "e316f6397e9048fd91650aeb9c0365be1f274af9e4a2a939c92296e7b19f3962", + "src/transformers/models/falcon_h1/modular_falcon_h1.py": "9417d5f47849a74bc76a1bfc4143488e1edfd39a7839ed1a2c37a521af6ea682", + "src/transformers/models/falcon_mamba/configuration_falcon_mamba.py": "0188f365f950e47ea2740c4c5c7cc88c57a1d5f9ad593624b891f81bdae1bf34", + "src/transformers/models/falcon_mamba/modeling_falcon_mamba.py": "794dc45783ebd5b09e058eeccff800ab551df1a6da29a1b4cae45d0ded8229ad", + "src/transformers/models/falcon_mamba/modular_falcon_mamba.py": "3c724cc5ed6eb8111a67c72970f401f4b976a77a0014956759da4d7ae2995655", + "src/transformers/models/fast_vlm/configuration_fast_vlm.py": "54da3584a0e82d9d7273832f31eaaeec87ab945e32b630059573d44354610523", + "src/transformers/models/fast_vlm/modeling_fast_vlm.py": "3cf26be59ca4d52e3775dd7bd11c397d898fec1a89642fc5a05fae4e9e1bbc3d", + "src/transformers/models/fast_vlm/modular_fast_vlm.py": "45098413ce521bd5879c9069a153eeaabe1fe2921257cc6c2c28df388771bf50", + "src/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py": "ee554313d2db445a950229a8b40da219d89bd7c554b4dc899f4b04c7ae94de6d", + "src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py": "ffa5010ca7b592883d770e3cf74fbc1b047ed35c2640653bc39fbb999ed75a8e", + "src/transformers/models/flaubert/configuration_flaubert.py": "9633a95754e1b5d94e8ef39247df3e5c66f9685172eeb74471dac554c9229109", + "src/transformers/models/flaubert/modeling_flaubert.py": "6e41683ead49a055b55c4f0b5da8c033003c7361bcddda9944d0c4c606b65348", + "src/transformers/models/flava/configuration_flava.py": "5cc1f71e11d317a702f1f780705b972730fca26f9fba4be31a245413765f5a58", + "src/transformers/models/flava/modeling_flava.py": "9aceeab0106a822a504116997c1f651657f049e97a15716016643751088746d8", + "src/transformers/models/flex_olmo/configuration_flex_olmo.py": "8f14425e503c491709d19e28119d298ccb2f3370cfe8f1fc416d309f9582b01e", + "src/transformers/models/flex_olmo/modeling_flex_olmo.py": "c5a17d5d9b396d4f478a60668dca5dc3821b59716912962fc593e29a273c722e", + "src/transformers/models/flex_olmo/modular_flex_olmo.py": "45f30b06117144c38845c3cb1bc48c6c72f3bb3765da2a815fe8e9cc205d275d", + "src/transformers/models/florence2/configuration_florence2.py": "f36a93ffc5098b805c4c84aa0ceda1b74bc89eade5f52d68937b641a7b24d3f9", + "src/transformers/models/florence2/modeling_florence2.py": "26cd42ece585ab1e3a0a897b39af68c3a1c1939844121d432595c7b7b106f2fa", + "src/transformers/models/florence2/modular_florence2.py": "773d24b542cb6127230dabdb687c2691b9baef8c4db2d87e9ca7fd626a3780a2", + "src/transformers/models/fnet/configuration_fnet.py": "ab38cf5111ead363a8d78cb0569e9e697dee4e59eb58104664fc0075e04b7b20", + "src/transformers/models/fnet/modeling_fnet.py": "83ce4b3f477d3d86d17b2db152b7ae4b7de48d5b55143c025045a3ad0e70f4eb", + "src/transformers/models/focalnet/configuration_focalnet.py": "fa915e41fd54d84fc00a7b9a57d78e280ce6cbe26e11a4548d7ff5084239b457", + "src/transformers/models/focalnet/modeling_focalnet.py": "d44f280d6e8ad9073250283f115d31e5e6a88b1135fda8c8943a7576edb129ee", + "src/transformers/models/fsmt/configuration_fsmt.py": "1df5942db321624116b9701d6b88b1b54b3468adcd56b2314bccbb945f5d2f11", + "src/transformers/models/fsmt/modeling_fsmt.py": "cebba339f509a01042df22043223fa79f31f915507a9dc394b2386a26ed7eb18", + "src/transformers/models/funnel/configuration_funnel.py": "2452d83f9e33aa1b985b1de72023900fa26d3acbe17c4339aa6c7785888e30af", + "src/transformers/models/funnel/modeling_funnel.py": "855d3a21b0ecc1d9bb2be7391cb052813231ac83e6343fad7575ca51062f5f28", + "src/transformers/models/fuyu/configuration_fuyu.py": "9eedd6492e19520972b70a720da122f2d7f249a2330bb0e2bd5d1f5845236f4e", + "src/transformers/models/fuyu/modeling_fuyu.py": "6e31a1c3cccb22b7f76ec83a5c784b547bd46277dde5949fe55dc71ff8c89fa7", + "src/transformers/models/gemma/configuration_gemma.py": "9117abb75749da1246cf7faa985c46411b2a91a1e93a1b0e4dc977d1654866f2", + "src/transformers/models/gemma/modeling_gemma.py": "f6f0cd5a7b7a1b895476abf3e5c2c88246de2f6e14d6a38863333452a25fab83", + "src/transformers/models/gemma/modular_gemma.py": "4b7d117bb6e128c2588bf7525e19c5df437a93f2265a1271d3dbbee683420d46", + "src/transformers/models/gemma2/configuration_gemma2.py": "8c643de0c94443194f8de3eb6f37ab6156ecc5889da7c7ab20241d5b78036e2e", + "src/transformers/models/gemma2/modeling_gemma2.py": "153226bb60dd43429842901b1b6ce079994e4662786afcafe02a02f5cab79c67", + "src/transformers/models/gemma2/modular_gemma2.py": "6fea0d4fcbbaa6d1a42e0d846beb907ba957f9a5b7d82abdf578db4775d7aafa", + "src/transformers/models/gemma3/configuration_gemma3.py": "069aeec9a2bd0501628122f515eea37b20873c13d031ff4374004c99cd939406", + "src/transformers/models/gemma3/modeling_gemma3.py": "9ae9fa36006078aa84b03832812ce92ab8628d9683f33db59a8c55a8b0277334", + "src/transformers/models/gemma3/modular_gemma3.py": "d0acd36c524bcc2ac352f5af8aaf0ab29446755799e33e7ba9e5d08d757f8e9b", + "src/transformers/models/gemma3n/configuration_gemma3n.py": "6d4a70a2fbb10aa6a822120e978b70708c38b6c699c48110dfecbb5f723e80ee", + "src/transformers/models/gemma3n/modeling_gemma3n.py": "1421dc604bb27f4e133167f547bc3e22fdbb18ff2a08a3f382bcb97ecfa46585", + "src/transformers/models/gemma3n/modular_gemma3n.py": "56543f8ec5ec59b503ebd898f4c87d2ae908d808fd2c443fe1865800d8d92054", + "src/transformers/models/git/configuration_git.py": "39456fecfd996b4fee61ab70b251746127c39f2b7aac152a0153289dd899427d", + "src/transformers/models/git/modeling_git.py": "2cdaad67d7f3807b35d84fa126e071e5bd460c55d6c3a6c5ccbddc3972273155", + "src/transformers/models/glm/configuration_glm.py": "43fa4d49845faffa5175d7450a3a50f4308d35aff20391f4a2119b5dd881496a", + "src/transformers/models/glm/modeling_glm.py": "c0af56314d99fc4423a5b10b0cf8251eca4582f2395748fdf229c800d0b7066b", + "src/transformers/models/glm/modular_glm.py": "cb95d3613d792ec0d2ef30a4959912bd63336b00810df4edf36da203209c8dfc", + "src/transformers/models/glm4/configuration_glm4.py": "599250f9d3f895eaeff4c0b70426e14bacbf57f254198a9792cf9b1e16aa1966", + "src/transformers/models/glm4/modeling_glm4.py": "a8f1ff1fbdfee288a6748d837b914f881500e79ee3f0609f992a697f51328132", + "src/transformers/models/glm4/modular_glm4.py": "6bb30f16155e5fec75a99be4d381b39219de5268176978cacac6b1e1feb6d1f0", + "src/transformers/models/glm46v/configuration_glm46v.py": "75c8efede9ebdacb35b26c2824dd19e9768c50a56570afb2f0f1d5afd72c900f", + "src/transformers/models/glm46v/modeling_glm46v.py": "1f85962df2415a7ff7e070bd4796b8de5c378bc15a6d3b0fb131cb4b81133a4f", + "src/transformers/models/glm46v/modular_glm46v.py": "d2020a8e457d63fcd87efc555464c3aa5b6352793aaec5b31b768de2b552a091", + "src/transformers/models/glm4_moe/configuration_glm4_moe.py": "b32ed9b4f63980c565b92874337e00635038bfb33dc0d455d8005aa1e48c178b", + "src/transformers/models/glm4_moe/modeling_glm4_moe.py": "5d5cdc7788034df7d8926696ac46b0647cde64dc258f0a684a26da7ebc2952e6", + "src/transformers/models/glm4_moe/modular_glm4_moe.py": "a993ce9117f32e7130f99df69950972fd21bc75a9bfdef61a15a19252290957a", + "src/transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py": "6d1dc311c950f2a776f28a0c3f02bf17401b03a51fb8bc3bfe65fce1b3edbf73", + "src/transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py": "ed6194c48cb2b028289df3832638291bdc17a7aeab85b0150c3881b9390629b9", + "src/transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py": "5c3d92595f33af03664adf236724ef644378db6d0f72693d0acc86507c9ae4be", + "src/transformers/models/glm4v/configuration_glm4v.py": "e049bf3a2b10dd119d992a0a609c143ed4e5693b44b01447dda731c9fd7c0844", + "src/transformers/models/glm4v/modeling_glm4v.py": "5d054840dd92fade410bd45326b096ae93e8813cd7a24ce346dce885b09578e6", + "src/transformers/models/glm4v/modular_glm4v.py": "3b811dedcc2d858c68656606667a3b14a34d2013aba8d11f5f8ee179caa35568", + "src/transformers/models/glm4v_moe/configuration_glm4v_moe.py": "61f35c68f19a58c0766fd53a0b5d001bcde798fd72c29d3169fe0f6212272c23", + "src/transformers/models/glm4v_moe/modeling_glm4v_moe.py": "2911f91ba04e0d2c442a490d812660bfa1f4009b1ab2af82b388ab1efe705a9b", + "src/transformers/models/glm4v_moe/modular_glm4v_moe.py": "54fbc2552829948bcd8c5efe33f068c11a998b1a6cd3fdf4b9c9a5b7535d654a", + "src/transformers/models/glm_image/configuration_glm_image.py": "64177a6c429bd6ea6014b3840528c565b2dcdb7572f1e617803940958b90c1c1", + "src/transformers/models/glm_image/modeling_glm_image.py": "09be965e23768671416f9f977412f7585e929de91e2079e817b2bb356cd155b7", + "src/transformers/models/glm_image/modular_glm_image.py": "e0c45ff71eb3ed9313a702c485967a1891f6db5c67c7b7cb75808803ff931431", + "src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py": "4040ac9d39fc01e13d3007790a38ee2fdc0ced1b984b47554121ee1247044cb9", + "src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py": "5767b3af44aa97de4ca3fb3a63ca324e95dbf959c96f66c138e2d1391367089f", + "src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py": "5c7142fabfe08a6c30373b8cda0064f3ed892cadce2b8c012802db52cf2c58d2", + "src/transformers/models/glm_ocr/configuration_glm_ocr.py": "b66cb6953cb6cac27eb554213498b8067c01eb37b8164947eeb472b79cab937e", + "src/transformers/models/glm_ocr/modeling_glm_ocr.py": "4b9c8adb4237c7869e14e1d6205cbf8ee38ee83b2d2d4c9c893e2af85f60676f", + "src/transformers/models/glm_ocr/modular_glm_ocr.py": "e1c85fc6db323fa4e1817c1f5f3c87806f57a11671cc86df5afd3bcb28725336", + "src/transformers/models/glmasr/configuration_glmasr.py": "ede1c1daf302302396b7f15bc146057dfbada127e61badc63991aa94395d9c3c", + "src/transformers/models/glmasr/modeling_glmasr.py": "9cb6b61f223e43c83df9e0c7db1bc6069746d258d50a90ce30a0cf7f60a9e33e", + "src/transformers/models/glmasr/modular_glmasr.py": "e7ee360bc3da3d3076507dc49a8ec2d973f975b08a1f45521436656a7f04ef74", + "src/transformers/models/glpn/configuration_glpn.py": "da8ce2caa9107313330a12a5c27db142a5a231524e31d7ac0c48b7c5e7ab47a6", + "src/transformers/models/glpn/modeling_glpn.py": "031f7211ef5cc993828f15c2bef5645f02d056d12855ab8a05e5aa2984649bf4", + "src/transformers/models/got_ocr2/configuration_got_ocr2.py": "9eadbf6daf60d57c549ed64a42232a8fbf7396fea8894337ba85187974fb8cc3", + "src/transformers/models/got_ocr2/modeling_got_ocr2.py": "91402025896744aee614d90a0dd2f1e6bdd201286e9a87a1ef2a12ae73b0217d", + "src/transformers/models/got_ocr2/modular_got_ocr2.py": "ba936274b071e7a2575863a22cadb6a18c3d39a19b6b0ea96801b8986dd4e130", + "src/transformers/models/gpt2/configuration_gpt2.py": "6877fb4366d14077721637986c3bdff4f9b4e32d02330a8a892f9843b8573ff8", + "src/transformers/models/gpt2/modeling_gpt2.py": "7bd1b14bda833ae55d6ed7767f323d2ee5888c79651727beaa76a58f1537b91a", + "src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py": "bf6cbf8fb239384a9d84a33c8a6361bb0dd74b3568f0d1f8817632134f6e34ce", + "src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py": "c0f3214356a6b0b9277fbb0b0ad95295a62dd9570af424df962962dfe69c5374", + "src/transformers/models/gpt_neo/configuration_gpt_neo.py": "0a93099ec341ce55c7d071e8d55d52ac586d3f3166056de1cf627e2e63721b36", + "src/transformers/models/gpt_neo/modeling_gpt_neo.py": "8602fd2a7d120eb90126a62711f52588c24a04896b4824c40499ef6f167c4acb", + "src/transformers/models/gpt_neox/configuration_gpt_neox.py": "178ba1009107a8f7f49c7bbdab403dcbb6e960328d937b2318f9edeb75938333", + "src/transformers/models/gpt_neox/modeling_gpt_neox.py": "a1784e2ee3ff4081e2b74ad1386d479c319565f9cf3e3d3e2863d49c021a28d8", + "src/transformers/models/gpt_neox/modular_gpt_neox.py": "827b035d2a824dbe6a112026b672341e555284b79a33fe44fe72015f0e36d5b8", + "src/transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py": "7544c340ea49ec3b6c1dddf0aafff3583cc120bd644b789e4f551a0a495598ef", + "src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py": "248f9f143d93e0e0177b2a42d118b56e838bcd257115098e4c18d664a4791d99", + "src/transformers/models/gpt_oss/configuration_gpt_oss.py": "c2684f4858e2f52040c644be163c4372660a963064cc92c3cc8f32d3bb3ed894", + "src/transformers/models/gpt_oss/modeling_gpt_oss.py": "a394c9335c13abdf370acbd93b97f1463093c731b0d77727ad7249d26ace67a5", + "src/transformers/models/gpt_oss/modular_gpt_oss.py": "0f81b6ea5c48d5105afb942db63d2425b93ed364d4f0db31746a07e471e88a2f", + "src/transformers/models/gptj/configuration_gptj.py": "45715867e35816d1d9cdfd36e3633cf50ad789abb7c130ea3a012849ec0fceaa", + "src/transformers/models/gptj/modeling_gptj.py": "c406949259ca0a566ac50d9ded6ad37124736fb5815352b842b2ee8354599a30", + "src/transformers/models/granite/configuration_granite.py": "ee91ee043646a0463e5726aa45380f81872773303b04b62112423dd15b003daa", + "src/transformers/models/granite/modeling_granite.py": "75dc96426b1fbb8409aad24d03705e6182c4a03a07b6b1f59684931bfb1b278c", + "src/transformers/models/granite/modular_granite.py": "0fcf9f2dc71096dffff2dd601c8675c04759a115e0720899618e492e17d7d880", + "src/transformers/models/granite_speech/configuration_granite_speech.py": "2f11afb83755fcf2c29e252b455d3dc1e67dc63d930cebc0c6f13f717ca8a987", + "src/transformers/models/granite_speech/modeling_granite_speech.py": "71f83aa1be0f2180097db8ae51aa5e70148b56b71cf70cf00ca6caad32b27ad5", + "src/transformers/models/granitemoe/configuration_granitemoe.py": "59e4336121378c1f3348ed7c6f625581eddc65263dc80b9abef39ec5ac863e48", + "src/transformers/models/granitemoe/modeling_granitemoe.py": "167e417c8452915b3b35479e305be9112b3a0d214f7f26612c7d518ef3b158cc", + "src/transformers/models/granitemoe/modular_granitemoe.py": "4efac9f589ad00225d6fc7aab43458a3641f4d885311e9f9171526aa9c765846", + "src/transformers/models/granitemoehybrid/configuration_granitemoehybrid.py": "03b66a5fdea23501a672728210ab208664b2106aedec178e325fac19859b3b28", + "src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py": "c9de882cd503aec8b4f263af764804c5186a5a7ef710f2b8f5e8659486ce7556", + "src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py": "179f76acd59c28c2a8cdd46168271aab6f680f3c036a19a987d73e405cc2b60e", + "src/transformers/models/granitemoeshared/configuration_granitemoeshared.py": "a5a1a9bfb0397a23b38fa3e0a820a7f081bb6d6eda06092fc8e0009ad35d84ab", + "src/transformers/models/granitemoeshared/modeling_granitemoeshared.py": "e6f38dca72be7bd1fa03dbfa6424f0ef3ba3ad55d67f139983751c51cbce90c6", + "src/transformers/models/granitemoeshared/modular_granitemoeshared.py": "2cd68110a89053c2238d528c1c55f93b236fd46c5f62b63594eaa4ee14ef9f8c", + "src/transformers/models/grounding_dino/configuration_grounding_dino.py": "ed1312506a65973154683a634e0b9cd6ab29c7719ddf5a9de066fa48a29c365b", + "src/transformers/models/grounding_dino/modeling_grounding_dino.py": "e4ccc8846842d1ffef1610f92480223e95e37ff962123ba1fff253f4743a7cf5", + "src/transformers/models/grounding_dino/modular_grounding_dino.py": "d15ce3d472bbe9381403eb608dd310ae70fba2e7514024d90f25d5d98234e332", + "src/transformers/models/groupvit/configuration_groupvit.py": "b2de28a065806695ec39880677be9e10ebcdb42e949783b4a6ace06b9a1cac50", + "src/transformers/models/groupvit/modeling_groupvit.py": "c840cf7631f2a913011ffe059b8ebd3d42813f7900f2c02506016ae53e7319ae", + "src/transformers/models/helium/configuration_helium.py": "39dcae798b9ef0bd667f371d24de04d82e3079b74073993c01751a7d4228e236", + "src/transformers/models/helium/modeling_helium.py": "80375332e3e29f8c795f300cffc8b1be7082cab18b005e2db7e247bae69937ec", + "src/transformers/models/helium/modular_helium.py": "2b169490ab056767e54c5f92a7d0a5a326a82151d4a6964dc45160fd33e07c8f", + "src/transformers/models/hgnet_v2/configuration_hgnet_v2.py": "d4dc9093f63c1cf87ac7e38e5ca453bc2ac76c9c06642bbf535e4ee05f1450f8", + "src/transformers/models/hgnet_v2/modeling_hgnet_v2.py": "33a29deb272c1527d72040246ef517e5ca8e5e87a9df41cd209fb05c768c2887", + "src/transformers/models/hgnet_v2/modular_hgnet_v2.py": "f07a228e327f98e19d5d24ae1f294741c495d3bac9df6a0a755b5a626ded5bd6", + "src/transformers/models/hiera/configuration_hiera.py": "4da484da3929b8e91d64cec6f31fe041b7f80ecc624d86d31c66203afe1349f6", + "src/transformers/models/hiera/modeling_hiera.py": "5d2c1965290d053b61d5686a3e486d26a3e66e6d56c60debb25cf61795940271", + "src/transformers/models/higgs_audio_v2/configuration_higgs_audio_v2.py": "3e4cb385300016380d1a38ac19994306e8349475881ef13e2499eb95a14b6937", + "src/transformers/models/higgs_audio_v2/modeling_higgs_audio_v2.py": "db470770386fec72e71ea87736b7624af445947b7ea565b9cc15876332e809f3", + "src/transformers/models/higgs_audio_v2/modular_higgs_audio_v2.py": "9041aa6c57b719f7bd5b3ee2e7fd0ad485abd554eb7945382bf823124059ee4b", + "src/transformers/models/higgs_audio_v2_tokenizer/configuration_higgs_audio_v2_tokenizer.py": "56ef66a647e8f1ea45d16eec435f941f55470fc5ebd0ec4923e63310456ed2ff", + "src/transformers/models/higgs_audio_v2_tokenizer/modeling_higgs_audio_v2_tokenizer.py": "ffabafe79e99b48168bd04e8932ee5ab6aeba07f59f39c2bde8c1c1ed25ed2dd", + "src/transformers/models/higgs_audio_v2_tokenizer/modular_higgs_audio_v2_tokenizer.py": "eb05be1a4399904bd563f928fd33bc8e373229fe585b80877bab131aa0508158", + "src/transformers/models/hubert/configuration_hubert.py": "3be790fcaa707e9aa6d3e4cd7a45ee1c3c1965852309010788b5c867cb147bb7", + "src/transformers/models/hubert/modeling_hubert.py": "6aa0010232d6151de7f964365c318a65f74ecdb08628dfd76e0aeadc6e1bc509", + "src/transformers/models/hubert/modular_hubert.py": "3ed97905e0cdc784ca21859426b6a63e5a05362971fbdb0e582d5888fe25e49d", + "src/transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py": "3956bfaa567b8abb38062046c7e8ff7ca2161c1cf41218e4886ca7b5cc318c41", + "src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py": "95935d73eed9c81708d31ba507944bda5d559f6ce2fb8aab6a6d878b00a1c30e", + "src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py": "a0816aeb1cd148fd41d55a65aa46c701530af5574899904ac0017b7165bbcdee", + "src/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py": "eb7d3240fa30d375894e39b2d18dda07d793ff4b5f576497f573c55b452ea9cf", + "src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py": "fe06f8b9a8babdecabfc25f5eba1b78060abedbd8faef303b4d10867537f88b4", + "src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py": "425138d046338473b9a3f8d145b3b44dd229e6f753b22e5cb773576a68bb92c9", + "src/transformers/models/ibert/configuration_ibert.py": "d26257a9642b11986ee39406468f746c88275909be34fa083fb18ef5da620187", + "src/transformers/models/ibert/modeling_ibert.py": "36ae71d50cc253a7773a4e22c37a266acb7f14ffad2b2b65b7ffb6715d599ec1", + "src/transformers/models/idefics/configuration_idefics.py": "8466386bfbec3793430335d19d2d788c25601d7018282e4bf4bd8964bb5a9b30", + "src/transformers/models/idefics/modeling_idefics.py": "dd4cb635a970451ea62c4c69a45c09bf9040f8996fa541772f32f16d04e73b1d", + "src/transformers/models/idefics2/configuration_idefics2.py": "a2cfb4c58668cb987fe8cb4407a4c7b2a5fe59faff89bfa599d889416efc3208", + "src/transformers/models/idefics2/modeling_idefics2.py": "3f859a770d44b5e0ec42b2efc635ca5a5dde35ceb4965480d2de008fcba6b4fc", + "src/transformers/models/idefics3/configuration_idefics3.py": "ac888b8e831910bda16a217a6620250d7d7df5e2bbfc1146dbaf0c86ff35592d", + "src/transformers/models/idefics3/modeling_idefics3.py": "d38fa8c7023a990ca1434209be664055bbc611e8b05c8ede9a655045be5d8186", + "src/transformers/models/ijepa/configuration_ijepa.py": "50eeee3e85784b43daed89d29985934b2681552fd9ef728ff447098deeb955ff", + "src/transformers/models/ijepa/modeling_ijepa.py": "f2045a773328e1d5112d24e1534d29732cb8fb35d783bca415dd14ed51de5c54", + "src/transformers/models/ijepa/modular_ijepa.py": "0d794538acdf5a079c34de2db17f2bbd01bedae6baa3ad0d4369cb528b720a6c", + "src/transformers/models/imagegpt/configuration_imagegpt.py": "69a64e7bda20c710ad3353887cc4069a1beb6b0566c389a24baf9fb7b19df537", + "src/transformers/models/imagegpt/modeling_imagegpt.py": "ae5b5e78c838af514ca7361c5dfc555447958a7a2ac21df8728431a32499130c", + "src/transformers/models/informer/configuration_informer.py": "6b58f8b04970545e81b461c37c983092d29c81088eba908b2f9a39841916a173", + "src/transformers/models/informer/modeling_informer.py": "fa8e18099395fc6f9589539b9a364db86931d0e10c0664ccc4975ba374c243a9", + "src/transformers/models/informer/modular_informer.py": "c7fdddb9fc0458d326133730f3abd6c2828ce3ccb7d8a536ff27dad401a703fd", + "src/transformers/models/instructblip/configuration_instructblip.py": "5e8b06439714e21b7b53d417bd9394c1033a4d625084f7efa608ececfcc707ce", + "src/transformers/models/instructblip/modeling_instructblip.py": "e90d855bbe23d0398d21e421c46d8f65805f9ef2ffa3f37b8236eac9fafd482a", + "src/transformers/models/instructblipvideo/configuration_instructblipvideo.py": "af943c65075c0583b5ef4bc10585c6dc77fe0da286c238af38a9e1181885680f", + "src/transformers/models/instructblipvideo/modeling_instructblipvideo.py": "dbd483108ec0fdf326b0ac22982cf698b30aeea0e793fd3645d78c794f4fd05d", + "src/transformers/models/instructblipvideo/modular_instructblipvideo.py": "200544979f197cfb23af923da3ba15ce73e92488cc098c0b4dd1821cebdf231a", + "src/transformers/models/internvl/configuration_internvl.py": "fe8195d8de48dca267db1babbbf27b68e7ccf0e12619ac6cc94d381c0e1d7fc6", + "src/transformers/models/internvl/modeling_internvl.py": "8e88aa0b4f106d8ae1db4295e9dd9852579f6961388b2e68a0a27b3bb42e35aa", + "src/transformers/models/internvl/modular_internvl.py": "9a6300baa9b84e5d9c006688f3a1d539d306d71f75c762fd7c060d8c30ef7a42", + "src/transformers/models/jais2/configuration_jais2.py": "c8fe3e6ed41c7b2789cfa40e98708c8cf6843072bbaf90e78ed385784dbf3d24", + "src/transformers/models/jais2/modeling_jais2.py": "70332ec1796d7b35605e316dc623bb3b575a9f62df06c845df59c7e797af3236", + "src/transformers/models/jais2/modular_jais2.py": "eda1a533e283498a7396c5cbda569b05a2da0874b4551620fa429312c2aec94e", + "src/transformers/models/jamba/configuration_jamba.py": "8af6cc52904b6d29d57127adbac7a9565d18c9568790cf7fe1c285dfb6b254db", + "src/transformers/models/jamba/modeling_jamba.py": "88574d672b75849e727c11805efb7db299088ae7b2ea47420e6b4f215407cbc0", + "src/transformers/models/jamba/modular_jamba.py": "91b43b0c16d3550eb9f03bf611e6d1185055fb485e79085c7b500067351b7ada", + "src/transformers/models/janus/configuration_janus.py": "27f32d804c63f98ab8fe13784d0e4deaf93d915e874026b6fda927bc66246d9c", + "src/transformers/models/janus/modeling_janus.py": "d0c0405fffc61583dfc5c30bdbca09ed5237a9f7648fc3082fcc4f7a4893fe5f", + "src/transformers/models/janus/modular_janus.py": "124315d13a7335ed47132e0ee9d00a1681d33e3c6fcb192a51490b617cdc1d5b", + "src/transformers/models/jetmoe/configuration_jetmoe.py": "0a94120f38233320eb1795f65aa1ea48a69c57e88a3496fb37691387da0c0852", + "src/transformers/models/jetmoe/modeling_jetmoe.py": "87aa01c2f7744028b0fde50649ac11ef0cd5ed5ac8b6e11a04f68bc5c4944e97", + "src/transformers/models/jetmoe/modular_jetmoe.py": "1611c058af53e7d12ec04237eed49d997e5a70b734728a821a3d8ea0008d35c8", + "src/transformers/models/jina_embeddings_v3/configuration_jina_embeddings_v3.py": "f92c62ba245f43a8dcbbf4c4aeb0f8ccc0873eab488aba2beea98dfb6c745b87", + "src/transformers/models/jina_embeddings_v3/modeling_jina_embeddings_v3.py": "85bded3925faa2af7c5cea61c4a0a5241d20fcf70d6bc76c1e1072961abc6c8a", + "src/transformers/models/jina_embeddings_v3/modular_jina_embeddings_v3.py": "e010ad4570dbc616fdbebfe21294f2a62c61497c3cdcb96d843e9736cd7ae7bc", + "src/transformers/models/kosmos2/configuration_kosmos2.py": "e608ada314417fe3596f7ce6c17de4e176a28d753d70784987ac76647e33dcd2", + "src/transformers/models/kosmos2/modeling_kosmos2.py": "f6cc4204fadb19893456754b9d4d9199b63938da181ffae00041e2d29c8fd102", + "src/transformers/models/kosmos2_5/configuration_kosmos2_5.py": "6c7a969abe650340f4a4c7b3f4612dd8145186a09e313de5e3251b9184a6e664", + "src/transformers/models/kosmos2_5/modeling_kosmos2_5.py": "8e9d18816d4f3180b1178eec20286b23fe7146116ea6fa3110fe770f4d188c00", + "src/transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py": "df3b14a1a047353d7df2ea4be7a7e0b7415d9a9107df00ddffd7c83c80a490ce", + "src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py": "d79f6d1622e316f543de0cdf140fb7be97f2f0c2e14f1328e8e1d60e0cdadf25", + "src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py": "54ada8c6cb8c30f17835391eccfb8dc4de40189ab77c8575d64de026aea34897", + "src/transformers/models/lasr/configuration_lasr.py": "f3bc15706ca88fc7276cf1fc34e000e5d15af0b760ca038ad03c91748fffdc9d", + "src/transformers/models/lasr/modeling_lasr.py": "a0826f3ccc742883aed4379c5781761cb86d5c051bdbd88b19c96ac785c99661", + "src/transformers/models/lasr/modular_lasr.py": "3fa07714ba48c5c75a0acccfc52565997942776375630e14f7afc2f7dcc4764c", + "src/transformers/models/layoutlm/configuration_layoutlm.py": "0c620323bb7283e81d725dec645e59b9c19e6d29b5cf3ae4fe727f28fb673a0b", + "src/transformers/models/layoutlm/modeling_layoutlm.py": "3a4fff5c809d5391b3e9e0a9a91854e415630d49f8f9a3b9cc3c3decbb9f0d03", + "src/transformers/models/layoutlmv2/configuration_layoutlmv2.py": "cc42120fe754ce9d70fba1ad4ade586bbc498bc376ad90718e81cb95fa96d198", + "src/transformers/models/layoutlmv2/modeling_layoutlmv2.py": "bff2c6e07df63c065f856b111e526ddddeaa2592fb8448163f94bba5e209791a", + "src/transformers/models/layoutlmv3/configuration_layoutlmv3.py": "6b686db372a201d892a5012a3c01d1489d3c770bfcd2b1af6e6a79c53cfaf7ba", + "src/transformers/models/layoutlmv3/modeling_layoutlmv3.py": "eced42508a051419f58728f67e87285103c50f425fff52e081819527f1d804e6", + "src/transformers/models/layoutxlm/configuration_layoutxlm.py": "1f9868377dc60967210d0e210afa763b8bbb548fb587a08a1ee460f9ca382d71", + "src/transformers/models/layoutxlm/modular_layoutxlm.py": "72b04dc0238c3994ffeab3793b01fdd13490c86eb6d0455e306f856c53698bec", + "src/transformers/models/led/configuration_led.py": "891c9f1ab5df5a930ad244ddf4895793b04d1719e5e4c4fb17a075f7d6815587", + "src/transformers/models/led/modeling_led.py": "d9c7c05f82cad4a0b7768d46bf03897a5e744c6680e247a297725415cd7c9f60", + "src/transformers/models/levit/configuration_levit.py": "30fabb6bec15d4eca98924d47c6e7ee373003d8adc1807b1d11a6019ed947059", + "src/transformers/models/levit/modeling_levit.py": "d9cbbbbf990729ea822985e5bf4ab5b3470e768b6cfea5f9e34770e4cf367ee2", + "src/transformers/models/lfm2/configuration_lfm2.py": "abd75f2d60437fb897c7fe18a29f6b46b0f1bb9199ed99fa333198ec71b6182e", + "src/transformers/models/lfm2/modeling_lfm2.py": "aad0ca1d374f389844dfbdc7d2b5d203ab7b65f729378ecb775a5ad0d8e3186b", + "src/transformers/models/lfm2/modular_lfm2.py": "20547db9f5818e125204c7a5cde1dcf10ccf2e355e6c016da76b56c064d26040", + "src/transformers/models/lfm2_moe/configuration_lfm2_moe.py": "3b7d1ae2fbff6f852535ba2f77bb56797ed15960a968067907cfd6d4af6b3e49", + "src/transformers/models/lfm2_moe/modeling_lfm2_moe.py": "86f07f9c909f44433383a1fd3d4de55836780eb655c4a3a551d99260f7ca1312", + "src/transformers/models/lfm2_moe/modular_lfm2_moe.py": "a729afc0cea7e95544bcfbf3c9e26df6c3fcd50217e9c2b5503420b8eb6662cc", + "src/transformers/models/lfm2_vl/configuration_lfm2_vl.py": "b4ad74bb99c172969ce90d0f891216586540ae56ae4d00e464fcc6602d5839d5", + "src/transformers/models/lfm2_vl/modeling_lfm2_vl.py": "c47ba12db8b9d2799d84852ac9740d7ef1ac989e72675a59072ad5e6bddd5c81", + "src/transformers/models/lfm2_vl/modular_lfm2_vl.py": "9d049b640f900fc1d20a6fbb89d4675cc21965d34d68f4b7a01385f4be48288e", + "src/transformers/models/lightglue/configuration_lightglue.py": "92c1257971fe300dde6f7d1027273137a179080880fc1767b363032b001314d9", + "src/transformers/models/lightglue/modeling_lightglue.py": "a1b135fd594b6c6d93b84f35a8b3145c121add4a8d2cee1d729430e5f8051f89", + "src/transformers/models/lightglue/modular_lightglue.py": "4dab9bc19ee0f29d060ad3f45c17deebba8c523d7adae6aee92b30ed8f41e8fa", + "src/transformers/models/lighton_ocr/configuration_lighton_ocr.py": "e09b1d7c2ddc07a32e1fc439cf6be058ed04c56b01f9b33eb8aafa52ad25078b", + "src/transformers/models/lighton_ocr/modeling_lighton_ocr.py": "1365e4a5fbd040df4fee3ace36c26707abb96f1cc08ad3d2ca76abe004495dd2", + "src/transformers/models/lighton_ocr/modular_lighton_ocr.py": "66674d1276e11713232afec8c1fe0afd8df6b12de4a0ea995d9d95ff3abf5359", + "src/transformers/models/lilt/configuration_lilt.py": "b559fbbf3b421aea2859ed0a4ee08ae783c4e15cdd86f8a8d0a3806255c28309", + "src/transformers/models/lilt/modeling_lilt.py": "5df3bcb7e24a15171235b7faaa14dd33c35b42807cdad1f987eb7656ad871e97", + "src/transformers/models/llama/configuration_llama.py": "c49ff892b7f62f8ebca48874e73cbc6dccb0f46be96fb60dc1f806283807cac8", + "src/transformers/models/llama/modeling_llama.py": "6f2ac1eef350e2156a2ba10c7eea78b86afe1b66c5f5ce8a6df76a285afc8fc4", + "src/transformers/models/llama4/configuration_llama4.py": "0eb9d8e88b4129e9e06b0a52b4622e71d19f1b249ece07ec3e54f8c35ef9a754", + "src/transformers/models/llama4/modeling_llama4.py": "0a8dae2ec943c843adf4dc6d578e19a8511dc23b9aad805fbe5d3730c2c9a1d7", + "src/transformers/models/llava/configuration_llava.py": "9427ffab42ee85abcd7c4773016852188c6c176ad6ca6d0f3b0fb35bcdc5a5d2", + "src/transformers/models/llava/modeling_llava.py": "83fbe20a69de34944005827149fa65090bfa1d97d94ce4e6725987660f989c6b", + "src/transformers/models/llava_next/configuration_llava_next.py": "1265f66ca602fb24ddb51d72cb33b35a510b674da85a37c3ee09a8412f97a4ea", + "src/transformers/models/llava_next/modeling_llava_next.py": "babe08592781923894dfa316306413604931a0389f5a10a8098ff7fae4a6ed3e", + "src/transformers/models/llava_next_video/configuration_llava_next_video.py": "800aa64d7a4558ecfc5a98b12229f15dc5c9e1384dbc15492c6402f9c7b0a349", + "src/transformers/models/llava_next_video/modeling_llava_next_video.py": "ef9818cd2fb50b5199450f565fa4d5f422cb39d9bc0a9334ef28a7bc30d62d2d", + "src/transformers/models/llava_next_video/modular_llava_next_video.py": "b883d36c20f9342a7a751c54f11f7a4584d7419600bf7a74da164ca1b73c9172", + "src/transformers/models/llava_onevision/configuration_llava_onevision.py": "24cbcac0f167945b6776f20e16595525f4e3aea47b76af884ae6fe73677596fc", + "src/transformers/models/llava_onevision/modeling_llava_onevision.py": "ba68552eb3b4a59220ca7ca2a372201361600afa33a63d791159296127e8b930", + "src/transformers/models/llava_onevision/modular_llava_onevision.py": "42d1b62db35c035f2c88437849138d40eb735e22036388b2e0a96a42e44cb866", + "src/transformers/models/longcat_flash/configuration_longcat_flash.py": "18aa599f9018a922bf6e85e9fc76d2fa4bac1b9503f49bcd72c65d96ac855ee3", + "src/transformers/models/longcat_flash/modeling_longcat_flash.py": "b45993d5b8095d6c94938898c0c5fbaa8b8e91c991df41196052ce4e2992cf2a", + "src/transformers/models/longcat_flash/modular_longcat_flash.py": "45ec389a7d08eda058b3ee5eef7e0a89c29735f7050e5db9602ae1864f8538c5", + "src/transformers/models/longformer/configuration_longformer.py": "bf48a7fc96c06ceacfbd0dbaa9693d4cc5ad8e6d4795d035706feef5bfe7103f", + "src/transformers/models/longformer/modeling_longformer.py": "f65e93d58c0269ec3de56e1ebac31ee3707f37f87ffd80a80e26f655a8f85570", + "src/transformers/models/longt5/configuration_longt5.py": "75e4053d4a5324079fcb2283cb2958c388d4c25436d8ee95668e80e79d0f0906", + "src/transformers/models/longt5/modeling_longt5.py": "a7648386915ed48297193e0da3e6f2cd7857377d700945cf75ab3bee0d4113a7", + "src/transformers/models/luke/configuration_luke.py": "4f93e179631aa338ef86d0281f032afc6b4c7069114f23bd6c4c9c46aae44bb9", + "src/transformers/models/luke/modeling_luke.py": "06c189416b7ee199f6cb843737824ff2a136e616b20e837c7c293b2c2950d228", + "src/transformers/models/lw_detr/configuration_lw_detr.py": "dd9bfc1eb2afb3503961ed1176635329b522a6621e38a9fd6723a22823babe1d", + "src/transformers/models/lw_detr/modeling_lw_detr.py": "33e966f3167c8d670a458d45cd45cf1f715d242401a9ff851ddc3c611556a7de", + "src/transformers/models/lw_detr/modular_lw_detr.py": "44073db28bb0e289251fabcd360277dd7ebed74bfd8c007100563cb6596d6e10", + "src/transformers/models/lxmert/configuration_lxmert.py": "192108fa56c4d3caf41c38fa250103ade45f1ee1184a31a6768466202f5baf7d", + "src/transformers/models/lxmert/modeling_lxmert.py": "2d4707b9eb872f00699110aa787339a4f563139a76c9a9ae13f6054cbe77db27", + "src/transformers/models/m2m_100/configuration_m2m_100.py": "3641f24e901972ce3ee59f696c2ae76df7f260e8830e2b61f66b33f48c460a25", + "src/transformers/models/m2m_100/modeling_m2m_100.py": "1d0a84904bd69f01ce72761e7d3853e47609217be59aee24cced870196ef9e5c", + "src/transformers/models/mamba/configuration_mamba.py": "676b6ac4ee81df3d931d3922d31500417434a519bdd2b1740df7eedaacbf2ba2", + "src/transformers/models/mamba/modeling_mamba.py": "1f53c50107dc26f76ad8f69be9ccc35400caab9a76c56349e98b484f3fa0ced7", + "src/transformers/models/mamba2/configuration_mamba2.py": "eda2e5732ff4a0936d303dd19fa93c752c2f836a0296a0f65ada38bb4405c820", + "src/transformers/models/mamba2/modeling_mamba2.py": "c79f65c420c4e4ecae40faf8e245944e00e72864705db7b8a62da398b370afe5", + "src/transformers/models/marian/configuration_marian.py": "04687c7d08b1bf00c3aa09ee421526ede6f7f9c488b578c10df52547d76c63f1", + "src/transformers/models/marian/modeling_marian.py": "dd7958f8d3161f6284b4e0e9f07acbe4067fdcc98645a3c20c405c88018bf519", + "src/transformers/models/markuplm/configuration_markuplm.py": "37c63f80ad5ab0822962659e0f2c7af1ce0ab9975737ffaad4f237962a7da19b", + "src/transformers/models/markuplm/modeling_markuplm.py": "9469040f6c9dca224f514b11d49077eb5551e389b36755867ff659013b16cbfc", + "src/transformers/models/mask2former/configuration_mask2former.py": "dff0731946cb303a264b71097ffc5163774be9ff3aac5944e6f70a69d68f8cd3", + "src/transformers/models/mask2former/modeling_mask2former.py": "bc1988c9b84d46a53d620f7e433342b0086632b405b3b1e74407a068721141ac", + "src/transformers/models/mask2former/modular_mask2former.py": "05d3b7dd795d94b1c2002a3bdccfaf4817e7e1d0274d51f3f19f167245c369d6", + "src/transformers/models/maskformer/configuration_maskformer.py": "382332dd9b4b14a368209ae023a5b6c7ca68793790509eedb78dde8978a73b6b", + "src/transformers/models/maskformer/configuration_maskformer_swin.py": "7eea22c3a809fb73e5dca74dd5d89d41caed8da4690516adbfbe7763f3447bd4", + "src/transformers/models/maskformer/modeling_maskformer.py": "4c556a0123547c493bf7be9fb5282e3592cf59c5c58504dc09fac4a32291f7cc", + "src/transformers/models/maskformer/modeling_maskformer_swin.py": "f0c7ca9256a19892c0aa43dfba0b9331700f10a2aad9e432d6666df83329580c", + "src/transformers/models/mbart/configuration_mbart.py": "215d49bba723ebef40e86935cb0715e54a724baf8f44643d4c342dde7aa1dc9e", + "src/transformers/models/mbart/modeling_mbart.py": "f342acab078db6e9d8fb9b3ec1619788e999a1406d80f9adf9899385118a226d", + "src/transformers/models/megatron_bert/configuration_megatron_bert.py": "820833f9e3ce92b3255d57e6e6a973de583e4506517144d28fedf21c3faf9cf7", + "src/transformers/models/megatron_bert/modeling_megatron_bert.py": "cdfd719bda6fb9e8d0164fd51daeed371c0f696ffb8c25cb00eb5fcb71ec6d60", + "src/transformers/models/metaclip_2/configuration_metaclip_2.py": "f39fe8ef1e1e2e7bb9a9d146a57aa33cee63e6cc38be477dfb8b064e47fedadc", + "src/transformers/models/metaclip_2/modeling_metaclip_2.py": "946f70f082f2175ce67aff547139b2c6d512bface7e618e93054376dcd6fc085", + "src/transformers/models/metaclip_2/modular_metaclip_2.py": "62c81588bc32791c7972391883410df250349f9acf6dd93a707a5ec2da415668", + "src/transformers/models/mgp_str/configuration_mgp_str.py": "c2e885be62ad4f06543a0adf0b9496298f9b58e23ba1f6810293fc0959fe0f0c", + "src/transformers/models/mgp_str/modeling_mgp_str.py": "6b8b54afe5067ec5b428693901da9741952b903b19344497fbea665e38207bed", + "src/transformers/models/mimi/configuration_mimi.py": "ee63e6b5311ea002b707c3af7dffda26d23525bffd3f98e3fba6a6afa8fc6870", + "src/transformers/models/mimi/modeling_mimi.py": "439ff84320e62279b7e462679f3fe714e7c677b880b2561f4e34cf2800409294", + "src/transformers/models/minimax/configuration_minimax.py": "70bc96d45ebd466cb3c841df0f14a7aa1d18c82a29fad9b2476252dbf7240ab1", + "src/transformers/models/minimax/modeling_minimax.py": "ef591877e3be0c67a4beb218e64111df54a6e3c261520f8941b099f1973fa195", + "src/transformers/models/minimax/modular_minimax.py": "f8ebc1fd0c77dc2e4a85b86ca9a7505e81ddec189f7e1cd419dab3a0621a5fb2", + "src/transformers/models/minimax_m2/configuration_minimax_m2.py": "81b3c38a8376fc54c0783ec9eb456e38d2957df89a44cff9a579899e04cf14d4", + "src/transformers/models/minimax_m2/modeling_minimax_m2.py": "4eca7507d3398335c23cfc85dd51a6b2957453ead26baf78a7e791643e3a4ac9", + "src/transformers/models/minimax_m2/modular_minimax_m2.py": "d4ade009e1a324a659d34d7f71d9f1a90dcf81fdd984a2bffaa7edd2708e22ae", + "src/transformers/models/ministral/configuration_ministral.py": "317a30b98b70d1b6d6455b72ac01b35953ea66c4dced8a929e5710f94dfbf60d", + "src/transformers/models/ministral/modeling_ministral.py": "b15911a156abc1c8b26ba6723946b93fc871d6cbc25467ed9ce1baef3887a7d2", + "src/transformers/models/ministral/modular_ministral.py": "20608822d41512fffcedc77477a9a189b821ab440db56a37df4cbe313a1b753f", + "src/transformers/models/ministral3/configuration_ministral3.py": "7fe0f928d9726f384a6ed80fd887c957b3534f6f7b468cf66ac4f5866f42435b", + "src/transformers/models/ministral3/modeling_ministral3.py": "cee5f45376a45e7701eae72637ff91cd20c22bc905a2bd905175afbb23644c2d", + "src/transformers/models/ministral3/modular_ministral3.py": "b0dcb37dabf1c18144be3a0b0ba46469c91d79b7bf2b650df8dd6e9ed9a0d926", + "src/transformers/models/mistral/configuration_mistral.py": "e335825263bb22ef018c6687fbf4dcb7a7b80dc5efbcb60946900a176a8bcc9c", + "src/transformers/models/mistral/modeling_mistral.py": "231f9d5f74e338780d6b5b90e6b1d34a4332778e075de8f03dec09416dbbf491", + "src/transformers/models/mistral/modular_mistral.py": "b52e48055ab011fb41cbd700967966c970d083ad77ea1b22a92907dc861f19d7", + "src/transformers/models/mistral3/configuration_mistral3.py": "d5c54eba59c720bb8f4c8657a429d15149c144d4ba5d4374de8b8ae5c63d13a7", + "src/transformers/models/mistral3/modeling_mistral3.py": "9288fe03a3876c45d5601a313ddee0408d878ae141ac6038b635fbad41f8289c", + "src/transformers/models/mistral3/modular_mistral3.py": "57303c2182e09daba7c0b006ce6b037a4b03fb5b606e769b095034645ffc9fad", + "src/transformers/models/mistral4/configuration_mistral4.py": "e0e0cb2724b64f26c6683b2153f9d4c48cd58cd68fa336fdec0dc6238932a5f3", + "src/transformers/models/mistral4/modeling_mistral4.py": "2d590ea9f7a781d38844e7568cc5535e6a89feedf33074c743b020f248567bd2", + "src/transformers/models/mistral4/modular_mistral4.py": "cb3984a3ab3ef7bc448c0b5850d2aa7ce1533d01ce46cd7fb0ff879f6a66954b", + "src/transformers/models/mixtral/configuration_mixtral.py": "bdd51edcd6f03fe0bb6eb8473eeb91668fe87e64bf8f66cf5bf77516e0e89647", + "src/transformers/models/mixtral/modeling_mixtral.py": "7eb898ca25bfa911c56779b90e6750ab59a1b0bf772ffce688fef53a0850ccd9", + "src/transformers/models/mixtral/modular_mixtral.py": "59b4473a0b3951c735b59571b0cb8268bf7f55e90a6a0a163ae5d6229c7ac262", + "src/transformers/models/mlcd/configuration_mlcd.py": "28d23bafe62c03828e6d676f650ee05f23721581a955b04302dc5b98e99bb3d0", + "src/transformers/models/mlcd/modeling_mlcd.py": "45850c7798bbfb5f71472b224f071c1aec5cc8677e7ba284bff2c61e1a08d764", + "src/transformers/models/mlcd/modular_mlcd.py": "85c33ee9c21d3fe91e3f9ab8c62c17cf46c499975f5f5a3b4c493b4fd77108df", + "src/transformers/models/mllama/configuration_mllama.py": "d99751ab304e369dc41d3e5c8c9e5c310871c9a7be20a30699227f9c8267f64a", + "src/transformers/models/mllama/modeling_mllama.py": "c29ff406e8532ff0aff483bf5bd9e90696a51bce908f2c924e48f4e839dff3ca", + "src/transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py": "258ffe3cbe483b95964cfd4fb40f21ead89fecf8cc52033a443a95d302f815bb", + "src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py": "b2959f3505f412bdfb61bc51caccf8114faebeb551a5949ced06205134fd3895", + "src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py": "68f2e10d57715841e0c20bcd0ecd5ccc2df2f860e722b9e244afb12417ec06b9", + "src/transformers/models/mobilebert/configuration_mobilebert.py": "7820022564eaf93cf1a95ebb1ce29b9e50e5374a0b509c560a6f298164879123", + "src/transformers/models/mobilebert/modeling_mobilebert.py": "9ba449503ffb3200cacd50044ace9a6a3dd0b00e27e47dd5daad2e2d9bedb827", + "src/transformers/models/mobilenet_v1/configuration_mobilenet_v1.py": "e87b08a9dc3aa2501de203871e021cac15e912e023776de9daad1b07858ee1c4", + "src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py": "7cd6e4ae6c9932449df39dec0c47e567f66af6243ceb2a3cf41d9aafe0bf2a38", + "src/transformers/models/mobilenet_v2/configuration_mobilenet_v2.py": "996dde5c9127cd8f2c4b0a36a35f4301b99aeb2fc0943963a69dd26b221a21b2", + "src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py": "a625288899fad7c83c64a77292b206b8ecdf0d549a93ede281540dcd6bb1d6fa", + "src/transformers/models/mobilevit/configuration_mobilevit.py": "8279aa0f13a8dc035e82662267e9eaa1e86d266c8a19dcee80bcfc0bdeeb6326", + "src/transformers/models/mobilevit/modeling_mobilevit.py": "a1693df066c75d43f283931e5f04168c7dc90d1b3483baa68f771c6c0727fd0c", + "src/transformers/models/mobilevitv2/configuration_mobilevitv2.py": "01d6cafa2d65d8c63a0757d8e118eab165442a4bd465170aa9cdef32ea92bd90", + "src/transformers/models/mobilevitv2/modeling_mobilevitv2.py": "f1f88531fb683fc656aea2ce5f07ce41f3396e57696a96c14a12c08d751cb41c", + "src/transformers/models/modernbert/configuration_modernbert.py": "7b38bbc365dc81ae44cf8216a78bdf877a82a07a90d156b016f2e8bb12656139", + "src/transformers/models/modernbert/modeling_modernbert.py": "17d1c8ebdd4b38cc1f8168f1ec0f6e5ce1555133d2838283c99cc353a87ea479", + "src/transformers/models/modernbert/modular_modernbert.py": "1041c9ca8875fe10727e24c72f739258757878aeab9f081eb7af0634449ea777", + "src/transformers/models/modernbert_decoder/configuration_modernbert_decoder.py": "06aff47e32cd549d515b94e53dd2005c098d01ccbd5ac58551eb9753828b3589", + "src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py": "ad57a59b680136964adc32dc1b46f829d7182cbb1049c06e7bd09de4a455d470", + "src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py": "79a99d35d0bd37c0dc115e41302ecb1226da5c11f09e990f481039ca236cd95d", + "src/transformers/models/modernvbert/configuration_modernvbert.py": "c11fc6e15ba2056bb1e69d0f3491f74b7de582ed9012158bad15facc77b43758", + "src/transformers/models/modernvbert/modeling_modernvbert.py": "e9e17b061e32b17b578fd74494189b0bdfcebbd6f6d2bdb66a002434f17602fa", + "src/transformers/models/modernvbert/modular_modernvbert.py": "1ce0733169ba27520fe04e64b5f8063dca1dacbd484108356dc11004142542e9", + "src/transformers/models/moonshine/configuration_moonshine.py": "cd39122f595a74659fc10c186a37f8e32fae878fa1f76fc79fa70138cf6cc3ce", + "src/transformers/models/moonshine/modeling_moonshine.py": "c304fc069c70d0b2f0b65ebf130dfc7e34a99a42990df5342c03abefdb6c616b", + "src/transformers/models/moonshine/modular_moonshine.py": "814a81b2ba484d29d91370286a8864712f486bff11409671d569c3dcdb06e2cc", + "src/transformers/models/moonshine_streaming/configuration_moonshine_streaming.py": "56b3543ac4b5248b748cb74ce06b58e52e6d9f28e984876dac72d30d548fc8d0", + "src/transformers/models/moonshine_streaming/modeling_moonshine_streaming.py": "0f655344aa9b5175bfd008934c3212943c182f62a3289968e80ea4091aa6c1c7", + "src/transformers/models/moonshine_streaming/modular_moonshine_streaming.py": "ca91e325a2615d738b6ce758ff7c260bed41ae8bd3ab5bce7a0c2ce0e0cec368", + "src/transformers/models/moshi/configuration_moshi.py": "e33d29255de8681d404faca02cb78233913a287042f67b8bf5a1d1c01ca74fa5", + "src/transformers/models/moshi/modeling_moshi.py": "fe5d4f845dbeb52a6f1876c6595d3f196fa638d8389250a757fbe4460b08e044", + "src/transformers/models/mpnet/configuration_mpnet.py": "58f1cc39d77c499402cfd50b5fac9f2171704791bf24f9b4cd2a1dec8a4acb7f", + "src/transformers/models/mpnet/modeling_mpnet.py": "bddb7d4f855b8bb9ef74fd0b1050adacdb9e9533e5b95462acca7677328a283d", + "src/transformers/models/mpt/configuration_mpt.py": "860fe86ffe0003c7c1cd2a8457329eefbaa0bb63a1ae8d632a28a36f4a455925", + "src/transformers/models/mpt/modeling_mpt.py": "306dcda2d0ed54f8368ebd8cad7d9ea0b4485dbd007b6f0cfa516c3943f88711", + "src/transformers/models/mra/configuration_mra.py": "3fde853d0de420d505f774e989e8e2aba9586e32850bb7f7ca28d6c3cdf0ca60", + "src/transformers/models/mra/modeling_mra.py": "782c9ae454349cfbbd5928606bd9584838b9052ecd6e4bbb33129e51693fe844", + "src/transformers/models/mt5/configuration_mt5.py": "f8eb09b0441e4cb8c2346517c9f146fe3b9f1f409530c62849178d878f2e2d1a", + "src/transformers/models/mt5/modeling_mt5.py": "d150abf4fabac1413f0ef88e1b6ea26b5f9d3fa7ef7bcfad11d987a55db1c695", + "src/transformers/models/musicgen/configuration_musicgen.py": "03f0ff57e1caa4514f6db3da5d4acd7aeea0f0555a6a7fa42bb4969715baad52", + "src/transformers/models/musicgen/modeling_musicgen.py": "d2683d095e73786ad71107b7860cfb27ea7583bda9930e1b8424541c4041e668", + "src/transformers/models/musicgen_melody/configuration_musicgen_melody.py": "ae9db8975eb6178dc5995e9a91a0a32ed0354fe858fee601a8baf034570322c4", + "src/transformers/models/musicgen_melody/modeling_musicgen_melody.py": "9fd2134aaa583b4d6df5280c5bf8cd6586ec84777515946eec0e8bc09aa6ad2d", + "src/transformers/models/mvp/configuration_mvp.py": "95b7acd60185b7eb5c47564b4168dc02837fa8f39eada44afa94c43ae90bc975", + "src/transformers/models/mvp/modeling_mvp.py": "5645e5720ecbc74588c864f6557a3c712da2f6453e383f3447b3ac63ddf632b7", + "src/transformers/models/nanochat/configuration_nanochat.py": "bb4059ee31e86009cc1c710c7ff6250c321ea0934c5bdc18d87cf45285fe6414", + "src/transformers/models/nanochat/modeling_nanochat.py": "5bb913d68b0d344f7742d9e1d3d11952907a8f9a47ea1c2cbf63a4c2043fe81a", + "src/transformers/models/nanochat/modular_nanochat.py": "d4a0ff59e9be383702d049d10c54adcd823b6fa2485cffb92b33a812fa8acc4a", + "src/transformers/models/nemotron/configuration_nemotron.py": "664c096eb39b194c8de8ba3f97f81aed660d3fe66fdcf8637a192ea27f358c8b", + "src/transformers/models/nemotron/modeling_nemotron.py": "6a213b0c5ba853eeeb5947245bf3e11483616d51c1af03d62f84ab6259f9960c", + "src/transformers/models/nemotron_h/configuration_nemotron_h.py": "f56b78f77870c390191d85fd98bfa2d424a528a16d5b8b1435b87735c5866935", + "src/transformers/models/nemotron_h/modeling_nemotron_h.py": "881d820a52e35d060febcdb6ef453763a68665ad4fca317a4dd04b27f7d26594", + "src/transformers/models/nemotron_h/modular_nemotron_h.py": "28c97f5c787aea019f3dd1509f64d8f72373c1e62822feed16f430d59c655346", + "src/transformers/models/nllb_moe/configuration_nllb_moe.py": "5ebf7510398cd6012008f7560892247e1a9c8769e85d780202d0974b7de4df36", + "src/transformers/models/nllb_moe/modeling_nllb_moe.py": "fb350865097d11defbfe0f3d399a44f71811d055e941779482742d4b1188366d", + "src/transformers/models/nystromformer/configuration_nystromformer.py": "b5116d57dce7637a8f77167e5eb46eaefdd3af0c4a96286af4a1b5fda6b5e029", + "src/transformers/models/nystromformer/modeling_nystromformer.py": "e691270fc1879dd1e40c5352d2dc01e35b13b665ec7224153cbd06449ae56709", + "src/transformers/models/olmo/configuration_olmo.py": "da39b2a0733385a0beee7baa7b21c4da2700aad12475d87cf90808d65ace1ca4", + "src/transformers/models/olmo/modeling_olmo.py": "a48a2e94daf3861a5649a54ad7bce567f6bda0ceb33062536a179d7534f81fcb", + "src/transformers/models/olmo/modular_olmo.py": "ee0d4ab25072687fad9241cbeb6bcd4b36981bf96622f65f1e91f3536884a977", + "src/transformers/models/olmo2/configuration_olmo2.py": "896ab23279155cf5dac64ee692ee8d8b2ab9182a35d02ce93e6aa44cb4f6f9c9", + "src/transformers/models/olmo2/modeling_olmo2.py": "2c99da55fe6597cb79272971da5112e89c638d5f5d35657e05ab730ad2b9a1a5", + "src/transformers/models/olmo2/modular_olmo2.py": "4577c969451e057972ee829e62cdb7ebd00169026e95d26bb941562f764f1ea6", + "src/transformers/models/olmo3/configuration_olmo3.py": "f518f4076eb958249d903d67762a86b26993988a595fe217f43971e513e10ea5", + "src/transformers/models/olmo3/modeling_olmo3.py": "ebc6fd816956bbd319601e0e6d8b4ee11e6577d3e925ebb0a65a22600ff79bc2", + "src/transformers/models/olmo3/modular_olmo3.py": "c39bed4953bead76fc789d1bb09fd4af17d62ff60997eaaf6b535c6a715e3712", + "src/transformers/models/olmo_hybrid/configuration_olmo_hybrid.py": "26fb338aceb7e1a6d306320c64cdfaa9036312ae7e0b3719cd0509e26f82a1e7", + "src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py": "68f2dfe750d0c2dbd9833a4ca81d114935c62ac2a870f33462e655e7b14b0a24", + "src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py": "85f7615f191f064ca5f24cee23a9af1ee012501e3133916ca043b86de4af823e", + "src/transformers/models/olmoe/configuration_olmoe.py": "a096aa14dc0fbad1481ce7f62f8f03d0dee860fa66a63fb38d560a43fa510815", + "src/transformers/models/olmoe/modeling_olmoe.py": "c398f44146b562fe71dae1ad6665e77faffc92365f3f8623ada885385d7fd806", + "src/transformers/models/olmoe/modular_olmoe.py": "d941507306e149215537a33aa161727245c0798230a0bde01a5540ce75f4b248", + "src/transformers/models/omdet_turbo/configuration_omdet_turbo.py": "d2d9ab0b744e8f04fc72bc176064a690868696e62e16da3a05a93330561faa81", + "src/transformers/models/omdet_turbo/modeling_omdet_turbo.py": "160cf15a5683cf746125b4b0c3c098b3b441fe2539b1931d89fd96d9ca7cde0d", + "src/transformers/models/oneformer/configuration_oneformer.py": "d7625287c165ee8e3c5ee803fefbf32916f3c58332c75a11732cfac9d7a923ed", + "src/transformers/models/oneformer/modeling_oneformer.py": "26a023e08b66402907e98abd1363fd19e26d204e05cadde570bbea4d90cce089", + "src/transformers/models/openai/configuration_openai.py": "5f480b0293fd3427bda76d56c660dea875983b74c29dc6adb670e42a96b03ce0", + "src/transformers/models/openai/modeling_openai.py": "669c67f1a5b18685c8b7a005a5add007a9124c4d00dd1d3d91402dd645033c3a", + "src/transformers/models/opt/configuration_opt.py": "6d4860fc3f7eb75ff719d8d8b46fd317f7e134b6c1de96345ff2d55c0d82aa6d", + "src/transformers/models/opt/modeling_opt.py": "9c39f2a057968e341099236bd49123366f5eeffb30644311d93d5badbf6a9da4", + "src/transformers/models/ovis2/configuration_ovis2.py": "787b8ffe5158bd2fdcd5d9efbed99db4109ed4c94b7e495293e3b3d2661a9455", + "src/transformers/models/ovis2/modeling_ovis2.py": "cb48dd26620072d278ca99652d4fe00ad94c2e056ef83bf122c05bca5f028ab5", + "src/transformers/models/ovis2/modular_ovis2.py": "367abc98dddded798f37b46dc7f643a56bd4554574b2480c88cdb75146135829", + "src/transformers/models/owlv2/configuration_owlv2.py": "6f108c6c5b853db5bca98b1997aee26d89c2b7fed5e5ea65c98334637e8ce71f", + "src/transformers/models/owlv2/modeling_owlv2.py": "a1f37fd7dfc7a97dd318eac8e3084200ac1d7661f7eefa5e5b3468deaf2fe5b0", + "src/transformers/models/owlv2/modular_owlv2.py": "02d5c9649b7011c53ae7abbe85ee3e9c8dc1c9098623798d6afb6d82c267a9ca", + "src/transformers/models/owlvit/configuration_owlvit.py": "05e1b308cecd557782216df969a62e3348bf17be07f83a4422422ac931d9fd8a", + "src/transformers/models/owlvit/modeling_owlvit.py": "5e1c5e6124c328c624f22c0209c874476a38f8377e61455736ea1ee6a434de35", + "src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py": "a367d2457470c4dca9c970a0aafd078e86ea3d4c143b214eb2d355f1c3b69342", + "src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py": "f6e7405da8639b77ba48d2969d5269b2bb1e690a0de076a0cb4b02cb714d7245", + "src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py": "26c67ed5c926805b101f84eb644e4bb95155f76ab083f83ca586c4aeb8c1a341", + "src/transformers/models/paligemma/configuration_paligemma.py": "e105c7d11088815d0e0f81c87a6027d3fa19d0f94408d88b0d526b7f4cf5a314", + "src/transformers/models/paligemma/modeling_paligemma.py": "8041ed488de7d6e4e729c39dc8e6447b00bed5b133a1bc0ec4fce4efe2a4c21a", + "src/transformers/models/parakeet/configuration_parakeet.py": "8094d2cb6356c9c8d0a750e90817cae05540a593e1b965f380ed1a6582f4debc", + "src/transformers/models/parakeet/modeling_parakeet.py": "d31c77fe4ab12f52aa51d67b6afa87b81eeebc4c44db7da4fd7f79e5faab3eef", + "src/transformers/models/parakeet/modular_parakeet.py": "aac30271ca7db42c545af0331926cc3a6c53cac32a5854af100c8f72da0f8017", + "src/transformers/models/patchtsmixer/configuration_patchtsmixer.py": "a05421a49697b8eeac5f72338f0104039c3d3cd0428c6b47f258d45b60ba3172", + "src/transformers/models/patchtsmixer/modeling_patchtsmixer.py": "4370829b090dae6c11b38723d8d7cdb5e6094f04938a97c846542b2389805c75", + "src/transformers/models/patchtst/configuration_patchtst.py": "aa16fe260718f90bf49cb5acac459299bbfd3c65d016771d81ed19f36c746be0", + "src/transformers/models/patchtst/modeling_patchtst.py": "0905c675fa093ea494e042a653de4232ad752faf94d90db85c365181a14db19d", + "src/transformers/models/pe_audio/configuration_pe_audio.py": "9d64dee70020704520d90a52751a869e1f11ce8a93f3c414663716a713146243", + "src/transformers/models/pe_audio/modeling_pe_audio.py": "f6908fa1e29e1417840cca3218333816cef268919b2430b8b09141ad2aa35ff2", + "src/transformers/models/pe_audio/modular_pe_audio.py": "354a8dfb56cb686f34ef7c63262435be9c9ec67c2dd6ca2b5df5c2e91b2024e9", + "src/transformers/models/pe_audio_video/configuration_pe_audio_video.py": "c70bb7a67921801c0ba6d938d83f9a6d00f90acccc70b031d4bd5702667fc4eb", + "src/transformers/models/pe_audio_video/modeling_pe_audio_video.py": "878197f405a2053d206c9d1fe0c99418e24130ca2e6694b478c738ad0e4cc7a9", + "src/transformers/models/pe_audio_video/modular_pe_audio_video.py": "adaa557a9b9685e9c59bbd837c095f4f3e3746613241e7c37ec631e0f0c15295", + "src/transformers/models/pe_video/configuration_pe_video.py": "3d0d8bd4ff4bc9506f8c4c308dcfb8c85436a95662a784a28a6dd72829e2c33e", + "src/transformers/models/pe_video/modeling_pe_video.py": "7d57f7a2e20a76ef52a2508ead16d0bca7f46ca74938c3d903e6f60962f40f2f", + "src/transformers/models/pe_video/modular_pe_video.py": "9ce5def083ff4b6a5fc87006c08f4085500a64fe29beb23c4b5ffcb6de655497", + "src/transformers/models/pegasus/configuration_pegasus.py": "ce8f541ccc72fca3bc9040180ff15cc378985220aab8bf40c423ac2abf322529", + "src/transformers/models/pegasus/modeling_pegasus.py": "d6a69b437518ef2df314fb4911dd222131d03da31833632b76c882af3dc6a1a5", + "src/transformers/models/pegasus_x/configuration_pegasus_x.py": "147586c17c54d9c482e1904f89d26c0855da2882ed173a2200ae1f941b9b8abd", + "src/transformers/models/pegasus_x/modeling_pegasus_x.py": "f553eef33c8a63cee60d20553c6dbcd11e26f87e66d24162b37e7261df3f1934", + "src/transformers/models/perceiver/configuration_perceiver.py": "efa38796d38c916de83b173ec48fdb889b15172e871840f8d29dfcd2f3505bbd", + "src/transformers/models/perceiver/modeling_perceiver.py": "9067785d1003797ab679c7110ecb89be37601d1671430b2191928744012f61c2", + "src/transformers/models/perception_lm/configuration_perception_lm.py": "a62b9baee1e9f6b2419c1fc30d8a3fbb0adb4e8979f7c34cf0663f10dac1ed0c", + "src/transformers/models/perception_lm/modeling_perception_lm.py": "5e265f4aea10581c38261d9243e3271daf9ae5d176e7e80053b2cd8fa9cf3e65", + "src/transformers/models/perception_lm/modular_perception_lm.py": "cb68eb17a00343f049359e5a8edf06771b52e0a1d3e744b9954c482cd63312b5", + "src/transformers/models/persimmon/configuration_persimmon.py": "2a66073d47c5932c85635266c9c00ef032102442490ea430e6c49cc4a0f9e994", + "src/transformers/models/persimmon/modeling_persimmon.py": "be2f54ec07de6cbfb3b33c331081347ec81a000274909393d8e9ec39ad2bf133", + "src/transformers/models/phi/configuration_phi.py": "263005162ce6a676688d54863a31db23503899586b5222772c915eb46b17e65c", + "src/transformers/models/phi/modeling_phi.py": "871d7480a128838ef3acfe6a2ad04393b1c843a596ebdfb5e4a6df8f53f6fbee", + "src/transformers/models/phi/modular_phi.py": "7f211ccc439d8f73e2f18969c8b3df0f5c64f1df010b58a33412d9564f61d9ef", + "src/transformers/models/phi3/configuration_phi3.py": "c328e5f5c8a05cb7d54b16aad04d4614f9cb5cbc6350c14d51efe3f3ef52e3dc", + "src/transformers/models/phi3/modeling_phi3.py": "61b1796c27a4dd42e3090b250ab72d7566c92cb8329ab215da2182fb08f353d4", + "src/transformers/models/phi3/modular_phi3.py": "ae530b68bedcf03f26999c39b86b80b2c6e8e4890bd3a1247509af9f647fd169", + "src/transformers/models/phi4_multimodal/configuration_phi4_multimodal.py": "ee2b699d94191317bc645b4042d8cdb99548a73823fa463c824f2ae90f861f5b", + "src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py": "7c95d1468de82c00aed6e53aa0e4f8af78d86dd345fcb533434621433461c6aa", + "src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py": "97b63063b61266813513266b394afb24f06c6a5ccd76eaaeb34fffba3218a6c3", + "src/transformers/models/phimoe/configuration_phimoe.py": "02a58ada430ce5fd55bf0a304a218d7b99856a40d01db4a459124e427a0dafb9", + "src/transformers/models/phimoe/modeling_phimoe.py": "8d1a92c9415331e6dce0d252491f50ff78ef80b14c54402af79d05fd3dbc7fa9", + "src/transformers/models/phimoe/modular_phimoe.py": "82e8ba6b83af63f74a16c12e4b6b15f04019beefdd0567c3be23daa8728297b5", + "src/transformers/models/pi0/configuration_pi0.py": "4f016c235747a8302ac8b63f3ec18044beab4dd5c70a337925682501ef68d719", + "src/transformers/models/pi0/modeling_pi0.py": "94c58c3eba1eeb135cef63b5c6fd34bce3cd442b4b2bcc31c21e31284e694212", + "src/transformers/models/pi0/modular_pi0.py": "682003320af135f1c5f09fbf6960f5d26360e35f3b4f2e930ad678e4ec588e4f", + "src/transformers/models/pix2struct/configuration_pix2struct.py": "bf78a49afb4f65fc80cd18816fd80b72acd409a0531b6f66614546748b98ed4f", + "src/transformers/models/pix2struct/modeling_pix2struct.py": "589e4a4ed8a4a09094e66894a046e63a8cb596552f7269bcf62e87775e1ab1ef", + "src/transformers/models/pixio/configuration_pixio.py": "72cef8841ed4991cd2a5af27211676923bb64bae7fa5a400e159d469edcbee7d", + "src/transformers/models/pixio/modeling_pixio.py": "f5b640a57d463b54fd192dbb9709dfe2eb7812426d77c247af534e41c737db22", + "src/transformers/models/pixio/modular_pixio.py": "eef6f9cc338334c3537d269debf19c85f89c20a004c8fd67a923ef7a5255a8e2", + "src/transformers/models/pixtral/configuration_pixtral.py": "5f52fde8b2e352a2d08c6affc6209d76a67a651f7e8f91bbc970f9e11fb7b80a", + "src/transformers/models/pixtral/modeling_pixtral.py": "1a654b23cc3cd8cd3c4a9e2c5c4f3b323d7cf1f66c45681910ce37cbbd42aa1f", + "src/transformers/models/plbart/configuration_plbart.py": "e8fba9af7c472ec130d6ea35a9947dc1f470510c465f4e26fc4a0e83f67a1c1d", + "src/transformers/models/plbart/modeling_plbart.py": "42d5d0f5cc1c028352f416ba8dde36dc1b5f518e6062854e1a8106134fe8732e", + "src/transformers/models/plbart/modular_plbart.py": "966c40e76a8359136d30ed3c423f03a0505a3946af63b396a7cbbd524963c98b", + "src/transformers/models/poolformer/configuration_poolformer.py": "03f850ecb729e481275878b8f00326c27e40ab42504489ecb842bd2468a6a5d8", + "src/transformers/models/poolformer/modeling_poolformer.py": "1a2d2cba37f9d6814b973d1cb1ae5af562716be0ccfb93b16273cfb2850f4402", + "src/transformers/models/pop2piano/configuration_pop2piano.py": "638cacdca584c45b729005badfb0548911e95a1530d34625a117a6af770d5a52", + "src/transformers/models/pop2piano/modeling_pop2piano.py": "2336faedf16d81f2107053866eb1dea174b09f901def9d028bb9b14d744b351c", + "src/transformers/models/pp_chart2table/configuration_pp_chart2table.py": "0d6ef41dd9027aa167390534426a88996bd2b99821e90f2ddae726cd74fee32b", + "src/transformers/models/pp_chart2table/modular_pp_chart2table.py": "f07f36144fe0c1fc137ec92ed7ae500d76bb13fb7a880b1b386b47bf9dcd4f4b", + "src/transformers/models/pp_doclayout_v2/configuration_pp_doclayout_v2.py": "1d8e6147d0ed203e53b23d6c127525a1a8042b292313e1d9b7452f83b24bf8a8", + "src/transformers/models/pp_doclayout_v2/modeling_pp_doclayout_v2.py": "ec06e14eb3c3810a77d2ae1ea2650edf4bc1e64b1a110071860e22f30e87e06e", + "src/transformers/models/pp_doclayout_v2/modular_pp_doclayout_v2.py": "c8b62e3792bb6e00f63b65756ad57f3de637fee412bbdf6bb048923cbeac3402", + "src/transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py": "5bd3acaa93534c0f7a2e2a39e4ba7f6b9bfc2fcf6fadbc33afcc4384b4e6d3c0", + "src/transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py": "064342c448a177f479bc07df673f7c3924301c0ba754d1ac6660d8e99bd9ee24", + "src/transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py": "f29b2707b642c4a07bcccb3dd7a95251d49d21c35136ff8515b60108d71f90d6", + "src/transformers/models/pp_lcnet/configuration_pp_lcnet.py": "3b04bf71a4ca210ef0268c2147421d5e54bbe5b9ac770b558daefe46b9043383", + "src/transformers/models/pp_lcnet/modeling_pp_lcnet.py": "c1624c2f4977d9821d859d9a94f3fe52c1ea00f1dee7ec040a53a2cd6fb8ae60", + "src/transformers/models/pp_lcnet/modular_pp_lcnet.py": "ea882bc49fca8720dd370df0f0c181629db31e34692f1aaabad608f4660d70ac", + "src/transformers/models/pp_lcnet_v3/configuration_pp_lcnet_v3.py": "8484e95b379186f3c4ddb791a3f7a78e7cd9dc98b5c7e98254233a99a7551ec6", + "src/transformers/models/pp_lcnet_v3/modeling_pp_lcnet_v3.py": "e0a9c99e026db17b99a57c2e2789f618eebc3c0d7adf9f513f2d56576896c8c4", + "src/transformers/models/pp_lcnet_v3/modular_pp_lcnet_v3.py": "b0e49b53962f43be56ab1a609d1460760daab37ef2e7942e0a357a949d8cf249", + "src/transformers/models/pp_ocrv5_mobile_det/configuration_pp_ocrv5_mobile_det.py": "e7ddce0a384dfe27e30cec3863c4546fa007bed800a72724970303d68742c305", + "src/transformers/models/pp_ocrv5_mobile_det/modeling_pp_ocrv5_mobile_det.py": "b248be15935132fbc75b112a31aa31891e567a9f26b557c068caaa5312dd3a67", + "src/transformers/models/pp_ocrv5_mobile_det/modular_pp_ocrv5_mobile_det.py": "d3d4953fa00063cbb8f5964c3174ba99bb20805d3f8d5934ec035bcc9f54bfd7", + "src/transformers/models/pp_ocrv5_mobile_rec/configuration_pp_ocrv5_mobile_rec.py": "b7b298ba3094690d9e5ce974669622c5397b9ef62a6287d2b2fb80792ccf8bfb", + "src/transformers/models/pp_ocrv5_mobile_rec/modeling_pp_ocrv5_mobile_rec.py": "4f7616efe95db7e86f525b6b91892611f82187ae4f177d4e19f4af75e54fcaaf", + "src/transformers/models/pp_ocrv5_mobile_rec/modular_pp_ocrv5_mobile_rec.py": "99e04cd8d97bc71ae085f8ca00b0a8a953e24f3e6db67bbf222bacda9c6be98e", + "src/transformers/models/pp_ocrv5_server_det/configuration_pp_ocrv5_server_det.py": "c6d42d4abdfdb2a4009d79ed9268dec9377f475337f76cf4d6df2fa9140ad96e", + "src/transformers/models/pp_ocrv5_server_det/modeling_pp_ocrv5_server_det.py": "94bc6c025512d0bf38ba4965f2749351fbc795c47897bfebf3a931581ac34dc5", + "src/transformers/models/pp_ocrv5_server_det/modular_pp_ocrv5_server_det.py": "b52b34ce33108a8ac665a46b8e76146d8b11b72de15411d47071ad6eac9ee162", + "src/transformers/models/pp_ocrv5_server_rec/configuration_pp_ocrv5_server_rec.py": "dd4be03e17cbe2335fdfe7eda64b140c93277e24553dccbaadf6ab18ea143fdb", + "src/transformers/models/pp_ocrv5_server_rec/modeling_pp_ocrv5_server_rec.py": "58432c2ece27e372ab4a5a8e43f76269ab878d5a711ba26f96e542c80f41b4c7", + "src/transformers/models/pp_ocrv5_server_rec/modular_pp_ocrv5_server_rec.py": "9cbde65d7536228eb11fc6a94f96b9a5c4596a95eadcaf94863192dc071ddab3", + "src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py": "ff0de4853fd45f7100c742e9656af40443aa69a077b94191f1dab34fb51ce966", + "src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py": "6d687bcbc80bc6bf0829520eef70e6534a655be798d07e3babbfa23454ce8dfe", + "src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py": "a31e08303c06b9c67ffba757c18ee8e96b94b07ad8da8da6489254d65c496155", + "src/transformers/models/prophetnet/configuration_prophetnet.py": "af504369210782aea21d35e793d161073bec8943cb78923dccc04440aa7764ba", + "src/transformers/models/prophetnet/modeling_prophetnet.py": "65c8a552e463a2c114d471bfbfa6bd92e196f4594dce0aa4c2863f6943bf9efb", + "src/transformers/models/pvt/configuration_pvt.py": "ca034732e3305db37178f18aeb142a2ade19f043e1ca9871d9620b2903711de5", + "src/transformers/models/pvt/modeling_pvt.py": "489262f1b8dc6a946d820b4b48b70cdb072a8dabedeaaafbc90aef506ba0923b", + "src/transformers/models/pvt_v2/configuration_pvt_v2.py": "519655c4ba46a21a1ed03d3bb473193c8ba11e051889dd7bec607c87ca6f2d55", + "src/transformers/models/pvt_v2/modeling_pvt_v2.py": "784dc6eb1b7f5b3b01e3b9478bf84a89f26eba1a6300fbdf0f832addd8ac9027", + "src/transformers/models/qwen2/configuration_qwen2.py": "a88024a2ff48e088ae76b53fd7eb540b45c992edfb41d332181c2deb1f47809d", + "src/transformers/models/qwen2/modeling_qwen2.py": "611e61ea4f0881a4b4b3cf77355b0a2ce1230d8a8674d1aa4bac7f2d8d76fc0d", + "src/transformers/models/qwen2/modular_qwen2.py": "0b9da277f60da80c7372ab3004d60a9fe4468d2c17d39e4366ac0901a09d3e2d", + "src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py": "18a218ced4ced4bab28f3466f5a7c1237c9afccc44c801387b92322fe16f2f25", + "src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py": "286589c8fe5e8a3b8441371c4adfaf21659411b4c6d1862169742928a2d405f3", + "src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py": "191e242c1c73018b7e5b4004ce29d87e72b6651262441a6183ccc463697d7137", + "src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py": "ad6592cb9138074e4370c74d57ab4d44afaf05fc93f159c34e154e019e5272f8", + "src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py": "a5aeca9101d862571dbb83b52ba2bd885ccc6c79e11bf904149ceb7aa4b7219e", + "src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py": "47322127e2a622f0738dfcedb0a5eb3189cc819b532721b842c2577454d8da90", + "src/transformers/models/qwen2_audio/configuration_qwen2_audio.py": "07a42ed953b74fbf4561ac08625da711f5705efb5fa0ff7ec3f6b399bc13e476", + "src/transformers/models/qwen2_audio/modeling_qwen2_audio.py": "592b2a7584a988ac564dea3d04f35d4e908167db523e105e2b7133a43601db6e", + "src/transformers/models/qwen2_moe/configuration_qwen2_moe.py": "78f97945f5605d79937bc923d32f5c2aa029cd7241a30daa2b24fc7d89e69380", + "src/transformers/models/qwen2_moe/modeling_qwen2_moe.py": "7e0f78b5e4b8e704f5520aa38ace6c71c1ba363aa73376e42ee1c5650457fc3c", + "src/transformers/models/qwen2_moe/modular_qwen2_moe.py": "30f8156bf6dc9d17a88cee36305f6faab485349e5f04cd49a51c72ac2c81559d", + "src/transformers/models/qwen2_vl/configuration_qwen2_vl.py": "2c62d3360f08cea9fa805165f101bed8b0f07b5eda7f06e4213362b2a4f399c8", + "src/transformers/models/qwen2_vl/modeling_qwen2_vl.py": "cabb30fe2fd1937eb3c13bdd05b9e58cd0c6487d3d5d9188b257e049b760e5da", + "src/transformers/models/qwen3/configuration_qwen3.py": "7b902535b5700dacc4ffefc84789809b89655c7b713936d458cf626cc055be8a", + "src/transformers/models/qwen3/modeling_qwen3.py": "94e34d4fab68a3a30072d4772894745f7d621c87e643372da51c279e1538b80a", + "src/transformers/models/qwen3/modular_qwen3.py": "3687b8bf105a41fe81d8ba4825afc25dd5f1475bdf35b64ccd558dc829a8e1b6", + "src/transformers/models/qwen3_5/configuration_qwen3_5.py": "83c4c4000c1ca0af602ee729a392d99b31942d48d2727461592403172289c146", + "src/transformers/models/qwen3_5/modeling_qwen3_5.py": "97e6f3850aac8713d77304ad07fbe9c97ee8226b97639f877511eabf74c37993", + "src/transformers/models/qwen3_5/modular_qwen3_5.py": "56dcd3cbe72f8733c5b8870f0520e6cd2a88a2448cf1dee66943c7eadba6edb8", + "src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py": "f46a65317ab75563341d76cd464a6fd511cc286851bc13c6af9fcc7ae577bd4d", + "src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py": "83e8ea2c8e331ae5ce81abe5cfd13b82694eac0eb7a99e4c7be6f80d1961ecbe", + "src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py": "9c13121107af598758ddbe05e226a907e30d23883a040b29ed86d485f87c873d", + "src/transformers/models/qwen3_moe/configuration_qwen3_moe.py": "d5cb43531bde34cde952e440055c1684dd9f71d92cfbfbff3ffe5a8a4f315261", + "src/transformers/models/qwen3_moe/modeling_qwen3_moe.py": "d2fcafdb7f784349d901345fd8f43cdae16215e8836197d085227615e209165c", + "src/transformers/models/qwen3_moe/modular_qwen3_moe.py": "43f488cf6c2dc6d74eea85beaf546556120db6a03150958eab50d14bcafa31c5", + "src/transformers/models/qwen3_next/configuration_qwen3_next.py": "7ec2aaf1964c5c563dae05e8caedad76db5b78b8cd7644bfb4bfb0919087d5ca", + "src/transformers/models/qwen3_next/modeling_qwen3_next.py": "12d66b39b73b040ee7cac784598d92de61b55ed9f424679579be1d881c6d14be", + "src/transformers/models/qwen3_next/modular_qwen3_next.py": "945db2183862039da809c85e0617d31b2d3dfb83dafe6ccedc6e94f12f479ddf", + "src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py": "6f28a3eb77c4879a061327e45bbbe9c18ebda3365b09ea7c0710239fbdd5e4cf", + "src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py": "9f8d8e5b25d17216e06e4f98bae534ab462b33c5c0d15777ed6a45f1516dac81", + "src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py": "ed35683349628e447cfbbf003cdb2d413a80d8733ecdfb6bf7715a4ced22e0f5", + "src/transformers/models/qwen3_vl/configuration_qwen3_vl.py": "bda6053535dfd5afb6dc5f861d3b0ae60a3f3bdfb0c5a1511baec1f99e38ba77", + "src/transformers/models/qwen3_vl/modeling_qwen3_vl.py": "c6b393ef254833a5a74bb0d3284d374e0b9bf3cc06bf0f9ba0ede6e0d34c108a", + "src/transformers/models/qwen3_vl/modular_qwen3_vl.py": "2d54c8312cf298eda35d97b91a2d1dd598de7f7fcb9de5e7cdf31fb36e45cf91", + "src/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py": "c296ebf5b2f5c64dcf99f505e2a76d58513b788542cbf3840b41340bc85a2a7a", + "src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py": "463c8ef897fd324fd4d1acb8f3664b83f846125292e4e5f40c975677821465ed", + "src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py": "c8203f36a93399cad0a2d0e8fe1e3c2197e63626408e8df963c01e92acc78d64", + "src/transformers/models/rag/configuration_rag.py": "a730529a627ed83af4c439b672e6243a868a6650a8b8d0d26bea5fd824a3db89", + "src/transformers/models/rag/modeling_rag.py": "b98c3ad8895d610274f4f371959c0eafbfb27281b102d783af42fda3af9121c0", + "src/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py": "f43e56390a9025a8e9cd976833f5a2d30a27c67adbf051e6978bb7688ca114dc", + "src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py": "171ae9d240d6aa199d11b03675d0cf236675140a01b7f117001d7b0ae12dc3f6", + "src/transformers/models/reformer/configuration_reformer.py": "b43964a765ff07bd99009106a3cc80d7d4333649dfe4fb1a6f87fcabea16acb5", + "src/transformers/models/reformer/modeling_reformer.py": "2b42c2c13a1025c61ab375960fe6ad277a9adb5febd89f71f7b1707616418b0e", + "src/transformers/models/regnet/configuration_regnet.py": "438bba88e2093027e77456f87369c9510ec905910776494d566f06a2cbd85032", + "src/transformers/models/regnet/modeling_regnet.py": "6bae13c5488e15f2ab600baca82409b683a5b13a7742bd1dd079cdbb8733a123", + "src/transformers/models/rembert/configuration_rembert.py": "a237a0c90c7e307f7405d778f0f2b2b1fb89d7cee9632a3249389daedd3cacf4", + "src/transformers/models/rembert/modeling_rembert.py": "346a04cbd7b81ee907680eed370d92a1b83a186816e426dbaaffd6ca829f3eb9", + "src/transformers/models/resnet/configuration_resnet.py": "833b96ca9efdd457ff5771c68d2f6de3f08ecb61bd8c4d169c6cc905ed5dcfb9", + "src/transformers/models/resnet/modeling_resnet.py": "78771bb35f85d3853e73097244ec11c3b1addc4a4ff251621cac9c36c12b0b94", + "src/transformers/models/roberta/configuration_roberta.py": "e579a667c856b6db5adafe8c9984baa14e587bc5da0dda15ea3a4c8a1e2bfb29", + "src/transformers/models/roberta/modeling_roberta.py": "a8f6586849f29b3b74dc0a37c732ef5b414ad4e23f6cc90f28b913ad2659a86d", + "src/transformers/models/roberta/modular_roberta.py": "919750780b4936c2ba5d1d411ee7b92240e2de7643718d7fdf59658b70128a56", + "src/transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py": "10606a5251c4fb27d2b2eaa127f78676601d30be6c2360e22db80823224419eb", + "src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py": "78a61603110029577db51e369badcd141a83194b816ef7d065096983deb01a52", + "src/transformers/models/roc_bert/configuration_roc_bert.py": "3f0b03a86dec3eb2a110c6a4378a3b87d9c909bcd8c89e521203d7ec5f503cda", + "src/transformers/models/roc_bert/modeling_roc_bert.py": "6a683a01a9490b5a0a0679cc105cf2bf8dbcc3c06933befab9c83fdcd9efe99f", + "src/transformers/models/roformer/configuration_roformer.py": "f74e96078410c8a25aa49f667a6e96116da7a46c4a981f464b2a05c7cb674cdb", + "src/transformers/models/roformer/modeling_roformer.py": "7a4f6b71e1064702272ed49901238ff07ce8207eda90f0e15b910550c0953a25", + "src/transformers/models/rt_detr/configuration_rt_detr.py": "b86a17451e5f75d96848ba694ea40491526d2ecc2e8a147688dc7b11a300b6b2", + "src/transformers/models/rt_detr/configuration_rt_detr_resnet.py": "55188876b979401c0ece58f3e5fb8f0c670f7af2acd5575d86750bdd569bf8b7", + "src/transformers/models/rt_detr/modeling_rt_detr.py": "82d696cd7776f1b76ea6b6239640c3f58e073c34df77b6479014d2f7b2c4c619", + "src/transformers/models/rt_detr/modeling_rt_detr_resnet.py": "cda38329575133e986d93788b84440f74a5d7982d57189be056fcfd87b5445c8", + "src/transformers/models/rt_detr/modular_rt_detr.py": "a2ed2bdf2925b480683b969634328bada856f30ab570f1551f4b10a5bea066c2", + "src/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py": "175bf7f1afcd08eadde3102d28b00e64d6217cec59ad2388fc486eb7f90b4523", + "src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py": "7b2ccea0aaa432d654d30d79a727257e9be356e2868780882d891f978cf97281", + "src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py": "c0ce70c8333e68a8945d1ed4ae81b355c5bcea17916d2126998deec4a5058704", + "src/transformers/models/rwkv/configuration_rwkv.py": "963e5e4595dfa25fcb9cfcfabdf1ac81c1402cdcd23b1b0e14c94f9dbc6c9ccd", + "src/transformers/models/rwkv/modeling_rwkv.py": "ffdce2de84d38fe4478be43aca0220e5ecb129c51155af72a632d50bb5087058", + "src/transformers/models/sam/configuration_sam.py": "de4813d26f76e04824ec7fc50491094d5bad507f71a6276cd4e2e2f423d5fcfa", + "src/transformers/models/sam/modeling_sam.py": "96bd35c10a1d5b66530e745b45129aef67f954078b6ccc43fe54d0cbc06d1920", + "src/transformers/models/sam2/configuration_sam2.py": "b042b6ac1ce51e87ce4dd2f1b62e152cb2595f0c5d88f8d9deb8fcc274a1ca93", + "src/transformers/models/sam2/modeling_sam2.py": "6254a79f9f06bf3eaa5f382ad7499fd3978710b2a73750a740c5e937139add9d", + "src/transformers/models/sam2/modular_sam2.py": "b57ecc1a0bf268f0e0b30e876502860db3b1f854a012b8d8ecaad15f47f54af9", + "src/transformers/models/sam2_video/configuration_sam2_video.py": "abdf32de02ce10a3ff1de2fbe850b9e7dd5bf6f8003b3dbff7500e66485083e0", + "src/transformers/models/sam2_video/modeling_sam2_video.py": "22a2d8c4d99490f4b193b918c5ce90d427769d328d180b42e8cffe79532814f9", + "src/transformers/models/sam2_video/modular_sam2_video.py": "95d9e8875165e0d4fe68a520c6d4e93b9bfbeb7f3a4bca01bdd3777b302e45d2", + "src/transformers/models/sam3/configuration_sam3.py": "23bee427162c2c33833150320f65a96d42006f20feb3791e52d4c8536323fd4d", + "src/transformers/models/sam3/modeling_sam3.py": "d450474b0c9747987cf7f723d3abb728c5967fbe80b793e48fe6178a4fd4cf03", + "src/transformers/models/sam3/modular_sam3.py": "58d1654c4c7a6089ee882f83e8d285e79a151ba51a93fdf45272f3327fbcac0a", + "src/transformers/models/sam3_tracker/configuration_sam3_tracker.py": "a82deb999c8374dc96c34139c5175b6c71dcc7ff2c20a6c084154a4402e1fbcc", + "src/transformers/models/sam3_tracker/modeling_sam3_tracker.py": "0c799adf42bb9c54b36c5eb99442c180981d328af3555b4b526497d097a4333a", + "src/transformers/models/sam3_tracker/modular_sam3_tracker.py": "0b039a6b18b87b9eb347a5396ac1a7f1c02b109cb4d184d579b6866d166caa1d", + "src/transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py": "cd2241e9f42767bf711d6e5a1e2cb315fffb0c0becd6cfefd26ab09fb02914ce", + "src/transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py": "5afc119915836d4369dcf53e3e074555b488beb0333dadd9b3206dd917942c03", + "src/transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py": "8dc65910cb8d09ccf9818d46a63e490303867dad11553c6388b9402159037ca9", + "src/transformers/models/sam3_video/configuration_sam3_video.py": "17b7a40f002bee963677665ed7c23b93590ea7af17111753db9a724ae3253f15", + "src/transformers/models/sam3_video/modeling_sam3_video.py": "c573ceb732dea819db1dcff8f242474159bfda7b2484eed77110396adfd10ed3", + "src/transformers/models/sam_hq/configuration_sam_hq.py": "cda2c3b8a62fb757e01c1a555cec49028a9a45db685c76cd4779ea0154e76940", + "src/transformers/models/sam_hq/modeling_sam_hq.py": "75fc1413ab47473b9502179f7dd281efa4b7ef8ec4e9847daa3522844f1cd784", + "src/transformers/models/sam_hq/modular_sam_hq.py": "55f90de4ba901867f96e593525648d1aa5f383e08224bce8662faaec23b214f4", + "src/transformers/models/seamless_m4t/configuration_seamless_m4t.py": "b6166c4f198ed9b7eebad18d3b1024ee8535d1620980ff00fbda43f568a4eb63", + "src/transformers/models/seamless_m4t/modeling_seamless_m4t.py": "b40a4939c4c2ed733fa7e6727641ee300eb935326024b9d672ee314eeb206641", + "src/transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py": "db2835bc60fa9e9fa94a8c2f0b2909dd01603adc4ef8bdf9f14c19b299b614b7", + "src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py": "cb83cac78a10a30313e32c6703f6c2e1bc13050c6951f7ed8c5e205f54729504", + "src/transformers/models/seed_oss/configuration_seed_oss.py": "0a908f8903f79fa3d7bf2857607133d1b86f44d8fd736faaf82fc8270487c932", + "src/transformers/models/seed_oss/modeling_seed_oss.py": "b48bd86899465b19663360ddcbfef1aa7ab05ed73ca3a82cba71c6693674fea3", + "src/transformers/models/seed_oss/modular_seed_oss.py": "5a8f3b555035b7ea87d70d7618c29269901866266acfd83818abfcb5c2e24c44", + "src/transformers/models/segformer/configuration_segformer.py": "7c9ef3137748d7d1afa3297d5cccd51e573888ed9aae67d218f07e7544161bf0", + "src/transformers/models/segformer/modeling_segformer.py": "50b5409817ab3152c8c5b7ef11db97647f8b5491ecd44ae510b91445e5ce7485", + "src/transformers/models/segformer/modular_segformer.py": "1af15e94abfe23da08252b783bdee98501ecaaafdc1ac3d4d58f29ed300fe54f", + "src/transformers/models/seggpt/configuration_seggpt.py": "f153230dea96743f50ac4d4ec5d1ea409b7840d6a32ebbfc21c855d70bf577a1", + "src/transformers/models/seggpt/modeling_seggpt.py": "cbe63676fd866d880d943f3d0df79fdb4527851e37be29482019c82a32d6d77f", + "src/transformers/models/sew/configuration_sew.py": "f7facf01cfb83060307adef7df69592eef44e04ffc6b26a50ac9aceec5819092", + "src/transformers/models/sew/modeling_sew.py": "883b8146429e7aa62a8155bffac12c1b16cdf0f99b6f75ecd1a62d3e843d0944", + "src/transformers/models/sew/modular_sew.py": "386f03666184106d140560663edf18f45ee58db68563b9e6d0d2da5abe0f87e4", + "src/transformers/models/sew_d/configuration_sew_d.py": "6c5696296ca4249646524c434ed6b746bc4bead5bd5210a067c8b1e0359b59ce", + "src/transformers/models/sew_d/modeling_sew_d.py": "600f58b34fc0a25c9ffffc17d7833a8095debd3b607a5511ba10ebc4c6ba23a8", + "src/transformers/models/shieldgemma2/configuration_shieldgemma2.py": "5973607a32482e769a38e675bf705ea7c7892a570ffa9625a055529f32513e90", + "src/transformers/models/shieldgemma2/modeling_shieldgemma2.py": "0dee8b7f44496b9e9454c77f191beab71d830ac126c3b9cfadcc55dbe9343b61", + "src/transformers/models/siglip/configuration_siglip.py": "3bff9508fcf4c2ba1a8379a168fe9874e1a93402b211b4f456165a6e53c2c666", + "src/transformers/models/siglip/modeling_siglip.py": "a541f04b950ca33106c52f34f5303adc2882dacd27b77284f61885324ed4ea10", + "src/transformers/models/siglip2/configuration_siglip2.py": "c924e37822921fe7a51d85fc8ba8668fa04d4815cd3e76d53a04b1a3c42b548a", + "src/transformers/models/siglip2/modeling_siglip2.py": "b7ba2b0af87cb31f02ab25cd010009abbd4efd8e4a90dee467c81b35d2156c1b", + "src/transformers/models/siglip2/modular_siglip2.py": "472fae36bb4e62c2ec9e82d0ca39cc7a6622f39f55d52dff757d5bb1b28845da", + "src/transformers/models/slanext/configuration_slanext.py": "87486fa6a600498d14b37b424db7b9967dda52a3e12aab4f1c506bb26393ae9d", + "src/transformers/models/slanext/modeling_slanext.py": "00ebcf239d8e8c82bd8fbdd73035912ecf41bc914db95dd049dec0e1974ec8e7", + "src/transformers/models/slanext/modular_slanext.py": "66528eda541815e67d196bcd5ed2b21129d1371f281e36009f1e9a6e6e368049", + "src/transformers/models/smollm3/configuration_smollm3.py": "2b17fa65fe25ee851963e8530ff36880bf4e29b315706f68856d8fb4c90806e6", + "src/transformers/models/smollm3/modeling_smollm3.py": "451eea89bf67518f9fd6589dfc771ade3adc0572d5a9454f8cedd8e190e04549", + "src/transformers/models/smollm3/modular_smollm3.py": "ced515a3ddbf6a898be9dba866a86aa49ae642b1afead61f2231a3dba09b45af", + "src/transformers/models/smolvlm/configuration_smolvlm.py": "ff218dcef1f1fc3075957b94167d39578a51081fb083a74f47b3fd6b947a87cb", + "src/transformers/models/smolvlm/modeling_smolvlm.py": "de36dee70698809f275aeb7173cf94812b248f1915f6139eb9dd38ba7eddf424", + "src/transformers/models/smolvlm/modular_smolvlm.py": "6d8029d85d12e630a37c0222338cc069745a698ddad7ab4172b1b66cb3cc8d20", + "src/transformers/models/solar_open/configuration_solar_open.py": "ce5ac3768b65df6e89528a903943a7ddae3b77b33e86e2895f841b89d537bb45", + "src/transformers/models/solar_open/modeling_solar_open.py": "0f8ece99c58443212500b47b8402318fcf7923d1cf57dc706023e81d718d0c29", + "src/transformers/models/solar_open/modular_solar_open.py": "7ab97d3960f83c74a3f8b61d78fbe7e176253041cb7edb886e490fd58f6f15a3", + "src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py": "d5ae14b909f586af37a1a7b72b4164550bbbc6c28732706c27571c800eb9d001", + "src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py": "ab74eb94be4bfdf72f4ff2c3eb5127fc21269619ddb322a34ec3112b2f18d9de", + "src/transformers/models/speech_to_text/configuration_speech_to_text.py": "531eb38b46bd8e2f31807aae7be4af1cf09ab9f9271e033f31adab6960552718", + "src/transformers/models/speech_to_text/modeling_speech_to_text.py": "044b29b22f5c6bacc133eaaf343306eb3a44373f2071c77a32b2ac7554e4d516", + "src/transformers/models/speecht5/configuration_speecht5.py": "203e0b0951f1948e7d918ab157749effa7b2675d03883fe5c90f69bf9b322e98", + "src/transformers/models/speecht5/modeling_speecht5.py": "d61e0877206480e4ed436072956927f2c4beaf3b036adfca577f68885562659c", + "src/transformers/models/splinter/configuration_splinter.py": "86b5b1a97377773cb22bfd16e8fcc1ad7b5e70ee888781fffc9420a7f5e78020", + "src/transformers/models/splinter/modeling_splinter.py": "bb84c61c46388ab1784edc8767aa06117efd6b1c2a6847b6f36a69f6d9b6292c", + "src/transformers/models/squeezebert/configuration_squeezebert.py": "2b177257c3341d8a98c545fd37559229b6690db823649f37cdd9fb919ccc14d8", + "src/transformers/models/squeezebert/modeling_squeezebert.py": "a6cb8f6ffac27cea1a56a91679958fd76b118d3cc5c52572dc2bb499a386e649", + "src/transformers/models/stablelm/configuration_stablelm.py": "6c9dff2cc1ad0880471223039f6bead226b1e9f75e7d9c65dc95eb717910ca9f", + "src/transformers/models/stablelm/modeling_stablelm.py": "020c276ebb0d8912fdeb4d66509c2b25526c9839c582158ccf29057368bfd8da", + "src/transformers/models/starcoder2/configuration_starcoder2.py": "df2b41071c6029b2cb4a04df67aac703d97ac7833e1b96bd099e9a87e2a515f5", + "src/transformers/models/starcoder2/modeling_starcoder2.py": "ac13a5196a749b9019590950f5250211f026850c46184e607774dc76d9a2b24b", + "src/transformers/models/starcoder2/modular_starcoder2.py": "45ce614a43276fb62ebd5f2a55aab6f8a978d446fecc935ad9d4ec7450a7ae7b", + "src/transformers/models/superglue/configuration_superglue.py": "7d3a99fd9dd299bd10912700bb798d09aa552406eb4cf1d12fa082ceda202ff3", + "src/transformers/models/superglue/modeling_superglue.py": "b39a02dab86a28513aa3c5c2bca8c10f1ce2555d37d9a0a0bf7168c63e8647b8", + "src/transformers/models/superpoint/configuration_superpoint.py": "c6ab918b4eb62ad04b2308cc753f4fcf28e13dd91803c0620458a610fdf1e155", + "src/transformers/models/superpoint/modeling_superpoint.py": "1a4dcfcd24292ddb49d0f829071296182c119ef93003b606ecd6eaeccda3752e", + "src/transformers/models/swiftformer/configuration_swiftformer.py": "30b4493452f1076f009043bf284dfba697887f929018e8f31256398a18f51f2f", + "src/transformers/models/swiftformer/modeling_swiftformer.py": "b908e5359b1de228046dd369a553bec6e96ab89376cf7cab00d757d03c9af4a9", + "src/transformers/models/swin/configuration_swin.py": "f644315ba345e8a2ddc1c3bb094f612a3250e8b87d26c1f737238c623fe16378", + "src/transformers/models/swin/modeling_swin.py": "7f1a074e152692bb781439dde56a0cdd77afa2d8292192168b51348051d0ce3c", + "src/transformers/models/swin2sr/configuration_swin2sr.py": "0da941a458fe11d3d558335ca75493adeac53bf2b14564b627696e1e2306e0f3", + "src/transformers/models/swin2sr/modeling_swin2sr.py": "8f9d31c11511119759ca72a1bfde5aaa1d8c19b8a2b3f3890a9aeafccc9ee193", + "src/transformers/models/swinv2/configuration_swinv2.py": "85ff6596e90a2fb2c4eacbecf949dcaef95c0003fc41e728472d48d14a74226b", + "src/transformers/models/swinv2/modeling_swinv2.py": "aef42ed7702fec6ae5ca86f766444a5e9493fa41c4181a3ec868e1ab7fe27256", + "src/transformers/models/switch_transformers/configuration_switch_transformers.py": "e08699e3ea2efe898a5b55c91baa6d23a20509020491c88dc3e491f04e284e00", + "src/transformers/models/switch_transformers/modeling_switch_transformers.py": "e6e26d6f031962ae376df836ca1c1a978f22768b6c8daed1be55a32fb3bba21d", + "src/transformers/models/switch_transformers/modular_switch_transformers.py": "4ab72c726d4a7e1eb762c194c78ce935c3daa62f8510f86166d64c0ea4a172ef", + "src/transformers/models/t5/configuration_t5.py": "258f1fb36f3116f2cb8711eddd68076956ae9455e7ff78ca5b1de695f6d89889", + "src/transformers/models/t5/modeling_t5.py": "4c0373817c224fd38a60337829cd10e31be0cc43aacb958134618a394b675e61", + "src/transformers/models/t5gemma/configuration_t5gemma.py": "6aff9e558005d4696849f4f66f21521e622d03bfc474360c7ccb0df9eeb874cf", + "src/transformers/models/t5gemma/modeling_t5gemma.py": "1830bc343bbf842f42d955e8ee7dfebe2ae4468a680de73efd3a2b17f122c966", + "src/transformers/models/t5gemma/modular_t5gemma.py": "b44b54873a282bcd159c5552d3dbd5a0163540315f5e378235b07658e24d648f", + "src/transformers/models/t5gemma2/configuration_t5gemma2.py": "a66102c62f5f7bf0e51c29e173f41b9db6a6b695f5dceb823dc6cb1ca6751a2d", + "src/transformers/models/t5gemma2/modeling_t5gemma2.py": "7865c70886cc526235b2915cb0814962b0eb27df40b15d24ed3435f46c22b6c5", + "src/transformers/models/t5gemma2/modular_t5gemma2.py": "5feee8c51df3513aa3f75f0c7d6c9404f5cb50ba03ca5dd8f4fc92c26cf250ff", + "src/transformers/models/table_transformer/configuration_table_transformer.py": "ee2f3cfa80d7cb195a2392f22ceed3e969b87268ba73266eb18d393e911a4e13", + "src/transformers/models/table_transformer/modeling_table_transformer.py": "eb6be8eb35da0216ac918b714791a11bd8499b4eb301e05529733bdb531a11d3", + "src/transformers/models/tapas/configuration_tapas.py": "a2b106513778866670262009862e55c263b2072e1c4bc16db75761493cdbe880", + "src/transformers/models/tapas/modeling_tapas.py": "24f9efcf6eaf9c1f26c40f0fdc185b4c964adbbf6d1b1d150253b7b6f6f7f29a", + "src/transformers/models/textnet/configuration_textnet.py": "059ba9112d3a8bfc909ffd1ad8f71d6564e43636e01c5e8c5170136605cbff19", + "src/transformers/models/textnet/modeling_textnet.py": "33a6880b52c6155694659d000433eb01c2e495001b0cf02037fd1d0581c11591", + "src/transformers/models/time_series_transformer/configuration_time_series_transformer.py": "7d0fa2ba71464b3b76645cbffc061f8e0805cafe22c81d652ef1963e204709a1", + "src/transformers/models/time_series_transformer/modeling_time_series_transformer.py": "1678789e34a1aeaca9b3eb7c8289b3b0e5c586355c32b1742910acb441da5d18", + "src/transformers/models/timesfm/configuration_timesfm.py": "2abf26d41f2c8be3ca7fe814a65291b925afcd2983b1bb11f2c726dbd5325a62", + "src/transformers/models/timesfm/modeling_timesfm.py": "867ec94db7ff94558a6cf892c9c7909729c686f30a010a04b9fc328f7068247a", + "src/transformers/models/timesfm/modular_timesfm.py": "2f89a6f2b18a615db38e1a95af8db7652ac7590be6e75ee4def2e19ff2c1fd8a", + "src/transformers/models/timesfm2_5/configuration_timesfm2_5.py": "12dfb4fc3a4c51930f9e8eaa9bed7fc4c8c22ce46635d0eee4525cabc24c2068", + "src/transformers/models/timesfm2_5/modeling_timesfm2_5.py": "db9ac2128206101e0e925eefbb054652745d2bb4256b25f4b2bae2f70d8bee19", + "src/transformers/models/timesfm2_5/modular_timesfm2_5.py": "e007a0eeb98cfbf2f31a6439eab64e364af8518e87ce14e6e67b22d5c0c1719a", + "src/transformers/models/timesformer/configuration_timesformer.py": "7e95903fcde64369ff098dc5884b782dc71e48d118518feefb129e1eb3a4bc5c", + "src/transformers/models/timesformer/modeling_timesformer.py": "c4df75c7a4231c3f30be58ce69dcdbaa70ef657e610332f44d4df664b8c59839", + "src/transformers/models/timm_backbone/configuration_timm_backbone.py": "c24ef6bbc33e5f1019badf9e6acd52806789ced26b57c7438c30477ba9791221", + "src/transformers/models/timm_backbone/modeling_timm_backbone.py": "2b1e156de437cd34efcbc6792f3fcb4ac9d8504e8a6fcd3b4cc5e1e42074d3f1", + "src/transformers/models/timm_wrapper/configuration_timm_wrapper.py": "d454a21d942d854f509fc48c805a84d2c7c73f52529bc90a94b0db893f5b9d8a", + "src/transformers/models/timm_wrapper/modeling_timm_wrapper.py": "7e75b6441add311d7fd017843cd0695706ce36831cf6dd60202f670ade87769a", + "src/transformers/models/trocr/configuration_trocr.py": "671de0e15378200e18534270167dbce96c2f0d991145a8a317ac0851d8c192dd", + "src/transformers/models/trocr/modeling_trocr.py": "a229d1e2ee8882e436eb4d1f85f4558b1752e092c37c51d004d83f24bb00c380", + "src/transformers/models/tvp/configuration_tvp.py": "7df7b40ce53e32d24db7f6ac6fb22b6883f6889ad7c48c6bd7df8e6c89dd67b9", + "src/transformers/models/tvp/modeling_tvp.py": "fb02e73fad52996f171a2e21c5bcfa72557d17e5bc5b13b17a7ebe02a44f688b", + "src/transformers/models/udop/configuration_udop.py": "3daa483ad3f04060b7594f4f253f07858324b09b226321801d0d98cd09931811", + "src/transformers/models/udop/modeling_udop.py": "7c3380847038a220cc7cf254ae9791e40630c2b00c5d3c0352212806204d9dde", + "src/transformers/models/umt5/configuration_umt5.py": "d3fa5f25a6c79c5937d706887f5c75638566ee4b29c76f58e904b218b0df558b", + "src/transformers/models/umt5/modeling_umt5.py": "659a167590580f7462b78e4eea17967ae22111c956dae9d6073ba0b9b2fd53f3", + "src/transformers/models/unispeech/configuration_unispeech.py": "fc6690fbe0c6a3bea6a274e2f8c5e64a34fc4bf900bcbeb6b6660af8f37a48a6", + "src/transformers/models/unispeech/modeling_unispeech.py": "243489cac13bcb2b311b4aa5b7bdf757ca68bb0b6a9f14a19fd07fdc3ae996ca", + "src/transformers/models/unispeech/modular_unispeech.py": "a6724ad3c95e2902d3f35efe618a9c3f22b70f33d584d8471fc8237662f1debd", + "src/transformers/models/unispeech_sat/configuration_unispeech_sat.py": "f65db34f18861f13c083d6016256c0a3dd0da817504ff6b433d2292c09005a55", + "src/transformers/models/unispeech_sat/modeling_unispeech_sat.py": "80499bea64eab21a1c70d4f5f3444f1850bee15d7aa2c88e0a1d08ad8092c5d4", + "src/transformers/models/unispeech_sat/modular_unispeech_sat.py": "b2185429705c5096cf88d3b2670839a43418a90697aaaef93638cdfabf0a7304", + "src/transformers/models/univnet/configuration_univnet.py": "d1f4f6a3d3b8fc03c961db1bc0312c06d7016af445c51fe0feeb2434fbdc4457", + "src/transformers/models/univnet/modeling_univnet.py": "d9242eaf486d85b4df109f23bca8b0d105b8513be1fb3eb5f2147ed5e0fe28eb", + "src/transformers/models/upernet/configuration_upernet.py": "9dc1c0f1b141826c363a47b38a1d046af86788c3b2fd4d321256ea3501f1d89c", + "src/transformers/models/upernet/modeling_upernet.py": "0b84a17b6e5d141a1b8939e0cdfaac10fcc3fb227ec8876ef6f17db36ec31b01", + "src/transformers/models/uvdoc/configuration_uvdoc.py": "5e220d0899ffdbd8e20f9177524d21d78d189ab90f25f0bceb846bc1a1789531", + "src/transformers/models/uvdoc/modeling_uvdoc.py": "3c68b2f366e3edbc0ac0cc9250e2f4e794ba0f57a40d3ad47e77a25503805792", + "src/transformers/models/uvdoc/modular_uvdoc.py": "d0d619637b45f14892312e6c9856438ce3be718a533381affafc524b0a156ab2", + "src/transformers/models/vaultgemma/configuration_vaultgemma.py": "a1b2e84ab96dc80be8bbbdc511cc95df256d1c26872db3d77ef229a658505a8f", + "src/transformers/models/vaultgemma/modeling_vaultgemma.py": "333d95f12385509b301cefbbb6a6b2c6600c1d749a67c0ff2f0a3c895f2e1a06", + "src/transformers/models/vaultgemma/modular_vaultgemma.py": "1d01e307790ef3d93ba5b72f0ad8cab08b5d7ee91f61be445e700898884a59b2", + "src/transformers/models/vibevoice_acoustic_tokenizer/configuration_vibevoice_acoustic_tokenizer.py": "40079dbafe11652dacb56c028c7e7aaec698f12c863900e0a16f1ef559a6be75", + "src/transformers/models/vibevoice_acoustic_tokenizer/modeling_vibevoice_acoustic_tokenizer.py": "5d69a3c4a17093a3fa18161f49908c445f76f5c22ced1b644686071a295baa17", + "src/transformers/models/vibevoice_acoustic_tokenizer/modular_vibevoice_acoustic_tokenizer.py": "ad785b8322d2957b00a29073ac435a15fa94150499e18ad7ae6ded9a4f06863e", + "src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py": "979315dcc48aa47921ff6e885af86be6ee65ab2daeaf9f1de41dc44b410efcad", + "src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py": "d5b8cf85c5c6fa07b515c45b95ea9c58760707190912b04760d9eceab66dc9a2", + "src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py": "935029559355546c9dd5064a9b3d3e27580f21047ed3aa1d3d5c7371a6930e84", + "src/transformers/models/video_llama_3/configuration_video_llama_3.py": "53f73bb2c36456c760662d56d8c022478b193dba6e1e68494bb2fad2db0b3a12", + "src/transformers/models/video_llama_3/modeling_video_llama_3.py": "1cfbb936132865b1baf843498d9899667513d305f12819a7a904b989e0a89d50", + "src/transformers/models/video_llama_3/modular_video_llama_3.py": "22d197d63a231e20dd6b397de29f4f901a0b1076cbc57e8085f4f0eb3377a321", + "src/transformers/models/video_llava/configuration_video_llava.py": "4b106d56292d32b5bf7a70ee7913feca24f03519ea44709d57ad310f47e0889a", + "src/transformers/models/video_llava/modeling_video_llava.py": "0c267dc405cdb41e5c8e0a97b21f4f4c70db6eb768fa911fb115c86a6ef6857b", + "src/transformers/models/videomae/configuration_videomae.py": "1adb23d83df6bba5340d2bbf8c0163b982ab3bbe07651ec02245e9f38739a02d", + "src/transformers/models/videomae/modeling_videomae.py": "8f187dca972d3db748149ee5a6ef4da8b9f182f600fe2dea5a2dee1a47aa908c", + "src/transformers/models/videomt/configuration_videomt.py": "149df79fceec422db1ac0ae40a7c94c3c3aa12e96607246098aa2cb5a0c920ad", + "src/transformers/models/videomt/modeling_videomt.py": "3a5a66df21962e758b6d8f3e1c2c20a5c2ee5828d0bdc5d409b2f5593ca44146", + "src/transformers/models/videomt/modular_videomt.py": "3672dd7e37e1a524ab877fb528c17f03d5a209f9700701b2cc2321530ee80399", + "src/transformers/models/vilt/configuration_vilt.py": "59a15b92ee78ab5fc85a88717cf05585fda9569f9da7fa5c11eedbe6bf39a66d", + "src/transformers/models/vilt/modeling_vilt.py": "5a5fd75f22d9182c2673de4f93468f6437cbcc0e342b4a6eeb0ae54f4d2e88b1", + "src/transformers/models/vipllava/configuration_vipllava.py": "91384e23e7c8ce045fe2d96c1cf651c3bdd8393f632c6afbfe8e4bbb576ee026", + "src/transformers/models/vipllava/modeling_vipllava.py": "f2906e27064ff719d3ede8fc8176ee2ca1bfddece5068772a06ad60ba51ba8a9", + "src/transformers/models/vipllava/modular_vipllava.py": "dd364d79983d80bddae2507fa0b2b81df41e1a7f0db1593a259d24de44aeed1a", + "src/transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py": "0a81582a774155b504094aae6e6a44579a1f56a2703edb9bf82e7f3b3ec64dd1", + "src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py": "1d71b5e756458ab5c261701416e14dea7bb9602cb2e1dbba3b6571878df49285", + "src/transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py": "a1905873800d646ac3579c6122d81a8632fa58e1f933987a653ee3f113f380fd", + "src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py": "c312c21577612a61b88eacc8f130428ec29cdf5a4475817ded3b5aae1092951e", + "src/transformers/models/visual_bert/configuration_visual_bert.py": "4091adcb8c45a5e25293ee7d9df08ab66d3818f92895d144646d372b320a4fec", + "src/transformers/models/visual_bert/modeling_visual_bert.py": "bb0408273c514ba56bf5d3a1246d68972899f2ba77dd7fba2758d7ae7f9e8eb9", + "src/transformers/models/vit/configuration_vit.py": "f15ab198fe308b52fc539ecf3e055de774719dc00c86c833b4aa51d23987179c", + "src/transformers/models/vit/modeling_vit.py": "900b7703dd256170cbb0a8c59a2ef3f54c93768e8bc1b5d6043c957c6771176d", + "src/transformers/models/vit_mae/configuration_vit_mae.py": "5966821656939a20443efcec6c16415be865fee7d12e4423ebb2d85bc8c16953", + "src/transformers/models/vit_mae/modeling_vit_mae.py": "ecec97ec8440d4109d13e2161cad6dd76ebc87d96855627217ee4fbdb5929f38", + "src/transformers/models/vit_msn/configuration_vit_msn.py": "445d7ccabcc998562fead5adb02fc8cae615f4506d7188d98d7c32d5952f72fa", + "src/transformers/models/vit_msn/modeling_vit_msn.py": "2bd91b70265ecd0a6ed4dafa8ad90b3e8bd07aa0ffa2c869dc3221067fb586e1", + "src/transformers/models/vitdet/configuration_vitdet.py": "99dfc5960e5f8d3d2f9e4f087b629291260ce622e518735691a3a044e5d0c460", + "src/transformers/models/vitdet/modeling_vitdet.py": "359ba606a369817b4c23449e5f4cc4dea27f565ee72dd925f481ec2655e63648", + "src/transformers/models/vitmatte/configuration_vitmatte.py": "a635ce922f5052037ae7c25ef8292805b6e775a7173607ad633e15f670caeeaf", + "src/transformers/models/vitmatte/modeling_vitmatte.py": "5191f4f733b30c082af4f9285a302449ca12d322967eabe8f1b667568e2279cb", + "src/transformers/models/vitpose/configuration_vitpose.py": "1933e851afacf6640480d623dfa6db57b2028f97c480c0a73a00d26c54229629", + "src/transformers/models/vitpose/modeling_vitpose.py": "4c43b51df9ea97e3cc0f6e8ac1d7a71e2a4b04e8c5b38f413edebcd81a81675c", + "src/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py": "42a054ed8ee6e4376dc2db96bede782f3c8addea99e11967b5ae325d5b423270", + "src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py": "c78c5dd6fff44ad7b3c912094344d4c25bf436460037c7cc38fd9149a52e4f54", + "src/transformers/models/vits/configuration_vits.py": "85b5c87a878a9a9c80f91cd5034a6a505ef52aaf5a405b2d0bbbde4a92894bc4", + "src/transformers/models/vits/modeling_vits.py": "58cd039119dc3586a70cecf7e642233681ccee047ba240bd6a537c75b34b5cee", + "src/transformers/models/vivit/configuration_vivit.py": "767519befa6bc580f71bb222a1e8c87772c97252a0c345bbe0e86cc122e62f36", + "src/transformers/models/vivit/modeling_vivit.py": "dcee9e2418de62a4566706267e5e10fe17826e52f6dff01da74508b90426d1eb", + "src/transformers/models/vjepa2/configuration_vjepa2.py": "6dbbcb9e885488a99498d1db0a48950cb62139d0cad0d4c7cd09d28e949a6b7d", + "src/transformers/models/vjepa2/modeling_vjepa2.py": "11ec392039a0c290b820956948b35bdd6872878119ce95d7e8daef97578fa098", + "src/transformers/models/voxtral/configuration_voxtral.py": "beefbe93584882eef99d9c7972127108c70fc7d048cd2d88657aedb5a9614995", + "src/transformers/models/voxtral/modeling_voxtral.py": "d3fa5ee2975383473797d79e6f3ad467d3c8d9cd9a5c825de295bce7ac3fff76", + "src/transformers/models/voxtral/modular_voxtral.py": "a124af843b2fe87f44ce0dfe3cc671010b316ed84150b0109b902dff7864d124", + "src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py": "191b686759d8ad262648c00896899d993d828a5d0833d9bae83bcebeaee22d89", + "src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py": "9e1f26cc719bca3d76068ea6ab7a9f83763f8b0c38b2bc0ac1191602e314c497", + "src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py": "d6c69d7d6af9e774ea92485faf3cba4813d42a9471f01a38096c8d722e6ba224", + "src/transformers/models/wav2vec2/configuration_wav2vec2.py": "8650bebeea0e71263c64bb82d5c2a52d450247d554b6da18fba8a8fff11efd93", + "src/transformers/models/wav2vec2/modeling_wav2vec2.py": "053cd02abe8cbb63a2f23d2eb486bd33723dab884b1718aa6d0923aa215d6424", + "src/transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py": "bb037d4710bf7117c8cd9e8aabef59143ad39a9de315d41b4f13d64a76453ca6", + "src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py": "d89cb5a2532f9986cbdf96663987b8c71b79cefb6501a701189dbb657a33c48d", + "src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py": "651a7c381d07e473f8dab7e0598a306650553230868770bfcfa5221fd2be2aa1", + "src/transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py": "2e138af1532aca4c5a1eb922203933309f1f4c12d7ee5007be36d68d593f37db", + "src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py": "1baeb85979dc5237962b8d2689033647c5e56ad76d8a27603edfd615175e1c26", + "src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py": "d8d459a537db62a26ad90dcdfad25160f17ebde81269e21db64d310c89fb66a1", + "src/transformers/models/wavlm/configuration_wavlm.py": "6c07717bea00a8c6b622b9cd97123c19d3e0744fa76ab3791676ae9a97d0324d", + "src/transformers/models/wavlm/modeling_wavlm.py": "43cd2c54edadfd5f697f68ad64c462d45707ccfabd9d15811213bd5e409340a3", + "src/transformers/models/wavlm/modular_wavlm.py": "146c101d45ef4442ce295f6ca1e0e2cbaf031038f918e7c54b379eef23c10b35", + "src/transformers/models/whisper/configuration_whisper.py": "91a5cc9d8e2284490628c48f5470d1b89a233b16ab219ad9d300b3ffc978de7c", + "src/transformers/models/whisper/modeling_whisper.py": "0d711e4623341a3da969d5fe90841a9ee28a4ac464ca42a7c8ea8f3e1d1d54da", + "src/transformers/models/x_clip/configuration_x_clip.py": "d39e1d1d73090f322369c5679d18f65ac0b897e76d66620fcf66390e3dffe346", + "src/transformers/models/x_clip/modeling_x_clip.py": "3cea422f2a284b6135e7e0efa09a5c10509bb142861c59bd66ad37ef8f241b4b", + "src/transformers/models/xcodec/configuration_xcodec.py": "d2cd3b2c86368b597476825ad2eac62fe8f0409a2c8b1f12cef97784ca733500", + "src/transformers/models/xcodec/modeling_xcodec.py": "03be2a880f1429d8722f700cdd5479f2da555ee4d241be774a122779c71b734c", + "src/transformers/models/xglm/configuration_xglm.py": "2dc8cea98578cb05cbeaddf4ca1e016860f7365c785d1c1a1af6f3f3eb3fa9d4", + "src/transformers/models/xglm/modeling_xglm.py": "7119c77597c7966720bf8c04228c41755a3e225a7b68f74e5d15f02964e9a023", + "src/transformers/models/xlm/configuration_xlm.py": "f0b5d2b6b9669d845540b2b9df5a2b7951354796a47d1bc985f14b23d631daa2", + "src/transformers/models/xlm/modeling_xlm.py": "baec462f1e4308c31b084ade8618fb2041a6a760a6090a4bd57ceaf4a8ef7dee", + "src/transformers/models/xlm_roberta/configuration_xlm_roberta.py": "7569a207cea00ea0eb1e22bf725e071d409f705c512ff1e33cc30a798665a193", + "src/transformers/models/xlm_roberta/modeling_xlm_roberta.py": "c8394866e7f785a6c2652c8341ff0d649dbe76ca619bf4078f66e2ec689b142b", + "src/transformers/models/xlm_roberta/modular_xlm_roberta.py": "8c8469c3867eb85c2474eec51406d32187b1f8169c6517d77b79cdd8ac41dde0", + "src/transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py": "fff0d189298232584d020e3b0b12184855bd49c3c2821c7385d29ae564be0664", + "src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py": "d9e4579799bfce9a55ea16ca5915033a1e463a6e1f2aa37d6fb852b780e1b7c3", + "src/transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py": "0db7040a77c0db9c71adb37374a6fe11655698bdd511c744327a5f90584b4f66", + "src/transformers/models/xlnet/configuration_xlnet.py": "b23ca312b54597df3c9515bd14e2bb8b0bccf7d71722c2030900eeb05f06ac2c", + "src/transformers/models/xlnet/modeling_xlnet.py": "26ef15cc9e19c5f1ad78e7beac1d02620659b4b7811ee07944b082cea346d827", + "src/transformers/models/xlstm/configuration_xlstm.py": "2fa1e1de35a743aaf4faa4399be2f6669990ab5370abf6a5670278ddacbd86ad", + "src/transformers/models/xlstm/modeling_xlstm.py": "658246eb2c90671f0912bc2bae93385a32f8a74b87c21ec334e34a1d6ccbff22", + "src/transformers/models/xmod/configuration_xmod.py": "fc8326fb791bf9165dad3af262e172b8c5e1985f4eb96ec6369a47de27ff3ebd", + "src/transformers/models/xmod/modeling_xmod.py": "019b42bd4b6373dddf2b15a16b9c20b3f0915dae92a7633bb5ef9f2ee1bb5ef9", + "src/transformers/models/yolos/configuration_yolos.py": "8970bd7e2ef458e6063b5fccf136c237b063465b4f7fd7be41c0f0b4e1fc1aa4", + "src/transformers/models/yolos/modeling_yolos.py": "32933319a7ef76a7d2d026abb2def66f2576349759836b23ebbe95682024c56f", + "src/transformers/models/yolos/modular_yolos.py": "67b24ac8dd457ab763f88b08a368c43fe32fb1487989dac63cd60c56e98badd1", + "src/transformers/models/yoso/configuration_yoso.py": "18fcac752ecd69aac9cda2ceadb444204a1b8b2536bb25030d2b34a34571279a", + "src/transformers/models/yoso/modeling_yoso.py": "ee3856557776db9294841314d24f39db87253677984766fcad6c78139a3db423", + "src/transformers/models/youtu/configuration_youtu.py": "72981b4a97b2f39069e4d3a74a008ab8acfc600c898c494662ee2f0c7683315d", + "src/transformers/models/youtu/modeling_youtu.py": "fbc8374c50df05da935a1b040aaf0a90e93689804c7b4e35e8d6b8c0813f66ae", + "src/transformers/models/youtu/modular_youtu.py": "ddbfc95d009cf94a6f3487c0d064c7a3aa2ef6c6015898c51988b5dab9c452c8", + "src/transformers/models/zamba/configuration_zamba.py": "23566c8ef76743f6be134cd1d5853d323c87677161ba2b477a57722e0215f376", + "src/transformers/models/zamba/modeling_zamba.py": "205a0b6cc541fe9d05e811226cfa8e0649cdb6fa634cfa888dd79f6b4f5263a6", + "src/transformers/models/zamba2/configuration_zamba2.py": "c98074770b92db6d888aaf530279b2b4f156357a6d4365cfce7298e88452ab12", + "src/transformers/models/zamba2/modeling_zamba2.py": "c4dbed4799ed9a2374a7eb6c55ebda6c7710310ed3b6d8508469a32c7babb999", + "src/transformers/models/zamba2/modular_zamba2.py": "546cc01d3910e159e566a41a8d86ed1c955eac41391871cf3e3cb2aef7ef24be", + "src/transformers/models/zoedepth/configuration_zoedepth.py": "1a2474f62a1d0e91bb183c1c2845dbbd4af92e5aa5ac2e7d1db96499c5b39907", + "src/transformers/models/zoedepth/modeling_zoedepth.py": "97e08f50416ab8ac3b6c52e6b3471aa3573b8619e85765acc6d0001b1bb8f657" +} From 908f0da1dbea78b0c02e03e5d25fb3318c09c5cd Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 31 Mar 2026 11:06:14 +0200 Subject: [PATCH 51/56] revert unrelated what are they doing here --- .../models/align/configuration_align.py | 1 - .../models/nemotron_h/modeling_nemotron_h.py | 8 +- .../models/zamba2/modeling_zamba2.py | 8 +- .../models/zamba2/modular_zamba2.py | 8 +- utils/mlinter/.mlinter_cache.json | 1073 ----------------- 5 files changed, 12 insertions(+), 1086 deletions(-) delete mode 100644 utils/mlinter/.mlinter_cache.json diff --git a/src/transformers/models/align/configuration_align.py b/src/transformers/models/align/configuration_align.py index babf97d4572a..cde6445cf62f 100644 --- a/src/transformers/models/align/configuration_align.py +++ b/src/transformers/models/align/configuration_align.py @@ -59,7 +59,6 @@ class AlignTextConfig(PreTrainedConfig): pad_token_id: int | None = 0 bos_token_id: int | None = None eos_token_id: int | list[int] | None = None - tie_word_embeddings: True @auto_docstring(checkpoint="kakaobrain/align-base") diff --git a/src/transformers/models/nemotron_h/modeling_nemotron_h.py b/src/transformers/models/nemotron_h/modeling_nemotron_h.py index af09fdfaf36f..9e264e5cfdcc 100644 --- a/src/transformers/models/nemotron_h/modeling_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modeling_nemotron_h.py @@ -435,9 +435,9 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention dBx = dB * hidden_states[..., None] # State calculation - recurrent_states = cache_params.layers[self.layer_idx].recurrent_states.clone() - recurrent_states = recurrent_states * dA + dBx - recurrent_states = cache_params.update_recurrent_state(recurrent_states, self.layer_idx) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states.clone() + ssm_states = ssm_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -446,7 +446,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = recurrent_states.to(C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_states.to(C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index ea5054194de6..6e4ea7dcf2d8 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -723,9 +723,9 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention dBx = dB * hidden_states[..., None] # State calculation - recurrent_states = cache_params.layers[self.layer_idx].recurrent_states.clone() - recurrent_states = recurrent_states * dA + dBx - recurrent_states = cache_params.update_recurrent_state(recurrent_states, self.layer_idx) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states.clone() + ssm_states = ssm_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -734,7 +734,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = recurrent_states.to(C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_states.to(C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index f39b6de31ff6..d7716301ad4a 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -511,9 +511,9 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention dBx = dB * hidden_states[..., None] # State calculation - recurrent_states = cache_params.layers[self.layer_idx].recurrent_states.clone() - recurrent_states = recurrent_states * dA + dBx - recurrent_states = cache_params.update_recurrent_state(recurrent_states, self.layer_idx) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states.clone() + ssm_states = ssm_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -522,7 +522,7 @@ def torch_forward(self, input_states, cache_params: Cache | None=None, attention C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = recurrent_states.to(C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_states.to(C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] diff --git a/utils/mlinter/.mlinter_cache.json b/utils/mlinter/.mlinter_cache.json deleted file mode 100644 index 39154b73c1d7..000000000000 --- a/utils/mlinter/.mlinter_cache.json +++ /dev/null @@ -1,1073 +0,0 @@ -{ - "src/transformers/models/afmoe/configuration_afmoe.py": "cb18784c5d578c0352c0f2958afb473b5aadd3ef4ecdc5216150204ab553493d", - "src/transformers/models/afmoe/modeling_afmoe.py": "a5681d5e3ed6e4c25c8cc5d207a40b1757c7f21e342c982380ec1119ac44861c", - "src/transformers/models/afmoe/modular_afmoe.py": "ae6dfcbc3fcf34c2fd3b217128d1e33db0d70db9045a19fb6da66e40f1f50a2c", - "src/transformers/models/aimv2/configuration_aimv2.py": "89696ddda44298d16a3c5c5e40741588d5d66a0679339479099a5ff42f6a23f1", - "src/transformers/models/aimv2/modeling_aimv2.py": "cb9287ca4946a51b16c99eb49bb4eafbad9df2a81fab1aa61d3beb7cab8c0a80", - "src/transformers/models/aimv2/modular_aimv2.py": "97d66d6d0756d07e7e0d1d730b4c89da394412d4507d92ede09fd7f5b11623f3", - "src/transformers/models/albert/configuration_albert.py": "e3c4d4e4c87111b669cc285cadb206d58731e0c88d166a1113a0f096f4a8909f", - "src/transformers/models/albert/modeling_albert.py": "fc64e43d93f5bbbe8734c663ef6740a559799cabf441e436cb2901cac9b27a38", - "src/transformers/models/align/configuration_align.py": "c4dc28a3ba2be74752f6ffecbb75c425332f2670a9de6d13584f6779a5ee1058", - "src/transformers/models/align/modeling_align.py": "73ddf7860acecd9ba61dc269ab54bb5e949c227016eefc3a66d823dcab8ef94f", - "src/transformers/models/altclip/configuration_altclip.py": "9487f951824ef5c5f60eb001604ead445b48968d8738ba674b4823cf5f7e298f", - "src/transformers/models/altclip/modeling_altclip.py": "b2d926d1f63e86913b12cdeaf3cbee1b28c95b6c80765a30a63ac910f9dd02b9", - "src/transformers/models/apertus/configuration_apertus.py": "ae34c92ef6630fdf3b4875f5b4b6fc08fb67a6b11533515411ae37b7a4dc4ed5", - "src/transformers/models/apertus/modeling_apertus.py": "f6e3cb98e5dfd454dafe50c36b52734e554700dce7cd0dced4ded6791c586b26", - "src/transformers/models/apertus/modular_apertus.py": "7c0c5f12ce3e6bbe163501d0a476eeb52c88255de081e5a03cebcfb42f7a5e67", - "src/transformers/models/arcee/configuration_arcee.py": "c1bdb413f20fda66604002b0c1d2fd06cb67c1e6bc3090b0282e34a77a52e387", - "src/transformers/models/arcee/modeling_arcee.py": "1fdfa8f1d32d2f2193d3146ed6baceafb3122e72f6888c4b471f9cd3be44a087", - "src/transformers/models/arcee/modular_arcee.py": "bb3829636df7ae7428960147106cbbc23f73503f57919371c8d79ca75c7fa45d", - "src/transformers/models/aria/configuration_aria.py": "8cabf127bb9ead8ae62278279aebac74aaecf8eb78a28be1e19af8e984aec4e6", - "src/transformers/models/aria/modeling_aria.py": "9828aae955a915bfc59b9714d6ae9c560f0791131d1a15d95218564fadffb66b", - "src/transformers/models/aria/modular_aria.py": "11784c387433a2ffa27e5eefadc8eda4c3bb75191a3077fc4d63a01421d80203", - "src/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py": "fd56d62a882a8c4baaa6e07b437635c24360a1e6f782125480baa6fbec35f18f", - "src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py": "0c744ebdb41884f72e9fd983da1f5b4a62ead979ac0d3f0010edeb968056401a", - "src/transformers/models/audioflamingo3/configuration_audioflamingo3.py": "23203bcec4b5df4af65faf6dc95d2ea6e9a3c36f38bdb4ac200a5b311674a6e2", - "src/transformers/models/audioflamingo3/modeling_audioflamingo3.py": "14b9c7e3a43c55ea507b41e815b3acb6d02fb0ff917b55a616e5451b70dc42d0", - "src/transformers/models/audioflamingo3/modular_audioflamingo3.py": "ba453855d1764eada7316808882e1bcca9a54505d58d97fbe701730d84485eba", - "src/transformers/models/auto/configuration_auto.py": "9bc2674a59eac35c771a46cc4059aaeb37e8e6602f17d36f8b10c3d18babc502", - "src/transformers/models/auto/modeling_auto.py": "b01d5c96f8d7125b2f70b630032dab0ce24a42f45e4f46b963ba34492bf5ef16", - "src/transformers/models/autoformer/configuration_autoformer.py": "243ec867b1578384d0860237d85a86451dce375b6b013f209a06c74411bfc810", - "src/transformers/models/autoformer/modeling_autoformer.py": "4d2f8c8cedd2a220a9d48dfcbe8dff1b8605390ced03a860293e3b81a3524d7e", - "src/transformers/models/aya_vision/configuration_aya_vision.py": "e8039bc9df5ac44693f533fd50b53b0b30dd3cd3f17f66a9df156b139e1f1b1f", - "src/transformers/models/aya_vision/modeling_aya_vision.py": "bfebfb56427cf3428c25fd5c2869cd3d5f95061220f67abca10fd102b9330f5d", - "src/transformers/models/aya_vision/modular_aya_vision.py": "68f24cffef72590c7f59074803d806942e2e290bade6737387c1e36faddbc9bb", - "src/transformers/models/bamba/configuration_bamba.py": "cf1b9f81f03a825255dc9546022356a6eb3b1e6f4cd664c3278759fe7d3f3f08", - "src/transformers/models/bamba/modeling_bamba.py": "420ea53746e558276e350de794ff5b234a79e1ec9170eba5fafa1f3f2a82b9c9", - "src/transformers/models/bamba/modular_bamba.py": "23adb0e8edbc56ba4bd0bea5ad53904d5963a7e36f42d6fcbc66d17691cac65f", - "src/transformers/models/bark/configuration_bark.py": "237c42af22103f191e551a5a1192a9bf4794e2b7861189a58cf6a2bbb75daa7b", - "src/transformers/models/bark/modeling_bark.py": "3bbfc82ff3f700a4b5deb860a423bf644ea26e7c3bc469407597351dda4533e0", - "src/transformers/models/bart/configuration_bart.py": "5f56eecabf2ff9bbf8f53c5df4b7368287e4626ea526bc1e7dc1776fc6922e5c", - "src/transformers/models/bart/modeling_bart.py": "8e8d6b713dc94cbced6cc05c54b0fc9217dcede141e79891def2b4c65cf205b6", - "src/transformers/models/beit/configuration_beit.py": "cc1c33ab0e97b5a6b1c7274d96b185d0484b7d18989720f18294dc656dac67f2", - "src/transformers/models/beit/modeling_beit.py": "e5e3be58febade51052456042defdb04f8bac95bc00ccaa7c299c2f847947a71", - "src/transformers/models/bert/configuration_bert.py": "924e13540f603e40b2b4ed51d7139eb95ed646a851d7cbb2ee6186da7c9829b9", - "src/transformers/models/bert/modeling_bert.py": "bc1c375d781fdbeb424e4592ea82be8ec25c4841a057f5f85e08f20f50d2068a", - "src/transformers/models/bert_generation/configuration_bert_generation.py": "ee10f065b884880ba1f6c9d08a966b2ed26812ee7602a7f565e7d87e683c1c6a", - "src/transformers/models/bert_generation/modeling_bert_generation.py": "5a0cfaf598a7be00806d9e3aa2a59c31aa7d4c671652e2e15cad600b14fe5080", - "src/transformers/models/big_bird/configuration_big_bird.py": "c328fdd36de02bef93bc9c02c34aac7d61e23ce32cd23932c894dc1b8ce5d35a", - "src/transformers/models/big_bird/modeling_big_bird.py": "1700b3b52a5ae85e5321367be5edfa1e3a4cc8e0709fc5c19b88ff4eb9be5523", - "src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py": "dc782ebd96e20824279e1c60c9dccdabfb4569da461a1054e6f46c1c15e5e8dd", - "src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py": "78b36af511d3a2e4df5ed526bb609638488ea5574d10b75f4d349b907fcf3077", - "src/transformers/models/biogpt/configuration_biogpt.py": "a34f6a9c53d25bea7d3df6b8a73612b287e754157e557c7061d45e4b367d36fb", - "src/transformers/models/biogpt/modeling_biogpt.py": "47a87ee17f7d5cca6221d5c87f1836af8c15da04ab5a16c0d3c5c67937d6b154", - "src/transformers/models/biogpt/modular_biogpt.py": "82a58e3790d5d236add39b6d7233da5f2d98108e1150e640e43f213952936e5b", - "src/transformers/models/bit/configuration_bit.py": "db00bd41886ba99305fde95b0b3ebad06ebfb143438441bbdf496697db50a78f", - "src/transformers/models/bit/modeling_bit.py": "44724c5dc21875fb438691a8b86039d60a34c903c6b0c92d88fc543bd1a57f8c", - "src/transformers/models/bitnet/configuration_bitnet.py": "7d79cd088e7cf06e5f27084c2aa98971ae146dc51623fd22611c3b29e23fb684", - "src/transformers/models/bitnet/modeling_bitnet.py": "1f589083a3b6d7836763de0b641b2555eafddb20c89aa7c985d7185aa30a7ee2", - "src/transformers/models/bitnet/modular_bitnet.py": "1bde132b6ae0857152632e990006a8872eebbadd76ac975c53194fe051896df1", - "src/transformers/models/blenderbot/configuration_blenderbot.py": "170ced3b7b82617975dfc05f34bb64d3c82dbcf1ef6a78b76794d554e7466f4e", - "src/transformers/models/blenderbot/modeling_blenderbot.py": "14505b502d4ec2d94bf15934752ba5b896caf478ba775a66f2e6d00ee51189ab", - "src/transformers/models/blenderbot_small/configuration_blenderbot_small.py": "03c1529efa869494c1293229409c9903184af564d857465d6c50eb8196433653", - "src/transformers/models/blenderbot_small/modeling_blenderbot_small.py": "64f35af1e75ac11f88bb294ec370ccde54782b1594bd96597870b9d7d2555508", - "src/transformers/models/blip/configuration_blip.py": "ac921b6675d31bae5a879571d49ac1fc60ebb5da0536fe636b7718d2f72304f3", - "src/transformers/models/blip/modeling_blip.py": "015f1b98b382a666db8c31711c6f295acd3c89254292401ff937f8c877f660b7", - "src/transformers/models/blip/modeling_blip_text.py": "eb0378f5c76c4a9ba5c1ea5ad3ad5287bf40e492afdce059bd84ddd4c1856087", - "src/transformers/models/blip_2/configuration_blip_2.py": "03e6a77499c9e5d172d6e04a69c1f126772bbd67a5217c41dcdc0711420063e3", - "src/transformers/models/blip_2/modeling_blip_2.py": "ed4e5d9fb1ad6557076af88f7679b91434d906b366acb4930887abd2bddef136", - "src/transformers/models/bloom/configuration_bloom.py": "47d054ceb7654b1328441e6f59d1afdffbd8f80685f708e5d1d30c5e53974324", - "src/transformers/models/bloom/modeling_bloom.py": "798a7dd43aa1e9584219255c36740873ea4d1f204c002ddb8ec7296bcf70128e", - "src/transformers/models/blt/configuration_blt.py": "9cfb74b1e0eb9ce9611a19942b6e82c32573470ecb8195c043e312e7b23abbce", - "src/transformers/models/blt/modeling_blt.py": "81e4bf6e0737688e8aa552d802c56362af7a171ab3464c546f5e73adc8505d9c", - "src/transformers/models/blt/modular_blt.py": "529435735e3413bcfc5c3a8096b56e58d57aa176a6628989c5d64039f8c2d0bb", - "src/transformers/models/bridgetower/configuration_bridgetower.py": "dabf6a2d4fdd1d38ae9f7b32c620ec7da52730e9533e668861513d7adfc7f7d7", - "src/transformers/models/bridgetower/modeling_bridgetower.py": "7e99550d2f015f6ce304bc73a461211080adfa52433298efb821ac81a75db03a", - "src/transformers/models/bros/configuration_bros.py": "af1f360c52851af9e08aa6767c39274d99f224e4b559fda04599461a9d6f50df", - "src/transformers/models/bros/modeling_bros.py": "0efca639237d5be48c6d28148d268c1728a5062f4422493fea5660e72c4a0cbf", - "src/transformers/models/camembert/configuration_camembert.py": "65bbc35964cbb42e7d06ea05128302237b80c16f8fb38b794e48299831bb0d31", - "src/transformers/models/camembert/modeling_camembert.py": "463f946576635f1db3013d5b5c77a98d99aea36c31b4e472f6a8efbbc47b24d3", - "src/transformers/models/camembert/modular_camembert.py": "12b8cfeeac5270dd43827aad7f3f3f61d67d781fd458b7d033f4daad44a386f6", - "src/transformers/models/canine/configuration_canine.py": "660d924669bd0ed8ecdea59870c8a351d3eda6d5b4850faa3a744db4a41a64a1", - "src/transformers/models/canine/modeling_canine.py": "5d180ff1dcfe4e264284338cfe74806392f821bc6abceb3c807249807339e911", - "src/transformers/models/chameleon/configuration_chameleon.py": "c180ffea27d9f06fcedccff1507367bf1e90aa60e933acaa7f5cce80162160d5", - "src/transformers/models/chameleon/modeling_chameleon.py": "5a234980db08fbfef75e932b3150f583f93439ca63939967a4d2395e59c153ec", - "src/transformers/models/chinese_clip/configuration_chinese_clip.py": "8b80e122141da04367b1d44fdee1ea2e2ca3d0e11f65be2fcc478870758f595b", - "src/transformers/models/chinese_clip/modeling_chinese_clip.py": "17c95c09d63f1a0a9304b55e6993e3b2c7edfb32b740b1a462171d20dfc7e310", - "src/transformers/models/chmv2/configuration_chmv2.py": "a6b4ecdfd6d5f728ba49e2fa06a97c469d32802ca4910eae1dd5b0c23c6dcf70", - "src/transformers/models/chmv2/modeling_chmv2.py": "2124cf933637bb60bb831bd3a632fc00e0d16c314b0e0a43e051370a738e58fa", - "src/transformers/models/chmv2/modular_chmv2.py": "27d4a6c27e4625bb9c881bad275cfc486db675a9581aed8a026e0e6e9db934cb", - "src/transformers/models/clap/configuration_clap.py": "0a4e35390b3a48ed865f342772e88f92d2338730f03bb1b95a3e66de3d7cfeba", - "src/transformers/models/clap/modeling_clap.py": "fecade298c566bc1656c64392c74ec6c7964a88c79d70888f48d746deef270eb", - "src/transformers/models/clip/configuration_clip.py": "50f7854da6572cf58f4c7087113a1cb5c0e5a37be413cdd33d820640c1cdfe45", - "src/transformers/models/clip/modeling_clip.py": "4f6dfbdd988328fc585c719ff88a516ceb90984618b82ab3acdf828becb741dd", - "src/transformers/models/clipseg/configuration_clipseg.py": "7676e7950a04a2e9a0e225f73fa1229362c2388991deedf4f27c639323badfc0", - "src/transformers/models/clipseg/modeling_clipseg.py": "20482c86460de66145ae112edc25ff58561765309604be426df2b20c5bf5eec8", - "src/transformers/models/clvp/configuration_clvp.py": "aa0fa5ac98e1a2e8dbac9b1c32e1bb2d6c8824d81853d1b40717cca15d46ade5", - "src/transformers/models/clvp/modeling_clvp.py": "2979ab9758d37b4b555db31b686b073c0b9e444253cdbfe912f15820fc1d7f46", - "src/transformers/models/codegen/configuration_codegen.py": "14cf06cc4237fbab5836d8607634e98db62ff778eed49b0060705bd7d900d99d", - "src/transformers/models/codegen/modeling_codegen.py": "410f4e458ba972a4431f19833397013ba311933b6f726d3280157269978fa5d1", - "src/transformers/models/cohere/configuration_cohere.py": "eb90b47a9977d1a201ffd823ac25c838e467412ee98c3d6aef81fc538c3975d9", - "src/transformers/models/cohere/modeling_cohere.py": "64fb21af826bf06f77b3e5cd3a41bb5f696aabce3afbd8f8ef73860cd37dc105", - "src/transformers/models/cohere/modular_cohere.py": "1dc15708e9c61f30dfe4cf43224017771355d0e7659f23ca8dde06906163e1e8", - "src/transformers/models/cohere2/configuration_cohere2.py": "c6a6e4e4cac03ed563f6407709cae4486ca131976a0152c4bfa58587a8a04fc0", - "src/transformers/models/cohere2/modeling_cohere2.py": "59e2f99fb0a33b3b74b63e6ba3869e8cecfcd70cda7f2131004699d7b7843949", - "src/transformers/models/cohere2/modular_cohere2.py": "32823f88142efa0826575ede028363cd721d92953117930b1033665f0a796e39", - "src/transformers/models/cohere2_vision/configuration_cohere2_vision.py": "cfc86d90cc8f9b71e8ddd4872326dd00e442a0ed5356ce21b0faba0590e090fe", - "src/transformers/models/cohere2_vision/modeling_cohere2_vision.py": "1aeeeaaa5b3a1741857a3559df0c919417255f4d22b10d811bb79d44cc22d03c", - "src/transformers/models/cohere2_vision/modular_cohere2_vision.py": "d889425af19565875b2f69af5410e02686e4e9f2545975b21f41776956c5f255", - "src/transformers/models/colmodernvbert/configuration_colmodernvbert.py": "2273072e47fae401bdd8a8e0bc3eced93033761a442cc5831357a519b125b9f7", - "src/transformers/models/colmodernvbert/modeling_colmodernvbert.py": "03ccbe74c30b366bd75728b749ecdd509f3324363a4a8ef605cf91acd40ea869", - "src/transformers/models/colmodernvbert/modular_colmodernvbert.py": "46ee2364642acd94bf2104593fbf9bb05f7a3ad19da173ec7cd2e84321cd9ba5", - "src/transformers/models/colpali/configuration_colpali.py": "769aabe840281c95bb22b9961e94fcbc61449b05149835c2f1b3b8c258ca41ea", - "src/transformers/models/colpali/modeling_colpali.py": "4c95c77f45d0fd414560068a4046d8323e2396c4f6b834945035e5677adba8b8", - "src/transformers/models/colpali/modular_colpali.py": "36fce66ab94350016138dbf44c63b248d51786a615548c1c73842a7dde56ec8c", - "src/transformers/models/colqwen2/configuration_colqwen2.py": "3d04aef7a93e9daf9aecba1547b74e9822d73b93b09e4ef7061425c94c6e7ffb", - "src/transformers/models/colqwen2/modeling_colqwen2.py": "b82dfb4975f1699fc2a79737e65435f04fee75a06f34e02b6ef2e68207e8c033", - "src/transformers/models/colqwen2/modular_colqwen2.py": "8ed503c2994674cc48897be4c2721723e7d617e3cd196a917dbd93b1c10f991f", - "src/transformers/models/conditional_detr/configuration_conditional_detr.py": "457c3bcfa4fbc6be338378385016fd756b8a4b0e486e3e7ebf85564b21d21ee3", - "src/transformers/models/conditional_detr/modeling_conditional_detr.py": "7dc0aee54e6404bfcdbae3a892a9b790e1d449d6be04c485b596e0773f54017a", - "src/transformers/models/conditional_detr/modular_conditional_detr.py": "3eb3d1dbf37cdf1758ad78a4b982fd8bc957c4a42176dd2ce37cea9c1f48ae3f", - "src/transformers/models/convbert/configuration_convbert.py": "fef352f0db34a64ffcb041cf21dbbc20e900451356b1fa253719ede911690239", - "src/transformers/models/convbert/modeling_convbert.py": "e0f1d3ae8512bbc6dde76c6bf873020fabed3597bbd95c41efba2b5e63763f8b", - "src/transformers/models/convnext/configuration_convnext.py": "df583a9c6a371c99bc30ad8f85db06cee883583e076531906d7e46624630bec3", - "src/transformers/models/convnext/modeling_convnext.py": "a18de5845a48b4d8b534973bf0401b6478eb16dad588287fa4ea794e0d9f0a19", - "src/transformers/models/convnextv2/configuration_convnextv2.py": "5a90b69b59adf695fc52841fe6a133091e69296113432b5f76728ef4b9cbcdbf", - "src/transformers/models/convnextv2/modeling_convnextv2.py": "ca087067320d3172df9b8f8c51264e8604a5f2057a81c0bc83c0dfca811d00c1", - "src/transformers/models/cpmant/configuration_cpmant.py": "e36d39c95f9a0359b69cbe16b1bc8d4cbf9ef78df7c65f10db52923ea5227140", - "src/transformers/models/cpmant/modeling_cpmant.py": "b67849add8536b34b2ec1fea4d34a0eb7ce3c3ed12c368ebc7eddf83e8a2c151", - "src/transformers/models/csm/configuration_csm.py": "580c7b5e4f04cf18685df855576c732b4a62b0ac5122ef07a98da2c17b0ab573", - "src/transformers/models/csm/modeling_csm.py": "e36148fc0c7785ec8f6e951e441cded0d2cd45611ebcdca5980dee2f3316703e", - "src/transformers/models/csm/modular_csm.py": "caac353646e7891bd3715ec04b9156e967e7dd81a758dadfa919d3a160719224", - "src/transformers/models/ctrl/configuration_ctrl.py": "4f1be8b7f1d941bce79fed85f1c353aa0b4c75966ff9350781bea78298d84c2a", - "src/transformers/models/ctrl/modeling_ctrl.py": "3d2185a273fbb35959bee234c1a2b320b94cb484e11fc35e4346f4b056f2f966", - "src/transformers/models/cvt/configuration_cvt.py": "33ad93c650aa04394ff3a4c4012c3bfd4efc4dd3b9dea64635f04abd966d064c", - "src/transformers/models/cvt/modeling_cvt.py": "7b71afc92f44097c0db3d622ffdfa95c858a13bfe8e0a27f25600e4d61f4a1c2", - "src/transformers/models/cwm/configuration_cwm.py": "b0736e8ee6ac08c8559cd9ba3c4613a9780914d8bfaac1f34c09b7d683e7465a", - "src/transformers/models/cwm/modeling_cwm.py": "b4fe43d614b9c21cda5d047b392d1527ee89d9f1582332df7d452565d72b5a22", - "src/transformers/models/cwm/modular_cwm.py": "eaabf54d8f7f97684a4f4cf1a3ab9ca8bb9f7f098ae2b98cc006f02aaa07328b", - "src/transformers/models/d_fine/configuration_d_fine.py": "57f8dbc1a5b0d8d8bc55b00cfcc729ea2c090489f71d0d7402433c7c7a4ff06d", - "src/transformers/models/d_fine/modeling_d_fine.py": "f4d4149af92a2d992aac90d2e9a9663d724d9df711ea319ac0e45320ce2b6849", - "src/transformers/models/d_fine/modular_d_fine.py": "1987db0582e024b9cebc92092a1db6e1c9401a8df5a30b5baed20d59dd8c3bc8", - "src/transformers/models/dab_detr/configuration_dab_detr.py": "044aa0b1f0972e9fda615bcc38a65b7a140b29bac2bfb105abf3494c0065a39f", - "src/transformers/models/dab_detr/modeling_dab_detr.py": "64be492568f9343e130a8e2be45d7f2559cd26d4a033cc904989f0169544b508", - "src/transformers/models/dac/configuration_dac.py": "12e72cf9357385a432466d6d6e91718c645db7e2095007feda71d653862f7768", - "src/transformers/models/dac/modeling_dac.py": "6e2837eb3c4dc681aaf68d7290beb5c6b095eff0cfcabd71e185815178fdaa3c", - "src/transformers/models/data2vec/configuration_data2vec_audio.py": "c79ee10106d926c90bd0d949feba8f10a9f0e8ef3305e17c00123f9ea5b1ec02", - "src/transformers/models/data2vec/configuration_data2vec_text.py": "543a47ac859d73c09194d2c5523da8b4121179af422b0d49d8ce2ae2dd75eac4", - "src/transformers/models/data2vec/configuration_data2vec_vision.py": "53fc204a7441a7d850353d0844b529e890cfd943760998099f738fe4c776ebc4", - "src/transformers/models/data2vec/modeling_data2vec_audio.py": "32326098589c9157b17601dd6e97c7b8318a4521d8fcf9dfbe8261631b5be734", - "src/transformers/models/data2vec/modeling_data2vec_text.py": "e7fe8a8ae658cae940421ce375ded83c2f7332e2187c15d2179e15baaf80106c", - "src/transformers/models/data2vec/modeling_data2vec_vision.py": "1a5fefca4e9018ad8d18142645ba4421d8311b393b2a91c6caf94ee25cd256e9", - "src/transformers/models/data2vec/modular_data2vec_audio.py": "b0a3e99d2981097dbfedd8460a9506d618179475beb0c0c2150d1632dec229fa", - "src/transformers/models/data2vec/modular_data2vec_text.py": "750785d49fda1170e291a5a3c5ad73c8f0211579e081e47c4dd1eafbc946d3d7", - "src/transformers/models/dbrx/configuration_dbrx.py": "beb361d07bf94f9a2d469cb9725b1a6988e6a819770fd7671743e5e84b8049fa", - "src/transformers/models/dbrx/modeling_dbrx.py": "2a2b9a881ac1f1f3ec41af08141bd2bd2b5d0fbf734034c43f35ad88f7b35018", - "src/transformers/models/dbrx/modular_dbrx.py": "e2602ff5b5deea70d6049a93926cd111fb5ffd2a658ece949828dc1c2185b97c", - "src/transformers/models/deberta/configuration_deberta.py": "8913b60a0a2f8b352f89db50123714430a0f5ebf84bedce8261fe8903ebf233d", - "src/transformers/models/deberta/modeling_deberta.py": "538de8ff25de19dade5b45b319eed2217cacd4e384a5a5d169fcdd609adc2419", - "src/transformers/models/deberta_v2/configuration_deberta_v2.py": "dd8976631d74b432d7f8ad042be77c978c45e2ec34a21275c0e49b616b9cbe98", - "src/transformers/models/deberta_v2/modeling_deberta_v2.py": "c12c852843921037e454dc0cdd41733c67eddb8d746e611334deb300f00e1e84", - "src/transformers/models/decision_transformer/configuration_decision_transformer.py": "263b4494b5aacff8b786af5b6eddf4b55561c1219db9302c653947eccbeb79a2", - "src/transformers/models/decision_transformer/modeling_decision_transformer.py": "ed48dce33a36564ed4870d6245d0ad2048845909ca8271ab04d18c1dea478400", - "src/transformers/models/deepseek_v2/configuration_deepseek_v2.py": "b0798b62092d04d0a55d0e0e8c2ea8623d436ed0324da31f7df29b5f9d2d8ffa", - "src/transformers/models/deepseek_v2/modeling_deepseek_v2.py": "215f762dc3c2cf50046d15905a00f2fb3a2c5441da01cccc657ffc3054f40757", - "src/transformers/models/deepseek_v2/modular_deepseek_v2.py": "37df6ba10372a53146a3beda9e4a57f85cad362a311151f78d271660bd583fcd", - "src/transformers/models/deepseek_v3/configuration_deepseek_v3.py": "d5b1a9f1d1eed3215ea4bc85f68a17fdc1cdba508f636c799d45532341698603", - "src/transformers/models/deepseek_v3/modeling_deepseek_v3.py": "4d71f61941533b2d9b7c9271195b2a0f50ec422b7ca58230de716b6fb8fa3212", - "src/transformers/models/deepseek_v3/modular_deepseek_v3.py": "66c80357dc4ad27a4fe5f059178ff44a81d0e9e5201e36c783cccede4fcdc7a1", - "src/transformers/models/deepseek_vl/configuration_deepseek_vl.py": "5455bfcffa4ae5bb841e37ddfdfb0965293963354b8cc05b2786ee0ef301ddac", - "src/transformers/models/deepseek_vl/modeling_deepseek_vl.py": "b23d11250b155e480a48fe9417bc4e1be3f8ac4d08cc93535375b5b0dc1d4ff9", - "src/transformers/models/deepseek_vl/modular_deepseek_vl.py": "635e1ee771e5fda88db6759ff81222685523ea0a630fe74532a0f63b318cd5e0", - "src/transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py": "54c67e75d0923bd643e0dc7afea40880f7a13b7c8f96120169f5e9af44555bca", - "src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py": "e72d6f404b762711c9f09f140d2a9620d1ae760a6fbe411979a5402555b63bc7", - "src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py": "47949abf6b64854955f64bcd62584b07686b0f30afe549319f91b31d253d8962", - "src/transformers/models/deformable_detr/configuration_deformable_detr.py": "139c7c496b0fae37c5589e0d3a4727614afbfdff762a1046d43355da39f2d17e", - "src/transformers/models/deformable_detr/modeling_deformable_detr.py": "6748c8f13c078922dda79652cc73d89f199d180f4c34308aa933a0c30d79c1a4", - "src/transformers/models/deformable_detr/modular_deformable_detr.py": "c842574153dc2b512244bd42430f5f8692872b5d115302c063988bd3e903e802", - "src/transformers/models/deit/configuration_deit.py": "14b0f4b247d0c88c84d6b2392ba8fe8455100b5b406867c54f0f468362ff9d71", - "src/transformers/models/deit/modeling_deit.py": "4652df841649330c2c647797b8bde19aa1a83074a29eba32c1ae662807460885", - "src/transformers/models/depth_anything/configuration_depth_anything.py": "2984b681dd707e858540686e8899d802b656e108f63b15cb33df88c16a46a1ff", - "src/transformers/models/depth_anything/modeling_depth_anything.py": "9fdf71845fefaa3f03ac3d8eed95131145bf03918f48c4406bc8d8dd3caab4cd", - "src/transformers/models/depth_pro/configuration_depth_pro.py": "1944a304cde013988e18f6caebedef414871bdf988be72ac4094c63f2ea59301", - "src/transformers/models/depth_pro/modeling_depth_pro.py": "3a4fd7a2f9cb24970f248f7e2ccc71a0f002752765514c93cc4e250deb5fd657", - "src/transformers/models/detr/configuration_detr.py": "ff234ca95fd31187b559b12c52fb508d36196667e7f276fd65e7b9963ec0b645", - "src/transformers/models/detr/modeling_detr.py": "a37df6a1ab87689ca259a31b61064d3c59cf5604748b5560fc1d131a505dafc2", - "src/transformers/models/dia/configuration_dia.py": "d080c34197d6b5083a3911f2975b262b86422e0e41c1bdb0a53cabb03d008a88", - "src/transformers/models/dia/modeling_dia.py": "f713fa943dc5f20bfa48db7afcacbcda34c0d92d01c7756c82c1409531148445", - "src/transformers/models/dia/modular_dia.py": "ea6fd0cf5b66505b1c08a1be2548ad86421cd0ea555c6e44082de4d37bb37320", - "src/transformers/models/diffllama/configuration_diffllama.py": "2eace633c0f28f912853ed476d30bd53e2f4a4cf660bf6fc89868d228df15d82", - "src/transformers/models/diffllama/modeling_diffllama.py": "bb977358857f1a378fd66aed97fcfbc18e382ca2a04c4fe523d78e7d792514fb", - "src/transformers/models/diffllama/modular_diffllama.py": "92c69dc0f0a515bddd6f94443ad2b10a82a8225eaeaaf95e8de6080764e2f16b", - "src/transformers/models/dinat/configuration_dinat.py": "601f91c7e337648c205244b9adb7139ccb93254b67259ab0bc211d16bb723150", - "src/transformers/models/dinat/modeling_dinat.py": "1283c47a7b334529bc4bfe23bc464495f0825054890cf95321c9638f6e2f442f", - "src/transformers/models/dinov2/configuration_dinov2.py": "7131b13fd760b969cb6b67428afd70821fc1fc1b9aafe5b3dc275fa05bbff420", - "src/transformers/models/dinov2/modeling_dinov2.py": "7fc6541b122d81a8026ab11a0b17a15bc960855c4df60b3d0e6bf311e1a9ef31", - "src/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py": "2de0c55dae17da749fe402f29c48e15d3d8acba62cf1ccf2ece307b7353a4a38", - "src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py": "5423f67c492393a598be37955b2bee1001a76c067c5c12dc82d526ad044a25bf", - "src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py": "ab18b1abaac08759c8e69ae690677bd32574d68e31d01be1d057841fcc718e25", - "src/transformers/models/dinov3_convnext/configuration_dinov3_convnext.py": "1a4e095afabe1bcc6bdffbc741f6b5c9b5fa51aea352fcf6559f18913bea6757", - "src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py": "bdcad82600f1a10b8a440a9e8978060d2da645f4391a88c01f749322d76146a2", - "src/transformers/models/dinov3_vit/configuration_dinov3_vit.py": "e39f3b6e93bba2aac3acf7b3925c461be8ba1a6fe188e64c3f3a48fe26f65129", - "src/transformers/models/dinov3_vit/modeling_dinov3_vit.py": "5f69e2640c8e62c1407663f3d48e823be098adaac4af0beabcbe75caf578773c", - "src/transformers/models/dinov3_vit/modular_dinov3_vit.py": "7e07ea03c3a2cab3eb1563b6e64dd93a1e4ded7c2589febfca0a605535cab983", - "src/transformers/models/distilbert/configuration_distilbert.py": "3fec20ef91059e5f8da778748b1dd02785ffbfd5535bb7558d0f80649f6087a1", - "src/transformers/models/distilbert/modeling_distilbert.py": "c613d45b2e6641e57b14f7b49d65fd8757fa1b0cef9ef2bf632642e1a67161fa", - "src/transformers/models/doge/configuration_doge.py": "e07aabcf9d9011a000c98dae9cd4e30d3b56935ecb69dc439fa55dbb9e42ddaf", - "src/transformers/models/doge/modeling_doge.py": "412d27f4bc2899df2b0557b8850098cb67a7f1227aba90fd6870e225a0f0f82d", - "src/transformers/models/doge/modular_doge.py": "b7fce8eedadb1dcf2ad1cd9995eed417c3c43a0c3a82033cb4d5730883fd3194", - "src/transformers/models/donut/configuration_donut_swin.py": "dcb89fcea06eeb3cd42ebb4082a1801ab329a788c6506c8769f3f30546002c77", - "src/transformers/models/donut/modeling_donut_swin.py": "6eab5e6afa15cd5a63f6d7b9b2907038071ddfaab35fa53ca636fc46359e1caf", - "src/transformers/models/dots1/configuration_dots1.py": "a8d761c4a70d80581311a7ba7c856ee605329cf460d908661f902bfaec07d0d7", - "src/transformers/models/dots1/modeling_dots1.py": "05d5c96c50096e4681bf30827b0eb656f56c511a883d912166bda19d7ade7903", - "src/transformers/models/dots1/modular_dots1.py": "381cd5bc6bbb4724df36a7c455c75ea0a6ce5149852ba298d47b6546b2c4c1b8", - "src/transformers/models/dpr/configuration_dpr.py": "59d713cbdd6f1be276c8cd515bd03653f9354b6fce9900d85758ab338f74a02d", - "src/transformers/models/dpr/modeling_dpr.py": "b8d5856d4a1d1e7abfd5dc55b4c7a5da5b8f7bd365fe19ce832b944093bb1566", - "src/transformers/models/dpt/configuration_dpt.py": "d90a7848c48c97f93cef4de36640b32ad9c5c9ef2c57809dfacc0cbe4508d5b6", - "src/transformers/models/dpt/modeling_dpt.py": "f9e205ccfe355d6623e3abb6369590ac121f9a41af4bc8abb09320f76a743114", - "src/transformers/models/dpt/modular_dpt.py": "6e7c3269fd38f84c7a03c1a67b3092ce36990eb6744ff559ab450bd08a583e71", - "src/transformers/models/edgetam/configuration_edgetam.py": "b2459a794266c3e394aaca7b5f4c17992f3eacf615db7a8611b096bbed4fceba", - "src/transformers/models/edgetam/modeling_edgetam.py": "5ac7d7798397ecbd8cf337d4e3e3f56e67899dba0f6d5bf546251f0cec990028", - "src/transformers/models/edgetam/modular_edgetam.py": "34b875d10c8a310b1bf734d771f9161c756606e4774f4de4826d70ba131cf391", - "src/transformers/models/edgetam_video/configuration_edgetam_video.py": "86f8eed2cf7a2ff4c6aae00529e056408c0f335df2930366ce95f37888c9f50e", - "src/transformers/models/edgetam_video/modeling_edgetam_video.py": "9fd826a6751ffc942f865f2b39dadd2b10a89f6751a6186cd7ab51a828d5e2f1", - "src/transformers/models/edgetam_video/modular_edgetam_video.py": "bf2abc99e6926345aefd7a622ff58d074cbe0250f7603df5ca3ab372357316fa", - "src/transformers/models/efficientloftr/configuration_efficientloftr.py": "268e037030682dfbd74f9394e87869ec75a2c4b86b4d995f3fc1bc894a3a1fbd", - "src/transformers/models/efficientloftr/modeling_efficientloftr.py": "d4f097fa7ffc0199f43c244dbbfdb042202be58ab8615495768350ed7c7cb3dd", - "src/transformers/models/efficientloftr/modular_efficientloftr.py": "10ba603a0542d60b3c12775862cd5bfa25bea57d42dd83d726b18f33f467bc3c", - "src/transformers/models/efficientnet/configuration_efficientnet.py": "3ea809766f7a479717d8d90896a29c1e7468c85e1712a9f3ba2fd84ef0547057", - "src/transformers/models/efficientnet/modeling_efficientnet.py": "31a5e2555e9784a5be395ec10755131699ed236aa37aebe12d9c7c3ca8d1f964", - "src/transformers/models/electra/configuration_electra.py": "53e603b02d3def048b998ea7dff4dc89505aa1d8af82f7832d09d0daad48b438", - "src/transformers/models/electra/modeling_electra.py": "d8bc54ad919a0d3e2cb31d765d8a501811b28b0b4f97502ee3171eb5a6b4ecfd", - "src/transformers/models/emu3/configuration_emu3.py": "edca6bff04b5bc3f4efdf51edee84b673f1ae5b2bd25048af61693e6e3015055", - "src/transformers/models/emu3/modeling_emu3.py": "6d1b9d472d94aaa1ec4bc36a7ad7bad41f6e59a10137612c93f390efb293da1b", - "src/transformers/models/emu3/modular_emu3.py": "4e6dcd4810d409445bebed5531e6812cb484619427f046a4d324e14d3140e20f", - "src/transformers/models/encodec/configuration_encodec.py": "cff01a374c1b9f206d3317014574d5715b211c77914a75cf2ca9c641d33b5f40", - "src/transformers/models/encodec/modeling_encodec.py": "39ccce4266ceecf45ff2aaba2743e1b9aa5937914214ef6bea18ff136d30973c", - "src/transformers/models/encoder_decoder/configuration_encoder_decoder.py": "585faa9249a5f69b2a512e8ebe042115d20c9aaf8019653ae9cfa1adbd0a8dc0", - "src/transformers/models/encoder_decoder/modeling_encoder_decoder.py": "fbf66b3d8909ffb835022cc7ad6f53bb6d2476a0c71291402dd5ec393e01ba96", - "src/transformers/models/eomt/configuration_eomt.py": "6608061c3f3d8c7f5bb866bba9b77c31c123c18be7ee67ab7dc46bb44666ce7f", - "src/transformers/models/eomt/modeling_eomt.py": "2e931b265494c224a0898f5f9e10549996b0bce879052f09038a1146ee7ca2cb", - "src/transformers/models/eomt/modular_eomt.py": "aaab107c5f91c5b67a35c7f885edfb84feeebf9c97b65aa5d1161df29f405451", - "src/transformers/models/eomt_dinov3/configuration_eomt_dinov3.py": "5ad8a03cb1d95bd9927faa261084ba36d5663edd37216537f5f5bb1e196d5349", - "src/transformers/models/eomt_dinov3/modeling_eomt_dinov3.py": "09061ee86dfe85ed20d1105565215a7dabe4cd50fedd3ae8ec57c9075b9a8684", - "src/transformers/models/eomt_dinov3/modular_eomt_dinov3.py": "abe6347307a7896f7edbc918ba4102dba19adc404056d6b0e3abd1272e5ccc5d", - "src/transformers/models/ernie/configuration_ernie.py": "65957de9d59960da78f00c5e1b2d436f574e4034099b93292567b656e71d2e8a", - "src/transformers/models/ernie/modeling_ernie.py": "b3d71e2b93ddbfdc38272f600cf1b57b7bea8f02db7751b13c1aa0ce7fb59d7b", - "src/transformers/models/ernie/modular_ernie.py": "5807a39769ea643f3e2f92d6fff5cf68888cf0d0387aa6002f6d7fdc2d2a2366", - "src/transformers/models/ernie4_5/configuration_ernie4_5.py": "c91ab7b3c45023b7ad24f2334be76d39bfd0df25f776366a15cc951513206587", - "src/transformers/models/ernie4_5/modeling_ernie4_5.py": "c50eea1ca81685f6933be53c76aa2e4197d4d7c9ca1f527a0480dc22b01f5eab", - "src/transformers/models/ernie4_5/modular_ernie4_5.py": "974be5f0a88e34664fc04e09a8131b28cf8fa31bfbef2ac5bf3e88ef27c7eb94", - "src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py": "d68328c29f38f1dc24caa1e4bc68a1b09d29797c54cecd74c4f48de365b52684", - "src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py": "394cd12396d84aa2bc4c1b48dee65d19ee445cd1bc9acb3aed2bfca2c623ee3f", - "src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py": "d581c3267f586a3ac16e55000dfba777b823f5f97742a403c32262e3501d102c", - "src/transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py": "1ec22399dd386689e7bd4d027c044e6da74bd21e33f2c7abbbceebd5526ab924", - "src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py": "27fbdd60aa982ef09ff8990a7342033b9be97b77d9027a48332ef7134144d79d", - "src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py": "4a8dc470910160bb2363915f49da199ee4026e40936a342f7731ee36803139c6", - "src/transformers/models/esm/configuration_esm.py": "581936d5f9791a3024d9c7776ecdf6aaaefa6603abc51dfeea9c01c9c20cc4b9", - "src/transformers/models/esm/modeling_esm.py": "9c83f87f554820e8bb5a2b7b73a455dca54ac48b21e7f3a5c34b7e042cc332aa", - "src/transformers/models/esm/modeling_esmfold.py": "6bced2ad7502f73bdb8026d817cb0bcc7e7d2c787c4bd6d25055bb0f8afe25c9", - "src/transformers/models/eurobert/configuration_eurobert.py": "3686f8d2537d3c0e7a4357f43ab56efcacdd302e0f1162d481b39412617e1423", - "src/transformers/models/eurobert/modeling_eurobert.py": "aab3bf203d1d5ef3e89299a8b62553ba0aec80e994a7c875cdc952df4c325b14", - "src/transformers/models/eurobert/modular_eurobert.py": "30f8134a3974a4cbba52e9c13dad659fb80986ee9ce4768e23b7eb266aca5ed1", - "src/transformers/models/evolla/configuration_evolla.py": "1480f8d162dd26f21a56fc83b0e24093d5bb41e88747dc5d86cd13341c1a3c66", - "src/transformers/models/evolla/modeling_evolla.py": "b75f66f3e7cf6ce873f8019862971a671b5f4e17b0b7c88d30685e59b3db4c9c", - "src/transformers/models/evolla/modular_evolla.py": "cba489702e1543131603015b805a84982e94192041acb80f7f1ab1a6582deea6", - "src/transformers/models/exaone4/configuration_exaone4.py": "6e5c461f84881ec9835633d506815ef2926e20581545367c1d3af5f3c8964cc5", - "src/transformers/models/exaone4/modeling_exaone4.py": "c9992e4049d7bf4022392e0c6c9391d6fcfab3cb486b7a89455324ba9f5f9562", - "src/transformers/models/exaone4/modular_exaone4.py": "60ae4bd5c733f8769d8cf466a7af1efb265e89570ef25ab46162b16b3ecfb8f0", - "src/transformers/models/exaone_moe/configuration_exaone_moe.py": "9dd01d8361be9d8c71a85c8746fecf6664ca1e703c2ba61c657439aafce576cf", - "src/transformers/models/exaone_moe/modeling_exaone_moe.py": "a56541292aa6e82de28b405bdc62272be44ef73c83ce56a17cd06fb43e597bd8", - "src/transformers/models/exaone_moe/modular_exaone_moe.py": "1749ec26996b110a47d61fcbdaa4e738e094c48eaf1d7d5de02974324c3044cd", - "src/transformers/models/falcon/configuration_falcon.py": "56c34229dfdf286a23f571db23114247e1cbfd2761234b46e4496b15667bf3df", - "src/transformers/models/falcon/modeling_falcon.py": "859b28338f403d6b8ff70c713006712fd1d4b38c7548109d459c253b4d557a5e", - "src/transformers/models/falcon_h1/configuration_falcon_h1.py": "009a43ec8826b15120d456f01ac9f03f31f8e419f3fc2ce415afdbf3b0872c11", - "src/transformers/models/falcon_h1/modeling_falcon_h1.py": "e316f6397e9048fd91650aeb9c0365be1f274af9e4a2a939c92296e7b19f3962", - "src/transformers/models/falcon_h1/modular_falcon_h1.py": "9417d5f47849a74bc76a1bfc4143488e1edfd39a7839ed1a2c37a521af6ea682", - "src/transformers/models/falcon_mamba/configuration_falcon_mamba.py": "0188f365f950e47ea2740c4c5c7cc88c57a1d5f9ad593624b891f81bdae1bf34", - "src/transformers/models/falcon_mamba/modeling_falcon_mamba.py": "794dc45783ebd5b09e058eeccff800ab551df1a6da29a1b4cae45d0ded8229ad", - "src/transformers/models/falcon_mamba/modular_falcon_mamba.py": "3c724cc5ed6eb8111a67c72970f401f4b976a77a0014956759da4d7ae2995655", - "src/transformers/models/fast_vlm/configuration_fast_vlm.py": "54da3584a0e82d9d7273832f31eaaeec87ab945e32b630059573d44354610523", - "src/transformers/models/fast_vlm/modeling_fast_vlm.py": "3cf26be59ca4d52e3775dd7bd11c397d898fec1a89642fc5a05fae4e9e1bbc3d", - "src/transformers/models/fast_vlm/modular_fast_vlm.py": "45098413ce521bd5879c9069a153eeaabe1fe2921257cc6c2c28df388771bf50", - "src/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py": "ee554313d2db445a950229a8b40da219d89bd7c554b4dc899f4b04c7ae94de6d", - "src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py": "ffa5010ca7b592883d770e3cf74fbc1b047ed35c2640653bc39fbb999ed75a8e", - "src/transformers/models/flaubert/configuration_flaubert.py": "9633a95754e1b5d94e8ef39247df3e5c66f9685172eeb74471dac554c9229109", - "src/transformers/models/flaubert/modeling_flaubert.py": "6e41683ead49a055b55c4f0b5da8c033003c7361bcddda9944d0c4c606b65348", - "src/transformers/models/flava/configuration_flava.py": "5cc1f71e11d317a702f1f780705b972730fca26f9fba4be31a245413765f5a58", - "src/transformers/models/flava/modeling_flava.py": "9aceeab0106a822a504116997c1f651657f049e97a15716016643751088746d8", - "src/transformers/models/flex_olmo/configuration_flex_olmo.py": "8f14425e503c491709d19e28119d298ccb2f3370cfe8f1fc416d309f9582b01e", - "src/transformers/models/flex_olmo/modeling_flex_olmo.py": "c5a17d5d9b396d4f478a60668dca5dc3821b59716912962fc593e29a273c722e", - "src/transformers/models/flex_olmo/modular_flex_olmo.py": "45f30b06117144c38845c3cb1bc48c6c72f3bb3765da2a815fe8e9cc205d275d", - "src/transformers/models/florence2/configuration_florence2.py": "f36a93ffc5098b805c4c84aa0ceda1b74bc89eade5f52d68937b641a7b24d3f9", - "src/transformers/models/florence2/modeling_florence2.py": "26cd42ece585ab1e3a0a897b39af68c3a1c1939844121d432595c7b7b106f2fa", - "src/transformers/models/florence2/modular_florence2.py": "773d24b542cb6127230dabdb687c2691b9baef8c4db2d87e9ca7fd626a3780a2", - "src/transformers/models/fnet/configuration_fnet.py": "ab38cf5111ead363a8d78cb0569e9e697dee4e59eb58104664fc0075e04b7b20", - "src/transformers/models/fnet/modeling_fnet.py": "83ce4b3f477d3d86d17b2db152b7ae4b7de48d5b55143c025045a3ad0e70f4eb", - "src/transformers/models/focalnet/configuration_focalnet.py": "fa915e41fd54d84fc00a7b9a57d78e280ce6cbe26e11a4548d7ff5084239b457", - "src/transformers/models/focalnet/modeling_focalnet.py": "d44f280d6e8ad9073250283f115d31e5e6a88b1135fda8c8943a7576edb129ee", - "src/transformers/models/fsmt/configuration_fsmt.py": "1df5942db321624116b9701d6b88b1b54b3468adcd56b2314bccbb945f5d2f11", - "src/transformers/models/fsmt/modeling_fsmt.py": "cebba339f509a01042df22043223fa79f31f915507a9dc394b2386a26ed7eb18", - "src/transformers/models/funnel/configuration_funnel.py": "2452d83f9e33aa1b985b1de72023900fa26d3acbe17c4339aa6c7785888e30af", - "src/transformers/models/funnel/modeling_funnel.py": "855d3a21b0ecc1d9bb2be7391cb052813231ac83e6343fad7575ca51062f5f28", - "src/transformers/models/fuyu/configuration_fuyu.py": "9eedd6492e19520972b70a720da122f2d7f249a2330bb0e2bd5d1f5845236f4e", - "src/transformers/models/fuyu/modeling_fuyu.py": "6e31a1c3cccb22b7f76ec83a5c784b547bd46277dde5949fe55dc71ff8c89fa7", - "src/transformers/models/gemma/configuration_gemma.py": "9117abb75749da1246cf7faa985c46411b2a91a1e93a1b0e4dc977d1654866f2", - "src/transformers/models/gemma/modeling_gemma.py": "f6f0cd5a7b7a1b895476abf3e5c2c88246de2f6e14d6a38863333452a25fab83", - "src/transformers/models/gemma/modular_gemma.py": "4b7d117bb6e128c2588bf7525e19c5df437a93f2265a1271d3dbbee683420d46", - "src/transformers/models/gemma2/configuration_gemma2.py": "8c643de0c94443194f8de3eb6f37ab6156ecc5889da7c7ab20241d5b78036e2e", - "src/transformers/models/gemma2/modeling_gemma2.py": "153226bb60dd43429842901b1b6ce079994e4662786afcafe02a02f5cab79c67", - "src/transformers/models/gemma2/modular_gemma2.py": "6fea0d4fcbbaa6d1a42e0d846beb907ba957f9a5b7d82abdf578db4775d7aafa", - "src/transformers/models/gemma3/configuration_gemma3.py": "069aeec9a2bd0501628122f515eea37b20873c13d031ff4374004c99cd939406", - "src/transformers/models/gemma3/modeling_gemma3.py": "9ae9fa36006078aa84b03832812ce92ab8628d9683f33db59a8c55a8b0277334", - "src/transformers/models/gemma3/modular_gemma3.py": "d0acd36c524bcc2ac352f5af8aaf0ab29446755799e33e7ba9e5d08d757f8e9b", - "src/transformers/models/gemma3n/configuration_gemma3n.py": "6d4a70a2fbb10aa6a822120e978b70708c38b6c699c48110dfecbb5f723e80ee", - "src/transformers/models/gemma3n/modeling_gemma3n.py": "1421dc604bb27f4e133167f547bc3e22fdbb18ff2a08a3f382bcb97ecfa46585", - "src/transformers/models/gemma3n/modular_gemma3n.py": "56543f8ec5ec59b503ebd898f4c87d2ae908d808fd2c443fe1865800d8d92054", - "src/transformers/models/git/configuration_git.py": "39456fecfd996b4fee61ab70b251746127c39f2b7aac152a0153289dd899427d", - "src/transformers/models/git/modeling_git.py": "2cdaad67d7f3807b35d84fa126e071e5bd460c55d6c3a6c5ccbddc3972273155", - "src/transformers/models/glm/configuration_glm.py": "43fa4d49845faffa5175d7450a3a50f4308d35aff20391f4a2119b5dd881496a", - "src/transformers/models/glm/modeling_glm.py": "c0af56314d99fc4423a5b10b0cf8251eca4582f2395748fdf229c800d0b7066b", - "src/transformers/models/glm/modular_glm.py": "cb95d3613d792ec0d2ef30a4959912bd63336b00810df4edf36da203209c8dfc", - "src/transformers/models/glm4/configuration_glm4.py": "599250f9d3f895eaeff4c0b70426e14bacbf57f254198a9792cf9b1e16aa1966", - "src/transformers/models/glm4/modeling_glm4.py": "a8f1ff1fbdfee288a6748d837b914f881500e79ee3f0609f992a697f51328132", - "src/transformers/models/glm4/modular_glm4.py": "6bb30f16155e5fec75a99be4d381b39219de5268176978cacac6b1e1feb6d1f0", - "src/transformers/models/glm46v/configuration_glm46v.py": "75c8efede9ebdacb35b26c2824dd19e9768c50a56570afb2f0f1d5afd72c900f", - "src/transformers/models/glm46v/modeling_glm46v.py": "1f85962df2415a7ff7e070bd4796b8de5c378bc15a6d3b0fb131cb4b81133a4f", - "src/transformers/models/glm46v/modular_glm46v.py": "d2020a8e457d63fcd87efc555464c3aa5b6352793aaec5b31b768de2b552a091", - "src/transformers/models/glm4_moe/configuration_glm4_moe.py": "b32ed9b4f63980c565b92874337e00635038bfb33dc0d455d8005aa1e48c178b", - "src/transformers/models/glm4_moe/modeling_glm4_moe.py": "5d5cdc7788034df7d8926696ac46b0647cde64dc258f0a684a26da7ebc2952e6", - "src/transformers/models/glm4_moe/modular_glm4_moe.py": "a993ce9117f32e7130f99df69950972fd21bc75a9bfdef61a15a19252290957a", - "src/transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py": "6d1dc311c950f2a776f28a0c3f02bf17401b03a51fb8bc3bfe65fce1b3edbf73", - "src/transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py": "ed6194c48cb2b028289df3832638291bdc17a7aeab85b0150c3881b9390629b9", - "src/transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py": "5c3d92595f33af03664adf236724ef644378db6d0f72693d0acc86507c9ae4be", - "src/transformers/models/glm4v/configuration_glm4v.py": "e049bf3a2b10dd119d992a0a609c143ed4e5693b44b01447dda731c9fd7c0844", - "src/transformers/models/glm4v/modeling_glm4v.py": "5d054840dd92fade410bd45326b096ae93e8813cd7a24ce346dce885b09578e6", - "src/transformers/models/glm4v/modular_glm4v.py": "3b811dedcc2d858c68656606667a3b14a34d2013aba8d11f5f8ee179caa35568", - "src/transformers/models/glm4v_moe/configuration_glm4v_moe.py": "61f35c68f19a58c0766fd53a0b5d001bcde798fd72c29d3169fe0f6212272c23", - "src/transformers/models/glm4v_moe/modeling_glm4v_moe.py": "2911f91ba04e0d2c442a490d812660bfa1f4009b1ab2af82b388ab1efe705a9b", - "src/transformers/models/glm4v_moe/modular_glm4v_moe.py": "54fbc2552829948bcd8c5efe33f068c11a998b1a6cd3fdf4b9c9a5b7535d654a", - "src/transformers/models/glm_image/configuration_glm_image.py": "64177a6c429bd6ea6014b3840528c565b2dcdb7572f1e617803940958b90c1c1", - "src/transformers/models/glm_image/modeling_glm_image.py": "09be965e23768671416f9f977412f7585e929de91e2079e817b2bb356cd155b7", - "src/transformers/models/glm_image/modular_glm_image.py": "e0c45ff71eb3ed9313a702c485967a1891f6db5c67c7b7cb75808803ff931431", - "src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py": "4040ac9d39fc01e13d3007790a38ee2fdc0ced1b984b47554121ee1247044cb9", - "src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py": "5767b3af44aa97de4ca3fb3a63ca324e95dbf959c96f66c138e2d1391367089f", - "src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py": "5c7142fabfe08a6c30373b8cda0064f3ed892cadce2b8c012802db52cf2c58d2", - "src/transformers/models/glm_ocr/configuration_glm_ocr.py": "b66cb6953cb6cac27eb554213498b8067c01eb37b8164947eeb472b79cab937e", - "src/transformers/models/glm_ocr/modeling_glm_ocr.py": "4b9c8adb4237c7869e14e1d6205cbf8ee38ee83b2d2d4c9c893e2af85f60676f", - "src/transformers/models/glm_ocr/modular_glm_ocr.py": "e1c85fc6db323fa4e1817c1f5f3c87806f57a11671cc86df5afd3bcb28725336", - "src/transformers/models/glmasr/configuration_glmasr.py": "ede1c1daf302302396b7f15bc146057dfbada127e61badc63991aa94395d9c3c", - "src/transformers/models/glmasr/modeling_glmasr.py": "9cb6b61f223e43c83df9e0c7db1bc6069746d258d50a90ce30a0cf7f60a9e33e", - "src/transformers/models/glmasr/modular_glmasr.py": "e7ee360bc3da3d3076507dc49a8ec2d973f975b08a1f45521436656a7f04ef74", - "src/transformers/models/glpn/configuration_glpn.py": "da8ce2caa9107313330a12a5c27db142a5a231524e31d7ac0c48b7c5e7ab47a6", - "src/transformers/models/glpn/modeling_glpn.py": "031f7211ef5cc993828f15c2bef5645f02d056d12855ab8a05e5aa2984649bf4", - "src/transformers/models/got_ocr2/configuration_got_ocr2.py": "9eadbf6daf60d57c549ed64a42232a8fbf7396fea8894337ba85187974fb8cc3", - "src/transformers/models/got_ocr2/modeling_got_ocr2.py": "91402025896744aee614d90a0dd2f1e6bdd201286e9a87a1ef2a12ae73b0217d", - "src/transformers/models/got_ocr2/modular_got_ocr2.py": "ba936274b071e7a2575863a22cadb6a18c3d39a19b6b0ea96801b8986dd4e130", - "src/transformers/models/gpt2/configuration_gpt2.py": "6877fb4366d14077721637986c3bdff4f9b4e32d02330a8a892f9843b8573ff8", - "src/transformers/models/gpt2/modeling_gpt2.py": "7bd1b14bda833ae55d6ed7767f323d2ee5888c79651727beaa76a58f1537b91a", - "src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py": "bf6cbf8fb239384a9d84a33c8a6361bb0dd74b3568f0d1f8817632134f6e34ce", - "src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py": "c0f3214356a6b0b9277fbb0b0ad95295a62dd9570af424df962962dfe69c5374", - "src/transformers/models/gpt_neo/configuration_gpt_neo.py": "0a93099ec341ce55c7d071e8d55d52ac586d3f3166056de1cf627e2e63721b36", - "src/transformers/models/gpt_neo/modeling_gpt_neo.py": "8602fd2a7d120eb90126a62711f52588c24a04896b4824c40499ef6f167c4acb", - "src/transformers/models/gpt_neox/configuration_gpt_neox.py": "178ba1009107a8f7f49c7bbdab403dcbb6e960328d937b2318f9edeb75938333", - "src/transformers/models/gpt_neox/modeling_gpt_neox.py": "a1784e2ee3ff4081e2b74ad1386d479c319565f9cf3e3d3e2863d49c021a28d8", - "src/transformers/models/gpt_neox/modular_gpt_neox.py": "827b035d2a824dbe6a112026b672341e555284b79a33fe44fe72015f0e36d5b8", - "src/transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py": "7544c340ea49ec3b6c1dddf0aafff3583cc120bd644b789e4f551a0a495598ef", - "src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py": "248f9f143d93e0e0177b2a42d118b56e838bcd257115098e4c18d664a4791d99", - "src/transformers/models/gpt_oss/configuration_gpt_oss.py": "c2684f4858e2f52040c644be163c4372660a963064cc92c3cc8f32d3bb3ed894", - "src/transformers/models/gpt_oss/modeling_gpt_oss.py": "a394c9335c13abdf370acbd93b97f1463093c731b0d77727ad7249d26ace67a5", - "src/transformers/models/gpt_oss/modular_gpt_oss.py": "0f81b6ea5c48d5105afb942db63d2425b93ed364d4f0db31746a07e471e88a2f", - "src/transformers/models/gptj/configuration_gptj.py": "45715867e35816d1d9cdfd36e3633cf50ad789abb7c130ea3a012849ec0fceaa", - "src/transformers/models/gptj/modeling_gptj.py": "c406949259ca0a566ac50d9ded6ad37124736fb5815352b842b2ee8354599a30", - "src/transformers/models/granite/configuration_granite.py": "ee91ee043646a0463e5726aa45380f81872773303b04b62112423dd15b003daa", - "src/transformers/models/granite/modeling_granite.py": "75dc96426b1fbb8409aad24d03705e6182c4a03a07b6b1f59684931bfb1b278c", - "src/transformers/models/granite/modular_granite.py": "0fcf9f2dc71096dffff2dd601c8675c04759a115e0720899618e492e17d7d880", - "src/transformers/models/granite_speech/configuration_granite_speech.py": "2f11afb83755fcf2c29e252b455d3dc1e67dc63d930cebc0c6f13f717ca8a987", - "src/transformers/models/granite_speech/modeling_granite_speech.py": "71f83aa1be0f2180097db8ae51aa5e70148b56b71cf70cf00ca6caad32b27ad5", - "src/transformers/models/granitemoe/configuration_granitemoe.py": "59e4336121378c1f3348ed7c6f625581eddc65263dc80b9abef39ec5ac863e48", - "src/transformers/models/granitemoe/modeling_granitemoe.py": "167e417c8452915b3b35479e305be9112b3a0d214f7f26612c7d518ef3b158cc", - "src/transformers/models/granitemoe/modular_granitemoe.py": "4efac9f589ad00225d6fc7aab43458a3641f4d885311e9f9171526aa9c765846", - "src/transformers/models/granitemoehybrid/configuration_granitemoehybrid.py": "03b66a5fdea23501a672728210ab208664b2106aedec178e325fac19859b3b28", - "src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py": "c9de882cd503aec8b4f263af764804c5186a5a7ef710f2b8f5e8659486ce7556", - "src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py": "179f76acd59c28c2a8cdd46168271aab6f680f3c036a19a987d73e405cc2b60e", - "src/transformers/models/granitemoeshared/configuration_granitemoeshared.py": "a5a1a9bfb0397a23b38fa3e0a820a7f081bb6d6eda06092fc8e0009ad35d84ab", - "src/transformers/models/granitemoeshared/modeling_granitemoeshared.py": "e6f38dca72be7bd1fa03dbfa6424f0ef3ba3ad55d67f139983751c51cbce90c6", - "src/transformers/models/granitemoeshared/modular_granitemoeshared.py": "2cd68110a89053c2238d528c1c55f93b236fd46c5f62b63594eaa4ee14ef9f8c", - "src/transformers/models/grounding_dino/configuration_grounding_dino.py": "ed1312506a65973154683a634e0b9cd6ab29c7719ddf5a9de066fa48a29c365b", - "src/transformers/models/grounding_dino/modeling_grounding_dino.py": "e4ccc8846842d1ffef1610f92480223e95e37ff962123ba1fff253f4743a7cf5", - "src/transformers/models/grounding_dino/modular_grounding_dino.py": "d15ce3d472bbe9381403eb608dd310ae70fba2e7514024d90f25d5d98234e332", - "src/transformers/models/groupvit/configuration_groupvit.py": "b2de28a065806695ec39880677be9e10ebcdb42e949783b4a6ace06b9a1cac50", - "src/transformers/models/groupvit/modeling_groupvit.py": "c840cf7631f2a913011ffe059b8ebd3d42813f7900f2c02506016ae53e7319ae", - "src/transformers/models/helium/configuration_helium.py": "39dcae798b9ef0bd667f371d24de04d82e3079b74073993c01751a7d4228e236", - "src/transformers/models/helium/modeling_helium.py": "80375332e3e29f8c795f300cffc8b1be7082cab18b005e2db7e247bae69937ec", - "src/transformers/models/helium/modular_helium.py": "2b169490ab056767e54c5f92a7d0a5a326a82151d4a6964dc45160fd33e07c8f", - "src/transformers/models/hgnet_v2/configuration_hgnet_v2.py": "d4dc9093f63c1cf87ac7e38e5ca453bc2ac76c9c06642bbf535e4ee05f1450f8", - "src/transformers/models/hgnet_v2/modeling_hgnet_v2.py": "33a29deb272c1527d72040246ef517e5ca8e5e87a9df41cd209fb05c768c2887", - "src/transformers/models/hgnet_v2/modular_hgnet_v2.py": "f07a228e327f98e19d5d24ae1f294741c495d3bac9df6a0a755b5a626ded5bd6", - "src/transformers/models/hiera/configuration_hiera.py": "4da484da3929b8e91d64cec6f31fe041b7f80ecc624d86d31c66203afe1349f6", - "src/transformers/models/hiera/modeling_hiera.py": "5d2c1965290d053b61d5686a3e486d26a3e66e6d56c60debb25cf61795940271", - "src/transformers/models/higgs_audio_v2/configuration_higgs_audio_v2.py": "3e4cb385300016380d1a38ac19994306e8349475881ef13e2499eb95a14b6937", - "src/transformers/models/higgs_audio_v2/modeling_higgs_audio_v2.py": "db470770386fec72e71ea87736b7624af445947b7ea565b9cc15876332e809f3", - "src/transformers/models/higgs_audio_v2/modular_higgs_audio_v2.py": "9041aa6c57b719f7bd5b3ee2e7fd0ad485abd554eb7945382bf823124059ee4b", - "src/transformers/models/higgs_audio_v2_tokenizer/configuration_higgs_audio_v2_tokenizer.py": "56ef66a647e8f1ea45d16eec435f941f55470fc5ebd0ec4923e63310456ed2ff", - "src/transformers/models/higgs_audio_v2_tokenizer/modeling_higgs_audio_v2_tokenizer.py": "ffabafe79e99b48168bd04e8932ee5ab6aeba07f59f39c2bde8c1c1ed25ed2dd", - "src/transformers/models/higgs_audio_v2_tokenizer/modular_higgs_audio_v2_tokenizer.py": "eb05be1a4399904bd563f928fd33bc8e373229fe585b80877bab131aa0508158", - "src/transformers/models/hubert/configuration_hubert.py": "3be790fcaa707e9aa6d3e4cd7a45ee1c3c1965852309010788b5c867cb147bb7", - "src/transformers/models/hubert/modeling_hubert.py": "6aa0010232d6151de7f964365c318a65f74ecdb08628dfd76e0aeadc6e1bc509", - "src/transformers/models/hubert/modular_hubert.py": "3ed97905e0cdc784ca21859426b6a63e5a05362971fbdb0e582d5888fe25e49d", - "src/transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py": "3956bfaa567b8abb38062046c7e8ff7ca2161c1cf41218e4886ca7b5cc318c41", - "src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py": "95935d73eed9c81708d31ba507944bda5d559f6ce2fb8aab6a6d878b00a1c30e", - "src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py": "a0816aeb1cd148fd41d55a65aa46c701530af5574899904ac0017b7165bbcdee", - "src/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py": "eb7d3240fa30d375894e39b2d18dda07d793ff4b5f576497f573c55b452ea9cf", - "src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py": "fe06f8b9a8babdecabfc25f5eba1b78060abedbd8faef303b4d10867537f88b4", - "src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py": "425138d046338473b9a3f8d145b3b44dd229e6f753b22e5cb773576a68bb92c9", - "src/transformers/models/ibert/configuration_ibert.py": "d26257a9642b11986ee39406468f746c88275909be34fa083fb18ef5da620187", - "src/transformers/models/ibert/modeling_ibert.py": "36ae71d50cc253a7773a4e22c37a266acb7f14ffad2b2b65b7ffb6715d599ec1", - "src/transformers/models/idefics/configuration_idefics.py": "8466386bfbec3793430335d19d2d788c25601d7018282e4bf4bd8964bb5a9b30", - "src/transformers/models/idefics/modeling_idefics.py": "dd4cb635a970451ea62c4c69a45c09bf9040f8996fa541772f32f16d04e73b1d", - "src/transformers/models/idefics2/configuration_idefics2.py": "a2cfb4c58668cb987fe8cb4407a4c7b2a5fe59faff89bfa599d889416efc3208", - "src/transformers/models/idefics2/modeling_idefics2.py": "3f859a770d44b5e0ec42b2efc635ca5a5dde35ceb4965480d2de008fcba6b4fc", - "src/transformers/models/idefics3/configuration_idefics3.py": "ac888b8e831910bda16a217a6620250d7d7df5e2bbfc1146dbaf0c86ff35592d", - "src/transformers/models/idefics3/modeling_idefics3.py": "d38fa8c7023a990ca1434209be664055bbc611e8b05c8ede9a655045be5d8186", - "src/transformers/models/ijepa/configuration_ijepa.py": "50eeee3e85784b43daed89d29985934b2681552fd9ef728ff447098deeb955ff", - "src/transformers/models/ijepa/modeling_ijepa.py": "f2045a773328e1d5112d24e1534d29732cb8fb35d783bca415dd14ed51de5c54", - "src/transformers/models/ijepa/modular_ijepa.py": "0d794538acdf5a079c34de2db17f2bbd01bedae6baa3ad0d4369cb528b720a6c", - "src/transformers/models/imagegpt/configuration_imagegpt.py": "69a64e7bda20c710ad3353887cc4069a1beb6b0566c389a24baf9fb7b19df537", - "src/transformers/models/imagegpt/modeling_imagegpt.py": "ae5b5e78c838af514ca7361c5dfc555447958a7a2ac21df8728431a32499130c", - "src/transformers/models/informer/configuration_informer.py": "6b58f8b04970545e81b461c37c983092d29c81088eba908b2f9a39841916a173", - "src/transformers/models/informer/modeling_informer.py": "fa8e18099395fc6f9589539b9a364db86931d0e10c0664ccc4975ba374c243a9", - "src/transformers/models/informer/modular_informer.py": "c7fdddb9fc0458d326133730f3abd6c2828ce3ccb7d8a536ff27dad401a703fd", - "src/transformers/models/instructblip/configuration_instructblip.py": "5e8b06439714e21b7b53d417bd9394c1033a4d625084f7efa608ececfcc707ce", - "src/transformers/models/instructblip/modeling_instructblip.py": "e90d855bbe23d0398d21e421c46d8f65805f9ef2ffa3f37b8236eac9fafd482a", - "src/transformers/models/instructblipvideo/configuration_instructblipvideo.py": "af943c65075c0583b5ef4bc10585c6dc77fe0da286c238af38a9e1181885680f", - "src/transformers/models/instructblipvideo/modeling_instructblipvideo.py": "dbd483108ec0fdf326b0ac22982cf698b30aeea0e793fd3645d78c794f4fd05d", - "src/transformers/models/instructblipvideo/modular_instructblipvideo.py": "200544979f197cfb23af923da3ba15ce73e92488cc098c0b4dd1821cebdf231a", - "src/transformers/models/internvl/configuration_internvl.py": "fe8195d8de48dca267db1babbbf27b68e7ccf0e12619ac6cc94d381c0e1d7fc6", - "src/transformers/models/internvl/modeling_internvl.py": "8e88aa0b4f106d8ae1db4295e9dd9852579f6961388b2e68a0a27b3bb42e35aa", - "src/transformers/models/internvl/modular_internvl.py": "9a6300baa9b84e5d9c006688f3a1d539d306d71f75c762fd7c060d8c30ef7a42", - "src/transformers/models/jais2/configuration_jais2.py": "c8fe3e6ed41c7b2789cfa40e98708c8cf6843072bbaf90e78ed385784dbf3d24", - "src/transformers/models/jais2/modeling_jais2.py": "70332ec1796d7b35605e316dc623bb3b575a9f62df06c845df59c7e797af3236", - "src/transformers/models/jais2/modular_jais2.py": "eda1a533e283498a7396c5cbda569b05a2da0874b4551620fa429312c2aec94e", - "src/transformers/models/jamba/configuration_jamba.py": "8af6cc52904b6d29d57127adbac7a9565d18c9568790cf7fe1c285dfb6b254db", - "src/transformers/models/jamba/modeling_jamba.py": "88574d672b75849e727c11805efb7db299088ae7b2ea47420e6b4f215407cbc0", - "src/transformers/models/jamba/modular_jamba.py": "91b43b0c16d3550eb9f03bf611e6d1185055fb485e79085c7b500067351b7ada", - "src/transformers/models/janus/configuration_janus.py": "27f32d804c63f98ab8fe13784d0e4deaf93d915e874026b6fda927bc66246d9c", - "src/transformers/models/janus/modeling_janus.py": "d0c0405fffc61583dfc5c30bdbca09ed5237a9f7648fc3082fcc4f7a4893fe5f", - "src/transformers/models/janus/modular_janus.py": "124315d13a7335ed47132e0ee9d00a1681d33e3c6fcb192a51490b617cdc1d5b", - "src/transformers/models/jetmoe/configuration_jetmoe.py": "0a94120f38233320eb1795f65aa1ea48a69c57e88a3496fb37691387da0c0852", - "src/transformers/models/jetmoe/modeling_jetmoe.py": "87aa01c2f7744028b0fde50649ac11ef0cd5ed5ac8b6e11a04f68bc5c4944e97", - "src/transformers/models/jetmoe/modular_jetmoe.py": "1611c058af53e7d12ec04237eed49d997e5a70b734728a821a3d8ea0008d35c8", - "src/transformers/models/jina_embeddings_v3/configuration_jina_embeddings_v3.py": "f92c62ba245f43a8dcbbf4c4aeb0f8ccc0873eab488aba2beea98dfb6c745b87", - "src/transformers/models/jina_embeddings_v3/modeling_jina_embeddings_v3.py": "85bded3925faa2af7c5cea61c4a0a5241d20fcf70d6bc76c1e1072961abc6c8a", - "src/transformers/models/jina_embeddings_v3/modular_jina_embeddings_v3.py": "e010ad4570dbc616fdbebfe21294f2a62c61497c3cdcb96d843e9736cd7ae7bc", - "src/transformers/models/kosmos2/configuration_kosmos2.py": "e608ada314417fe3596f7ce6c17de4e176a28d753d70784987ac76647e33dcd2", - "src/transformers/models/kosmos2/modeling_kosmos2.py": "f6cc4204fadb19893456754b9d4d9199b63938da181ffae00041e2d29c8fd102", - "src/transformers/models/kosmos2_5/configuration_kosmos2_5.py": "6c7a969abe650340f4a4c7b3f4612dd8145186a09e313de5e3251b9184a6e664", - "src/transformers/models/kosmos2_5/modeling_kosmos2_5.py": "8e9d18816d4f3180b1178eec20286b23fe7146116ea6fa3110fe770f4d188c00", - "src/transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py": "df3b14a1a047353d7df2ea4be7a7e0b7415d9a9107df00ddffd7c83c80a490ce", - "src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py": "d79f6d1622e316f543de0cdf140fb7be97f2f0c2e14f1328e8e1d60e0cdadf25", - "src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py": "54ada8c6cb8c30f17835391eccfb8dc4de40189ab77c8575d64de026aea34897", - "src/transformers/models/lasr/configuration_lasr.py": "f3bc15706ca88fc7276cf1fc34e000e5d15af0b760ca038ad03c91748fffdc9d", - "src/transformers/models/lasr/modeling_lasr.py": "a0826f3ccc742883aed4379c5781761cb86d5c051bdbd88b19c96ac785c99661", - "src/transformers/models/lasr/modular_lasr.py": "3fa07714ba48c5c75a0acccfc52565997942776375630e14f7afc2f7dcc4764c", - "src/transformers/models/layoutlm/configuration_layoutlm.py": "0c620323bb7283e81d725dec645e59b9c19e6d29b5cf3ae4fe727f28fb673a0b", - "src/transformers/models/layoutlm/modeling_layoutlm.py": "3a4fff5c809d5391b3e9e0a9a91854e415630d49f8f9a3b9cc3c3decbb9f0d03", - "src/transformers/models/layoutlmv2/configuration_layoutlmv2.py": "cc42120fe754ce9d70fba1ad4ade586bbc498bc376ad90718e81cb95fa96d198", - "src/transformers/models/layoutlmv2/modeling_layoutlmv2.py": "bff2c6e07df63c065f856b111e526ddddeaa2592fb8448163f94bba5e209791a", - "src/transformers/models/layoutlmv3/configuration_layoutlmv3.py": "6b686db372a201d892a5012a3c01d1489d3c770bfcd2b1af6e6a79c53cfaf7ba", - "src/transformers/models/layoutlmv3/modeling_layoutlmv3.py": "eced42508a051419f58728f67e87285103c50f425fff52e081819527f1d804e6", - "src/transformers/models/layoutxlm/configuration_layoutxlm.py": "1f9868377dc60967210d0e210afa763b8bbb548fb587a08a1ee460f9ca382d71", - "src/transformers/models/layoutxlm/modular_layoutxlm.py": "72b04dc0238c3994ffeab3793b01fdd13490c86eb6d0455e306f856c53698bec", - "src/transformers/models/led/configuration_led.py": "891c9f1ab5df5a930ad244ddf4895793b04d1719e5e4c4fb17a075f7d6815587", - "src/transformers/models/led/modeling_led.py": "d9c7c05f82cad4a0b7768d46bf03897a5e744c6680e247a297725415cd7c9f60", - "src/transformers/models/levit/configuration_levit.py": "30fabb6bec15d4eca98924d47c6e7ee373003d8adc1807b1d11a6019ed947059", - "src/transformers/models/levit/modeling_levit.py": "d9cbbbbf990729ea822985e5bf4ab5b3470e768b6cfea5f9e34770e4cf367ee2", - "src/transformers/models/lfm2/configuration_lfm2.py": "abd75f2d60437fb897c7fe18a29f6b46b0f1bb9199ed99fa333198ec71b6182e", - "src/transformers/models/lfm2/modeling_lfm2.py": "aad0ca1d374f389844dfbdc7d2b5d203ab7b65f729378ecb775a5ad0d8e3186b", - "src/transformers/models/lfm2/modular_lfm2.py": "20547db9f5818e125204c7a5cde1dcf10ccf2e355e6c016da76b56c064d26040", - "src/transformers/models/lfm2_moe/configuration_lfm2_moe.py": "3b7d1ae2fbff6f852535ba2f77bb56797ed15960a968067907cfd6d4af6b3e49", - "src/transformers/models/lfm2_moe/modeling_lfm2_moe.py": "86f07f9c909f44433383a1fd3d4de55836780eb655c4a3a551d99260f7ca1312", - "src/transformers/models/lfm2_moe/modular_lfm2_moe.py": "a729afc0cea7e95544bcfbf3c9e26df6c3fcd50217e9c2b5503420b8eb6662cc", - "src/transformers/models/lfm2_vl/configuration_lfm2_vl.py": "b4ad74bb99c172969ce90d0f891216586540ae56ae4d00e464fcc6602d5839d5", - "src/transformers/models/lfm2_vl/modeling_lfm2_vl.py": "c47ba12db8b9d2799d84852ac9740d7ef1ac989e72675a59072ad5e6bddd5c81", - "src/transformers/models/lfm2_vl/modular_lfm2_vl.py": "9d049b640f900fc1d20a6fbb89d4675cc21965d34d68f4b7a01385f4be48288e", - "src/transformers/models/lightglue/configuration_lightglue.py": "92c1257971fe300dde6f7d1027273137a179080880fc1767b363032b001314d9", - "src/transformers/models/lightglue/modeling_lightglue.py": "a1b135fd594b6c6d93b84f35a8b3145c121add4a8d2cee1d729430e5f8051f89", - "src/transformers/models/lightglue/modular_lightglue.py": "4dab9bc19ee0f29d060ad3f45c17deebba8c523d7adae6aee92b30ed8f41e8fa", - "src/transformers/models/lighton_ocr/configuration_lighton_ocr.py": "e09b1d7c2ddc07a32e1fc439cf6be058ed04c56b01f9b33eb8aafa52ad25078b", - "src/transformers/models/lighton_ocr/modeling_lighton_ocr.py": "1365e4a5fbd040df4fee3ace36c26707abb96f1cc08ad3d2ca76abe004495dd2", - "src/transformers/models/lighton_ocr/modular_lighton_ocr.py": "66674d1276e11713232afec8c1fe0afd8df6b12de4a0ea995d9d95ff3abf5359", - "src/transformers/models/lilt/configuration_lilt.py": "b559fbbf3b421aea2859ed0a4ee08ae783c4e15cdd86f8a8d0a3806255c28309", - "src/transformers/models/lilt/modeling_lilt.py": "5df3bcb7e24a15171235b7faaa14dd33c35b42807cdad1f987eb7656ad871e97", - "src/transformers/models/llama/configuration_llama.py": "c49ff892b7f62f8ebca48874e73cbc6dccb0f46be96fb60dc1f806283807cac8", - "src/transformers/models/llama/modeling_llama.py": "6f2ac1eef350e2156a2ba10c7eea78b86afe1b66c5f5ce8a6df76a285afc8fc4", - "src/transformers/models/llama4/configuration_llama4.py": "0eb9d8e88b4129e9e06b0a52b4622e71d19f1b249ece07ec3e54f8c35ef9a754", - "src/transformers/models/llama4/modeling_llama4.py": "0a8dae2ec943c843adf4dc6d578e19a8511dc23b9aad805fbe5d3730c2c9a1d7", - "src/transformers/models/llava/configuration_llava.py": "9427ffab42ee85abcd7c4773016852188c6c176ad6ca6d0f3b0fb35bcdc5a5d2", - "src/transformers/models/llava/modeling_llava.py": "83fbe20a69de34944005827149fa65090bfa1d97d94ce4e6725987660f989c6b", - "src/transformers/models/llava_next/configuration_llava_next.py": "1265f66ca602fb24ddb51d72cb33b35a510b674da85a37c3ee09a8412f97a4ea", - "src/transformers/models/llava_next/modeling_llava_next.py": "babe08592781923894dfa316306413604931a0389f5a10a8098ff7fae4a6ed3e", - "src/transformers/models/llava_next_video/configuration_llava_next_video.py": "800aa64d7a4558ecfc5a98b12229f15dc5c9e1384dbc15492c6402f9c7b0a349", - "src/transformers/models/llava_next_video/modeling_llava_next_video.py": "ef9818cd2fb50b5199450f565fa4d5f422cb39d9bc0a9334ef28a7bc30d62d2d", - "src/transformers/models/llava_next_video/modular_llava_next_video.py": "b883d36c20f9342a7a751c54f11f7a4584d7419600bf7a74da164ca1b73c9172", - "src/transformers/models/llava_onevision/configuration_llava_onevision.py": "24cbcac0f167945b6776f20e16595525f4e3aea47b76af884ae6fe73677596fc", - "src/transformers/models/llava_onevision/modeling_llava_onevision.py": "ba68552eb3b4a59220ca7ca2a372201361600afa33a63d791159296127e8b930", - "src/transformers/models/llava_onevision/modular_llava_onevision.py": "42d1b62db35c035f2c88437849138d40eb735e22036388b2e0a96a42e44cb866", - "src/transformers/models/longcat_flash/configuration_longcat_flash.py": "18aa599f9018a922bf6e85e9fc76d2fa4bac1b9503f49bcd72c65d96ac855ee3", - "src/transformers/models/longcat_flash/modeling_longcat_flash.py": "b45993d5b8095d6c94938898c0c5fbaa8b8e91c991df41196052ce4e2992cf2a", - "src/transformers/models/longcat_flash/modular_longcat_flash.py": "45ec389a7d08eda058b3ee5eef7e0a89c29735f7050e5db9602ae1864f8538c5", - "src/transformers/models/longformer/configuration_longformer.py": "bf48a7fc96c06ceacfbd0dbaa9693d4cc5ad8e6d4795d035706feef5bfe7103f", - "src/transformers/models/longformer/modeling_longformer.py": "f65e93d58c0269ec3de56e1ebac31ee3707f37f87ffd80a80e26f655a8f85570", - "src/transformers/models/longt5/configuration_longt5.py": "75e4053d4a5324079fcb2283cb2958c388d4c25436d8ee95668e80e79d0f0906", - "src/transformers/models/longt5/modeling_longt5.py": "a7648386915ed48297193e0da3e6f2cd7857377d700945cf75ab3bee0d4113a7", - "src/transformers/models/luke/configuration_luke.py": "4f93e179631aa338ef86d0281f032afc6b4c7069114f23bd6c4c9c46aae44bb9", - "src/transformers/models/luke/modeling_luke.py": "06c189416b7ee199f6cb843737824ff2a136e616b20e837c7c293b2c2950d228", - "src/transformers/models/lw_detr/configuration_lw_detr.py": "dd9bfc1eb2afb3503961ed1176635329b522a6621e38a9fd6723a22823babe1d", - "src/transformers/models/lw_detr/modeling_lw_detr.py": "33e966f3167c8d670a458d45cd45cf1f715d242401a9ff851ddc3c611556a7de", - "src/transformers/models/lw_detr/modular_lw_detr.py": "44073db28bb0e289251fabcd360277dd7ebed74bfd8c007100563cb6596d6e10", - "src/transformers/models/lxmert/configuration_lxmert.py": "192108fa56c4d3caf41c38fa250103ade45f1ee1184a31a6768466202f5baf7d", - "src/transformers/models/lxmert/modeling_lxmert.py": "2d4707b9eb872f00699110aa787339a4f563139a76c9a9ae13f6054cbe77db27", - "src/transformers/models/m2m_100/configuration_m2m_100.py": "3641f24e901972ce3ee59f696c2ae76df7f260e8830e2b61f66b33f48c460a25", - "src/transformers/models/m2m_100/modeling_m2m_100.py": "1d0a84904bd69f01ce72761e7d3853e47609217be59aee24cced870196ef9e5c", - "src/transformers/models/mamba/configuration_mamba.py": "676b6ac4ee81df3d931d3922d31500417434a519bdd2b1740df7eedaacbf2ba2", - "src/transformers/models/mamba/modeling_mamba.py": "1f53c50107dc26f76ad8f69be9ccc35400caab9a76c56349e98b484f3fa0ced7", - "src/transformers/models/mamba2/configuration_mamba2.py": "eda2e5732ff4a0936d303dd19fa93c752c2f836a0296a0f65ada38bb4405c820", - "src/transformers/models/mamba2/modeling_mamba2.py": "c79f65c420c4e4ecae40faf8e245944e00e72864705db7b8a62da398b370afe5", - "src/transformers/models/marian/configuration_marian.py": "04687c7d08b1bf00c3aa09ee421526ede6f7f9c488b578c10df52547d76c63f1", - "src/transformers/models/marian/modeling_marian.py": "dd7958f8d3161f6284b4e0e9f07acbe4067fdcc98645a3c20c405c88018bf519", - "src/transformers/models/markuplm/configuration_markuplm.py": "37c63f80ad5ab0822962659e0f2c7af1ce0ab9975737ffaad4f237962a7da19b", - "src/transformers/models/markuplm/modeling_markuplm.py": "9469040f6c9dca224f514b11d49077eb5551e389b36755867ff659013b16cbfc", - "src/transformers/models/mask2former/configuration_mask2former.py": "dff0731946cb303a264b71097ffc5163774be9ff3aac5944e6f70a69d68f8cd3", - "src/transformers/models/mask2former/modeling_mask2former.py": "bc1988c9b84d46a53d620f7e433342b0086632b405b3b1e74407a068721141ac", - "src/transformers/models/mask2former/modular_mask2former.py": "05d3b7dd795d94b1c2002a3bdccfaf4817e7e1d0274d51f3f19f167245c369d6", - "src/transformers/models/maskformer/configuration_maskformer.py": "382332dd9b4b14a368209ae023a5b6c7ca68793790509eedb78dde8978a73b6b", - "src/transformers/models/maskformer/configuration_maskformer_swin.py": "7eea22c3a809fb73e5dca74dd5d89d41caed8da4690516adbfbe7763f3447bd4", - "src/transformers/models/maskformer/modeling_maskformer.py": "4c556a0123547c493bf7be9fb5282e3592cf59c5c58504dc09fac4a32291f7cc", - "src/transformers/models/maskformer/modeling_maskformer_swin.py": "f0c7ca9256a19892c0aa43dfba0b9331700f10a2aad9e432d6666df83329580c", - "src/transformers/models/mbart/configuration_mbart.py": "215d49bba723ebef40e86935cb0715e54a724baf8f44643d4c342dde7aa1dc9e", - "src/transformers/models/mbart/modeling_mbart.py": "f342acab078db6e9d8fb9b3ec1619788e999a1406d80f9adf9899385118a226d", - "src/transformers/models/megatron_bert/configuration_megatron_bert.py": "820833f9e3ce92b3255d57e6e6a973de583e4506517144d28fedf21c3faf9cf7", - "src/transformers/models/megatron_bert/modeling_megatron_bert.py": "cdfd719bda6fb9e8d0164fd51daeed371c0f696ffb8c25cb00eb5fcb71ec6d60", - "src/transformers/models/metaclip_2/configuration_metaclip_2.py": "f39fe8ef1e1e2e7bb9a9d146a57aa33cee63e6cc38be477dfb8b064e47fedadc", - "src/transformers/models/metaclip_2/modeling_metaclip_2.py": "946f70f082f2175ce67aff547139b2c6d512bface7e618e93054376dcd6fc085", - "src/transformers/models/metaclip_2/modular_metaclip_2.py": "62c81588bc32791c7972391883410df250349f9acf6dd93a707a5ec2da415668", - "src/transformers/models/mgp_str/configuration_mgp_str.py": "c2e885be62ad4f06543a0adf0b9496298f9b58e23ba1f6810293fc0959fe0f0c", - "src/transformers/models/mgp_str/modeling_mgp_str.py": "6b8b54afe5067ec5b428693901da9741952b903b19344497fbea665e38207bed", - "src/transformers/models/mimi/configuration_mimi.py": "ee63e6b5311ea002b707c3af7dffda26d23525bffd3f98e3fba6a6afa8fc6870", - "src/transformers/models/mimi/modeling_mimi.py": "439ff84320e62279b7e462679f3fe714e7c677b880b2561f4e34cf2800409294", - "src/transformers/models/minimax/configuration_minimax.py": "70bc96d45ebd466cb3c841df0f14a7aa1d18c82a29fad9b2476252dbf7240ab1", - "src/transformers/models/minimax/modeling_minimax.py": "ef591877e3be0c67a4beb218e64111df54a6e3c261520f8941b099f1973fa195", - "src/transformers/models/minimax/modular_minimax.py": "f8ebc1fd0c77dc2e4a85b86ca9a7505e81ddec189f7e1cd419dab3a0621a5fb2", - "src/transformers/models/minimax_m2/configuration_minimax_m2.py": "81b3c38a8376fc54c0783ec9eb456e38d2957df89a44cff9a579899e04cf14d4", - "src/transformers/models/minimax_m2/modeling_minimax_m2.py": "4eca7507d3398335c23cfc85dd51a6b2957453ead26baf78a7e791643e3a4ac9", - "src/transformers/models/minimax_m2/modular_minimax_m2.py": "d4ade009e1a324a659d34d7f71d9f1a90dcf81fdd984a2bffaa7edd2708e22ae", - "src/transformers/models/ministral/configuration_ministral.py": "317a30b98b70d1b6d6455b72ac01b35953ea66c4dced8a929e5710f94dfbf60d", - "src/transformers/models/ministral/modeling_ministral.py": "b15911a156abc1c8b26ba6723946b93fc871d6cbc25467ed9ce1baef3887a7d2", - "src/transformers/models/ministral/modular_ministral.py": "20608822d41512fffcedc77477a9a189b821ab440db56a37df4cbe313a1b753f", - "src/transformers/models/ministral3/configuration_ministral3.py": "7fe0f928d9726f384a6ed80fd887c957b3534f6f7b468cf66ac4f5866f42435b", - "src/transformers/models/ministral3/modeling_ministral3.py": "cee5f45376a45e7701eae72637ff91cd20c22bc905a2bd905175afbb23644c2d", - "src/transformers/models/ministral3/modular_ministral3.py": "b0dcb37dabf1c18144be3a0b0ba46469c91d79b7bf2b650df8dd6e9ed9a0d926", - "src/transformers/models/mistral/configuration_mistral.py": "e335825263bb22ef018c6687fbf4dcb7a7b80dc5efbcb60946900a176a8bcc9c", - "src/transformers/models/mistral/modeling_mistral.py": "231f9d5f74e338780d6b5b90e6b1d34a4332778e075de8f03dec09416dbbf491", - "src/transformers/models/mistral/modular_mistral.py": "b52e48055ab011fb41cbd700967966c970d083ad77ea1b22a92907dc861f19d7", - "src/transformers/models/mistral3/configuration_mistral3.py": "d5c54eba59c720bb8f4c8657a429d15149c144d4ba5d4374de8b8ae5c63d13a7", - "src/transformers/models/mistral3/modeling_mistral3.py": "9288fe03a3876c45d5601a313ddee0408d878ae141ac6038b635fbad41f8289c", - "src/transformers/models/mistral3/modular_mistral3.py": "57303c2182e09daba7c0b006ce6b037a4b03fb5b606e769b095034645ffc9fad", - "src/transformers/models/mistral4/configuration_mistral4.py": "e0e0cb2724b64f26c6683b2153f9d4c48cd58cd68fa336fdec0dc6238932a5f3", - "src/transformers/models/mistral4/modeling_mistral4.py": "2d590ea9f7a781d38844e7568cc5535e6a89feedf33074c743b020f248567bd2", - "src/transformers/models/mistral4/modular_mistral4.py": "cb3984a3ab3ef7bc448c0b5850d2aa7ce1533d01ce46cd7fb0ff879f6a66954b", - "src/transformers/models/mixtral/configuration_mixtral.py": "bdd51edcd6f03fe0bb6eb8473eeb91668fe87e64bf8f66cf5bf77516e0e89647", - "src/transformers/models/mixtral/modeling_mixtral.py": "7eb898ca25bfa911c56779b90e6750ab59a1b0bf772ffce688fef53a0850ccd9", - "src/transformers/models/mixtral/modular_mixtral.py": "59b4473a0b3951c735b59571b0cb8268bf7f55e90a6a0a163ae5d6229c7ac262", - "src/transformers/models/mlcd/configuration_mlcd.py": "28d23bafe62c03828e6d676f650ee05f23721581a955b04302dc5b98e99bb3d0", - "src/transformers/models/mlcd/modeling_mlcd.py": "45850c7798bbfb5f71472b224f071c1aec5cc8677e7ba284bff2c61e1a08d764", - "src/transformers/models/mlcd/modular_mlcd.py": "85c33ee9c21d3fe91e3f9ab8c62c17cf46c499975f5f5a3b4c493b4fd77108df", - "src/transformers/models/mllama/configuration_mllama.py": "d99751ab304e369dc41d3e5c8c9e5c310871c9a7be20a30699227f9c8267f64a", - "src/transformers/models/mllama/modeling_mllama.py": "c29ff406e8532ff0aff483bf5bd9e90696a51bce908f2c924e48f4e839dff3ca", - "src/transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py": "258ffe3cbe483b95964cfd4fb40f21ead89fecf8cc52033a443a95d302f815bb", - "src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py": "b2959f3505f412bdfb61bc51caccf8114faebeb551a5949ced06205134fd3895", - "src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py": "68f2e10d57715841e0c20bcd0ecd5ccc2df2f860e722b9e244afb12417ec06b9", - "src/transformers/models/mobilebert/configuration_mobilebert.py": "7820022564eaf93cf1a95ebb1ce29b9e50e5374a0b509c560a6f298164879123", - "src/transformers/models/mobilebert/modeling_mobilebert.py": "9ba449503ffb3200cacd50044ace9a6a3dd0b00e27e47dd5daad2e2d9bedb827", - "src/transformers/models/mobilenet_v1/configuration_mobilenet_v1.py": "e87b08a9dc3aa2501de203871e021cac15e912e023776de9daad1b07858ee1c4", - "src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py": "7cd6e4ae6c9932449df39dec0c47e567f66af6243ceb2a3cf41d9aafe0bf2a38", - "src/transformers/models/mobilenet_v2/configuration_mobilenet_v2.py": "996dde5c9127cd8f2c4b0a36a35f4301b99aeb2fc0943963a69dd26b221a21b2", - "src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py": "a625288899fad7c83c64a77292b206b8ecdf0d549a93ede281540dcd6bb1d6fa", - "src/transformers/models/mobilevit/configuration_mobilevit.py": "8279aa0f13a8dc035e82662267e9eaa1e86d266c8a19dcee80bcfc0bdeeb6326", - "src/transformers/models/mobilevit/modeling_mobilevit.py": "a1693df066c75d43f283931e5f04168c7dc90d1b3483baa68f771c6c0727fd0c", - "src/transformers/models/mobilevitv2/configuration_mobilevitv2.py": "01d6cafa2d65d8c63a0757d8e118eab165442a4bd465170aa9cdef32ea92bd90", - "src/transformers/models/mobilevitv2/modeling_mobilevitv2.py": "f1f88531fb683fc656aea2ce5f07ce41f3396e57696a96c14a12c08d751cb41c", - "src/transformers/models/modernbert/configuration_modernbert.py": "7b38bbc365dc81ae44cf8216a78bdf877a82a07a90d156b016f2e8bb12656139", - "src/transformers/models/modernbert/modeling_modernbert.py": "17d1c8ebdd4b38cc1f8168f1ec0f6e5ce1555133d2838283c99cc353a87ea479", - "src/transformers/models/modernbert/modular_modernbert.py": "1041c9ca8875fe10727e24c72f739258757878aeab9f081eb7af0634449ea777", - "src/transformers/models/modernbert_decoder/configuration_modernbert_decoder.py": "06aff47e32cd549d515b94e53dd2005c098d01ccbd5ac58551eb9753828b3589", - "src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py": "ad57a59b680136964adc32dc1b46f829d7182cbb1049c06e7bd09de4a455d470", - "src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py": "79a99d35d0bd37c0dc115e41302ecb1226da5c11f09e990f481039ca236cd95d", - "src/transformers/models/modernvbert/configuration_modernvbert.py": "c11fc6e15ba2056bb1e69d0f3491f74b7de582ed9012158bad15facc77b43758", - "src/transformers/models/modernvbert/modeling_modernvbert.py": "e9e17b061e32b17b578fd74494189b0bdfcebbd6f6d2bdb66a002434f17602fa", - "src/transformers/models/modernvbert/modular_modernvbert.py": "1ce0733169ba27520fe04e64b5f8063dca1dacbd484108356dc11004142542e9", - "src/transformers/models/moonshine/configuration_moonshine.py": "cd39122f595a74659fc10c186a37f8e32fae878fa1f76fc79fa70138cf6cc3ce", - "src/transformers/models/moonshine/modeling_moonshine.py": "c304fc069c70d0b2f0b65ebf130dfc7e34a99a42990df5342c03abefdb6c616b", - "src/transformers/models/moonshine/modular_moonshine.py": "814a81b2ba484d29d91370286a8864712f486bff11409671d569c3dcdb06e2cc", - "src/transformers/models/moonshine_streaming/configuration_moonshine_streaming.py": "56b3543ac4b5248b748cb74ce06b58e52e6d9f28e984876dac72d30d548fc8d0", - "src/transformers/models/moonshine_streaming/modeling_moonshine_streaming.py": "0f655344aa9b5175bfd008934c3212943c182f62a3289968e80ea4091aa6c1c7", - "src/transformers/models/moonshine_streaming/modular_moonshine_streaming.py": "ca91e325a2615d738b6ce758ff7c260bed41ae8bd3ab5bce7a0c2ce0e0cec368", - "src/transformers/models/moshi/configuration_moshi.py": "e33d29255de8681d404faca02cb78233913a287042f67b8bf5a1d1c01ca74fa5", - "src/transformers/models/moshi/modeling_moshi.py": "fe5d4f845dbeb52a6f1876c6595d3f196fa638d8389250a757fbe4460b08e044", - "src/transformers/models/mpnet/configuration_mpnet.py": "58f1cc39d77c499402cfd50b5fac9f2171704791bf24f9b4cd2a1dec8a4acb7f", - "src/transformers/models/mpnet/modeling_mpnet.py": "bddb7d4f855b8bb9ef74fd0b1050adacdb9e9533e5b95462acca7677328a283d", - "src/transformers/models/mpt/configuration_mpt.py": "860fe86ffe0003c7c1cd2a8457329eefbaa0bb63a1ae8d632a28a36f4a455925", - "src/transformers/models/mpt/modeling_mpt.py": "306dcda2d0ed54f8368ebd8cad7d9ea0b4485dbd007b6f0cfa516c3943f88711", - "src/transformers/models/mra/configuration_mra.py": "3fde853d0de420d505f774e989e8e2aba9586e32850bb7f7ca28d6c3cdf0ca60", - "src/transformers/models/mra/modeling_mra.py": "782c9ae454349cfbbd5928606bd9584838b9052ecd6e4bbb33129e51693fe844", - "src/transformers/models/mt5/configuration_mt5.py": "f8eb09b0441e4cb8c2346517c9f146fe3b9f1f409530c62849178d878f2e2d1a", - "src/transformers/models/mt5/modeling_mt5.py": "d150abf4fabac1413f0ef88e1b6ea26b5f9d3fa7ef7bcfad11d987a55db1c695", - "src/transformers/models/musicgen/configuration_musicgen.py": "03f0ff57e1caa4514f6db3da5d4acd7aeea0f0555a6a7fa42bb4969715baad52", - "src/transformers/models/musicgen/modeling_musicgen.py": "d2683d095e73786ad71107b7860cfb27ea7583bda9930e1b8424541c4041e668", - "src/transformers/models/musicgen_melody/configuration_musicgen_melody.py": "ae9db8975eb6178dc5995e9a91a0a32ed0354fe858fee601a8baf034570322c4", - "src/transformers/models/musicgen_melody/modeling_musicgen_melody.py": "9fd2134aaa583b4d6df5280c5bf8cd6586ec84777515946eec0e8bc09aa6ad2d", - "src/transformers/models/mvp/configuration_mvp.py": "95b7acd60185b7eb5c47564b4168dc02837fa8f39eada44afa94c43ae90bc975", - "src/transformers/models/mvp/modeling_mvp.py": "5645e5720ecbc74588c864f6557a3c712da2f6453e383f3447b3ac63ddf632b7", - "src/transformers/models/nanochat/configuration_nanochat.py": "bb4059ee31e86009cc1c710c7ff6250c321ea0934c5bdc18d87cf45285fe6414", - "src/transformers/models/nanochat/modeling_nanochat.py": "5bb913d68b0d344f7742d9e1d3d11952907a8f9a47ea1c2cbf63a4c2043fe81a", - "src/transformers/models/nanochat/modular_nanochat.py": "d4a0ff59e9be383702d049d10c54adcd823b6fa2485cffb92b33a812fa8acc4a", - "src/transformers/models/nemotron/configuration_nemotron.py": "664c096eb39b194c8de8ba3f97f81aed660d3fe66fdcf8637a192ea27f358c8b", - "src/transformers/models/nemotron/modeling_nemotron.py": "6a213b0c5ba853eeeb5947245bf3e11483616d51c1af03d62f84ab6259f9960c", - "src/transformers/models/nemotron_h/configuration_nemotron_h.py": "f56b78f77870c390191d85fd98bfa2d424a528a16d5b8b1435b87735c5866935", - "src/transformers/models/nemotron_h/modeling_nemotron_h.py": "881d820a52e35d060febcdb6ef453763a68665ad4fca317a4dd04b27f7d26594", - "src/transformers/models/nemotron_h/modular_nemotron_h.py": "28c97f5c787aea019f3dd1509f64d8f72373c1e62822feed16f430d59c655346", - "src/transformers/models/nllb_moe/configuration_nllb_moe.py": "5ebf7510398cd6012008f7560892247e1a9c8769e85d780202d0974b7de4df36", - "src/transformers/models/nllb_moe/modeling_nllb_moe.py": "fb350865097d11defbfe0f3d399a44f71811d055e941779482742d4b1188366d", - "src/transformers/models/nystromformer/configuration_nystromformer.py": "b5116d57dce7637a8f77167e5eb46eaefdd3af0c4a96286af4a1b5fda6b5e029", - "src/transformers/models/nystromformer/modeling_nystromformer.py": "e691270fc1879dd1e40c5352d2dc01e35b13b665ec7224153cbd06449ae56709", - "src/transformers/models/olmo/configuration_olmo.py": "da39b2a0733385a0beee7baa7b21c4da2700aad12475d87cf90808d65ace1ca4", - "src/transformers/models/olmo/modeling_olmo.py": "a48a2e94daf3861a5649a54ad7bce567f6bda0ceb33062536a179d7534f81fcb", - "src/transformers/models/olmo/modular_olmo.py": "ee0d4ab25072687fad9241cbeb6bcd4b36981bf96622f65f1e91f3536884a977", - "src/transformers/models/olmo2/configuration_olmo2.py": "896ab23279155cf5dac64ee692ee8d8b2ab9182a35d02ce93e6aa44cb4f6f9c9", - "src/transformers/models/olmo2/modeling_olmo2.py": "2c99da55fe6597cb79272971da5112e89c638d5f5d35657e05ab730ad2b9a1a5", - "src/transformers/models/olmo2/modular_olmo2.py": "4577c969451e057972ee829e62cdb7ebd00169026e95d26bb941562f764f1ea6", - "src/transformers/models/olmo3/configuration_olmo3.py": "f518f4076eb958249d903d67762a86b26993988a595fe217f43971e513e10ea5", - "src/transformers/models/olmo3/modeling_olmo3.py": "ebc6fd816956bbd319601e0e6d8b4ee11e6577d3e925ebb0a65a22600ff79bc2", - "src/transformers/models/olmo3/modular_olmo3.py": "c39bed4953bead76fc789d1bb09fd4af17d62ff60997eaaf6b535c6a715e3712", - "src/transformers/models/olmo_hybrid/configuration_olmo_hybrid.py": "26fb338aceb7e1a6d306320c64cdfaa9036312ae7e0b3719cd0509e26f82a1e7", - "src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py": "68f2dfe750d0c2dbd9833a4ca81d114935c62ac2a870f33462e655e7b14b0a24", - "src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py": "85f7615f191f064ca5f24cee23a9af1ee012501e3133916ca043b86de4af823e", - "src/transformers/models/olmoe/configuration_olmoe.py": "a096aa14dc0fbad1481ce7f62f8f03d0dee860fa66a63fb38d560a43fa510815", - "src/transformers/models/olmoe/modeling_olmoe.py": "c398f44146b562fe71dae1ad6665e77faffc92365f3f8623ada885385d7fd806", - "src/transformers/models/olmoe/modular_olmoe.py": "d941507306e149215537a33aa161727245c0798230a0bde01a5540ce75f4b248", - "src/transformers/models/omdet_turbo/configuration_omdet_turbo.py": "d2d9ab0b744e8f04fc72bc176064a690868696e62e16da3a05a93330561faa81", - "src/transformers/models/omdet_turbo/modeling_omdet_turbo.py": "160cf15a5683cf746125b4b0c3c098b3b441fe2539b1931d89fd96d9ca7cde0d", - "src/transformers/models/oneformer/configuration_oneformer.py": "d7625287c165ee8e3c5ee803fefbf32916f3c58332c75a11732cfac9d7a923ed", - "src/transformers/models/oneformer/modeling_oneformer.py": "26a023e08b66402907e98abd1363fd19e26d204e05cadde570bbea4d90cce089", - "src/transformers/models/openai/configuration_openai.py": "5f480b0293fd3427bda76d56c660dea875983b74c29dc6adb670e42a96b03ce0", - "src/transformers/models/openai/modeling_openai.py": "669c67f1a5b18685c8b7a005a5add007a9124c4d00dd1d3d91402dd645033c3a", - "src/transformers/models/opt/configuration_opt.py": "6d4860fc3f7eb75ff719d8d8b46fd317f7e134b6c1de96345ff2d55c0d82aa6d", - "src/transformers/models/opt/modeling_opt.py": "9c39f2a057968e341099236bd49123366f5eeffb30644311d93d5badbf6a9da4", - "src/transformers/models/ovis2/configuration_ovis2.py": "787b8ffe5158bd2fdcd5d9efbed99db4109ed4c94b7e495293e3b3d2661a9455", - "src/transformers/models/ovis2/modeling_ovis2.py": "cb48dd26620072d278ca99652d4fe00ad94c2e056ef83bf122c05bca5f028ab5", - "src/transformers/models/ovis2/modular_ovis2.py": "367abc98dddded798f37b46dc7f643a56bd4554574b2480c88cdb75146135829", - "src/transformers/models/owlv2/configuration_owlv2.py": "6f108c6c5b853db5bca98b1997aee26d89c2b7fed5e5ea65c98334637e8ce71f", - "src/transformers/models/owlv2/modeling_owlv2.py": "a1f37fd7dfc7a97dd318eac8e3084200ac1d7661f7eefa5e5b3468deaf2fe5b0", - "src/transformers/models/owlv2/modular_owlv2.py": "02d5c9649b7011c53ae7abbe85ee3e9c8dc1c9098623798d6afb6d82c267a9ca", - "src/transformers/models/owlvit/configuration_owlvit.py": "05e1b308cecd557782216df969a62e3348bf17be07f83a4422422ac931d9fd8a", - "src/transformers/models/owlvit/modeling_owlvit.py": "5e1c5e6124c328c624f22c0209c874476a38f8377e61455736ea1ee6a434de35", - "src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py": "a367d2457470c4dca9c970a0aafd078e86ea3d4c143b214eb2d355f1c3b69342", - "src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py": "f6e7405da8639b77ba48d2969d5269b2bb1e690a0de076a0cb4b02cb714d7245", - "src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py": "26c67ed5c926805b101f84eb644e4bb95155f76ab083f83ca586c4aeb8c1a341", - "src/transformers/models/paligemma/configuration_paligemma.py": "e105c7d11088815d0e0f81c87a6027d3fa19d0f94408d88b0d526b7f4cf5a314", - "src/transformers/models/paligemma/modeling_paligemma.py": "8041ed488de7d6e4e729c39dc8e6447b00bed5b133a1bc0ec4fce4efe2a4c21a", - "src/transformers/models/parakeet/configuration_parakeet.py": "8094d2cb6356c9c8d0a750e90817cae05540a593e1b965f380ed1a6582f4debc", - "src/transformers/models/parakeet/modeling_parakeet.py": "d31c77fe4ab12f52aa51d67b6afa87b81eeebc4c44db7da4fd7f79e5faab3eef", - "src/transformers/models/parakeet/modular_parakeet.py": "aac30271ca7db42c545af0331926cc3a6c53cac32a5854af100c8f72da0f8017", - "src/transformers/models/patchtsmixer/configuration_patchtsmixer.py": "a05421a49697b8eeac5f72338f0104039c3d3cd0428c6b47f258d45b60ba3172", - "src/transformers/models/patchtsmixer/modeling_patchtsmixer.py": "4370829b090dae6c11b38723d8d7cdb5e6094f04938a97c846542b2389805c75", - "src/transformers/models/patchtst/configuration_patchtst.py": "aa16fe260718f90bf49cb5acac459299bbfd3c65d016771d81ed19f36c746be0", - "src/transformers/models/patchtst/modeling_patchtst.py": "0905c675fa093ea494e042a653de4232ad752faf94d90db85c365181a14db19d", - "src/transformers/models/pe_audio/configuration_pe_audio.py": "9d64dee70020704520d90a52751a869e1f11ce8a93f3c414663716a713146243", - "src/transformers/models/pe_audio/modeling_pe_audio.py": "f6908fa1e29e1417840cca3218333816cef268919b2430b8b09141ad2aa35ff2", - "src/transformers/models/pe_audio/modular_pe_audio.py": "354a8dfb56cb686f34ef7c63262435be9c9ec67c2dd6ca2b5df5c2e91b2024e9", - "src/transformers/models/pe_audio_video/configuration_pe_audio_video.py": "c70bb7a67921801c0ba6d938d83f9a6d00f90acccc70b031d4bd5702667fc4eb", - "src/transformers/models/pe_audio_video/modeling_pe_audio_video.py": "878197f405a2053d206c9d1fe0c99418e24130ca2e6694b478c738ad0e4cc7a9", - "src/transformers/models/pe_audio_video/modular_pe_audio_video.py": "adaa557a9b9685e9c59bbd837c095f4f3e3746613241e7c37ec631e0f0c15295", - "src/transformers/models/pe_video/configuration_pe_video.py": "3d0d8bd4ff4bc9506f8c4c308dcfb8c85436a95662a784a28a6dd72829e2c33e", - "src/transformers/models/pe_video/modeling_pe_video.py": "7d57f7a2e20a76ef52a2508ead16d0bca7f46ca74938c3d903e6f60962f40f2f", - "src/transformers/models/pe_video/modular_pe_video.py": "9ce5def083ff4b6a5fc87006c08f4085500a64fe29beb23c4b5ffcb6de655497", - "src/transformers/models/pegasus/configuration_pegasus.py": "ce8f541ccc72fca3bc9040180ff15cc378985220aab8bf40c423ac2abf322529", - "src/transformers/models/pegasus/modeling_pegasus.py": "d6a69b437518ef2df314fb4911dd222131d03da31833632b76c882af3dc6a1a5", - "src/transformers/models/pegasus_x/configuration_pegasus_x.py": "147586c17c54d9c482e1904f89d26c0855da2882ed173a2200ae1f941b9b8abd", - "src/transformers/models/pegasus_x/modeling_pegasus_x.py": "f553eef33c8a63cee60d20553c6dbcd11e26f87e66d24162b37e7261df3f1934", - "src/transformers/models/perceiver/configuration_perceiver.py": "efa38796d38c916de83b173ec48fdb889b15172e871840f8d29dfcd2f3505bbd", - "src/transformers/models/perceiver/modeling_perceiver.py": "9067785d1003797ab679c7110ecb89be37601d1671430b2191928744012f61c2", - "src/transformers/models/perception_lm/configuration_perception_lm.py": "a62b9baee1e9f6b2419c1fc30d8a3fbb0adb4e8979f7c34cf0663f10dac1ed0c", - "src/transformers/models/perception_lm/modeling_perception_lm.py": "5e265f4aea10581c38261d9243e3271daf9ae5d176e7e80053b2cd8fa9cf3e65", - "src/transformers/models/perception_lm/modular_perception_lm.py": "cb68eb17a00343f049359e5a8edf06771b52e0a1d3e744b9954c482cd63312b5", - "src/transformers/models/persimmon/configuration_persimmon.py": "2a66073d47c5932c85635266c9c00ef032102442490ea430e6c49cc4a0f9e994", - "src/transformers/models/persimmon/modeling_persimmon.py": "be2f54ec07de6cbfb3b33c331081347ec81a000274909393d8e9ec39ad2bf133", - "src/transformers/models/phi/configuration_phi.py": "263005162ce6a676688d54863a31db23503899586b5222772c915eb46b17e65c", - "src/transformers/models/phi/modeling_phi.py": "871d7480a128838ef3acfe6a2ad04393b1c843a596ebdfb5e4a6df8f53f6fbee", - "src/transformers/models/phi/modular_phi.py": "7f211ccc439d8f73e2f18969c8b3df0f5c64f1df010b58a33412d9564f61d9ef", - "src/transformers/models/phi3/configuration_phi3.py": "c328e5f5c8a05cb7d54b16aad04d4614f9cb5cbc6350c14d51efe3f3ef52e3dc", - "src/transformers/models/phi3/modeling_phi3.py": "61b1796c27a4dd42e3090b250ab72d7566c92cb8329ab215da2182fb08f353d4", - "src/transformers/models/phi3/modular_phi3.py": "ae530b68bedcf03f26999c39b86b80b2c6e8e4890bd3a1247509af9f647fd169", - "src/transformers/models/phi4_multimodal/configuration_phi4_multimodal.py": "ee2b699d94191317bc645b4042d8cdb99548a73823fa463c824f2ae90f861f5b", - "src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py": "7c95d1468de82c00aed6e53aa0e4f8af78d86dd345fcb533434621433461c6aa", - "src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py": "97b63063b61266813513266b394afb24f06c6a5ccd76eaaeb34fffba3218a6c3", - "src/transformers/models/phimoe/configuration_phimoe.py": "02a58ada430ce5fd55bf0a304a218d7b99856a40d01db4a459124e427a0dafb9", - "src/transformers/models/phimoe/modeling_phimoe.py": "8d1a92c9415331e6dce0d252491f50ff78ef80b14c54402af79d05fd3dbc7fa9", - "src/transformers/models/phimoe/modular_phimoe.py": "82e8ba6b83af63f74a16c12e4b6b15f04019beefdd0567c3be23daa8728297b5", - "src/transformers/models/pi0/configuration_pi0.py": "4f016c235747a8302ac8b63f3ec18044beab4dd5c70a337925682501ef68d719", - "src/transformers/models/pi0/modeling_pi0.py": "94c58c3eba1eeb135cef63b5c6fd34bce3cd442b4b2bcc31c21e31284e694212", - "src/transformers/models/pi0/modular_pi0.py": "682003320af135f1c5f09fbf6960f5d26360e35f3b4f2e930ad678e4ec588e4f", - "src/transformers/models/pix2struct/configuration_pix2struct.py": "bf78a49afb4f65fc80cd18816fd80b72acd409a0531b6f66614546748b98ed4f", - "src/transformers/models/pix2struct/modeling_pix2struct.py": "589e4a4ed8a4a09094e66894a046e63a8cb596552f7269bcf62e87775e1ab1ef", - "src/transformers/models/pixio/configuration_pixio.py": "72cef8841ed4991cd2a5af27211676923bb64bae7fa5a400e159d469edcbee7d", - "src/transformers/models/pixio/modeling_pixio.py": "f5b640a57d463b54fd192dbb9709dfe2eb7812426d77c247af534e41c737db22", - "src/transformers/models/pixio/modular_pixio.py": "eef6f9cc338334c3537d269debf19c85f89c20a004c8fd67a923ef7a5255a8e2", - "src/transformers/models/pixtral/configuration_pixtral.py": "5f52fde8b2e352a2d08c6affc6209d76a67a651f7e8f91bbc970f9e11fb7b80a", - "src/transformers/models/pixtral/modeling_pixtral.py": "1a654b23cc3cd8cd3c4a9e2c5c4f3b323d7cf1f66c45681910ce37cbbd42aa1f", - "src/transformers/models/plbart/configuration_plbart.py": "e8fba9af7c472ec130d6ea35a9947dc1f470510c465f4e26fc4a0e83f67a1c1d", - "src/transformers/models/plbart/modeling_plbart.py": "42d5d0f5cc1c028352f416ba8dde36dc1b5f518e6062854e1a8106134fe8732e", - "src/transformers/models/plbart/modular_plbart.py": "966c40e76a8359136d30ed3c423f03a0505a3946af63b396a7cbbd524963c98b", - "src/transformers/models/poolformer/configuration_poolformer.py": "03f850ecb729e481275878b8f00326c27e40ab42504489ecb842bd2468a6a5d8", - "src/transformers/models/poolformer/modeling_poolformer.py": "1a2d2cba37f9d6814b973d1cb1ae5af562716be0ccfb93b16273cfb2850f4402", - "src/transformers/models/pop2piano/configuration_pop2piano.py": "638cacdca584c45b729005badfb0548911e95a1530d34625a117a6af770d5a52", - "src/transformers/models/pop2piano/modeling_pop2piano.py": "2336faedf16d81f2107053866eb1dea174b09f901def9d028bb9b14d744b351c", - "src/transformers/models/pp_chart2table/configuration_pp_chart2table.py": "0d6ef41dd9027aa167390534426a88996bd2b99821e90f2ddae726cd74fee32b", - "src/transformers/models/pp_chart2table/modular_pp_chart2table.py": "f07f36144fe0c1fc137ec92ed7ae500d76bb13fb7a880b1b386b47bf9dcd4f4b", - "src/transformers/models/pp_doclayout_v2/configuration_pp_doclayout_v2.py": "1d8e6147d0ed203e53b23d6c127525a1a8042b292313e1d9b7452f83b24bf8a8", - "src/transformers/models/pp_doclayout_v2/modeling_pp_doclayout_v2.py": "ec06e14eb3c3810a77d2ae1ea2650edf4bc1e64b1a110071860e22f30e87e06e", - "src/transformers/models/pp_doclayout_v2/modular_pp_doclayout_v2.py": "c8b62e3792bb6e00f63b65756ad57f3de637fee412bbdf6bb048923cbeac3402", - "src/transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py": "5bd3acaa93534c0f7a2e2a39e4ba7f6b9bfc2fcf6fadbc33afcc4384b4e6d3c0", - "src/transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py": "064342c448a177f479bc07df673f7c3924301c0ba754d1ac6660d8e99bd9ee24", - "src/transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py": "f29b2707b642c4a07bcccb3dd7a95251d49d21c35136ff8515b60108d71f90d6", - "src/transformers/models/pp_lcnet/configuration_pp_lcnet.py": "3b04bf71a4ca210ef0268c2147421d5e54bbe5b9ac770b558daefe46b9043383", - "src/transformers/models/pp_lcnet/modeling_pp_lcnet.py": "c1624c2f4977d9821d859d9a94f3fe52c1ea00f1dee7ec040a53a2cd6fb8ae60", - "src/transformers/models/pp_lcnet/modular_pp_lcnet.py": "ea882bc49fca8720dd370df0f0c181629db31e34692f1aaabad608f4660d70ac", - "src/transformers/models/pp_lcnet_v3/configuration_pp_lcnet_v3.py": "8484e95b379186f3c4ddb791a3f7a78e7cd9dc98b5c7e98254233a99a7551ec6", - "src/transformers/models/pp_lcnet_v3/modeling_pp_lcnet_v3.py": "e0a9c99e026db17b99a57c2e2789f618eebc3c0d7adf9f513f2d56576896c8c4", - "src/transformers/models/pp_lcnet_v3/modular_pp_lcnet_v3.py": "b0e49b53962f43be56ab1a609d1460760daab37ef2e7942e0a357a949d8cf249", - "src/transformers/models/pp_ocrv5_mobile_det/configuration_pp_ocrv5_mobile_det.py": "e7ddce0a384dfe27e30cec3863c4546fa007bed800a72724970303d68742c305", - "src/transformers/models/pp_ocrv5_mobile_det/modeling_pp_ocrv5_mobile_det.py": "b248be15935132fbc75b112a31aa31891e567a9f26b557c068caaa5312dd3a67", - "src/transformers/models/pp_ocrv5_mobile_det/modular_pp_ocrv5_mobile_det.py": "d3d4953fa00063cbb8f5964c3174ba99bb20805d3f8d5934ec035bcc9f54bfd7", - "src/transformers/models/pp_ocrv5_mobile_rec/configuration_pp_ocrv5_mobile_rec.py": "b7b298ba3094690d9e5ce974669622c5397b9ef62a6287d2b2fb80792ccf8bfb", - "src/transformers/models/pp_ocrv5_mobile_rec/modeling_pp_ocrv5_mobile_rec.py": "4f7616efe95db7e86f525b6b91892611f82187ae4f177d4e19f4af75e54fcaaf", - "src/transformers/models/pp_ocrv5_mobile_rec/modular_pp_ocrv5_mobile_rec.py": "99e04cd8d97bc71ae085f8ca00b0a8a953e24f3e6db67bbf222bacda9c6be98e", - "src/transformers/models/pp_ocrv5_server_det/configuration_pp_ocrv5_server_det.py": "c6d42d4abdfdb2a4009d79ed9268dec9377f475337f76cf4d6df2fa9140ad96e", - "src/transformers/models/pp_ocrv5_server_det/modeling_pp_ocrv5_server_det.py": "94bc6c025512d0bf38ba4965f2749351fbc795c47897bfebf3a931581ac34dc5", - "src/transformers/models/pp_ocrv5_server_det/modular_pp_ocrv5_server_det.py": "b52b34ce33108a8ac665a46b8e76146d8b11b72de15411d47071ad6eac9ee162", - "src/transformers/models/pp_ocrv5_server_rec/configuration_pp_ocrv5_server_rec.py": "dd4be03e17cbe2335fdfe7eda64b140c93277e24553dccbaadf6ab18ea143fdb", - "src/transformers/models/pp_ocrv5_server_rec/modeling_pp_ocrv5_server_rec.py": "58432c2ece27e372ab4a5a8e43f76269ab878d5a711ba26f96e542c80f41b4c7", - "src/transformers/models/pp_ocrv5_server_rec/modular_pp_ocrv5_server_rec.py": "9cbde65d7536228eb11fc6a94f96b9a5c4596a95eadcaf94863192dc071ddab3", - "src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py": "ff0de4853fd45f7100c742e9656af40443aa69a077b94191f1dab34fb51ce966", - "src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py": "6d687bcbc80bc6bf0829520eef70e6534a655be798d07e3babbfa23454ce8dfe", - "src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py": "a31e08303c06b9c67ffba757c18ee8e96b94b07ad8da8da6489254d65c496155", - "src/transformers/models/prophetnet/configuration_prophetnet.py": "af504369210782aea21d35e793d161073bec8943cb78923dccc04440aa7764ba", - "src/transformers/models/prophetnet/modeling_prophetnet.py": "65c8a552e463a2c114d471bfbfa6bd92e196f4594dce0aa4c2863f6943bf9efb", - "src/transformers/models/pvt/configuration_pvt.py": "ca034732e3305db37178f18aeb142a2ade19f043e1ca9871d9620b2903711de5", - "src/transformers/models/pvt/modeling_pvt.py": "489262f1b8dc6a946d820b4b48b70cdb072a8dabedeaaafbc90aef506ba0923b", - "src/transformers/models/pvt_v2/configuration_pvt_v2.py": "519655c4ba46a21a1ed03d3bb473193c8ba11e051889dd7bec607c87ca6f2d55", - "src/transformers/models/pvt_v2/modeling_pvt_v2.py": "784dc6eb1b7f5b3b01e3b9478bf84a89f26eba1a6300fbdf0f832addd8ac9027", - "src/transformers/models/qwen2/configuration_qwen2.py": "a88024a2ff48e088ae76b53fd7eb540b45c992edfb41d332181c2deb1f47809d", - "src/transformers/models/qwen2/modeling_qwen2.py": "611e61ea4f0881a4b4b3cf77355b0a2ce1230d8a8674d1aa4bac7f2d8d76fc0d", - "src/transformers/models/qwen2/modular_qwen2.py": "0b9da277f60da80c7372ab3004d60a9fe4468d2c17d39e4366ac0901a09d3e2d", - "src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py": "18a218ced4ced4bab28f3466f5a7c1237c9afccc44c801387b92322fe16f2f25", - "src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py": "286589c8fe5e8a3b8441371c4adfaf21659411b4c6d1862169742928a2d405f3", - "src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py": "191e242c1c73018b7e5b4004ce29d87e72b6651262441a6183ccc463697d7137", - "src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py": "ad6592cb9138074e4370c74d57ab4d44afaf05fc93f159c34e154e019e5272f8", - "src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py": "a5aeca9101d862571dbb83b52ba2bd885ccc6c79e11bf904149ceb7aa4b7219e", - "src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py": "47322127e2a622f0738dfcedb0a5eb3189cc819b532721b842c2577454d8da90", - "src/transformers/models/qwen2_audio/configuration_qwen2_audio.py": "07a42ed953b74fbf4561ac08625da711f5705efb5fa0ff7ec3f6b399bc13e476", - "src/transformers/models/qwen2_audio/modeling_qwen2_audio.py": "592b2a7584a988ac564dea3d04f35d4e908167db523e105e2b7133a43601db6e", - "src/transformers/models/qwen2_moe/configuration_qwen2_moe.py": "78f97945f5605d79937bc923d32f5c2aa029cd7241a30daa2b24fc7d89e69380", - "src/transformers/models/qwen2_moe/modeling_qwen2_moe.py": "7e0f78b5e4b8e704f5520aa38ace6c71c1ba363aa73376e42ee1c5650457fc3c", - "src/transformers/models/qwen2_moe/modular_qwen2_moe.py": "30f8156bf6dc9d17a88cee36305f6faab485349e5f04cd49a51c72ac2c81559d", - "src/transformers/models/qwen2_vl/configuration_qwen2_vl.py": "2c62d3360f08cea9fa805165f101bed8b0f07b5eda7f06e4213362b2a4f399c8", - "src/transformers/models/qwen2_vl/modeling_qwen2_vl.py": "cabb30fe2fd1937eb3c13bdd05b9e58cd0c6487d3d5d9188b257e049b760e5da", - "src/transformers/models/qwen3/configuration_qwen3.py": "7b902535b5700dacc4ffefc84789809b89655c7b713936d458cf626cc055be8a", - "src/transformers/models/qwen3/modeling_qwen3.py": "94e34d4fab68a3a30072d4772894745f7d621c87e643372da51c279e1538b80a", - "src/transformers/models/qwen3/modular_qwen3.py": "3687b8bf105a41fe81d8ba4825afc25dd5f1475bdf35b64ccd558dc829a8e1b6", - "src/transformers/models/qwen3_5/configuration_qwen3_5.py": "83c4c4000c1ca0af602ee729a392d99b31942d48d2727461592403172289c146", - "src/transformers/models/qwen3_5/modeling_qwen3_5.py": "97e6f3850aac8713d77304ad07fbe9c97ee8226b97639f877511eabf74c37993", - "src/transformers/models/qwen3_5/modular_qwen3_5.py": "56dcd3cbe72f8733c5b8870f0520e6cd2a88a2448cf1dee66943c7eadba6edb8", - "src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py": "f46a65317ab75563341d76cd464a6fd511cc286851bc13c6af9fcc7ae577bd4d", - "src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py": "83e8ea2c8e331ae5ce81abe5cfd13b82694eac0eb7a99e4c7be6f80d1961ecbe", - "src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py": "9c13121107af598758ddbe05e226a907e30d23883a040b29ed86d485f87c873d", - "src/transformers/models/qwen3_moe/configuration_qwen3_moe.py": "d5cb43531bde34cde952e440055c1684dd9f71d92cfbfbff3ffe5a8a4f315261", - "src/transformers/models/qwen3_moe/modeling_qwen3_moe.py": "d2fcafdb7f784349d901345fd8f43cdae16215e8836197d085227615e209165c", - "src/transformers/models/qwen3_moe/modular_qwen3_moe.py": "43f488cf6c2dc6d74eea85beaf546556120db6a03150958eab50d14bcafa31c5", - "src/transformers/models/qwen3_next/configuration_qwen3_next.py": "7ec2aaf1964c5c563dae05e8caedad76db5b78b8cd7644bfb4bfb0919087d5ca", - "src/transformers/models/qwen3_next/modeling_qwen3_next.py": "12d66b39b73b040ee7cac784598d92de61b55ed9f424679579be1d881c6d14be", - "src/transformers/models/qwen3_next/modular_qwen3_next.py": "945db2183862039da809c85e0617d31b2d3dfb83dafe6ccedc6e94f12f479ddf", - "src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py": "6f28a3eb77c4879a061327e45bbbe9c18ebda3365b09ea7c0710239fbdd5e4cf", - "src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py": "9f8d8e5b25d17216e06e4f98bae534ab462b33c5c0d15777ed6a45f1516dac81", - "src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py": "ed35683349628e447cfbbf003cdb2d413a80d8733ecdfb6bf7715a4ced22e0f5", - "src/transformers/models/qwen3_vl/configuration_qwen3_vl.py": "bda6053535dfd5afb6dc5f861d3b0ae60a3f3bdfb0c5a1511baec1f99e38ba77", - "src/transformers/models/qwen3_vl/modeling_qwen3_vl.py": "c6b393ef254833a5a74bb0d3284d374e0b9bf3cc06bf0f9ba0ede6e0d34c108a", - "src/transformers/models/qwen3_vl/modular_qwen3_vl.py": "2d54c8312cf298eda35d97b91a2d1dd598de7f7fcb9de5e7cdf31fb36e45cf91", - "src/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py": "c296ebf5b2f5c64dcf99f505e2a76d58513b788542cbf3840b41340bc85a2a7a", - "src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py": "463c8ef897fd324fd4d1acb8f3664b83f846125292e4e5f40c975677821465ed", - "src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py": "c8203f36a93399cad0a2d0e8fe1e3c2197e63626408e8df963c01e92acc78d64", - "src/transformers/models/rag/configuration_rag.py": "a730529a627ed83af4c439b672e6243a868a6650a8b8d0d26bea5fd824a3db89", - "src/transformers/models/rag/modeling_rag.py": "b98c3ad8895d610274f4f371959c0eafbfb27281b102d783af42fda3af9121c0", - "src/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py": "f43e56390a9025a8e9cd976833f5a2d30a27c67adbf051e6978bb7688ca114dc", - "src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py": "171ae9d240d6aa199d11b03675d0cf236675140a01b7f117001d7b0ae12dc3f6", - "src/transformers/models/reformer/configuration_reformer.py": "b43964a765ff07bd99009106a3cc80d7d4333649dfe4fb1a6f87fcabea16acb5", - "src/transformers/models/reformer/modeling_reformer.py": "2b42c2c13a1025c61ab375960fe6ad277a9adb5febd89f71f7b1707616418b0e", - "src/transformers/models/regnet/configuration_regnet.py": "438bba88e2093027e77456f87369c9510ec905910776494d566f06a2cbd85032", - "src/transformers/models/regnet/modeling_regnet.py": "6bae13c5488e15f2ab600baca82409b683a5b13a7742bd1dd079cdbb8733a123", - "src/transformers/models/rembert/configuration_rembert.py": "a237a0c90c7e307f7405d778f0f2b2b1fb89d7cee9632a3249389daedd3cacf4", - "src/transformers/models/rembert/modeling_rembert.py": "346a04cbd7b81ee907680eed370d92a1b83a186816e426dbaaffd6ca829f3eb9", - "src/transformers/models/resnet/configuration_resnet.py": "833b96ca9efdd457ff5771c68d2f6de3f08ecb61bd8c4d169c6cc905ed5dcfb9", - "src/transformers/models/resnet/modeling_resnet.py": "78771bb35f85d3853e73097244ec11c3b1addc4a4ff251621cac9c36c12b0b94", - "src/transformers/models/roberta/configuration_roberta.py": "e579a667c856b6db5adafe8c9984baa14e587bc5da0dda15ea3a4c8a1e2bfb29", - "src/transformers/models/roberta/modeling_roberta.py": "a8f6586849f29b3b74dc0a37c732ef5b414ad4e23f6cc90f28b913ad2659a86d", - "src/transformers/models/roberta/modular_roberta.py": "919750780b4936c2ba5d1d411ee7b92240e2de7643718d7fdf59658b70128a56", - "src/transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py": "10606a5251c4fb27d2b2eaa127f78676601d30be6c2360e22db80823224419eb", - "src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py": "78a61603110029577db51e369badcd141a83194b816ef7d065096983deb01a52", - "src/transformers/models/roc_bert/configuration_roc_bert.py": "3f0b03a86dec3eb2a110c6a4378a3b87d9c909bcd8c89e521203d7ec5f503cda", - "src/transformers/models/roc_bert/modeling_roc_bert.py": "6a683a01a9490b5a0a0679cc105cf2bf8dbcc3c06933befab9c83fdcd9efe99f", - "src/transformers/models/roformer/configuration_roformer.py": "f74e96078410c8a25aa49f667a6e96116da7a46c4a981f464b2a05c7cb674cdb", - "src/transformers/models/roformer/modeling_roformer.py": "7a4f6b71e1064702272ed49901238ff07ce8207eda90f0e15b910550c0953a25", - "src/transformers/models/rt_detr/configuration_rt_detr.py": "b86a17451e5f75d96848ba694ea40491526d2ecc2e8a147688dc7b11a300b6b2", - "src/transformers/models/rt_detr/configuration_rt_detr_resnet.py": "55188876b979401c0ece58f3e5fb8f0c670f7af2acd5575d86750bdd569bf8b7", - "src/transformers/models/rt_detr/modeling_rt_detr.py": "82d696cd7776f1b76ea6b6239640c3f58e073c34df77b6479014d2f7b2c4c619", - "src/transformers/models/rt_detr/modeling_rt_detr_resnet.py": "cda38329575133e986d93788b84440f74a5d7982d57189be056fcfd87b5445c8", - "src/transformers/models/rt_detr/modular_rt_detr.py": "a2ed2bdf2925b480683b969634328bada856f30ab570f1551f4b10a5bea066c2", - "src/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py": "175bf7f1afcd08eadde3102d28b00e64d6217cec59ad2388fc486eb7f90b4523", - "src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py": "7b2ccea0aaa432d654d30d79a727257e9be356e2868780882d891f978cf97281", - "src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py": "c0ce70c8333e68a8945d1ed4ae81b355c5bcea17916d2126998deec4a5058704", - "src/transformers/models/rwkv/configuration_rwkv.py": "963e5e4595dfa25fcb9cfcfabdf1ac81c1402cdcd23b1b0e14c94f9dbc6c9ccd", - "src/transformers/models/rwkv/modeling_rwkv.py": "ffdce2de84d38fe4478be43aca0220e5ecb129c51155af72a632d50bb5087058", - "src/transformers/models/sam/configuration_sam.py": "de4813d26f76e04824ec7fc50491094d5bad507f71a6276cd4e2e2f423d5fcfa", - "src/transformers/models/sam/modeling_sam.py": "96bd35c10a1d5b66530e745b45129aef67f954078b6ccc43fe54d0cbc06d1920", - "src/transformers/models/sam2/configuration_sam2.py": "b042b6ac1ce51e87ce4dd2f1b62e152cb2595f0c5d88f8d9deb8fcc274a1ca93", - "src/transformers/models/sam2/modeling_sam2.py": "6254a79f9f06bf3eaa5f382ad7499fd3978710b2a73750a740c5e937139add9d", - "src/transformers/models/sam2/modular_sam2.py": "b57ecc1a0bf268f0e0b30e876502860db3b1f854a012b8d8ecaad15f47f54af9", - "src/transformers/models/sam2_video/configuration_sam2_video.py": "abdf32de02ce10a3ff1de2fbe850b9e7dd5bf6f8003b3dbff7500e66485083e0", - "src/transformers/models/sam2_video/modeling_sam2_video.py": "22a2d8c4d99490f4b193b918c5ce90d427769d328d180b42e8cffe79532814f9", - "src/transformers/models/sam2_video/modular_sam2_video.py": "95d9e8875165e0d4fe68a520c6d4e93b9bfbeb7f3a4bca01bdd3777b302e45d2", - "src/transformers/models/sam3/configuration_sam3.py": "23bee427162c2c33833150320f65a96d42006f20feb3791e52d4c8536323fd4d", - "src/transformers/models/sam3/modeling_sam3.py": "d450474b0c9747987cf7f723d3abb728c5967fbe80b793e48fe6178a4fd4cf03", - "src/transformers/models/sam3/modular_sam3.py": "58d1654c4c7a6089ee882f83e8d285e79a151ba51a93fdf45272f3327fbcac0a", - "src/transformers/models/sam3_tracker/configuration_sam3_tracker.py": "a82deb999c8374dc96c34139c5175b6c71dcc7ff2c20a6c084154a4402e1fbcc", - "src/transformers/models/sam3_tracker/modeling_sam3_tracker.py": "0c799adf42bb9c54b36c5eb99442c180981d328af3555b4b526497d097a4333a", - "src/transformers/models/sam3_tracker/modular_sam3_tracker.py": "0b039a6b18b87b9eb347a5396ac1a7f1c02b109cb4d184d579b6866d166caa1d", - "src/transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py": "cd2241e9f42767bf711d6e5a1e2cb315fffb0c0becd6cfefd26ab09fb02914ce", - "src/transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py": "5afc119915836d4369dcf53e3e074555b488beb0333dadd9b3206dd917942c03", - "src/transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py": "8dc65910cb8d09ccf9818d46a63e490303867dad11553c6388b9402159037ca9", - "src/transformers/models/sam3_video/configuration_sam3_video.py": "17b7a40f002bee963677665ed7c23b93590ea7af17111753db9a724ae3253f15", - "src/transformers/models/sam3_video/modeling_sam3_video.py": "c573ceb732dea819db1dcff8f242474159bfda7b2484eed77110396adfd10ed3", - "src/transformers/models/sam_hq/configuration_sam_hq.py": "cda2c3b8a62fb757e01c1a555cec49028a9a45db685c76cd4779ea0154e76940", - "src/transformers/models/sam_hq/modeling_sam_hq.py": "75fc1413ab47473b9502179f7dd281efa4b7ef8ec4e9847daa3522844f1cd784", - "src/transformers/models/sam_hq/modular_sam_hq.py": "55f90de4ba901867f96e593525648d1aa5f383e08224bce8662faaec23b214f4", - "src/transformers/models/seamless_m4t/configuration_seamless_m4t.py": "b6166c4f198ed9b7eebad18d3b1024ee8535d1620980ff00fbda43f568a4eb63", - "src/transformers/models/seamless_m4t/modeling_seamless_m4t.py": "b40a4939c4c2ed733fa7e6727641ee300eb935326024b9d672ee314eeb206641", - "src/transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py": "db2835bc60fa9e9fa94a8c2f0b2909dd01603adc4ef8bdf9f14c19b299b614b7", - "src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py": "cb83cac78a10a30313e32c6703f6c2e1bc13050c6951f7ed8c5e205f54729504", - "src/transformers/models/seed_oss/configuration_seed_oss.py": "0a908f8903f79fa3d7bf2857607133d1b86f44d8fd736faaf82fc8270487c932", - "src/transformers/models/seed_oss/modeling_seed_oss.py": "b48bd86899465b19663360ddcbfef1aa7ab05ed73ca3a82cba71c6693674fea3", - "src/transformers/models/seed_oss/modular_seed_oss.py": "5a8f3b555035b7ea87d70d7618c29269901866266acfd83818abfcb5c2e24c44", - "src/transformers/models/segformer/configuration_segformer.py": "7c9ef3137748d7d1afa3297d5cccd51e573888ed9aae67d218f07e7544161bf0", - "src/transformers/models/segformer/modeling_segformer.py": "50b5409817ab3152c8c5b7ef11db97647f8b5491ecd44ae510b91445e5ce7485", - "src/transformers/models/segformer/modular_segformer.py": "1af15e94abfe23da08252b783bdee98501ecaaafdc1ac3d4d58f29ed300fe54f", - "src/transformers/models/seggpt/configuration_seggpt.py": "f153230dea96743f50ac4d4ec5d1ea409b7840d6a32ebbfc21c855d70bf577a1", - "src/transformers/models/seggpt/modeling_seggpt.py": "cbe63676fd866d880d943f3d0df79fdb4527851e37be29482019c82a32d6d77f", - "src/transformers/models/sew/configuration_sew.py": "f7facf01cfb83060307adef7df69592eef44e04ffc6b26a50ac9aceec5819092", - "src/transformers/models/sew/modeling_sew.py": "883b8146429e7aa62a8155bffac12c1b16cdf0f99b6f75ecd1a62d3e843d0944", - "src/transformers/models/sew/modular_sew.py": "386f03666184106d140560663edf18f45ee58db68563b9e6d0d2da5abe0f87e4", - "src/transformers/models/sew_d/configuration_sew_d.py": "6c5696296ca4249646524c434ed6b746bc4bead5bd5210a067c8b1e0359b59ce", - "src/transformers/models/sew_d/modeling_sew_d.py": "600f58b34fc0a25c9ffffc17d7833a8095debd3b607a5511ba10ebc4c6ba23a8", - "src/transformers/models/shieldgemma2/configuration_shieldgemma2.py": "5973607a32482e769a38e675bf705ea7c7892a570ffa9625a055529f32513e90", - "src/transformers/models/shieldgemma2/modeling_shieldgemma2.py": "0dee8b7f44496b9e9454c77f191beab71d830ac126c3b9cfadcc55dbe9343b61", - "src/transformers/models/siglip/configuration_siglip.py": "3bff9508fcf4c2ba1a8379a168fe9874e1a93402b211b4f456165a6e53c2c666", - "src/transformers/models/siglip/modeling_siglip.py": "a541f04b950ca33106c52f34f5303adc2882dacd27b77284f61885324ed4ea10", - "src/transformers/models/siglip2/configuration_siglip2.py": "c924e37822921fe7a51d85fc8ba8668fa04d4815cd3e76d53a04b1a3c42b548a", - "src/transformers/models/siglip2/modeling_siglip2.py": "b7ba2b0af87cb31f02ab25cd010009abbd4efd8e4a90dee467c81b35d2156c1b", - "src/transformers/models/siglip2/modular_siglip2.py": "472fae36bb4e62c2ec9e82d0ca39cc7a6622f39f55d52dff757d5bb1b28845da", - "src/transformers/models/slanext/configuration_slanext.py": "87486fa6a600498d14b37b424db7b9967dda52a3e12aab4f1c506bb26393ae9d", - "src/transformers/models/slanext/modeling_slanext.py": "00ebcf239d8e8c82bd8fbdd73035912ecf41bc914db95dd049dec0e1974ec8e7", - "src/transformers/models/slanext/modular_slanext.py": "66528eda541815e67d196bcd5ed2b21129d1371f281e36009f1e9a6e6e368049", - "src/transformers/models/smollm3/configuration_smollm3.py": "2b17fa65fe25ee851963e8530ff36880bf4e29b315706f68856d8fb4c90806e6", - "src/transformers/models/smollm3/modeling_smollm3.py": "451eea89bf67518f9fd6589dfc771ade3adc0572d5a9454f8cedd8e190e04549", - "src/transformers/models/smollm3/modular_smollm3.py": "ced515a3ddbf6a898be9dba866a86aa49ae642b1afead61f2231a3dba09b45af", - "src/transformers/models/smolvlm/configuration_smolvlm.py": "ff218dcef1f1fc3075957b94167d39578a51081fb083a74f47b3fd6b947a87cb", - "src/transformers/models/smolvlm/modeling_smolvlm.py": "de36dee70698809f275aeb7173cf94812b248f1915f6139eb9dd38ba7eddf424", - "src/transformers/models/smolvlm/modular_smolvlm.py": "6d8029d85d12e630a37c0222338cc069745a698ddad7ab4172b1b66cb3cc8d20", - "src/transformers/models/solar_open/configuration_solar_open.py": "ce5ac3768b65df6e89528a903943a7ddae3b77b33e86e2895f841b89d537bb45", - "src/transformers/models/solar_open/modeling_solar_open.py": "0f8ece99c58443212500b47b8402318fcf7923d1cf57dc706023e81d718d0c29", - "src/transformers/models/solar_open/modular_solar_open.py": "7ab97d3960f83c74a3f8b61d78fbe7e176253041cb7edb886e490fd58f6f15a3", - "src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py": "d5ae14b909f586af37a1a7b72b4164550bbbc6c28732706c27571c800eb9d001", - "src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py": "ab74eb94be4bfdf72f4ff2c3eb5127fc21269619ddb322a34ec3112b2f18d9de", - "src/transformers/models/speech_to_text/configuration_speech_to_text.py": "531eb38b46bd8e2f31807aae7be4af1cf09ab9f9271e033f31adab6960552718", - "src/transformers/models/speech_to_text/modeling_speech_to_text.py": "044b29b22f5c6bacc133eaaf343306eb3a44373f2071c77a32b2ac7554e4d516", - "src/transformers/models/speecht5/configuration_speecht5.py": "203e0b0951f1948e7d918ab157749effa7b2675d03883fe5c90f69bf9b322e98", - "src/transformers/models/speecht5/modeling_speecht5.py": "d61e0877206480e4ed436072956927f2c4beaf3b036adfca577f68885562659c", - "src/transformers/models/splinter/configuration_splinter.py": "86b5b1a97377773cb22bfd16e8fcc1ad7b5e70ee888781fffc9420a7f5e78020", - "src/transformers/models/splinter/modeling_splinter.py": "bb84c61c46388ab1784edc8767aa06117efd6b1c2a6847b6f36a69f6d9b6292c", - "src/transformers/models/squeezebert/configuration_squeezebert.py": "2b177257c3341d8a98c545fd37559229b6690db823649f37cdd9fb919ccc14d8", - "src/transformers/models/squeezebert/modeling_squeezebert.py": "a6cb8f6ffac27cea1a56a91679958fd76b118d3cc5c52572dc2bb499a386e649", - "src/transformers/models/stablelm/configuration_stablelm.py": "6c9dff2cc1ad0880471223039f6bead226b1e9f75e7d9c65dc95eb717910ca9f", - "src/transformers/models/stablelm/modeling_stablelm.py": "020c276ebb0d8912fdeb4d66509c2b25526c9839c582158ccf29057368bfd8da", - "src/transformers/models/starcoder2/configuration_starcoder2.py": "df2b41071c6029b2cb4a04df67aac703d97ac7833e1b96bd099e9a87e2a515f5", - "src/transformers/models/starcoder2/modeling_starcoder2.py": "ac13a5196a749b9019590950f5250211f026850c46184e607774dc76d9a2b24b", - "src/transformers/models/starcoder2/modular_starcoder2.py": "45ce614a43276fb62ebd5f2a55aab6f8a978d446fecc935ad9d4ec7450a7ae7b", - "src/transformers/models/superglue/configuration_superglue.py": "7d3a99fd9dd299bd10912700bb798d09aa552406eb4cf1d12fa082ceda202ff3", - "src/transformers/models/superglue/modeling_superglue.py": "b39a02dab86a28513aa3c5c2bca8c10f1ce2555d37d9a0a0bf7168c63e8647b8", - "src/transformers/models/superpoint/configuration_superpoint.py": "c6ab918b4eb62ad04b2308cc753f4fcf28e13dd91803c0620458a610fdf1e155", - "src/transformers/models/superpoint/modeling_superpoint.py": "1a4dcfcd24292ddb49d0f829071296182c119ef93003b606ecd6eaeccda3752e", - "src/transformers/models/swiftformer/configuration_swiftformer.py": "30b4493452f1076f009043bf284dfba697887f929018e8f31256398a18f51f2f", - "src/transformers/models/swiftformer/modeling_swiftformer.py": "b908e5359b1de228046dd369a553bec6e96ab89376cf7cab00d757d03c9af4a9", - "src/transformers/models/swin/configuration_swin.py": "f644315ba345e8a2ddc1c3bb094f612a3250e8b87d26c1f737238c623fe16378", - "src/transformers/models/swin/modeling_swin.py": "7f1a074e152692bb781439dde56a0cdd77afa2d8292192168b51348051d0ce3c", - "src/transformers/models/swin2sr/configuration_swin2sr.py": "0da941a458fe11d3d558335ca75493adeac53bf2b14564b627696e1e2306e0f3", - "src/transformers/models/swin2sr/modeling_swin2sr.py": "8f9d31c11511119759ca72a1bfde5aaa1d8c19b8a2b3f3890a9aeafccc9ee193", - "src/transformers/models/swinv2/configuration_swinv2.py": "85ff6596e90a2fb2c4eacbecf949dcaef95c0003fc41e728472d48d14a74226b", - "src/transformers/models/swinv2/modeling_swinv2.py": "aef42ed7702fec6ae5ca86f766444a5e9493fa41c4181a3ec868e1ab7fe27256", - "src/transformers/models/switch_transformers/configuration_switch_transformers.py": "e08699e3ea2efe898a5b55c91baa6d23a20509020491c88dc3e491f04e284e00", - "src/transformers/models/switch_transformers/modeling_switch_transformers.py": "e6e26d6f031962ae376df836ca1c1a978f22768b6c8daed1be55a32fb3bba21d", - "src/transformers/models/switch_transformers/modular_switch_transformers.py": "4ab72c726d4a7e1eb762c194c78ce935c3daa62f8510f86166d64c0ea4a172ef", - "src/transformers/models/t5/configuration_t5.py": "258f1fb36f3116f2cb8711eddd68076956ae9455e7ff78ca5b1de695f6d89889", - "src/transformers/models/t5/modeling_t5.py": "4c0373817c224fd38a60337829cd10e31be0cc43aacb958134618a394b675e61", - "src/transformers/models/t5gemma/configuration_t5gemma.py": "6aff9e558005d4696849f4f66f21521e622d03bfc474360c7ccb0df9eeb874cf", - "src/transformers/models/t5gemma/modeling_t5gemma.py": "1830bc343bbf842f42d955e8ee7dfebe2ae4468a680de73efd3a2b17f122c966", - "src/transformers/models/t5gemma/modular_t5gemma.py": "b44b54873a282bcd159c5552d3dbd5a0163540315f5e378235b07658e24d648f", - "src/transformers/models/t5gemma2/configuration_t5gemma2.py": "a66102c62f5f7bf0e51c29e173f41b9db6a6b695f5dceb823dc6cb1ca6751a2d", - "src/transformers/models/t5gemma2/modeling_t5gemma2.py": "7865c70886cc526235b2915cb0814962b0eb27df40b15d24ed3435f46c22b6c5", - "src/transformers/models/t5gemma2/modular_t5gemma2.py": "5feee8c51df3513aa3f75f0c7d6c9404f5cb50ba03ca5dd8f4fc92c26cf250ff", - "src/transformers/models/table_transformer/configuration_table_transformer.py": "ee2f3cfa80d7cb195a2392f22ceed3e969b87268ba73266eb18d393e911a4e13", - "src/transformers/models/table_transformer/modeling_table_transformer.py": "eb6be8eb35da0216ac918b714791a11bd8499b4eb301e05529733bdb531a11d3", - "src/transformers/models/tapas/configuration_tapas.py": "a2b106513778866670262009862e55c263b2072e1c4bc16db75761493cdbe880", - "src/transformers/models/tapas/modeling_tapas.py": "24f9efcf6eaf9c1f26c40f0fdc185b4c964adbbf6d1b1d150253b7b6f6f7f29a", - "src/transformers/models/textnet/configuration_textnet.py": "059ba9112d3a8bfc909ffd1ad8f71d6564e43636e01c5e8c5170136605cbff19", - "src/transformers/models/textnet/modeling_textnet.py": "33a6880b52c6155694659d000433eb01c2e495001b0cf02037fd1d0581c11591", - "src/transformers/models/time_series_transformer/configuration_time_series_transformer.py": "7d0fa2ba71464b3b76645cbffc061f8e0805cafe22c81d652ef1963e204709a1", - "src/transformers/models/time_series_transformer/modeling_time_series_transformer.py": "1678789e34a1aeaca9b3eb7c8289b3b0e5c586355c32b1742910acb441da5d18", - "src/transformers/models/timesfm/configuration_timesfm.py": "2abf26d41f2c8be3ca7fe814a65291b925afcd2983b1bb11f2c726dbd5325a62", - "src/transformers/models/timesfm/modeling_timesfm.py": "867ec94db7ff94558a6cf892c9c7909729c686f30a010a04b9fc328f7068247a", - "src/transformers/models/timesfm/modular_timesfm.py": "2f89a6f2b18a615db38e1a95af8db7652ac7590be6e75ee4def2e19ff2c1fd8a", - "src/transformers/models/timesfm2_5/configuration_timesfm2_5.py": "12dfb4fc3a4c51930f9e8eaa9bed7fc4c8c22ce46635d0eee4525cabc24c2068", - "src/transformers/models/timesfm2_5/modeling_timesfm2_5.py": "db9ac2128206101e0e925eefbb054652745d2bb4256b25f4b2bae2f70d8bee19", - "src/transformers/models/timesfm2_5/modular_timesfm2_5.py": "e007a0eeb98cfbf2f31a6439eab64e364af8518e87ce14e6e67b22d5c0c1719a", - "src/transformers/models/timesformer/configuration_timesformer.py": "7e95903fcde64369ff098dc5884b782dc71e48d118518feefb129e1eb3a4bc5c", - "src/transformers/models/timesformer/modeling_timesformer.py": "c4df75c7a4231c3f30be58ce69dcdbaa70ef657e610332f44d4df664b8c59839", - "src/transformers/models/timm_backbone/configuration_timm_backbone.py": "c24ef6bbc33e5f1019badf9e6acd52806789ced26b57c7438c30477ba9791221", - "src/transformers/models/timm_backbone/modeling_timm_backbone.py": "2b1e156de437cd34efcbc6792f3fcb4ac9d8504e8a6fcd3b4cc5e1e42074d3f1", - "src/transformers/models/timm_wrapper/configuration_timm_wrapper.py": "d454a21d942d854f509fc48c805a84d2c7c73f52529bc90a94b0db893f5b9d8a", - "src/transformers/models/timm_wrapper/modeling_timm_wrapper.py": "7e75b6441add311d7fd017843cd0695706ce36831cf6dd60202f670ade87769a", - "src/transformers/models/trocr/configuration_trocr.py": "671de0e15378200e18534270167dbce96c2f0d991145a8a317ac0851d8c192dd", - "src/transformers/models/trocr/modeling_trocr.py": "a229d1e2ee8882e436eb4d1f85f4558b1752e092c37c51d004d83f24bb00c380", - "src/transformers/models/tvp/configuration_tvp.py": "7df7b40ce53e32d24db7f6ac6fb22b6883f6889ad7c48c6bd7df8e6c89dd67b9", - "src/transformers/models/tvp/modeling_tvp.py": "fb02e73fad52996f171a2e21c5bcfa72557d17e5bc5b13b17a7ebe02a44f688b", - "src/transformers/models/udop/configuration_udop.py": "3daa483ad3f04060b7594f4f253f07858324b09b226321801d0d98cd09931811", - "src/transformers/models/udop/modeling_udop.py": "7c3380847038a220cc7cf254ae9791e40630c2b00c5d3c0352212806204d9dde", - "src/transformers/models/umt5/configuration_umt5.py": "d3fa5f25a6c79c5937d706887f5c75638566ee4b29c76f58e904b218b0df558b", - "src/transformers/models/umt5/modeling_umt5.py": "659a167590580f7462b78e4eea17967ae22111c956dae9d6073ba0b9b2fd53f3", - "src/transformers/models/unispeech/configuration_unispeech.py": "fc6690fbe0c6a3bea6a274e2f8c5e64a34fc4bf900bcbeb6b6660af8f37a48a6", - "src/transformers/models/unispeech/modeling_unispeech.py": "243489cac13bcb2b311b4aa5b7bdf757ca68bb0b6a9f14a19fd07fdc3ae996ca", - "src/transformers/models/unispeech/modular_unispeech.py": "a6724ad3c95e2902d3f35efe618a9c3f22b70f33d584d8471fc8237662f1debd", - "src/transformers/models/unispeech_sat/configuration_unispeech_sat.py": "f65db34f18861f13c083d6016256c0a3dd0da817504ff6b433d2292c09005a55", - "src/transformers/models/unispeech_sat/modeling_unispeech_sat.py": "80499bea64eab21a1c70d4f5f3444f1850bee15d7aa2c88e0a1d08ad8092c5d4", - "src/transformers/models/unispeech_sat/modular_unispeech_sat.py": "b2185429705c5096cf88d3b2670839a43418a90697aaaef93638cdfabf0a7304", - "src/transformers/models/univnet/configuration_univnet.py": "d1f4f6a3d3b8fc03c961db1bc0312c06d7016af445c51fe0feeb2434fbdc4457", - "src/transformers/models/univnet/modeling_univnet.py": "d9242eaf486d85b4df109f23bca8b0d105b8513be1fb3eb5f2147ed5e0fe28eb", - "src/transformers/models/upernet/configuration_upernet.py": "9dc1c0f1b141826c363a47b38a1d046af86788c3b2fd4d321256ea3501f1d89c", - "src/transformers/models/upernet/modeling_upernet.py": "0b84a17b6e5d141a1b8939e0cdfaac10fcc3fb227ec8876ef6f17db36ec31b01", - "src/transformers/models/uvdoc/configuration_uvdoc.py": "5e220d0899ffdbd8e20f9177524d21d78d189ab90f25f0bceb846bc1a1789531", - "src/transformers/models/uvdoc/modeling_uvdoc.py": "3c68b2f366e3edbc0ac0cc9250e2f4e794ba0f57a40d3ad47e77a25503805792", - "src/transformers/models/uvdoc/modular_uvdoc.py": "d0d619637b45f14892312e6c9856438ce3be718a533381affafc524b0a156ab2", - "src/transformers/models/vaultgemma/configuration_vaultgemma.py": "a1b2e84ab96dc80be8bbbdc511cc95df256d1c26872db3d77ef229a658505a8f", - "src/transformers/models/vaultgemma/modeling_vaultgemma.py": "333d95f12385509b301cefbbb6a6b2c6600c1d749a67c0ff2f0a3c895f2e1a06", - "src/transformers/models/vaultgemma/modular_vaultgemma.py": "1d01e307790ef3d93ba5b72f0ad8cab08b5d7ee91f61be445e700898884a59b2", - "src/transformers/models/vibevoice_acoustic_tokenizer/configuration_vibevoice_acoustic_tokenizer.py": "40079dbafe11652dacb56c028c7e7aaec698f12c863900e0a16f1ef559a6be75", - "src/transformers/models/vibevoice_acoustic_tokenizer/modeling_vibevoice_acoustic_tokenizer.py": "5d69a3c4a17093a3fa18161f49908c445f76f5c22ced1b644686071a295baa17", - "src/transformers/models/vibevoice_acoustic_tokenizer/modular_vibevoice_acoustic_tokenizer.py": "ad785b8322d2957b00a29073ac435a15fa94150499e18ad7ae6ded9a4f06863e", - "src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py": "979315dcc48aa47921ff6e885af86be6ee65ab2daeaf9f1de41dc44b410efcad", - "src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py": "d5b8cf85c5c6fa07b515c45b95ea9c58760707190912b04760d9eceab66dc9a2", - "src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py": "935029559355546c9dd5064a9b3d3e27580f21047ed3aa1d3d5c7371a6930e84", - "src/transformers/models/video_llama_3/configuration_video_llama_3.py": "53f73bb2c36456c760662d56d8c022478b193dba6e1e68494bb2fad2db0b3a12", - "src/transformers/models/video_llama_3/modeling_video_llama_3.py": "1cfbb936132865b1baf843498d9899667513d305f12819a7a904b989e0a89d50", - "src/transformers/models/video_llama_3/modular_video_llama_3.py": "22d197d63a231e20dd6b397de29f4f901a0b1076cbc57e8085f4f0eb3377a321", - "src/transformers/models/video_llava/configuration_video_llava.py": "4b106d56292d32b5bf7a70ee7913feca24f03519ea44709d57ad310f47e0889a", - "src/transformers/models/video_llava/modeling_video_llava.py": "0c267dc405cdb41e5c8e0a97b21f4f4c70db6eb768fa911fb115c86a6ef6857b", - "src/transformers/models/videomae/configuration_videomae.py": "1adb23d83df6bba5340d2bbf8c0163b982ab3bbe07651ec02245e9f38739a02d", - "src/transformers/models/videomae/modeling_videomae.py": "8f187dca972d3db748149ee5a6ef4da8b9f182f600fe2dea5a2dee1a47aa908c", - "src/transformers/models/videomt/configuration_videomt.py": "149df79fceec422db1ac0ae40a7c94c3c3aa12e96607246098aa2cb5a0c920ad", - "src/transformers/models/videomt/modeling_videomt.py": "3a5a66df21962e758b6d8f3e1c2c20a5c2ee5828d0bdc5d409b2f5593ca44146", - "src/transformers/models/videomt/modular_videomt.py": "3672dd7e37e1a524ab877fb528c17f03d5a209f9700701b2cc2321530ee80399", - "src/transformers/models/vilt/configuration_vilt.py": "59a15b92ee78ab5fc85a88717cf05585fda9569f9da7fa5c11eedbe6bf39a66d", - "src/transformers/models/vilt/modeling_vilt.py": "5a5fd75f22d9182c2673de4f93468f6437cbcc0e342b4a6eeb0ae54f4d2e88b1", - "src/transformers/models/vipllava/configuration_vipllava.py": "91384e23e7c8ce045fe2d96c1cf651c3bdd8393f632c6afbfe8e4bbb576ee026", - "src/transformers/models/vipllava/modeling_vipllava.py": "f2906e27064ff719d3ede8fc8176ee2ca1bfddece5068772a06ad60ba51ba8a9", - "src/transformers/models/vipllava/modular_vipllava.py": "dd364d79983d80bddae2507fa0b2b81df41e1a7f0db1593a259d24de44aeed1a", - "src/transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py": "0a81582a774155b504094aae6e6a44579a1f56a2703edb9bf82e7f3b3ec64dd1", - "src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py": "1d71b5e756458ab5c261701416e14dea7bb9602cb2e1dbba3b6571878df49285", - "src/transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py": "a1905873800d646ac3579c6122d81a8632fa58e1f933987a653ee3f113f380fd", - "src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py": "c312c21577612a61b88eacc8f130428ec29cdf5a4475817ded3b5aae1092951e", - "src/transformers/models/visual_bert/configuration_visual_bert.py": "4091adcb8c45a5e25293ee7d9df08ab66d3818f92895d144646d372b320a4fec", - "src/transformers/models/visual_bert/modeling_visual_bert.py": "bb0408273c514ba56bf5d3a1246d68972899f2ba77dd7fba2758d7ae7f9e8eb9", - "src/transformers/models/vit/configuration_vit.py": "f15ab198fe308b52fc539ecf3e055de774719dc00c86c833b4aa51d23987179c", - "src/transformers/models/vit/modeling_vit.py": "900b7703dd256170cbb0a8c59a2ef3f54c93768e8bc1b5d6043c957c6771176d", - "src/transformers/models/vit_mae/configuration_vit_mae.py": "5966821656939a20443efcec6c16415be865fee7d12e4423ebb2d85bc8c16953", - "src/transformers/models/vit_mae/modeling_vit_mae.py": "ecec97ec8440d4109d13e2161cad6dd76ebc87d96855627217ee4fbdb5929f38", - "src/transformers/models/vit_msn/configuration_vit_msn.py": "445d7ccabcc998562fead5adb02fc8cae615f4506d7188d98d7c32d5952f72fa", - "src/transformers/models/vit_msn/modeling_vit_msn.py": "2bd91b70265ecd0a6ed4dafa8ad90b3e8bd07aa0ffa2c869dc3221067fb586e1", - "src/transformers/models/vitdet/configuration_vitdet.py": "99dfc5960e5f8d3d2f9e4f087b629291260ce622e518735691a3a044e5d0c460", - "src/transformers/models/vitdet/modeling_vitdet.py": "359ba606a369817b4c23449e5f4cc4dea27f565ee72dd925f481ec2655e63648", - "src/transformers/models/vitmatte/configuration_vitmatte.py": "a635ce922f5052037ae7c25ef8292805b6e775a7173607ad633e15f670caeeaf", - "src/transformers/models/vitmatte/modeling_vitmatte.py": "5191f4f733b30c082af4f9285a302449ca12d322967eabe8f1b667568e2279cb", - "src/transformers/models/vitpose/configuration_vitpose.py": "1933e851afacf6640480d623dfa6db57b2028f97c480c0a73a00d26c54229629", - "src/transformers/models/vitpose/modeling_vitpose.py": "4c43b51df9ea97e3cc0f6e8ac1d7a71e2a4b04e8c5b38f413edebcd81a81675c", - "src/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py": "42a054ed8ee6e4376dc2db96bede782f3c8addea99e11967b5ae325d5b423270", - "src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py": "c78c5dd6fff44ad7b3c912094344d4c25bf436460037c7cc38fd9149a52e4f54", - "src/transformers/models/vits/configuration_vits.py": "85b5c87a878a9a9c80f91cd5034a6a505ef52aaf5a405b2d0bbbde4a92894bc4", - "src/transformers/models/vits/modeling_vits.py": "58cd039119dc3586a70cecf7e642233681ccee047ba240bd6a537c75b34b5cee", - "src/transformers/models/vivit/configuration_vivit.py": "767519befa6bc580f71bb222a1e8c87772c97252a0c345bbe0e86cc122e62f36", - "src/transformers/models/vivit/modeling_vivit.py": "dcee9e2418de62a4566706267e5e10fe17826e52f6dff01da74508b90426d1eb", - "src/transformers/models/vjepa2/configuration_vjepa2.py": "6dbbcb9e885488a99498d1db0a48950cb62139d0cad0d4c7cd09d28e949a6b7d", - "src/transformers/models/vjepa2/modeling_vjepa2.py": "11ec392039a0c290b820956948b35bdd6872878119ce95d7e8daef97578fa098", - "src/transformers/models/voxtral/configuration_voxtral.py": "beefbe93584882eef99d9c7972127108c70fc7d048cd2d88657aedb5a9614995", - "src/transformers/models/voxtral/modeling_voxtral.py": "d3fa5ee2975383473797d79e6f3ad467d3c8d9cd9a5c825de295bce7ac3fff76", - "src/transformers/models/voxtral/modular_voxtral.py": "a124af843b2fe87f44ce0dfe3cc671010b316ed84150b0109b902dff7864d124", - "src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py": "191b686759d8ad262648c00896899d993d828a5d0833d9bae83bcebeaee22d89", - "src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py": "9e1f26cc719bca3d76068ea6ab7a9f83763f8b0c38b2bc0ac1191602e314c497", - "src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py": "d6c69d7d6af9e774ea92485faf3cba4813d42a9471f01a38096c8d722e6ba224", - "src/transformers/models/wav2vec2/configuration_wav2vec2.py": "8650bebeea0e71263c64bb82d5c2a52d450247d554b6da18fba8a8fff11efd93", - "src/transformers/models/wav2vec2/modeling_wav2vec2.py": "053cd02abe8cbb63a2f23d2eb486bd33723dab884b1718aa6d0923aa215d6424", - "src/transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py": "bb037d4710bf7117c8cd9e8aabef59143ad39a9de315d41b4f13d64a76453ca6", - "src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py": "d89cb5a2532f9986cbdf96663987b8c71b79cefb6501a701189dbb657a33c48d", - "src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py": "651a7c381d07e473f8dab7e0598a306650553230868770bfcfa5221fd2be2aa1", - "src/transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py": "2e138af1532aca4c5a1eb922203933309f1f4c12d7ee5007be36d68d593f37db", - "src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py": "1baeb85979dc5237962b8d2689033647c5e56ad76d8a27603edfd615175e1c26", - "src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py": "d8d459a537db62a26ad90dcdfad25160f17ebde81269e21db64d310c89fb66a1", - "src/transformers/models/wavlm/configuration_wavlm.py": "6c07717bea00a8c6b622b9cd97123c19d3e0744fa76ab3791676ae9a97d0324d", - "src/transformers/models/wavlm/modeling_wavlm.py": "43cd2c54edadfd5f697f68ad64c462d45707ccfabd9d15811213bd5e409340a3", - "src/transformers/models/wavlm/modular_wavlm.py": "146c101d45ef4442ce295f6ca1e0e2cbaf031038f918e7c54b379eef23c10b35", - "src/transformers/models/whisper/configuration_whisper.py": "91a5cc9d8e2284490628c48f5470d1b89a233b16ab219ad9d300b3ffc978de7c", - "src/transformers/models/whisper/modeling_whisper.py": "0d711e4623341a3da969d5fe90841a9ee28a4ac464ca42a7c8ea8f3e1d1d54da", - "src/transformers/models/x_clip/configuration_x_clip.py": "d39e1d1d73090f322369c5679d18f65ac0b897e76d66620fcf66390e3dffe346", - "src/transformers/models/x_clip/modeling_x_clip.py": "3cea422f2a284b6135e7e0efa09a5c10509bb142861c59bd66ad37ef8f241b4b", - "src/transformers/models/xcodec/configuration_xcodec.py": "d2cd3b2c86368b597476825ad2eac62fe8f0409a2c8b1f12cef97784ca733500", - "src/transformers/models/xcodec/modeling_xcodec.py": "03be2a880f1429d8722f700cdd5479f2da555ee4d241be774a122779c71b734c", - "src/transformers/models/xglm/configuration_xglm.py": "2dc8cea98578cb05cbeaddf4ca1e016860f7365c785d1c1a1af6f3f3eb3fa9d4", - "src/transformers/models/xglm/modeling_xglm.py": "7119c77597c7966720bf8c04228c41755a3e225a7b68f74e5d15f02964e9a023", - "src/transformers/models/xlm/configuration_xlm.py": "f0b5d2b6b9669d845540b2b9df5a2b7951354796a47d1bc985f14b23d631daa2", - "src/transformers/models/xlm/modeling_xlm.py": "baec462f1e4308c31b084ade8618fb2041a6a760a6090a4bd57ceaf4a8ef7dee", - "src/transformers/models/xlm_roberta/configuration_xlm_roberta.py": "7569a207cea00ea0eb1e22bf725e071d409f705c512ff1e33cc30a798665a193", - "src/transformers/models/xlm_roberta/modeling_xlm_roberta.py": "c8394866e7f785a6c2652c8341ff0d649dbe76ca619bf4078f66e2ec689b142b", - "src/transformers/models/xlm_roberta/modular_xlm_roberta.py": "8c8469c3867eb85c2474eec51406d32187b1f8169c6517d77b79cdd8ac41dde0", - "src/transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py": "fff0d189298232584d020e3b0b12184855bd49c3c2821c7385d29ae564be0664", - "src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py": "d9e4579799bfce9a55ea16ca5915033a1e463a6e1f2aa37d6fb852b780e1b7c3", - "src/transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py": "0db7040a77c0db9c71adb37374a6fe11655698bdd511c744327a5f90584b4f66", - "src/transformers/models/xlnet/configuration_xlnet.py": "b23ca312b54597df3c9515bd14e2bb8b0bccf7d71722c2030900eeb05f06ac2c", - "src/transformers/models/xlnet/modeling_xlnet.py": "26ef15cc9e19c5f1ad78e7beac1d02620659b4b7811ee07944b082cea346d827", - "src/transformers/models/xlstm/configuration_xlstm.py": "2fa1e1de35a743aaf4faa4399be2f6669990ab5370abf6a5670278ddacbd86ad", - "src/transformers/models/xlstm/modeling_xlstm.py": "658246eb2c90671f0912bc2bae93385a32f8a74b87c21ec334e34a1d6ccbff22", - "src/transformers/models/xmod/configuration_xmod.py": "fc8326fb791bf9165dad3af262e172b8c5e1985f4eb96ec6369a47de27ff3ebd", - "src/transformers/models/xmod/modeling_xmod.py": "019b42bd4b6373dddf2b15a16b9c20b3f0915dae92a7633bb5ef9f2ee1bb5ef9", - "src/transformers/models/yolos/configuration_yolos.py": "8970bd7e2ef458e6063b5fccf136c237b063465b4f7fd7be41c0f0b4e1fc1aa4", - "src/transformers/models/yolos/modeling_yolos.py": "32933319a7ef76a7d2d026abb2def66f2576349759836b23ebbe95682024c56f", - "src/transformers/models/yolos/modular_yolos.py": "67b24ac8dd457ab763f88b08a368c43fe32fb1487989dac63cd60c56e98badd1", - "src/transformers/models/yoso/configuration_yoso.py": "18fcac752ecd69aac9cda2ceadb444204a1b8b2536bb25030d2b34a34571279a", - "src/transformers/models/yoso/modeling_yoso.py": "ee3856557776db9294841314d24f39db87253677984766fcad6c78139a3db423", - "src/transformers/models/youtu/configuration_youtu.py": "72981b4a97b2f39069e4d3a74a008ab8acfc600c898c494662ee2f0c7683315d", - "src/transformers/models/youtu/modeling_youtu.py": "fbc8374c50df05da935a1b040aaf0a90e93689804c7b4e35e8d6b8c0813f66ae", - "src/transformers/models/youtu/modular_youtu.py": "ddbfc95d009cf94a6f3487c0d064c7a3aa2ef6c6015898c51988b5dab9c452c8", - "src/transformers/models/zamba/configuration_zamba.py": "23566c8ef76743f6be134cd1d5853d323c87677161ba2b477a57722e0215f376", - "src/transformers/models/zamba/modeling_zamba.py": "205a0b6cc541fe9d05e811226cfa8e0649cdb6fa634cfa888dd79f6b4f5263a6", - "src/transformers/models/zamba2/configuration_zamba2.py": "c98074770b92db6d888aaf530279b2b4f156357a6d4365cfce7298e88452ab12", - "src/transformers/models/zamba2/modeling_zamba2.py": "c4dbed4799ed9a2374a7eb6c55ebda6c7710310ed3b6d8508469a32c7babb999", - "src/transformers/models/zamba2/modular_zamba2.py": "546cc01d3910e159e566a41a8d86ed1c955eac41391871cf3e3cb2aef7ef24be", - "src/transformers/models/zoedepth/configuration_zoedepth.py": "1a2474f62a1d0e91bb183c1c2845dbbd4af92e5aa5ac2e7d1db96499c5b39907", - "src/transformers/models/zoedepth/modeling_zoedepth.py": "97e08f50416ab8ac3b6c52e6b3471aa3573b8619e85765acc6d0001b1bb8f657" -} From 86de2bc8bd6a5435520ff0d14b2a85d8de73f49e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 31 Mar 2026 11:16:18 +0200 Subject: [PATCH 52/56] more renaming --- src/transformers/cache_utils.py | 38 ++++++++++++++-------------- src/transformers/generation/utils.py | 6 ++--- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 297059a5b1e2..b73f2658ed80 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -670,9 +670,9 @@ def _dequantize(self, qtensor): class LinearAttentionCacheLayerMixin(ABC): - """Base, abstract class for a mamba single layer's cache.""" + """Base, abstract class for a linear attention single layer's cache.""" - # All shapes are static by essence in a Mamba layer, so it is compileable + # All shapes are static by essence in a LinearAttention layer, so it is compileable is_compileable = True def __init__(self): @@ -727,7 +727,7 @@ def reorder_cache(self, beam_idx: torch.LongTensor): self.recurrent_states = self.recurrent_states.index_select(0, beam_idx.to(self.device)) def crop(self, max_length: int): - # We don't crop the mamba cache, so simply do nothing here + # We don't crop the linear attention cache, so simply do nothing here pass @@ -756,7 +756,7 @@ def lazy_initialization( def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor: """ - Update the mamba cache in-place, and return the necessary conv states. + Update the linear attention cache in-place, and return the necessary conv states. Args: conv_states (`torch.Tensor`): The new conv states to cache. @@ -789,7 +789,7 @@ def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor def update_recurrent_state(self, recurrent_states: torch.Tensor, **kwargs) -> torch.Tensor: """ - Update the mamba cache in-place, and return the necessary ssm states. + Update the linear attention cache in-place, and return the necessary ssm states. Args: smm_states (`torch.Tensor`): The new ssm states to cache. @@ -816,7 +816,7 @@ def lazy_initialization(self, *args, **kwargs) -> None: # When the Attention cache is used with `update`, `lazy_initialization` is called with 2 positional args if len(args) == 2 and len(kwargs) == 0: DynamicLayer.lazy_initialization(self, *args) - # Otherwise, for the Mamba cache, when it's called in `update_conv_state` or `update_recurrent_state`, it's + # Otherwise, for the LinearAttention cache, when it's called in `update_conv_state` or `update_recurrent_state`, it's # always called with 1 single kwarg (cause it needs to know if it's for the conv or ssm states) if len(args) == 0 and len(kwargs) == 1: LinearAttentionLayer.lazy_initialization(self, **kwargs) @@ -957,7 +957,7 @@ def update_conv_state(self, conv_states: torch.Tensor, layer_idx: int, **kwargs) # NOTE: if we slightly break `update` arg order, we could combine this with it, and allow offloading support # out of the box if not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin): - raise ValueError("Cannot call `update_conv_state` on a non-Mamba layer!") + raise ValueError("Cannot call `update_conv_state` on a non-LinearAttention layer!") conv_states = self.layers[layer_idx].update_conv_state(conv_states, **kwargs) return conv_states @@ -977,7 +977,7 @@ def update_recurrent_state(self, recurrent_states: torch.Tensor, layer_idx: int, # NOTE: if we slightly break `update` arg order, we could combine this with it, and allow offloading support # out of the box if not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin): - raise ValueError("Cannot call `update_conv_state` on a non-Mamba layer!") + raise ValueError("Cannot call `update_conv_state` on a non-LinearAttention layer!") recurrent_states = self.layers[layer_idx].update_recurrent_state(recurrent_states, **kwargs) return recurrent_states @@ -1001,12 +1001,12 @@ def get_seq_length(self, layer_idx: int = 0) -> int: if layer_idx >= len(self.layers): return 0 - # For alternating attention-mamba caches, `get_seq_length` needs to use attention layer idx when called with default layer_idx + # For alternating attention/linear attention caches, `get_seq_length` needs to use attention layer idx when called with default layer_idx if not isinstance(self.layers[layer_idx], CacheLayerMixin): # If this is called with non-default arg, raise if layer_idx != 0: raise ValueError( - f"You called `get_seq_length` on layer index {layer_idx}, but this layer is a Mamba layer, which " + f"You called `get_seq_length` on layer index {layer_idx}, but this layer is a LinearAttention layer, which " "does not track sequence length." ) try: @@ -1015,17 +1015,17 @@ def get_seq_length(self, layer_idx: int = 0) -> int: except StopIteration: raise ValueError( "`get_seq_length` can only be called on Attention layers, and the current Cache seem to only contain " - "Mamba layers." + "LinearAttention layers." ) return self.layers[layer_idx].get_seq_length() def has_previous_state(self, layer_idx: int | None = None) -> bool: - """Returns whether the Mamba layer at index `layer_idx` has previous state or not.""" + """Returns whether the LinearAttention layer at index `layer_idx` has previous state or not.""" if layer_idx is not None and layer_idx >= len(self.layers): return False - # In this case, use last Mamba layer + # In this case, use last LinearAttention layer if layer_idx is None: try: layer_idx = next( @@ -1035,8 +1035,8 @@ def has_previous_state(self, layer_idx: int | None = None) -> bool: ) except StopIteration: raise ValueError( - "`has_previous_state` can only be called on Mamba layers, and the current Cache seem to only contain " - "Attention layers." + "`has_previous_state` can only be called on LinearAttention layers, and the current Cache seem to " + "only contain Attention layers." ) elif not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin): raise ValueError( @@ -1057,12 +1057,12 @@ def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: if layer_idx >= len(self.layers): return query_length, 0 - # For alternating attention-mamba caches, `get_mask_sizes` needs to use attention layer idx when called with default layer_idx + # For alternating attention/linear attention caches, `get_mask_sizes` needs to use attention layer idx when called with default layer_idx if not isinstance(self.layers[layer_idx], CacheLayerMixin): # If this is called with non-default arg, raise if layer_idx != 0: raise ValueError( - f"You called `get_mask_sizes` on layer index {layer_idx}, but this layer is a Mamba layer, which " + f"You called `get_mask_sizes` on layer index {layer_idx}, but this layer is a LinearAttention layer, which " "does not track sequence length." ) try: @@ -1071,7 +1071,7 @@ def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: except StopIteration: raise ValueError( "`get_mask_sizes` can only be called on Attention layers, and the current Cache seem to only contain " - "Mamba layers." + "LinearAttention layers." ) return self.layers[layer_idx].get_mask_sizes(query_length) @@ -1341,7 +1341,7 @@ def __init__( layer = StaticSlidingWindowLayer( max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size ) - # Mamba layers are static by essence - using `"moe"` as well is a trick, see the comment about it on DynamicCache + # LinearAttention layers are static by essence - using `"moe"` as well is a trick, see the comment about it on DynamicCache elif layer_type in ("mamba", "conv", "linear_attention", "moe"): layer = LinearAttentionLayer() else: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ed8e3947fd5f..d28ac6744023 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1849,12 +1849,12 @@ def _prepare_cache_for_generation( generation_config.cache_implementation = "dynamic_full" dynamic_cache_kwargs = {} - # mamba models always need to pass the config, otherwise it will use an Attention cache for the Mamba layers - is_mamba = any( + # linear attention models always need to pass the config, otherwise it will use an Attention cache for the LinearAttention layers + is_linear_attention = any( x in ("mamba", "conv", "linear_attention") for x in getattr(self.config.get_text_config(decoder=True), "layer_types", []) ) - if generation_config.cache_implementation != "dynamic_full" or is_mamba: + if generation_config.cache_implementation != "dynamic_full" or is_linear_attention: dynamic_cache_kwargs["config"] = self.config.get_text_config(decoder=True) if generation_config.cache_implementation == "offloaded": From f5dfd79bfb4c140fadd26fecd1a51b86dfd0b9d0 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 31 Mar 2026 11:43:43 +0200 Subject: [PATCH 53/56] revert offloading change --- .../models/musicflamingo/configuration_musicflamingo.py | 2 +- src/transformers/models/musicflamingo/modular_musicflamingo.py | 2 +- tests/test_modeling_common.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/musicflamingo/configuration_musicflamingo.py b/src/transformers/models/musicflamingo/configuration_musicflamingo.py index e7e7a15dcde1..562a3bf13071 100644 --- a/src/transformers/models/musicflamingo/configuration_musicflamingo.py +++ b/src/transformers/models/musicflamingo/configuration_musicflamingo.py @@ -32,7 +32,7 @@ class MusicFlamingoConfig(PreTrainedConfig): r""" audio_bos_token_id (`int`, *optional*, defaults to 151670): - The beginning-of-audio token index used to mark the start of audio spans. + The beginning-of-audio token index used to mark the start of audio spans. audio_eos_token_id (`int`, *optional*, defaults to 151671): The end-of-audio token index used to mark the end of audio spans. audio_frame_step (`float`, *optional*, defaults to 0.01): diff --git a/src/transformers/models/musicflamingo/modular_musicflamingo.py b/src/transformers/models/musicflamingo/modular_musicflamingo.py index 5da937bab052..2a9735f78ce0 100644 --- a/src/transformers/models/musicflamingo/modular_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modular_musicflamingo.py @@ -43,7 +43,7 @@ class MusicFlamingoConfig(AudioFlamingo3Config): r""" audio_bos_token_id (`int`, *optional*, defaults to 151670): - The beginning-of-audio token index used to mark the start of audio spans. + The beginning-of-audio token index used to mark the start of audio spans. audio_eos_token_id (`int`, *optional*, defaults to 151671): The end-of-audio token index used to mark the end of audio spans. audio_frame_step (`float`, *optional*, defaults to 0.01): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a68cdba04cd9..d4e8ca71a60a 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2945,7 +2945,7 @@ def test_cpu_offload(self): model.cpu().save_pretrained(tmp_dir) for max_size in max_gpu_sizes: - max_memory = {0: max_size, "cpu": model_size * 3} + max_memory = {0: max_size, "cpu": model_size * 2} new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) # Making sure part of the model will actually end up offloaded self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"}) From 476aaafe4e79fe00d4f766ef8e55901f4e8153e7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 31 Mar 2026 11:49:39 +0200 Subject: [PATCH 54/56] add offloading skips --- tests/models/zamba/test_modeling_zamba.py | 10 ++++++++-- tests/models/zamba2/test_modeling_zamba2.py | 6 ++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index 037f157f79db..48e8bdb64a40 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -300,17 +300,23 @@ def setUp(self): self.config_tester = ConfigTester(self, config_class=ZambaConfig, hidden_size=32) @unittest.skip( - "Same as zamba2 -> investigate, it's probably due to their tied weights that accelerate does not work" + "Same as zamba2 -> investigate, it's probably due to their mixed layer classes or tied weights that accelerate does not work" ) def test_disk_offload_bin(self): pass @unittest.skip( - "Same as zamba2 -> investigate, it's probably due to their tied weights that accelerate does not work" + "Same as zamba2 -> investigate, it's probably due to their mixed layer classes or tied weights that accelerate does not work" ) def test_disk_offload_safetensors(self): pass + @unittest.skip( + "Same as zamba2 -> investigate, it's probably due to their mixed layer classes or tied weights that accelerate does not work" + ) + def test_cpu_offload(self): + pass + def test_config(self): self.config_tester.run_common_tests() diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 1695cc9d5556..a240a4a71334 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -329,6 +329,12 @@ def test_disk_offload_bin(self): def test_disk_offload_safetensors(self): pass + @unittest.skip( + "Offloading does not work correctly for zamba2 - probably due to their mixed layer classes or tied weights" + ) + def test_cpu_offload(self): + pass + @unittest.skip("position_ids cannot be used to pad due to Mamba2 layers") def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass From 7a6928718bb8c812999e62bf229fa9484fd1401d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 31 Mar 2026 12:14:31 +0200 Subject: [PATCH 55/56] split shapes for tests --- tests/generation/test_utils.py | 28 ++++++++++++------- tests/models/bamba/test_modeling_bamba.py | 9 +++--- .../falcon_h1/test_modeling_falcon_h1.py | 11 ++++---- .../test_modeling_granitemoehybrid.py | 8 ++++-- tests/models/jamba/test_modeling_jamba.py | 9 +++--- tests/models/lfm2/test_modeling_lfm2.py | 6 ++-- .../models/lfm2_moe/test_modeling_lfm2_moe.py | 6 ++-- tests/models/lfm2_vl/test_modeling_lfm2_vl.py | 6 ++-- tests/models/mamba2/test_modeling_mamba2.py | 8 ++++-- .../nemotron_h/test_modeling_nemotron_h.py | 13 +++++---- tests/models/qwen3_5/test_modeling_qwen3_5.py | 26 +++++++++++------ .../qwen3_5_moe/test_modeling_qwen3_5_moe.py | 26 +++++++++++------ .../qwen3_next/test_modeling_qwen3_next.py | 13 ++++++--- tests/models/zamba/test_modeling_zamba.py | 10 ++++--- tests/models/zamba2/test_modeling_zamba2.py | 8 ++++-- 15 files changed, 114 insertions(+), 73 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c0d24abda5c7..2184827b0891 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2537,16 +2537,23 @@ def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, c [encoder_expected_shape] * len(hidden_states), ) - def _get_mamba_cache_shapes(self, batch_size: int, config): - # Default mamba cache shape - can vary based on models so this function is convenient to easily check caches - # Note that non-mamba models will NOT have the config fields - it does not matter as they will only use attention cache layers - # so the None default values will not be used + def _get_conv_state_shape(self, batch_size: int, config): + # Default conv state shape, for linear attention models - can vary based on models so this function is convenient + # to easily check caches + # Note that non-mamba models will NOT have the config fields - it does not matter as they will only use attention + # cache layers, so the None default values will not be used intermediate_size = getattr(config, "intermediate_size", None) conv_kernel = getattr(config, "conv_kernel", None) + return (batch_size, intermediate_size, conv_kernel) + + def _get_recurrent_state_shape(self, batch_size: int, config): + # Default recurrent state shape, for linear attention models - can vary based on models so this function is convenient + # to easily check caches + # Note that non-mamba models will NOT have the config fields - it does not matter as they will only use attention + # cache layers, so the None default values will not be used + intermediate_size = getattr(config, "intermediate_size", None) state_size = getattr(config, "state_size", None) - conv_shape = (batch_size, intermediate_size, conv_kernel) - ssm_shape = (batch_size, intermediate_size, state_size) - return conv_shape, ssm_shape + return (batch_size, intermediate_size, state_size) def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): # Raise a useful error, asking to explicitly override the method @@ -2580,7 +2587,8 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l ) # For mamba layers - conv_shape, ssm_shape = self._get_mamba_cache_shapes(batch_size, config) + conv_shape = self._get_conv_state_shape(batch_size, config) + recurrent_shape = self._get_recurrent_state_shape(batch_size, config) # Check the size is coherent num_hidden_layers = config.num_hidden_layers @@ -2600,13 +2608,13 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l self.assertEqual(layer.conv_states.shape, conv_shape) # May not be used (e.g. lfm2) if layer.is_recurrent_states_initialized: - self.assertEqual(layer.recurrent_states.shape, ssm_shape) + self.assertEqual(layer.recurrent_states.shape, recurrent_shape) # Mamba only layer cache elif type(layer) is LinearAttentionLayer: self.assertEqual(layer.conv_states.shape, conv_shape) # May not be used (e.g. lfm2) if layer.is_recurrent_states_initialized: - self.assertEqual(layer.recurrent_states.shape, ssm_shape) + self.assertEqual(layer.recurrent_states.shape, recurrent_shape) # Attention only layer type else: # Remove the seq_length dim for cross-attention cache (it changes based on the model) diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index bf180d0fedf0..fd028512f16c 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -279,15 +279,16 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.7, 0.8] - def _get_mamba_cache_shapes(self, batch_size: int, config): - # For mamba layers + def _get_conv_state_shape(self, batch_size: int, config): conv_shape = ( batch_size, config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * config.mamba_d_state, config.mamba_d_conv, ) - ssm_shape = (batch_size, config.mamba_n_heads, config.mamba_d_head, config.mamba_d_state) - return conv_shape, ssm_shape + return conv_shape + + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.mamba_n_heads, config.mamba_d_head, config.mamba_d_state) def setUp(self): self.model_tester = self.model_tester_class(self) diff --git a/tests/models/falcon_h1/test_modeling_falcon_h1.py b/tests/models/falcon_h1/test_modeling_falcon_h1.py index 2537a76b0847..a6429f2da621 100644 --- a/tests/models/falcon_h1/test_modeling_falcon_h1.py +++ b/tests/models/falcon_h1/test_modeling_falcon_h1.py @@ -253,18 +253,19 @@ class FalconH1ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM {"feature-extraction": FalconH1Model, "text-generation": FalconH1ForCausalLM} if is_torch_available() else {} ) - def _get_mamba_cache_shapes(self, batch_size: int, config): - conv_kernel_size = config.mamba_d_conv + def _get_conv_state_shape(self, batch_size: int, config): intermediate_size = ( config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) ) conv_shape = ( batch_size, intermediate_size + 2 * config.mamba_n_groups * config.mamba_d_state, - conv_kernel_size, + config.mamba_d_conv, ) - ssm_shape = (batch_size, config.mamba_n_heads, config.mamba_d_head, config.mamba_d_state) - return conv_shape, ssm_shape + return conv_shape + + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.mamba_n_heads, config.mamba_d_head, config.mamba_d_state) def setUp(self): self.model_tester = FalconH1ModelTester(self) diff --git a/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py b/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py index 83fdc436b9a9..919bf79deac3 100644 --- a/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py +++ b/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py @@ -302,14 +302,16 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_id loss_padfree = res_padfree.loss torch.testing.assert_close(loss_padded, loss_padfree) - def _get_mamba_cache_shapes(self, batch_size: int, config): + def _get_conv_state_shape(self, batch_size: int, config): conv_shape = ( batch_size, config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * config.mamba_d_state, config.mamba_d_conv, ) - ssm_shape = (batch_size, config.mamba_n_heads, config.mamba_d_head, config.mamba_d_state) - return conv_shape, ssm_shape + return conv_shape + + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.mamba_n_heads, config.mamba_d_head, config.mamba_d_state) def test_config_requires_mamba_or_attention_layers(self): """Ensure we can't create a config with disallowed layers.""" diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index 77500ad06251..28d4b7a18c61 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -322,10 +322,11 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi else {} ) - def _get_mamba_cache_shapes(self, batch_size: int, config): - conv_shape = (batch_size, config.mamba_expand * config.hidden_size, config.mamba_d_conv) - ssm_shape = (batch_size, config.mamba_expand * config.hidden_size, config.mamba_d_state) - return conv_shape, ssm_shape + def _get_conv_state_shape(self, batch_size: int, config): + return (batch_size, config.mamba_expand * config.hidden_size, config.mamba_d_conv) + + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.mamba_expand * config.hidden_size, config.mamba_d_state) def setUp(self): self.model_tester = JambaModelTester(self) diff --git a/tests/models/lfm2/test_modeling_lfm2.py b/tests/models/lfm2/test_modeling_lfm2.py index a44c303c6c97..67698c564092 100644 --- a/tests/models/lfm2/test_modeling_lfm2.py +++ b/tests/models/lfm2/test_modeling_lfm2.py @@ -51,10 +51,8 @@ class Lfm2ModelTest(CausalLMModelTest, unittest.TestCase): # used in `test_torch_compile_for_training` _torch_compile_train_cls = Lfm2ForCausalLM if is_torch_available() else None - def _get_mamba_cache_shapes(self, batch_size: int, config): - conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) - ssm_shape = (0,) - return conv_shape, ssm_shape + def _get_conv_state_shape(self, batch_size: int, config): + return (batch_size, config.hidden_size, config.conv_L_cache) def test_attention_outputs(self): """Lfm2Moe alternates between attention and short-conv layers.""" diff --git a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py index 8264cf4f6118..2015a4a83e31 100644 --- a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py +++ b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py @@ -69,10 +69,8 @@ class Lfm2MoeModelTest(CausalLMModelTest, unittest.TestCase): # used in `test_torch_compile_for_training` _torch_compile_train_cls = Lfm2MoeForCausalLM if is_torch_available() else None - def _get_mamba_cache_shapes(self, batch_size: int, config): - conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) - ssm_shape = (0,) - return conv_shape, ssm_shape + def _get_conv_state_shape(self, batch_size: int, config): + return (batch_size, config.hidden_size, config.conv_L_cache) def test_attention_outputs(self): """Lfm2Moe alternates between attention and short-conv layers.""" diff --git a/tests/models/lfm2_vl/test_modeling_lfm2_vl.py b/tests/models/lfm2_vl/test_modeling_lfm2_vl.py index f4ef47a97402..c14e3933f77b 100644 --- a/tests/models/lfm2_vl/test_modeling_lfm2_vl.py +++ b/tests/models/lfm2_vl/test_modeling_lfm2_vl.py @@ -171,10 +171,8 @@ def setUp(self): self, config_class=Lfm2VlConfig, has_text_modality=False, common_properties=common_properties ) - def _get_mamba_cache_shapes(self, batch_size: int, config): - conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) - ssm_shape = (0,) - return conv_shape, ssm_shape + def _get_conv_state_shape(self, batch_size: int, config): + return (batch_size, config.hidden_size, config.conv_L_cache) def test_config(self): self.config_tester.run_common_tests() diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index f64047907acc..6e7116afdc36 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -244,15 +244,17 @@ def setUp(self): self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] ) - def _get_mamba_cache_shapes(self, batch_size: int, config): + def _get_conv_state_shape(self, batch_size: int, config): intermediate_size = config.expand * config.hidden_size conv_shape = ( batch_size, intermediate_size + 2 * config.n_groups * config.state_size, config.conv_kernel, ) - ssm_shape = (batch_size, config.num_heads, config.head_dim, config.state_size) - return conv_shape, ssm_shape + return conv_shape + + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.num_heads, config.head_dim, config.state_size) def test_mamba2_caching(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/nemotron_h/test_modeling_nemotron_h.py b/tests/models/nemotron_h/test_modeling_nemotron_h.py index 2c45f4c5485c..19ca9f4f77fd 100644 --- a/tests/models/nemotron_h/test_modeling_nemotron_h.py +++ b/tests/models/nemotron_h/test_modeling_nemotron_h.py @@ -351,15 +351,17 @@ class NemotronHModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester else {} ) - def _get_mamba_cache_shapes(self, batch_size: int, config): + def _get_conv_state_shape(self, batch_size: int, config): intermediate_size = config.mamba_num_heads * config.mamba_head_dim conv_shape = ( batch_size, intermediate_size + 2 * config.n_groups * config.ssm_state_size, config.conv_kernel, ) - ssm_shape = (batch_size, config.mamba_num_heads, config.mamba_head_dim, config.ssm_state_size) - return conv_shape, ssm_shape + return conv_shape + + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.mamba_num_heads, config.mamba_head_dim, config.ssm_state_size) def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): # Raise a useful error, asking to explicitly override the method @@ -379,7 +381,8 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l # For cross attention cache, the seq_length depends on the model, so we remove that dim attention_shape = (batch_size, num_kv_heads, seq_length, head_dim) # For mamba layers - conv_shape, ssm_shape = self._get_mamba_cache_shapes(batch_size, config) + conv_shape = self._get_conv_state_shape(batch_size, config) + recurrent_shape = self._get_recurrent_state_shape(batch_size, config) # Check each layer has the correct shape for layer, layer_type in zip(past_key_values.layers, config.layer_types): @@ -394,7 +397,7 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l # Mamba layer cache elif layer_type == "mamba": self.assertEqual(layer.conv_states.shape, conv_shape) - self.assertEqual(layer.recurrent_states.shape, ssm_shape) + self.assertEqual(layer.recurrent_states.shape, recurrent_shape) else: raise ValueError("Unknown layer type.") diff --git a/tests/models/qwen3_5/test_modeling_qwen3_5.py b/tests/models/qwen3_5/test_modeling_qwen3_5.py index 839b4f6c7fc4..f90fb09546d6 100644 --- a/tests/models/qwen3_5/test_modeling_qwen3_5.py +++ b/tests/models/qwen3_5/test_modeling_qwen3_5.py @@ -70,16 +70,21 @@ class Qwen3_5TextModelTest(CausalLMModelTest, unittest.TestCase): config_class = Qwen3_5TextConfig model_split_percents = [0.5, 0.8, 0.9] - def _get_mamba_cache_shapes(self, batch_size: int, config): + def _get_conv_state_shape(self, batch_size: int, config): num_v_heads = config.linear_num_value_heads num_k_heads = config.linear_num_key_heads head_k_dim = config.linear_key_head_dim head_v_dim = config.linear_value_head_dim intermediate_size = 2 * num_k_heads * head_k_dim + num_v_heads * head_v_dim - conv_shape = (batch_size, intermediate_size, config.linear_conv_kernel_dim) - ssm_shape = (batch_size, num_v_heads, head_k_dim, head_v_dim) - return conv_shape, ssm_shape + return (batch_size, intermediate_size, config.linear_conv_kernel_dim) + + def _get_recurrent_state_shape(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + + return (batch_size, num_v_heads, head_k_dim, head_v_dim) def test_attention_outputs(self): "Needs to be overwritten as Qwen3.5 alternates between attention layers and gated deltanet layers." @@ -299,16 +304,21 @@ def setUp(self): def test_config(self): self.config_tester.run_common_tests() - def _get_mamba_cache_shapes(self, batch_size: int, config): + def _get_conv_state_shape(self, batch_size: int, config): num_v_heads = config.linear_num_value_heads num_k_heads = config.linear_num_key_heads head_k_dim = config.linear_key_head_dim head_v_dim = config.linear_value_head_dim intermediate_size = 2 * num_k_heads * head_k_dim + num_v_heads * head_v_dim - conv_shape = (batch_size, intermediate_size, config.linear_conv_kernel_dim) - ssm_shape = (batch_size, num_v_heads, head_k_dim, head_v_dim) - return conv_shape, ssm_shape + return (batch_size, intermediate_size, config.linear_conv_kernel_dim) + + def _get_recurrent_state_shape(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + + return (batch_size, num_v_heads, head_k_dim, head_v_dim) def test_attention_outputs(self): "Needs to be overwritten as Qwen3.5 alternates between attention layers and gated deltanet layers." diff --git a/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py b/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py index 55b58a1aea34..d949f777f8a4 100644 --- a/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py +++ b/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py @@ -73,16 +73,21 @@ class Qwen3_5MoeTextModelTest(CausalLMModelTest, unittest.TestCase): model_tester_class = Qwen3_5MoeTextModelTester config_class = Qwen3_5MoeTextConfig - def _get_mamba_cache_shapes(self, batch_size: int, config): + def _get_conv_state_shape(self, batch_size: int, config): num_v_heads = config.linear_num_value_heads num_k_heads = config.linear_num_key_heads head_k_dim = config.linear_key_head_dim head_v_dim = config.linear_value_head_dim intermediate_size = 2 * num_k_heads * head_k_dim + num_v_heads * head_v_dim - conv_shape = (batch_size, intermediate_size, config.linear_conv_kernel_dim) - ssm_shape = (batch_size, num_v_heads, head_k_dim, head_v_dim) - return conv_shape, ssm_shape + return (batch_size, intermediate_size, config.linear_conv_kernel_dim) + + def _get_recurrent_state_shape(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + + return (batch_size, num_v_heads, head_k_dim, head_v_dim) def test_attention_outputs(self): "Needs to be overwritten as Qwen3.5 Moe alternates between attention layers and gated deltanet layers." @@ -381,16 +386,21 @@ def setUp(self): def test_config(self): self.config_tester.run_common_tests() - def _get_mamba_cache_shapes(self, batch_size: int, config): + def _get_conv_state_shape(self, batch_size: int, config): num_v_heads = config.linear_num_value_heads num_k_heads = config.linear_num_key_heads head_k_dim = config.linear_key_head_dim head_v_dim = config.linear_value_head_dim intermediate_size = 2 * num_k_heads * head_k_dim + num_v_heads * head_v_dim - conv_shape = (batch_size, intermediate_size, config.linear_conv_kernel_dim) - ssm_shape = (batch_size, num_v_heads, head_k_dim, head_v_dim) - return conv_shape, ssm_shape + return (batch_size, intermediate_size, config.linear_conv_kernel_dim) + + def _get_recurrent_state_shape(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + + return (batch_size, num_v_heads, head_k_dim, head_v_dim) def test_attention_outputs(self): "Needs to be overwritten as Qwen3.5 Moe alternates between attention layers and gated deltanet layers." diff --git a/tests/models/qwen3_next/test_modeling_qwen3_next.py b/tests/models/qwen3_next/test_modeling_qwen3_next.py index f597b8e28192..4cb53fb6c695 100644 --- a/tests/models/qwen3_next/test_modeling_qwen3_next.py +++ b/tests/models/qwen3_next/test_modeling_qwen3_next.py @@ -57,16 +57,21 @@ def __init__(self, parent): class Qwen3NextModelTest(CausalLMModelTest, unittest.TestCase): model_tester_class = Qwen3NextModelTester - def _get_mamba_cache_shapes(self, batch_size: int, config): + def _get_conv_state_shape(self, batch_size: int, config): num_v_heads = config.linear_num_value_heads num_k_heads = config.linear_num_key_heads head_k_dim = config.linear_key_head_dim head_v_dim = config.linear_value_head_dim intermediate_size = 2 * num_k_heads * head_k_dim + num_v_heads * head_v_dim - conv_shape = (batch_size, intermediate_size, config.linear_conv_kernel_dim) - ssm_shape = (batch_size, num_v_heads, head_k_dim, head_v_dim) - return conv_shape, ssm_shape + return (batch_size, intermediate_size, config.linear_conv_kernel_dim) + + def _get_recurrent_state_shape(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + + return (batch_size, num_v_heads, head_k_dim, head_v_dim) def test_attention_outputs(self): "Needs to be overwritten as Qwen3-Next alternates between attention layers and gated deltanet layers." diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index 48e8bdb64a40..9e7ee869b88d 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -289,11 +289,13 @@ class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ) model_split_percents = [0.5, 0.8, 0.9] - def _get_mamba_cache_shapes(self, batch_size: int, config): + def _get_conv_state_shape(self, batch_size: int, config): intermediate_size = config.mamba_expand * config.hidden_size - conv_shape = (batch_size, intermediate_size, config.mamba_d_conv) - ssm_shape = (batch_size, config.n_mamba_heads, intermediate_size // config.n_mamba_heads, config.mamba_d_state) - return conv_shape, ssm_shape + return (batch_size, intermediate_size, config.mamba_d_conv) + + def _get_recurrent_state_shape(self, batch_size: int, config): + intermediate_size = config.mamba_expand * config.hidden_size + return (batch_size, config.n_mamba_heads, intermediate_size // config.n_mamba_heads, config.mamba_d_state) def setUp(self): self.model_tester = ZambaModelTester(self) diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index a240a4a71334..66b9093ee4e8 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -299,15 +299,17 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix ) model_split_percents = [0.5, 0.8, 0.9] - def _get_mamba_cache_shapes(self, batch_size: int, config): + def _get_conv_state_shape(self, batch_size: int, config): intermediate_size = config.mamba_expand * config.hidden_size conv_shape = ( batch_size, intermediate_size + 2 * config.mamba_ngroups * config.mamba_d_state, config.mamba_d_conv, ) - ssm_shape = (batch_size, config.n_mamba_heads, config.mamba_headdim, config.mamba_d_state) - return conv_shape, ssm_shape + return conv_shape + + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.n_mamba_heads, config.mamba_headdim, config.mamba_d_state) def setUp(self): self.model_tester = Zamba2ModelTester(self) From 3600b89ce12991f04ae48bc6bacf58d23c7374ee Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 31 Mar 2026 14:24:24 +0200 Subject: [PATCH 56/56] comments and renaming --- src/transformers/cache_utils.py | 4 ++-- src/transformers/generation/utils.py | 6 +++--- tests/generation/test_utils.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index b73f2658ed80..ac324ebb62b4 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -804,7 +804,7 @@ def update_recurrent_state(self, recurrent_states: torch.Tensor, **kwargs) -> to return self.recurrent_states -class LinearAttentionAndAttentionLayer(LinearAttentionLayer, DynamicLayer): +class LinearAttentionAndFullAttentionLayer(LinearAttentionLayer, DynamicLayer): # The dynamic Attention part makes it non-compileable is_compileable = False @@ -1231,7 +1231,7 @@ def __init__( elif layer_type in ("mamba", "conv", "linear_attention", "moe"): layers.append(LinearAttentionLayer()) elif layer_type == "hybrid": - layers.append(LinearAttentionAndAttentionLayer()) + layers.append(LinearAttentionAndFullAttentionLayer()) else: layers.append(DynamicLayer()) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d28ac6744023..d07cf05b1ed7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1781,7 +1781,7 @@ def _supports_default_dynamic_cache(cls: type["GenerativePreTrainedModel"]) -> b "reformer", "minimax", "xlnet", - "olmohybrid", + "olmohybrid", # olmo_hybrid cannot use linear attention cache for now as it uses split k,q,v conv states "rwkv", "xlstm", ) @@ -2001,8 +2001,8 @@ def _valid_auto_compile_criteria( valid_hardware = self.device.type in ["cuda", "xpu"] or bool( generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices ) - # Note: for full mamba models, even a DynamicCache is compileable since all layers are mamba, but we don't want - # to ALWAYS compile when calling `generate`, so we check the type + # Note: for some models that only use linear attention (e.g. Mamba), even a DynamicCache is compileable since all + # layers are, but we don't want to ALWAYS compile when calling `generate`, so we check the type using_compilable_cache = cache is not None and cache.is_compileable and type(cache) is not DynamicCache can_compile = valid_hardware and using_compilable_cache diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 2184827b0891..055b852be0b3 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -81,7 +81,7 @@ Cache, DynamicCache, EncoderDecoderCache, - LinearAttentionAndAttentionLayer, + LinearAttentionAndFullAttentionLayer, LinearAttentionLayer, QuantoQuantizedLayer, StaticCache, @@ -2599,7 +2599,7 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l # Check each layer has the correct shape for layer in past_key_values.layers: # Mamba + Attention layer cache - if type(layer) is LinearAttentionAndAttentionLayer: + if type(layer) is LinearAttentionAndFullAttentionLayer: # Remove the seq_length dim for cross-attention cache (it changes based on the model) keys = layer.keys if seq_length is not None else layer.keys[:, :, 0, :] values = layer.values if seq_length is not None else layer.values[:, :, 0, :] @@ -2664,7 +2664,7 @@ def _check_caches_are_equal(self, cache1: Cache, cache2: Cache): self.assertEqual(type(cache1.layers[idx]), type(cache2.layers[idx])) # Mamba + Attention layer - if type(cache1.layers[idx]) is LinearAttentionAndAttentionLayer: + if type(cache1.layers[idx]) is LinearAttentionAndFullAttentionLayer: torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states)