diff --git a/docs/source/en/model_doc/falcon_mamba.md b/docs/source/en/model_doc/falcon_mamba.md index 78b6e23a8127..bba9e95b63a2 100644 --- a/docs/source/en/model_doc/falcon_mamba.md +++ b/docs/source/en/model_doc/falcon_mamba.md @@ -111,13 +111,6 @@ outputs = model.generate(**inputs, max_new_tokens=100) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` -## FalconMambaCache - -[[autodoc]] FalconMambaCache - - update_conv_state - - update_ssm_state - - reset - ## FalconMambaConfig [[autodoc]] FalconMambaConfig diff --git a/docs/source/en/model_doc/mamba.md b/docs/source/en/model_doc/mamba.md index 7add263ab4fd..dd2bb2580a1a 100644 --- a/docs/source/en/model_doc/mamba.md +++ b/docs/source/en/model_doc/mamba.md @@ -110,13 +110,6 @@ print(tokenizer.decode(output[0], skip_special_tokens=True)) trainer.train() ``` -## MambaCache - -[[autodoc]] MambaCache - - update_conv_state - - update_ssm_state - - reset - ## MambaConfig [[autodoc]] MambaConfig diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b32587b1fbc2..9b48d2f5a5de 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -634,7 +634,6 @@ from .modeling_utils import AttentionInterface as AttentionInterface from .modeling_utils import PreTrainedModel as PreTrainedModel from .models import * - from .models.mamba.modeling_mamba import MambaCache as MambaCache from .models.timm_wrapper import TimmWrapperImageProcessor as TimmWrapperImageProcessor # Optimization diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 7dede60a7b27..ac324ebb62b4 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -669,6 +669,168 @@ def _dequantize(self, qtensor): return tensor +class LinearAttentionCacheLayerMixin(ABC): + """Base, abstract class for a linear attention single layer's cache.""" + + # All shapes are static by essence in a LinearAttention layer, so it is compileable + is_compileable = True + + def __init__(self): + self.conv_states: torch.Tensor | None = None + self.recurrent_states: torch.Tensor | None = None + self.is_conv_states_initialized = False + self.is_recurrent_states_initialized = False + self.has_previous_state = False + + def __repr__(self): + return f"{self.__class__.__name__}" + + @abstractmethod + def lazy_initialization( + self, conv_states: torch.Tensor | None = None, recurrent_states: torch.Tensor | None = None + ) -> None: ... + + @abstractmethod + def update_conv_state(self, conv_states: torch.Tensor) -> torch.Tensor: ... + + @abstractmethod + def update_recurrent_state(self, recurrent_states: torch.Tensor) -> torch.Tensor: ... + + def offload(self): + """Offload this layer's data to CPU device.""" + if self.is_conv_states_initialized: + self.conv_states = self.conv_states.to("cpu", non_blocking=True) + if self.is_recurrent_states_initialized: + self.recurrent_states = self.recurrent_states.to("cpu", non_blocking=True) + + def prefetch(self): + """In case of layer offloading, this allows to move the data back to the layer's device ahead of time.""" + if self.is_conv_states_initialized and self.conv_states.device != self.device: + self.conv_states = self.conv_states.to(self.device, non_blocking=True) + if self.is_recurrent_states_initialized and self.recurrent_states.device != self.device: + self.recurrent_states = self.recurrent_states.to(self.device, non_blocking=True) + + def reset(self) -> None: + """Resets the cache values while preserving the objects""" + if self.is_conv_states_initialized: + self.conv_states.zero_() + if self.is_recurrent_states_initialized: + self.recurrent_states.zero_() + self.has_previous_state = False + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + if self.is_conv_states_initialized: + self.conv_states = self.conv_states.index_select(0, beam_idx.to(self.device)) + # recurrent_states can stay empty sometimes, see e.g. lfm2 which only uses the conv_states + if self.is_recurrent_states_initialized: + self.recurrent_states = self.recurrent_states.index_select(0, beam_idx.to(self.device)) + + def crop(self, max_length: int): + # We don't crop the linear attention cache, so simply do nothing here + pass + + +class LinearAttentionLayer(LinearAttentionCacheLayerMixin): + def lazy_initialization( + self, conv_states: torch.Tensor | None = None, recurrent_states: torch.Tensor | None = None + ) -> None: + # Here, we will lazy init both states separately, each in their own update function + if conv_states is not None: + self.dtype, self.device = conv_states.dtype, conv_states.device + # Even if prefill is larfer/shorter than the conv_size, the tensor is always either padded or truncated + self.max_batch_size, self.conv_kernel_size = conv_states.shape[0], conv_states.shape[-1] + # The shape is always static, so we init as such + self.conv_states = torch.zeros_like(conv_states, dtype=self.dtype, device=self.device) + # Mark as static address to be able to use cudagraphs + if not is_torchdynamo_compiling(): + torch._dynamo.mark_static_address(self.conv_states) + self.is_conv_states_initialized = True + if recurrent_states is not None: + # The shape is always static, so we init as such + self.recurrent_states = torch.zeros_like(recurrent_states, dtype=self.dtype, device=self.device) + # Mark as static address to be able to use cudagraphs + if not is_torchdynamo_compiling(): + torch._dynamo.mark_static_address(self.recurrent_states) + self.is_recurrent_states_initialized = True + + def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Update the linear attention cache in-place, and return the necessary conv states. + + Args: + conv_states (`torch.Tensor`): The new conv states to cache. + + Returns: + `torch.Tensor`: The updated conv states. + """ + # Lazy initialization + if not self.is_conv_states_initialized: + self.lazy_initialization(conv_states=conv_states) + + if not self.has_previous_state: + # Note that we copy instead of assigning, to preserve the static address for cudagraphs + self.conv_states.copy_(conv_states) + self.has_previous_state = True + # Technically, this update is not logically correct if the prefill is smaller than `conv_kernel_size`, + # as it will `roll` anyway in the first decoding step, even though it should `roll` ONLY if the cache is already full. + # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now + else: + # Note that we copy instead of assigning, to preserve the static address for cudagraphs + num_new_tokens = conv_states.shape[-1] + if num_new_tokens >= self.conv_kernel_size: + self.conv_states.copy_(conv_states[..., -self.conv_kernel_size :]) + else: + new_conv_states = self.conv_states.roll(shifts=-num_new_tokens, dims=-1) + new_conv_states[:, :, -num_new_tokens:] = conv_states + self.conv_states.copy_(new_conv_states) + + return self.conv_states + + def update_recurrent_state(self, recurrent_states: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Update the linear attention cache in-place, and return the necessary ssm states. + + Args: + smm_states (`torch.Tensor`): The new ssm states to cache. + + Returns: + `torch.Tensor`: The updated ssm states. + """ + if not self.is_recurrent_states_initialized: + self.lazy_initialization(recurrent_states=recurrent_states) + # Note that we copy instead of assigning, to preserve the static address for cudagraphs + self.recurrent_states.copy_(recurrent_states) + return self.recurrent_states + + +class LinearAttentionAndFullAttentionLayer(LinearAttentionLayer, DynamicLayer): + # The dynamic Attention part makes it non-compileable + is_compileable = False + + def __init__(self): + DynamicLayer.__init__(self) + LinearAttentionLayer.__init__(self) + + def lazy_initialization(self, *args, **kwargs) -> None: + # When the Attention cache is used with `update`, `lazy_initialization` is called with 2 positional args + if len(args) == 2 and len(kwargs) == 0: + DynamicLayer.lazy_initialization(self, *args) + # Otherwise, for the LinearAttention cache, when it's called in `update_conv_state` or `update_recurrent_state`, it's + # always called with 1 single kwarg (cause it needs to know if it's for the conv or ssm states) + if len(args) == 0 and len(kwargs) == 1: + LinearAttentionLayer.lazy_initialization(self, **kwargs) + + def reset(self) -> None: + LinearAttentionLayer.reset(self) + DynamicLayer.reset(self) + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + LinearAttentionLayer.reorder_cache(self, beam_idx) + DynamicLayer.reorder_cache(self, beam_idx) + + class Cache: """ A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for @@ -676,9 +838,9 @@ class Cache: Args: layers (`Optional`, *optional*): - A list of pre-created `CacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate` will - be used. - layer_class_to_replicate (`type[CacheLayerMixin]`, *optional*): + A list of pre-created `CacheLayerMixin` or `LinearAttentionCacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate` + will be used. + layer_class_to_replicate (`type[CacheLayerMixin | LinearAttentionCacheLayerMixin]`, *optional*): Only used if `layers` is omitted (`None`), in which case it will be used as the base class for each layer, and the layers will be added lazily as soon as `update` is called with a `layer_idx` greater than the current list of layers. @@ -691,8 +853,8 @@ class Cache: def __init__( self, - layers: list[CacheLayerMixin] | None = None, - layer_class_to_replicate: type[CacheLayerMixin] | None = None, + layers: list[CacheLayerMixin | LinearAttentionCacheLayerMixin] | None = None, + layer_class_to_replicate: type[CacheLayerMixin | LinearAttentionCacheLayerMixin] | None = None, offloading: bool = False, offload_only_non_sliding: bool = True, ): @@ -779,6 +941,46 @@ def update( return keys, values + def update_conv_state(self, conv_states: torch.Tensor, layer_idx: int, **kwargs) -> torch.Tensor: + """ + Updates the cache with the new `conv_states` for the layer `layer_idx`. + + Parameters: + conv_states (`torch.Tensor`): + The new conv states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + + Return: + `torch.Tensor`: The updated conv states. + """ + # NOTE: if we slightly break `update` arg order, we could combine this with it, and allow offloading support + # out of the box + if not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin): + raise ValueError("Cannot call `update_conv_state` on a non-LinearAttention layer!") + conv_states = self.layers[layer_idx].update_conv_state(conv_states, **kwargs) + return conv_states + + def update_recurrent_state(self, recurrent_states: torch.Tensor, layer_idx: int, **kwargs) -> torch.Tensor: + """ + Updates the cache with the new `recurrent_states` for the layer `layer_idx`. + + Parameters: + smm_states (`torch.Tensor`): + The new ssm states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + + Return: + `torch.Tensor`: The updated ssm states. + """ + # NOTE: if we slightly break `update` arg order, we could combine this with it, and allow offloading support + # out of the box + if not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin): + raise ValueError("Cannot call `update_conv_state` on a non-LinearAttention layer!") + recurrent_states = self.layers[layer_idx].update_recurrent_state(recurrent_states, **kwargs) + return recurrent_states + def early_initialization( self, batch_size: int, num_heads: int, head_dim: int, dtype: torch.dtype, device: torch.device ): @@ -798,8 +1000,52 @@ def get_seq_length(self, layer_idx: int = 0) -> int: """Returns the sequence length of the cache for the given layer.""" if layer_idx >= len(self.layers): return 0 + + # For alternating attention/linear attention caches, `get_seq_length` needs to use attention layer idx when called with default layer_idx + if not isinstance(self.layers[layer_idx], CacheLayerMixin): + # If this is called with non-default arg, raise + if layer_idx != 0: + raise ValueError( + f"You called `get_seq_length` on layer index {layer_idx}, but this layer is a LinearAttention layer, which " + "does not track sequence length." + ) + try: + # Use the first attention layer + layer_idx = next(idx for idx in range(len(self)) if isinstance(self.layers[idx], CacheLayerMixin)) + except StopIteration: + raise ValueError( + "`get_seq_length` can only be called on Attention layers, and the current Cache seem to only contain " + "LinearAttention layers." + ) + return self.layers[layer_idx].get_seq_length() + def has_previous_state(self, layer_idx: int | None = None) -> bool: + """Returns whether the LinearAttention layer at index `layer_idx` has previous state or not.""" + if layer_idx is not None and layer_idx >= len(self.layers): + return False + + # In this case, use last LinearAttention layer + if layer_idx is None: + try: + layer_idx = next( + idx + for idx in range(len(self) - 1, -1, -1) + if isinstance(self.layers[idx], LinearAttentionCacheLayerMixin) + ) + except StopIteration: + raise ValueError( + "`has_previous_state` can only be called on LinearAttention layers, and the current Cache seem to " + "only contain Attention layers." + ) + elif not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin): + raise ValueError( + f"You called `has_previous_state` on layer index {layer_idx}, but this layer is an Attention layer, which " + "does not support calling it." + ) + + return self.layers[layer_idx].has_previous_state + def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: """ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for @@ -810,6 +1056,24 @@ def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: # simply the query_length if layer_idx >= len(self.layers): return query_length, 0 + + # For alternating attention/linear attention caches, `get_mask_sizes` needs to use attention layer idx when called with default layer_idx + if not isinstance(self.layers[layer_idx], CacheLayerMixin): + # If this is called with non-default arg, raise + if layer_idx != 0: + raise ValueError( + f"You called `get_mask_sizes` on layer index {layer_idx}, but this layer is a LinearAttention layer, which " + "does not track sequence length." + ) + try: + # Use the first attention layer + layer_idx = next(idx for idx in range(len(self)) if isinstance(self.layers[idx], CacheLayerMixin)) + except StopIteration: + raise ValueError( + "`get_mask_sizes` can only be called on Attention layers, and the current Cache seem to only contain " + "LinearAttention layers." + ) + return self.layers[layer_idx].get_mask_sizes(query_length) def get_max_cache_shape(self, layer_idx: int = 0) -> int: @@ -945,10 +1209,12 @@ def __init__( ) layer_types = getattr(decoder_config, "layer_types", None) if layer_types is None: - layer_types = [ - "sliding_attention" if sliding_window is not None else "full_attention" - for _ in range(decoder_config.num_hidden_layers) - ] + layer_types = [] + for _ in range(decoder_config.num_hidden_layers): + if sliding_window is not None: + layer_types.append("sliding_attention") + else: + layer_types.append("full_attention") # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n) if hasattr(decoder_config, "num_kv_shared_layers"): layer_types = layer_types[: -decoder_config.num_kv_shared_layers] @@ -958,6 +1224,14 @@ def __init__( # states they should return - only the mask changes to make them different at the end! if layer_type in ("sliding_attention", "chunked_attention"): layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window)) + # Note: we want moe layers to be LinearAttentionLayer, so that we can correctly grab sequence length etc from attention layers. + # Since moe layers will stay empty (they don't need any cache), we don't want them to collide for mask creation etc + # TODO: maybe use a dummy layer in those cases, or a dictionary {idx: Layer} for self.layers, so that we can skip + # the indices we don't need + elif layer_type in ("mamba", "conv", "linear_attention", "moe"): + layers.append(LinearAttentionLayer()) + elif layer_type == "hybrid": + layers.append(LinearAttentionAndFullAttentionLayer()) else: layers.append(DynamicLayer()) @@ -1067,6 +1341,9 @@ def __init__( layer = StaticSlidingWindowLayer( max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size ) + # LinearAttention layers are static by essence - using `"moe"` as well is a trick, see the comment about it on DynamicCache + elif layer_type in ("mamba", "conv", "linear_attention", "moe"): + layer = LinearAttentionLayer() else: layer = StaticLayer(max_cache_len=max_cache_len) layers.append(layer) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 5d3d3145ef00..97d6b94b57aa 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -68,6 +68,8 @@ "attention", "sparse", "dense", + "hybrid", # for layers that have both mamba and attention in zamba and zamba2 + "moe", # for nemotron_h, which uses either attention, mamba or moe ) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e8902257bb3b..d07cf05b1ed7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1775,19 +1775,19 @@ def _prepare_static_cache( def _supports_default_dynamic_cache(cls: type["GenerativePreTrainedModel"]) -> bool: """ Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. - This adds exception for some models like `Mamba` models which use their own caches. """ # NOTE: remove xlnet/reformer when the models are deprecated, non-standard model architecture/cache name - return not cls._is_stateful and all( - special_model_name not in cls.__name__.lower() - or "minimaxm2" in cls.__name__.lower() # name clash between minimax and minimax m2 - for special_model_name in [ - "reformer", - "minimax", - "xlnet", - "lfm2", - "lfm2_vl", - ] + unsupported_model_names = ( + "reformer", + "minimax", + "xlnet", + "olmohybrid", # olmo_hybrid cannot use linear attention cache for now as it uses split k,q,v conv states + "rwkv", + "xlstm", + ) + # name clash between minimax and minimax m2, so we add this "or" + return "minimaxm2" in cls.__name__.lower() or all( + unsupported_name not in cls.__name__.lower() for unsupported_name in unsupported_model_names ) def _prepare_cache_for_generation( @@ -1849,7 +1849,12 @@ def _prepare_cache_for_generation( generation_config.cache_implementation = "dynamic_full" dynamic_cache_kwargs = {} - if generation_config.cache_implementation != "dynamic_full": + # linear attention models always need to pass the config, otherwise it will use an Attention cache for the LinearAttention layers + is_linear_attention = any( + x in ("mamba", "conv", "linear_attention") + for x in getattr(self.config.get_text_config(decoder=True), "layer_types", []) + ) + if generation_config.cache_implementation != "dynamic_full" or is_linear_attention: dynamic_cache_kwargs["config"] = self.config.get_text_config(decoder=True) if generation_config.cache_implementation == "offloaded": @@ -1862,7 +1867,7 @@ def _prepare_cache_for_generation( f"and will be removed in v5.13. Please only use one of {STATIC_CACHE_IMPLEMENTATIONS}, " "and the layer structure will be inferred automatically." ) - model_kwargs["past_key_values"] = self._prepare_static_cache( + model_kwargs[cache_name] = self._prepare_static_cache( cache_implementation=generation_config.cache_implementation, batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size, max_cache_len=max_cache_length, @@ -1878,19 +1883,19 @@ def _prepare_cache_for_generation( cache_config = generation_config.cache_config if generation_config.cache_config is not None else {} cache_config.setdefault("config", self.config.get_text_config(decoder=True)) backend = cache_config.pop("backend", "quanto") - model_kwargs["past_key_values"] = QuantizedCache(backend=backend, **cache_config) + model_kwargs[cache_name] = QuantizedCache(backend=backend, **cache_config) # i.e. `cache_implementation` in [None, "dynamic", "offloaded", "dynamic_full"] # TODO: prepare linear cache from a single API, instead of creating in modeling code else: - model_kwargs["past_key_values"] = DynamicCache(**dynamic_cache_kwargs) + model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs) if ( self.config.is_encoder_decoder - and "past_key_values" in model_kwargs - and not isinstance(model_kwargs["past_key_values"], EncoderDecoderCache) + and cache_name in model_kwargs + and not isinstance(model_kwargs[cache_name], EncoderDecoderCache) ): - model_kwargs["past_key_values"] = EncoderDecoderCache( - model_kwargs["past_key_values"], # self-attention cache + model_kwargs[cache_name] = EncoderDecoderCache( + model_kwargs[cache_name], # self-attention cache DynamicCache(**dynamic_cache_kwargs), # cross-attention cache ) @@ -1990,13 +1995,15 @@ def _valid_auto_compile_criteria( if generation_config.disable_compile: return False + cache = model_kwargs.get("past_key_values", model_kwargs.get("cache_params")) + # Base logic valid_hardware = self.device.type in ["cuda", "xpu"] or bool( generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices ) - using_compilable_cache = ( - isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable - ) + # Note: for some models that only use linear attention (e.g. Mamba), even a DynamicCache is compileable since all + # layers are, but we don't want to ALWAYS compile when calling `generate`, so we check the type + using_compilable_cache = cache is not None and cache.is_compileable and type(cache) is not DynamicCache can_compile = valid_hardware and using_compilable_cache # Exception 1: Some quantization methods do not support compilation @@ -3467,10 +3474,9 @@ def _assisted_decoding( # The cache must be dynamic for assisted generation, and the check must happen AFTER preparing cache if not model_kwargs["use_cache"]: raise ValueError("assisted generate requires `use_cache=True`") - if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"] or ( - "past_key_values" in model_kwargs - and hasattr(model_kwargs["past_key_values"], "layers") - and any(getattr(l, "is_compileable", False) for l in model_kwargs["past_key_values"].layers) + if ( + generation_config.cache_implementation in ["static", "hybrid", "sliding_window"] + or type(model_kwargs.get("past_key_values")) is StaticCache ): raise ValueError("assisted generate is not supported with Static cache classes`") # Get the candidate generator, given the parameterization diff --git a/src/transformers/models/bamba/configuration_bamba.py b/src/transformers/models/bamba/configuration_bamba.py index 316c7deedcb1..57da7cbd4e64 100644 --- a/src/transformers/models/bamba/configuration_bamba.py +++ b/src/transformers/models/bamba/configuration_bamba.py @@ -43,6 +43,7 @@ class BambaConfig(PreTrainedConfig): """ model_type = "bamba" + attribute_map = {"layer_types": "layers_block_type"} keys_to_ignore_at_inference = ["past_key_values"] vocab_size: int = 128000 diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 8888fa5ddbdf..90129fc998b1 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -24,15 +24,14 @@ # limitations under the License. from collections.abc import Callable -from typing import Any, Optional, TypedDict +from typing import Optional, TypedDict import torch from torch import nn -from transformers.activations import ACT2FN - from ... import initialization as init -from ...cache_utils import Cache +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernelized_func from ...integrations.hub_kernels import lazy_load_kernel @@ -76,112 +75,6 @@ class BambaFlashAttentionKwargs(TypedDict, total=False): seq_idx: torch.IntTensor -class HybridMambaAttentionDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): - self.layers_block_type = config.layers_block_type - self.has_previous_state = False # only used by mamba - conv_kernel_size = config.mamba_d_conv - ssm_state_size = config.mamba_d_state - - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - if self.layers_block_type[i] == "mamba": - self.conv_states += [ - torch.zeros( - batch_size, - (config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * ssm_state_size), - conv_kernel_size, - device=device, - dtype=dtype, - ) - ] - self.ssm_states += [ - torch.zeros( - batch_size, - config.mamba_n_heads, - config.mamba_d_head, - ssm_state_size, - device=device, - dtype=dtype, - ) - ] - else: - self.conv_states += [torch.tensor([[]] * batch_size, device=device)] - self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.key_cache) - - def __getitem__(self, layer_idx): - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - if self.get_seq_length() > 0: - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].shape[-1] == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - class BambaRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -592,7 +485,7 @@ def __init__(self, config: BambaConfig, layer_idx: int): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, seq_idx: torch.IntTensor | None = None, ): @@ -605,12 +498,7 @@ def cuda_kernels_forward( groups_time_state_size = self.n_groups * self.ssm_state_size use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) # getting projected states from cache if it exists @@ -622,7 +510,7 @@ def cuda_kernels_forward( # 2. Convolution sequence transformation hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -644,7 +532,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -704,7 +592,7 @@ def cuda_kernels_forward( hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), ) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) if self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( @@ -745,7 +633,7 @@ def cuda_kernels_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer @@ -759,7 +647,7 @@ def cuda_kernels_forward( def torch_forward( self, input_states, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): batch_size, seq_len, _ = input_states.shape @@ -771,23 +659,13 @@ def torch_forward( gate, hidden_states_B_C, dt = projected_states.split( [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + hidden_states_B_C = hidden_states_B_C.transpose(1,2) - use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size - ) + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 # 2. Convolution sequence transformation if use_precomputed_states: - cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1) - cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device) - - # We need to guarantee that anything regarding the cache is on the same device - conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + conv_states = cache_params.update_conv_state(hidden_states_B_C, self.layer_idx) hidden_states_B_C = torch.sum( conv_states * self.conv1d.weight.squeeze(1), dim=-1 @@ -798,13 +676,12 @@ def torch_forward( else: # Init cache if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) conv_states = nn.functional.pad( - hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + hidden_states_B_C, (self.conv_kernel_size - hidden_states_B_C.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) - hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C)[..., :seq_len].transpose(1, 2)) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( @@ -817,7 +694,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if use_precomputed_states: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.ssm_states[self.layer_idx].device + cache_device = cache_params.layers[self.layer_idx].recurrent_states.device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -847,9 +724,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -858,7 +734,7 @@ def torch_forward( C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_states.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -919,10 +795,7 @@ def torch_forward( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - if use_precomputed_states: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) decay_chunk = decay_chunk.transpose(1, 3) @@ -949,7 +822,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) @@ -963,7 +836,7 @@ def torch_forward( def forward( self, hidden_states, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, seq_idx: torch.IntTensor | None = None, **kwargs, @@ -1042,7 +915,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[BambaFlashAttentionKwargs], @@ -1133,7 +1006,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[BambaFlashAttentionKwargs], @@ -1146,10 +1019,7 @@ def forward( hidden_states = inputs_embeds if use_cache and past_key_values is None: - logger.warning_once( - "Bamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " - "provided, so no cache will be returned." - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) @@ -1179,14 +1049,9 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - - next_cache = None if not use_cache else past_key_values - return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, ) def _update_mamba_mask(self, attention_mask, past_key_values): @@ -1196,7 +1061,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None @@ -1226,7 +1091,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, @@ -1295,13 +1160,6 @@ def prepare_inputs_for_generation( is_first_iteration=False, **kwargs, ): - # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` - - if past_key_values is None: - past_key_values = HybridMambaAttentionDynamicCache( - self.config, input_ids.shape[0], self.dtype, device=self.device - ) - kwargs["logits_to_keep"] = self.config.num_logits_to_keep model_inputs = super().prepare_inputs_for_generation( input_ids, diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 4961025f1743..a79b26ff8fe9 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -23,9 +23,20 @@ import torch from torch import nn -from transformers.activations import ACT2FN -from transformers.models.jamba.modeling_jamba import HybridMambaAttentionDynamicCache, JambaAttentionDecoderLayer -from transformers.models.llama.modeling_llama import ( +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...integrations.hub_kernels import lazy_load_kernel +from ...masking_utils import create_causal_mask +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.generic import merge_with_config_defaults +from ...utils.import_utils import resolve_internal_import +from ...utils.output_capturing import capture_outputs +from ..jamba.modeling_jamba import JambaAttentionDecoderLayer +from ..llama.modeling_llama import ( LlamaAttention, LlamaForCausalLM, LlamaMLP, @@ -33,24 +44,13 @@ LlamaRotaryEmbedding, rotate_half, ) -from transformers.models.mamba2.modeling_mamba2 import ( +from ..mamba2.modeling_mamba2 import ( MambaRMSNormGated, apply_mask_to_padding_states, pad_tensor_by_size, reshape_into_chunks, segment_sum, ) - -from ... import initialization as init -from ...integrations.hub_kernels import lazy_load_kernel -from ...masking_utils import create_causal_mask -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_utils import PreTrainedModel -from ...processing_utils import Unpack -from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging -from ...utils.generic import merge_with_config_defaults -from ...utils.import_utils import resolve_internal_import -from ...utils.output_capturing import capture_outputs from .configuration_bamba import BambaConfig @@ -81,60 +81,6 @@ class BambaFlashAttentionKwargs(TypedDict, total=False): seq_idx: torch.IntTensor -# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer -class HybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache): - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): - self.layers_block_type = config.layers_block_type - self.has_previous_state = False # only used by mamba - conv_kernel_size = config.mamba_d_conv - ssm_state_size = config.mamba_d_state - - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - if self.layers_block_type[i] == "mamba": - self.conv_states += [ - torch.zeros( - batch_size, - (config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * ssm_state_size), - conv_kernel_size, - device=device, - dtype=dtype, - ) - ] - self.ssm_states += [ - torch.zeros( - batch_size, - config.mamba_n_heads, - config.mamba_d_head, - ssm_state_size, - device=device, - dtype=dtype, - ) - ] - else: - self.conv_states += [torch.tensor([[]] * batch_size, device=device)] - self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - class BambaRotaryEmbedding(LlamaRotaryEmbedding): pass @@ -296,7 +242,7 @@ def __init__(self, config: BambaConfig, layer_idx: int): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, seq_idx: torch.IntTensor | None = None, ): @@ -309,12 +255,7 @@ def cuda_kernels_forward( groups_time_state_size = self.n_groups * self.ssm_state_size use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) # getting projected states from cache if it exists @@ -326,7 +267,7 @@ def cuda_kernels_forward( # 2. Convolution sequence transformation hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -348,7 +289,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -408,7 +349,7 @@ def cuda_kernels_forward( hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), ) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) if self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( @@ -449,7 +390,7 @@ def cuda_kernels_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer @@ -463,7 +404,7 @@ def cuda_kernels_forward( def torch_forward( self, input_states, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): batch_size, seq_len, _ = input_states.shape @@ -475,23 +416,13 @@ def torch_forward( gate, hidden_states_B_C, dt = projected_states.split( [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + hidden_states_B_C = hidden_states_B_C.transpose(1,2) - use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size - ) + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 # 2. Convolution sequence transformation if use_precomputed_states: - cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1) - cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device) - - # We need to guarantee that anything regarding the cache is on the same device - conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + conv_states = cache_params.update_conv_state(hidden_states_B_C, self.layer_idx) hidden_states_B_C = torch.sum( conv_states * self.conv1d.weight.squeeze(1), dim=-1 @@ -502,13 +433,12 @@ def torch_forward( else: # Init cache if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) conv_states = nn.functional.pad( - hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + hidden_states_B_C, (self.conv_kernel_size - hidden_states_B_C.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) - hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C)[..., :seq_len].transpose(1, 2)) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( @@ -521,7 +451,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if use_precomputed_states: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.ssm_states[self.layer_idx].device + cache_device = cache_params.layers[self.layer_idx].recurrent_states.device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -551,9 +481,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -562,7 +491,7 @@ def torch_forward( C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_states.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -623,10 +552,7 @@ def torch_forward( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - if use_precomputed_states: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) decay_chunk = decay_chunk.transpose(1, 3) @@ -653,7 +579,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) @@ -667,7 +593,7 @@ def torch_forward( def forward( self, hidden_states, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, seq_idx: torch.IntTensor | None = None, **kwargs, @@ -717,7 +643,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[BambaFlashAttentionKwargs], @@ -808,7 +734,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[BambaFlashAttentionKwargs], @@ -821,10 +747,7 @@ def forward( hidden_states = inputs_embeds if use_cache and past_key_values is None: - logger.warning_once( - "Bamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " - "provided, so no cache will be returned." - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) @@ -854,14 +777,9 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - - next_cache = None if not use_cache else past_key_values - return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, ) def _update_mamba_mask(self, attention_mask, past_key_values): @@ -871,7 +789,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None @@ -893,7 +811,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, @@ -962,13 +880,6 @@ def prepare_inputs_for_generation( is_first_iteration=False, **kwargs, ): - # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` - - if past_key_values is None: - past_key_values = HybridMambaAttentionDynamicCache( - self.config, input_ids.shape[0], self.dtype, device=self.device - ) - kwargs["logits_to_keep"] = self.config.num_logits_to_keep model_inputs = super().prepare_inputs_for_generation( input_ids, diff --git a/src/transformers/models/falcon_h1/configuration_falcon_h1.py b/src/transformers/models/falcon_h1/configuration_falcon_h1.py index 980bd9376101..4e4be01d9715 100644 --- a/src/transformers/models/falcon_h1/configuration_falcon_h1.py +++ b/src/transformers/models/falcon_h1/configuration_falcon_h1.py @@ -54,6 +54,7 @@ class FalconH1Config(PreTrainedConfig): """ model_type = "falcon_h1" + attribute_map = {"layer_types": "layers_block_type"} keys_to_ignore_at_inference = ["past_key_values"] vocab_size: int = 128000 @@ -132,7 +133,7 @@ def validate_architecture(self): @property def layers_block_type(self): - return ["attention" for i in range(self.num_hidden_layers)] + return ["hybrid" for i in range(self.num_hidden_layers)] __all__ = ["FalconH1Config"] diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 237c04c8d28d..37b5da9df4b3 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -24,16 +24,15 @@ # limitations under the License. from collections.abc import Callable -from typing import Any, Optional +from typing import Optional import torch import torch.nn.functional as F from torch import nn -from transformers.activations import ACT2FN - from ... import initialization as init -from ...cache_utils import Cache +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func from ...integrations.hub_kernels import lazy_load_kernel @@ -54,161 +53,6 @@ logger = logging.get_logger(__name__) -class FalconHybridMambaAttentionDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__( - self, - config: FalconH1Config, - batch_size: int, - dtype: torch.dtype = torch.float16, - devices: list[str] | None = None, - ): - self.seqlen_offset = 0 - self.dtype = dtype - self.has_previous_state = False - self.conv_kernel_size = config.mamba_d_conv - - self.intermediate_size = ( - config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) - ) - - self.conv_states = { - i: torch.zeros( - batch_size, - self.intermediate_size + 2 * config.mamba_n_groups * config.mamba_d_state, - self.conv_kernel_size, - device=devices[i], - dtype=dtype, - ) - for i in range(config.num_hidden_layers) - } - self.ssm_states = { - i: torch.zeros( - batch_size, - config.mamba_n_heads, - config.mamba_d_head, - config.mamba_d_state, - device=devices[i], - dtype=dtype, - ) - for i in range(config.num_hidden_layers) - } - - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - self.transformer_layers.append(i) - - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - def __len__(self): - return len(self.key_cache) - - def __getitem__(self, layer_idx): - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if len(self.key_cache) <= layer_idx: - # There may be skipped layers, fill them with empty lists - for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append([]) - self.value_cache.append([]) - self.key_cache.append(key_states) - self.value_cache.append(value_states) - elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - if self.get_seq_length() > 0: - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].shape[-1] == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now - if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states[layer_idx].device) - else: - self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states[layer_idx].device) - - return self.conv_states[layer_idx] - - def reset(self): - self.has_previous_state = False - self.conv_states.zero_() - self.ssm_states.zero_() - - class FalconH1RotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -633,7 +477,7 @@ def __init__(self, config: FalconH1Config, layer_idx: int): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: FalconHybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): # 1. Gated MLP's linear projection @@ -649,12 +493,7 @@ def cuda_kernels_forward( groups_time_state_size = self.n_groups * self.ssm_state_size use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) # getting projected states from cache if it exists @@ -668,7 +507,7 @@ def cuda_kernels_forward( # 2. Convolution sequence transformation hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -690,7 +529,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -763,7 +602,7 @@ def cuda_kernels_forward( hidden_states_B_C.permute(0, 2, 1), (self.conv_kernel_size - hidden_states_B_C.shape[-2], 0), ) - cache_params.update_conv_state(self.layer_idx, conv_states, cache_init=True) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) time_step = nn.functional.softplus(dt + self.dt_bias) # 1D Convolution @@ -810,7 +649,7 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer if self.mamba_rms_norm: @@ -824,7 +663,7 @@ def cuda_kernels_forward( def torch_forward( self, input_states, - cache_params: FalconHybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): batch_size, seq_len, _ = input_states.shape @@ -839,19 +678,13 @@ def torch_forward( gate, hidden_states_B_C, dt = projected_states.split([ self.intermediate_size, self.conv_dim, self.num_heads ], dim=-1) + hidden_states_B_C = hidden_states_B_C.transpose(1,2) - use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size - ) + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 # 2. Convolution sequence transformation if use_precomputed_states: - conv_states = cache_params.update_conv_state(self.layer_idx, hidden_states_B_C, cache_init=False) + conv_states = cache_params.update_conv_state(hidden_states_B_C, self.layer_idx) # We need to guarantee that anything regarding the cache is on the same device conv_states = conv_states.to(device=self.conv1d.weight.device) @@ -864,13 +697,12 @@ def torch_forward( else: # Init cache if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) conv_states = nn.functional.pad( - hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + hidden_states_B_C, (self.conv_kernel_size - hidden_states_B_C.shape[-1], 0) ) - conv_states = cache_params.update_conv_state(self.layer_idx, conv_states, cache_init=True) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) - hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C)[..., :seq_len].transpose(1, 2)) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( @@ -883,7 +715,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if use_precomputed_states: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.ssm_states[self.layer_idx].device + cache_device = cache_params.layers[self.layer_idx].recurrent_states.device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -913,9 +745,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -924,7 +755,7 @@ def torch_forward( C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_states.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -985,10 +816,7 @@ def torch_forward( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - if use_precomputed_states: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) decay_chunk = decay_chunk.transpose(1, 3) @@ -1015,7 +843,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) if self.mamba_rms_norm: scan_output = self.norm(y, gate) @@ -1032,7 +860,7 @@ def torch_forward( def forward( self, hidden_states, - cache_params: FalconHybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): @@ -1110,7 +938,7 @@ def forward( attention_mask: torch.Tensor | None = None, mamba_attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: FalconHybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs, @@ -1120,7 +948,7 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`FalconHybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -1271,7 +1099,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: FalconHybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -1284,10 +1112,7 @@ def forward( hidden_states = inputs_embeds if use_cache and past_key_values is None: - logger.warning_once( - "FalconH1 requires an initialized `FalconHybridMambaAttentionDynamicCache` to return a cache. None was " - "provided, so no cache will be returned." - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1319,14 +1144,9 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - - next_cache = None if not use_cache else past_key_values - return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, ) def _update_mamba_mask(self, attention_mask, past_key_values): @@ -1336,7 +1156,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None @@ -1365,7 +1185,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: FalconHybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, @@ -1427,18 +1247,6 @@ def prepare_inputs_for_generation( is_first_iteration=False, **kwargs, ): - # Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache` - - if past_key_values is None: - past_key_values = FalconHybridMambaAttentionDynamicCache( - self.config, - input_ids.shape[0], - self.dtype, - devices=[ - self.model.layers[i].mamba.conv1d.weight.device for i in range(self.config.num_hidden_layers) - ], - ) - kwargs["logits_to_keep"] = self.config.num_logits_to_keep model_inputs = super().prepare_inputs_for_generation( input_ids, diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 4bd7301cd3fd..75cbed28a646 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -19,15 +19,26 @@ """PyTorch FalconH1 model.""" from collections.abc import Callable -from typing import Any import torch import torch.nn.functional as F from torch import nn -from transformers.activations import ACT2FN -from transformers.models.jamba.modeling_jamba import HybridMambaAttentionDynamicCache -from transformers.models.llama.modeling_llama import ( +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...integrations.hub_kernels import lazy_load_kernel +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.generic import merge_with_config_defaults +from ...utils.import_utils import resolve_internal_import +from ...utils.output_capturing import capture_outputs +from ..llama.modeling_llama import ( LlamaAttention, LlamaForCausalLM, LlamaMLP, @@ -36,152 +47,19 @@ apply_rotary_pos_emb, eager_attention_forward, ) -from transformers.models.mamba2.modeling_mamba2 import ( +from ..mamba2.modeling_mamba2 import ( MambaRMSNormGated, apply_mask_to_padding_states, pad_tensor_by_size, reshape_into_chunks, segment_sum, ) - -from ... import initialization as init -from ...cache_utils import Cache -from ...integrations.hub_kernels import lazy_load_kernel -from ...masking_utils import create_causal_mask -from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging -from ...utils.generic import merge_with_config_defaults -from ...utils.import_utils import resolve_internal_import -from ...utils.output_capturing import capture_outputs from .configuration_falcon_h1 import FalconH1Config logger = logging.get_logger(__name__) -class FalconHybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache): - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - def __init__( - self, - config: FalconH1Config, - batch_size: int, - dtype: torch.dtype = torch.float16, - devices: list[str] | None = None, - ): - self.seqlen_offset = 0 - self.dtype = dtype - self.has_previous_state = False - self.conv_kernel_size = config.mamba_d_conv - - self.intermediate_size = ( - config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) - ) - - self.conv_states = { - i: torch.zeros( - batch_size, - self.intermediate_size + 2 * config.mamba_n_groups * config.mamba_d_state, - self.conv_kernel_size, - device=devices[i], - dtype=dtype, - ) - for i in range(config.num_hidden_layers) - } - self.ssm_states = { - i: torch.zeros( - batch_size, - config.mamba_n_heads, - config.mamba_d_head, - config.mamba_d_state, - device=devices[i], - dtype=dtype, - ) - for i in range(config.num_hidden_layers) - } - - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - self.transformer_layers.append(i) - - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if len(self.key_cache) <= layer_idx: - # There may be skipped layers, fill them with empty lists - for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append([]) - self.value_cache.append([]) - self.key_cache.append(key_states) - self.value_cache.append(value_states) - elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now - if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states[layer_idx].device) - else: - self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states[layer_idx].device) - - return self.conv_states[layer_idx] - - def reset(self): - self.has_previous_state = False - self.conv_states.zero_() - self.ssm_states.zero_() - - class FalconH1RotaryEmbedding(LlamaRotaryEmbedding): pass @@ -386,7 +264,7 @@ def __init__(self, config: FalconH1Config, layer_idx: int): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: FalconHybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): # 1. Gated MLP's linear projection @@ -402,12 +280,7 @@ def cuda_kernels_forward( groups_time_state_size = self.n_groups * self.ssm_state_size use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) # getting projected states from cache if it exists @@ -421,7 +294,7 @@ def cuda_kernels_forward( # 2. Convolution sequence transformation hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -443,7 +316,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -516,7 +389,7 @@ def cuda_kernels_forward( hidden_states_B_C.permute(0, 2, 1), (self.conv_kernel_size - hidden_states_B_C.shape[-2], 0), ) - cache_params.update_conv_state(self.layer_idx, conv_states, cache_init=True) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) time_step = nn.functional.softplus(dt + self.dt_bias) # 1D Convolution @@ -563,7 +436,7 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer if self.mamba_rms_norm: @@ -577,7 +450,7 @@ def cuda_kernels_forward( def torch_forward( self, input_states, - cache_params: FalconHybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): batch_size, seq_len, _ = input_states.shape @@ -592,19 +465,13 @@ def torch_forward( gate, hidden_states_B_C, dt = projected_states.split([ self.intermediate_size, self.conv_dim, self.num_heads ], dim=-1) + hidden_states_B_C = hidden_states_B_C.transpose(1,2) - use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size - ) + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 # 2. Convolution sequence transformation if use_precomputed_states: - conv_states = cache_params.update_conv_state(self.layer_idx, hidden_states_B_C, cache_init=False) + conv_states = cache_params.update_conv_state(hidden_states_B_C, self.layer_idx) # We need to guarantee that anything regarding the cache is on the same device conv_states = conv_states.to(device=self.conv1d.weight.device) @@ -617,13 +484,12 @@ def torch_forward( else: # Init cache if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) conv_states = nn.functional.pad( - hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + hidden_states_B_C, (self.conv_kernel_size - hidden_states_B_C.shape[-1], 0) ) - conv_states = cache_params.update_conv_state(self.layer_idx, conv_states, cache_init=True) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) - hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C)[..., :seq_len].transpose(1, 2)) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( @@ -636,7 +502,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if use_precomputed_states: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.ssm_states[self.layer_idx].device + cache_device = cache_params.layers[self.layer_idx].recurrent_states.device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -666,9 +532,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -677,7 +542,7 @@ def torch_forward( C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_states.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -738,10 +603,7 @@ def torch_forward( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - if use_precomputed_states: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) decay_chunk = decay_chunk.transpose(1, 3) @@ -768,7 +630,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) if self.mamba_rms_norm: scan_output = self.norm(y, gate) @@ -785,7 +647,7 @@ def torch_forward( def forward( self, hidden_states, - cache_params: FalconHybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): @@ -839,7 +701,7 @@ def forward( attention_mask: torch.Tensor | None = None, mamba_attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: FalconHybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs, @@ -849,7 +711,7 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`FalconHybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -1000,7 +862,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: FalconHybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -1013,10 +875,7 @@ def forward( hidden_states = inputs_embeds if use_cache and past_key_values is None: - logger.warning_once( - "FalconH1 requires an initialized `FalconHybridMambaAttentionDynamicCache` to return a cache. None was " - "provided, so no cache will be returned." - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1048,14 +907,9 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - - next_cache = None if not use_cache else past_key_values - return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, ) def _update_mamba_mask(self, attention_mask, past_key_values): @@ -1065,7 +919,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None @@ -1080,7 +934,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: FalconHybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, @@ -1142,18 +996,6 @@ def prepare_inputs_for_generation( is_first_iteration=False, **kwargs, ): - # Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache` - - if past_key_values is None: - past_key_values = FalconHybridMambaAttentionDynamicCache( - self.config, - input_ids.shape[0], - self.dtype, - devices=[ - self.model.layers[i].mamba.conv1d.weight.device for i in range(self.config.num_hidden_layers) - ], - ) - kwargs["logits_to_keep"] = self.config.num_logits_to_keep model_inputs = super().prepare_inputs_for_generation( input_ids, diff --git a/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py b/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py index 5946e81b136d..5f314970ce1c 100644 --- a/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py @@ -105,5 +105,9 @@ def __post_init__(self, **kwargs): ) super().__post_init__(**kwargs) + @property + def layer_types(self): + return ["mamba"] * self.num_hidden_layers + __all__ = ["FalconMambaConfig"] diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index f80d0f7ca06f..32fd1bf4a358 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -27,7 +27,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...configuration_utils import PreTrainedConfig +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import lazy_load_kernel from ...modeling_layers import GradientCheckpointingLayer @@ -57,117 +57,6 @@ logger = logging.get_logger(__name__) -class FalconMambaCache: - """ - Cache for falcon_mamba model which does not have attention mechanism and key value states. - - Arguments: - config (`PreTrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if - a smaller batch size is used. - dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): - The default `dtype` to use when initializing the layer. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. Should be the same as the layer. - - Example: - - ```python - >>> import torch - >>> from transformers import AutoTokenizer, FalconMambaForCausalLM, FalconMambaCache - - >>> model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b") - >>> tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b") - - >>> inputs = tokenizer(text="My name is FalconMamba", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> cache_params = FalconMambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, cache_params=cache_params, use_cache=True) - >>> outputs.cache_params - ``` - """ - - is_compileable = True - - # TODO (joao): add layer_device_map arg and update code in `generate` accordingly - def __init__( - self, - config: PreTrainedConfig, - max_batch_size: int, - dtype: torch.dtype = torch.float16, - device: torch.device | str | None = None, - ): - self.max_batch_size = max_batch_size - self._dtype = dtype - self.intermediate_size = config.intermediate_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_kernel - self.has_previous_state = False - - self.conv_states: list[torch.Tensor] = [] - self.ssm_states: list[torch.Tensor] = [] - device = torch.device(device) if device is not None else None - for _ in range(config.num_hidden_layers): - conv_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.conv_kernel_size, - device=device, - dtype=self._dtype, - ) - ssm_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.ssm_state_size, - device=device, - dtype=self._dtype, - ) - - torch._dynamo.mark_static_address(conv_state) - torch._dynamo.mark_static_address(ssm_state) - self.conv_states.append(conv_state) - self.ssm_states.append(ssm_state) - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used - # when the cache is initialized in the forward pass (e.g. FalconMamba) - if self.conv_states[layer_idx].device != new_conv_state.device: - self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) - - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now - if cache_init: - self.conv_states[layer_idx].copy_(new_conv_state) - else: - conv_state = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - conv_state[:, :, -1:] = new_conv_state - self.conv_states[layer_idx].copy_(conv_state) - - # If last layer is updated, set the flag - if layer_idx == len(self.conv_states) - 1: - self.has_previous_state = True - - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx].zero_() - self.ssm_states[layer_idx] += new_ssm_state.to(self.ssm_states[layer_idx].device) - return self.ssm_states[layer_idx] - - def reset(self): - self.has_previous_state = False - for layer_idx in range(len(self.conv_states)): - # In-place ops prevent breaking the static address - self.conv_states[layer_idx].zero_() - self.ssm_states[layer_idx].zero_() - - def rms_forward(hidden_states, variance_epsilon=1e-6): """ Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will @@ -310,7 +199,7 @@ def warn_slow_implementation(self): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): # 1. Gated MLP's linear projection @@ -342,14 +231,14 @@ def cuda_kernels_forward( if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) - is_decoding = cache_params is not None and cache_params.has_previous_state + is_decoding = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if is_decoding: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, conv_weights, self.conv1d.bias, self.activation, @@ -360,7 +249,7 @@ def cuda_kernels_forward( conv_states = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.update_conv_state(self.layer_idx, conv_states, cache_init=True) + cache_params.update_conv_state(conv_states, self.layer_idx) hidden_states = causal_conv1d_fn( hidden_states, conv_weights, self.conv1d.bias, activation=self.activation ) @@ -391,7 +280,7 @@ def cuda_kernels_forward( time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None if is_decoding: scan_outputs = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].recurrent_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -416,7 +305,7 @@ def cuda_kernels_forward( return_last_state=True, ) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(self.layer_idx, ssm_state) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) @@ -425,7 +314,7 @@ def cuda_kernels_forward( # fmt: off def slow_forward(self, input_states, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): batch_size, seq_len, _ = input_states.shape @@ -437,19 +326,24 @@ def slow_forward(self, if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): + ssm_state = cache_params.layers[self.layer_idx].recurrent_states.clone() + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype + ) + # 2. Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx].clone() - ssm_state = ssm_state.to(hidden_states.device) - if not cache_params.has_previous_state: + if not cache_params.has_previous_state(self.layer_idx): conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) - cache_params.update_conv_state(self.layer_idx, conv_state, cache_init=True) + cache_params.update_conv_state(conv_state, self.layer_idx) hidden_states = self.act( self.conv1d(hidden_states)[..., :seq_len] ) # [batch, intermediate_size, seq_len] else: - conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_init=False) + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) conv_state = conv_state.to(self.conv1d.weight.device) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: @@ -458,9 +352,6 @@ def slow_forward(self, self.act(hidden_states).to(dtype).unsqueeze(-1) ) # [batch, intermediate_size, 1] : decoding else: - ssm_state = torch.zeros( - (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype - ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] if attention_mask is not None: @@ -538,7 +429,7 @@ def combine_fn(left, right): scan_output = scan_output * self.act(gate) if cache_params is not None: - cache_params.update_ssm_state(self.layer_idx, ssm_state) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] @@ -548,7 +439,7 @@ def combine_fn(left, right): def forward( self, hidden_states, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, **kwargs, ): @@ -590,7 +481,7 @@ def __init__(self, config, layer_idx): def forward( self, hidden_states, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, **kwargs, ): @@ -658,7 +549,7 @@ def _init_weights(self, module): ) class FalconMambaOutput(ModelOutput): r""" - cache_params (`FalconMambaCache`): + cache_params (`Cache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -666,7 +557,7 @@ class FalconMambaOutput(ModelOutput): """ last_hidden_state: torch.FloatTensor | None = None - cache_params: FalconMambaCache | None = None + cache_params: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None @@ -682,7 +573,7 @@ class FalconMambaCausalLMOutput(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - cache_params (`FalconMambaCache`): + cache_params (`Cache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -691,7 +582,7 @@ class FalconMambaCausalLMOutput(ModelOutput): loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None - cache_params: FalconMambaCache | None = None + cache_params: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None @@ -721,7 +612,7 @@ def forward( self, input_ids: torch.LongTensor | None = None, inputs_embeds: torch.LongTensor | None = None, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, use_cache: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, @@ -729,7 +620,7 @@ def forward( **kwargs, ) -> tuple | FalconMambaOutput: r""" - cache_params (`FalconMambaCache`, *optional*): + cache_params (`Cache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). use_cache (`bool`, *optional*): @@ -751,9 +642,7 @@ def forward( use_cache = False if use_cache and cache_params is None: - cache_params = FalconMambaCache( - self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype - ) + cache_params = DynamicCache(config=self.config) hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None @@ -809,12 +698,11 @@ def prepare_inputs_for_generation( input_ids, inputs_embeds=None, use_cache=None, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, is_first_iteration: bool | None = False, **kwargs, ): - # Overwritten -- has custom cache class `FalconMambaCache` model_inputs = super().prepare_inputs_for_generation( input_ids, inputs_embeds=inputs_embeds, @@ -825,15 +713,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if use_cache and cache_params is None: - if inputs_embeds is not None: - max_batch_size = inputs_embeds.size(0) - else: - max_batch_size = input_ids.size(0) - model_inputs["cache_params"] = FalconMambaCache( - self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype - ) - elif use_cache and not is_first_iteration: + if use_cache and not is_first_iteration: model_inputs["attention_mask"] = None return model_inputs @@ -844,7 +724,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, labels: torch.LongTensor | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, @@ -853,7 +733,7 @@ def forward( **kwargs, # for now we need this for generation ) -> tuple | FalconMambaCausalLMOutput: r""" - cache_params (`FalconMambaCache`, *optional*): + cache_params (`Cache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -903,4 +783,4 @@ def forward( ) -__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel", "FalconMambaCache"] +__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel"] diff --git a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py index ca3f337a940e..6962c66c2973 100644 --- a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -18,12 +18,12 @@ from torch import nn from ... import initialization as init +from ...cache_utils import Cache from ...utils import auto_docstring, logging from ...utils.import_utils import is_mambapy_available, is_torch_greater_or_equal, is_torchdynamo_compiling, is_tracing from ..mamba.configuration_mamba import MambaConfig from ..mamba.modeling_mamba import ( MambaBlock, - MambaCache, MambaCausalLMOutput, MambaForCausalLM, MambaMixer, @@ -101,39 +101,9 @@ class FalconMambaConfig(MambaConfig): use_associative_scan: bool = True mixer_rms_eps: float = 1e-6 - -class FalconMambaCache(MambaCache): - """ - Cache for falcon_mamba model which does not have attention mechanism and key value states. - - Arguments: - config (`PreTrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if - a smaller batch size is used. - dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): - The default `dtype` to use when initializing the layer. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. Should be the same as the layer. - - Example: - - ```python - >>> import torch - >>> from transformers import AutoTokenizer, FalconMambaForCausalLM, FalconMambaCache - - >>> model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b") - >>> tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b") - - >>> inputs = tokenizer(text="My name is FalconMamba", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> cache_params = FalconMambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, cache_params=cache_params, use_cache=True) - >>> outputs.cache_params - ``` - """ + @property + def layer_types(self): + return ["mamba"] * self.num_hidden_layers def rms_forward(hidden_states, variance_epsilon=1e-6): @@ -194,7 +164,7 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int, initialize_mixer_w def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): # 1. Gated MLP's linear projection @@ -226,14 +196,14 @@ def cuda_kernels_forward( if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) - is_decoding = cache_params is not None and cache_params.has_previous_state + is_decoding = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if is_decoding: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, conv_weights, self.conv1d.bias, self.activation, @@ -244,7 +214,7 @@ def cuda_kernels_forward( conv_states = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.update_conv_state(self.layer_idx, conv_states, cache_init=True) + cache_params.update_conv_state(conv_states, self.layer_idx) hidden_states = causal_conv1d_fn( hidden_states, conv_weights, self.conv1d.bias, activation=self.activation ) @@ -275,7 +245,7 @@ def cuda_kernels_forward( time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None if is_decoding: scan_outputs = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].recurrent_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -300,7 +270,7 @@ def cuda_kernels_forward( return_last_state=True, ) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(self.layer_idx, ssm_state) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) @@ -309,7 +279,7 @@ def cuda_kernels_forward( def slow_forward( self, input_states, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): batch_size, seq_len, _ = input_states.shape @@ -321,19 +291,24 @@ def slow_forward( if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): + ssm_state = cache_params.layers[self.layer_idx].recurrent_states.clone() + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype + ) + # 2. Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx].clone() - ssm_state = ssm_state.to(hidden_states.device) - if not cache_params.has_previous_state: + if not cache_params.has_previous_state(self.layer_idx): conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) - cache_params.update_conv_state(self.layer_idx, conv_state, cache_init=True) + cache_params.update_conv_state(conv_state, self.layer_idx) hidden_states = self.act( self.conv1d(hidden_states)[..., :seq_len] ) # [batch, intermediate_size, seq_len] else: - conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_init=False) + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) conv_state = conv_state.to(self.conv1d.weight.device) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: @@ -342,9 +317,6 @@ def slow_forward( self.act(hidden_states).to(dtype).unsqueeze(-1) ) # [batch, intermediate_size, 1] : decoding else: - ssm_state = torch.zeros( - (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype - ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] if attention_mask is not None: @@ -422,7 +394,7 @@ def combine_fn(left, right): scan_output = scan_output * self.act(gate) if cache_params is not None: - cache_params.update_ssm_state(self.layer_idx, ssm_state) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] @@ -431,7 +403,7 @@ def combine_fn(left, right): def forward( self, hidden_states, - cache_params: FalconMambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, **kwargs, ): @@ -499,6 +471,5 @@ class FalconMambaForCausalLM(MambaForCausalLM): "FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel", - "FalconMambaCache", "FalconMambaConfig", ] diff --git a/src/transformers/models/granitemoehybrid/configuration_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/configuration_granitemoehybrid.py index d17bd4f393ed..843630333f59 100644 --- a/src/transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/configuration_granitemoehybrid.py @@ -51,9 +51,7 @@ class GraniteMoeHybridConfig(PreTrainedConfig): ```""" model_type = "granitemoehybrid" - attribute_map = { - "layers_block_type": "layer_types", - } + attribute_map = {"layers_block_type": "layer_types"} keys_to_ignore_at_inference = ["past_key_values"] vocab_size: int = 32000 @@ -122,10 +120,5 @@ def validate_architecture(self): if self.mamba_d_head * self.mamba_n_heads != mamba_intermediate: raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size") - # overwrite the function to use in `HybridMambaAttentionDynamicCache` - @property - def layers_block_type(self): - return self.layer_types - __all__ = ["GraniteMoeHybridConfig"] diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 676ec8f93773..dadffaea0072 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -19,16 +19,15 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable -from typing import Any, Optional, TypedDict +from typing import Optional, TypedDict import torch from torch import nn from torch.nn import functional as F -from transformers.activations import ACT2FN - from ... import initialization as init -from ...cache_utils import Cache +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func from ...integrations.hub_kernels import lazy_load_kernel @@ -187,112 +186,6 @@ def forward( return attn_output, attn_weights -class HybridMambaAttentionDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config: GraniteMoeHybridConfig, batch_size, dtype=torch.float16, device=None): - self.layers_block_type = config.layers_block_type - self.has_previous_state = False # only used by mamba - conv_kernel_size = config.mamba_d_conv - ssm_state_size = config.mamba_d_state - - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - if self.layers_block_type[i] == "mamba": - self.conv_states += [ - torch.zeros( - batch_size, - (config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * ssm_state_size), - conv_kernel_size, - device=device, - dtype=dtype, - ) - ] - self.ssm_states += [ - torch.zeros( - batch_size, - config.mamba_n_heads, - config.mamba_d_head, - ssm_state_size, - device=device, - dtype=dtype, - ) - ] - else: - self.conv_states += [torch.tensor([[]] * batch_size, device=device)] - self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.key_cache) - - def __getitem__(self, layer_idx): - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - if self.get_seq_length() > 0: - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].shape[-1] == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - # Helper methods for segment sum computation @@ -469,7 +362,7 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, seq_idx: torch.IntTensor | None = None, ): @@ -482,12 +375,7 @@ def cuda_kernels_forward( groups_time_state_size = self.n_groups * self.ssm_state_size use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) # getting projected states from cache if it exists @@ -499,7 +387,7 @@ def cuda_kernels_forward( # 2. Convolution sequence transformation hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -521,7 +409,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -581,7 +469,7 @@ def cuda_kernels_forward( hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), ) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) if self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( @@ -622,7 +510,7 @@ def cuda_kernels_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer @@ -636,7 +524,7 @@ def cuda_kernels_forward( def torch_forward( self, input_states, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): batch_size, seq_len, _ = input_states.shape @@ -648,23 +536,13 @@ def torch_forward( gate, hidden_states_B_C, dt = projected_states.split( [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + hidden_states_B_C = hidden_states_B_C.transpose(1,2) - use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size - ) + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 # 2. Convolution sequence transformation if use_precomputed_states: - cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1) - cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device) - - # We need to guarantee that anything regarding the cache is on the same device - conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + conv_states = cache_params.update_conv_state(hidden_states_B_C, self.layer_idx) hidden_states_B_C = torch.sum( conv_states * self.conv1d.weight.squeeze(1), dim=-1 @@ -675,13 +553,12 @@ def torch_forward( else: # Init cache if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) conv_states = nn.functional.pad( - hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + hidden_states_B_C, (self.conv_kernel_size - hidden_states_B_C.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) - hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C)[..., :seq_len].transpose(1, 2)) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( @@ -694,7 +571,7 @@ def torch_forward( A = -torch.exp(self.A_log.float()) # [num_heads] if use_precomputed_states: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.ssm_states[self.layer_idx].device + cache_device = cache_params.layers[self.layer_idx].recurrent_states.device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -724,9 +601,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -735,7 +611,7 @@ def torch_forward( C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_states.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -796,10 +672,7 @@ def torch_forward( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - if use_precomputed_states: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) decay_chunk = decay_chunk.transpose(1, 3) @@ -826,7 +699,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) @@ -840,7 +713,7 @@ def torch_forward( def forward( self, hidden_states, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, seq_idx: torch.IntTensor | None = None, **kwargs, @@ -1293,6 +1166,9 @@ def forward( inputs_embeds = inputs_embeds * self.embedding_multiplier + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens @@ -1326,9 +1202,6 @@ def forward( ) hidden_states = self.norm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, @@ -1341,7 +1214,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None @@ -1534,36 +1407,5 @@ def forward( router_logits=outputs.router_logits, ) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - position_ids=None, - use_cache=True, - is_first_iteration=False, - **kwargs, - ): - # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` - - if past_key_values is None and use_cache: - past_key_values = HybridMambaAttentionDynamicCache( - self.config, input_ids.shape[0], self.dtype, device=self.device - ) - - model_inputs = super().prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - use_cache=use_cache, - is_first_iteration=is_first_iteration, - **kwargs, - ) - - return model_inputs - __all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"] diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index 4c72531bddb5..93291275b5b7 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -18,7 +18,7 @@ from torch import nn from ... import initialization as init -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -27,7 +27,7 @@ from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..bamba.configuration_bamba import BambaConfig -from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache +from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated from ..gemma2.modeling_gemma2 import Gemma2RotaryEmbedding from ..granitemoeshared.modeling_granitemoeshared import ( GraniteFlashAttentionKwargs, @@ -226,6 +226,9 @@ def forward( inputs_embeds = inputs_embeds * self.embedding_multiplier + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens @@ -259,9 +262,6 @@ def forward( ) hidden_states = self.norm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, @@ -274,7 +274,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None @@ -315,36 +315,5 @@ def forward(self, **super_kwargs): ```""" return super().forward(**super_kwargs) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - position_ids=None, - use_cache=True, - is_first_iteration=False, - **kwargs, - ): - # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` - - if past_key_values is None and use_cache: - past_key_values = HybridMambaAttentionDynamicCache( - self.config, input_ids.shape[0], self.dtype, device=self.device - ) - - model_inputs = super().prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - use_cache=use_cache, - is_first_iteration=is_first_iteration, - **kwargs, - ) - - return model_inputs - __all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"] diff --git a/src/transformers/models/jamba/configuration_jamba.py b/src/transformers/models/jamba/configuration_jamba.py index 17d384e1eba6..0e0465b0ed1b 100644 --- a/src/transformers/models/jamba/configuration_jamba.py +++ b/src/transformers/models/jamba/configuration_jamba.py @@ -93,6 +93,12 @@ def layers_block_type(self): for i in range(self.num_hidden_layers) ] + @property + def layer_types(self): + # Follow the `layer_types` conventions + layer_types = self.layers_block_type + return ["full_attention" if x == "attention" else x for x in layer_types] + @property def layers_num_experts(self): return [ diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 3d7986933395..ae618fb4a2b3 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -23,13 +23,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable -from typing import Any import torch from torch import nn from ... import initialization as init from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import ( lazy_load_kernel, @@ -74,100 +74,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class HybridMambaAttentionDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config, batch_size, dtype=torch.float16, device=None): - self.dtype = dtype - self.layers_block_type = config.layers_block_type - self.has_previous_state = False # only used by mamba - intermediate_size = config.mamba_expand * config.hidden_size - ssm_state_size = config.mamba_d_state - conv_kernel_size = config.mamba_d_conv - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - if self.layers_block_type[i] == "mamba": - self.conv_states += [ - torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) - ] - self.ssm_states += [ - torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) - ] - else: - self.conv_states += [torch.tensor([[]] * batch_size, device=device)] - self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.key_cache) - - def __getitem__(self, layer_idx): - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - if self.get_seq_length() > 0: - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].shape[-1] == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -260,7 +166,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] @@ -371,17 +277,12 @@ def __init__(self, config: JambaConfig, layer_idx): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): batch_size, seq_len, _ = hidden_states.shape use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) @@ -398,7 +299,7 @@ def cuda_kernels_forward( if use_precomputed_states: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, conv_weights, self.conv1d.bias, self.activation, @@ -407,7 +308,7 @@ def cuda_kernels_forward( else: if cache_params is not None: conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + cache_params.update_conv_state(conv_states, self.layer_idx) hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation) if attention_mask is not None: @@ -442,7 +343,7 @@ def cuda_kernels_forward( time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None if use_precomputed_states: scan_outputs = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].recurrent_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -467,7 +368,7 @@ def cuda_kernels_forward( return_last_state=True, ) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) @@ -475,7 +376,7 @@ def cuda_kernels_forward( return contextualized_states # fmt: off - def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache | None = None, attention_mask: torch.LongTensor | None = None): + def slow_forward(self, input_states, cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated MLP's linear projection @@ -485,23 +386,19 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) - use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache) - # 2. Convolution sequence transformation - if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size: - if self.training: - # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass - ssm_state = cache_params.ssm_states[self.layer_idx].clone() - else: - ssm_state = cache_params.ssm_states[self.layer_idx] - - ssm_state = ssm_state.to(hidden_states.device) + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): + # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass + ssm_state = cache_params.layers[self.layer_idx].recurrent_states.clone() + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) - if cache_params.has_previous_state and seq_len == 1 and \ - cache_params.conv_states[self.layer_idx].shape[0] == batch_size: - conv_state = cache_params.conv_states[self.layer_idx] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) - conv_state[:, :, -1] = hidden_states[:, :, 0] - cache_params.conv_states[self.layer_idx] = conv_state + # 2. Convolution sequence transformation + if cache_params is not None: + if cache_params.has_previous_state(self.layer_idx) and seq_len == 1: + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: hidden_states += self.conv1d.bias @@ -511,13 +408,9 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx] = conv_state + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) else: - ssm_state = torch.zeros( - (batch_size, self.intermediate_size, self.ssm_state_size), - device=hidden_states.device, dtype=dtype - ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) if attention_mask is not None: @@ -552,8 +445,8 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa scan_output = scan_output + (hidden_states * self.D[None, :, None]) scan_output = (scan_output * self.act(gate)) - if use_cache: - cache_params.ssm_states[self.layer_idx] = ssm_state + if cache_params is not None: + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) @@ -563,7 +456,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa def forward( self, hidden_states, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): if self.config.use_mamba_kernels and ( @@ -690,7 +583,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: @@ -727,7 +620,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: residual = hidden_states @@ -804,7 +697,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -816,12 +709,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = HybridMambaAttentionDynamicCache( - config=self.config, - batch_size=inputs_embeds.shape[0], - dtype=inputs_embeds.dtype, - device=inputs_embeds.device, - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -851,9 +739,6 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, @@ -866,7 +751,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None @@ -980,7 +865,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index da80fbd7187e..21e6623d3296 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -17,13 +17,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable -from typing import Any import torch from torch import nn from ... import initialization as init from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...integrations import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer @@ -47,100 +47,6 @@ class JambaRMSNorm(LlamaRMSNorm): pass -class HybridMambaAttentionDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config, batch_size, dtype=torch.float16, device=None): - self.dtype = dtype - self.layers_block_type = config.layers_block_type - self.has_previous_state = False # only used by mamba - intermediate_size = config.mamba_expand * config.hidden_size - ssm_state_size = config.mamba_d_state - conv_kernel_size = config.mamba_d_conv - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - if self.layers_block_type[i] == "mamba": - self.conv_states += [ - torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) - ] - self.ssm_states += [ - torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) - ] - else: - self.conv_states += [torch.tensor([[]] * batch_size, device=device)] - self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.key_cache) - - def __getitem__(self, layer_idx): - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - if self.get_seq_length() > 0: - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].shape[-1] == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - class JambaAttention(LlamaAttention): def __init__(self, config: JambaConfig, layer_idx: int): super().__init__(config, layer_idx) @@ -153,7 +59,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] @@ -264,17 +170,12 @@ def __init__(self, config: JambaConfig, layer_idx): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): batch_size, seq_len, _ = hidden_states.shape use_precomputed_states = ( - cache_params is not None - and cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] - == cache_params.ssm_states[self.layer_idx].shape[0] - == batch_size + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) @@ -291,7 +192,7 @@ def cuda_kernels_forward( if use_precomputed_states: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, conv_weights, self.conv1d.bias, self.activation, @@ -300,7 +201,7 @@ def cuda_kernels_forward( else: if cache_params is not None: conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + cache_params.update_conv_state(conv_states, self.layer_idx) hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation) if attention_mask is not None: @@ -335,7 +236,7 @@ def cuda_kernels_forward( time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None if use_precomputed_states: scan_outputs = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].recurrent_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -360,7 +261,7 @@ def cuda_kernels_forward( return_last_state=True, ) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) @@ -368,7 +269,7 @@ def cuda_kernels_forward( return contextualized_states # fmt: off - def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache | None = None, attention_mask: torch.LongTensor | None = None): + def slow_forward(self, input_states, cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated MLP's linear projection @@ -378,23 +279,19 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) - use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache) - # 2. Convolution sequence transformation - if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size: - if self.training: - # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass - ssm_state = cache_params.ssm_states[self.layer_idx].clone() - else: - ssm_state = cache_params.ssm_states[self.layer_idx] - - ssm_state = ssm_state.to(hidden_states.device) + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): + # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass + ssm_state = cache_params.layers[self.layer_idx].recurrent_states.clone() + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) - if cache_params.has_previous_state and seq_len == 1 and \ - cache_params.conv_states[self.layer_idx].shape[0] == batch_size: - conv_state = cache_params.conv_states[self.layer_idx] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) - conv_state[:, :, -1] = hidden_states[:, :, 0] - cache_params.conv_states[self.layer_idx] = conv_state + # 2. Convolution sequence transformation + if cache_params is not None: + if cache_params.has_previous_state(self.layer_idx) and seq_len == 1: + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: hidden_states += self.conv1d.bias @@ -404,13 +301,9 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx] = conv_state + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) else: - ssm_state = torch.zeros( - (batch_size, self.intermediate_size, self.ssm_state_size), - device=hidden_states.device, dtype=dtype - ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) if attention_mask is not None: @@ -445,8 +338,8 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa scan_output = scan_output + (hidden_states * self.D[None, :, None]) scan_output = (scan_output * self.act(gate)) - if use_cache: - cache_params.ssm_states[self.layer_idx] = ssm_state + if cache_params is not None: + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) @@ -456,7 +349,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa def forward( self, hidden_states, - cache_params: HybridMambaAttentionDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): if self.config.use_mamba_kernels and ( @@ -535,7 +428,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: @@ -572,7 +465,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: residual = hidden_states @@ -649,7 +542,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -661,12 +554,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = HybridMambaAttentionDynamicCache( - config=self.config, - batch_size=inputs_embeds.shape[0], - dtype=inputs_embeds.dtype, - device=inputs_embeds.device, - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -696,9 +584,6 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, @@ -711,7 +596,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None @@ -728,7 +613,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: HybridMambaAttentionDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index ad3d154c2d06..ef753e3b2893 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -18,13 +18,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable -from typing import Any, Optional +from typing import Optional import torch import torch.nn.functional as F from torch import nn -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func from ...masking_utils import create_causal_mask @@ -152,160 +152,6 @@ def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) -class Lfm2HybridConvCache: - """ - Attention and conv cache for Lfm2. - - It stores the Key and Value states as a list of tensors, one for each layer. - Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`. - Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`. - """ - - # Override @property existing in Cache - max_batch_size = None - is_compileable = False - key_cache = None - value_cache = None - - def __init__( - self, - config: Lfm2Config, - max_batch_size: int, - dtype: torch.dtype = torch.float32, - device: torch.device | str | None = None, - ): - self.key_cache = [] - self.value_cache = [] - self.max_batch_size = max_batch_size - self.layer_types = config.layer_types - self.first_attention_layer = self.layer_types.index("full_attention") - self.last_conv_layer = len(self.layer_types) - self.layer_types[::-1].index("conv") - 1 - self.conv_L_cache = config.conv_L_cache - self._dtype = dtype - self.has_previous_state = False - - self.conv_cache: list[torch.Tensor] = [] - device = torch.device(device) if device is not None else None - - for _ in range(config.num_hidden_layers): - conv_state = torch.zeros( - self.max_batch_size, - config.hidden_size, - self.conv_L_cache, - dtype=self._dtype, - device=device, - ) - self.conv_cache.append(conv_state) - self.key_cache.append(torch.tensor([])) - self.value_cache.append(torch.tensor([])) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if self.key_cache[layer_idx].numel() == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now - if cache_init: - self.conv_cache[layer_idx] = new_conv_state.to(self.conv_cache[layer_idx].device) - else: - self.conv_cache[layer_idx] = self.conv_cache[layer_idx].roll(shifts=-1, dims=-1) - self.conv_cache[layer_idx][:, :, -1] = new_conv_state[:, :, -1].to(self.conv_cache[layer_idx].device) - - # If last layer is updated, set the flag - if layer_idx == self.last_conv_layer: - self.has_previous_state = True - - return self.conv_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - if self.conv_cache[layer_idx].numel(): - device = self.conv_cache[layer_idx].device - self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.first_attention_layer if self.layer_types[layer_idx] != "full_attention" else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - full_mask_kv_offset = 0 - past_seen_tokens = self.get_seq_length() - kv_length = query_length + past_seen_tokens - return kv_length, full_mask_kv_offset - - def crop(self, max_length: int): - """Crop the cache to the given length""" - if max_length < 0: - max_length = self.get_seq_length() - abs(max_length) - - if self.get_seq_length() <= max_length: - return - - for idx in range(len(self.key_cache)): - if self.key_cache[idx].numel(): - self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] - self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] - - def __len__(self) -> int: - return len(self.key_cache) - - def reset(self): - self.has_previous_state = False - for layer_idx in range(len(self.conv_cache)): - # In-place ops prevent breaking the static address - self.conv_cache[layer_idx].zero_() - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -400,7 +246,7 @@ def forward( hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] @@ -477,7 +323,7 @@ def __init__( def cuda_kernels_forward( self, x: torch.Tensor, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): x = apply_mask_to_padding_states(x, attention_mask) @@ -487,10 +333,10 @@ def cuda_kernels_forward( Bx = B * x conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2)) - if past_key_values is not None and past_key_values.has_previous_state: + if past_key_values is not None and past_key_values.has_previous_state(self.layer_idx): conv_out = causal_conv1d_update( Bx.squeeze(-1), - past_key_values.conv_cache[self.layer_idx], + past_key_values.layers[self.layer_idx].conv_states, conv_weights, self.conv.bias, None, @@ -499,7 +345,7 @@ def cuda_kernels_forward( else: if past_key_values is not None: conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) - past_key_values.update_conv_state(self.layer_idx, conv_state, cache_init=True) + conv_state = past_key_values.update_conv_state(conv_state, self.layer_idx) conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None) @@ -510,7 +356,7 @@ def cuda_kernels_forward( def slow_forward( self, x: torch.Tensor, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): seqlen = x.shape[1] @@ -521,8 +367,8 @@ def slow_forward( Bx = B * x - if past_key_values is not None and past_key_values.has_previous_state: - conv_state = past_key_values.update_conv_state(self.layer_idx, Bx, cache_init=False) + if past_key_values is not None and past_key_values.has_previous_state(self.layer_idx): + conv_state = past_key_values.update_conv_state(Bx, self.layer_idx) conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1) if self.bias: conv_out += self.conv.bias @@ -531,7 +377,7 @@ def slow_forward( else: if past_key_values is not None: conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) - conv_state = past_key_values.update_conv_state(self.layer_idx, conv_state, cache_init=True) + conv_state = past_key_values.update_conv_state(conv_state, self.layer_idx) conv_out = self.conv(Bx)[..., :seqlen] @@ -543,7 +389,7 @@ def slow_forward( def forward( self, hidden_states: torch.Tensor, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): if is_fast_path_available and "cuda" in hidden_states.device.type and not is_torchdynamo_compiling(): @@ -570,7 +416,7 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, **kwargs, ) -> torch.Tensor: residual = hidden_states @@ -639,7 +485,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -651,10 +497,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - batch_size = inputs_embeds.shape[0] - past_key_values = Lfm2HybridConvCache( - config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index 6f3a754d69ae..c08b61956081 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable -from typing import Any import torch import torch.nn.functional as F from torch import nn +from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast @@ -80,160 +80,6 @@ def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) -class Lfm2HybridConvCache: - """ - Attention and conv cache for Lfm2. - - It stores the Key and Value states as a list of tensors, one for each layer. - Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`. - Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`. - """ - - # Override @property existing in Cache - max_batch_size = None - is_compileable = False - key_cache = None - value_cache = None - - def __init__( - self, - config: Lfm2Config, - max_batch_size: int, - dtype: torch.dtype = torch.float32, - device: torch.device | str | None = None, - ): - self.key_cache = [] - self.value_cache = [] - self.max_batch_size = max_batch_size - self.layer_types = config.layer_types - self.first_attention_layer = self.layer_types.index("full_attention") - self.last_conv_layer = len(self.layer_types) - self.layer_types[::-1].index("conv") - 1 - self.conv_L_cache = config.conv_L_cache - self._dtype = dtype - self.has_previous_state = False - - self.conv_cache: list[torch.Tensor] = [] - device = torch.device(device) if device is not None else None - - for _ in range(config.num_hidden_layers): - conv_state = torch.zeros( - self.max_batch_size, - config.hidden_size, - self.conv_L_cache, - dtype=self._dtype, - device=device, - ) - self.conv_cache.append(conv_state) - self.key_cache.append(torch.tensor([])) - self.value_cache.append(torch.tensor([])) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if self.key_cache[layer_idx].numel() == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now - if cache_init: - self.conv_cache[layer_idx] = new_conv_state.to(self.conv_cache[layer_idx].device) - else: - self.conv_cache[layer_idx] = self.conv_cache[layer_idx].roll(shifts=-1, dims=-1) - self.conv_cache[layer_idx][:, :, -1] = new_conv_state[:, :, -1].to(self.conv_cache[layer_idx].device) - - # If last layer is updated, set the flag - if layer_idx == self.last_conv_layer: - self.has_previous_state = True - - return self.conv_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - if self.conv_cache[layer_idx].numel(): - device = self.conv_cache[layer_idx].device - self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.first_attention_layer if self.layer_types[layer_idx] != "full_attention" else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - full_mask_kv_offset = 0 - past_seen_tokens = self.get_seq_length() - kv_length = query_length + past_seen_tokens - return kv_length, full_mask_kv_offset - - def crop(self, max_length: int): - """Crop the cache to the given length""" - if max_length < 0: - max_length = self.get_seq_length() - abs(max_length) - - if self.get_seq_length() <= max_length: - return - - for idx in range(len(self.key_cache)): - if self.key_cache[idx].numel(): - self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] - self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] - - def __len__(self) -> int: - return len(self.key_cache) - - def reset(self): - self.has_previous_state = False - for layer_idx in range(len(self.conv_cache)): - # In-place ops prevent breaking the static address - self.conv_cache[layer_idx].zero_() - - class Lfm2Attention(LlamaAttention): def __init__(self, config: Lfm2Config, layer_idx: int): super().__init__(config, layer_idx) @@ -251,7 +97,7 @@ def forward( hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] @@ -312,7 +158,7 @@ def __init__( def cuda_kernels_forward( self, x: torch.Tensor, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): x = apply_mask_to_padding_states(x, attention_mask) @@ -322,10 +168,10 @@ def cuda_kernels_forward( Bx = B * x conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2)) - if past_key_values is not None and past_key_values.has_previous_state: + if past_key_values is not None and past_key_values.has_previous_state(self.layer_idx): conv_out = causal_conv1d_update( Bx.squeeze(-1), - past_key_values.conv_cache[self.layer_idx], + past_key_values.layers[self.layer_idx].conv_states, conv_weights, self.conv.bias, None, @@ -334,7 +180,7 @@ def cuda_kernels_forward( else: if past_key_values is not None: conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) - past_key_values.update_conv_state(self.layer_idx, conv_state, cache_init=True) + conv_state = past_key_values.update_conv_state(conv_state, self.layer_idx) conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None) @@ -345,7 +191,7 @@ def cuda_kernels_forward( def slow_forward( self, x: torch.Tensor, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): seqlen = x.shape[1] @@ -356,8 +202,8 @@ def slow_forward( Bx = B * x - if past_key_values is not None and past_key_values.has_previous_state: - conv_state = past_key_values.update_conv_state(self.layer_idx, Bx, cache_init=False) + if past_key_values is not None and past_key_values.has_previous_state(self.layer_idx): + conv_state = past_key_values.update_conv_state(Bx, self.layer_idx) conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1) if self.bias: conv_out += self.conv.bias @@ -366,7 +212,7 @@ def slow_forward( else: if past_key_values is not None: conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) - conv_state = past_key_values.update_conv_state(self.layer_idx, conv_state, cache_init=True) + conv_state = past_key_values.update_conv_state(conv_state, self.layer_idx) conv_out = self.conv(Bx)[..., :seqlen] @@ -378,7 +224,7 @@ def slow_forward( def forward( self, hidden_states: torch.Tensor, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): if is_fast_path_available and "cuda" in hidden_states.device.type and not is_torchdynamo_compiling(): @@ -405,7 +251,7 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, **kwargs, ) -> torch.Tensor: residual = hidden_states @@ -445,7 +291,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Lfm2HybridConvCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -457,10 +303,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - batch_size = inputs_embeds.shape[0] - past_key_values = Lfm2HybridConvCache( - config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py index 2a5d4564a1e1..0369ae31b8ae 100644 --- a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py @@ -19,14 +19,14 @@ # limitations under the License. from collections.abc import Callable -from typing import Any, Optional +from typing import Optional import torch import torch.nn.functional as F from torch import nn from ... import initialization as init -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import ( use_experts_implementation, @@ -228,160 +228,6 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) -class Lfm2MoeHybridConvCache: - """ - Attention and conv cache for Lfm2Moe. - - It stores the Key and Value states as a list of tensors, one for each layer. - Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`. - Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`. - """ - - # Override @property existing in Cache - max_batch_size = None - is_compileable = False - key_cache = None - value_cache = None - - def __init__( - self, - config: Lfm2MoeConfig, - max_batch_size: int, - dtype: torch.dtype = torch.float32, - device: torch.device | str | None = None, - ): - self.key_cache = [] - self.value_cache = [] - self.max_batch_size = max_batch_size - self.layer_types = config.layer_types - self.first_attention_layer = self.layer_types.index("full_attention") - self.last_conv_layer = len(self.layer_types) - self.layer_types[::-1].index("conv") - 1 - self.conv_L_cache = config.conv_L_cache - self._dtype = dtype - self.has_previous_state = False - - self.conv_cache: list[torch.Tensor] = [] - device = torch.device(device) if device is not None else None - - for _ in range(config.num_hidden_layers): - conv_state = torch.zeros( - self.max_batch_size, - config.hidden_size, - self.conv_L_cache, - dtype=self._dtype, - device=device, - ) - self.conv_cache.append(conv_state) - self.key_cache.append(torch.tensor([])) - self.value_cache.append(torch.tensor([])) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if self.key_cache[layer_idx].numel() == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now - if cache_init: - self.conv_cache[layer_idx] = new_conv_state.to(self.conv_cache[layer_idx].device) - else: - self.conv_cache[layer_idx] = self.conv_cache[layer_idx].roll(shifts=-1, dims=-1) - self.conv_cache[layer_idx][:, :, -1] = new_conv_state[:, :, -1].to(self.conv_cache[layer_idx].device) - - # If last layer is updated, set the flag - if layer_idx == self.last_conv_layer: - self.has_previous_state = True - - return self.conv_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - if self.conv_cache[layer_idx].numel(): - device = self.conv_cache[layer_idx].device - self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.first_attention_layer if self.layer_types[layer_idx] != "full_attention" else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - full_mask_kv_offset = 0 - past_seen_tokens = self.get_seq_length() - kv_length = query_length + past_seen_tokens - return kv_length, full_mask_kv_offset - - def crop(self, max_length: int): - """Crop the cache to the given length""" - if max_length < 0: - max_length = self.get_seq_length() - abs(max_length) - - if self.get_seq_length() <= max_length: - return - - for idx in range(len(self.key_cache)): - if self.key_cache[idx].numel(): - self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] - self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] - - def __len__(self) -> int: - return len(self.key_cache) - - def reset(self): - self.has_previous_state = False - for layer_idx in range(len(self.conv_cache)): - # In-place ops prevent breaking the static address - self.conv_cache[layer_idx].zero_() - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -476,7 +322,7 @@ def forward( hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None, - past_key_values: Lfm2MoeHybridConvCache | None = None, + past_key_values: Cache | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] @@ -553,7 +399,7 @@ def __init__( def cuda_kernels_forward( self, x: torch.Tensor, - past_key_values: Lfm2MoeHybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): x = apply_mask_to_padding_states(x, attention_mask) @@ -563,10 +409,10 @@ def cuda_kernels_forward( Bx = B * x conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2)) - if past_key_values is not None and past_key_values.has_previous_state: + if past_key_values is not None and past_key_values.has_previous_state(self.layer_idx): conv_out = causal_conv1d_update( Bx.squeeze(-1), - past_key_values.conv_cache[self.layer_idx], + past_key_values.layers[self.layer_idx].conv_states, conv_weights, self.conv.bias, None, @@ -575,7 +421,7 @@ def cuda_kernels_forward( else: if past_key_values is not None: conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) - past_key_values.update_conv_state(self.layer_idx, conv_state, cache_init=True) + conv_state = past_key_values.update_conv_state(conv_state, self.layer_idx) conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None) @@ -586,7 +432,7 @@ def cuda_kernels_forward( def slow_forward( self, x: torch.Tensor, - past_key_values: Lfm2MoeHybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): seqlen = x.shape[1] @@ -597,8 +443,8 @@ def slow_forward( Bx = B * x - if past_key_values is not None and past_key_values.has_previous_state: - conv_state = past_key_values.update_conv_state(self.layer_idx, Bx, cache_init=False) + if past_key_values is not None and past_key_values.has_previous_state(self.layer_idx): + conv_state = past_key_values.update_conv_state(Bx, self.layer_idx) conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1) if self.bias: conv_out += self.conv.bias @@ -607,7 +453,7 @@ def slow_forward( else: if past_key_values is not None: conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) - conv_state = past_key_values.update_conv_state(self.layer_idx, conv_state, cache_init=True) + conv_state = past_key_values.update_conv_state(conv_state, self.layer_idx) conv_out = self.conv(Bx)[..., :seqlen] @@ -619,7 +465,7 @@ def slow_forward( def forward( self, hidden_states: torch.Tensor, - past_key_values: Lfm2MoeHybridConvCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, ): if is_fast_path_available and "cuda" in hidden_states.device.type and not is_torchdynamo_compiling(): @@ -650,7 +496,7 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Lfm2MoeHybridConvCache | None = None, + past_key_values: Cache | None = None, **kwargs, ) -> torch.Tensor: residual = hidden_states @@ -685,7 +531,7 @@ class Lfm2MoePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _can_compile_fullgraph = False # uses a non-compilable custom cache class Lfm2MoeHybridConvCache + _can_compile_fullgraph = False # uses a non-compilable cache class _supports_attention_backend = True _can_record_outputs = { "hidden_states": Lfm2MoeDecoderLayer, @@ -729,7 +575,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Lfm2MoeHybridConvCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -741,10 +587,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - batch_size = inputs_embeds.shape[0] - past_key_values = Lfm2MoeHybridConvCache( - config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/src/transformers/models/lfm2_moe/modular_lfm2_moe.py b/src/transformers/models/lfm2_moe/modular_lfm2_moe.py index 7d97d8b70dd5..a1b2799d2bae 100644 --- a/src/transformers/models/lfm2_moe/modular_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modular_lfm2_moe.py @@ -17,6 +17,7 @@ from torch import nn from ... import initialization as init +from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask from ...modeling_outputs import MoeModelOutputWithPast from ...modeling_utils import PreTrainedModel @@ -26,7 +27,6 @@ from ..lfm2.modeling_lfm2 import ( Lfm2Attention, Lfm2DecoderLayer, - Lfm2HybridConvCache, Lfm2MLP, Lfm2RotaryEmbedding, Lfm2ShortConv, @@ -110,10 +110,6 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) -class Lfm2MoeHybridConvCache(Lfm2HybridConvCache): - pass - - class Lfm2MoeAttention(Lfm2Attention): pass @@ -133,7 +129,7 @@ def __init__(self, config: Lfm2MoeConfig, layer_idx: int): class Lfm2MoePreTrainedModel(LlamaPreTrainedModel): - _can_compile_fullgraph = False # uses a non-compilable custom cache class Lfm2MoeHybridConvCache + _can_compile_fullgraph = False # uses a non-compilable cache class @torch.no_grad() def _init_weights(self, module): @@ -159,7 +155,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Lfm2MoeHybridConvCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -171,10 +167,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - batch_size = inputs_embeds.shape[0] - past_key_values = Lfm2MoeHybridConvCache( - config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 63fec59c5c88..bc91025ead2e 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -98,5 +98,9 @@ def __post_init__(self, **kwargs): ) super().__post_init__(**kwargs) + @property + def layer_types(self): + return ["mamba"] * self.num_hidden_layers + __all__ = ["MambaConfig"] diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 69002f50ab78..2e5695c1d4fa 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -22,7 +22,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...configuration_utils import PreTrainedConfig +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import lazy_load_kernel from ...modeling_layers import GradientCheckpointingLayer @@ -55,117 +55,6 @@ pscan = None -class MambaCache: - """ - Cache for mamba model which does not have attention mechanism and key value states. - - Arguments: - config (`PreTrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if - a smaller batch size is used. - dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): - The default `dtype` to use when initializing the layer. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. Should be the same as the layer. - - Example: - - ```python - >>> import torch - >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache - - >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") - - >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> cache_params = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, cache_params=cache_params, use_cache=True) - >>> outputs.cache_params - ``` - """ - - is_compileable = True - - # TODO (joao): add layer_device_map arg and update code in `generate` accordingly - def __init__( - self, - config: PreTrainedConfig, - max_batch_size: int, - dtype: torch.dtype = torch.float16, - device: torch.device | str | None = None, - ): - self.max_batch_size = max_batch_size - self._dtype = dtype - self.intermediate_size = config.intermediate_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_kernel - self.has_previous_state = False - - self.conv_states: list[torch.Tensor] = [] - self.ssm_states: list[torch.Tensor] = [] - device = torch.device(device) if device is not None else None - for _ in range(config.num_hidden_layers): - conv_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.conv_kernel_size, - device=device, - dtype=self._dtype, - ) - ssm_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.ssm_state_size, - device=device, - dtype=self._dtype, - ) - - torch._dynamo.mark_static_address(conv_state) - torch._dynamo.mark_static_address(ssm_state) - self.conv_states.append(conv_state) - self.ssm_states.append(ssm_state) - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used - # when the cache is initialized in the forward pass (e.g. Mamba) - if self.conv_states[layer_idx].device != new_conv_state.device: - self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) - - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now - if cache_init: - self.conv_states[layer_idx].copy_(new_conv_state) - else: - conv_state = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - conv_state[:, :, -1:] = new_conv_state - self.conv_states[layer_idx].copy_(conv_state) - - # If last layer is updated, set the flag - if layer_idx == len(self.conv_states) - 1: - self.has_previous_state = True - - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx].zero_() - self.ssm_states[layer_idx] += new_ssm_state.to(self.ssm_states[layer_idx].device) - return self.ssm_states[layer_idx] - - def reset(self): - self.has_previous_state = False - for layer_idx in range(len(self.conv_states)): - # In-place ops prevent breaking the static address - self.conv_states[layer_idx].zero_() - self.ssm_states[layer_idx].zero_() - - class MambaMixer(nn.Module): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. @@ -278,7 +167,7 @@ def warn_slow_implementation(self): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: MambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, ): # 1. Gated MLP's linear projection @@ -307,14 +196,14 @@ def cuda_kernels_forward( if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) - is_decoding = cache_params is not None and cache_params.has_previous_state + is_decoding = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if is_decoding: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, conv_weights, self.conv1d.bias, self.activation, @@ -325,7 +214,7 @@ def cuda_kernels_forward( conv_states = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.update_conv_state(self.layer_idx, conv_states, cache_init=True) + cache_params.update_conv_state(conv_states, self.layer_idx) hidden_states = causal_conv1d_fn( hidden_states, conv_weights, self.conv1d.bias, activation=self.activation ) @@ -346,7 +235,7 @@ def cuda_kernels_forward( time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None if is_decoding: scan_outputs = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].recurrent_states, hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -371,14 +260,14 @@ def cuda_kernels_forward( return_last_state=True, ) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(self.layer_idx, ssm_state) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) return contextualized_states # fmt: off - def slow_forward(self, input_states, cache_params: MambaCache | None=None, attention_mask: torch.LongTensor | None = None): + def slow_forward(self, input_states, cache_params: Cache | None=None, attention_mask: torch.LongTensor | None = None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated MLP's linear projection @@ -388,30 +277,32 @@ def slow_forward(self, input_states, cache_params: MambaCache | None=None, atten if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): + ssm_state = cache_params.layers[self.layer_idx].recurrent_states.clone() + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) + # 2. Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx].clone() - ssm_state = ssm_state.to(hidden_states.device) - if not cache_params.has_previous_state: + if not cache_params.has_previous_state(self.layer_idx): conv_state = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.update_conv_state(self.layer_idx, conv_state, cache_init=True) + cache_params.update_conv_state(conv_state, self.layer_idx) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] else: - conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_init=False) + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) conv_state = conv_state.to(self.conv1d.weight.device) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: hidden_states += self.conv1d.bias hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding else: - ssm_state = torch.zeros( - (batch_size, self.intermediate_size, self.ssm_state_size), - device=hidden_states.device, dtype=dtype - ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] if attention_mask is not None: @@ -465,7 +356,7 @@ def combine_fn(left, right): scan_output = (scan_output * self.act(gate)) if cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] @@ -475,7 +366,7 @@ def combine_fn(left, right): def forward( self, hidden_states, - cache_params: MambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, **kwargs, ): @@ -519,7 +410,7 @@ def __init__(self, config, layer_idx): def forward( self, hidden_states, - cache_params: MambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, **kwargs, ): @@ -587,7 +478,7 @@ def _init_weights(self, module): ) class MambaOutput(ModelOutput): r""" - cache_params (`MambaCache`): + cache_params (`Cache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -595,7 +486,7 @@ class MambaOutput(ModelOutput): """ last_hidden_state: torch.FloatTensor | None = None - cache_params: MambaCache | None = None + cache_params: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None @@ -611,7 +502,7 @@ class MambaCausalLMOutput(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - cache_params (`MambaCache`): + cache_params (`Cache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -620,7 +511,7 @@ class MambaCausalLMOutput(ModelOutput): loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None - cache_params: MambaCache | None = None + cache_params: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None @@ -655,7 +546,7 @@ def forward( self, input_ids: torch.LongTensor | None = None, inputs_embeds: torch.LongTensor | None = None, - cache_params: MambaCache | None = None, + cache_params: Cache | None = None, use_cache: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, @@ -663,7 +554,7 @@ def forward( **kwargs, ) -> tuple | MambaOutput: r""" - cache_params (`MambaCache`, *optional*): + cache_params (`Cache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). use_cache (`bool`, *optional*): @@ -685,9 +576,7 @@ def forward( use_cache = False if use_cache and cache_params is None: - cache_params = MambaCache( - self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype - ) + cache_params = DynamicCache(config=self.config) hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None @@ -743,12 +632,11 @@ def prepare_inputs_for_generation( input_ids, inputs_embeds=None, use_cache=None, - cache_params: MambaCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None, is_first_iteration: bool | None = False, **kwargs, ): - # Overwritten -- has custom cache class `MambaCache` model_inputs = super().prepare_inputs_for_generation( input_ids, inputs_embeds=inputs_embeds, @@ -759,15 +647,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if use_cache and cache_params is None: - if inputs_embeds is not None: - max_batch_size = inputs_embeds.size(0) - else: - max_batch_size = input_ids.size(0) - model_inputs["cache_params"] = MambaCache( - self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype - ) - elif use_cache and not is_first_iteration: + if use_cache and not is_first_iteration: model_inputs["attention_mask"] = None return model_inputs @@ -778,7 +658,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, - cache_params: MambaCache | None = None, + cache_params: Cache | None = None, labels: torch.LongTensor | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, @@ -787,7 +667,7 @@ def forward( **kwargs, # for now we need this for generation ) -> tuple | MambaCausalLMOutput: r""" - cache_params (`MambaCache`, *optional*): + cache_params (`Cache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -837,4 +717,4 @@ def forward( ) -__all__ = ["MambaForCausalLM", "MambaModel", "MambaPreTrainedModel", "MambaCache"] +__all__ = ["MambaForCausalLM", "MambaModel", "MambaPreTrainedModel"] diff --git a/src/transformers/models/mamba2/configuration_mamba2.py b/src/transformers/models/mamba2/configuration_mamba2.py index 31af61ab99d8..d60fa5776422 100644 --- a/src/transformers/models/mamba2/configuration_mamba2.py +++ b/src/transformers/models/mamba2/configuration_mamba2.py @@ -103,5 +103,9 @@ def validate_architecture(self): f"({self.num_heads * self.head_dim})." ) + @property + def layer_types(self): + return ["mamba"] * self.num_hidden_layers + __all__ = ["Mamba2Config"] diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 3a70ec35dfd9..d0a47ef9dc63 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -21,6 +21,7 @@ from ... import initialization as init from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import lazy_load_kernel from ...modeling_layers import GradientCheckpointingLayer @@ -99,94 +100,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): return hidden_states -class Mamba2Cache: - """ - Arguments: - config: Mamba2Config - batch_size: int - dtype: torch.dtype - device: torch.device - - Attributes: - dtype: (`torch.dtype`): - The default `dtype` used to initializing the cache. - conv_kernel_size: (`int`): - Model's convolution kernel size taken from config. - n_groups: (`int`): - Model's number of groups taken from the config - similar to tensor parallel in Transformer. - state_size: (`int`): - Model's SSM state size taken from config. - num_heads: (`int`): - The number of heads used in the linear attention / SSM. - head_dim: (`int`): - The respective dimension of the heads used in the linear attention / SSM. - intermediate_size: (`int`): - Model's intermediate_size based on (expand * hidden_dim) from config. - conv_states: (`torch.Tensor`): - A tensor of shape `[num_layers, batch_size, conv_kernel_size, intermediate_size + 2 * n_groups * state_size]` that holds convolutional states. - ssm_states: (`torch.Tensor`): - A tensor of shape `[num_layers, batch_size, num_heads, head_dim, state_size]` that holds ssm states. - """ - - def __init__( - self, config: Mamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: str | None = None - ): - self.dtype = dtype - self.conv_kernel_size = config.conv_kernel - self.n_groups = config.n_groups - self.state_size = config.state_size - self.num_heads = config.num_heads - self.head_dim = config.head_dim - self.intermediate_size = int(config.expand * config.hidden_size) - - self.conv_states = torch.zeros( - config.num_hidden_layers, - batch_size, - self.intermediate_size + 2 * self.n_groups * self.state_size, - self.conv_kernel_size, - device=device, - dtype=dtype, - ) - self.ssm_states = torch.zeros( - config.num_hidden_layers, - batch_size, - self.num_heads, - self.head_dim, - self.state_size, - device=device, - dtype=dtype, - ) - self.has_previous_state = False - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - # Technically, those update are not logically correct if the prefill is smaller than `conv_kernel_size`, - # as it will `roll` anyway in the first decoding step even though it should `roll` ONLY if the cache is already full. - # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now - if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) - else: - self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) - - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) - - # If last layer is updated, set the flag - if layer_idx == self.conv_states.shape[0] - 1: - self.has_previous_state = True - - return self.ssm_states[layer_idx] - - def reset(self): - self.has_previous_state = False - self.conv_states.zero_() - self.ssm_states.zero_() - - class MambaRMSNormGated(torch.nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() @@ -325,7 +238,7 @@ def init_mamba2_weights(self): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: Mamba2Cache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): # 1. Gated MLP's linear projection @@ -343,7 +256,7 @@ def cuda_kernels_forward( ) // 2 # Single step calculations via cache - if cache_params is not None and cache_params.has_previous_state: + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) @@ -351,7 +264,7 @@ def cuda_kernels_forward( # 2. Convolution sequence transformation hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -373,7 +286,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -429,11 +342,9 @@ def cuda_kernels_forward( hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) conv_states = nn.functional.pad( hidden_states_B_C_transposed, - (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), - ) - cache_params.update_conv_state( - layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True + (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), ) + conv_states = cache_params.update_conv_state(conv_states, layer_idx=self.layer_idx) if self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( @@ -473,7 +384,7 @@ def cuda_kernels_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + cache_params.update_recurrent_state(ssm_state, layer_idx=self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer @@ -487,8 +398,8 @@ def cuda_kernels_forward( def torch_forward( self, hidden_states: torch.Tensor, - cache_params: Mamba2Cache | None=None, - attention_mask: torch.Tensor | None=None + cache_params: Cache | None = None, + attention_mask: torch.Tensor | None = None ): batch_size, seq_len, _ = hidden_states.shape dtype = hidden_states.dtype @@ -500,13 +411,13 @@ def torch_forward( _, _, gate, hidden_states_B_C, dt = projected_states.split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + hidden_states_B_C = hidden_states_B_C.transpose(1,2) - # 2. Convolution sequence transformation - if cache_params is not None and cache_params.has_previous_state: - cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False) + is_decoding = cache_params is not None and cache_params.has_previous_state(self.layer_idx) - # We need to guarantee that anything regarding the cache is on the same device - conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + # 2. Convolution sequence transformation + if is_decoding: + conv_states = cache_params.update_conv_state(hidden_states_B_C, layer_idx=self.layer_idx) hidden_states_B_C = torch.sum( conv_states * self.conv1d.weight.squeeze(1), dim=-1 @@ -517,13 +428,12 @@ def torch_forward( else: # Init cache if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) conv_states = nn.functional.pad( - hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + hidden_states_B_C, (self.conv_kernel_size - hidden_states_B_C.shape[-1], 0) ) - cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True) + cache_params.update_conv_state(conv_states, layer_idx=self.layer_idx) - hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C)[..., :seq_len].transpose(1, 2)) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( @@ -534,9 +444,9 @@ def torch_forward( # 3. SSM transformation A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and cache_params.has_previous_state: + if is_decoding: # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.ssm_states.device + cache_device = cache_params.layers[self.layer_idx].device # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -566,10 +476,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - cache_params.update_ssm_state( - layer_idx=self.layer_idx, - new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, layer_idx=self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -578,8 +486,8 @@ def torch_forward( C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions + ssm_states = ssm_states.to(device=C.device, dtype=C.dtype) ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] y = torch.bmm(ssm_states_reshaped, C_reshaped) @@ -639,10 +547,7 @@ def torch_forward( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - if cache_params is not None and cache_params.has_previous_state: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) decay_chunk = decay_chunk.transpose(1, 3) @@ -669,7 +574,7 @@ def torch_forward( # Init cache if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + cache_params.update_recurrent_state(ssm_state, layer_idx=self.layer_idx) scan_output = self.norm(y, gate) @@ -683,7 +588,7 @@ def torch_forward( def forward( self, hidden_states, - cache_params: Mamba2Cache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): @@ -721,7 +626,7 @@ def __init__(self, config, layer_idx): def forward( self, hidden_states, - cache_params: Mamba2Cache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): @@ -790,7 +695,7 @@ def _init_weights(self, module): # Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2 class Mamba2Output(ModelOutput): r""" - cache_params (`Mamba2Cache`): + cache_params (`Cache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -798,7 +703,7 @@ class Mamba2Output(ModelOutput): """ last_hidden_state: torch.FloatTensor | None = None - cache_params: Mamba2Cache | None = None + cache_params: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None @@ -815,7 +720,7 @@ class Mamba2CausalLMOutput(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - cache_params (`Mamba2Cache`): + cache_params (`Cache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -824,7 +729,7 @@ class Mamba2CausalLMOutput(ModelOutput): loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None - cache_params: Mamba2Cache | None = None + cache_params: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None @@ -859,7 +764,7 @@ def forward( self, input_ids: torch.LongTensor | None = None, inputs_embeds: torch.LongTensor | None = None, - cache_params: Mamba2Cache | None = None, + cache_params: Cache | None = None, use_cache: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, @@ -867,7 +772,7 @@ def forward( **kwargs, ) -> tuple | Mamba2Output: r""" - cache_params (`Mamba2Cache`, *optional*): + cache_params (`Cache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). use_cache (`bool`, *optional*): @@ -889,9 +794,7 @@ def forward( use_cache = False if use_cache and cache_params is None: - cache_params = Mamba2Cache( - self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype - ) + cache_params = DynamicCache(config=self.config) hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None @@ -947,13 +850,11 @@ def prepare_inputs_for_generation( input_ids, inputs_embeds=None, use_cache=None, - cache_params: Mamba2Cache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, is_first_iteration: bool | None = False, **kwargs, ): - # Overwritten -- has custom cache class `Mamba2Cache` - model_inputs = super().prepare_inputs_for_generation( input_ids, inputs_embeds=inputs_embeds, @@ -964,15 +865,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if use_cache and cache_params is None: - if inputs_embeds is not None: - max_batch_size = inputs_embeds.size(0) - else: - max_batch_size = input_ids.size(0) - model_inputs["cache_params"] = Mamba2Cache( - self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype - ) - elif use_cache and not is_first_iteration: + if use_cache and not is_first_iteration: model_inputs["attention_mask"] = None return model_inputs @@ -982,7 +875,7 @@ def forward( self, input_ids: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, - cache_params: Mamba2Cache | None = None, + cache_params: Cache | None = None, labels: torch.LongTensor | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, @@ -992,7 +885,7 @@ def forward( **kwargs, # for now we need this for generation and loss_function ) -> tuple | Mamba2CausalLMOutput: r""" - cache_params (`Mamba2Cache`, *optional*): + cache_params (`Cache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): diff --git a/src/transformers/models/musicflamingo/configuration_musicflamingo.py b/src/transformers/models/musicflamingo/configuration_musicflamingo.py index e7e7a15dcde1..562a3bf13071 100644 --- a/src/transformers/models/musicflamingo/configuration_musicflamingo.py +++ b/src/transformers/models/musicflamingo/configuration_musicflamingo.py @@ -32,7 +32,7 @@ class MusicFlamingoConfig(PreTrainedConfig): r""" audio_bos_token_id (`int`, *optional*, defaults to 151670): - The beginning-of-audio token index used to mark the start of audio spans. + The beginning-of-audio token index used to mark the start of audio spans. audio_eos_token_id (`int`, *optional*, defaults to 151671): The end-of-audio token index used to mark the end of audio spans. audio_frame_step (`float`, *optional*, defaults to 0.01): diff --git a/src/transformers/models/musicflamingo/modular_musicflamingo.py b/src/transformers/models/musicflamingo/modular_musicflamingo.py index 5da937bab052..2a9735f78ce0 100644 --- a/src/transformers/models/musicflamingo/modular_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modular_musicflamingo.py @@ -43,7 +43,7 @@ class MusicFlamingoConfig(AudioFlamingo3Config): r""" audio_bos_token_id (`int`, *optional*, defaults to 151670): - The beginning-of-audio token index used to mark the start of audio spans. + The beginning-of-audio token index used to mark the start of audio spans. audio_eos_token_id (`int`, *optional*, defaults to 151671): The end-of-audio token index used to mark the end of audio spans. audio_frame_step (`float`, *optional*, defaults to 0.01): diff --git a/src/transformers/models/nemotron_h/configuration_nemotron_h.py b/src/transformers/models/nemotron_h/configuration_nemotron_h.py index efa95cf00b61..df613f0e5bea 100644 --- a/src/transformers/models/nemotron_h/configuration_nemotron_h.py +++ b/src/transformers/models/nemotron_h/configuration_nemotron_h.py @@ -81,6 +81,7 @@ class NemotronHConfig(PreTrainedConfig): """ model_type = "nemotron_h" + attribute_map = {"layer_types": "layers_block_type"} keys_to_ignore_at_inference = ["past_key_values"] vocab_size: int = 131072 diff --git a/src/transformers/models/nemotron_h/modeling_nemotron_h.py b/src/transformers/models/nemotron_h/modeling_nemotron_h.py index 4455c9ba49b7..9e264e5cfdcc 100644 --- a/src/transformers/models/nemotron_h/modeling_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modeling_nemotron_h.py @@ -22,7 +22,6 @@ import math from collections.abc import Callable -from typing import Any import torch import torch.nn.functional as F @@ -30,6 +29,7 @@ from ... import initialization as init from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import ( lazy_load_kernel, @@ -54,119 +54,6 @@ logger = logging.get_logger(__name__) -class NemotronHHybridDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__( - self, config: NemotronHConfig, batch_size: int, dtype: torch.dtype = torch.float16, device: str | None = None - ): - self.dtype = dtype - self.layers_block_type = config.layers_block_type - self.has_previous_state = False - self.intermediate_size = int(config.mamba_num_heads * config.mamba_head_dim) - self.ssm_state_size = config.ssm_state_size - self.conv_kernel_size = config.conv_kernel - self.n_mamba_heads = config.mamba_num_heads - self.transformer_layers = [] - self._modules = {} - self._parameters = {} - self._buffers = {} - self.conv_states = {} - self.ssm_states = {} - for i in range(config.num_hidden_layers): - if self.layers_block_type[i] == "mamba": - # Only allocate mamba cache for mamba layers - self.conv_states[i] = torch.zeros( - batch_size, - self.intermediate_size + 2 * config.n_groups * self.ssm_state_size, - self.conv_kernel_size, - device=device, - dtype=dtype, - ) - self.ssm_states[i] = torch.zeros( - batch_size, - self.n_mamba_heads, - config.mamba_head_dim, - self.ssm_state_size, - device=device, - dtype=dtype, - ) - else: - # For attention and moe layers, use empty tensors - self.conv_states[i] = torch.tensor([[]] * batch_size, device=device) - self.ssm_states[i] = torch.tensor([[]] * batch_size, device=device) - - if self.layers_block_type[i] == "attention": - self.transformer_layers.append(i) - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.key_cache) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - if self.get_seq_length() > 0: - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - # Helper methods for segment sum computation @@ -329,7 +216,7 @@ def __init__(self, config: NemotronHConfig, layer_idx: int | None = None): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: NemotronHHybridDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): # set up dimensions for reshapes later @@ -339,7 +226,7 @@ def cuda_kernels_forward( d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads # getting projected states from cache if it exists - if cache_params is not None and cache_params.has_previous_state: + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D) d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads] @@ -347,7 +234,7 @@ def cuda_kernels_forward( hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -368,7 +255,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -432,7 +319,7 @@ def cuda_kernels_forward( conv_state = nn.functional.pad( hidden_states_B_C_t, (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] @@ -469,7 +356,7 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer scan_output = self.norm(scan_output, gate) @@ -477,12 +364,12 @@ def cuda_kernels_forward( return out # fmt: off - def torch_forward(self, input_states, cache_params: NemotronHHybridDynamicCache | None=None, attention_mask: torch.Tensor | None=None): + def torch_forward(self, input_states, cache_params: Cache | None=None, attention_mask: torch.Tensor | None = None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # Gated MLP's linear projection - if cache_params is not None and cache_params.has_previous_state: - projected_states = self.in_proj(input_states.squeeze(1)) + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): + projected_states = self.in_proj(input_states) else: if attention_mask is not None: # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 @@ -492,43 +379,34 @@ def torch_forward(self, input_states, cache_params: NemotronHHybridDynamicCache _, _, gate, hidden_states, dt = projected_states.split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + hidden_states = hidden_states.transpose(1, 2) + + use_precomputed_state = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # Convolution sequence transformation - if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx].clone() - ssm_state = ssm_state.to(hidden_states.device) - if cache_params.has_previous_state: - gate = gate.unsqueeze(1) - conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) - # handle batched generation - states are copied through - conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states - cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) - if self.use_conv_bias: - hidden_states += self.conv1d.bias - hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding - else: - hidden_states = hidden_states.transpose(1,2) + if use_precomputed_state: + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding + else: + if cache_params is not None: conv_state = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] - if attention_mask is not None: - dtype = hidden_states.dtype - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - else: - ssm_state = torch.zeros( - (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), - device=hidden_states.device, dtype=dtype - ) - hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) + + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len].transpose(1, 2)) + if attention_mask is not None: + dtype = hidden_states.dtype + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and cache_params.has_previous_state: + if use_precomputed_state: # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] @@ -557,9 +435,9 @@ def torch_forward(self, input_states, cache_params: NemotronHHybridDynamicCache dBx = dB * hidden_states[..., None] # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states.clone() + ssm_states = ssm_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -568,7 +446,7 @@ def torch_forward(self, input_states, cache_params: NemotronHHybridDynamicCache C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_states.to(C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -629,10 +507,7 @@ def torch_forward(self, input_states, cache_params: NemotronHHybridDynamicCache B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] # permute back B * decay states states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) - if cache_params is not None and cache_params.has_previous_state: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) @@ -660,7 +535,7 @@ def torch_forward(self, input_states, cache_params: NemotronHHybridDynamicCache y = y[:, :seq_len, :, :] y = y.reshape(batch_size, seq_len, -1) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) @@ -674,7 +549,7 @@ def torch_forward(self, input_states, cache_params: NemotronHHybridDynamicCache def forward( self, hidden_states, - cache_params: NemotronHHybridDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): @@ -968,7 +843,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - past_key_values: NemotronHHybridDynamicCache | None = None, + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] @@ -1036,7 +911,7 @@ def __init__(self, config, layer_idx): def forward( self, hidden_states, - past_key_values: NemotronHHybridDynamicCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, use_cache: bool | None = False, @@ -1162,7 +1037,7 @@ def forward( input_ids: torch.LongTensor | None = None, inputs_embeds: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: NemotronHHybridDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = None, attention_mask: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], @@ -1174,12 +1049,7 @@ def forward( inputs_embeds = self.embeddings(input_ids) if use_cache and past_key_values is None: - past_key_values = NemotronHHybridDynamicCache( - config=self.config, - batch_size=inputs_embeds.shape[0], - dtype=inputs_embeds.dtype, - device=inputs_embeds.device, - ) + past_key_values = DynamicCache(config=self.config) hidden_states = inputs_embeds @@ -1218,9 +1088,6 @@ def forward( hidden_states = self.norm_f(hidden_states) - if past_key_values is not None and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, @@ -1233,7 +1100,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None @@ -1260,7 +1127,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: NemotronHHybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, @@ -1327,13 +1194,6 @@ def prepare_inputs_for_generation( is_first_iteration=False, **kwargs, ): - # Overwritten -- has a unique cache type, `NemotronHHybridDynamicCache` - - if past_key_values is None: - past_key_values = NemotronHHybridDynamicCache( - self.config, input_ids.shape[0], dtype=self.dtype, device=self.device - ) - kwargs["logits_to_keep"] = self.config.num_logits_to_keep model_inputs = super().prepare_inputs_for_generation( input_ids, diff --git a/src/transformers/models/nemotron_h/modular_nemotron_h.py b/src/transformers/models/nemotron_h/modular_nemotron_h.py index 8412be1dc5f9..a7433a982f1c 100644 --- a/src/transformers/models/nemotron_h/modular_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modular_nemotron_h.py @@ -22,6 +22,7 @@ from ... import initialization as init from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...integrations import use_experts_implementation from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer @@ -32,7 +33,7 @@ from ...models.llama.modeling_llama import LlamaRMSNorm from ...models.nemotron.modeling_nemotron import NemotronMLP from ...models.zamba.modeling_zamba import ZambaForCausalLM -from ...models.zamba2.modeling_zamba2 import Zamba2HybridDynamicCache, Zamba2MambaMixer, Zamba2RMSNormGated +from ...models.zamba2.modeling_zamba2 import Zamba2MambaMixer, Zamba2RMSNormGated from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging from ...utils.generic import merge_with_config_defaults @@ -45,52 +46,6 @@ is_fast_path_available = False -class NemotronHHybridDynamicCache(Zamba2HybridDynamicCache): - def __init__( - self, config: NemotronHConfig, batch_size: int, dtype: torch.dtype = torch.float16, device: str | None = None - ): - self.dtype = dtype - self.layers_block_type = config.layers_block_type - self.has_previous_state = False - self.intermediate_size = int(config.mamba_num_heads * config.mamba_head_dim) - self.ssm_state_size = config.ssm_state_size - self.conv_kernel_size = config.conv_kernel - self.n_mamba_heads = config.mamba_num_heads - self.transformer_layers = [] - self._modules = {} - self._parameters = {} - self._buffers = {} - self.conv_states = {} - self.ssm_states = {} - for i in range(config.num_hidden_layers): - if self.layers_block_type[i] == "mamba": - # Only allocate mamba cache for mamba layers - self.conv_states[i] = torch.zeros( - batch_size, - self.intermediate_size + 2 * config.n_groups * self.ssm_state_size, - self.conv_kernel_size, - device=device, - dtype=dtype, - ) - self.ssm_states[i] = torch.zeros( - batch_size, - self.n_mamba_heads, - config.mamba_head_dim, - self.ssm_state_size, - device=device, - dtype=dtype, - ) - else: - # For attention and moe layers, use empty tensors - self.conv_states[i] = torch.tensor([[]] * batch_size, device=device) - self.ssm_states[i] = torch.tensor([[]] * batch_size, device=device) - - if self.layers_block_type[i] == "attention": - self.transformer_layers.append(i) - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - class NemotronHMamba2Mixer(Zamba2MambaMixer): def __init__(self, config: NemotronHConfig, layer_idx: int | None = None): super().__init__(config, layer_idx) @@ -133,7 +88,7 @@ def __init__(self, config: NemotronHConfig, layer_idx: int | None = None): def forward( self, hidden_states, - cache_params: NemotronHHybridDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): @@ -277,7 +232,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - past_key_values: NemotronHHybridDynamicCache | None = None, + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: return super().forward(hidden_states, attention_mask, past_key_values, **kwargs) @@ -318,7 +273,7 @@ def __init__(self, config, layer_idx): def forward( self, hidden_states, - past_key_values: NemotronHHybridDynamicCache | None = None, + past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, use_cache: bool | None = False, @@ -444,7 +399,7 @@ def forward( input_ids: torch.LongTensor | None = None, inputs_embeds: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: NemotronHHybridDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = None, attention_mask: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], @@ -456,12 +411,7 @@ def forward( inputs_embeds = self.embeddings(input_ids) if use_cache and past_key_values is None: - past_key_values = NemotronHHybridDynamicCache( - config=self.config, - batch_size=inputs_embeds.shape[0], - dtype=inputs_embeds.dtype, - device=inputs_embeds.device, - ) + past_key_values = DynamicCache(config=self.config) hidden_states = inputs_embeds @@ -500,9 +450,6 @@ def forward( hidden_states = self.norm_f(hidden_states) - if past_key_values is not None and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, @@ -515,7 +462,7 @@ def _update_mamba_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ mamba_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): mamba_mask = None @@ -532,7 +479,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: NemotronHHybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, diff --git a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py index e3b83bff770a..bf72e06d94eb 100644 --- a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py @@ -61,8 +61,7 @@ class OlmoHybridDynamicCache: """ Cache for hybrid model supporting both attention KV cache and linear attention state. - Inherits from Qwen3NextDynamicCache. The main difference is that this cache - stores separate conv states for q, k, v (instead of a single conv_states list). + The main difference is that this cache stores separate conv states for q, k, v (instead of a single conv_states). """ is_compileable = False @@ -155,7 +154,6 @@ def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: kv_length = query_length + past_seen_tokens return kv_length, kv_offset - @property def has_previous_state(self): """We have a previous state if the last linear (conv) layer was already updated.""" return self.conv_states_q[self.last_linear_layer] is not None @@ -725,7 +723,7 @@ def forward( batch_size, seq_len, _ = hidden_states.shape use_cache = cache_params is not None - use_precomputed = use_cache and getattr(cache_params, "has_previous_state", False) and seq_len == 1 + use_precomputed = use_cache and cache_params.has_previous_state() and seq_len == 1 conv_state_q = cache_params.conv_states_q[self.layer_idx] if cache_params else None conv_state_k = cache_params.conv_states_k[self.layer_idx] if cache_params else None @@ -1024,7 +1022,7 @@ def _update_linear_attn_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ linear_attn_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): linear_attn_mask = None diff --git a/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py index 74f6a70ce6de..089f29309007 100644 --- a/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py @@ -16,6 +16,7 @@ import math from collections.abc import Callable +from typing import Any import torch import torch.nn as nn @@ -48,7 +49,6 @@ eager_attention_forward, ) from ..qwen3_next.modeling_qwen3_next import ( - Qwen3NextDynamicCache, Qwen3NextModel, Qwen3NextPreTrainedModel, Qwen3NextRMSNormGated, @@ -189,22 +189,49 @@ def validate_architecture(self): raise ValueError("OLMoHybrid expects at least one attention layer.") -class OlmoHybridDynamicCache(Qwen3NextDynamicCache): +class OlmoHybridDynamicCache: """ Cache for hybrid model supporting both attention KV cache and linear attention state. - Inherits from Qwen3NextDynamicCache. The main difference is that this cache - stores separate conv states for q, k, v (instead of a single conv_states list). + The main difference is that this cache stores separate conv states for q, k, v (instead of a single conv_states). """ + is_compileable = False + def __init__(self, config: OlmoHybridConfig): - super().__init__(config) - del self.conv_states + super().__init__() + self.layer_types = config.layer_types + self.transformer_layers = [ + i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" + ] + self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") + self.recurrent_states = [None for _ in range(config.num_hidden_layers)] + self.key_cache = [None for _ in range(config.num_hidden_layers)] + self.value_cache = [None for _ in range(config.num_hidden_layers)] # Replace single conv_states with separate q, k, v conv states self.conv_states_q = [None for _ in range(config.num_hidden_layers)] self.conv_states_k = [None for _ in range(config.num_hidden_layers)] self.conv_states_v = [None for _ in range(config.num_hidden_layers)] + def __len__(self): + return len(self.layer_types) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.key_cache[layer_idx] is None: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" batch_size = beam_idx.shape[0] @@ -240,8 +267,27 @@ def reorder_cache(self, beam_idx: torch.LongTensor): 0, beam_idx.to(device) ) - @property + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. + """ + kv_offset = 0 + past_seen_tokens = self.get_seq_length(layer_idx) + kv_length = query_length + past_seen_tokens + return kv_length, kv_offset + def has_previous_state(self): + """We have a previous state if the last linear (conv) layer was already updated.""" return self.conv_states_q[self.last_linear_layer] is not None @@ -495,7 +541,7 @@ def forward( batch_size, seq_len, _ = hidden_states.shape use_cache = cache_params is not None - use_precomputed = use_cache and getattr(cache_params, "has_previous_state", False) and seq_len == 1 + use_precomputed = use_cache and cache_params.has_previous_state() and seq_len == 1 conv_state_q = cache_params.conv_states_q[self.layer_idx] if cache_params else None conv_state_k = cache_params.conv_states_k[self.layer_idx] if cache_params else None diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index 57589b70b94f..eba3eec02fdd 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -29,7 +29,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernelized_func from ...masking_utils import create_causal_mask @@ -66,95 +66,6 @@ logger = logging.get_logger(__name__) -class Qwen3_5DynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention - cache (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config: Qwen3_5Config): - super().__init__() - self.layer_types = config.layer_types - self.transformer_layers = [ - i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" - ] - self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") - - # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference - self.conv_states = [None for _ in range(config.num_hidden_layers)] - self.recurrent_states = [None for _ in range(config.num_hidden_layers)] - self.key_cache = [None for _ in range(config.num_hidden_layers)] - self.value_cache = [None for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.layer_types) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if self.key_cache[layer_idx] is None: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx] is not None: - device = self.key_cache[layer_idx].device - beam_idx = beam_idx.to(device) - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) - - if self.conv_states[layer_idx] is not None: - device = self.conv_states[layer_idx].device - beam_idx = beam_idx.to(device) - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx) - self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. - """ - kv_offset = 0 - past_seen_tokens = self.get_seq_length(layer_idx) - kv_length = query_length + past_seen_tokens - return kv_length, kv_offset - - @property - def has_previous_state(self): - """We have a previous state if the last linear (conv) layer was already updated.""" - return self.conv_states[self.last_linear_layer] is not None - - class Qwen3_5VisionRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -511,7 +422,7 @@ def __init__(self, config: Qwen3_5Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - cache_params: Qwen3_5DynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) @@ -519,12 +430,14 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1 + use_precomputed_states = ( + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 + ) # getting projected states from cache if it exists - if cache_params is not None: - conv_state = cache_params.conv_states[self.layer_idx] - recurrent_state = cache_params.recurrent_states[self.layer_idx] + if use_precomputed_states: + conv_state = cache_params.layers[self.layer_idx].conv_states + recurrent_state = cache_params.layers[self.layer_idx].recurrent_states mixed_qkv = self.in_proj_qkv(hidden_states) mixed_qkv = mixed_qkv.transpose(1, 2) @@ -548,7 +461,7 @@ def forward( else: if cache_params is not None: conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - cache_params.conv_states[self.layer_idx] = conv_state + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -608,7 +521,7 @@ def forward( # Update cache if cache_params is not None: - cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx) # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) @@ -1328,7 +1241,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = Qwen3_5DynamicCache(config=self.config) + past_key_values = DynamicCache(config=self.config) # the hard coded `4` is for text, temporal, height and width. if position_ids is None: @@ -1384,7 +1297,7 @@ def _update_linear_attn_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ linear_attn_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): linear_attn_mask = None diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index b76991426f13..8fddbc6115c1 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -21,7 +21,7 @@ from torch import nn from ... import initialization as init -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling @@ -34,7 +34,6 @@ from ..qwen3_next.configuration_qwen3_next import Qwen3NextConfig from ..qwen3_next.modeling_qwen3_next import ( Qwen3NextAttention, - Qwen3NextDynamicCache, Qwen3NextGatedDeltaNet, Qwen3NextMLP, Qwen3NextModel, @@ -160,10 +159,6 @@ class Qwen3_5Config(Qwen3VLConfig): vision_end_token_id: int = 248054 -class Qwen3_5DynamicCache(Qwen3NextDynamicCache): - pass - - class Qwen3_5VisionRotaryEmbedding(Qwen3VLVisionRotaryEmbedding): pass @@ -212,7 +207,7 @@ def fix_query_key_value_ordering(self): def forward( self, hidden_states: torch.Tensor, - cache_params: Qwen3_5DynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) @@ -220,12 +215,14 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1 + use_precomputed_states = ( + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 + ) # getting projected states from cache if it exists - if cache_params is not None: - conv_state = cache_params.conv_states[self.layer_idx] - recurrent_state = cache_params.recurrent_states[self.layer_idx] + if use_precomputed_states: + conv_state = cache_params.layers[self.layer_idx].conv_states + recurrent_state = cache_params.layers[self.layer_idx].recurrent_states mixed_qkv = self.in_proj_qkv(hidden_states) mixed_qkv = mixed_qkv.transpose(1, 2) @@ -249,7 +246,7 @@ def forward( else: if cache_params is not None: conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - cache_params.conv_states[self.layer_idx] = conv_state + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -309,7 +306,7 @@ def forward( # Update cache if cache_params is not None: - cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx) # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) @@ -501,7 +498,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = Qwen3_5DynamicCache(config=self.config) + past_key_values = DynamicCache(config=self.config) # the hard coded `4` is for text, temporal, height and width. if position_ids is None: diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 801156d236c3..be4501d34903 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -29,7 +29,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_experts_implementation, use_kernelized_func from ...masking_utils import create_causal_mask @@ -173,95 +173,6 @@ def apply_interleaved_mrope(self, freqs, mrope_section): return freqs_t -class Qwen3_5MoeDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention - cache (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config: Qwen3_5MoeConfig): - super().__init__() - self.layer_types = config.layer_types - self.transformer_layers = [ - i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" - ] - self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") - - # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference - self.conv_states = [None for _ in range(config.num_hidden_layers)] - self.recurrent_states = [None for _ in range(config.num_hidden_layers)] - self.key_cache = [None for _ in range(config.num_hidden_layers)] - self.value_cache = [None for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.layer_types) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if self.key_cache[layer_idx] is None: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx] is not None: - device = self.key_cache[layer_idx].device - beam_idx = beam_idx.to(device) - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) - - if self.conv_states[layer_idx] is not None: - device = self.conv_states[layer_idx].device - beam_idx = beam_idx.to(device) - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx) - self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. - """ - kv_offset = 0 - past_seen_tokens = self.get_seq_length(layer_idx) - kv_length = query_length + past_seen_tokens - return kv_length, kv_offset - - @property - def has_previous_state(self): - """We have a previous state if the last linear (conv) layer was already updated.""" - return self.conv_states[self.last_linear_layer] is not None - - class Qwen3_5MoeRMSNormGated(nn.Module): def __init__(self, hidden_size, eps=1e-6, **kwargs): super().__init__() @@ -512,7 +423,7 @@ def __init__(self, config: Qwen3_5MoeConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - cache_params: Qwen3_5MoeDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) @@ -520,12 +431,14 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1 + use_precomputed_states = ( + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 + ) # getting projected states from cache if it exists - if cache_params is not None: - conv_state = cache_params.conv_states[self.layer_idx] - recurrent_state = cache_params.recurrent_states[self.layer_idx] + if use_precomputed_states: + conv_state = cache_params.layers[self.layer_idx].conv_states + recurrent_state = cache_params.layers[self.layer_idx].recurrent_states mixed_qkv = self.in_proj_qkv(hidden_states) mixed_qkv = mixed_qkv.transpose(1, 2) @@ -549,7 +462,7 @@ def forward( else: if cache_params is not None: conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - cache_params.conv_states[self.layer_idx] = conv_state + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -609,7 +522,7 @@ def forward( # Update cache if cache_params is not None: - cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx) # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) @@ -1453,7 +1366,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = Qwen3_5MoeDynamicCache(config=self.config) + past_key_values = DynamicCache(config=self.config) # the hard coded `4` is for text, temporal, height and width. if position_ids is None: @@ -1509,7 +1422,7 @@ def _update_linear_attn_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ linear_attn_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): linear_attn_mask = None @@ -2005,7 +1918,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Qwen3_5MoeDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, diff --git a/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py index f8684ddd83db..312b22bc88ed 100644 --- a/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py @@ -35,7 +35,6 @@ from ..qwen3_next.modeling_qwen3_next import ( Qwen3NextAttention, Qwen3NextDecoderLayer, - Qwen3NextDynamicCache, Qwen3NextExperts, Qwen3NextForCausalLM, Qwen3NextPreTrainedModel, @@ -156,10 +155,6 @@ class Qwen3_5MoeTextRotaryEmbedding(Qwen3_5TextRotaryEmbedding): pass -class Qwen3_5MoeDynamicCache(Qwen3NextDynamicCache): - pass - - class Qwen3_5MoeGatedDeltaNet(Qwen3_5GatedDeltaNet): pass diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 7b45f0ea4838..9e7fa7e01c69 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -19,7 +19,7 @@ # limitations under the License. from collections.abc import Callable -from typing import Any, Optional +from typing import Optional import torch import torch.nn.functional as F @@ -27,7 +27,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_experts_implementation, use_kernelized_func from ...masking_utils import create_causal_mask @@ -82,95 +82,6 @@ def forward(self, hidden_states, gate=None): return hidden_states.to(input_dtype) -class Qwen3NextDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention - cache (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config: Qwen3NextConfig): - super().__init__() - self.layer_types = config.layer_types - self.transformer_layers = [ - i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" - ] - self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") - - # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference - self.conv_states = [None for _ in range(config.num_hidden_layers)] - self.recurrent_states = [None for _ in range(config.num_hidden_layers)] - self.key_cache = [None for _ in range(config.num_hidden_layers)] - self.value_cache = [None for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.layer_types) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if self.key_cache[layer_idx] is None: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx] is not None: - device = self.key_cache[layer_idx].device - beam_idx = beam_idx.to(device) - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) - - if self.conv_states[layer_idx] is not None: - device = self.conv_states[layer_idx].device - beam_idx = beam_idx.to(device) - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx) - self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. - """ - kv_offset = 0 - past_seen_tokens = self.get_seq_length(layer_idx) - kv_length = query_length + past_seen_tokens - return kv_length, kv_offset - - @property - def has_previous_state(self): - """We have a previous state if the last linear (conv) layer was already updated.""" - return self.conv_states[self.last_linear_layer] is not None - - class Qwen3NextRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -681,7 +592,7 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): def forward( self, hidden_states: torch.Tensor, - cache_params: Qwen3NextDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) @@ -689,12 +600,14 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1 + use_precomputed_states = ( + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 + ) # getting projected states from cache if it exists - if cache_params is not None: - conv_state = cache_params.conv_states[self.layer_idx] - recurrent_state = cache_params.recurrent_states[self.layer_idx] + if use_precomputed_states: + conv_state = cache_params.layers[self.layer_idx].conv_states + recurrent_state = cache_params.layers[self.layer_idx].recurrent_states projected_states_qkvz = self.in_proj_qkvz(hidden_states) projected_states_ba = self.in_proj_ba(hidden_states) @@ -717,7 +630,7 @@ def forward( else: if cache_params is not None: conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - cache_params.conv_states[self.layer_idx] = conv_state + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -761,7 +674,6 @@ def forward( output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) - else: core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, @@ -776,7 +688,7 @@ def forward( # Update cache if cache_params is not None: - cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx) z_shape_og = z.shape # reshape input data into 2D tensor @@ -1021,7 +933,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = Qwen3NextDynamicCache(config=self.config) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1068,7 +980,7 @@ def _update_linear_attn_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ linear_attn_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): linear_attn_mask = None @@ -1182,7 +1094,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Qwen3NextDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index a22b85bf9278..417a9a59cf8b 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -14,7 +14,7 @@ """PyTorch Qwen3-Next model.""" from collections.abc import Callable -from typing import Any, Optional +from typing import Optional import torch import torch.nn.functional as F @@ -22,7 +22,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -92,95 +92,6 @@ def forward(self, hidden_states, gate=None): return hidden_states.to(input_dtype) -class Qwen3NextDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention - cache (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config: Qwen3NextConfig): - super().__init__() - self.layer_types = config.layer_types - self.transformer_layers = [ - i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" - ] - self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") - - # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference - self.conv_states = [None for _ in range(config.num_hidden_layers)] - self.recurrent_states = [None for _ in range(config.num_hidden_layers)] - self.key_cache = [None for _ in range(config.num_hidden_layers)] - self.value_cache = [None for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.layer_types) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if self.key_cache[layer_idx] is None: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx] is not None: - device = self.key_cache[layer_idx].device - beam_idx = beam_idx.to(device) - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) - - if self.conv_states[layer_idx] is not None: - device = self.conv_states[layer_idx].device - beam_idx = beam_idx.to(device) - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx) - self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. - """ - kv_offset = 0 - past_seen_tokens = self.get_seq_length(layer_idx) - kv_length = query_length + past_seen_tokens - return kv_length, kv_offset - - @property - def has_previous_state(self): - """We have a previous state if the last linear (conv) layer was already updated.""" - return self.conv_states[self.last_linear_layer] is not None - - class Qwen3NextRotaryEmbedding(Gemma2RotaryEmbedding): @staticmethod def compute_default_rope_parameters( @@ -520,7 +431,7 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): def forward( self, hidden_states: torch.Tensor, - cache_params: Qwen3NextDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) @@ -528,12 +439,14 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1 + use_precomputed_states = ( + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 + ) # getting projected states from cache if it exists - if cache_params is not None: - conv_state = cache_params.conv_states[self.layer_idx] - recurrent_state = cache_params.recurrent_states[self.layer_idx] + if use_precomputed_states: + conv_state = cache_params.layers[self.layer_idx].conv_states + recurrent_state = cache_params.layers[self.layer_idx].recurrent_states projected_states_qkvz = self.in_proj_qkvz(hidden_states) projected_states_ba = self.in_proj_ba(hidden_states) @@ -556,7 +469,7 @@ def forward( else: if cache_params is not None: conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - cache_params.conv_states[self.layer_idx] = conv_state + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -600,7 +513,6 @@ def forward( output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) - else: core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, @@ -615,7 +527,7 @@ def forward( # Update cache if cache_params is not None: - cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx) z_shape_og = z.shape # reshape input data into 2D tensor @@ -777,7 +689,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = Qwen3NextDynamicCache(config=self.config) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -824,7 +736,7 @@ def _update_linear_attn_mask(self, attention_mask, past_key_values): 2. Attending to all inputs """ linear_attn_mask = attention_mask - if (past_key_values is not None and past_key_values.has_previous_state) or ( + if (past_key_values is not None and past_key_values.has_previous_state()) or ( attention_mask is not None and torch.all(attention_mask == 1) ): linear_attn_mask = None @@ -841,7 +753,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Qwen3NextDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, diff --git a/src/transformers/models/zamba/configuration_zamba.py b/src/transformers/models/zamba/configuration_zamba.py index 6ac44a5419bb..5432da30b90f 100644 --- a/src/transformers/models/zamba/configuration_zamba.py +++ b/src/transformers/models/zamba/configuration_zamba.py @@ -53,6 +53,7 @@ class ZambaConfig(PreTrainedConfig): model_type = "zamba" keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"layer_types": "layers_block_type", "head_dim": "attention_head_dim"} vocab_size: int = 32000 tie_word_embeddings: bool = True diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index c985236ff0f7..3e891f9a3baf 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -20,7 +20,6 @@ import math from collections.abc import Callable -from typing import Any import torch from torch import nn @@ -28,7 +27,7 @@ from ... import initialization as init from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask @@ -80,107 +79,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class ZambaHybridDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__(self, config, batch_size, dtype=torch.float16, device=None): - self.dtype = dtype - self.is_compileable = False - self.layers_block_type = config.layers_block_type - self.has_previous_state = False # only used by mamba - self.intermediate_size = config.mamba_expand * config.hidden_size - self.ssm_state_size = config.mamba_d_state - self.conv_kernel_size = config.mamba_d_conv - self.n_mamba_heads = config.n_mamba_heads - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - self._modules = {} - self._parameters = {} - self._buffers = {} - for i in range(config.num_hidden_layers): - self.conv_states += [ - torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype) - ] - cache_shape = ( - batch_size, - self.n_mamba_heads, - self.intermediate_size // self.n_mamba_heads, - self.ssm_state_size, - ) - self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)] - if self.layers_block_type[i] == "hybrid": - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.key_cache) - - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.reorder_cache - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - if self.get_seq_length() > 0: - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.get_seq_length - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].shape[-1] == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.get_mask_sizes - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -243,7 +141,7 @@ def forward( hidden_states: torch.Tensor, layer_idx: int, attention_mask: torch.Tensor | None, - past_key_values: ZambaHybridDynamicCache | None = None, + past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: input_shape = hidden_states.shape[:-1] @@ -369,7 +267,7 @@ def __init__(self, config: ZambaConfig, layer_idx): ) def cuda_kernels_forward( - self, hidden_states: torch.Tensor, cache_params: ZambaHybridDynamicCache = None, attention_mask=None + self, hidden_states: torch.Tensor, cache_params: Cache | None = None, attention_mask=None ): batch_size, seq_len, _ = hidden_states.shape use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1 @@ -387,7 +285,7 @@ def cuda_kernels_forward( if use_precomputed_states: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, conv_weights, self.conv1d.bias, self.activation, @@ -398,7 +296,7 @@ def cuda_kernels_forward( hidden_states = hidden_states * attention_mask.unsqueeze(1) if cache_params is not None: conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) - cache_params.conv_states[self.layer_idx].copy_(conv_states) + conv_states = cache_params.update_conv_state(conv_states, self.layer_idx) hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation) if attention_mask is not None and not torch.all(attention_mask == 1): hidden_states = hidden_states * attention_mask.unsqueeze(1) @@ -424,7 +322,7 @@ def cuda_kernels_forward( if use_precomputed_states: for n in range(self.n_mamba_heads): scan_outputs_ = selective_state_update( - cache_params.ssm_states[self.layer_idx][:, n], + cache_params.layers[self.layer_idx].recurrent_states[:, n], hidden_states[n, ..., 0], discrete_time_step[n, ..., 0], A[n], @@ -459,13 +357,13 @@ def cuda_kernels_forward( scan_outputs = torch.cat((scan_outputs, scan_outputs_), dim=1).contiguous() ssm_state = torch.cat((ssm_state, ssm_state_.unsqueeze(1)), dim=1) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) return contextualized_states - def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None): + def slow_forward(self, input_states, cache_params: Cache | None = None, attention_mask=None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated linear projection @@ -476,26 +374,20 @@ def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = Non gate = gate.squeeze(2) gate = gate.reshape(batch_size, self.n_mamba_heads, -1, seq_len).transpose(0, 1) - use_cache = isinstance(cache_params, ZambaHybridDynamicCache) + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): + # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass + ssm_state = cache_params.layers[self.layer_idx].recurrent_states.clone() + else: + ssm_state = torch.zeros( + (batch_size, self.n_mamba_heads, self.mamba_head_dim, self.ssm_state_size), + device=hidden_states.device, + dtype=dtype, + ) + # 2. Convolution sequence transformation - if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size: - if self.training: - # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass - ssm_state = cache_params.ssm_states[self.layer_idx].clone() - else: - ssm_state = cache_params.ssm_states[self.layer_idx] - - ssm_state = ssm_state.to(hidden_states.device) - - if ( - cache_params.has_previous_state - and seq_len == 1 - and cache_params.conv_states[self.layer_idx].shape[0] == batch_size - ): - conv_state = cache_params.conv_states[self.layer_idx] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) - conv_state[:, :, -1] = hidden_states[:, :, 0] - cache_params.conv_states[self.layer_idx] = conv_state + if cache_params is not None: + if cache_params.has_previous_state(self.layer_idx) and seq_len == 1: + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: hidden_states += self.conv1d.bias @@ -504,16 +396,11 @@ def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = Non if attention_mask is not None: hidden_states = hidden_states * attention_mask[:, -hidden_states.shape[-1] :].unsqueeze(1) conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) - cache_params.conv_states[self.layer_idx] = conv_state + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) if attention_mask is not None: hidden_states = hidden_states * attention_mask[:, -hidden_states.shape[-1] :].unsqueeze(1) else: - ssm_state = torch.zeros( - (batch_size, self.n_mamba_heads, self.mamba_head_dim, self.ssm_state_size), - device=hidden_states.device, - dtype=dtype, - ) if attention_mask is not None: hidden_states = hidden_states * attention_mask.unsqueeze(1) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) @@ -549,8 +436,8 @@ def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = Non scan_output = scan_output + (hidden_states * self.D[:, None, :, None]) scan_output = scan_output * self.act(gate) - if use_cache: - cache_params.ssm_states[self.layer_idx] = ssm_state + if cache_params is not None: + cache_params.update_recurrent_state(ssm_state, self.layer_idx) # 4. Final linear projection contextualized_states = self.out_proj( @@ -558,7 +445,7 @@ def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = Non ) return contextualized_states - def forward(self, hidden_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None, **kwargs): + def forward(self, hidden_states, cache_params: Cache | None = None, attention_mask=None, **kwargs): is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) @@ -606,7 +493,7 @@ def forward( original_hidden_states: torch.Tensor, layer_idx: int, attention_mask: torch.Tensor | None = None, - past_key_values: ZambaHybridDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: @@ -620,7 +507,7 @@ def forward( layer_idx (`int`): layer_idx in the forward pass. Used to distinguish Zamba's tied transformer layers. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -656,7 +543,7 @@ def forward( layer_idx: int | None = None, attention_mask: torch.Tensor | None = None, causal_mask: torch.Tensor | None = None, - past_key_values: ZambaHybridDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, position_ids: torch.LongTensor | None = None, transformer_hidden_states: torch.Tensor | None = None, @@ -667,7 +554,7 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -708,7 +595,7 @@ def forward( layer_idx: int | None = None, attention_mask: torch.Tensor | None = None, causal_mask: torch.Tensor | None = None, - past_key_values: ZambaHybridDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: @@ -720,7 +607,7 @@ def forward( layer_idx (`int`): layer number. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -754,11 +641,10 @@ class ZambaPreTrainedModel(PreTrainedModel): config: ZambaConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["ZambaAttentionDecoderLayer", "ZambaMambaDecoderLayer"] + _no_split_modules = ["ZambaHybridLayer", "ZambaMambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn = False _supports_sdpa = False - # Note: only supports ZambaHybridDynamicCache _is_stateful = True _can_record_outputs = { "hidden_states": ZambaMambaDecoderLayer, @@ -835,7 +721,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: ZambaHybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -854,10 +740,7 @@ def forward( # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer if use_cache and past_key_values is None: - logger.warning_once( - "Zamba requires an initialized `ZambaHybridDynamicCache` to return a cache. None was " - "provided, so no cache will be returned." - ) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -886,9 +769,6 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, @@ -915,7 +795,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: ZambaHybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, @@ -987,13 +867,6 @@ def prepare_inputs_for_generation( is_first_iteration=False, **kwargs, ): - # Overwritten -- has a unique cache type, `ZambaHybridDynamicCache` - - if past_key_values is None: - past_key_values = ZambaHybridDynamicCache( - self.config, input_ids.shape[0], dtype=self.dtype, device=self.device - ) - kwargs["logits_to_keep"] = self.config.num_logits_to_keep model_inputs = super().prepare_inputs_for_generation( input_ids, diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index a9888babe74d..6d1af578e087 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -66,7 +66,7 @@ class Zamba2Config(PreTrainedConfig): ```""" model_type = "zamba2" - attribute_map = {"head_dim": "attention_head_dim"} + attribute_map = {"layer_types": "layers_block_type", "head_dim": "attention_head_dim"} keys_to_ignore_at_inference = ["past_key_values"] vocab_size: int = 32000 diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 775d0d45d009..6e4ea7dcf2d8 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -21,13 +21,14 @@ import math from collections.abc import Callable from itertools import cycle -from typing import Any, Optional +from typing import Optional import torch from torch import nn from ... import initialization as init from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_func_from_hub from ...integrations.hub_kernels import lazy_load_kernel @@ -88,107 +89,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Zamba2HybridDynamicCache: - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - is_compileable = False - - def __init__( - self, config: Zamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: str | None = None - ): - self.dtype = dtype - self.layers_block_type = config.layers_block_type - self.has_previous_state = False - self.intermediate_size = int(config.mamba_expand * config.hidden_size) - self.ssm_state_size = config.mamba_d_state - self.conv_kernel_size = config.mamba_d_conv - self.n_mamba_heads = config.n_mamba_heads - self.transformer_layers = [] - self._modules = {} - self._parameters = {} - self._buffers = {} - self.conv_states = {} - self.ssm_states = {} - for i in range(config.num_hidden_layers): - self.conv_states[i] = torch.zeros( - batch_size, - self.intermediate_size + 2 * config.mamba_ngroups * config.mamba_d_state, - self.conv_kernel_size, - device=device, - dtype=dtype, - ) - self.ssm_states[i] = torch.zeros( - batch_size, self.n_mamba_heads, config.mamba_headdim, self.ssm_state_size, device=device, dtype=dtype - ) - if self.layers_block_type[i] == "hybrid": - self.transformer_layers.append(i) - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def __len__(self): - return len(self.key_cache) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - if self.get_seq_length() > 0: - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - class Zamba2RotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -401,7 +301,7 @@ def forward( hidden_states: torch.Tensor, layer_idx: int, attention_mask: torch.Tensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: @@ -604,7 +504,7 @@ def __init__(self, config: Zamba2Config, layer_idx: int | None = None): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: Zamba2HybridDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): # set up dimensions for reshapes later @@ -614,7 +514,7 @@ def cuda_kernels_forward( d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads # getting projected states from cache if it exists - if cache_params is not None and cache_params.has_previous_state: + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D) d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads] @@ -622,7 +522,7 @@ def cuda_kernels_forward( hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -643,7 +543,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -707,7 +607,7 @@ def cuda_kernels_forward( conv_state = nn.functional.pad( hidden_states_B_C_t, (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] @@ -744,7 +644,7 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer scan_output = self.norm(scan_output, gate) @@ -752,12 +652,12 @@ def cuda_kernels_forward( return out # fmt: off - def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | None=None, attention_mask: torch.Tensor | None=None): + def torch_forward(self, input_states, cache_params: Cache | None=None, attention_mask: torch.Tensor | None = None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # Gated MLP's linear projection - if cache_params is not None and cache_params.has_previous_state: - projected_states = self.in_proj(input_states.squeeze(1)) + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): + projected_states = self.in_proj(input_states) else: if attention_mask is not None: # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 @@ -767,43 +667,34 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N _, _, gate, hidden_states, dt = projected_states.split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + hidden_states = hidden_states.transpose(1, 2) + + use_precomputed_state = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # Convolution sequence transformation - if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx].clone() - ssm_state = ssm_state.to(hidden_states.device) - if cache_params.has_previous_state: - gate = gate.unsqueeze(1) - conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) - # handle batched generation - states are copied through - conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states - cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) - if self.use_conv_bias: - hidden_states += self.conv1d.bias - hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding - else: - hidden_states = hidden_states.transpose(1,2) + if use_precomputed_state: + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding + else: + if cache_params is not None: conv_state = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] - if attention_mask is not None: - dtype = hidden_states.dtype - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - else: - ssm_state = torch.zeros( - (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), - device=hidden_states.device, dtype=dtype - ) - hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) + + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len].transpose(1, 2)) + if attention_mask is not None: + dtype = hidden_states.dtype + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and cache_params.has_previous_state: + if use_precomputed_state: # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] @@ -832,9 +723,9 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N dBx = dB * hidden_states[..., None] # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states.clone() + ssm_states = ssm_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -843,7 +734,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_states.to(C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -904,10 +795,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] # permute back B * decay states states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) - if cache_params is not None and cache_params.has_previous_state: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) @@ -935,7 +823,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N y = y[:, :seq_len, :, :] y = y.reshape(batch_size, seq_len, -1) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) @@ -949,7 +837,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N def forward( self, hidden_states, - cache_params: Zamba2HybridDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): @@ -1017,7 +905,7 @@ def forward( original_hidden_states: torch.Tensor, layer_idx: int, attention_mask: torch.Tensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, position_embeddings: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.FloatTensor]: @@ -1030,7 +918,7 @@ def forward( (see fig. 2 in https://huggingface.co/papers/2405.16712). attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -1069,7 +957,7 @@ def forward( layer_idx: int | None = None, attention_mask: torch.Tensor | None = None, causal_mask: torch.Tensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, position_ids: torch.LongTensor | None = None, transformer_hidden_states: torch.Tensor | None = None, @@ -1080,7 +968,7 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -1123,7 +1011,7 @@ def forward( layer_idx: int | None = None, attention_mask: torch.Tensor | None = None, causal_mask: torch.Tensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, position_embeddings: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, @@ -1137,7 +1025,7 @@ def forward( layer_idx (`int`): layer number. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -1176,7 +1064,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): config: Zamba2Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] + _no_split_modules = ["Zamba2HybridLayer", "Zamba2MambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_flex_attn = True @@ -1245,7 +1133,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -1264,8 +1152,7 @@ def forward( # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer if use_cache and past_key_values is None: - batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] - past_key_values = Zamba2HybridDynamicCache(self.config, batch_size, dtype=self.dtype, device=self.device) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1302,9 +1189,6 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values is not None and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, @@ -1365,7 +1249,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, @@ -1437,13 +1321,6 @@ def prepare_inputs_for_generation( is_first_iteration=False, **kwargs, ): - # Overwritten -- has a unique cache type, `Zamba2HybridDynamicCache` - - if past_key_values is None: - past_key_values = Zamba2HybridDynamicCache( - self.config, input_ids.shape[0], dtype=self.dtype, device=self.device - ) - kwargs["logits_to_keep"] = self.config.num_logits_to_keep model_inputs = super().prepare_inputs_for_generation( input_ids, @@ -1490,7 +1367,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 2d57dc94046c..d7716301ad4a 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -21,6 +21,7 @@ from ... import initialization as init from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast @@ -37,7 +38,6 @@ ZambaAttentionDecoderLayer, ZambaForCausalLM, ZambaForSequenceClassification, - ZambaHybridDynamicCache, ZambaHybridLayer, ZambaMambaDecoderLayer, ZambaModel, @@ -77,71 +77,6 @@ class Zamba2RMSNorm(ZambaRMSNorm): pass -class Zamba2HybridDynamicCache(ZambaHybridDynamicCache): - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - def __init__( - self, config: Zamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: str | None = None - ): - self.dtype = dtype - self.layers_block_type = config.layers_block_type - self.has_previous_state = False - self.intermediate_size = int(config.mamba_expand * config.hidden_size) - self.ssm_state_size = config.mamba_d_state - self.conv_kernel_size = config.mamba_d_conv - self.n_mamba_heads = config.n_mamba_heads - self.transformer_layers = [] - self._modules = {} - self._parameters = {} - self._buffers = {} - self.conv_states = {} - self.ssm_states = {} - for i in range(config.num_hidden_layers): - self.conv_states[i] = torch.zeros( - batch_size, - self.intermediate_size + 2 * config.mamba_ngroups * config.mamba_d_state, - self.conv_kernel_size, - device=device, - dtype=dtype, - ) - self.ssm_states[i] = torch.zeros( - batch_size, self.n_mamba_heads, config.mamba_headdim, self.ssm_state_size, device=device, dtype=dtype - ) - if self.layers_block_type[i] == "hybrid": - self.transformer_layers.append(i) - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the mask""" - kv_offset = 0 - kv_length = self.get_seq_length(layer_idx) + query_length - return kv_length, kv_offset - - class Zamba2RotaryEmbedding(LlamaRotaryEmbedding): pass @@ -208,7 +143,7 @@ def forward( hidden_states: torch.Tensor, layer_idx: int, attention_mask: torch.Tensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: @@ -357,7 +292,7 @@ def __init__(self, config: Zamba2Config, layer_idx: int | None = None): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: Zamba2HybridDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, ): # set up dimensions for reshapes later @@ -367,7 +302,7 @@ def cuda_kernels_forward( d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads # getting projected states from cache if it exists - if cache_params is not None and cache_params.has_previous_state: + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D) d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads] @@ -375,7 +310,7 @@ def cuda_kernels_forward( hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, - cache_params.conv_states[self.layer_idx], + cache_params.layers[self.layer_idx].conv_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, @@ -396,7 +331,7 @@ def cuda_kernels_forward( C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], + cache_params.layers[self.layer_idx].recurrent_states, hidden_states_reshaped, dt, A, @@ -460,7 +395,7 @@ def cuda_kernels_forward( conv_state = nn.functional.pad( hidden_states_B_C_t, (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] @@ -497,7 +432,7 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer scan_output = self.norm(scan_output, gate) @@ -505,12 +440,12 @@ def cuda_kernels_forward( return out # fmt: off - def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | None=None, attention_mask: torch.Tensor | None=None): + def torch_forward(self, input_states, cache_params: Cache | None=None, attention_mask: torch.Tensor | None = None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # Gated MLP's linear projection - if cache_params is not None and cache_params.has_previous_state: - projected_states = self.in_proj(input_states.squeeze(1)) + if cache_params is not None and cache_params.has_previous_state(self.layer_idx): + projected_states = self.in_proj(input_states) else: if attention_mask is not None: # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 @@ -520,43 +455,34 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N _, _, gate, hidden_states, dt = projected_states.split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) + hidden_states = hidden_states.transpose(1, 2) + + use_precomputed_state = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # Convolution sequence transformation - if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx].clone() - ssm_state = ssm_state.to(hidden_states.device) - if cache_params.has_previous_state: - gate = gate.unsqueeze(1) - conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) - # handle batched generation - states are copied through - conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states - cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) - if self.use_conv_bias: - hidden_states += self.conv1d.bias - hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding - else: - hidden_states = hidden_states.transpose(1,2) + if use_precomputed_state: + conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding + else: + if cache_params is not None: conv_state = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] - if attention_mask is not None: - dtype = hidden_states.dtype - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - else: - ssm_state = torch.zeros( - (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), - device=hidden_states.device, dtype=dtype - ) - hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) + + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len].transpose(1, 2)) + if attention_mask is not None: + dtype = hidden_states.dtype + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and cache_params.has_previous_state: + if use_precomputed_state: # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] @@ -585,9 +511,9 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N dBx = dB * hidden_states[..., None] # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) + ssm_states = cache_params.layers[self.layer_idx].recurrent_states.clone() + ssm_states = ssm_states * dA + dBx + ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] @@ -596,7 +522,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] + ssm_states = ssm_states.to(C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -657,10 +583,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] # permute back B * decay states states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) - if cache_params is not None and cache_params.has_previous_state: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] - else: - previous_states = torch.zeros_like(states[:, :1]) + previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) @@ -688,7 +611,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N y = y[:, :seq_len, :, :] y = y.reshape(batch_size, seq_len, -1) if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_recurrent_state(ssm_state, self.layer_idx) scan_output = self.norm(y, gate) @@ -702,7 +625,7 @@ def torch_forward(self, input_states, cache_params: Zamba2HybridDynamicCache | N def forward( self, hidden_states, - cache_params: Zamba2HybridDynamicCache | None = None, + cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): @@ -768,7 +691,7 @@ def forward( original_hidden_states: torch.Tensor, layer_idx: int, attention_mask: torch.Tensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, position_embeddings: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.FloatTensor]: @@ -781,7 +704,7 @@ def forward( (see fig. 2 in https://huggingface.co/papers/2405.16712). attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -828,7 +751,7 @@ def forward( layer_idx: int | None = None, attention_mask: torch.Tensor | None = None, causal_mask: torch.Tensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, use_cache: bool | None = False, position_embeddings: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, @@ -842,7 +765,7 @@ def forward( layer_idx (`int`): layer number. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -881,7 +804,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): config: Zamba2Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] + _no_split_modules = ["Zamba2HybridLayer", "Zamba2MambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_flex_attn = True @@ -983,7 +906,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -1002,8 +925,7 @@ def forward( # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer if use_cache and past_key_values is None: - batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] - past_key_values = Zamba2HybridDynamicCache(self.config, batch_size, dtype=self.dtype, device=self.device) + past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1040,9 +962,6 @@ def forward( hidden_states = self.final_layernorm(hidden_states) - if past_key_values is not None and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, @@ -1069,7 +988,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: Zamba2HybridDynamicCache | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 7d764b55d008..055b852be0b3 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -81,6 +81,8 @@ Cache, DynamicCache, EncoderDecoderCache, + LinearAttentionAndFullAttentionLayer, + LinearAttentionLayer, QuantoQuantizedLayer, StaticCache, ) @@ -2535,6 +2537,24 @@ def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, c [encoder_expected_shape] * len(hidden_states), ) + def _get_conv_state_shape(self, batch_size: int, config): + # Default conv state shape, for linear attention models - can vary based on models so this function is convenient + # to easily check caches + # Note that non-mamba models will NOT have the config fields - it does not matter as they will only use attention + # cache layers, so the None default values will not be used + intermediate_size = getattr(config, "intermediate_size", None) + conv_kernel = getattr(config, "conv_kernel", None) + return (batch_size, intermediate_size, conv_kernel) + + def _get_recurrent_state_shape(self, batch_size: int, config): + # Default recurrent state shape, for linear attention models - can vary based on models so this function is convenient + # to easily check caches + # Note that non-mamba models will NOT have the config fields - it does not matter as they will only use attention + # cache layers, so the None default values will not be used + intermediate_size = getattr(config, "intermediate_size", None) + state_size = getattr(config, "state_size", None) + return (batch_size, intermediate_size, state_size) + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): # Raise a useful error, asking to explicitly override the method if not isinstance(past_key_values, Cache): @@ -2553,17 +2573,23 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l config = config.get_text_config(decoder=True) # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + # Only pure mamba models do not have num_attention_heads defined in config, so it can never be 1 in practice for attention models + num_attention_heads = getattr(config, "num_attention_heads", 1) + num_kv_heads = getattr(config, "num_key_value_heads", num_attention_heads) hidden_size = getattr(config, "d_model", config.hidden_size) - head_dim = getattr(config, "head_dim", hidden_size // config.num_attention_heads) + head_dim = getattr(config, "head_dim", hidden_size // num_attention_heads) # For cross attention cache, the seq_length depends on the model, so we remove that dim - expected_shape = ( - (batch_size, num_heads, seq_length, head_dim) + attention_shape = ( + (batch_size, num_kv_heads, seq_length, head_dim) if seq_length is not None - else (batch_size, num_heads, head_dim) + else (batch_size, num_kv_heads, head_dim) ) + # For mamba layers + conv_shape = self._get_conv_state_shape(batch_size, config) + recurrent_shape = self._get_recurrent_state_shape(batch_size, config) + # Check the size is coherent num_hidden_layers = config.num_hidden_layers if getattr(config, "num_kv_shared_layers", None) is not None: @@ -2572,11 +2598,30 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l # Check each layer has the correct shape for layer in past_key_values.layers: - # Remove the seq_length dim for cross-attention cache (it changes based on the model) - keys = layer.keys if seq_length is not None else layer.keys[:, :, 0, :] - values = layer.values if seq_length is not None else layer.values[:, :, 0, :] - self.assertEqual(keys.shape, expected_shape) - self.assertEqual(values.shape, expected_shape) + # Mamba + Attention layer cache + if type(layer) is LinearAttentionAndFullAttentionLayer: + # Remove the seq_length dim for cross-attention cache (it changes based on the model) + keys = layer.keys if seq_length is not None else layer.keys[:, :, 0, :] + values = layer.values if seq_length is not None else layer.values[:, :, 0, :] + self.assertEqual(keys.shape, attention_shape) + self.assertEqual(values.shape, attention_shape) + self.assertEqual(layer.conv_states.shape, conv_shape) + # May not be used (e.g. lfm2) + if layer.is_recurrent_states_initialized: + self.assertEqual(layer.recurrent_states.shape, recurrent_shape) + # Mamba only layer cache + elif type(layer) is LinearAttentionLayer: + self.assertEqual(layer.conv_states.shape, conv_shape) + # May not be used (e.g. lfm2) + if layer.is_recurrent_states_initialized: + self.assertEqual(layer.recurrent_states.shape, recurrent_shape) + # Attention only layer type + else: + # Remove the seq_length dim for cross-attention cache (it changes based on the model) + keys = layer.keys if seq_length is not None else layer.keys[:, :, 0, :] + values = layer.values if seq_length is not None else layer.values[:, :, 0, :] + self.assertEqual(keys.shape, attention_shape) + self.assertEqual(values.shape, attention_shape) def _check_sequence_inside_sequence(self, tensor_1, tensor_2): # check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1. @@ -2616,8 +2661,30 @@ def _check_caches_are_equal(self, cache1: Cache, cache2: Cache): num_layers = len(cache1) for idx in range(num_layers): - torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) - torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) + self.assertEqual(type(cache1.layers[idx]), type(cache2.layers[idx])) + + # Mamba + Attention layer + if type(cache1.layers[idx]) is LinearAttentionAndFullAttentionLayer: + torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) + torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) + torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) + # May not be used (e.g. lfm2) + if cache1.layers[idx].is_recurrent_states_initialized: + torch.testing.assert_close( + cache1.layers[idx].recurrent_states, cache2.layers[idx].recurrent_states + ) + # Mamba layer + elif type(cache1.layers[idx]) is LinearAttentionLayer: + torch.testing.assert_close(cache1.layers[idx].conv_states, cache2.layers[idx].conv_states) + # May not be used (e.g. lfm2) + if cache1.layers[idx].is_recurrent_states_initialized: + torch.testing.assert_close( + cache1.layers[idx].recurrent_states, cache2.layers[idx].recurrent_states + ) + # Attention layer + else: + torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys) + torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values) @require_torch diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index ee8143f31c14..fd028512f16c 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -47,11 +47,7 @@ if is_torch_available(): import torch - from transformers import ( - BambaForCausalLM, - BambaModel, - ) - from transformers.models.bamba.modeling_bamba import HybridMambaAttentionDynamicCache + from transformers import BambaForCausalLM, BambaModel class BambaModelTester: @@ -228,14 +224,9 @@ def create_and_check_decoder_model_past_large_inputs( model.eval() # first forward pass - # Attention: Jamba needs the cache to be initialized to return a cache! - past_key_values = HybridMambaAttentionDynamicCache( - config, input_ids.shape[0], model.dtype, device=model.device - ) outputs = model( input_ids, attention_mask=input_mask, - past_key_values=past_key_values, use_cache=True, ) past_key_values = outputs.past_key_values @@ -288,48 +279,16 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.7, 0.8] - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, HybridMambaAttentionDynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) - + def _get_conv_state_shape(self, batch_size: int, config): conv_shape = ( batch_size, config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * config.mamba_d_state, config.mamba_d_conv, ) - ssm_shape = (batch_size, config.mamba_n_heads, config.mamba_d_head, config.mamba_d_state) - - self.assertTrue(config.num_hidden_layers, len(past_key_values)) + return conv_shape - for idx in range(len(past_key_values)): - if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) - self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) - else: - self.assertEqual(past_key_values.key_cache[idx].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[idx].shape, attention_shape) - - def _check_caches_are_equal( - self, cache1: HybridMambaAttentionDynamicCache, cache2: HybridMambaAttentionDynamicCache - ): - if not isinstance(cache1, HybridMambaAttentionDynamicCache) or not isinstance( - cache2, HybridMambaAttentionDynamicCache - ): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx]) - torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx]) + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.mamba_n_heads, config.mamba_d_head, config.mamba_d_state) def setUp(self): self.model_tester = self.model_tester_class(self) diff --git a/tests/models/falcon_h1/test_modeling_falcon_h1.py b/tests/models/falcon_h1/test_modeling_falcon_h1.py index 88c656dce1e0..a6429f2da621 100644 --- a/tests/models/falcon_h1/test_modeling_falcon_h1.py +++ b/tests/models/falcon_h1/test_modeling_falcon_h1.py @@ -37,9 +37,6 @@ import torch from transformers import AutoTokenizer, FalconH1ForCausalLM, FalconH1Model - from transformers.models.falcon_h1.modeling_falcon_h1 import ( - FalconHybridMambaAttentionDynamicCache, - ) class FalconH1ModelTester: @@ -206,17 +203,9 @@ def create_and_check_decoder_model_past_large_inputs( model.eval() # first forward pass - # Attention: Jamba needs the cache to be initialized to return a cache! - past_key_values = FalconHybridMambaAttentionDynamicCache( - config, - input_ids.shape[0], - model.dtype, - devices=[model.device for _ in range(model.config.num_hidden_layers)], - ) outputs = model( input_ids, attention_mask=input_mask, - past_key_values=past_key_values, use_cache=True, ) past_key_values = outputs.past_key_values @@ -264,38 +253,19 @@ class FalconH1ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM {"feature-extraction": FalconH1Model, "text-generation": FalconH1ForCausalLM} if is_torch_available() else {} ) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, FalconHybridMambaAttentionDynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - expected_shape = (batch_size, num_heads, seq_length, head_dim) - - self.assertListEqual( - [key_tensor.shape for key_tensor in past_key_values.key_cache], - [expected_shape] * len(past_key_values.key_cache), + def _get_conv_state_shape(self, batch_size: int, config): + intermediate_size = ( + config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) ) - self.assertListEqual( - [value_cache.shape for value_cache in past_key_values.value_cache], - [expected_shape] * len(past_key_values.value_cache), + conv_shape = ( + batch_size, + intermediate_size + 2 * config.mamba_n_groups * config.mamba_d_state, + config.mamba_d_conv, ) + return conv_shape - def _check_caches_are_equal(self, cache1, cache2): - if not isinstance(cache1, FalconHybridMambaAttentionDynamicCache) or not isinstance( - cache2, FalconHybridMambaAttentionDynamicCache - ): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx]) - torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx]) + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.mamba_n_heads, config.mamba_d_head, config.mamba_d_state) def setUp(self): self.model_tester = FalconH1ModelTester(self) diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py index 9e4ce3e5cb65..b14bc4227efe 100644 --- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py +++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py @@ -41,8 +41,7 @@ if is_torch_available(): import torch - from transformers import FalconMambaForCausalLM, FalconMambaModel - from transformers.models.falcon_mamba.modeling_falcon_mamba import FalconMambaCache + from transformers import DynamicCache, FalconMambaForCausalLM, FalconMambaModel # Copied from transformers.tests.models.mamba.MambaModelTester with Mamba->FalconMamba,mamba->falcon_mamba @@ -255,7 +254,6 @@ def prepare_config_and_inputs_for_common(self): @require_torch -# Copied from transformers.tests.models.mamba.MambaModelTest with Mamba->Falcon,mamba->falcon_mamba,FalconMambaCache->MambaCache class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (FalconMambaModel, FalconMambaForCausalLM) if is_torch_available() else () has_attentions = False # FalconMamba does not support attentions @@ -276,18 +274,6 @@ def setUp(self): self, config_class=FalconMambaConfig, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] ) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, FalconMambaCache) - - conv_shape = (batch_size, config.intermediate_size, config.conv_kernel) - ssm_shape = (batch_size, config.intermediate_size, config.state_size) - - self.assertTrue(config.num_hidden_layers, len(past_key_values.conv_states)) - - for idx in range(len(past_key_values.conv_states)): - self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) - self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) - def assertInterval(self, member, container, msg=None): r""" Simple utility function to check if a member is inside an interval. @@ -348,9 +334,12 @@ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, FalconMambaCache): # MODIFIED PART START - recursive_check(tuple_object.conv_states, dict_object.conv_states) - recursive_check(tuple_object.ssm_states, dict_object.ssm_states) + if isinstance(tuple_object, DynamicCache): # MODIFIED PART START + for idx in range(len(tuple_object)): + recursive_check(tuple_object.layers[idx].conv_states, dict_object.layers[idx].conv_states) + recursive_check( + tuple_object.layers[idx].recurrent_states, dict_object.layers[idx].recurrent_states + ) elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): recursive_check(tuple_iterable_value, dict_iterable_value) diff --git a/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py b/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py index 8cb946d0aa2e..919bf79deac3 100644 --- a/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py +++ b/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py @@ -46,11 +46,7 @@ if is_torch_available(): import torch - from transformers import ( - GraniteMoeHybridForCausalLM, - GraniteMoeHybridModel, - ) - from transformers.models.granitemoehybrid.modeling_granitemoehybrid import HybridMambaAttentionDynamicCache + from transformers import GraniteMoeHybridForCausalLM, GraniteMoeHybridModel class GraniteMoeHybridModelTester(BambaModelTester): @@ -109,24 +105,6 @@ class GraniteMoeHybridModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.7, 0.8] - def _check_caches_are_equal( - self, cache1: HybridMambaAttentionDynamicCache, cache2: HybridMambaAttentionDynamicCache - ): - if not isinstance(cache1, HybridMambaAttentionDynamicCache) or not isinstance( - cache2, HybridMambaAttentionDynamicCache - ): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx]) - torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx]) - def setUp(self): self.model_tester = self.model_tester_class(self) self.config_tester = ConfigTester(self, config_class=self.model_tester.config_class, hidden_size=64) @@ -324,30 +302,16 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_id loss_padfree = res_padfree.loss torch.testing.assert_close(loss_padded, loss_padfree) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, HybridMambaAttentionDynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) - + def _get_conv_state_shape(self, batch_size: int, config): conv_shape = ( batch_size, config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * config.mamba_d_state, config.mamba_d_conv, ) - ssm_shape = (batch_size, config.mamba_n_heads, config.mamba_d_head, config.mamba_d_state) - - self.assertTrue(config.num_hidden_layers, len(past_key_values)) + return conv_shape - for idx in range(len(past_key_values)): - if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) - self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) - else: - self.assertEqual(past_key_values.key_cache[idx].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[idx].shape, attention_shape) + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.mamba_n_heads, config.mamba_d_head, config.mamba_d_state) def test_config_requires_mamba_or_attention_layers(self): """Ensure we can't create a config with disallowed layers.""" diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index a71271dd3cbe..28d4b7a18c61 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -43,14 +43,7 @@ if is_torch_available(): import torch - from transformers import ( - JambaForCausalLM, - JambaForSequenceClassification, - JambaModel, - ) - from transformers.models.jamba.modeling_jamba import ( - HybridMambaAttentionDynamicCache, - ) + from transformers import JambaForCausalLM, JambaForSequenceClassification, JambaModel class JambaConfigTester(ConfigTester): @@ -250,17 +243,7 @@ def create_and_check_decoder_model_past_large_inputs( model.to(torch_device) model.eval() - # first forward pass - # Attention: Jamba needs the cache to be initialized to return a cache! - past_key_values = HybridMambaAttentionDynamicCache( - config, input_ids.shape[0], model.dtype, device=model.device - ) - outputs = model( - input_ids, - attention_mask=input_mask, - past_key_values=past_key_values, - use_cache=True, - ) + outputs = model(input_ids, attention_mask=input_mask, use_cache=True) past_key_values = outputs.past_key_values # create hypothetical multiple next token and extent to next_input_ids @@ -339,43 +322,11 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi else {} ) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, HybridMambaAttentionDynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) - conv_shape = (batch_size, config.mamba_expand * config.hidden_size, config.mamba_d_conv) - ssm_shape = (batch_size, config.mamba_expand * config.hidden_size, config.mamba_d_state) - - self.assertTrue(config.num_hidden_layers, len(past_key_values)) + def _get_conv_state_shape(self, batch_size: int, config): + return (batch_size, config.mamba_expand * config.hidden_size, config.mamba_d_conv) - for idx in range(len(past_key_values)): - if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) - self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) - else: - self.assertEqual(past_key_values.key_cache[idx].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[idx].shape, attention_shape) - - def _check_caches_are_equal( - self, cache1: HybridMambaAttentionDynamicCache, cache2: HybridMambaAttentionDynamicCache - ): - if not isinstance(cache1, HybridMambaAttentionDynamicCache) or not isinstance( - cache2, HybridMambaAttentionDynamicCache - ): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx]) - torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx]) + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.mamba_expand * config.hidden_size, config.mamba_d_state) def setUp(self): self.model_tester = JambaModelTester(self) diff --git a/tests/models/lfm2/test_modeling_lfm2.py b/tests/models/lfm2/test_modeling_lfm2.py index 13afd1c2726b..67698c564092 100644 --- a/tests/models/lfm2/test_modeling_lfm2.py +++ b/tests/models/lfm2/test_modeling_lfm2.py @@ -30,7 +30,6 @@ import torch from transformers import Lfm2ForCausalLM, Lfm2Model - from transformers.models.lfm2.modeling_lfm2 import Lfm2HybridConvCache class Lfm2ModelTester(CausalLMModelTester): @@ -52,34 +51,8 @@ class Lfm2ModelTest(CausalLMModelTest, unittest.TestCase): # used in `test_torch_compile_for_training` _torch_compile_train_cls = Lfm2ForCausalLM if is_torch_available() else None - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, Lfm2HybridConvCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) - conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) - - for i in range(config.num_hidden_layers): - if config.layer_types[i] == "full_attention": - self.assertEqual(past_key_values.key_cache[i].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[i].shape, attention_shape) - else: - self.assertEqual(past_key_values.conv_cache[i].shape, conv_shape) - - def _check_caches_are_equal(self, cache1: Lfm2HybridConvCache, cache2: Lfm2HybridConvCache): - if not isinstance(cache1, Lfm2HybridConvCache) or not isinstance(cache2, Lfm2HybridConvCache): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_cache[idx], cache2.conv_cache[idx]) + def _get_conv_state_shape(self, batch_size: int, config): + return (batch_size, config.hidden_size, config.conv_L_cache) def test_attention_outputs(self): """Lfm2Moe alternates between attention and short-conv layers.""" diff --git a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py index fa8aecc99707..2015a4a83e31 100644 --- a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py +++ b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py @@ -33,7 +33,6 @@ import torch from transformers import Lfm2MoeConfig, Lfm2MoeForCausalLM, Lfm2MoeModel - from transformers.models.lfm2_moe.modeling_lfm2_moe import Lfm2MoeHybridConvCache class Lfm2MoeModelTester(CausalLMModelTester): @@ -70,34 +69,8 @@ class Lfm2MoeModelTest(CausalLMModelTest, unittest.TestCase): # used in `test_torch_compile_for_training` _torch_compile_train_cls = Lfm2MoeForCausalLM if is_torch_available() else None - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, Lfm2MoeHybridConvCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) - conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) - - for i in range(config.num_hidden_layers): - if config.layer_types[i] == "full_attention": - self.assertEqual(past_key_values.key_cache[i].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[i].shape, attention_shape) - else: - self.assertEqual(past_key_values.conv_cache[i].shape, conv_shape) - - def _check_caches_are_equal(self, cache1: Lfm2MoeHybridConvCache, cache2: Lfm2MoeHybridConvCache): - if not isinstance(cache1, Lfm2MoeHybridConvCache) or not isinstance(cache2, Lfm2MoeHybridConvCache): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_cache[idx], cache2.conv_cache[idx]) + def _get_conv_state_shape(self, batch_size: int, config): + return (batch_size, config.hidden_size, config.conv_L_cache) def test_attention_outputs(self): """Lfm2Moe alternates between attention and short-conv layers.""" diff --git a/tests/models/lfm2_vl/test_modeling_lfm2_vl.py b/tests/models/lfm2_vl/test_modeling_lfm2_vl.py index aa153563e4f8..c14e3933f77b 100644 --- a/tests/models/lfm2_vl/test_modeling_lfm2_vl.py +++ b/tests/models/lfm2_vl/test_modeling_lfm2_vl.py @@ -44,7 +44,6 @@ import torch from transformers import Lfm2VlConfig, Lfm2VlForConditionalGeneration, Lfm2VlModel - from transformers.models.lfm2.modeling_lfm2 import Lfm2HybridConvCache class Lfm2VlModelTester(CausalLMModelTester): @@ -172,35 +171,8 @@ def setUp(self): self, config_class=Lfm2VlConfig, has_text_modality=False, common_properties=common_properties ) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, Lfm2HybridConvCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) - conv_shape = (batch_size, config.hidden_size, config.conv_L_cache) - - for i in range(config.num_hidden_layers): - if config.layer_types[i] == "full_attention": - self.assertEqual(past_key_values.key_cache[i].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[i].shape, attention_shape) - else: - self.assertEqual(past_key_values.conv_cache[i].shape, conv_shape) - - def _check_caches_are_equal(self, cache1: Lfm2HybridConvCache, cache2: Lfm2HybridConvCache): - """Text model uses lfm2, which has non-standard cache""" - if not isinstance(cache1, Lfm2HybridConvCache) or not isinstance(cache2, Lfm2HybridConvCache): - raise ValueError("The wrong cache is being used!") - - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_cache[idx], cache2.conv_cache[idx]) + def _get_conv_state_shape(self, batch_size: int, config): + return (batch_size, config.hidden_size, config.conv_L_cache) def test_config(self): self.config_tester.run_common_tests() diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 6430a014ec4f..276c03d65099 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -31,11 +31,7 @@ if is_torch_available(): import torch - from transformers import ( - MambaForCausalLM, - MambaModel, - ) - from transformers.models.mamba.modeling_mamba import MambaCache + from transformers import CompileConfig, DynamicCache, MambaForCausalLM, MambaModel class MambaModelTester: @@ -246,18 +242,6 @@ def setUp(self): def test_enable_input_require_grads(self): self.skipTest("Mamba currently requires CUDA/Metal/XPU to run enable_input_require_grads.") - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, MambaCache) - - conv_shape = (batch_size, config.intermediate_size, config.conv_kernel) - ssm_shape = (batch_size, config.intermediate_size, config.state_size) - - self.assertTrue(config.num_hidden_layers, len(past_key_values.conv_states)) - - for idx in range(len(past_key_values.conv_states)): - self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) - self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) - def assertInterval(self, member, container, msg=None): r""" Simple utility function to check if a member is inside an interval. @@ -317,9 +301,12 @@ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, MambaCache): # MODIFIED PART START - recursive_check(tuple_object.conv_states, dict_object.conv_states) - recursive_check(tuple_object.ssm_states, dict_object.ssm_states) + if isinstance(tuple_object, DynamicCache): # MODIFIED PART START + for idx in range(len(tuple_object)): + recursive_check(tuple_object.layers[idx].conv_states, dict_object.layers[idx].conv_states) + recursive_check( + tuple_object.layers[idx].recurrent_states, dict_object.layers[idx].recurrent_states + ) elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): recursive_check(tuple_iterable_value, dict_iterable_value) @@ -368,29 +355,6 @@ def recursive_check(tuple_object, dict_object): def test_beam_sample_generate(self): pass - def test_dtype_mismatch_handled_in_cache(self): - config, input_ids, *args = self.model_tester.prepare_config_and_inputs() - model = MambaModel(config) - model.to(torch_device).to(torch.float16) - model.eval() - - # Create cache with float32 dtype - cache_params = MambaCache(config, max_batch_size=input_ids.size(0), dtype=torch.float32, device=torch_device) - - # If code is correct, no error occurs and test passes - outputs = model( - input_ids, - cache_params=cache_params, - use_cache=True, - ) - - self.assertIsNotNone(outputs) - self.assertIsNotNone(outputs.last_hidden_state) - self.assertEqual( - outputs.last_hidden_state.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size), - ) - @unittest.skip("Mamba models do not support DDP.") def test_multi_gpu_data_parallel_forward(self): pass @@ -490,8 +454,10 @@ def test_compile_mamba_cache(self): output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) - model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead") - output = model.generate(input_ids, max_new_tokens=20) + compile_config = CompileConfig(fullgraph=True, mode="reduce-overhead") + output = model.generate( + input_ids, max_new_tokens=20, cache_implementation="static", compile_config=compile_config + ) output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index aef487d46351..6e7116afdc36 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -34,11 +34,8 @@ if is_torch_available(): import torch - from transformers import ( - Mamba2ForCausalLM, - Mamba2Model, - ) - from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer + from transformers import DynamicCache, Mamba2ForCausalLM, Mamba2Model + from transformers.models.mamba2.modeling_mamba2 import Mamba2Mixer class Mamba2ConfigTester(ConfigTester): @@ -247,20 +244,17 @@ def setUp(self): self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] ) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, Mamba2Cache) - + def _get_conv_state_shape(self, batch_size: int, config): intermediate_size = config.expand * config.hidden_size conv_shape = ( - config.num_hidden_layers, batch_size, intermediate_size + 2 * config.n_groups * config.state_size, config.conv_kernel, ) - ssm_shape = (config.num_hidden_layers, batch_size, config.num_heads, config.head_dim, config.state_size) + return conv_shape - self.assertEqual(past_key_values.conv_states.shape, conv_shape) - self.assertEqual(past_key_values.ssm_states.shape, ssm_shape) + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.num_heads, config.head_dim, config.state_size) def test_mamba2_caching(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -291,9 +285,12 @@ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, Mamba2Cache): # MODIFIED PART START - recursive_check(tuple_object.conv_states, dict_object.conv_states) - recursive_check(tuple_object.ssm_states, dict_object.ssm_states) + if isinstance(tuple_object, DynamicCache): # MODIFIED PART START + for idx in range(len(tuple_object)): + recursive_check(tuple_object.layers[idx].conv_states, dict_object.layers[idx].conv_states) + recursive_check( + tuple_object.layers[idx].recurrent_states, dict_object.layers[idx].recurrent_states + ) elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): recursive_check(tuple_iterable_value, dict_iterable_value) diff --git a/tests/models/nemotron_h/test_modeling_nemotron_h.py b/tests/models/nemotron_h/test_modeling_nemotron_h.py index 467971673065..19ca9f4f77fd 100644 --- a/tests/models/nemotron_h/test_modeling_nemotron_h.py +++ b/tests/models/nemotron_h/test_modeling_nemotron_h.py @@ -18,7 +18,6 @@ import pytest from huggingface_hub.errors import StrictDataclassClassValidationError -from parameterized import parameterized from transformers import AutoTokenizer, NemotronHConfig, NemotronHForCausalLM, is_torch_available from transformers.testing_utils import ( @@ -40,13 +39,7 @@ if is_torch_available(): import torch - from transformers import ( - NemotronHForCausalLM, - NemotronHModel, - ) - from transformers.models.nemotron_h.modeling_nemotron_h import ( - NemotronHHybridDynamicCache, - ) + from transformers import DynamicCache, NemotronHForCausalLM, NemotronHModel class NemotronHModelTester: @@ -237,12 +230,9 @@ def create_and_check_decoder_model_past_large_inputs( model.eval() # first forward pass - # Attention: NemotronH needs the cache to be initialized to return a cache! - past_key_values = NemotronHHybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device) outputs = model( input_ids, attention_mask=input_mask, - past_key_values=past_key_values, use_cache=True, ) past_key_values = outputs.past_key_values @@ -319,17 +309,11 @@ def create_and_check_mamba2_slow_vs_fast_forward(self, config, input_ids, *args) self.parent.assertTrue(torch.allclose(outputs_fast, outputs_slow, atol=1e-3, rtol=1e-3)) # Test with cache - batch_size = input_ids.shape[0] - cache_params = NemotronHHybridDynamicCache( - config=config, batch_size=batch_size, dtype=token_emb.dtype, device=torch_device - ) - + cache_params = DynamicCache(config=config) outputs_fast_cached = mamba_mixer.cuda_kernels_forward(token_emb, cache_params=cache_params) # Reset cache for fair comparison - cache_params_slow = NemotronHHybridDynamicCache( - config=config, batch_size=batch_size, dtype=token_emb.dtype, device=torch_device - ) + cache_params_slow = DynamicCache(config=config) outputs_slow_cached = mamba_mixer.torch_forward(token_emb, cache_params=cache_params_slow) self.parent.assertTrue(torch.allclose(outputs_fast_cached, outputs_slow_cached, atol=1e-3, rtol=1e-3)) @@ -367,46 +351,55 @@ class NemotronHModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester else {} ) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, NemotronHHybridDynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) - - # Mamba cache shapes + def _get_conv_state_shape(self, batch_size: int, config): intermediate_size = config.mamba_num_heads * config.mamba_head_dim conv_shape = ( batch_size, intermediate_size + 2 * config.n_groups * config.ssm_state_size, config.conv_kernel, ) - ssm_shape = (batch_size, config.mamba_num_heads, config.mamba_head_dim, config.ssm_state_size) + return conv_shape - self.assertTrue(config.num_hidden_layers, len(past_key_values)) + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.mamba_num_heads, config.mamba_head_dim, config.ssm_state_size) - for idx in range(len(past_key_values)): - if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) - self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) - elif config.layers_block_type[idx] == "attention": - self.assertEqual(past_key_values.key_cache[idx].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[idx].shape, attention_shape) - - def _check_caches_are_equal(self, cache1: NemotronHHybridDynamicCache, cache2: NemotronHHybridDynamicCache): - if not isinstance(cache1, NemotronHHybridDynamicCache) or not isinstance(cache2, NemotronHHybridDynamicCache): - raise ValueError("The wrong cache is being used!") + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): + # Raise a useful error, asking to explicitly override the method + if not isinstance(past_key_values, DynamicCache): + raise ValueError("The cache does not use the correct Cache") - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") + # Use the correct config + config = config.get_text_config(decoder=True) - num_layers = len(cache1) - for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx]) - torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx]) + # (batch, kv heads, seq_length, head_dim) + # Only pure mamba models do not have num_attention_heads defined in config, so it can never be 1 in practice for attention models + num_attention_heads = getattr(config, "num_attention_heads", 1) + num_kv_heads = getattr(config, "num_key_value_heads", num_attention_heads) + hidden_size = getattr(config, "d_model", config.hidden_size) + head_dim = getattr(config, "head_dim", hidden_size // num_attention_heads) + + # For cross attention cache, the seq_length depends on the model, so we remove that dim + attention_shape = (batch_size, num_kv_heads, seq_length, head_dim) + # For mamba layers + conv_shape = self._get_conv_state_shape(batch_size, config) + recurrent_shape = self._get_recurrent_state_shape(batch_size, config) + + # Check each layer has the correct shape + for layer, layer_type in zip(past_key_values.layers, config.layer_types): + # Moe layers have a default mamba cache instantiated, but it stays empty as the layer does not use it + if layer_type == "moe": + self.assertEqual(layer.conv_states, None) + self.assertEqual(layer.recurrent_states, None) + # Attention layer cache + elif layer_type == "attention": + self.assertEqual(layer.keys.shape, attention_shape) + self.assertEqual(layer.values.shape, attention_shape) + # Mamba layer cache + elif layer_type == "mamba": + self.assertEqual(layer.conv_states.shape, conv_shape) + self.assertEqual(layer.recurrent_states.shape, recurrent_shape) + else: + raise ValueError("Unknown layer type.") def setUp(self): self.model_tester = NemotronHModelTester(self) @@ -586,11 +579,6 @@ def test_flash_attn_2_fp32_ln(self): # with attention mask _ = model(dummy_input, attention_mask=dummy_attention_mask) - @unittest.skip(reason="NemotronH has its own special cache type") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch_accelerator def test_flex_attention_with_grads(self): """ diff --git a/tests/models/qwen3_5/test_modeling_qwen3_5.py b/tests/models/qwen3_5/test_modeling_qwen3_5.py index 191a8cf788e4..f90fb09546d6 100644 --- a/tests/models/qwen3_5/test_modeling_qwen3_5.py +++ b/tests/models/qwen3_5/test_modeling_qwen3_5.py @@ -46,7 +46,6 @@ Qwen3_5TextConfig, Qwen3_5TextModel, ) - from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DynamicCache class Qwen3_5TextModelTester(CausalLMModelTester): @@ -71,35 +70,21 @@ class Qwen3_5TextModelTest(CausalLMModelTest, unittest.TestCase): config_class = Qwen3_5TextConfig model_split_percents = [0.5, 0.8, 0.9] - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - "Qwen3.5 has a special Cache as it alternates with gated deltanet layers" - self.assertIsInstance(past_key_values, Qwen3_5DynamicCache) + def _get_conv_state_shape(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + num_k_heads = config.linear_num_key_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + intermediate_size = 2 * num_k_heads * head_k_dim + num_v_heads * head_v_dim - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - expected_shape = (batch_size, num_heads, seq_length, head_dim) + return (batch_size, intermediate_size, config.linear_conv_kernel_dim) - attention_layer_indices = past_key_values.transformer_layers - self.assertListEqual( - [past_key_values.key_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - self.assertListEqual( - [past_key_values.value_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - - def _check_caches_are_equal(self, cache1, cache2): - "Qwen3.5 has a special Cache as it alternates with gated deltanet layers" - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") + def _get_recurrent_state_shape(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim - num_layers = len(cache1) - for idx in range(num_layers): - if cache1.key_cache[idx] is not None: - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) + return (batch_size, num_v_heads, head_k_dim, head_v_dim) def test_attention_outputs(self): "Needs to be overwritten as Qwen3.5 alternates between attention layers and gated deltanet layers." @@ -319,35 +304,21 @@ def setUp(self): def test_config(self): self.config_tester.run_common_tests() - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - "Qwen3.5 has a special Cache as it alternates with gated deltanet layers" - self.assertIsInstance(past_key_values, Qwen3_5DynamicCache) + def _get_conv_state_shape(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + num_k_heads = config.linear_num_key_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + intermediate_size = 2 * num_k_heads * head_k_dim + num_v_heads * head_v_dim - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - expected_shape = (batch_size, num_heads, seq_length, head_dim) - - attention_layer_indices = past_key_values.transformer_layers - self.assertListEqual( - [past_key_values.key_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - self.assertListEqual( - [past_key_values.value_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) + return (batch_size, intermediate_size, config.linear_conv_kernel_dim) - def _check_caches_are_equal(self, cache1, cache2): - "Qwen3.5 has a special Cache as it alternates with gated deltanet layers" - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") + def _get_recurrent_state_shape(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim - num_layers = len(cache1) - for idx in range(num_layers): - if cache1.key_cache[idx] is not None: - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) + return (batch_size, num_v_heads, head_k_dim, head_v_dim) def test_attention_outputs(self): "Needs to be overwritten as Qwen3.5 alternates between attention layers and gated deltanet layers." diff --git a/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py b/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py index 27cef6196313..d949f777f8a4 100644 --- a/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py +++ b/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py @@ -51,7 +51,6 @@ Qwen3_5MoeTextConfig, Qwen3_5MoeTextModel, ) - from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeDynamicCache class Qwen3_5MoeTextModelTester(CausalLMModelTester): @@ -74,35 +73,21 @@ class Qwen3_5MoeTextModelTest(CausalLMModelTest, unittest.TestCase): model_tester_class = Qwen3_5MoeTextModelTester config_class = Qwen3_5MoeTextConfig - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - "Qwen3.5 Moe has a special Cache as it alternates with gated deltanet layers" - self.assertIsInstance(past_key_values, Qwen3_5MoeDynamicCache) + def _get_conv_state_shape(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + num_k_heads = config.linear_num_key_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + intermediate_size = 2 * num_k_heads * head_k_dim + num_v_heads * head_v_dim - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - expected_shape = (batch_size, num_heads, seq_length, head_dim) + return (batch_size, intermediate_size, config.linear_conv_kernel_dim) - attention_layer_indices = past_key_values.transformer_layers - self.assertListEqual( - [past_key_values.key_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - self.assertListEqual( - [past_key_values.value_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - - def _check_caches_are_equal(self, cache1, cache2): - "Qwen3.5 Moe has a special Cache as it alternates with gated deltanet layers" - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") + def _get_recurrent_state_shape(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim - num_layers = len(cache1) - for idx in range(num_layers): - if cache1.key_cache[idx] is not None: - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) + return (batch_size, num_v_heads, head_k_dim, head_v_dim) def test_attention_outputs(self): "Needs to be overwritten as Qwen3.5 Moe alternates between attention layers and gated deltanet layers." @@ -401,35 +386,21 @@ def setUp(self): def test_config(self): self.config_tester.run_common_tests() - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - "Qwen3.5 Moe has a special Cache as it alternates with gated deltanet layers" - self.assertIsInstance(past_key_values, Qwen3_5MoeDynamicCache) + def _get_conv_state_shape(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + num_k_heads = config.linear_num_key_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + intermediate_size = 2 * num_k_heads * head_k_dim + num_v_heads * head_v_dim - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - expected_shape = (batch_size, num_heads, seq_length, head_dim) - - attention_layer_indices = past_key_values.transformer_layers - self.assertListEqual( - [past_key_values.key_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - self.assertListEqual( - [past_key_values.value_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) + return (batch_size, intermediate_size, config.linear_conv_kernel_dim) - def _check_caches_are_equal(self, cache1, cache2): - "Qwen3.5 Moe has a special Cache as it alternates with gated deltanet layers" - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") + def _get_recurrent_state_shape(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim - num_layers = len(cache1) - for idx in range(num_layers): - if cache1.key_cache[idx] is not None: - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) + return (batch_size, num_v_heads, head_k_dim, head_v_dim) def test_attention_outputs(self): "Needs to be overwritten as Qwen3.5 Moe alternates between attention layers and gated deltanet layers." diff --git a/tests/models/qwen3_next/test_modeling_qwen3_next.py b/tests/models/qwen3_next/test_modeling_qwen3_next.py index 29e5f51705de..4cb53fb6c695 100644 --- a/tests/models/qwen3_next/test_modeling_qwen3_next.py +++ b/tests/models/qwen3_next/test_modeling_qwen3_next.py @@ -25,10 +25,8 @@ import torch from transformers import ( - Cache, Qwen3NextModel, ) - from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextDynamicCache from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester from ...test_modeling_common import ( @@ -59,35 +57,21 @@ def __init__(self, parent): class Qwen3NextModelTest(CausalLMModelTest, unittest.TestCase): model_tester_class = Qwen3NextModelTester - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - "Qwen3-Next has a special Cache as it alternates with gated deltanet layers" - self.assertIsInstance(past_key_values, Qwen3NextDynamicCache) + def _get_conv_state_shape(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + num_k_heads = config.linear_num_key_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + intermediate_size = 2 * num_k_heads * head_k_dim + num_v_heads * head_v_dim - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - expected_shape = (batch_size, num_heads, seq_length, head_dim) + return (batch_size, intermediate_size, config.linear_conv_kernel_dim) - attention_layer_indices = past_key_values.transformer_layers - self.assertListEqual( - [past_key_values.key_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - self.assertListEqual( - [past_key_values.value_cache[idx].shape for idx in attention_layer_indices], - [expected_shape] * len(attention_layer_indices), - ) - - def _check_caches_are_equal(self, cache1: Cache, cache2: Cache): - "Qwen3-Next has a special Cache as it alternates with gated deltanet layers" - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") + def _get_recurrent_state_shape(self, batch_size: int, config): + num_v_heads = config.linear_num_value_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim - num_layers = len(cache1) - for idx in range(num_layers): - if cache1.key_cache[idx] is not None: - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) + return (batch_size, num_v_heads, head_k_dim, head_v_dim) def test_attention_outputs(self): "Needs to be overwritten as Qwen3-Next alternates between attention layers and gated deltanet layers." diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index c5d0c81be98c..9e7ee869b88d 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -38,14 +38,7 @@ if is_torch_available(): import torch - from transformers import ( - ZambaForCausalLM, - ZambaForSequenceClassification, - ZambaModel, - ) - from transformers.models.zamba.modeling_zamba import ( - ZambaHybridDynamicCache, - ) + from transformers import ZambaForCausalLM, ZambaForSequenceClassification, ZambaModel class ZambaModelTester: @@ -212,12 +205,9 @@ def create_and_check_decoder_model_past_large_inputs( model.eval() # first forward pass - # Attention: Zamba needs the cache to be initialized to return a cache! - past_key_values = ZambaHybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device) outputs = model( input_ids, attention_mask=input_mask, - past_key_values=past_key_values, use_cache=True, ) past_key_values = outputs.past_key_values @@ -297,47 +287,38 @@ class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi if is_torch_available() else {} ) + model_split_percents = [0.5, 0.8, 0.9] - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, ZambaHybridDynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "attention_head_dim") - attention_shape = (batch_size, num_heads, seq_length, head_dim) - + def _get_conv_state_shape(self, batch_size: int, config): intermediate_size = config.mamba_expand * config.hidden_size - conv_shape = (batch_size, intermediate_size, config.mamba_d_conv) - ssm_shape = (batch_size, config.n_mamba_heads, intermediate_size // config.n_mamba_heads, config.mamba_d_state) - - self.assertTrue(config.num_hidden_layers, len(past_key_values)) - - for idx in range(len(past_key_values)): - if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) - self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) - else: - self.assertEqual(past_key_values.key_cache[idx].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[idx].shape, attention_shape) - - def _check_caches_are_equal(self, cache1: ZambaHybridDynamicCache, cache2: ZambaHybridDynamicCache): - if not isinstance(cache1, ZambaHybridDynamicCache) or not isinstance(cache2, ZambaHybridDynamicCache): - raise ValueError("The wrong cache is being used!") + return (batch_size, intermediate_size, config.mamba_d_conv) - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx]) - torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx]) + def _get_recurrent_state_shape(self, batch_size: int, config): + intermediate_size = config.mamba_expand * config.hidden_size + return (batch_size, config.n_mamba_heads, intermediate_size // config.n_mamba_heads, config.mamba_d_state) def setUp(self): self.model_tester = ZambaModelTester(self) self.config_tester = ConfigTester(self, config_class=ZambaConfig, hidden_size=32) + @unittest.skip( + "Same as zamba2 -> investigate, it's probably due to their mixed layer classes or tied weights that accelerate does not work" + ) + def test_disk_offload_bin(self): + pass + + @unittest.skip( + "Same as zamba2 -> investigate, it's probably due to their mixed layer classes or tied weights that accelerate does not work" + ) + def test_disk_offload_safetensors(self): + pass + + @unittest.skip( + "Same as zamba2 -> investigate, it's probably due to their mixed layer classes or tied weights that accelerate does not work" + ) + def test_cpu_offload(self): + pass + def test_config(self): self.config_tester.run_common_tests() diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 8a9d168fe0c5..66b9093ee4e8 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -39,14 +39,7 @@ if is_torch_available(): import torch - from transformers import ( - Zamba2ForCausalLM, - Zamba2ForSequenceClassification, - Zamba2Model, - ) - from transformers.models.zamba2.modeling_zamba2 import ( - Zamba2HybridDynamicCache, - ) + from transformers import Zamba2ForCausalLM, Zamba2ForSequenceClassification, Zamba2Model class Zamba2ModelTester: @@ -222,12 +215,9 @@ def create_and_check_decoder_model_past_large_inputs( model.eval() # first forward pass - # Attention: Zamba2 needs the cache to be initialized to return a cache! - past_key_values = Zamba2HybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device) outputs = model( input_ids, attention_mask=input_mask, - past_key_values=past_key_values, use_cache=True, ) past_key_values = outputs.past_key_values @@ -307,46 +297,19 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix if is_torch_available() else {} ) + model_split_percents = [0.5, 0.8, 0.9] - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): - self.assertIsInstance(past_key_values, Zamba2HybridDynamicCache) - - # (batch, kv heads, seq_length, head_dim) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attention_shape = (batch_size, num_heads, seq_length, head_dim) - + def _get_conv_state_shape(self, batch_size: int, config): intermediate_size = config.mamba_expand * config.hidden_size conv_shape = ( batch_size, intermediate_size + 2 * config.mamba_ngroups * config.mamba_d_state, config.mamba_d_conv, ) - ssm_shape = (batch_size, config.n_mamba_heads, config.mamba_headdim, config.mamba_d_state) - - self.assertTrue(config.num_hidden_layers, len(past_key_values)) - - for idx in range(len(past_key_values)): - if config.layers_block_type[idx] == "mamba": - self.assertEqual(past_key_values.conv_states[idx].shape, conv_shape) - self.assertEqual(past_key_values.ssm_states[idx].shape, ssm_shape) - else: - self.assertEqual(past_key_values.key_cache[idx].shape, attention_shape) - self.assertEqual(past_key_values.value_cache[idx].shape, attention_shape) - - def _check_caches_are_equal(self, cache1: Zamba2HybridDynamicCache, cache2: Zamba2HybridDynamicCache): - if not isinstance(cache1, Zamba2HybridDynamicCache) or not isinstance(cache2, Zamba2HybridDynamicCache): - raise ValueError("The wrong cache is being used!") + return conv_shape - if not len(cache1) == len(cache2): - raise ValueError("Both caches do not have the same number of layers.") - - num_layers = len(cache1) - for idx in range(num_layers): - torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) - torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) - torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx]) - torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx]) + def _get_recurrent_state_shape(self, batch_size: int, config): + return (batch_size, config.n_mamba_heads, config.mamba_headdim, config.mamba_d_state) def setUp(self): self.model_tester = Zamba2ModelTester(self) @@ -368,6 +331,12 @@ def test_disk_offload_bin(self): def test_disk_offload_safetensors(self): pass + @unittest.skip( + "Offloading does not work correctly for zamba2 - probably due to their mixed layer classes or tied weights" + ) + def test_cpu_offload(self): + pass + @unittest.skip("position_ids cannot be used to pad due to Mamba2 layers") def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass @@ -508,11 +477,6 @@ def test_flash_attn_2_fp32_ln(self): # with attention mask _ = model(dummy_input, attention_mask=dummy_attention_mask) - @unittest.skip(reason="Zamba2 has its own special cache type") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch_accelerator def test_flex_attention_with_grads(self): """